diff --git a/tests/test_transforms.py b/tests/test_transforms.py index f3b2ffe2..a83785e0 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -342,6 +342,10 @@ class TestBatchTransform(TestCase): 'arbitrary dataframe should contain only "test"' ) + for data in algo.history_return_sid_filter[wl:]: + self.assertIn(0, data.columns) + self.assertNotIn(1, data.columns) + # test overloaded class for test_history in [algo.history_return_price_class, algo.history_return_price_decorator]: diff --git a/zipline/test_algorithms.py b/zipline/test_algorithms.py index ee3e41a2..7dbc1385 100644 --- a/zipline/test_algorithms.py +++ b/zipline/test_algorithms.py @@ -268,6 +268,7 @@ class BatchTransformAlgorithm(TradingAlgorithm): self.history_return_args = [] self.history_return_arbitrary_fields = [] self.history_return_nan = [] + self.history_return_sid_filter = [] self.return_price_class = ReturnPriceBatchTransform( refresh_period=self.refresh_period, @@ -305,6 +306,13 @@ class BatchTransformAlgorithm(TradingAlgorithm): clean_nans=True ) + self.return_sid_filter = return_price_batch_decorator( + refresh_period=self.refresh_period, + window_length=self.window_length, + clean_nans=True, + sids=[0] + ) + self.iter = 0 self.set_slippage(FixedSlippage()) @@ -339,6 +347,13 @@ class BatchTransformAlgorithm(TradingAlgorithm): self.iter += 1 + # Add a new sid to check that it does not get included + extra_sid_data = deepcopy(data) + extra_sid_data[1] = extra_sid_data[0] + self.history_return_sid_filter.append( + self.return_sid_filter.handle_data(extra_sid_data) + ) + class SetPortfolioAlgorithm(TradingAlgorithm): """ diff --git a/zipline/transforms/utils.py b/zipline/transforms/utils.py index 369f8054..5f4c706a 100644 --- a/zipline/transforms/utils.py +++ b/zipline/transforms/utils.py @@ -343,7 +343,27 @@ class BatchTransform(EventWindow): func=None, refresh_period=None, window_length=None, - clean_nans=True): + clean_nans=True, + sids=None): + """Instantiate new batch_transform object. + + :Arguments: + func : python function + If supplied will be called after each refresh_period + with the data panel and all args and kwargs supplied + to the handle_data() call. + refresh_period : int + Interval to call batch_transform function. + window_length : int + How many days the trailing window should have. + clean_nans : bool + Whether to (forward) fill in nans. + sids : list + Which sids to include in the moving window. If not + supplied sids will be extracted from incoming + events. + + """ super(BatchTransform, self).__init__(True, window_length=window_length) @@ -355,6 +375,8 @@ class BatchTransform(EventWindow): self.clean_nans = clean_nans + self.sids = sids + self.refresh_period = refresh_period self.window_length = window_length self.trading_days_since_update = 0 @@ -445,7 +467,13 @@ class BatchTransform(EventWindow): """ # This Panel data structure ultimately gets passed to the # user-overloaded get_value() method. - sids = set.union(*[set(tick.data.keys()) for tick in self.ticks]) + + # If sids are set, use those. Otherwise extract. + if self.sids is not None: + sids = self.sids + else: + sids = set.union(*[set(tick.data.keys()) for tick in self.ticks]) + dts = [tick.dt for tick in self.ticks] data = pd.Panel(items=self.field_names, major_axis=dts, @@ -454,9 +482,10 @@ class BatchTransform(EventWindow): # Fill data panel for tick in self.ticks: dt = tick.dt - for sid, fields in tick.data.iteritems(): + for sid in sids: + fields = tick.data[sid] for field_name in self.field_names: - data[field_name][sid].ix[dt] = fields[field_name] + data[field_name][sid].ix[dt] = fields[field_name] if self.clean_nans: # Fills in gaps of missing data during transform