mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 11:03:38 +08:00
203 lines
7.8 KiB
Python
203 lines
7.8 KiB
Python
"""
|
|
Generator versions of transforms.
|
|
"""
|
|
import types
|
|
|
|
from copy import deepcopy
|
|
from datetime import datetime
|
|
from collections import deque, defaultdict
|
|
from numbers import Number
|
|
from abc import ABCMeta, abstractmethod
|
|
|
|
from zipline import ndict
|
|
from zipline.utils.tradingcalendar import trading_days_between
|
|
from zipline.gens.utils import assert_sort_unframe_protocol, \
|
|
assert_transform_protocol, hash_args
|
|
|
|
class Passthrough(object):
|
|
FORWARDER = True
|
|
"""
|
|
Trivial class for forwarding events.
|
|
"""
|
|
def __init__(self):
|
|
pass
|
|
|
|
def update(self, event):
|
|
pass
|
|
|
|
def functional_transform(stream_in, func, *args, **kwargs):
|
|
"""
|
|
Generic transform generator that takes each message from an in-stream
|
|
and yields the output of a function on that message. Not sure how
|
|
useful this will be in reality, but good for testing.
|
|
"""
|
|
assert isinstance(func, types.FunctionType), \
|
|
"Functional"
|
|
namestring = func.__name__ + hash_args(*args, **kwargs)
|
|
|
|
for message in stream_in:
|
|
assert_sort_unframe_protocol(message)
|
|
out_value = func(message, *args, **kwargs)
|
|
assert_transform_protocol(out_value)
|
|
yield(namestring, out_value)
|
|
|
|
class StatefulTransform(object):
|
|
"""
|
|
Generic transform generator that takes each message from an
|
|
in-stream and passes it to a state object. For each call to
|
|
update, the state class must produce a message to be fed
|
|
downstream. Any transform class with the FORWARDER class variable
|
|
set to true will forward all fields in the original message.
|
|
Otherwise only dt, tnfm_id, and tnfm_value are forwarded.
|
|
"""
|
|
def __init__(self, tnfm_class, *args, **kwargs):
|
|
assert isinstance(tnfm_class, (types.ObjectType, types.ClassType)), \
|
|
"Stateful transform requires a class."
|
|
assert tnfm_class.__dict__.has_key('update'), \
|
|
"Stateful transform requires the class to have an update method"
|
|
|
|
self.forward_all = tnfm_class.__dict__.get('FORWARDER', False)
|
|
self.update_in_place = tnfm_class.__dict__.get('UPDATER', False)
|
|
self.append_value = tnfm_class.__dict__.get('APPENDER', False)
|
|
|
|
# You only one special behavior mode can be set.
|
|
assert sum(map(int, [self.forward_all,
|
|
self.update_in_place,
|
|
self.append_value])) <= 1
|
|
|
|
# Create an instance of our transform class.
|
|
self.state = tnfm_class(*args, **kwargs)
|
|
|
|
# Create the string associated with this generator's output.
|
|
self.namestring = tnfm_class.__name__ + hash_args(*args, **kwargs)
|
|
|
|
def get_hash(self):
|
|
return self.namestring
|
|
|
|
def transform(self, stream_in):
|
|
return self._gen(stream_in)
|
|
|
|
def _gen(self, stream_in):
|
|
# IMPORTANT: Messages may contain pointers that are shared with
|
|
# other streams, so we only manipulate copies.
|
|
|
|
for message in stream_in:
|
|
|
|
# allow upstream generators to yield None to avoid
|
|
# blocking.
|
|
if message == None:
|
|
continue
|
|
|
|
#TODO: refactor this to avoid unnecessary copying.
|
|
|
|
assert_sort_unframe_protocol(message)
|
|
message_copy = deepcopy(message)
|
|
|
|
# Same shared pointer issue here as above.
|
|
tnfm_value = self.state.update(deepcopy(message_copy))
|
|
|
|
# FORWARDER flag means we want to keep all original
|
|
# values, plus append tnfm_id and tnfm_value. Used for
|
|
# preserving the original event fields when our output
|
|
# will be fed into a merge.
|
|
if self.forward_all:
|
|
out_message = message_copy
|
|
out_message.tnfm_id = self.namestring
|
|
out_message.tnfm_value = tnfm_value
|
|
yield out_message
|
|
|
|
# UPDATER flag should be used for transforms that
|
|
# side-effectfully modify the event they are passed.
|
|
# Updated messages are passed along exactly as they are
|
|
# returned to use by our state class. Useful for chaining
|
|
# specific transforms that won't be fed to a merge. (See
|
|
# the implementation of TradeSimulationClient for example
|
|
# usage of this flag with PerformanceTracker and
|
|
# TransactionSimulator.
|
|
elif self.update_in_place:
|
|
yield tnfm_value
|
|
|
|
# APPENDER flag should be used to add a single new
|
|
# key-value pair to the event. The new key is this
|
|
# transform's namestring, and it's value is the value
|
|
# returned by state.update(event). This is almost
|
|
# identical to the behavior of FORWARDER, except we
|
|
# compress the two calculated values (tnfm_id, and
|
|
# tnfm_value) into a single field.
|
|
elif self.append_value:
|
|
out_message = message_copy
|
|
out_message[self.namestring] = tnfm_value
|
|
yield out_message
|
|
|
|
# If no flags are set, we create a new message containing
|
|
# just the tnfm_id, the event's datetime, and the
|
|
# calculated tnfm_value. This is the default behavior for
|
|
# a transform being fed into a merge.
|
|
else:
|
|
out_message = ndict()
|
|
out_message.tnfm_id = self.namestring
|
|
out_message.tnfm_value = tnfm_value
|
|
out_message.dt = message_copy.dt
|
|
yield out_message
|
|
|
|
class EventWindow:
|
|
"""
|
|
Abstract base class for transform classes that calculate iterative
|
|
metrics on events within a given timedelta. Maintains a list of
|
|
events that are within a certain timedelta of the most recent
|
|
tick. Calls self.handle_add(event) for each event added to the
|
|
window. Calls self.handle_remove(event) for each event removed
|
|
from the window. Subclass these methods along with init(*args,
|
|
**kwargs) to calculate metrics over the window.
|
|
|
|
See zipline/gens/mavg.py and zipline/gens/vwap.py for example
|
|
implementations of moving average and volume-weighted average
|
|
price.
|
|
"""
|
|
# Mark this as an abstract base class.
|
|
__metaclass__ = ABCMeta
|
|
|
|
def __init__(self, delta):
|
|
self.ticks = deque()
|
|
self.delta = delta
|
|
|
|
@abstractmethod
|
|
def handle_add(self, event):
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def handle_remove(self, event):
|
|
raise NotImplementedError()
|
|
|
|
def __len__(self):
|
|
return len(self.ticks)
|
|
|
|
def update(self, event):
|
|
self.assert_well_formed(event)
|
|
# Add new event and increment totals.
|
|
self.ticks.append(event)
|
|
self.handle_add(event)
|
|
|
|
# Clear out expired event.
|
|
#
|
|
# newest oldest
|
|
# | |
|
|
# V V
|
|
while (self.ticks[-1].dt - self.ticks[0].dt) > self.delta:
|
|
# popleft removes and returns the oldest tick in self.ticks
|
|
popped = self.ticks.popleft()
|
|
# Subclasses should override handle_remove to define
|
|
# behavior for removing ticks.
|
|
self.handle_remove(popped)
|
|
|
|
# All event windows expect to receive events with datetime fields
|
|
# that arrive in sorted order.
|
|
def assert_well_formed(self, event):
|
|
assert isinstance(event, ndict), "Bad event in EventWindow:%s" % event
|
|
assert event.has_key('dt'), "Missing dt in EventWindow:%s" % event
|
|
assert isinstance(event.dt, datetime),"Bad dt in EventWindow:%s" % event
|
|
if len(self.ticks) > 0:
|
|
# Something is wrong if new event is older than previous.
|
|
assert event.dt >= self.ticks[-1].dt, \
|
|
"Events arrived out of order in EventWindow: %s -> %s" % (event, self.ticks[0])
|