mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-05 22:16:17 +08:00
moving toward abstract base for event window tnfms
This commit is contained in:
@@ -27,7 +27,7 @@ def date_sort(stream_in, source_ids):
|
||||
# Incoming messages should be the output of DATASOURCE_UNFRAME.
|
||||
assert_datasource_unframe_protocol(message), \
|
||||
"Bad message in date_sort: %s" % message
|
||||
|
||||
|
||||
# Only allow messages from sources we expect.
|
||||
assert message.source_id in sources, "Unexpected source: %s" % message
|
||||
|
||||
|
||||
+37
-48
@@ -174,20 +174,39 @@ class MovingAverage(object):
|
||||
window.update(event)
|
||||
return window.get_averages()
|
||||
|
||||
class EventWindow(object):
|
||||
"""
|
||||
Maintains a list of events that are within a certain timedelta
|
||||
of the most recent tick. The expected use of this class is to
|
||||
track events associated with a single sid. We provide simple
|
||||
functionality for averages, but anything more complicated
|
||||
should be handled by a containing class.
|
||||
class EventWindow:
|
||||
"""
|
||||
Abstract base class for transform classes that calculate iterative
|
||||
metrics on events within a given timedelta. Maintains a list of
|
||||
events that are within a certain timedelta of the most recent
|
||||
tick. Calls self.handle_add(event) for each event added to the
|
||||
window. Calls self.handle_remove(event) for each event removed
|
||||
from the window. Subclass these methods along with init(*args,
|
||||
**kwargs) to calculate metrics over the window.
|
||||
|
||||
def __init__(self, delta, fields):
|
||||
See zipline/gens/mavg.py and zipline/gens/vwap.py for example
|
||||
implementations of moving average and volume-weighted average
|
||||
price.
|
||||
"""
|
||||
# Mark this as an abstract base class.
|
||||
__metaclass__ = ABCMeta
|
||||
|
||||
def __init__(self, delta, *args, **kwargs):
|
||||
self.ticks = deque()
|
||||
self.delta = delta
|
||||
self.fields = fields
|
||||
self.totals = defaultdict(float)
|
||||
self.init(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def init(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def handle_add(self, event):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def handle_remove(self, event):
|
||||
raise NotImplementedError()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.ticks)
|
||||
@@ -196,44 +215,19 @@ class EventWindow(object):
|
||||
self.assert_well_formed(event)
|
||||
# Add new event and increment totals.
|
||||
self.ticks.append(event)
|
||||
for field in self.fields:
|
||||
self.totals[field] += event[field]
|
||||
self.handle_add(event)
|
||||
|
||||
# We return a list of all out-of-range events we removed.
|
||||
out_of_range = []
|
||||
|
||||
# Clear out expired events, decrementing totals.
|
||||
# Clear out expired event.
|
||||
#
|
||||
# newest oldest
|
||||
# | |
|
||||
# V V
|
||||
|
||||
while (self.ticks[-1].dt - self.ticks[0].dt) >= self.delta:
|
||||
# popleft removes and returns ticks[0]
|
||||
while (self.ticks[-1].dt - self.ticks[0].dt) > self.delta:
|
||||
# popleft removes and returns the oldest tick in self.ticks
|
||||
popped = self.ticks.popleft()
|
||||
# Decrement totals
|
||||
for field in self.fields:
|
||||
self.totals[field] -= popped[field]
|
||||
# Add the popped element to the list of dropped events.
|
||||
out_of_range.append(popped)
|
||||
|
||||
return out_of_range
|
||||
|
||||
def average(self, field):
|
||||
assert field in self.fields
|
||||
if len(self.ticks) == 0:
|
||||
return 0.0
|
||||
else:
|
||||
return self.totals[field] / len(self.ticks)
|
||||
|
||||
def get_averages(self):
|
||||
"""
|
||||
Return an ndict of all our tracked averages.
|
||||
"""
|
||||
out = ndict()
|
||||
# out.ticks = len(self.ticks)
|
||||
for field in self.fields:
|
||||
out[field] = self.average(field)
|
||||
return out
|
||||
# Subclasses should override handle_remove to define
|
||||
# behavior for removing ticks.
|
||||
self.handle_remove(popped)
|
||||
|
||||
def assert_well_formed(self, event):
|
||||
assert isinstance(event, ndict), "Bad event in EventWindow:%s" % event
|
||||
@@ -243,8 +237,3 @@ class EventWindow(object):
|
||||
# Something is wrong if new event is older than previous.
|
||||
assert event.dt >= self.ticks[-1].dt, \
|
||||
"Events arrived out of order in EventWindow: %s -> %s" % (event, self.ticks[0])
|
||||
for field in self.fields:
|
||||
assert event.has_key(field), \
|
||||
"Event missing [%s] in EventWindow" % field
|
||||
assert isinstance(event[field], Number), \
|
||||
"Got %s for %s in EventWindow" % (event[field], field)
|
||||
|
||||
Reference in New Issue
Block a user