Files
catalyst/zipline/gens/mavg.py
T
scottsanderson e061cb3a07 new-style vwap
2012-08-06 15:35:56 -04:00

106 lines
3.6 KiB
Python

from numbers import Number
from datetime import datetime, timedelta
from collections import defaultdict
from zipline import ndict
from zipline.gens.transform import EventWindow
class MovingAverage(object):
"""
Class that maintains a dictionary from sids to
MovingAverageEventWindows. For each sid, we maintain moving
averages over any number of distinct fields (For example, we can
maintain a sid's average volume as well as its average price.)
"""
def __init__(self, delta, fields):
self.delta = delta
self.fields = fields
# No way to pass arguments to the defaultdict factory, so we
# need to define a method to generate the correct EventWindows.
self.sid_windows = defaultdict(self.create_window)
def create_window(self):
"""Factory method for self.sid_windows."""
return MovingAverageEventWindow(self.delta, self.fields)
def update(self, event):
"""
Update the event window for this event's sid. Return an ndict
from tracked fields to moving averages.
"""
# This will create a new EventWindow if this is the first
# message for this sid.
window = self.sid_windows[event.sid]
window.update(event)
return window.get_averages()
class MovingAverageEventWindow(EventWindow):
"""
Iteratively calculates moving averages for a particular sid over a
given time window. We can maintain averages for arbitrarily many
fields on a single sid. (For example, we might track average
price as well as average volume for a single sid.) The expected
functionality of this class is to be instantiated inside a
MovingAverage transform.
"""
def __init__(self, delta, fields):
# Call the superclass constructor to set up base EventWindow
# infrastructure.
EventWindow.__init__(self, delta)
# We maintain a dictionary of totals for each of our tracked
# fields.
self.fields = fields
self.totals = defaultdict(float)
# Subclass customization for adding new events.
def handle_add(self, event):
# Sanity check on the event.
self.assert_required_fields(event)
# Increment our running totals with data from the event.
for field in self.fields:
self.totals[field] += event[field]
# Subclass customization for removing expired events.
def handle_remove(self, event):
# Decrement our running totals with data from the event.
for field in self.fields:
self.totals[field] -= event[field]
def average(self, field):
"""
Calculate the average value of our ticks over a single field.
"""
# Sanity check.
assert field in self.fields
# Averages are None by convention if we have no ticks.
if len(self.ticks) == 0:
return 0.0
# Calculate and return the average. len(self.ticks) is O(1).
else:
return self.totals[field] / len(self.ticks)
def get_averages(self):
"""
Return an ndict of all our tracked averages.
"""
out = ndict()
for field in self.fields:
out[field] = self.average(field)
return out
def assert_required_fields(self, event):
"""
We only allow events with all of our tracked fields.
"""
for field in self.fields:
assert event.has_key(field), \
"Event missing [%s] in MovingAverageEventWindow" % field
assert isinstance(event[field], Number), \
"Got %s for %s in MovingAverageEventWindow" % (event[field], field)