BUG: Fix crash in transforms on malformed CUSTOM events.

Fixes a crash in various transforms when providing CUSTOM events whose fields
don't match the fields required for the transform.

This is fixed by requiring all `EventWindow` subclasses to supply a `fields`
property, which returns a list of strings that are required keys for any event
that can be processed by the window.  Any CUSTOM events the don't supply the
required fields for a transform window are ignored by that window.
This commit is contained in:
Scott Sanderson
2014-07-21 18:07:26 -04:00
parent f9b1fe6ff6
commit d02dd972d5
6 changed files with 147 additions and 64 deletions
+79 -12
View File
@@ -17,13 +17,18 @@ import pytz
import numpy as np
from datetime import timedelta, datetime
from itertools import chain
from unittest import TestCase
from nose_parameterized import parameterized
from six.moves import range
from zipline.utils.test_utils import setup_logger
from zipline.protocol import Event
from zipline.protocol import (
DATASOURCE_TYPE,
Event,
)
from zipline.sources import SpecificEquityTrades
from zipline.transforms.utils import StatefulTransform, EventWindow
from zipline.transforms import MovingVWAP
@@ -47,6 +52,11 @@ class NoopEventWindow(EventWindow):
self.added = []
self.removed = []
self._fields = []
@property
def fields(self):
return self._fields
def handle_add(self, event):
self.added.append(event)
@@ -134,20 +144,53 @@ class TestFinanceTransforms(TestCase):
timedelta(days=1),
self.sim_params
)
self.source = SpecificEquityTrades(event_list=trade_history)
self.source = trade_history
def intersperse_custom_events(self, events):
"""
Take a stream of events and return the same stream with a minimal event
of type CUSTOM following each trade event. Used to test graceful
handling of CUSTOM events that are missing required transform fields.
"""
return list(
chain.from_iterable(
(
event,
Event(
initial_values={
'dt': event.dt,
'sid': event.sid,
'source_id': "fake_custom_source",
'type': DATASOURCE_TYPE.CUSTOM
}
)
)
for event in events
)
)
def tearDown(self):
self.log_handler.pop_application()
def test_vwap(self):
@parameterized.expand([
('with_custom', True),
('without_custom', False),
])
def test_vwap(self, name, add_custom_events):
vwap = MovingVWAP(
market_aware=True,
window_length=2
)
if add_custom_events:
self.source = self.intersperse_custom_events(self.source)
transformed = list(vwap.transform(self.source))
# Output values
tnfm_vals = [message[vwap.get_hash()] for message in transformed]
# Output values. Unprocessed custom events will not have a field
# corresponding to the transform hash.
tnfm_vals = [message[vwap.get_hash()] for message in transformed
if message.type != DATASOURCE_TYPE.CUSTOM]
# "Hand calculated" values.
expected = [
(10.0 * 100) / 100.0,
@@ -161,12 +204,20 @@ class TestFinanceTransforms(TestCase):
# Output should match the expected.
self.assertEquals(tnfm_vals, expected)
def test_returns(self):
@parameterized.expand([
('with_custom', True),
('without_custom', False),
])
def test_returns(self, name, add_custom_events):
# Daily returns.
returns = Returns(1)
if add_custom_events:
self.source = self.intersperse_custom_events(self.source)
transformed = list(returns.transform(self.source))
tnfm_vals = [message[returns.get_hash()] for message in transformed]
tnfm_vals = [message[returns.get_hash()] for message in transformed
if message.type != DATASOURCE_TYPE.CUSTOM]
# No returns for the first event because we don't have a
# previous close.
@@ -202,7 +253,11 @@ class TestFinanceTransforms(TestCase):
self.assertEquals(tnfm_vals, expected)
def test_moving_average(self):
@parameterized.expand([
('with_custom', True),
('without_custom', False),
])
def test_moving_average(self, name, add_custom_events):
mavg = MovingAverage(
market_aware=True,
@@ -210,12 +265,17 @@ class TestFinanceTransforms(TestCase):
window_length=2
)
if add_custom_events:
self.source = self.intersperse_custom_events(self.source)
transformed = list(mavg.transform(self.source))
# Output values.
tnfm_prices = [message[mavg.get_hash()].price
for message in transformed]
for message in transformed
if message.type != DATASOURCE_TYPE.CUSTOM]
tnfm_volumes = [message[mavg.get_hash()].volume
for message in transformed]
for message in transformed
if message.type != DATASOURCE_TYPE.CUSTOM]
# "Hand-calculated" values
expected_prices = [
@@ -238,7 +298,11 @@ class TestFinanceTransforms(TestCase):
self.assertEquals(tnfm_prices, expected_prices)
self.assertEquals(tnfm_volumes, expected_volumes)
def test_moving_stddev(self):
@parameterized.expand([
('with_custom', True),
('without_custom', False),
])
def test_moving_stddev(self, name, add_custom_events):
trade_history = factory.create_trade_history(
133,
[10.0, 15.0, 13.0, 12.0],
@@ -253,10 +317,13 @@ class TestFinanceTransforms(TestCase):
)
self.source = SpecificEquityTrades(event_list=trade_history)
if add_custom_events:
self.source = self.intersperse_custom_events(self.source)
transformed = list(stddev.transform(self.source))
vals = [message[stddev.get_hash()] for message in transformed]
vals = [message[stddev.get_hash()] for message in transformed
if message.type != DATASOURCE_TYPE.CUSTOM]
expected = [
None,
+5 -13
View File
@@ -18,7 +18,6 @@ from collections import defaultdict
from six import string_types, with_metaclass
from zipline.transforms.utils import EventWindow, TransformMeta
from zipline.errors import WrongDataForTransform
class MovingAverage(with_metaclass(TransformMeta)):
@@ -108,13 +107,16 @@ class MovingAverageEventWindow(EventWindow):
# We maintain a dictionary of totals for each of our tracked
# fields.
self.fields = fields
self._fields = fields
self.totals = defaultdict(float)
@property
def fields(self):
return self._fields
# Subclass customization for adding new events.
def handle_add(self, event):
# Sanity check on the event.
self.assert_required_fields(event)
# Increment our running totals with data from the event.
for field in self.fields:
self.totals[field] += event[field]
@@ -148,13 +150,3 @@ class MovingAverageEventWindow(EventWindow):
for field in self.fields:
out.__dict__[field] = self.average(field)
return out
def assert_required_fields(self, event):
"""
We only allow events with all of our tracked fields.
"""
for field in self.fields:
if field not in event:
raise WrongDataForTransform(
transform="MovingAverageEventWindow",
fields=self.fields)
+3 -3
View File
@@ -59,7 +59,7 @@ class ReturnsFromPriorClose(object):
self.window_length = window_length
def update(self, event):
self.assert_required_fields(event)
self.check_required_fields(event)
if self.last_event:
# Day has changed since the last event we saw. Treat
@@ -91,7 +91,7 @@ class ReturnsFromPriorClose(object):
# the current event is now the last_event
self.last_event = event
def assert_required_fields(self, event):
def check_required_fields(self, event):
"""
We only allow events with a price field to be run through
the returns transform.
@@ -99,4 +99,4 @@ class ReturnsFromPriorClose(object):
if 'price' not in event:
raise WrongDataForTransform(
transform="ReturnsEventWindow",
fields='price')
fields=['price'])
+4 -12
View File
@@ -18,7 +18,6 @@ from math import sqrt
from six import with_metaclass
from zipline.errors import WrongDataForTransform
from zipline.transforms.utils import EventWindow, TransformMeta
import zipline.utils.math_utils as zp_math
@@ -89,8 +88,11 @@ class MovingStandardDevWindow(EventWindow):
self.sum = 0.0
self.sum_sqr = 0.0
@property
def fields(self):
return ['price']
def handle_add(self, event):
self.assert_required_fields(event)
self.sum += event.price
self.sum_sqr += event.price ** 2
@@ -113,13 +115,3 @@ class MovingStandardDevWindow(EventWindow):
return 0.0
stddev = sqrt(s_squared)
return stddev
def assert_required_fields(self, event):
"""
We only allow events with a price field to be run through
the returns transform.
"""
if 'price' not in event:
raise WrongDataForTransform(
transform="StdDevEventWindow",
fields='price')
+51 -12
View File
@@ -24,12 +24,13 @@ from numbers import Integral
from datetime import datetime
from collections import deque
from abc import ABCMeta, abstractmethod
from abc import ABCMeta, abstractmethod, abstractproperty
from six import with_metaclass
from zipline.protocol import DATASOURCE_TYPE
from zipline.errors import WrongDataForTransform
from zipline.gens.utils import assert_sort_unframe_protocol, hash_args
from zipline.protocol import DATASOURCE_TYPE
from zipline.finance import trading
log = logbook.Logger('Transform')
@@ -128,7 +129,7 @@ class StatefulTransform(object):
# other streams. Transforms that modify their input
# messages should only manipulate copies.
for message in stream_in:
# we only handle TRADE events.
# we only handle TRADE and CUSTOM events.
if (hasattr(message, 'type')
and message.type not in (
DATASOURCE_TYPE.TRADE,
@@ -142,7 +143,22 @@ class StatefulTransform(object):
assert_sort_unframe_protocol(message)
tnfm_value = self.state.update(message)
try:
tnfm_value = self.state.update(message)
except WrongDataForTransform:
# Transform classes should raise WrongDataForTransform if they
# are unable to process the event BEFORE performing any state
# modifications, because we continue the simulation if a
# WrongDataForTransform is raised on a CUSTOM event.
if message.type == DATASOURCE_TYPE.CUSTOM:
# Pass through custom events that are not applicable to
# this transform.
yield message
continue
else:
# If a TRADE event raises a WrongDataForTransform,
# something bad has happend.
raise
out_message = message
out_message[self.namestring] = tnfm_value
@@ -193,6 +209,10 @@ class EventWindow(with_metaclass(ABCMeta)):
def handle_add(self, event):
raise NotImplementedError()
@abstractproperty
def fields(self):
raise NotImplementedError()
@abstractmethod
def handle_remove(self, event):
raise NotImplementedError()
@@ -202,13 +222,8 @@ class EventWindow(with_metaclass(ABCMeta)):
def update(self, event):
if (hasattr(event, 'type')
and event.type not in (
DATASOURCE_TYPE.TRADE,
DATASOURCE_TYPE.CUSTOM)):
return
self.assert_well_formed(event)
# Add new event and increment totals.
self.ticks.append(event)
@@ -245,9 +260,13 @@ class EventWindow(with_metaclass(ABCMeta)):
return trading_days_between >= self.window_length
# All event windows expect to receive events with datetime fields
# that arrive in sorted order.
def assert_well_formed(self, event):
"""
Verify that the supplied event contains all the fields required by this
EventWindow to be processed.
"""
self.check_required_fields(event)
assert isinstance(event.dt, datetime), \
"Bad dt in EventWindow:%s" % event
if len(self.ticks) > 0:
@@ -255,3 +274,23 @@ class EventWindow(with_metaclass(ABCMeta)):
assert event.dt >= self.ticks[-1].dt, \
"Events arrived out of order in EventWindow: %s -> %s" % \
(event, self.ticks[0])
def check_required_fields(self, event):
"""
We only allow events with all of our tracked fields.
"""
# All events require a 'dt' field.
if 'dt' not in event:
raise WrongDataForTransform(
transform=self.__class__.__name__,
fields=['dt'],
)
# Subclasses must implement the 'fields' property to specify other
# required fields.
for field in self.fields:
if field not in event:
raise WrongDataForTransform(
transform=self.__class__.__name__,
fields=self.fields,
)
+5 -12
View File
@@ -17,7 +17,6 @@ from collections import defaultdict
from six import with_metaclass
from zipline.errors import WrongDataForTransform
from zipline.transforms.utils import EventWindow, TransformMeta
@@ -73,14 +72,16 @@ class VWAPEventWindow(EventWindow):
"""
def __init__(self, market_aware=True, window_length=None, delta=None):
EventWindow.__init__(self, market_aware, window_length, delta)
self.fields = ('price', 'volume')
self._fields = ('price', 'volume')
self.flux = 0.0
self.totalvolume = 0.0
@property
def fields(self):
return self._fields
# Subclass customization for adding new events.
def handle_add(self, event):
# Sanity check on the event.
self.assert_required_fields(event)
self.flux += event.volume * event.price
self.totalvolume += event.volume
@@ -98,11 +99,3 @@ class VWAPEventWindow(EventWindow):
return None
else:
return (self.flux / self.totalvolume)
# We need numerical price and volume to calculate a vwap.
def assert_required_fields(self, event):
for field in self.fields:
if field not in event:
raise WrongDataForTransform(
transform="VWAPEventWindow",
fields=self.fields)