mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-29 13:47:54 +08:00
added generator-style transforms
This commit is contained in:
@@ -12,7 +12,7 @@ from datetime import datetime, timedelta
|
||||
from collections import deque, defaultdict
|
||||
|
||||
from zipline import ndict
|
||||
from zipline.gens.utils import stringify_args, assert_datasource_protocol, \
|
||||
from zipline.gens.utils import hash_args, assert_datasource_protocol, \
|
||||
assert_trade_protocol, assert_datasource_unframe_protocol
|
||||
|
||||
import zipline.protocol as zp
|
||||
@@ -33,10 +33,7 @@ def FeedGen(stream_in, source_ids):
|
||||
assert isinstance(id, basestring), "Bad source_id %s" % source_id
|
||||
sources[id] = deque()
|
||||
|
||||
namestring = "FeedGen" + stringify_args(source_ids)
|
||||
|
||||
# Process incoming streams.
|
||||
|
||||
for message in stream_in:
|
||||
# Incoming messages should be the output of DATASOURCE_UNFRAME.
|
||||
assert_datasource_unframe_protocol(message), \
|
||||
@@ -52,6 +49,7 @@ def FeedGen(stream_in, source_ids):
|
||||
|
||||
while full(sources) and not done(sources):
|
||||
message = pop_oldest(sources)
|
||||
assert feed_protocol(message)
|
||||
yield message
|
||||
|
||||
# We should have only a done message left in each queue.
|
||||
|
||||
@@ -10,7 +10,7 @@ from pymongo import ASCENDING
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from zipline import ndict
|
||||
from zipline.gens.utils import stringify_args, assert_datasource_protocol, \
|
||||
from zipline.gens.utils import hash_args, assert_datasource_protocol, \
|
||||
assert_trade_protocol
|
||||
|
||||
import zipline.protocol as zp
|
||||
@@ -33,7 +33,7 @@ def MongoTradeHistoryGen(collection, filter, start_date, end_date):
|
||||
|
||||
# Create unique identifier string that can be used to break
|
||||
# sorting ties deterministically
|
||||
argstring = stringify_args(collection, filter, start_date, end_date)
|
||||
argstring = hash_args(collection, filter, start_date, end_date)
|
||||
source_id = "MongoTradeHistoryGen" + argstring
|
||||
|
||||
# All datasources
|
||||
|
||||
@@ -0,0 +1,200 @@
|
||||
"""
|
||||
Generator versions of transforms.
|
||||
"""
|
||||
|
||||
import pytz
|
||||
import logbook
|
||||
import pymongo
|
||||
import types
|
||||
|
||||
from pymongo import ASCENDING
|
||||
from datetime import datetime, timedelta
|
||||
from collections import deque, defaultdict
|
||||
from numbers import Number
|
||||
|
||||
from zipline import ndict
|
||||
from zipline.gens.utils import hash_args, assert_datasource_protocol, \
|
||||
assert_trade_protocol, assert_datasource_unframe_protocol, \
|
||||
assert_feed_protocol
|
||||
|
||||
import zipline.protocol as zp
|
||||
|
||||
def PassthroughTransformGen(stream_in):
|
||||
"""Trivial transform for event forwarding."""
|
||||
|
||||
# hash_args with no arguments is the same as hashlib.md5.update(":"); hashlib.md5.digest().
|
||||
namestring = "Passthrough" + hash_args()
|
||||
|
||||
for message in stream_in:
|
||||
assert_feed_unframe_protocol(message)
|
||||
out_value = message
|
||||
assert_transform_protocol(out_value)
|
||||
yield (namestring, out_value)
|
||||
|
||||
def FunctionalTransformGen(stream_in, fun, *args, **kwargs):
|
||||
"""
|
||||
Generic transform generator that takes each message from an in-stream
|
||||
and yields the output of a function on that message.
|
||||
"""
|
||||
|
||||
# TODO: Distinguish between functions and classes in hash_args.
|
||||
namestring = fun.__name__ + hash_args(*args, **kwargs)
|
||||
|
||||
for message in stream_in:
|
||||
assert_feed_unframe_protocol(message)
|
||||
out_value = fun(message, *args, **kwargs)
|
||||
assert_transform_protocol(out_value)
|
||||
yield(namestring, out_value)
|
||||
|
||||
|
||||
def StatefulTransformGen(stream_in, tnfm_class, *args, **kwargs):
|
||||
"""
|
||||
Generic transform generator that takes each message from an in-stream
|
||||
and feeds it to a state class. For each call to update, the state
|
||||
class must produce a message to be fed downstream.
|
||||
"""
|
||||
# Create an instance of our transform class.
|
||||
state = tnfm_class(*args, **kwargs)
|
||||
|
||||
# Generate the string associated with this generator's output.
|
||||
namestring = tnfm_class.__name__ + hash_args(*args, **kwargs)
|
||||
|
||||
for message in stream_in:
|
||||
assert_feed_unframe_protocol(message)
|
||||
out_value = state.update(message)
|
||||
assert_transform_protocol(out_value)
|
||||
yield (namestring, out_value)
|
||||
|
||||
def MovingAverageTransformGen(stream_in, days, fields):
|
||||
"""
|
||||
Generator that uses the MovingAverage state class to calculate
|
||||
a moving average for all stocks over a specified number of days.
|
||||
"""
|
||||
return StatefulTransformGen(stream_in, MovingAverage, timedelta(days=days), fields)
|
||||
|
||||
class MovingAverage(object):
|
||||
"""
|
||||
Class that maintains a dictionary from sids to EventWindows
|
||||
calculating an average value for the specified fields over the
|
||||
specified time window. Upon receipt of each message we update the
|
||||
corresponding window and return the calculated average.
|
||||
"""
|
||||
|
||||
def __init__(self, delta, fields):
|
||||
self.delta = delta
|
||||
self.fields = fields
|
||||
|
||||
# No way to pass arguments to the defaultdict factory, so we
|
||||
# need to define a method to generate the correct EventWindows.
|
||||
self.sid_windows = defaultdict(self.create_window)
|
||||
|
||||
def create_window(self):
|
||||
"""Factory method for self.sid_windows."""
|
||||
return EventWindow(self.delta, self.fields)
|
||||
|
||||
def update(self, event):
|
||||
"""
|
||||
Update the event window for this event's sid. Return an ndict from
|
||||
tracked fields to averages.
|
||||
"""
|
||||
|
||||
assert isinstance(event, ndict),"Bad event in MovingAverage: %s" % event
|
||||
assert event.has_key('sid'), "No sid in MovingAverage: %s" % event
|
||||
|
||||
# This will create a new EventWindow if this is the first
|
||||
# message for this sid.
|
||||
window = self.sid_windows[event.sid]
|
||||
window.update(event)
|
||||
return window.get_averages()
|
||||
|
||||
|
||||
class EventWindow(object):
|
||||
"""
|
||||
Maintains a list of events that are within a certain timedelta
|
||||
of the most recent tick. Also maintains a rolling average as
|
||||
a float on any specified fields. Events must arrive sorted by
|
||||
dt.
|
||||
"""
|
||||
def __init__(self, delta, fields):
|
||||
self.ticks = deque()
|
||||
self.delta = delta
|
||||
self.fields = fields
|
||||
self.totals = defaultdict(float)
|
||||
|
||||
|
||||
def update(self, event):
|
||||
self.assert_well_formed(event)
|
||||
# Add new event and increment totals.
|
||||
self.ticks.append(event)
|
||||
for field in self.fields:
|
||||
self.totals[field] += event[field]
|
||||
|
||||
# Clear out expired events, decrementing totals.
|
||||
# newest oldest
|
||||
# | |
|
||||
# V V
|
||||
while (self.ticks[-1].dt - self.ticks[0].dt) > self.delta:
|
||||
# popleft get ticks[0]
|
||||
popped = self.ticks.popleft()
|
||||
# Decrement totals
|
||||
for field in self.fields:
|
||||
self.totals[field] -= popped[field]
|
||||
|
||||
def average(self, field):
|
||||
assert field in self.fields
|
||||
if len(self.ticks) == 0:
|
||||
return 0.0
|
||||
else:
|
||||
return self.totals[field] / len(self.ticks)
|
||||
|
||||
def get_averages(self):
|
||||
"""
|
||||
Return an ndict of all our tracked averages.
|
||||
"""
|
||||
out = ndict()
|
||||
|
||||
for field in self.fields:
|
||||
out[field] = self.average(field)
|
||||
|
||||
return out
|
||||
|
||||
def assert_well_formed(self, event):
|
||||
assert isinstance(event, ndict), "Bad event in EventWindow:%s" % event
|
||||
assert event.has_key('dt'), "Missing dt in EventWindow:%s" % event
|
||||
assert isinstance(event.dt, datetime),"Bad dt in EventWindow:%s" % event
|
||||
if len(self.ticks) > 0:
|
||||
# Something is wrong if new event is older than previous.
|
||||
assert event.dt >= self.ticks[-1].dt, \
|
||||
"Events arrived out of order in EventWindow: %s -> %s" % (event, self.ticks[0])
|
||||
for field in self.fields:
|
||||
assert event.has_key(field), \
|
||||
"Event missing [%s] in EventWindow" % field
|
||||
assert isinstance(event[field], Number), \
|
||||
"Got %s for %s in EventWindow" % (event[field], field)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
averages = MovingAverage(timedelta(minutes = 1), ['price', 'vol'])
|
||||
|
||||
e = ndict()
|
||||
e.price = 1
|
||||
e.vol = 2
|
||||
e.sid = "foo"
|
||||
e.dt = datetime.now()
|
||||
|
||||
averages.update(e)
|
||||
|
||||
e = ndict()
|
||||
e.price = 2
|
||||
e.vol = 3
|
||||
e.sid = "foo"
|
||||
e.dt = datetime.now()
|
||||
averages.update(e)
|
||||
|
||||
e = ndict()
|
||||
e.price = 3
|
||||
e.vol = 1
|
||||
e.sid = "foo"
|
||||
e.dt = datetime.now() + timedelta(hours =1)
|
||||
averages.update(e)
|
||||
@@ -26,7 +26,7 @@ def alternate(g1, g2):
|
||||
if e2 != None:
|
||||
yield e2
|
||||
|
||||
def stringify_args(*args, **kwargs):
|
||||
def hash_args(*args, **kwargs):
|
||||
"""Define a unique string for any set of representable args."""
|
||||
arg_string = '_'.join([str(arg) for arg in args])
|
||||
kwarg_string = '_'.join([str(key) + '=' + str(value) for key, value in kwargs.iteritems()])
|
||||
@@ -62,11 +62,14 @@ def assert_trade_protocol(event):
|
||||
|
||||
def assert_datasource_unframe_protocol(event):
|
||||
"""Assert that an event is valid output of zp.DATASOURCE_UNFRAME."""
|
||||
|
||||
assert isinstance(event, ndict)
|
||||
assert isinstance(event.source_id, basestring)
|
||||
assert event.type in DATASOURCE_TYPE
|
||||
assert event.has_key('dt')
|
||||
|
||||
def assert_feed_protocol(event):
|
||||
pass
|
||||
"""Assert that an event is valid input to zp.FEED_FRAME."""
|
||||
assert isinstance(feed, ndict)
|
||||
assert isinstance(event.source_id, basestring)
|
||||
assert event.type in DATASOURCE_TYPE
|
||||
assert event.has_key('dt')
|
||||
|
||||
@@ -316,6 +316,7 @@ def FEED_FRAME(event):
|
||||
|
||||
- source_id
|
||||
- type
|
||||
- dt
|
||||
"""
|
||||
assert isinstance(event, ndict), 'unknown type %s' % str(event)
|
||||
source_id = event.source_id
|
||||
|
||||
Reference in New Issue
Block a user