mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 11:19:32 +08:00
106 lines
3.6 KiB
Python
106 lines
3.6 KiB
Python
from numbers import Number
|
|
from datetime import datetime, timedelta
|
|
from collections import defaultdict
|
|
|
|
from zipline import ndict
|
|
from zipline.gens.transform import EventWindow
|
|
|
|
class MovingAverage(object):
|
|
"""
|
|
Class that maintains a dictionary from sids to
|
|
MovingAverageEventWindows. For each sid, we maintain moving
|
|
averages over any number of distinct fields (For example, we can
|
|
maintain a sid's average volume as well as its average price.)
|
|
"""
|
|
|
|
def __init__(self, delta, fields):
|
|
self.delta = delta
|
|
self.fields = fields
|
|
# No way to pass arguments to the defaultdict factory, so we
|
|
# need to define a method to generate the correct EventWindows.
|
|
self.sid_windows = defaultdict(self.create_window)
|
|
|
|
def create_window(self):
|
|
"""Factory method for self.sid_windows."""
|
|
return MovingAverageEventWindow(self.delta, self.fields)
|
|
|
|
def update(self, event):
|
|
"""
|
|
Update the event window for this event's sid. Return an ndict
|
|
from tracked fields to moving averages.
|
|
"""
|
|
# This will create a new EventWindow if this is the first
|
|
# message for this sid.
|
|
window = self.sid_windows[event.sid]
|
|
window.update(event)
|
|
return window.get_averages()
|
|
|
|
class MovingAverageEventWindow(EventWindow):
|
|
"""
|
|
Iteratively calculates moving averages for a particular sid over a
|
|
given time window. We can maintain averages for arbitrarily many
|
|
fields on a single sid. (For example, we might track average
|
|
price as well as average volume for a single sid.) The expected
|
|
functionality of this class is to be instantiated inside a
|
|
MovingAverage transform.
|
|
"""
|
|
|
|
def __init__(self, delta, fields):
|
|
|
|
# Call the superclass constructor to set up base EventWindow
|
|
# infrastructure.
|
|
EventWindow.__init__(self, delta)
|
|
|
|
# We maintain a dictionary of totals for each of our tracked
|
|
# fields.
|
|
self.fields = fields
|
|
self.totals = defaultdict(float)
|
|
|
|
# Subclass customization for adding new events.
|
|
def handle_add(self, event):
|
|
# Sanity check on the event.
|
|
self.assert_required_fields(event)
|
|
# Increment our running totals with data from the event.
|
|
for field in self.fields:
|
|
self.totals[field] += event[field]
|
|
|
|
# Subclass customization for removing expired events.
|
|
def handle_remove(self, event):
|
|
# Decrement our running totals with data from the event.
|
|
for field in self.fields:
|
|
self.totals[field] -= event[field]
|
|
|
|
def average(self, field):
|
|
"""
|
|
Calculate the average value of our ticks over a single field.
|
|
"""
|
|
# Sanity check.
|
|
assert field in self.fields
|
|
|
|
# Averages are None by convention if we have no ticks.
|
|
if len(self.ticks) == 0:
|
|
return 0.0
|
|
|
|
# Calculate and return the average. len(self.ticks) is O(1).
|
|
else:
|
|
return self.totals[field] / len(self.ticks)
|
|
|
|
def get_averages(self):
|
|
"""
|
|
Return an ndict of all our tracked averages.
|
|
"""
|
|
out = ndict()
|
|
for field in self.fields:
|
|
out[field] = self.average(field)
|
|
return out
|
|
|
|
def assert_required_fields(self, event):
|
|
"""
|
|
We only allow events with all of our tracked fields.
|
|
"""
|
|
for field in self.fields:
|
|
assert event.has_key(field), \
|
|
"Event missing [%s] in MovingAverageEventWindow" % field
|
|
assert isinstance(event[field], Number), \
|
|
"Got %s for %s in MovingAverageEventWindow" % (event[field], field)
|