diff --git a/tests/test_transforms.py b/tests/test_transforms.py index ea1024f6..e3cc1961 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -294,15 +294,23 @@ class BatchTransformTestCase(TestCase): setup_logger(self) self.source, self.df = factory.create_test_df_source() - def test_batch_inherit(self): + def test_event_window(self): algo = BatchTransformAlgorithm(sids=[0, 1]) algo.run(self.source) - assert algo.history_class[:2] == algo.history_decorator[:2] == [None, None], "First two iterations should return None" + assert algo.history_return_price_class[:2] == algo.history_return_price_decorator[:2] == [None, None], "First two iterations should return None" # test overloaded class - for test_history in [algo.history_class, algo.history_decorator]: + for test_history in [algo.history_return_price_class, algo.history_return_price_decorator]: self.assertTrue(np.all(test_history[2].values.flatten() == range(4, 10))) self.assertTrue(np.all(test_history[3].values.flatten() == range(4, 10))) self.assertTrue(np.all(test_history[4].values.flatten() == range(6, 14))) + def test_passing_of_args(self): + algo = BatchTransformAlgorithm([0, 1], 1, kwarg='str') + algo.run(self.source) + self.assertEqual(algo.args, (1,)) + self.assertEqual(algo.kwargs, {'kwarg':'str'}) + expected_item = ((1, ), {'kwarg': 'str'}) + self.assertEqual(algo.history_return_args, [None, None, expected_item, expected_item, expected_item]) + diff --git a/zipline/algorithm.py b/zipline/algorithm.py index 08d388da..f54ba0cb 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -24,7 +24,7 @@ class TradingAlgorithm(object): ``` To then to run this algorithm: - >>> my_algo = MyAlgo(100, sids=[0]) + >>> my_algo = MyAlgo([0], 100) # first argument has to be list of sids >>> stats = my_algo.run(data) """ @@ -32,7 +32,7 @@ class TradingAlgorithm(object): """ Initialize sids and other state variables. - Calls user-defined initialize and forwarding *args and **kwargs. + Calls user-defined initialize() forwarding *args and **kwargs. """ self.sids = sids self.done = False @@ -45,6 +45,8 @@ class TradingAlgorithm(object): # call to user-defined initialize method self.initialize(*args, **kwargs) + self.initialized = True + def _create_simulator(self, start, end): """ Create trading environment, transforms and SimulatedTrading object. diff --git a/zipline/gens/tradesimulation.py b/zipline/gens/tradesimulation.py index 5efb1505..5b5cec52 100644 --- a/zipline/gens/tradesimulation.py +++ b/zipline/gens/tradesimulation.py @@ -219,9 +219,11 @@ class AlgorithmSimulator(object): # snapshot time to any log record generated. #with self.processor.threadbound(), self.stdout_capture(Logger('Print'),''): - # Call user's initialize method with a timeout. - with Timeout(INIT_TIMEOUT, message="Call to initialize timed out"): - self.algo.initialize() + # Call user's initialize method with a timeout (only if + # initialize wasn't called already). + if not getattr(self.algo, 'initialized', False): + with Timeout(INIT_TIMEOUT, message="Call to initialize timed out"): + self.algo.initialize() # Group together events with the same dt field. This depends on the # events already being sorted. diff --git a/zipline/gens/transform.py b/zipline/gens/transform.py index 8ec862eb..3ff259cc 100644 --- a/zipline/gens/transform.py +++ b/zipline/gens/transform.py @@ -339,7 +339,7 @@ class BatchTransform(EventWindow): self.updated = False self.data = None - def handle_data(self, data): + def handle_data(self, data, *args, **kwargs): """ New method to handle a data frame as sent to the algorithm's handle_data method. @@ -356,7 +356,7 @@ class BatchTransform(EventWindow): self.update(data) # return newly computed or cached value - return self.get_transform_value() + return self.get_transform_value(*args, **kwargs) def handle_add(self, event): if not self.last_refresh: diff --git a/zipline/test_algorithms.py b/zipline/test_algorithms.py index 10795e33..a8bc6ab4 100644 --- a/zipline/test_algorithms.py +++ b/zipline/test_algorithms.py @@ -399,33 +399,53 @@ class TestRegisterTransformAlgorithm(TradingAlgorithm): def handle_data(self, data): pass -class NoopBatchTransform(BatchTransform): +########################################## +# Algorithm using simple batch transforms + +class ReturnPriceBatchTransform(BatchTransform): def get_value(self, data): return data.price @batch_transform -def noop_batch_decorator(data): +def return_price_batch_decorator(data): return data.price +@batch_transform +def return_args_batch_decorator(data, *args, **kwargs): + return args, kwargs + class BatchTransformAlgorithm(TradingAlgorithm): def initialize(self, *args, **kwargs): - self.history_class = [] - self.history_decorator = [] - self.days = 3 - self.noop_class = NoopBatchTransform(sids=[0, 1], - market_aware=False, - refresh_period=2, - delta=timedelta(days=self.days)) + self.history_return_price_class = [] + self.history_return_price_decorator = [] + self.history_return_args = [] - self.noop_decorator = noop_batch_decorator(sids=[0, 1], - market_aware=False, - refresh_period=2, - delta=timedelta(days=self.days)) + self.days = 3 + + self.args = args + self.kwargs = kwargs + + self.return_price_class = ReturnPriceBatchTransform(sids=self.sids, + market_aware=False, + refresh_period=2, + delta=timedelta(days=self.days) + ) + + self.return_price_decorator = return_price_batch_decorator(sids=self.sids, + market_aware=False, + refresh_period=2, + delta=timedelta(days=self.days) + ) + + self.return_args_batch = return_args_batch_decorator(sids=self.sids, + market_aware=False, + refresh_period=2, + delta=timedelta(days=self.days) + ) def handle_data(self, data): - window_class = self.noop_class.handle_data(data) - window_decorator = self.noop_decorator.handle_data(data) - self.history_class.append(window_class) - self.history_decorator.append(window_decorator) + self.history_return_price_class.append(self.return_price_class.handle_data(data)) + self.history_return_price_decorator.append(self.return_price_decorator.handle_data(data)) + self.history_return_args.append(self.return_args_batch.handle_data(data, *self.args, **self.kwargs))