From 4c02fea6e2d30b9f89d1780d058a5cf20c791399 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Mon, 19 Nov 2012 10:20:24 -0500 Subject: [PATCH] BUG: batch_transform was wrongly updating when days < refresh_period. --- tests/test_transforms.py | 8 ++++ zipline/test_algorithms.py | 9 ++++ zipline/transforms/utils.py | 86 ++++++++++++++++++++++--------------- 3 files changed, 69 insertions(+), 34 deletions(-) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 8c979095..17e30285 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -327,6 +327,14 @@ class TestBatchTransform(TestCase): self.assertEqual(algo.history_return_price_market_aware[:2], [None, None], "First two iterations should return None") + self.assertEqual(algo.history_return_more_days_than_refresh[:3], + [None, None, None], + "First five iterations should return None") + self.assertTrue(isinstance( + algo.history_return_more_days_than_refresh[4], + pd.DataFrame), + "Sixth iteration should not be None" + ) # test overloaded class for test_history in [algo.history_return_price_class, diff --git a/zipline/test_algorithms.py b/zipline/test_algorithms.py index 97a9157f..c8a5d30b 100644 --- a/zipline/test_algorithms.py +++ b/zipline/test_algorithms.py @@ -260,6 +260,7 @@ class BatchTransformAlgorithm(TradingAlgorithm): self.history_return_price_decorator = [] self.history_return_args = [] self.history_return_price_market_aware = [] + self.history_return_more_days_than_refresh = [] self.return_price_class = ReturnPriceBatchTransform( market_aware=False, @@ -285,6 +286,12 @@ class BatchTransformAlgorithm(TradingAlgorithm): days=self.days ) + self.return_price_more_days_than_refresh = ReturnPriceBatchTransform( + market_aware=True, + refresh_period=1, + days=3 + ) + self.set_slippage(FixedSlippage()) def handle_data(self, data): @@ -297,6 +304,8 @@ class BatchTransformAlgorithm(TradingAlgorithm): data, *self.args, **self.kwargs)) self.history_return_price_market_aware.append( self.return_price_market_aware.handle_data(data)) + self.history_return_more_days_than_refresh.append( + self.return_price_more_days_than_refresh.handle_data(data)) class SetPortfolioAlgorithm(TradingAlgorithm): diff --git a/zipline/transforms/utils.py b/zipline/transforms/utils.py index 242a2b18..4da8926a 100644 --- a/zipline/transforms/utils.py +++ b/zipline/transforms/utils.py @@ -361,6 +361,7 @@ class BatchTransform(EventWindow): self.refresh_period = refresh_period self.days = days self.trading_days_since_update = 0 + self.trading_days_total = 0 self.full = False self.last_dt = None @@ -397,49 +398,59 @@ class BatchTransform(EventWindow): self.last_dt = event.dt return + # update trading day counters if self.last_dt.day != event.dt.day: self.last_dt = event.dt self.trading_days_since_update += 1 + self.trading_days_total += 1 - if self.trading_days_since_update >= self.refresh_period: - # Create a pandas.Panel (i.e. 3d DataFrame) from the - # events in the current window. - # - # The resulting panel looks like this: - # index : field_name (e.g. price) - # major axis/rows : dt - # minor axis/colums : sid - # - # This Panel data structure ultimately gets passed to the - # user-overloaded get_value() method. - # - # self.ticks contains ndicts with data, dt keys. - # event parameter is an ndict with data, dt keys. - fields = {} - for field_name in ['price', 'volume']: - sids = self.ticks[0].data.keys() - # Skip non-existant fields - if field_name not in self.ticks[0].data[sids[0]]: - continue - - values_per_sid = {} - - for sid in sids: - values_per_sid[sid] = pd.Series( - {tick.data[sid].dt: tick.data[sid][field_name] - for tick in self.ticks} - ) - - # concatenate different sids into one df - fields[field_name] = pd.DataFrame.from_dict(values_per_sid) - - self.data = pd.Panel.from_dict(fields, orient='items') - + if self.trading_days_since_update >= self.refresh_period and\ + self.trading_days_total >= self.days: + # Create datapanel of running event window. + self.data = self.get_data() + # Setting updated to True will cause get_transform_value() + # to call the user-defined batch-transform with the most + # recent datapanel self.updated = True self.trading_days_since_update = 0 else: self.updated = False + def get_data(self): + # Create a pandas.Panel (i.e. 3d DataFrame) from the + # events in the current window. + # + # The resulting panel looks like this: + # index : field_name (e.g. price) + # major axis/rows : dt + # minor axis/colums : sid + # + # This Panel data structure ultimately gets passed to the + # user-overloaded get_value() method. + # + # self.ticks contains ndicts with data, dt keys. + # event parameter is an ndict with data, dt keys. + fields = {} + for field_name in ['price', 'volume']: + sids = self.ticks[0].data.keys() + # Skip non-existant fields + if field_name not in self.ticks[0].data[sids[0]]: + continue + + values_per_sid = {} + + for sid in sids: + values_per_sid[sid] = pd.Series( + {tick.data[sid].dt: tick.data[sid][field_name] + for tick in self.ticks} + ) + + # concatenate different sids into one df + fields[field_name] = pd.DataFrame.from_dict(values_per_sid) + + data = pd.Panel.from_dict(fields, orient='items') + return data + def handle_remove(self, event): # since an event is expiring, we know the window is full self.full = True @@ -449,6 +460,13 @@ class BatchTransform(EventWindow): "Either overwrite get_value or provide a func argument.") def get_transform_value(self, *args, **kwargs): + """Call user-defined batch-transform function passing all + arguments. + + Note that this will only call the transform if the datapanel + has actually been updated. Otherwise, the previously, cached + value will be returned. + """ if self.data is None: return None