博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
股票量化交易回测框架pyalgotrade源码阅读(一)
阅读量:6209 次
发布时间:2019-06-21

本文共 27867 字,大约阅读时间需要 92 分钟。

PyAlgoTrade是什么呢?

一个股票量化交易的策略回测框架。

而作者的说明如下。

To make it easy to backtest stock trading strategies.

  简单的来说,是一个用于验证自己交易策略的框架。

适用以下场景:

  我有个前无古人后无来者的想法,我觉得我按照这个想法去买股票稳赚不赔,但是为了稳妥起见,我需要测试一下这个我的这个想法到底用没有用,怎么测试呢?

大概下面两种方法

一:弄个模拟交易的软件,每天按照自己的想法买入卖出,然后看看一个月或者一年后的收益如何。

优点:更贴近现实,至少当下的现实

缺点:测试周期大,数据有限

二:我相信我的这个想法不是针对现在或者未来有用,甚至是在以前应该也是起作用的,那么我可以将历史数据调出来,用于测试,看看在历史行情中收益如何。

优点:数据充分,可以反复测试。

缺点:可能不能贴近现实

  而pyalgotrade就是为了提供给使用者基于历史数据回测的框架,即为了让你更好的使用上述的第二种方法。

注:无论怎么测,肯定都有偏差的, 因为都是猜,就像赌博,你算好了各种概率,想好了各种策略,但是你能保证的只是你赢钱的概率大一些,而不是必赢,因为在没有欺诈的情况下,未来是不可测,也不能确定的,谁也不能预知未来~吧~

文章目录

  1. 官方示例

  2. 设计模式之观察者模式

  3. 源码解析

官方示例

sma_crossover.py文件

from pyalgotrade import strategyfrom pyalgotrade.technical import mafrom pyalgotrade.technical import crossclass SMACrossOver(strategy.BacktestingStrategy):    def __init__(self, feed, instrument, smaPeriod):        super(SMACrossOver, self).__init__(feed)        self.__instrument = instrument        self.__position = None        # We'll use adjusted close values instead of regular close values.        self.setUseAdjustedValues(True)        self.__prices = feed[instrument].getPriceDataSeries()        self.__sma = ma.SMA(self.__prices, smaPeriod)    def getSMA(self):        return self.__sma    def onEnterCanceled(self, position):        self.__position = None    def onExitOk(self, position):        self.__position = None    def onExitCanceled(self, position):        # If the exit was canceled, re-submit it.        self.__position.exitMarket()    def onBars(self, bars):        # If a position was not opened, check if we should enter a long position.        if self.__position is None:            if cross.cross_above(self.__prices, self.__sma) > 0:                shares = int(self.getBroker().getCash() * 0.9 / bars[self.__instrument].getPrice())                # Enter a buy market order. The order is good till canceled.                self.__position = self.enterLong(self.__instrument, shares, True)        # Check if we have to exit the position.        elif not self.__position.exitActive() and cross.cross_below(self.__prices, self.__sma) > 0:            self.__position.exitMarket()

sma_crossover_sample.py

import sma_crossoverfrom pyalgotrade import plotterfrom pyalgotrade.tools import yahoofinancefrom pyalgotrade.stratanalyzer import sharpedef main(plot):    instrument = "aapl"    smaPeriod = 163    # Download the bars.    feed = yahoofinance.build_feed([instrument], 2011, 2012, ".")    strat = sma_crossover.SMACrossOver(feed, instrument, smaPeriod)    sharpeRatioAnalyzer = sharpe.SharpeRatio()    strat.attachAnalyzer(sharpeRatioAnalyzer)    if plot:        plt = plotter.StrategyPlotter(strat, True, False, True)        plt.getInstrumentSubplot(instrument).addDataSeries("sma", strat.getSMA())    strat.run()    print "Sharpe ratio: %.2f" % sharpeRatioAnalyzer.getSharpeRatio(0.05)    if plot:        plt.plot()if __name__ == "__main__":    main(True)

  上面的代码主要做一件这样的事。

  创建了一个策略,这个策略就是你的想法,这个想法是什么呢?

  想法是,当价格高于近163日内的平均价格就买入,低于近163日内的平均价格就卖出(平仓)。

  其实还做了其他的事,比如策略分析之类的,但是这篇文章暂时忽略。

设计模式之观察者模式

#!/usr/bin/python#coding:utf8'''Observer'''  class Subject(object):    def __init__(self):        self._observers = []     def attach(self, observer):        if not observer in self._observers:            self._observers.append(observer)     def detach(self, observer):        try:            self._observers.remove(observer)        except ValueError:            pass     def notify(self, modifier=None):        for observer in self._observers:            if modifier != observer:                observer.update(self) # Example usageclass Data(Subject):    def __init__(self, name=''):        Subject.__init__(self)        self.name = name        self._data = 0     @property    def data(self):        return self._data     @data.setter    def data(self, value):        self._data = value        self.notify() class HexViewer:    def update(self, subject):        print('HexViewer: Subject %s has data 0x%x' %              (subject.name, subject.data)) class DecimalViewer:    def update(self, subject):        print('DecimalViewer: Subject %s has data %d' %              (subject.name, subject.data)) # Example usage...def main():    data1 = Data('Data 1')    data2 = Data('Data 2')    view1 = DecimalViewer()    view2 = HexViewer()    data1.attach(view1)    data1.attach(view2)    data2.attach(view2)    data2.attach(view1)     print("Setting Data 1 = 10")    data1.data = 10    print("Setting Data 2 = 15")    data2.data = 15    print("Setting Data 1 = 3")    data1.data = 3    print("Setting Data 2 = 5")    data2.data = 5    print("Detach HexViewer from data1 and data2.")    data1.detach(view2)    data2.detach(view2)    print("Setting Data 1 = 10")    data1.data = 10    print("Setting Data 2 = 15")    data2.data = 15 if __name__ == '__main__':    main()

意图:

  定义对象间的一种一对多的依赖关系,当一个对象的状态发生改变时, 所有依赖于它的对象都得到通知并被自动更新。

适用性:

  当一个抽象模型有两个方面, 其中一个方面依赖于另一方面。将这二者封装在独立的对象中以使它们可以各自独立地改变和复用。

  当对一个对象的改变需要同时改变其它对象, 而不知道具体有多少对象有待改变。

  当一个对象必须通知其它对象,而它又不能假定其它对象是谁。换言之, 你不希望这些对象是紧密耦合的。

摘自:

  如果你看得懂就略过吧。

  上面的代码想做个上面事情呢?

  想达到事件的目的,即,在更新数据的时候,会触发相关的事件

  上面定义了主要三个种类型的类,subject,data,viewer。

  其中subject是data的父类。

  通过attach的操作,将不同的viewer加入到self.__observers列表里面,当data对象要更新数据的时候,就回调用notify方法,而notify方法则会遍历self.__observers列表的每个observer,然后依次调用其update方法。

  这也是为毛hexViewer,DecimalViewer都要实现自身的update方法。

  为毛要这么写?

  前人总结的经验~

  能不能不这么写?

  可以的。

  如果看不懂这个设计模式,那么pyalgotrade的源码看起来可能会比较吃力,但是也只是可能而已,因为很多人看不懂,只是因为没有实际的有用场景而已。

源码解析

  首先是框架,看一遍,比如那些模块,不过个人经验之谈就是,看完之后,一般都会有一下迷思。

  为毛这么写?

  这里到底想干什么?

  这么复杂有毛用~

  恩,我也是这种感觉~

  一般是pdb跟一遍流程或者一个一个找继承关系。

  pdb这里就不讲了,主要就是跟每个方法调用死磕到底,当然了,你也许有你得方法,我比较较真就是这样看源代码的,至少现在是这样的。

  在看源代码之前,官方文档,示例什么的最好也看一下,这样才能跟接近作者的意思。

  这里面有个对象,需要着重声明,那就是bar。

  什么是bar呢?

  每个bar都是一个时刻股票各个价格的集合,即,当前价格,当前时间,最高价,最低价,成交量什么的。

  而这些属性都是通过get_xxx的方法获取的。

获取数据

很明显数据是通过下面这行代码获取的。

feed = yahoofinance.build_feed([instrument], 2011, 2012, ".")

build_feed方法在tools/yahoofinance.py

def build_feed(instruments, fromYear, toYear, storage, frequency=bar.Frequency.DAY, timezone=None, skipErrors=False):logger = pyalgotrade.logger.getLogger("yahoofinance")    logger = pyalgotrade.logger.getLogger("yahoofinance")    ret = yahoofeed.Feed(frequency, timezone)    for year in range(fromYear, toYear+1):        for instrument in instruments:            fileName = os.path.join(storage, "%s-%d-yahoofinance.csv" % (instrument, year))            if not os.path.exists(fileName):                logger.info("Downloading %s %d to %s" % (instrument, year, fileName))                try:                    if frequency == bar.Frequency.DAY:                        download_daily_bars(instrument, year, fileName)                    elif frequency == bar.Frequency.WEEK:                        download_weekly_bars(instrument, year, fileName)                    else:                        raise Exception("Invalid frequency")                except Exception, e:                    if skipErrors:                        logger.error(str(e))                        continue                    else:                        raise e            ret.addBarsFromCSV(instrument, fileName)    return ret

在build_feed函数里面又根据情况调用了相应的下载函数

def download_csv(instrument, begin, end, frequency):    url = "http://ichart.finance.yahoo.com/table.csv?s=%s&a=%d&b=%d&c=%d&d=%d&e=%d&f=%d&g=%s&ignore=.csv" % (instrument, __adjust_month(begin.month), begin.day, begin.year, __adjust_month(end.month), end.day, end.year, frequency)    return csvutils.download_csv(url)

  而最终执行的下载函数为download_csv,通过这个函数我们可以访问yahoo的api,最终下载函数,当然了,可以进一步的查看csvutils.download_csv函数。

  这里我们知道数据是通过download_csv这个函数,将相应的股票代码,开始结束时间及频率传入,然后访问相应的url,得到相应的数据。

feed对象

  在tools/yahoofinance.py中我们可以看到,返回的结果并不是一个csv的对象,而是一个ret即,Feed对象,而Feed对象通过addBarsFromCSV将下载的数据加载到内存。

  从这里你也许会开始抓狂了为毛一层一层的继承。

其中yahoofeed.Feed在barfeed/yahoofeed.py

class Feed(csvfeed.BarFeed):    def addBarsFromCSV(self, instrument, path, timezone=None):        rowParser = RowParser(            self.getDailyBarTime(), self.getFrequency(), timezone, self.__sanitizeBars, self.__barClass        )        super(Feed, self).addBarsFromCSV(instrument, path, rowParser)

上面调用了父类的addBarsFromCSV方法。

父类的addBarsFromCSV在barfeed/csvfeed.py

class BarFeed(membf.BarFeed):    def addBarsFromCSV(self, instrument, path, rowParser):        # Load the csv file        loadedBars = []        reader = csvutils.FastDictReader(open(path, "r"), fieldnames=rowParser.getFieldNames(), delimiter=rowParser.getDelimiter())        for row in reader:            bar_ = rowParser.parseBar(row)            if bar_ is not None and (self.__barFilter is None or self.__barFilter.includeBar(bar_)):                loadedBars.append(bar_)        self.addBarsFromSequence(instrument, loadedBars)

然后csvfeed又调用了父类的方法~

值得注意的是,上面的rowParser.parseBar方法在子类实现的 。。。后面会在提及。

addBarsFromSequence方法在barfeed/membf.py

class BarFeed(barfeed.BaseBarFeed):    def addBarsFromSequence(self, instrument, bars):        if self.__started:            raise Exception("Can't add more bars once you started consuming bars")        self.__bars.setdefault(instrument, [])        self.__nextPos.setdefault(instrument, 0)        # Add and sort the bars        self.__bars[instrument].extend(bars)        barCmp = lambda x, y: cmp(x.getDateTime(), y.getDateTime())        self.__bars[instrument].sort(barCmp)        self.registerInstrument(instrument)

然后又调用了父类的方法~

值得注意的是这里将yahoo的数据存在了self.__bars中,至于bars是什么对象,后面再说。

registerInstrument方法在barfeed/__init__.py

class BaseBarFeed(feed.BaseFeed):    def registerInstrument(self, instrument):        self.__defaultInstrument = instrument        self.registerDataSeries(instrument)

然后又调用了父类的方法~

registerDataSeries方法在feed/__init__.py

class BaseFeed(observer.Subject):    def __init__(self, maxLen):        super(BaseFeed, self).__init__()        maxLen = dataseries.get_checked_max_len(maxLen)        self.__ds = {}        self.__event = observer.Event()        self.__maxLen = maxLen    def registerDataSeries(self, key):        if key not in self.__ds:            self.__ds[key] = self.createDataSeries(key, self.__maxLen)

  恩,这里就是逻辑的终点了,虽然它还是继承,不过pyalgotrade里面大多数对象都是是继承observer.Subject,之所以继承,是为了完成类似观察者的设计模式里面的事件操作。

  简单总结一下继承关系。

barfeed/yahoofeed.Feed -> barfeed/csvfeed.BarFeed -> barfeed/membf.BarFeed -> barfeed/__init__.py.BaseFeed -> feed/__init.py.BaseFeed

  然后yahoo的数据结果,最终是由RowParser的parseBar方法依次导入,而RowPaser.parseBar方法是在barfeed/yahoofeed.py中。

  然后我们再来走一遍加载数据的流程,不过这次不只是整个逻辑,而这次我们关注于具体的数据是啥。

其中barfeed/yahoofeed里面的RowParser的逻辑及parsrBar的具体的具体实现,截取如下。

class RowParser(csvfeed.RowParser):    def __init__(self, dailyBarTime, frequency, timezone=None, sanitize=False, barClass=bar.BasicBar):        self.__dailyBarTime = dailyBarTime        self.__frequency = frequency        self.__timezone = timezone        self.__sanitize = sanitize        self.__barClass = barClass    def __parseDate(self, dateString):        ret = parse_date(dateString)        # Time on Yahoo! Finance CSV files is empty. If told to set one, do it.        if self.__dailyBarTime is not None:            ret = datetime.datetime.combine(ret, self.__dailyBarTime)        # Localize the datetime if a timezone was given.        if self.__timezone:            ret = dt.localize(ret, self.__timezone)        return ret    def getFieldNames(self):        # It is expected for the first row to have the field names.        return None    def getDelimiter(self):        return ","    def parseBar(self, csvRowDict):        dateTime = self.__parseDate(csvRowDict["Date"])        close = float(csvRowDict["Close"])        open_ = float(csvRowDict["Open"])        high = float(csvRowDict["High"])        low = float(csvRowDict["Low"])        volume = float(csvRowDict["Volume"])        adjClose = float(csvRowDict["Adj Close"])        if self.__sanitize:            open_, high, low, close = common.sanitize_ohlc(open_, high, low, close)        return self.__barClass(dateTime, open_, high, low, close, volume, adjClose, self.__frequency)

  其中解析后返回的结果是一个bar.BasicBar对象。

  然后调用父类barfeed/csvfeed里面的addBarsFromCSV方法,得到一个bar.BasicBar对象的列表,即loadBars。传入继承于父类的addBarsFromSequence方法,截取如下。

class BarFeed(membf.BarFeed):    def addBarsFromCSV(self, instrument, path, rowParser):        # Load the csv file        loadedBars = []        reader = csvutils.FastDictReader(open(path, "r"), fieldnames=rowParser.getFieldNames(), delimiter=rowParser.getDelimiter())        for row in reader:            bar_ = rowParser.parseBar(row)            if bar_ is not None and (self.__barFilter is None or self.__barFilter.includeBar(bar_)):                loadedBars.append(bar_)        self.addBarsFromSequence(instrument, loadedBars)

下面则是处理addBarsFromSequence的操作,主要是创建了一个self.__bars的字典,每个股票代码对应相应时间段的bar.BasicBar对象的列表,然后调用父类的registerInstrument方法,传入相应的股票代码。

barfeed/membf.py --> BarFeed

class BarFeed(barfeed.BaseBarFeed):    def addBarsFromSequence(self, instrument, bars):        if self.__started:            raise Exception("Can't add more bars once you started consuming bars")        self.__bars.setdefault(instrument, [])        self.__nextPos.setdefault(instrument, 0)        # Add and sort the bars        self.__bars[instrument].extend(bars)        barCmp = lambda x, y: cmp(x.getDateTime(), y.getDateTime())        self.__bars[instrument].sort(barCmp)        self.registerInstrument(instrument)

下面则是registerInstrument的具体逻辑,即注册DataSeries对象,而registerDataSeries方法是在父类实现。

barfeed/__init__.py --->BaseBarFeed

BaseBarFeed(feed.BaseFeed):    def registerInstrument(self, instrument):        self.__defaultInstrument = instrument        self.registerDataSeries(instrument)

下面则是最终的registerDataSeries操作,创建了一个dataseries的对象。

feed/__init__.py  --->BaseFeed

class BaseFeed(observer.Subject):    def registerDataSeries(self, key):        if key not in self.__ds:            self.__ds[key] = self.createDataSeries(key, self.__maxLen)

而createDataSeries方法并没有在基类中实现。

@abc.abstractmethoddef createDataSeries(self, key, maxLen):    raise NotImplementedError()

createDataSeries的具体实现则是在barfeed/__init__.py --->BaseBarFeed

    def createDataSeries(self, key, maxLen):        ret = bards.BarDataSeries(maxLen)        ret.setUseAdjustedValues(self.__useAdjustedValues)        return ret

所以最终,feed对象有两个重要的数据集。

一:

self.__bars

里面的数据结构大概是{"instrument_xx":[bar1,bar2,bar3]}

self.__ds = {}

里面的数据结构大概是self.__ds = {"instrument_xx": dataseries_xx}

其中instrument指特定的股票代码,比如aapl,bar1,bar2则是bar.BasicBar对象,dataseries则是bards.BarDataSeries对象。

至于bar.BasicBar以及dataseries的数据结构到底是什么,大家可以自行瞧瞧。

值得注意的是,父类与基类之间数据获取不会通过共享变量的方式获得,比如最终通过基类self.__ds的数据是通过基类的getKeys的方法暴露给子类去获取实际的数据。。

策略

初始化策略

strat = sma_crossover.SMACrossOver(feed, instrument, smaPeriod)

策略最终继承于strategy.BacktestingStrategy

analyzer

创建一个stratanalyzer的实例并attach

sharpeRatioAnalyzer = sharpe.SharpeRatio()strat.attachAnalyzer(sharpeRatioAnalyzer)

analyzer这里暂时不说,因为,这里主要将具体的策略实现,以及feed对象,analyzer以及broker的内容会放在下一篇文章讲。

run

运行策略。

strat.run()

run方法在strategy/__init__.py里面的BaseStrategy类。

class BaseStrategy(object):    def run(self):    """Call once (**and only once**) to run the strategy."""        self.__dispatcher.run()    if self.__barFeed.getCurrentBars() is not None:        self.onFinish(self.__barFeed.getCurrentBars())    else:        raise Exception("Feed was empty")

而run方法会调用self.__dispatcher的run方法,即dispatcher.py里面的Dispatcher类,在说Dispatcher类之前,我们得先看看BaseStrategy在初始化的时候到底初始化了啥。

class BaseStrategy(object):    def __init__(self, barFeed, broker):        self.__barFeed = barFeed        self.__broker = broker        self.__activePositions = set()        self.__orderToPosition = {}        self.__barsProcessedEvent = observer.Event()        self.__analyzers = []        self.__namedAnalyzers = {}        self.__resampledBarFeeds = []        self.__dispatcher = dispatcher.Dispatcher()        self.__broker.getOrderUpdatedEvent().subscribe(self.__onOrderEvent)        self.__barFeed.getNewValuesEvent().subscribe(self.__onBars)        self.__dispatcher.getStartEvent().subscribe(self.onStart)        self.__dispatcher.getIdleEvent().subscribe(self.__onIdle)        # It is important to dispatch broker events before feed events, specially if we're backtesting.        self.__dispatcher.addSubject(self.__broker)        self.__dispatcher.addSubject(self.__barFeed)

  绑定barFeed,broker到self,初始化__activePositions,OderToPosition,__analyzers,__namedAnlyzers,__resampledBarFeeds的值,并初始化一个observer.Event的实例。

  创建一个dispatcher的实例,并在dispatcher的初始化过程中创建两个observer.Event,observer.Event的实例。

  其中broker实例通过getOrderUpdatedEvent方法得到一个event实例,并订阅策略的onOrderEvent的事件

  barFeed实例通过getNewValuesEvent方法得到一个event实例,并订阅策略的onBars的事件。

  dispatcher的实例分别获得startEvent,IdleEvent并订阅onStart,__onIdle事件。

  最后dispatcher实例将broker,barFeed两个subject分别加入到dispatcher的subjects列表中。

  然后我们在回到Dispatcher类的run方法,这里主要是首先遍历自己__subjects列表里面的subject,然后调用每个subject的start方法,由BaseStrategy类的初始化方法可知,dispatcher加入了两个subject,分别是broker,barFeed。

具体实现如下。

class Dispatcher(object):    def run(self):    try:        for subject in self.__subjects:            subject.start()        self.__startEvent.emit()                while not self.__stop:            eof, eventsDispatched = self.__dispatch()        if eof:            self.__stop = True        elif not eventsDispatched:            self.__idleEvent.emit()    finally:        for subject in self.__subjects:            subject.stop()        for subject in self.__subjects:            subject.join()

整个回测策略的逻辑基本就是在dispatcher调度各个subject并触发事件的过程。

调用完每个subject的start方法后,执行自身的self.__startEvent.emit方法。

然后通过while循环启动整个运转逻辑。

在循环结束后依次启动每个subject并等待所有subject关闭。

现在再次回到初始化过程,查看各个event,subject的内容到底是什么。

self.__broker.getOrderUpdatedEvent().subscribe(self.__onOrderEvent)    def __onOrderEvent(self, broker_, orderEvent):        order = orderEvent.getOrder()        self.onOrderUpdated(order)        self.__barFeed.getNewValuesEvent().subscribe(self.__onBars)    def __onBars(self, dateTime, bars):        # THE ORDER HERE IS VERY IMPORTANT        # 1: Let analyzers process bars.        self.__notifyAnalyzers(lambda s: s.beforeOnBars(self, bars))        # 2: Let the strategy process current bars and submit orders.        self.onBars(bars)        # 3: Notify that the bars were processed.        self.__barsProcessedEvent.emit(self, bars)        self.__dispatcher.getStartEvent().subscribe(self.onStart)    def onStart(self):        """Override (optional) to get notified when the strategy starts executing. The default implementation is empty. """        pass        self.__dispatcher.getIdleEvent().subscribe(self.__onIdle)        def __onIdle(self):        # Force a resample check to avoid depending solely on the underlying        # barfeed events.        for resampledBarFeed in self.__resampledBarFeeds:        resampledBarFeed.checkNow(self.getCurrentDateTime())        self.onIdle()

上面是各个event订阅的subject,是相应的handler函数。

然后现在瞧瞧每个subject的start方法。

其中observer.py里面定义的Subject类似一个抽象工厂,只是定义了各个方法但是并没有实现具体方法的逻辑。

我们首先来看看broker这个subject的start方法的处理逻辑。

而继承observer.Subject的Broker也只是一个抽象工厂,定义了一系列的接口。

在此策略中,我们据代码得知,我们初始化的broker是一个backtesting的broker,代码如下。

class BacktestingStrategy(BaseStrategy):    def __init__(self, barFeed, cash_or_brk=1000000):        # The broker should subscribe to barFeed events before the strategy.        # This is to avoid executing orders submitted in the current tick.        if isinstance(cash_or_brk, pyalgotrade.broker.Broker):            broker = cash_or_brk        else:          broker = backtesting.Broker(cash_or_brk, barFeed)        查看backtesting的broker        broker/backtesting.py        class Broker(broker.Broker):        def start(self):            super(Broker, self).start()

 

查看backtesting的broker -> broker/backtesting.py

        class Broker(broker.Broker):        def start(self):            super(Broker, self).start()

其中基类的start如下

observer.pyclass Subject(object):@abc.abstractmethoddef start(self):pass

然后再来看barFeed的subject的start

其中barFeed也没有自己定义start方法,即,start方法也是如上。

在每个subject调用start方法后,dispatcher就会调用自身self.__startEvent.emit。然后到循环eof, eventsDispatched = self.__dispatch()

    def __dispatch(self):        smallestDateTime = None        eof = True        eventsDispatched = False        # Scan for the lowest datetime.        for subject in self.__subjects:            if not subject.eof():                eof = False                smallestDateTime = utils.safe_min(smallestDateTime, subject.peekDateTime())

再次实例创建的feed为yahoofeed

而依次继承于csvfeed.BarFeed,membf.BarFeed,barfeed.BaseBaseFeed,feed.BaseFeed

其中membf.BarFeed,BaseBarFeed都实现了eof方法。

通过代码追踪,我们发现eof主要为了判断是否以及迭代完每一个bar

代码如下

    def eof(self):        ret = True        # Check if there is at least one more bar to return.        for instrument, bars in self.__bars.iteritems():            nextPos = self.__nextPos[instrument]            if nextPos < len(bars):                ret = False                break        return ret

其中self.__nextPos在addBarsFromSequence函数里面已经将其定义为0,也就是说,这个nextPos是为了在迭代每个bar的同时记录迭代的位置,即索引位置。

当判断完eof之后,则调用__dispatchSubject方法,迭代每个subject并调用其dispatch方法。

其中dispatch的实现在基类feed/__init__.py

class BaseFeed(observer.Subject):    def dispatch(self):        dateTime, values = self.getNextValuesAndUpdateDS()        if dateTime is not None:            self.__event.emit(dateTime, values)        return dateTime is not None

getNextValuesAndUpdateDS方法实现在feed/__init__.py

   def getNextValuesAndUpdateDS(self):        dateTime, values = self.getNextValues()        if dateTime is not None:            for key, value in values.items():                # Get or create the datseries for each key.                try:                    ds = self.__ds[key]                except KeyError:                    ds = self.createDataSeries(key, self.__maxLen)                    self.__ds[key] = ds                ds.appendWithDateTime(dateTime, value)        return (dateTime, values)    def __iter__(self):        return feed_iterator(self)

而getNextValues的方法实现在barfeed/__init__.py

class BaseBarFeed(feed.BaseFeed):    def getNextValues(self):        dateTime = None        bars = self.getNextBars()        if bars is not None:            dateTime = bars.getDateTime()            # Check that current bar datetimes are greater than the previous one.            if self.__currentBars is not None and self.__currentBars.getDateTime() >= dateTime:                raise Exception(                    "Bar date times are not in order. Previous datetime was %s and current datetime is %s" % (                        self.__currentBars.getDateTime(),                        dateTime                    )                )            # Update self.__currentBars and self.__lastBars            self.__currentBars = bars            for instrument in bars.getInstruments():                self.__lastBars[instrument] = bars[instrument]        return (dateTime, bars)

其中 getNextBars的方法实现在barfeed/membf.py

class BarFeed(barfeed.BaseBarFeed):    def getNextBars(self):        # All bars must have the same datetime. We will return all the ones with the smallest datetime.        smallestDateTime = self.peekDateTime()        if smallestDateTime is None:            return None        # Make a second pass to get all the bars that had the smallest datetime.        ret = {}        for instrument, bars in self.__bars.iteritems():            nextPos = self.__nextPos[instrument]            if nextPos < len(bars) and bars[nextPos].getDateTime() == smallestDateTime:                ret[instrument] = bars[nextPos]                self.__nextPos[instrument] += 1        if self.__currDateTime == smallestDateTime:            raise Exception("Duplicate bars found for %s on %s" % (ret.keys(), smallestDateTime))        self.__currDateTime = smallestDateTime        return bar.Bars(ret)

其中Bars对象则是对bar的进一层封装

提供方法如下。

def __getitem__(self, instrument):return self.__barDict[instrument]def __contains__(self, instrument):return instrument in self.__barDictdef items(self):def keys(self):def getInstruments(self):def getDateTime(self):def getBar(self, instrument):

至此,我们了解到了feed对象,以及每个bar是怎么迭代的,但是还没有看到每个bar的处理操作。

所以在回到feed的dispatch方法,处理流程如下

    def dispatch(self):        dateTime, values = self.getNextValuesAndUpdateDS()        if dateTime is not None:            self.__event.emit(dateTime, values)        return dateTime is not None

需要着重说明的就是self.__event.emit(dateTime, values)

其中values是一个bar.Bars实例。

broker的dispatch方法

def dispatch(self):# All events were already emitted while handling barfeed events.pass

这里,我们可以看到如果dataTime不是None的话,就会通过emit提交时间

而feed里面注册了__onBars的handlers

所以在每次迭代的时候都会触发event的emit操作,即执行每个在feed中注册了的handler,这里只注册了一个handler--->__onBars

def __onBars(self, dateTime, bars):    # THE ORDER HERE IS VERY IMPORTANT    # 1: Let analyzers process bars.    self.__notifyAnalyzers(lambda s: s.beforeOnBars(self, bars))    # 2: Let the strategy process current bars and submit orders.    self.onBars(bars)    # 3: Notify that the bars were processed.    self.__barsProcessedEvent.emit(self, bars)

所以迭代每一个bar的时候,都会执行onBar的函数。

而onBar函数是自己定义的,在本示例中,onBar的函数内容如下

def onBars(self, bars):    def onBars(self, bars):        # If a position was not opened, check if we should enter a long position.        if self.__position is None:            if cross.cross_above(self.__prices, self.__sma) > 0:                shares = int(self.getBroker().getCash() * 0.9 / bars[self.__instrument].getPrice())                # Enter a buy market order. The order is good till canceled.                self.__position = self.enterLong(self.__instrument, shares, True)        # Check if we have to exit the position.        elif not self.__position.exitActive() and cross.cross_below(self.__prices, self.__sma) > 0:            self.__position.exitMarket()

bar是每个指定频率的open,close,low,high,adj close,volume数据集合对象。

DataSeries是一个随着迭代,不断增加datetime,以及bar的序列。

而technical的触发是在feed/__init__.py里面的ds.appendWithDateTime。

    def getNextValuesAndUpdateDS(self):        dateTime, values = self.getNextValues()        if dateTime is not None:            for key, value in values.items():                # Get or create the datseries for each key.                try:                    ds = self.__ds[key]                except KeyError:                    ds = self.createDataSeries(key, self.__maxLen)                    self.__ds[key] = ds                ds.appendWithDateTime(dateTime, value)        return (dateTime, values)

然后ma.py

class SMA(technical.EventBasedFilter):    def __init__(self, dataSeries, period, maxLen=None):    super(SMA, self).__init__(dataSeries, SMAEventWindow(period), maxLen)

然后technical/__init__.py

class EventBasedFilter(dataseries.SequenceDataSeries):    def __init__(self, windowSize, dtype=float, skipNone=True):        assert(windowSize > 0)        assert(isinstance(windowSize, int))        self.__values = collections.NumPyDeque(windowSize, dtype)        self.__windowSize = windowSize        self.__skipNone = skipNone    def __onNewValue(self, dataSeries, dateTime, value):        # Let the event window perform calculations.        self.__eventWindow.onNewValue(dateTime, value)        # Get the resulting value        newValue = self.__eventWindow.getValue()        # Add the new value.        self.appendWithDateTime(dateTime, newValue)

而__eventWindow.onNewValue在technical/ma.py

class SMAEventWindow(technical.EventWindow):    def __init__(self, period):        assert(period > 0)        super(SMAEventWindow, self).__init__(period)        self.__value = None    def onNewValue(self, dateTime, value):        firstValue = None        if len(self.getValues()) > 0:            firstValue = self.getValues()[0]            assert(firstValue is not None)        super(SMAEventWindow, self).onNewValue(dateTime, value)        if value is not None and self.windowFull():            if self.__value is None:                self.__value = self.getValues().mean()            else:                self.__value = self.__value + value / float(self.getWindowSize()) - firstValue / float(self.getWindowSize())    def getValue(self):        return self.__value

至此基于pyalgotrade的一个简单示例,按照其执行流程的源码解读到此完毕。

后记:后面有点乱了,写篇文章还是蛮费时间的,太长了,pyalgotrade的源码解读估计还得写一段时间去了。

这就是系列的衍生篇了。

参考链接:

Python设计模式: 

PyAlgoTrade 文档: 

如果觉得不错,并有所收获,请我喝杯茶呗

wKioL1lU4MXwELckAADg-gB3Tsc583.jpg-wh_50wKiom1lU4Mqg8rxIAADzypnX0FU518.jpg-wh_50

转载地址:http://hbbja.baihongyu.com/

你可能感兴趣的文章
Redis安全
查看>>
Lua面向对象设计(转)
查看>>
动态载入Layout 与 论Activity、 Window、View的关系
查看>>
发展中的生命力——Leo鉴书69
查看>>
iOS计算两个时间的时间差
查看>>
细说C#多线程那些事 - 线程同步和多线程优先级
查看>>
Woobuntu woobuntu_build.sh hacking
查看>>
接口与抽象类的区别
查看>>
CORS 专题
查看>>
检查给定串是否存在于由区间及点集的结合内
查看>>
美团团购订单系统优化记
查看>>
Iptables防火墙规则使用梳理
查看>>
使用FileReader接口读取文件内容
查看>>
Spring_使用XML文件的方式配置事务
查看>>
css 点点加载demo
查看>>
TCP/IP 协议族的简介
查看>>
简单单层bp神经网络
查看>>
eclipse Maven 使用记录 ------ 建立 webapp项目
查看>>
解决Python交叉编译后,键盘方向键乱码的问题
查看>>
idea svn 不见的问题
查看>>