moving toward abstract base for event window tnfms

This commit is contained in:
scottsanderson
2012-08-06 13:11:16 -04:00
parent 4655e643a4
commit 1073696965
2 changed files with 38 additions and 49 deletions
+1 -1
View File
@@ -27,7 +27,7 @@ def date_sort(stream_in, source_ids):
# Incoming messages should be the output of DATASOURCE_UNFRAME.
assert_datasource_unframe_protocol(message), \
"Bad message in date_sort: %s" % message
# Only allow messages from sources we expect.
assert message.source_id in sources, "Unexpected source: %s" % message
+37 -48
View File
@@ -174,20 +174,39 @@ class MovingAverage(object):
window.update(event)
return window.get_averages()
class EventWindow(object):
"""
Maintains a list of events that are within a certain timedelta
of the most recent tick. The expected use of this class is to
track events associated with a single sid. We provide simple
functionality for averages, but anything more complicated
should be handled by a containing class.
class EventWindow:
"""
Abstract base class for transform classes that calculate iterative
metrics on events within a given timedelta. Maintains a list of
events that are within a certain timedelta of the most recent
tick. Calls self.handle_add(event) for each event added to the
window. Calls self.handle_remove(event) for each event removed
from the window. Subclass these methods along with init(*args,
**kwargs) to calculate metrics over the window.
def __init__(self, delta, fields):
See zipline/gens/mavg.py and zipline/gens/vwap.py for example
implementations of moving average and volume-weighted average
price.
"""
# Mark this as an abstract base class.
__metaclass__ = ABCMeta
def __init__(self, delta, *args, **kwargs):
self.ticks = deque()
self.delta = delta
self.fields = fields
self.totals = defaultdict(float)
self.init(*args, **kwargs)
@abstractmethod
def init(self):
raise NotImplementedError()
@abstractmethod
def handle_add(self, event):
raise NotImplementedError()
@abstractmethod
def handle_remove(self, event):
raise NotImplementedError()
def __len__(self):
return len(self.ticks)
@@ -196,44 +215,19 @@ class EventWindow(object):
self.assert_well_formed(event)
# Add new event and increment totals.
self.ticks.append(event)
for field in self.fields:
self.totals[field] += event[field]
self.handle_add(event)
# We return a list of all out-of-range events we removed.
out_of_range = []
# Clear out expired events, decrementing totals.
# Clear out expired event.
#
# newest oldest
# | |
# V V
while (self.ticks[-1].dt - self.ticks[0].dt) >= self.delta:
# popleft removes and returns ticks[0]
while (self.ticks[-1].dt - self.ticks[0].dt) > self.delta:
# popleft removes and returns the oldest tick in self.ticks
popped = self.ticks.popleft()
# Decrement totals
for field in self.fields:
self.totals[field] -= popped[field]
# Add the popped element to the list of dropped events.
out_of_range.append(popped)
return out_of_range
def average(self, field):
assert field in self.fields
if len(self.ticks) == 0:
return 0.0
else:
return self.totals[field] / len(self.ticks)
def get_averages(self):
"""
Return an ndict of all our tracked averages.
"""
out = ndict()
# out.ticks = len(self.ticks)
for field in self.fields:
out[field] = self.average(field)
return out
# Subclasses should override handle_remove to define
# behavior for removing ticks.
self.handle_remove(popped)
def assert_well_formed(self, event):
assert isinstance(event, ndict), "Bad event in EventWindow:%s" % event
@@ -243,8 +237,3 @@ class EventWindow(object):
# Something is wrong if new event is older than previous.
assert event.dt >= self.ticks[-1].dt, \
"Events arrived out of order in EventWindow: %s -> %s" % (event, self.ticks[0])
for field in self.fields:
assert event.has_key(field), \
"Event missing [%s] in EventWindow" % field
assert isinstance(event[field], Number), \
"Got %s for %s in EventWindow" % (event[field], field)