diff --git a/tests/test_transforms.py b/tests/test_transforms.py index ea1024f6..4f6cae0f 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -294,15 +294,24 @@ class BatchTransformTestCase(TestCase): setup_logger(self) self.source, self.df = factory.create_test_df_source() - def test_batch_inherit(self): + def test_event_window(self): algo = BatchTransformAlgorithm(sids=[0, 1]) algo.run(self.source) - assert algo.history_class[:2] == algo.history_decorator[:2] == [None, None], "First two iterations should return None" + self.assertEqual(algo.history_return_price_class[:2], [None, None], "First two iterations should return None") + self.assertEqual(algo.history_return_price_decorator[:2], [None, None], "First two iterations should return None") # test overloaded class - for test_history in [algo.history_class, algo.history_decorator]: + for test_history in [algo.history_return_price_class, algo.history_return_price_decorator]: self.assertTrue(np.all(test_history[2].values.flatten() == range(4, 10))) self.assertTrue(np.all(test_history[3].values.flatten() == range(4, 10))) self.assertTrue(np.all(test_history[4].values.flatten() == range(6, 14))) + def test_passing_of_args(self): + algo = BatchTransformAlgorithm([0, 1], 1, kwarg='str') + self.assertEqual(algo.args, (1,)) + self.assertEqual(algo.kwargs, {'kwarg':'str'}) + + algo.run(self.source) + expected_item = ((1, ), {'kwarg': 'str'}) + self.assertEqual(algo.history_return_args, [None, None, expected_item, expected_item, expected_item]) diff --git a/zipline/algorithm.py b/zipline/algorithm.py index 08d388da..f54ba0cb 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -24,7 +24,7 @@ class TradingAlgorithm(object): ``` To then to run this algorithm: - >>> my_algo = MyAlgo(100, sids=[0]) + >>> my_algo = MyAlgo([0], 100) # first argument has to be list of sids >>> stats = my_algo.run(data) """ @@ -32,7 +32,7 @@ class TradingAlgorithm(object): """ Initialize sids and other state variables. - Calls user-defined initialize and forwarding *args and **kwargs. + Calls user-defined initialize() forwarding *args and **kwargs. """ self.sids = sids self.done = False @@ -45,6 +45,8 @@ class TradingAlgorithm(object): # call to user-defined initialize method self.initialize(*args, **kwargs) + self.initialized = True + def _create_simulator(self, start, end): """ Create trading environment, transforms and SimulatedTrading object. diff --git a/zipline/gens/tradesimulation.py b/zipline/gens/tradesimulation.py index 5efb1505..5b5cec52 100644 --- a/zipline/gens/tradesimulation.py +++ b/zipline/gens/tradesimulation.py @@ -219,9 +219,11 @@ class AlgorithmSimulator(object): # snapshot time to any log record generated. #with self.processor.threadbound(), self.stdout_capture(Logger('Print'),''): - # Call user's initialize method with a timeout. - with Timeout(INIT_TIMEOUT, message="Call to initialize timed out"): - self.algo.initialize() + # Call user's initialize method with a timeout (only if + # initialize wasn't called already). + if not getattr(self.algo, 'initialized', False): + with Timeout(INIT_TIMEOUT, message="Call to initialize timed out"): + self.algo.initialize() # Group together events with the same dt field. This depends on the # events already being sorted. diff --git a/zipline/gens/transform.py b/zipline/gens/transform.py index 8ec862eb..3ff259cc 100644 --- a/zipline/gens/transform.py +++ b/zipline/gens/transform.py @@ -339,7 +339,7 @@ class BatchTransform(EventWindow): self.updated = False self.data = None - def handle_data(self, data): + def handle_data(self, data, *args, **kwargs): """ New method to handle a data frame as sent to the algorithm's handle_data method. @@ -356,7 +356,7 @@ class BatchTransform(EventWindow): self.update(data) # return newly computed or cached value - return self.get_transform_value() + return self.get_transform_value(*args, **kwargs) def handle_add(self, event): if not self.last_refresh: diff --git a/zipline/lines.py b/zipline/lines.py index a3dfa9ba..a8fedc56 100644 --- a/zipline/lines.py +++ b/zipline/lines.py @@ -60,7 +60,6 @@ before invoking simulate. +---------------------------------+ """ -from zipline.test_algorithms import TestAlgorithm from zipline.utils import factory from zipline.gens.composites import ( @@ -144,6 +143,8 @@ class SimulatedTrading(object): - transforms: optional parameter that provides a list of StatefulTransform objects. """ + from zipline.test_algorithms import TestAlgorithm + assert isinstance(config, dict) sid_list = config.get('sid_list') if not sid_list: diff --git a/zipline/optimize/example.py b/zipline/optimize/example.py index c8912558..ef84fc0d 100644 --- a/zipline/optimize/example.py +++ b/zipline/optimize/example.py @@ -8,13 +8,7 @@ import numpy as np import matplotlib.pyplot as plt import cProfile from zipline.gens.mavg import MovingAverage -from zipline.gens.cov import CovTransform, cov from zipline.algorithm import TradingAlgorithm -from zipline.gens.transform import BatchTransform, batch_transform - -@batch_transform -def cov(data): - return data.price.cov() class DMA(TradingAlgorithm): """Dual Moving Average algorithm. @@ -35,14 +29,9 @@ class DMA(TradingAlgorithm): market_aware=True, days=long_window) - self.cov = cov(sids=self.sids, refresh_period=1, days=5) - def handle_data(self, data): self.events += 1 - cov = self.cov.handle_data(data) - print cov - for sid in self.sids: # access transforms via their user-defined tag if (data[sid].short_mavg['price'] > data[sid].long_mavg['price']) and not self.invested[sid]: @@ -56,17 +45,17 @@ class DMA(TradingAlgorithm): def load_close_px(indexes=None, stocks=None): from pandas.io.data import DataReader import pytz + from collections import OrderedDict if indexes is None: indexes = {'SPX' : '^GSPC'} if stocks is None: - stocks = ['AAPL'] #, 'GE', 'IBM', 'MSFT', 'XOM', 'AA', 'JNJ', 'PEP'] + stocks = ['AAPL', 'GE', 'IBM', 'MSFT', 'XOM', 'AA', 'JNJ', 'PEP'] - #start = pd.datetime(1990, 1, 1) start = pd.datetime(1990, 1, 1, 0, 0, 0, 0, pytz.utc) - end = pd.datetime(1992, 1, 1, 0, 0, 0, 0, pytz.utc) #pd.datetime.today() + end = pd.datetime(1992, 1, 1, 0, 0, 0, 0, pytz.utc) - data = {} + data = OrderedDict() for stock in stocks: print stock @@ -82,13 +71,13 @@ def load_close_px(indexes=None, stocks=None): df = pd.DataFrame({i: d['Close'] for i, d in enumerate(data.itervalues())}) df.index = df.index.tz_localize(pytz.utc) + df.save('close_px.dat') return df def run((short_window, long_window)): - #data = pd.DataFrame.from_csv('SP500.csv') - #data = pd.DataFrame.from_csv('aapl.csv') #load_close_px() - data = load_close_px() + data = pd.load('close_px.dat') + #data = load_close_px() myalgo = DMA([0, 1], amount=100, short_window=short_window, long_window=long_window) stats = myalgo.run(data) stats['sw'] = short_window diff --git a/zipline/optimize/profile.sh b/zipline/optimize/profile.sh new file mode 100755 index 00000000..0e0b602b --- /dev/null +++ b/zipline/optimize/profile.sh @@ -0,0 +1,2 @@ +#!/bin/bash +python -m cProfile -o example.prof example.py diff --git a/zipline/test_algorithms.py b/zipline/test_algorithms.py index 10795e33..a8bc6ab4 100644 --- a/zipline/test_algorithms.py +++ b/zipline/test_algorithms.py @@ -399,33 +399,53 @@ class TestRegisterTransformAlgorithm(TradingAlgorithm): def handle_data(self, data): pass -class NoopBatchTransform(BatchTransform): +########################################## +# Algorithm using simple batch transforms + +class ReturnPriceBatchTransform(BatchTransform): def get_value(self, data): return data.price @batch_transform -def noop_batch_decorator(data): +def return_price_batch_decorator(data): return data.price +@batch_transform +def return_args_batch_decorator(data, *args, **kwargs): + return args, kwargs + class BatchTransformAlgorithm(TradingAlgorithm): def initialize(self, *args, **kwargs): - self.history_class = [] - self.history_decorator = [] - self.days = 3 - self.noop_class = NoopBatchTransform(sids=[0, 1], - market_aware=False, - refresh_period=2, - delta=timedelta(days=self.days)) + self.history_return_price_class = [] + self.history_return_price_decorator = [] + self.history_return_args = [] - self.noop_decorator = noop_batch_decorator(sids=[0, 1], - market_aware=False, - refresh_period=2, - delta=timedelta(days=self.days)) + self.days = 3 + + self.args = args + self.kwargs = kwargs + + self.return_price_class = ReturnPriceBatchTransform(sids=self.sids, + market_aware=False, + refresh_period=2, + delta=timedelta(days=self.days) + ) + + self.return_price_decorator = return_price_batch_decorator(sids=self.sids, + market_aware=False, + refresh_period=2, + delta=timedelta(days=self.days) + ) + + self.return_args_batch = return_args_batch_decorator(sids=self.sids, + market_aware=False, + refresh_period=2, + delta=timedelta(days=self.days) + ) def handle_data(self, data): - window_class = self.noop_class.handle_data(data) - window_decorator = self.noop_decorator.handle_data(data) - self.history_class.append(window_class) - self.history_decorator.append(window_decorator) + self.history_return_price_class.append(self.return_price_class.handle_data(data)) + self.history_return_price_decorator.append(self.return_price_decorator.handle_data(data)) + self.history_return_args.append(self.return_args_batch.handle_data(data, *self.args, **self.kwargs))