Files
catalyst/zipline/finance/movingaverage.py
T

73 lines
1.9 KiB
Python

from datetime import timedelta
from collections import defaultdict
from zipline.transforms.base import BaseTransform
class MovingAverageTransform(BaseTransform):
def init(self, name, days=3):
self.state = {}
self.state['name'] = name
self.days = days
self.by_sid = defaultdict(self._create)
@property
def get_id(self):
return self.state['name']
def transform(self, event):
cur = self.by_sid[event.sid]
cur.update(event)
self.state['value'] = cur.average
return self.state
def _create(self):
return MovingAverage(self.days)
class MovingAverage(object):
def __init__(self, days):
self.window = EventWindow(days)
self.total = 0.0
self.average = 0.0
def update(self, event):
self.window.update(event)
self.total += event.price
for dropped in self.window.dropped_ticks:
self.total -= dropped.price
if len(self.window.ticks) > 0:
self.average = self.total / len(self.window.ticks)
else:
self.average = 0.0
class EventWindow(object):
"""
Tracks a window of the event history. Use an instance to track the events
inside your window to efficiently calculate rolling statistics.
"""
def __init__(self, days):
self.ticks = []
self.dropped_ticks = []
self.delta = timedelta(days=days)
def update(self, event):
# add new event
self.ticks.append(event)
# determine which events are expired
last_date = event['dt']
first_date = last_date - self.delta
self.dropped_ticks = []
for tick in self.ticks:
if tick['dt'] <= first_date:
self.dropped_ticks.append(tick)
# remove the expired events
slice_index = len(self.dropped_ticks)
self.ticks = self.ticks[slice_index:]