diff --git a/zipline/test_algorithms.py b/zipline/test_algorithms.py index 470e9bf8..e4d9b559 100644 --- a/zipline/test_algorithms.py +++ b/zipline/test_algorithms.py @@ -273,6 +273,11 @@ def uses_ufunc(data, *args, **kwargs): return np.log(data) +@batch_transform +def price_multiple(data, multiplier, keyarg=1): + return data.price * multiplier * keyarg + + class BatchTransformAlgorithm(TradingAlgorithm): def initialize(self, *args, **kwargs): self.refresh_period = kwargs.pop('refresh_period', 1) @@ -354,6 +359,12 @@ class BatchTransformAlgorithm(TradingAlgorithm): clean_nans=False ) + self.price_multiple = price_multiple( + refresh_period=self.refresh_period, + window_length=self.window_length, + clean_nans=False + ) + self.iter = 0 self.set_slippage(FixedSlippage()) @@ -370,6 +381,29 @@ class BatchTransformAlgorithm(TradingAlgorithm): self.return_not_full.handle_data(data)) self.uses_ufunc.handle_data(data) + # check that calling transforms with the same arguments + # is idempotent + self.price_multiple.handle_data(data, 1, keyarg=1) + + if self.price_multiple.full: + pre = len(self.price_multiple.ticks) + result1 = self.price_multiple.handle_data(data, 1, keyarg=1) + post = len(self.price_multiple.ticks) + assert pre == post, "batch transform is appending redundant events" + result2 = self.price_multiple.handle_data(data, 1, keyarg=1) + assert result1 is result2, "batch transform is not idempotent" + + # check that calling transform with the same data, but + # different supplemental arguments results in new + # results. + result3 = self.price_multiple.handle_data(data, 2, keyarg=1) + assert result1 is not result3, \ + "batch transform is not updating for new args" + + result4 = self.price_multiple.handle_data(data, 1, keyarg=2) + assert result1 is not result4,\ + "batch transform is not updating for new kwargs" + new_data = deepcopy(data) for sid in new_data: new_data[sid]['arbitrary'] = 123 diff --git a/zipline/transforms/utils.py b/zipline/transforms/utils.py index 56212ae4..cac1c119 100644 --- a/zipline/transforms/utils.py +++ b/zipline/transforms/utils.py @@ -228,8 +228,6 @@ class EventWindow(object): # Subclasses should override handle_add to define behavior for # adding new ticks. self.handle_add(event) - #if len(self.ticks) > self.window_length: - # import nose.tools; nose.tools.set_trace() # Clear out any expired events. # # oldest newest @@ -406,9 +404,19 @@ class BatchTransform(EventWindow): # functionality to zipline if len(v)} - # append data frame to window. update() will call handle_add() and - # handle_remove() appropriately - self.update(event) + # only modify the trailing window if this is + # a new event. This is intended to make handle_data + # idempotent. + if event not in self.ticks: + # append data frame to window. update() will call handle_add() and + # handle_remove() appropriately, and self.updated + # will be modified based on the refresh_period + self.update(event) + else: + # we are recalculating based on an old event, so + # there is no change in the contents of the trailing + # window + self.updated = False # return newly computed or cached value return self.get_transform_value(*args, **kwargs) @@ -449,7 +457,6 @@ class BatchTransform(EventWindow): # to call the user-defined batch-transform with the most # recent datapanel self.updated = True - self.trading_days_since_update = 0 else: self.updated = False @@ -516,10 +523,10 @@ class BatchTransform(EventWindow): if self.updated: # Create new pandas panel self.window = self.get_data() + # reset our counter for refresh_period + self.trading_days_since_update = 0 - args_changed = args != self.last_args - args_changed = args_changed or kwargs != self.last_kwargs - + args_changed = args != self.last_args or kwargs != self.last_kwargs if self.updated or args_changed: self.cached = self.compute_transform_value( self.window,