mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-01 13:09:32 +08:00
movingaverage implemented as transform
This commit is contained in:
+55
-44
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
Generator versions of transforms.
|
||||
"""
|
||||
|
||||
import random
|
||||
import pytz
|
||||
import logbook
|
||||
import pymongo
|
||||
@@ -11,18 +11,22 @@ from pymongo import ASCENDING
|
||||
from datetime import datetime, timedelta
|
||||
from collections import deque, defaultdict
|
||||
from numbers import Number
|
||||
from itertools import izip
|
||||
|
||||
from zipline import ndict
|
||||
from zipline.gens.utils import hash_args, assert_datasource_protocol, \
|
||||
assert_trade_protocol, assert_datasource_unframe_protocol, \
|
||||
assert_feed_protocol
|
||||
from zipline.gens.utils import hash_args, date_gen
|
||||
from zipline.gens.utils import assert_feed_unframe_protocol, assert_transform_protocol
|
||||
|
||||
import zipline.protocol as zp
|
||||
|
||||
def PassthroughTransformGen(stream_in):
|
||||
"""Trivial transform for event forwarding."""
|
||||
|
||||
# hash_args with no arguments is the same as hashlib.md5.update(":"); hashlib.md5.digest().
|
||||
# hash_args with no arguments is the same as:
|
||||
# hasher = hashlib.md5()
|
||||
# hasher.update(":");
|
||||
# hashlib.md5.digest().
|
||||
|
||||
namestring = "Passthrough" + hash_args()
|
||||
|
||||
for message in stream_in:
|
||||
@@ -34,7 +38,8 @@ def PassthroughTransformGen(stream_in):
|
||||
def FunctionalTransformGen(stream_in, fun, *args, **kwargs):
|
||||
"""
|
||||
Generic transform generator that takes each message from an in-stream
|
||||
and yields the output of a function on that message.
|
||||
and yields the output of a function on that message. Not sure how
|
||||
useful this will be in reality, but good for testing.
|
||||
"""
|
||||
|
||||
# TODO: Distinguish between functions and classes in hash_args.
|
||||
@@ -46,13 +51,13 @@ def FunctionalTransformGen(stream_in, fun, *args, **kwargs):
|
||||
assert_transform_protocol(out_value)
|
||||
yield(namestring, out_value)
|
||||
|
||||
|
||||
def StatefulTransformGen(stream_in, tnfm_class, *args, **kwargs):
|
||||
"""
|
||||
Generic transform generator that takes each message from an in-stream
|
||||
and feeds it to a state class. For each call to update, the state
|
||||
class must produce a message to be fed downstream.
|
||||
"""
|
||||
"""
|
||||
|
||||
# Create an instance of our transform class.
|
||||
state = tnfm_class(*args, **kwargs)
|
||||
|
||||
@@ -75,8 +80,7 @@ def MovingAverageTransformGen(stream_in, days, fields):
|
||||
class MovingAverage(object):
|
||||
"""
|
||||
Class that maintains a dictionary from sids to EventWindows
|
||||
calculating an average value for the specified fields over the
|
||||
specified time window. Upon receipt of each message we update the
|
||||
Upon receipt of each message we update the
|
||||
corresponding window and return the calculated average.
|
||||
"""
|
||||
|
||||
@@ -105,41 +109,53 @@ class MovingAverage(object):
|
||||
# message for this sid.
|
||||
window = self.sid_windows[event.sid]
|
||||
window.update(event)
|
||||
|
||||
return window.get_averages()
|
||||
|
||||
|
||||
class EventWindow(object):
|
||||
"""
|
||||
Maintains a list of events that are within a certain timedelta
|
||||
of the most recent tick. Also maintains a rolling average as
|
||||
a float on any specified fields. Events must arrive sorted by
|
||||
dt.
|
||||
of the most recent tick. The expected use of this class is to
|
||||
track events associated with a single sid. We provide simple
|
||||
functionality for averages, but anything more complicated
|
||||
should be handled by a containing class.
|
||||
"""
|
||||
|
||||
def __init__(self, delta, fields):
|
||||
self.ticks = deque()
|
||||
self.delta = delta
|
||||
self.fields = fields
|
||||
self.totals = defaultdict(float)
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.ticks)
|
||||
|
||||
def update(self, event):
|
||||
self.assert_well_formed(event)
|
||||
# Add new event and increment totals.
|
||||
self.ticks.append(event)
|
||||
for field in self.fields:
|
||||
self.totals[field] += event[field]
|
||||
|
||||
# We return a list of all out-of-range events we removed.
|
||||
out_of_range = []
|
||||
|
||||
# Clear out expired events, decrementing totals.
|
||||
# newest oldest
|
||||
# | |
|
||||
# V V
|
||||
while (self.ticks[-1].dt - self.ticks[0].dt) > self.delta:
|
||||
# popleft get ticks[0]
|
||||
|
||||
while (self.ticks[-1].dt - self.ticks[0].dt) >= self.delta:
|
||||
# popleft removes and returns ticks[0]
|
||||
popped = self.ticks.popleft()
|
||||
# Decrement totals
|
||||
for field in self.fields:
|
||||
self.totals[field] -= popped[field]
|
||||
|
||||
# Add the popped element to the list of dropped events.
|
||||
out_of_range.append(popped)
|
||||
|
||||
return out_of_range
|
||||
|
||||
def average(self, field):
|
||||
assert field in self.fields
|
||||
if len(self.ticks) == 0:
|
||||
@@ -152,10 +168,8 @@ class EventWindow(object):
|
||||
Return an ndict of all our tracked averages.
|
||||
"""
|
||||
out = ndict()
|
||||
|
||||
for field in self.fields:
|
||||
out[field] = self.average(field)
|
||||
|
||||
return out
|
||||
|
||||
def assert_well_formed(self, event):
|
||||
@@ -171,30 +185,27 @@ class EventWindow(object):
|
||||
"Event missing [%s] in EventWindow" % field
|
||||
assert isinstance(event[field], Number), \
|
||||
"Got %s for %s in EventWindow" % (event[field], field)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
averages = MovingAverage(timedelta(minutes = 1), ['price', 'vol'])
|
||||
|
||||
e = ndict()
|
||||
e.price = 1
|
||||
e.vol = 2
|
||||
e.sid = "foo"
|
||||
e.dt = datetime.now()
|
||||
|
||||
averages.update(e)
|
||||
|
||||
e = ndict()
|
||||
e.price = 2
|
||||
e.vol = 3
|
||||
e.sid = "foo"
|
||||
e.dt = datetime.now()
|
||||
averages.update(e)
|
||||
def make_event(**kwargs):
|
||||
e = ndict()
|
||||
for key, value in kwargs.iteritems():
|
||||
e[key] = value
|
||||
return e
|
||||
|
||||
e = ndict()
|
||||
e.price = 3
|
||||
e.vol = 1
|
||||
e.sid = "foo"
|
||||
e.dt = datetime.now() + timedelta(hours =1)
|
||||
averages.update(e)
|
||||
dates = date_gen(delta = timedelta(hours = 12))
|
||||
events = (
|
||||
make_event(
|
||||
sid = 'foo', price = random.random(),
|
||||
dt = date,
|
||||
type = zp.DATASOURCE_TYPE.TRADE,
|
||||
source_id = 'ds',
|
||||
vol = i
|
||||
)
|
||||
for date, i in izip(dates, xrange(100))
|
||||
)
|
||||
|
||||
gen = MovingAverageTransformGen(events, 1, ['price', 'vol'])
|
||||
|
||||
|
||||
|
||||
+13
-1
@@ -69,7 +69,19 @@ def assert_datasource_unframe_protocol(event):
|
||||
|
||||
def assert_feed_protocol(event):
|
||||
"""Assert that an event is valid input to zp.FEED_FRAME."""
|
||||
assert isinstance(feed, ndict)
|
||||
assert isinstance(event, ndict)
|
||||
assert isinstance(event.source_id, basestring)
|
||||
assert event.type in DATASOURCE_TYPE
|
||||
assert event.has_key('dt')
|
||||
|
||||
|
||||
def assert_feed_unframe_protocol(event):
|
||||
"""Same as above."""
|
||||
assert isinstance(event, ndict)
|
||||
assert event.type in DATASOURCE_TYPE
|
||||
assert event.has_key('dt')
|
||||
|
||||
|
||||
def assert_transform_protocol(event):
|
||||
pass
|
||||
|
||||
|
||||
@@ -331,6 +331,9 @@ def FEED_UNFRAME(msg):
|
||||
#TODO: anything we can do to assert more about the content of the dict?
|
||||
assert isinstance(payload, dict)
|
||||
rval = ndict(payload)
|
||||
assert rval.source_id
|
||||
assert rval.type in DATASOURCE_TYPE
|
||||
assert rval.dt
|
||||
UNPACK_DATE(rval)
|
||||
return rval
|
||||
except TypeError:
|
||||
|
||||
Reference in New Issue
Block a user