diff --git a/tests/test_transforms.py b/tests/test_transforms.py index a83785e0..1b2536dc 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -346,6 +346,14 @@ class TestBatchTransform(TestCase): self.assertIn(0, data.columns) self.assertNotIn(1, data.columns) + for data in algo.history_return_field_filter[wl:]: + self.assertIn('price', data.items) + self.assertNotIn('ignore', data.items) + + for data in algo.history_return_field_no_filter[wl:]: + self.assertIn('price', data.items) + self.assertIn('ignore', data.items) + # test overloaded class for test_history in [algo.history_return_price_class, algo.history_return_price_decorator]: diff --git a/zipline/test_algorithms.py b/zipline/test_algorithms.py index 7dbc1385..735e3348 100644 --- a/zipline/test_algorithms.py +++ b/zipline/test_algorithms.py @@ -269,6 +269,8 @@ class BatchTransformAlgorithm(TradingAlgorithm): self.history_return_arbitrary_fields = [] self.history_return_nan = [] self.history_return_sid_filter = [] + self.history_return_field_filter = [] + self.history_return_field_no_filter = [] self.return_price_class = ReturnPriceBatchTransform( refresh_period=self.refresh_period, @@ -313,6 +315,19 @@ class BatchTransformAlgorithm(TradingAlgorithm): sids=[0] ) + self.return_field_filter = return_data( + refresh_period=self.refresh_period, + window_length=self.window_length, + clean_nans=True, + fields=['price'] + ) + + self.return_field_no_filter = return_data( + refresh_period=self.refresh_period, + window_length=self.window_length, + clean_nans=True + ) + self.iter = 0 self.set_slippage(FixedSlippage()) @@ -354,6 +369,16 @@ class BatchTransformAlgorithm(TradingAlgorithm): self.return_sid_filter.handle_data(extra_sid_data) ) + # Add a field to check that it does not get included + extra_field_data = deepcopy(data) + extra_field_data[0]['ignore'] = extra_sid_data[0]['price'] + self.history_return_field_filter.append( + self.return_field_filter.handle_data(extra_field_data) + ) + self.history_return_field_no_filter.append( + self.return_field_no_filter.handle_data(extra_field_data) + ) + class SetPortfolioAlgorithm(TradingAlgorithm): """ diff --git a/zipline/transforms/utils.py b/zipline/transforms/utils.py index 5f4c706a..117fe8c8 100644 --- a/zipline/transforms/utils.py +++ b/zipline/transforms/utils.py @@ -344,7 +344,8 @@ class BatchTransform(EventWindow): refresh_period=None, window_length=None, clean_nans=True, - sids=None): + sids=None, + fields=None): """Instantiate new batch_transform object. :Arguments: @@ -362,7 +363,10 @@ class BatchTransform(EventWindow): Which sids to include in the moving window. If not supplied sids will be extracted from incoming events. - + fields : list + Which fields to include in the moving window + (e.g. 'price'). If not supplied, fields will be + extracted from incoming events. """ super(BatchTransform, self).__init__(True, @@ -388,7 +392,7 @@ class BatchTransform(EventWindow): self.updated = False self.cached = None - self.field_names = None + self.field_names = fields def handle_data(self, data, *args, **kwargs): """ @@ -432,7 +436,8 @@ class BatchTransform(EventWindow): def handle_add(self, event): if not self.last_dt: - self.field_names = self._extract_field_names(event) + if self.field_names is None: + self.field_names = self._extract_field_names(event) self.last_dt = event.dt return