From 107369696597c38e266263bc097608ecc3edb87d Mon Sep 17 00:00:00 2001 From: scottsanderson Date: Mon, 6 Aug 2012 13:11:16 -0400 Subject: [PATCH] moving toward abstract base for event window tnfms --- zipline/gens/sort.py | 2 +- zipline/gens/transform.py | 85 +++++++++++++++++---------------------- 2 files changed, 38 insertions(+), 49 deletions(-) diff --git a/zipline/gens/sort.py b/zipline/gens/sort.py index 3ff5ee3f..9755da74 100644 --- a/zipline/gens/sort.py +++ b/zipline/gens/sort.py @@ -27,7 +27,7 @@ def date_sort(stream_in, source_ids): # Incoming messages should be the output of DATASOURCE_UNFRAME. assert_datasource_unframe_protocol(message), \ "Bad message in date_sort: %s" % message - + # Only allow messages from sources we expect. assert message.source_id in sources, "Unexpected source: %s" % message diff --git a/zipline/gens/transform.py b/zipline/gens/transform.py index 36e15689..ee8cd649 100644 --- a/zipline/gens/transform.py +++ b/zipline/gens/transform.py @@ -174,20 +174,39 @@ class MovingAverage(object): window.update(event) return window.get_averages() -class EventWindow(object): - """ - Maintains a list of events that are within a certain timedelta - of the most recent tick. The expected use of this class is to - track events associated with a single sid. We provide simple - functionality for averages, but anything more complicated - should be handled by a containing class. +class EventWindow: """ + Abstract base class for transform classes that calculate iterative + metrics on events within a given timedelta. Maintains a list of + events that are within a certain timedelta of the most recent + tick. Calls self.handle_add(event) for each event added to the + window. Calls self.handle_remove(event) for each event removed + from the window. Subclass these methods along with init(*args, + **kwargs) to calculate metrics over the window. - def __init__(self, delta, fields): + See zipline/gens/mavg.py and zipline/gens/vwap.py for example + implementations of moving average and volume-weighted average + price. + """ + # Mark this as an abstract base class. + __metaclass__ = ABCMeta + + def __init__(self, delta, *args, **kwargs): self.ticks = deque() self.delta = delta - self.fields = fields - self.totals = defaultdict(float) + self.init(*args, **kwargs) + + @abstractmethod + def init(self): + raise NotImplementedError() + + @abstractmethod + def handle_add(self, event): + raise NotImplementedError() + + @abstractmethod + def handle_remove(self, event): + raise NotImplementedError() def __len__(self): return len(self.ticks) @@ -196,44 +215,19 @@ class EventWindow(object): self.assert_well_formed(event) # Add new event and increment totals. self.ticks.append(event) - for field in self.fields: - self.totals[field] += event[field] + self.handle_add(event) - # We return a list of all out-of-range events we removed. - out_of_range = [] - - # Clear out expired events, decrementing totals. + # Clear out expired event. + # # newest oldest # | | # V V - - while (self.ticks[-1].dt - self.ticks[0].dt) >= self.delta: - # popleft removes and returns ticks[0] + while (self.ticks[-1].dt - self.ticks[0].dt) > self.delta: + # popleft removes and returns the oldest tick in self.ticks popped = self.ticks.popleft() - # Decrement totals - for field in self.fields: - self.totals[field] -= popped[field] - # Add the popped element to the list of dropped events. - out_of_range.append(popped) - - return out_of_range - - def average(self, field): - assert field in self.fields - if len(self.ticks) == 0: - return 0.0 - else: - return self.totals[field] / len(self.ticks) - - def get_averages(self): - """ - Return an ndict of all our tracked averages. - """ - out = ndict() - # out.ticks = len(self.ticks) - for field in self.fields: - out[field] = self.average(field) - return out + # Subclasses should override handle_remove to define + # behavior for removing ticks. + self.handle_remove(popped) def assert_well_formed(self, event): assert isinstance(event, ndict), "Bad event in EventWindow:%s" % event @@ -243,8 +237,3 @@ class EventWindow(object): # Something is wrong if new event is older than previous. assert event.dt >= self.ticks[-1].dt, \ "Events arrived out of order in EventWindow: %s -> %s" % (event, self.ticks[0]) - for field in self.fields: - assert event.has_key(field), \ - "Event missing [%s] in EventWindow" % field - assert isinstance(event[field], Number), \ - "Got %s for %s in EventWindow" % (event[field], field)