diff --git a/tests/test_transforms.py b/tests/test_transforms.py index f3b2ffe2..8b775990 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import deque + import pytz import numpy as np import pandas as pd @@ -342,6 +344,24 @@ class TestBatchTransform(TestCase): 'arbitrary dataframe should contain only "test"' ) + for data in algo.history_return_sid_filter[wl:]: + self.assertIn(0, data.columns) + self.assertNotIn(1, data.columns) + + for data in algo.history_return_field_filter[wl:]: + self.assertIn('price', data.items) + self.assertNotIn('ignore', data.items) + + for data in algo.history_return_field_no_filter[wl:]: + self.assertIn('price', data.items) + self.assertIn('ignore', data.items) + + for data in algo.history_return_ticks[wl:]: + self.assertTrue(isinstance(data, deque)) + + for data in algo.history_return_not_full: + self.assertIsNot(data, None) + # test overloaded class for test_history in [algo.history_return_price_class, algo.history_return_price_decorator]: diff --git a/zipline/test_algorithms.py b/zipline/test_algorithms.py index ee3e41a2..b22d2719 100644 --- a/zipline/test_algorithms.py +++ b/zipline/test_algorithms.py @@ -72,6 +72,7 @@ The algorithm must expose methods: """ from copy import deepcopy +import numpy as np from zipline.algorithm import TradingAlgorithm from zipline.finance.slippage import FixedSlippage @@ -268,6 +269,11 @@ class BatchTransformAlgorithm(TradingAlgorithm): self.history_return_args = [] self.history_return_arbitrary_fields = [] self.history_return_nan = [] + self.history_return_sid_filter = [] + self.history_return_field_filter = [] + self.history_return_field_no_filter = [] + self.history_return_ticks = [] + self.history_return_not_full = [] self.return_price_class = ReturnPriceBatchTransform( refresh_period=self.refresh_period, @@ -305,6 +311,38 @@ class BatchTransformAlgorithm(TradingAlgorithm): clean_nans=True ) + self.return_sid_filter = return_price_batch_decorator( + refresh_period=self.refresh_period, + window_length=self.window_length, + clean_nans=True, + sids=[0] + ) + + self.return_field_filter = return_data( + refresh_period=self.refresh_period, + window_length=self.window_length, + clean_nans=True, + fields=['price'] + ) + + self.return_field_no_filter = return_data( + refresh_period=self.refresh_period, + window_length=self.window_length, + clean_nans=True + ) + + self.return_ticks = return_data( + refresh_period=self.refresh_period, + window_length=self.window_length, + create_panel=False + ) + + self.return_not_full = return_data( + refresh_period=0, + window_length=self.window_length, + compute_only_full=False + ) + self.iter = 0 self.set_slippage(FixedSlippage()) @@ -317,6 +355,10 @@ class BatchTransformAlgorithm(TradingAlgorithm): self.history_return_args.append( self.return_args_batch.handle_data( data, *self.args, **self.kwargs)) + self.history_return_ticks.append( + self.return_ticks.handle_data(data)) + self.history_return_not_full.append( + self.return_not_full.handle_data(data)) new_data = deepcopy(data) for sid in new_data: @@ -331,7 +373,6 @@ class BatchTransformAlgorithm(TradingAlgorithm): self.return_nan.handle_data(data)) else: nan_data = deepcopy(data) - import numpy as np for sid in nan_data.iterkeys(): nan_data[sid].price = np.nan self.history_return_nan.append( @@ -339,6 +380,23 @@ class BatchTransformAlgorithm(TradingAlgorithm): self.iter += 1 + # Add a new sid to check that it does not get included + extra_sid_data = deepcopy(data) + extra_sid_data[1] = extra_sid_data[0] + self.history_return_sid_filter.append( + self.return_sid_filter.handle_data(extra_sid_data) + ) + + # Add a field to check that it does not get included + extra_field_data = deepcopy(data) + extra_field_data[0]['ignore'] = extra_sid_data[0]['price'] + self.history_return_field_filter.append( + self.return_field_filter.handle_data(extra_field_data) + ) + self.history_return_field_no_filter.append( + self.return_field_no_filter.handle_data(extra_field_data) + ) + class SetPortfolioAlgorithm(TradingAlgorithm): """ diff --git a/zipline/transforms/utils.py b/zipline/transforms/utils.py index 369f8054..e343a40c 100644 --- a/zipline/transforms/utils.py +++ b/zipline/transforms/utils.py @@ -24,6 +24,7 @@ from copy import deepcopy from datetime import datetime from collections import deque from abc import ABCMeta, abstractmethod +from numbers import Integral import pandas as pd @@ -343,7 +344,43 @@ class BatchTransform(EventWindow): func=None, refresh_period=None, window_length=None, - clean_nans=True): + clean_nans=True, + sids=None, + fields=None, + create_panel=True, + compute_only_full=True): + + """Instantiate new batch_transform object. + + :Arguments: + func : python function + If supplied will be called after each refresh_period + with the data panel and all args and kwargs supplied + to the handle_data() call. + refresh_period : int + Interval to call batch_transform function. + window_length : int + How many days the trailing window should have. + clean_nans : bool + Whether to (forward) fill in nans. + sids : list + Which sids to include in the moving window. If not + supplied sids will be extracted from incoming + events. + fields : list + Which fields to include in the moving window + (e.g. 'price'). If not supplied, fields will be + extracted from incoming events. + create_panel : bool + If True, will create a pandas panel every refresh + period and pass it to the user-defined function. + If False, will pass the underlying deque reference + directly to the function which will be significantly + faster. + compute_only_full : bool + Only call the user-defined function once the window is + full. Returns None if window is not full yet. + """ super(BatchTransform, self).__init__(True, window_length=window_length) @@ -354,6 +391,16 @@ class BatchTransform(EventWindow): self.compute_transform_value = self.get_value self.clean_nans = clean_nans + self.create_panel = create_panel + self.compute_only_full = compute_only_full + + self.sids = sids + if isinstance(self.sids, (basestring, Integral)): + self.sids = [self.sids] + + self.field_names = fields + if isinstance(self.field_names, str): + self.field_names = [self.field_names] self.refresh_period = refresh_period self.window_length = window_length @@ -366,8 +413,6 @@ class BatchTransform(EventWindow): self.updated = False self.cached = None - self.field_names = None - def handle_data(self, data, *args, **kwargs): """ New method to handle a data frame as sent to the algorithm's @@ -410,9 +455,9 @@ class BatchTransform(EventWindow): def handle_add(self, event): if not self.last_dt: - self.field_names = self._extract_field_names(event) + if self.field_names is None: + self.field_names = self._extract_field_names(event) self.last_dt = event.dt - return # update trading day counters if self.last_dt.day != event.dt.day: @@ -420,15 +465,14 @@ class BatchTransform(EventWindow): self.trading_days_since_update += 1 self.trading_days_total += 1 - if ( - self.trading_days_total >= self.window_length and - self.trading_days_since_update >= self.refresh_period - ): + if self.trading_days_total >= self.window_length: + self.full = True + + if self.trading_days_since_update >= self.refresh_period: # Setting updated to True will cause get_transform_value() # to call the user-defined batch-transform with the most # recent datapanel self.updated = True - self.full = True self.trading_days_since_update = 0 else: self.updated = False @@ -445,7 +489,13 @@ class BatchTransform(EventWindow): """ # This Panel data structure ultimately gets passed to the # user-overloaded get_value() method. - sids = set.union(*[set(tick.data.keys()) for tick in self.ticks]) + + # If sids are set, use those. Otherwise extract. + if self.sids is not None: + sids = self.sids + else: + sids = set.union(*[set(tick.data.keys()) for tick in self.ticks]) + dts = [tick.dt for tick in self.ticks] data = pd.Panel(items=self.field_names, major_axis=dts, @@ -454,9 +504,10 @@ class BatchTransform(EventWindow): # Fill data panel for tick in self.ticks: dt = tick.dt - for sid, fields in tick.data.iteritems(): + for sid in sids: + fields = tick.data[sid] for field_name in self.field_names: - data[field_name][sid].ix[dt] = fields[field_name] + data[field_name][sid].ix[dt] = fields[field_name] if self.clean_nans: # Fills in gaps of missing data during transform @@ -471,8 +522,7 @@ class BatchTransform(EventWindow): return data def handle_remove(self, event): - # since an event is expiring, we know the window is full - self.full = True + pass def get_value(self, *args, **kwargs): raise NotImplementedError( @@ -486,12 +536,15 @@ class BatchTransform(EventWindow): has actually been updated. Otherwise, the previously, cached value will be returned. """ - if not self.full: + if self.compute_only_full and not self.full: return None if self.updated: - self.cached = self.compute_transform_value(self.get_data(), - *args, **kwargs) + # Either create new pandas panel or pass ticks dequeue + # directly + data = self.get_data() if self.create_panel else self.ticks + self.cached = self.compute_transform_value(data, *args, + **kwargs) return self.cached