mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-06 05:14:38 +08:00
BUG: Fix crash in transforms on malformed CUSTOM events.
Fixes a crash in various transforms when providing CUSTOM events whose fields don't match the fields required for the transform. This is fixed by requiring all `EventWindow` subclasses to supply a `fields` property, which returns a list of strings that are required keys for any event that can be processed by the window. Any CUSTOM events the don't supply the required fields for a transform window are ignored by that window.
This commit is contained in:
+79
-12
@@ -17,13 +17,18 @@ import pytz
|
||||
import numpy as np
|
||||
|
||||
from datetime import timedelta, datetime
|
||||
from itertools import chain
|
||||
from unittest import TestCase
|
||||
|
||||
from nose_parameterized import parameterized
|
||||
from six.moves import range
|
||||
|
||||
from zipline.utils.test_utils import setup_logger
|
||||
|
||||
from zipline.protocol import Event
|
||||
from zipline.protocol import (
|
||||
DATASOURCE_TYPE,
|
||||
Event,
|
||||
)
|
||||
from zipline.sources import SpecificEquityTrades
|
||||
from zipline.transforms.utils import StatefulTransform, EventWindow
|
||||
from zipline.transforms import MovingVWAP
|
||||
@@ -47,6 +52,11 @@ class NoopEventWindow(EventWindow):
|
||||
|
||||
self.added = []
|
||||
self.removed = []
|
||||
self._fields = []
|
||||
|
||||
@property
|
||||
def fields(self):
|
||||
return self._fields
|
||||
|
||||
def handle_add(self, event):
|
||||
self.added.append(event)
|
||||
@@ -134,20 +144,53 @@ class TestFinanceTransforms(TestCase):
|
||||
timedelta(days=1),
|
||||
self.sim_params
|
||||
)
|
||||
self.source = SpecificEquityTrades(event_list=trade_history)
|
||||
self.source = trade_history
|
||||
|
||||
def intersperse_custom_events(self, events):
|
||||
"""
|
||||
Take a stream of events and return the same stream with a minimal event
|
||||
of type CUSTOM following each trade event. Used to test graceful
|
||||
handling of CUSTOM events that are missing required transform fields.
|
||||
"""
|
||||
return list(
|
||||
chain.from_iterable(
|
||||
(
|
||||
event,
|
||||
Event(
|
||||
initial_values={
|
||||
'dt': event.dt,
|
||||
'sid': event.sid,
|
||||
'source_id': "fake_custom_source",
|
||||
'type': DATASOURCE_TYPE.CUSTOM
|
||||
}
|
||||
)
|
||||
)
|
||||
for event in events
|
||||
)
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
self.log_handler.pop_application()
|
||||
|
||||
def test_vwap(self):
|
||||
@parameterized.expand([
|
||||
('with_custom', True),
|
||||
('without_custom', False),
|
||||
])
|
||||
def test_vwap(self, name, add_custom_events):
|
||||
vwap = MovingVWAP(
|
||||
market_aware=True,
|
||||
window_length=2
|
||||
)
|
||||
|
||||
if add_custom_events:
|
||||
self.source = self.intersperse_custom_events(self.source)
|
||||
|
||||
transformed = list(vwap.transform(self.source))
|
||||
|
||||
# Output values
|
||||
tnfm_vals = [message[vwap.get_hash()] for message in transformed]
|
||||
# Output values. Unprocessed custom events will not have a field
|
||||
# corresponding to the transform hash.
|
||||
tnfm_vals = [message[vwap.get_hash()] for message in transformed
|
||||
if message.type != DATASOURCE_TYPE.CUSTOM]
|
||||
# "Hand calculated" values.
|
||||
expected = [
|
||||
(10.0 * 100) / 100.0,
|
||||
@@ -161,12 +204,20 @@ class TestFinanceTransforms(TestCase):
|
||||
# Output should match the expected.
|
||||
self.assertEquals(tnfm_vals, expected)
|
||||
|
||||
def test_returns(self):
|
||||
@parameterized.expand([
|
||||
('with_custom', True),
|
||||
('without_custom', False),
|
||||
])
|
||||
def test_returns(self, name, add_custom_events):
|
||||
# Daily returns.
|
||||
returns = Returns(1)
|
||||
|
||||
if add_custom_events:
|
||||
self.source = self.intersperse_custom_events(self.source)
|
||||
|
||||
transformed = list(returns.transform(self.source))
|
||||
tnfm_vals = [message[returns.get_hash()] for message in transformed]
|
||||
tnfm_vals = [message[returns.get_hash()] for message in transformed
|
||||
if message.type != DATASOURCE_TYPE.CUSTOM]
|
||||
|
||||
# No returns for the first event because we don't have a
|
||||
# previous close.
|
||||
@@ -202,7 +253,11 @@ class TestFinanceTransforms(TestCase):
|
||||
|
||||
self.assertEquals(tnfm_vals, expected)
|
||||
|
||||
def test_moving_average(self):
|
||||
@parameterized.expand([
|
||||
('with_custom', True),
|
||||
('without_custom', False),
|
||||
])
|
||||
def test_moving_average(self, name, add_custom_events):
|
||||
|
||||
mavg = MovingAverage(
|
||||
market_aware=True,
|
||||
@@ -210,12 +265,17 @@ class TestFinanceTransforms(TestCase):
|
||||
window_length=2
|
||||
)
|
||||
|
||||
if add_custom_events:
|
||||
self.source = self.intersperse_custom_events(self.source)
|
||||
|
||||
transformed = list(mavg.transform(self.source))
|
||||
# Output values.
|
||||
tnfm_prices = [message[mavg.get_hash()].price
|
||||
for message in transformed]
|
||||
for message in transformed
|
||||
if message.type != DATASOURCE_TYPE.CUSTOM]
|
||||
tnfm_volumes = [message[mavg.get_hash()].volume
|
||||
for message in transformed]
|
||||
for message in transformed
|
||||
if message.type != DATASOURCE_TYPE.CUSTOM]
|
||||
|
||||
# "Hand-calculated" values
|
||||
expected_prices = [
|
||||
@@ -238,7 +298,11 @@ class TestFinanceTransforms(TestCase):
|
||||
self.assertEquals(tnfm_prices, expected_prices)
|
||||
self.assertEquals(tnfm_volumes, expected_volumes)
|
||||
|
||||
def test_moving_stddev(self):
|
||||
@parameterized.expand([
|
||||
('with_custom', True),
|
||||
('without_custom', False),
|
||||
])
|
||||
def test_moving_stddev(self, name, add_custom_events):
|
||||
trade_history = factory.create_trade_history(
|
||||
133,
|
||||
[10.0, 15.0, 13.0, 12.0],
|
||||
@@ -253,10 +317,13 @@ class TestFinanceTransforms(TestCase):
|
||||
)
|
||||
|
||||
self.source = SpecificEquityTrades(event_list=trade_history)
|
||||
if add_custom_events:
|
||||
self.source = self.intersperse_custom_events(self.source)
|
||||
|
||||
transformed = list(stddev.transform(self.source))
|
||||
|
||||
vals = [message[stddev.get_hash()] for message in transformed]
|
||||
vals = [message[stddev.get_hash()] for message in transformed
|
||||
if message.type != DATASOURCE_TYPE.CUSTOM]
|
||||
|
||||
expected = [
|
||||
None,
|
||||
|
||||
@@ -18,7 +18,6 @@ from collections import defaultdict
|
||||
from six import string_types, with_metaclass
|
||||
|
||||
from zipline.transforms.utils import EventWindow, TransformMeta
|
||||
from zipline.errors import WrongDataForTransform
|
||||
|
||||
|
||||
class MovingAverage(with_metaclass(TransformMeta)):
|
||||
@@ -108,13 +107,16 @@ class MovingAverageEventWindow(EventWindow):
|
||||
|
||||
# We maintain a dictionary of totals for each of our tracked
|
||||
# fields.
|
||||
self.fields = fields
|
||||
self._fields = fields
|
||||
self.totals = defaultdict(float)
|
||||
|
||||
@property
|
||||
def fields(self):
|
||||
return self._fields
|
||||
|
||||
# Subclass customization for adding new events.
|
||||
def handle_add(self, event):
|
||||
# Sanity check on the event.
|
||||
self.assert_required_fields(event)
|
||||
# Increment our running totals with data from the event.
|
||||
for field in self.fields:
|
||||
self.totals[field] += event[field]
|
||||
@@ -148,13 +150,3 @@ class MovingAverageEventWindow(EventWindow):
|
||||
for field in self.fields:
|
||||
out.__dict__[field] = self.average(field)
|
||||
return out
|
||||
|
||||
def assert_required_fields(self, event):
|
||||
"""
|
||||
We only allow events with all of our tracked fields.
|
||||
"""
|
||||
for field in self.fields:
|
||||
if field not in event:
|
||||
raise WrongDataForTransform(
|
||||
transform="MovingAverageEventWindow",
|
||||
fields=self.fields)
|
||||
|
||||
@@ -59,7 +59,7 @@ class ReturnsFromPriorClose(object):
|
||||
self.window_length = window_length
|
||||
|
||||
def update(self, event):
|
||||
self.assert_required_fields(event)
|
||||
self.check_required_fields(event)
|
||||
if self.last_event:
|
||||
|
||||
# Day has changed since the last event we saw. Treat
|
||||
@@ -91,7 +91,7 @@ class ReturnsFromPriorClose(object):
|
||||
# the current event is now the last_event
|
||||
self.last_event = event
|
||||
|
||||
def assert_required_fields(self, event):
|
||||
def check_required_fields(self, event):
|
||||
"""
|
||||
We only allow events with a price field to be run through
|
||||
the returns transform.
|
||||
@@ -99,4 +99,4 @@ class ReturnsFromPriorClose(object):
|
||||
if 'price' not in event:
|
||||
raise WrongDataForTransform(
|
||||
transform="ReturnsEventWindow",
|
||||
fields='price')
|
||||
fields=['price'])
|
||||
|
||||
@@ -18,7 +18,6 @@ from math import sqrt
|
||||
|
||||
from six import with_metaclass
|
||||
|
||||
from zipline.errors import WrongDataForTransform
|
||||
from zipline.transforms.utils import EventWindow, TransformMeta
|
||||
import zipline.utils.math_utils as zp_math
|
||||
|
||||
@@ -89,8 +88,11 @@ class MovingStandardDevWindow(EventWindow):
|
||||
self.sum = 0.0
|
||||
self.sum_sqr = 0.0
|
||||
|
||||
@property
|
||||
def fields(self):
|
||||
return ['price']
|
||||
|
||||
def handle_add(self, event):
|
||||
self.assert_required_fields(event)
|
||||
self.sum += event.price
|
||||
self.sum_sqr += event.price ** 2
|
||||
|
||||
@@ -113,13 +115,3 @@ class MovingStandardDevWindow(EventWindow):
|
||||
return 0.0
|
||||
stddev = sqrt(s_squared)
|
||||
return stddev
|
||||
|
||||
def assert_required_fields(self, event):
|
||||
"""
|
||||
We only allow events with a price field to be run through
|
||||
the returns transform.
|
||||
"""
|
||||
if 'price' not in event:
|
||||
raise WrongDataForTransform(
|
||||
transform="StdDevEventWindow",
|
||||
fields='price')
|
||||
|
||||
+51
-12
@@ -24,12 +24,13 @@ from numbers import Integral
|
||||
|
||||
from datetime import datetime
|
||||
from collections import deque
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from abc import ABCMeta, abstractmethod, abstractproperty
|
||||
|
||||
from six import with_metaclass
|
||||
|
||||
from zipline.protocol import DATASOURCE_TYPE
|
||||
from zipline.errors import WrongDataForTransform
|
||||
from zipline.gens.utils import assert_sort_unframe_protocol, hash_args
|
||||
from zipline.protocol import DATASOURCE_TYPE
|
||||
from zipline.finance import trading
|
||||
|
||||
log = logbook.Logger('Transform')
|
||||
@@ -128,7 +129,7 @@ class StatefulTransform(object):
|
||||
# other streams. Transforms that modify their input
|
||||
# messages should only manipulate copies.
|
||||
for message in stream_in:
|
||||
# we only handle TRADE events.
|
||||
# we only handle TRADE and CUSTOM events.
|
||||
if (hasattr(message, 'type')
|
||||
and message.type not in (
|
||||
DATASOURCE_TYPE.TRADE,
|
||||
@@ -142,7 +143,22 @@ class StatefulTransform(object):
|
||||
|
||||
assert_sort_unframe_protocol(message)
|
||||
|
||||
tnfm_value = self.state.update(message)
|
||||
try:
|
||||
tnfm_value = self.state.update(message)
|
||||
except WrongDataForTransform:
|
||||
# Transform classes should raise WrongDataForTransform if they
|
||||
# are unable to process the event BEFORE performing any state
|
||||
# modifications, because we continue the simulation if a
|
||||
# WrongDataForTransform is raised on a CUSTOM event.
|
||||
if message.type == DATASOURCE_TYPE.CUSTOM:
|
||||
# Pass through custom events that are not applicable to
|
||||
# this transform.
|
||||
yield message
|
||||
continue
|
||||
else:
|
||||
# If a TRADE event raises a WrongDataForTransform,
|
||||
# something bad has happend.
|
||||
raise
|
||||
|
||||
out_message = message
|
||||
out_message[self.namestring] = tnfm_value
|
||||
@@ -193,6 +209,10 @@ class EventWindow(with_metaclass(ABCMeta)):
|
||||
def handle_add(self, event):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractproperty
|
||||
def fields(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def handle_remove(self, event):
|
||||
raise NotImplementedError()
|
||||
@@ -202,13 +222,8 @@ class EventWindow(with_metaclass(ABCMeta)):
|
||||
|
||||
def update(self, event):
|
||||
|
||||
if (hasattr(event, 'type')
|
||||
and event.type not in (
|
||||
DATASOURCE_TYPE.TRADE,
|
||||
DATASOURCE_TYPE.CUSTOM)):
|
||||
return
|
||||
|
||||
self.assert_well_formed(event)
|
||||
|
||||
# Add new event and increment totals.
|
||||
self.ticks.append(event)
|
||||
|
||||
@@ -245,9 +260,13 @@ class EventWindow(with_metaclass(ABCMeta)):
|
||||
|
||||
return trading_days_between >= self.window_length
|
||||
|
||||
# All event windows expect to receive events with datetime fields
|
||||
# that arrive in sorted order.
|
||||
def assert_well_formed(self, event):
|
||||
"""
|
||||
Verify that the supplied event contains all the fields required by this
|
||||
EventWindow to be processed.
|
||||
"""
|
||||
self.check_required_fields(event)
|
||||
|
||||
assert isinstance(event.dt, datetime), \
|
||||
"Bad dt in EventWindow:%s" % event
|
||||
if len(self.ticks) > 0:
|
||||
@@ -255,3 +274,23 @@ class EventWindow(with_metaclass(ABCMeta)):
|
||||
assert event.dt >= self.ticks[-1].dt, \
|
||||
"Events arrived out of order in EventWindow: %s -> %s" % \
|
||||
(event, self.ticks[0])
|
||||
|
||||
def check_required_fields(self, event):
|
||||
"""
|
||||
We only allow events with all of our tracked fields.
|
||||
"""
|
||||
# All events require a 'dt' field.
|
||||
if 'dt' not in event:
|
||||
raise WrongDataForTransform(
|
||||
transform=self.__class__.__name__,
|
||||
fields=['dt'],
|
||||
)
|
||||
|
||||
# Subclasses must implement the 'fields' property to specify other
|
||||
# required fields.
|
||||
for field in self.fields:
|
||||
if field not in event:
|
||||
raise WrongDataForTransform(
|
||||
transform=self.__class__.__name__,
|
||||
fields=self.fields,
|
||||
)
|
||||
|
||||
@@ -17,7 +17,6 @@ from collections import defaultdict
|
||||
|
||||
from six import with_metaclass
|
||||
|
||||
from zipline.errors import WrongDataForTransform
|
||||
from zipline.transforms.utils import EventWindow, TransformMeta
|
||||
|
||||
|
||||
@@ -73,14 +72,16 @@ class VWAPEventWindow(EventWindow):
|
||||
"""
|
||||
def __init__(self, market_aware=True, window_length=None, delta=None):
|
||||
EventWindow.__init__(self, market_aware, window_length, delta)
|
||||
self.fields = ('price', 'volume')
|
||||
self._fields = ('price', 'volume')
|
||||
self.flux = 0.0
|
||||
self.totalvolume = 0.0
|
||||
|
||||
@property
|
||||
def fields(self):
|
||||
return self._fields
|
||||
|
||||
# Subclass customization for adding new events.
|
||||
def handle_add(self, event):
|
||||
# Sanity check on the event.
|
||||
self.assert_required_fields(event)
|
||||
self.flux += event.volume * event.price
|
||||
self.totalvolume += event.volume
|
||||
|
||||
@@ -98,11 +99,3 @@ class VWAPEventWindow(EventWindow):
|
||||
return None
|
||||
else:
|
||||
return (self.flux / self.totalvolume)
|
||||
|
||||
# We need numerical price and volume to calculate a vwap.
|
||||
def assert_required_fields(self, event):
|
||||
for field in self.fields:
|
||||
if field not in event:
|
||||
raise WrongDataForTransform(
|
||||
transform="VWAPEventWindow",
|
||||
fields=self.fields)
|
||||
|
||||
Reference in New Issue
Block a user