mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-02 08:27:53 +08:00
added tests to ensure:
- repeated calls with the same data window do not update batch transform
windows.
- repeated calls with the same data and same supplemental parameters do
not update batch transform results
- repeated calls with the same data and different supplemental params
do update batch transform results
This commit is contained in:
@@ -273,6 +273,11 @@ def uses_ufunc(data, *args, **kwargs):
|
||||
return np.log(data)
|
||||
|
||||
|
||||
@batch_transform
|
||||
def price_multiple(data, multiplier, keyarg=1):
|
||||
return data.price * multiplier * keyarg
|
||||
|
||||
|
||||
class BatchTransformAlgorithm(TradingAlgorithm):
|
||||
def initialize(self, *args, **kwargs):
|
||||
self.refresh_period = kwargs.pop('refresh_period', 1)
|
||||
@@ -354,6 +359,12 @@ class BatchTransformAlgorithm(TradingAlgorithm):
|
||||
clean_nans=False
|
||||
)
|
||||
|
||||
self.price_multiple = price_multiple(
|
||||
refresh_period=self.refresh_period,
|
||||
window_length=self.window_length,
|
||||
clean_nans=False
|
||||
)
|
||||
|
||||
self.iter = 0
|
||||
|
||||
self.set_slippage(FixedSlippage())
|
||||
@@ -370,6 +381,29 @@ class BatchTransformAlgorithm(TradingAlgorithm):
|
||||
self.return_not_full.handle_data(data))
|
||||
self.uses_ufunc.handle_data(data)
|
||||
|
||||
# check that calling transforms with the same arguments
|
||||
# is idempotent
|
||||
self.price_multiple.handle_data(data, 1, keyarg=1)
|
||||
|
||||
if self.price_multiple.full:
|
||||
pre = len(self.price_multiple.ticks)
|
||||
result1 = self.price_multiple.handle_data(data, 1, keyarg=1)
|
||||
post = len(self.price_multiple.ticks)
|
||||
assert pre == post, "batch transform is appending redundant events"
|
||||
result2 = self.price_multiple.handle_data(data, 1, keyarg=1)
|
||||
assert result1 is result2, "batch transform is not idempotent"
|
||||
|
||||
# check that calling transform with the same data, but
|
||||
# different supplemental arguments results in new
|
||||
# results.
|
||||
result3 = self.price_multiple.handle_data(data, 2, keyarg=1)
|
||||
assert result1 is not result3, \
|
||||
"batch transform is not updating for new args"
|
||||
|
||||
result4 = self.price_multiple.handle_data(data, 1, keyarg=2)
|
||||
assert result1 is not result4,\
|
||||
"batch transform is not updating for new kwargs"
|
||||
|
||||
new_data = deepcopy(data)
|
||||
for sid in new_data:
|
||||
new_data[sid]['arbitrary'] = 123
|
||||
|
||||
@@ -228,8 +228,6 @@ class EventWindow(object):
|
||||
# Subclasses should override handle_add to define behavior for
|
||||
# adding new ticks.
|
||||
self.handle_add(event)
|
||||
#if len(self.ticks) > self.window_length:
|
||||
# import nose.tools; nose.tools.set_trace()
|
||||
# Clear out any expired events.
|
||||
#
|
||||
# oldest newest
|
||||
@@ -406,9 +404,19 @@ class BatchTransform(EventWindow):
|
||||
# functionality to zipline
|
||||
if len(v)}
|
||||
|
||||
# append data frame to window. update() will call handle_add() and
|
||||
# handle_remove() appropriately
|
||||
self.update(event)
|
||||
# only modify the trailing window if this is
|
||||
# a new event. This is intended to make handle_data
|
||||
# idempotent.
|
||||
if event not in self.ticks:
|
||||
# append data frame to window. update() will call handle_add() and
|
||||
# handle_remove() appropriately, and self.updated
|
||||
# will be modified based on the refresh_period
|
||||
self.update(event)
|
||||
else:
|
||||
# we are recalculating based on an old event, so
|
||||
# there is no change in the contents of the trailing
|
||||
# window
|
||||
self.updated = False
|
||||
|
||||
# return newly computed or cached value
|
||||
return self.get_transform_value(*args, **kwargs)
|
||||
@@ -449,7 +457,6 @@ class BatchTransform(EventWindow):
|
||||
# to call the user-defined batch-transform with the most
|
||||
# recent datapanel
|
||||
self.updated = True
|
||||
self.trading_days_since_update = 0
|
||||
else:
|
||||
self.updated = False
|
||||
|
||||
@@ -516,10 +523,10 @@ class BatchTransform(EventWindow):
|
||||
if self.updated:
|
||||
# Create new pandas panel
|
||||
self.window = self.get_data()
|
||||
# reset our counter for refresh_period
|
||||
self.trading_days_since_update = 0
|
||||
|
||||
args_changed = args != self.last_args
|
||||
args_changed = args_changed or kwargs != self.last_kwargs
|
||||
|
||||
args_changed = args != self.last_args or kwargs != self.last_kwargs
|
||||
if self.updated or args_changed:
|
||||
self.cached = self.compute_transform_value(
|
||||
self.window,
|
||||
|
||||
Reference in New Issue
Block a user