From f0cb4eaaedb12a22ebe3281f799e1c8068493de9 Mon Sep 17 00:00:00 2001 From: scottsanderson Date: Fri, 27 Jul 2012 17:06:07 -0400 Subject: [PATCH] movingaverage implemented as transform --- zipline/gens/transform.py | 99 ++++++++++++++++++++++----------------- zipline/gens/utils.py | 14 +++++- zipline/protocol.py | 3 ++ 3 files changed, 71 insertions(+), 45 deletions(-) diff --git a/zipline/gens/transform.py b/zipline/gens/transform.py index 83cc09da..bdd85ee9 100644 --- a/zipline/gens/transform.py +++ b/zipline/gens/transform.py @@ -1,7 +1,7 @@ """ Generator versions of transforms. """ - +import random import pytz import logbook import pymongo @@ -11,18 +11,22 @@ from pymongo import ASCENDING from datetime import datetime, timedelta from collections import deque, defaultdict from numbers import Number +from itertools import izip from zipline import ndict -from zipline.gens.utils import hash_args, assert_datasource_protocol, \ - assert_trade_protocol, assert_datasource_unframe_protocol, \ - assert_feed_protocol +from zipline.gens.utils import hash_args, date_gen +from zipline.gens.utils import assert_feed_unframe_protocol, assert_transform_protocol import zipline.protocol as zp def PassthroughTransformGen(stream_in): """Trivial transform for event forwarding.""" - # hash_args with no arguments is the same as hashlib.md5.update(":"); hashlib.md5.digest(). + # hash_args with no arguments is the same as: + # hasher = hashlib.md5() + # hasher.update(":"); + # hashlib.md5.digest(). + namestring = "Passthrough" + hash_args() for message in stream_in: @@ -34,7 +38,8 @@ def PassthroughTransformGen(stream_in): def FunctionalTransformGen(stream_in, fun, *args, **kwargs): """ Generic transform generator that takes each message from an in-stream - and yields the output of a function on that message. + and yields the output of a function on that message. Not sure how + useful this will be in reality, but good for testing. """ # TODO: Distinguish between functions and classes in hash_args. @@ -46,13 +51,13 @@ def FunctionalTransformGen(stream_in, fun, *args, **kwargs): assert_transform_protocol(out_value) yield(namestring, out_value) - def StatefulTransformGen(stream_in, tnfm_class, *args, **kwargs): """ Generic transform generator that takes each message from an in-stream and feeds it to a state class. For each call to update, the state class must produce a message to be fed downstream. - """ + """ + # Create an instance of our transform class. state = tnfm_class(*args, **kwargs) @@ -75,8 +80,7 @@ def MovingAverageTransformGen(stream_in, days, fields): class MovingAverage(object): """ Class that maintains a dictionary from sids to EventWindows - calculating an average value for the specified fields over the - specified time window. Upon receipt of each message we update the + Upon receipt of each message we update the corresponding window and return the calculated average. """ @@ -105,41 +109,53 @@ class MovingAverage(object): # message for this sid. window = self.sid_windows[event.sid] window.update(event) + return window.get_averages() - class EventWindow(object): """ Maintains a list of events that are within a certain timedelta - of the most recent tick. Also maintains a rolling average as - a float on any specified fields. Events must arrive sorted by - dt. + of the most recent tick. The expected use of this class is to + track events associated with a single sid. We provide simple + functionality for averages, but anything more complicated + should be handled by a containing class. """ + def __init__(self, delta, fields): self.ticks = deque() self.delta = delta self.fields = fields self.totals = defaultdict(float) - + def __len__(self): + return len(self.ticks) + def update(self, event): self.assert_well_formed(event) # Add new event and increment totals. self.ticks.append(event) for field in self.fields: self.totals[field] += event[field] + + # We return a list of all out-of-range events we removed. + out_of_range = [] # Clear out expired events, decrementing totals. # newest oldest # | | # V V - while (self.ticks[-1].dt - self.ticks[0].dt) > self.delta: - # popleft get ticks[0] + + while (self.ticks[-1].dt - self.ticks[0].dt) >= self.delta: + # popleft removes and returns ticks[0] popped = self.ticks.popleft() # Decrement totals for field in self.fields: self.totals[field] -= popped[field] - + # Add the popped element to the list of dropped events. + out_of_range.append(popped) + + return out_of_range + def average(self, field): assert field in self.fields if len(self.ticks) == 0: @@ -152,10 +168,8 @@ class EventWindow(object): Return an ndict of all our tracked averages. """ out = ndict() - for field in self.fields: out[field] = self.average(field) - return out def assert_well_formed(self, event): @@ -171,30 +185,27 @@ class EventWindow(object): "Event missing [%s] in EventWindow" % field assert isinstance(event[field], Number), \ "Got %s for %s in EventWindow" % (event[field], field) - - + if __name__ == "__main__": - - averages = MovingAverage(timedelta(minutes = 1), ['price', 'vol']) - - e = ndict() - e.price = 1 - e.vol = 2 - e.sid = "foo" - e.dt = datetime.now() - averages.update(e) - - e = ndict() - e.price = 2 - e.vol = 3 - e.sid = "foo" - e.dt = datetime.now() - averages.update(e) + def make_event(**kwargs): + e = ndict() + for key, value in kwargs.iteritems(): + e[key] = value + return e - e = ndict() - e.price = 3 - e.vol = 1 - e.sid = "foo" - e.dt = datetime.now() + timedelta(hours =1) - averages.update(e) + dates = date_gen(delta = timedelta(hours = 12)) + events = ( + make_event( + sid = 'foo', price = random.random(), + dt = date, + type = zp.DATASOURCE_TYPE.TRADE, + source_id = 'ds', + vol = i + ) + for date, i in izip(dates, xrange(100)) + ) + + gen = MovingAverageTransformGen(events, 1, ['price', 'vol']) + + diff --git a/zipline/gens/utils.py b/zipline/gens/utils.py index e65823ef..8d168d8e 100644 --- a/zipline/gens/utils.py +++ b/zipline/gens/utils.py @@ -69,7 +69,19 @@ def assert_datasource_unframe_protocol(event): def assert_feed_protocol(event): """Assert that an event is valid input to zp.FEED_FRAME.""" - assert isinstance(feed, ndict) + assert isinstance(event, ndict) assert isinstance(event.source_id, basestring) assert event.type in DATASOURCE_TYPE assert event.has_key('dt') + + +def assert_feed_unframe_protocol(event): + """Same as above.""" + assert isinstance(event, ndict) + assert event.type in DATASOURCE_TYPE + assert event.has_key('dt') + + +def assert_transform_protocol(event): + pass + diff --git a/zipline/protocol.py b/zipline/protocol.py index ffcc9fd3..073ae0ce 100644 --- a/zipline/protocol.py +++ b/zipline/protocol.py @@ -331,6 +331,9 @@ def FEED_UNFRAME(msg): #TODO: anything we can do to assert more about the content of the dict? assert isinstance(payload, dict) rval = ndict(payload) + assert rval.source_id + assert rval.type in DATASOURCE_TYPE + assert rval.dt UNPACK_DATE(rval) return rval except TypeError: