diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 215c523b..b1a013ce 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -17,13 +17,18 @@ import pytz import numpy as np from datetime import timedelta, datetime +from itertools import chain from unittest import TestCase +from nose_parameterized import parameterized from six.moves import range from zipline.utils.test_utils import setup_logger -from zipline.protocol import Event +from zipline.protocol import ( + DATASOURCE_TYPE, + Event, +) from zipline.sources import SpecificEquityTrades from zipline.transforms.utils import StatefulTransform, EventWindow from zipline.transforms import MovingVWAP @@ -47,6 +52,11 @@ class NoopEventWindow(EventWindow): self.added = [] self.removed = [] + self._fields = [] + + @property + def fields(self): + return self._fields def handle_add(self, event): self.added.append(event) @@ -134,20 +144,53 @@ class TestFinanceTransforms(TestCase): timedelta(days=1), self.sim_params ) - self.source = SpecificEquityTrades(event_list=trade_history) + self.source = trade_history + + def intersperse_custom_events(self, events): + """ + Take a stream of events and return the same stream with a minimal event + of type CUSTOM following each trade event. Used to test graceful + handling of CUSTOM events that are missing required transform fields. + """ + return list( + chain.from_iterable( + ( + event, + Event( + initial_values={ + 'dt': event.dt, + 'sid': event.sid, + 'source_id': "fake_custom_source", + 'type': DATASOURCE_TYPE.CUSTOM + } + ) + ) + for event in events + ) + ) def tearDown(self): self.log_handler.pop_application() - def test_vwap(self): + @parameterized.expand([ + ('with_custom', True), + ('without_custom', False), + ]) + def test_vwap(self, name, add_custom_events): vwap = MovingVWAP( market_aware=True, window_length=2 ) + + if add_custom_events: + self.source = self.intersperse_custom_events(self.source) + transformed = list(vwap.transform(self.source)) - # Output values - tnfm_vals = [message[vwap.get_hash()] for message in transformed] + # Output values. Unprocessed custom events will not have a field + # corresponding to the transform hash. + tnfm_vals = [message[vwap.get_hash()] for message in transformed + if message.type != DATASOURCE_TYPE.CUSTOM] # "Hand calculated" values. expected = [ (10.0 * 100) / 100.0, @@ -161,12 +204,20 @@ class TestFinanceTransforms(TestCase): # Output should match the expected. self.assertEquals(tnfm_vals, expected) - def test_returns(self): + @parameterized.expand([ + ('with_custom', True), + ('without_custom', False), + ]) + def test_returns(self, name, add_custom_events): # Daily returns. returns = Returns(1) + if add_custom_events: + self.source = self.intersperse_custom_events(self.source) + transformed = list(returns.transform(self.source)) - tnfm_vals = [message[returns.get_hash()] for message in transformed] + tnfm_vals = [message[returns.get_hash()] for message in transformed + if message.type != DATASOURCE_TYPE.CUSTOM] # No returns for the first event because we don't have a # previous close. @@ -202,7 +253,11 @@ class TestFinanceTransforms(TestCase): self.assertEquals(tnfm_vals, expected) - def test_moving_average(self): + @parameterized.expand([ + ('with_custom', True), + ('without_custom', False), + ]) + def test_moving_average(self, name, add_custom_events): mavg = MovingAverage( market_aware=True, @@ -210,12 +265,17 @@ class TestFinanceTransforms(TestCase): window_length=2 ) + if add_custom_events: + self.source = self.intersperse_custom_events(self.source) + transformed = list(mavg.transform(self.source)) # Output values. tnfm_prices = [message[mavg.get_hash()].price - for message in transformed] + for message in transformed + if message.type != DATASOURCE_TYPE.CUSTOM] tnfm_volumes = [message[mavg.get_hash()].volume - for message in transformed] + for message in transformed + if message.type != DATASOURCE_TYPE.CUSTOM] # "Hand-calculated" values expected_prices = [ @@ -238,7 +298,11 @@ class TestFinanceTransforms(TestCase): self.assertEquals(tnfm_prices, expected_prices) self.assertEquals(tnfm_volumes, expected_volumes) - def test_moving_stddev(self): + @parameterized.expand([ + ('with_custom', True), + ('without_custom', False), + ]) + def test_moving_stddev(self, name, add_custom_events): trade_history = factory.create_trade_history( 133, [10.0, 15.0, 13.0, 12.0], @@ -253,10 +317,13 @@ class TestFinanceTransforms(TestCase): ) self.source = SpecificEquityTrades(event_list=trade_history) + if add_custom_events: + self.source = self.intersperse_custom_events(self.source) transformed = list(stddev.transform(self.source)) - vals = [message[stddev.get_hash()] for message in transformed] + vals = [message[stddev.get_hash()] for message in transformed + if message.type != DATASOURCE_TYPE.CUSTOM] expected = [ None, diff --git a/zipline/transforms/mavg.py b/zipline/transforms/mavg.py index dc79bcc1..fd9a5ec1 100644 --- a/zipline/transforms/mavg.py +++ b/zipline/transforms/mavg.py @@ -18,7 +18,6 @@ from collections import defaultdict from six import string_types, with_metaclass from zipline.transforms.utils import EventWindow, TransformMeta -from zipline.errors import WrongDataForTransform class MovingAverage(with_metaclass(TransformMeta)): @@ -108,13 +107,16 @@ class MovingAverageEventWindow(EventWindow): # We maintain a dictionary of totals for each of our tracked # fields. - self.fields = fields + self._fields = fields self.totals = defaultdict(float) + @property + def fields(self): + return self._fields + # Subclass customization for adding new events. def handle_add(self, event): # Sanity check on the event. - self.assert_required_fields(event) # Increment our running totals with data from the event. for field in self.fields: self.totals[field] += event[field] @@ -148,13 +150,3 @@ class MovingAverageEventWindow(EventWindow): for field in self.fields: out.__dict__[field] = self.average(field) return out - - def assert_required_fields(self, event): - """ - We only allow events with all of our tracked fields. - """ - for field in self.fields: - if field not in event: - raise WrongDataForTransform( - transform="MovingAverageEventWindow", - fields=self.fields) diff --git a/zipline/transforms/returns.py b/zipline/transforms/returns.py index be0bb8d9..401a9cda 100644 --- a/zipline/transforms/returns.py +++ b/zipline/transforms/returns.py @@ -59,7 +59,7 @@ class ReturnsFromPriorClose(object): self.window_length = window_length def update(self, event): - self.assert_required_fields(event) + self.check_required_fields(event) if self.last_event: # Day has changed since the last event we saw. Treat @@ -91,7 +91,7 @@ class ReturnsFromPriorClose(object): # the current event is now the last_event self.last_event = event - def assert_required_fields(self, event): + def check_required_fields(self, event): """ We only allow events with a price field to be run through the returns transform. @@ -99,4 +99,4 @@ class ReturnsFromPriorClose(object): if 'price' not in event: raise WrongDataForTransform( transform="ReturnsEventWindow", - fields='price') + fields=['price']) diff --git a/zipline/transforms/stddev.py b/zipline/transforms/stddev.py index 7571e1b1..06b767ce 100644 --- a/zipline/transforms/stddev.py +++ b/zipline/transforms/stddev.py @@ -18,7 +18,6 @@ from math import sqrt from six import with_metaclass -from zipline.errors import WrongDataForTransform from zipline.transforms.utils import EventWindow, TransformMeta import zipline.utils.math_utils as zp_math @@ -89,8 +88,11 @@ class MovingStandardDevWindow(EventWindow): self.sum = 0.0 self.sum_sqr = 0.0 + @property + def fields(self): + return ['price'] + def handle_add(self, event): - self.assert_required_fields(event) self.sum += event.price self.sum_sqr += event.price ** 2 @@ -113,13 +115,3 @@ class MovingStandardDevWindow(EventWindow): return 0.0 stddev = sqrt(s_squared) return stddev - - def assert_required_fields(self, event): - """ - We only allow events with a price field to be run through - the returns transform. - """ - if 'price' not in event: - raise WrongDataForTransform( - transform="StdDevEventWindow", - fields='price') diff --git a/zipline/transforms/utils.py b/zipline/transforms/utils.py index 70aed99e..fe2bf625 100644 --- a/zipline/transforms/utils.py +++ b/zipline/transforms/utils.py @@ -24,12 +24,13 @@ from numbers import Integral from datetime import datetime from collections import deque -from abc import ABCMeta, abstractmethod +from abc import ABCMeta, abstractmethod, abstractproperty from six import with_metaclass -from zipline.protocol import DATASOURCE_TYPE +from zipline.errors import WrongDataForTransform from zipline.gens.utils import assert_sort_unframe_protocol, hash_args +from zipline.protocol import DATASOURCE_TYPE from zipline.finance import trading log = logbook.Logger('Transform') @@ -128,7 +129,7 @@ class StatefulTransform(object): # other streams. Transforms that modify their input # messages should only manipulate copies. for message in stream_in: - # we only handle TRADE events. + # we only handle TRADE and CUSTOM events. if (hasattr(message, 'type') and message.type not in ( DATASOURCE_TYPE.TRADE, @@ -142,7 +143,22 @@ class StatefulTransform(object): assert_sort_unframe_protocol(message) - tnfm_value = self.state.update(message) + try: + tnfm_value = self.state.update(message) + except WrongDataForTransform: + # Transform classes should raise WrongDataForTransform if they + # are unable to process the event BEFORE performing any state + # modifications, because we continue the simulation if a + # WrongDataForTransform is raised on a CUSTOM event. + if message.type == DATASOURCE_TYPE.CUSTOM: + # Pass through custom events that are not applicable to + # this transform. + yield message + continue + else: + # If a TRADE event raises a WrongDataForTransform, + # something bad has happend. + raise out_message = message out_message[self.namestring] = tnfm_value @@ -193,6 +209,10 @@ class EventWindow(with_metaclass(ABCMeta)): def handle_add(self, event): raise NotImplementedError() + @abstractproperty + def fields(self): + raise NotImplementedError() + @abstractmethod def handle_remove(self, event): raise NotImplementedError() @@ -202,13 +222,8 @@ class EventWindow(with_metaclass(ABCMeta)): def update(self, event): - if (hasattr(event, 'type') - and event.type not in ( - DATASOURCE_TYPE.TRADE, - DATASOURCE_TYPE.CUSTOM)): - return - self.assert_well_formed(event) + # Add new event and increment totals. self.ticks.append(event) @@ -245,9 +260,13 @@ class EventWindow(with_metaclass(ABCMeta)): return trading_days_between >= self.window_length - # All event windows expect to receive events with datetime fields - # that arrive in sorted order. def assert_well_formed(self, event): + """ + Verify that the supplied event contains all the fields required by this + EventWindow to be processed. + """ + self.check_required_fields(event) + assert isinstance(event.dt, datetime), \ "Bad dt in EventWindow:%s" % event if len(self.ticks) > 0: @@ -255,3 +274,23 @@ class EventWindow(with_metaclass(ABCMeta)): assert event.dt >= self.ticks[-1].dt, \ "Events arrived out of order in EventWindow: %s -> %s" % \ (event, self.ticks[0]) + + def check_required_fields(self, event): + """ + We only allow events with all of our tracked fields. + """ + # All events require a 'dt' field. + if 'dt' not in event: + raise WrongDataForTransform( + transform=self.__class__.__name__, + fields=['dt'], + ) + + # Subclasses must implement the 'fields' property to specify other + # required fields. + for field in self.fields: + if field not in event: + raise WrongDataForTransform( + transform=self.__class__.__name__, + fields=self.fields, + ) diff --git a/zipline/transforms/vwap.py b/zipline/transforms/vwap.py index 545ab32e..aaa9a49b 100644 --- a/zipline/transforms/vwap.py +++ b/zipline/transforms/vwap.py @@ -17,7 +17,6 @@ from collections import defaultdict from six import with_metaclass -from zipline.errors import WrongDataForTransform from zipline.transforms.utils import EventWindow, TransformMeta @@ -73,14 +72,16 @@ class VWAPEventWindow(EventWindow): """ def __init__(self, market_aware=True, window_length=None, delta=None): EventWindow.__init__(self, market_aware, window_length, delta) - self.fields = ('price', 'volume') + self._fields = ('price', 'volume') self.flux = 0.0 self.totalvolume = 0.0 + @property + def fields(self): + return self._fields + # Subclass customization for adding new events. def handle_add(self, event): - # Sanity check on the event. - self.assert_required_fields(event) self.flux += event.volume * event.price self.totalvolume += event.volume @@ -98,11 +99,3 @@ class VWAPEventWindow(EventWindow): return None else: return (self.flux / self.totalvolume) - - # We need numerical price and volume to calculate a vwap. - def assert_required_fields(self, event): - for field in self.fields: - if field not in event: - raise WrongDataForTransform( - transform="VWAPEventWindow", - fields=self.fields)