movingaverage implemented as transform

This commit is contained in:
scottsanderson
2012-07-27 17:06:07 -04:00
parent 4ff943eb34
commit f0cb4eaaed
3 changed files with 71 additions and 45 deletions
+55 -44
View File
@@ -1,7 +1,7 @@
"""
Generator versions of transforms.
"""
import random
import pytz
import logbook
import pymongo
@@ -11,18 +11,22 @@ from pymongo import ASCENDING
from datetime import datetime, timedelta
from collections import deque, defaultdict
from numbers import Number
from itertools import izip
from zipline import ndict
from zipline.gens.utils import hash_args, assert_datasource_protocol, \
assert_trade_protocol, assert_datasource_unframe_protocol, \
assert_feed_protocol
from zipline.gens.utils import hash_args, date_gen
from zipline.gens.utils import assert_feed_unframe_protocol, assert_transform_protocol
import zipline.protocol as zp
def PassthroughTransformGen(stream_in):
"""Trivial transform for event forwarding."""
# hash_args with no arguments is the same as hashlib.md5.update(":"); hashlib.md5.digest().
# hash_args with no arguments is the same as:
# hasher = hashlib.md5()
# hasher.update(":");
# hashlib.md5.digest().
namestring = "Passthrough" + hash_args()
for message in stream_in:
@@ -34,7 +38,8 @@ def PassthroughTransformGen(stream_in):
def FunctionalTransformGen(stream_in, fun, *args, **kwargs):
"""
Generic transform generator that takes each message from an in-stream
and yields the output of a function on that message.
and yields the output of a function on that message. Not sure how
useful this will be in reality, but good for testing.
"""
# TODO: Distinguish between functions and classes in hash_args.
@@ -46,13 +51,13 @@ def FunctionalTransformGen(stream_in, fun, *args, **kwargs):
assert_transform_protocol(out_value)
yield(namestring, out_value)
def StatefulTransformGen(stream_in, tnfm_class, *args, **kwargs):
"""
Generic transform generator that takes each message from an in-stream
and feeds it to a state class. For each call to update, the state
class must produce a message to be fed downstream.
"""
"""
# Create an instance of our transform class.
state = tnfm_class(*args, **kwargs)
@@ -75,8 +80,7 @@ def MovingAverageTransformGen(stream_in, days, fields):
class MovingAverage(object):
"""
Class that maintains a dictionary from sids to EventWindows
calculating an average value for the specified fields over the
specified time window. Upon receipt of each message we update the
Upon receipt of each message we update the
corresponding window and return the calculated average.
"""
@@ -105,41 +109,53 @@ class MovingAverage(object):
# message for this sid.
window = self.sid_windows[event.sid]
window.update(event)
return window.get_averages()
class EventWindow(object):
"""
Maintains a list of events that are within a certain timedelta
of the most recent tick. Also maintains a rolling average as
a float on any specified fields. Events must arrive sorted by
dt.
of the most recent tick. The expected use of this class is to
track events associated with a single sid. We provide simple
functionality for averages, but anything more complicated
should be handled by a containing class.
"""
def __init__(self, delta, fields):
self.ticks = deque()
self.delta = delta
self.fields = fields
self.totals = defaultdict(float)
def __len__(self):
return len(self.ticks)
def update(self, event):
self.assert_well_formed(event)
# Add new event and increment totals.
self.ticks.append(event)
for field in self.fields:
self.totals[field] += event[field]
# We return a list of all out-of-range events we removed.
out_of_range = []
# Clear out expired events, decrementing totals.
# newest oldest
# | |
# V V
while (self.ticks[-1].dt - self.ticks[0].dt) > self.delta:
# popleft get ticks[0]
while (self.ticks[-1].dt - self.ticks[0].dt) >= self.delta:
# popleft removes and returns ticks[0]
popped = self.ticks.popleft()
# Decrement totals
for field in self.fields:
self.totals[field] -= popped[field]
# Add the popped element to the list of dropped events.
out_of_range.append(popped)
return out_of_range
def average(self, field):
assert field in self.fields
if len(self.ticks) == 0:
@@ -152,10 +168,8 @@ class EventWindow(object):
Return an ndict of all our tracked averages.
"""
out = ndict()
for field in self.fields:
out[field] = self.average(field)
return out
def assert_well_formed(self, event):
@@ -171,30 +185,27 @@ class EventWindow(object):
"Event missing [%s] in EventWindow" % field
assert isinstance(event[field], Number), \
"Got %s for %s in EventWindow" % (event[field], field)
if __name__ == "__main__":
averages = MovingAverage(timedelta(minutes = 1), ['price', 'vol'])
e = ndict()
e.price = 1
e.vol = 2
e.sid = "foo"
e.dt = datetime.now()
averages.update(e)
e = ndict()
e.price = 2
e.vol = 3
e.sid = "foo"
e.dt = datetime.now()
averages.update(e)
def make_event(**kwargs):
e = ndict()
for key, value in kwargs.iteritems():
e[key] = value
return e
e = ndict()
e.price = 3
e.vol = 1
e.sid = "foo"
e.dt = datetime.now() + timedelta(hours =1)
averages.update(e)
dates = date_gen(delta = timedelta(hours = 12))
events = (
make_event(
sid = 'foo', price = random.random(),
dt = date,
type = zp.DATASOURCE_TYPE.TRADE,
source_id = 'ds',
vol = i
)
for date, i in izip(dates, xrange(100))
)
gen = MovingAverageTransformGen(events, 1, ['price', 'vol'])
+13 -1
View File
@@ -69,7 +69,19 @@ def assert_datasource_unframe_protocol(event):
def assert_feed_protocol(event):
"""Assert that an event is valid input to zp.FEED_FRAME."""
assert isinstance(feed, ndict)
assert isinstance(event, ndict)
assert isinstance(event.source_id, basestring)
assert event.type in DATASOURCE_TYPE
assert event.has_key('dt')
def assert_feed_unframe_protocol(event):
"""Same as above."""
assert isinstance(event, ndict)
assert event.type in DATASOURCE_TYPE
assert event.has_key('dt')
def assert_transform_protocol(event):
pass
+3
View File
@@ -331,6 +331,9 @@ def FEED_UNFRAME(msg):
#TODO: anything we can do to assert more about the content of the dict?
assert isinstance(payload, dict)
rval = ndict(payload)
assert rval.source_id
assert rval.type in DATASOURCE_TYPE
assert rval.dt
UNPACK_DATE(rval)
return rval
except TypeError: