from numbers import Number from datetime import datetime, timedelta from collections import defaultdict from zipline import ndict from zipline.gens.transform import EventWindow class MovingAverage(object): """ Class that maintains a dictionary from sids to MovingAverageEventWindows. For each sid, we maintain moving averages over any number of distinct fields (For example, we can maintain a sid's average volume as well as its average price.) """ def __init__(self, delta, fields): self.delta = delta self.fields = fields # No way to pass arguments to the defaultdict factory, so we # need to define a method to generate the correct EventWindows. self.sid_windows = defaultdict(self.create_window) def create_window(self): """Factory method for self.sid_windows.""" return MovingAverageEventWindow(self.delta, self.fields) def update(self, event): """ Update the event window for this event's sid. Return an ndict from tracked fields to moving averages. """ # This will create a new EventWindow if this is the first # message for this sid. window = self.sid_windows[event.sid] window.update(event) return window.get_averages() class MovingAverageEventWindow(EventWindow): """ Iteratively calculates moving averages for a particular sid over a given time window. We can maintain averages for arbitrarily many fields on a single sid. (For example, we might track average price as well as average volume for a single sid.) The expected functionality of this class is to be instantiated inside a MovingAverage transform. """ def __init__(self, delta, fields): # Call the superclass constructor to set up base EventWindow # infrastructure. EventWindow.__init__(self, delta) # We maintain a dictionary of totals for each of our tracked # fields. self.fields = fields self.totals = defaultdict(float) # Subclass customization for adding new events. def handle_add(self, event): # Sanity check on the event. self.assert_required_fields(event) # Increment our running totals with data from the event. for field in self.fields: self.totals[field] += event[field] # Subclass customization for removing expired events. def handle_remove(self, event): # Decrement our running totals with data from the event. for field in self.fields: self.totals[field] -= event[field] def average(self, field): """ Calculate the average value of our ticks over a single field. """ # Sanity check. assert field in self.fields # Averages are None by convention if we have no ticks. if len(self.ticks) == 0: return 0.0 # Calculate and return the average. len(self.ticks) is O(1). else: return self.totals[field] / len(self.ticks) def get_averages(self): """ Return an ndict of all our tracked averages. """ out = ndict() for field in self.fields: out[field] = self.average(field) return out def assert_required_fields(self, event): """ We only allow events with all of our tracked fields. """ for field in self.fields: assert event.has_key(field), \ "Event missing [%s] in MovingAverageEventWindow" % field assert isinstance(event[field], Number), \ "Got %s for %s in MovingAverageEventWindow" % (event[field], field)