From 52d75c9d31fa350680930cc01fb42b7deeca5fee Mon Sep 17 00:00:00 2001 From: fawce Date: Fri, 18 May 2012 14:42:47 -0400 Subject: [PATCH] fixed bug with BaseTransform state and name properties. --- zipline/components/merge.py | 2 -- zipline/finance/movingaverage.py | 21 ++++++++++++++------- zipline/finance/returns.py | 8 +++++++- zipline/finance/vwap.py | 12 +++++++++--- zipline/transforms/base.py | 1 - 5 files changed, 30 insertions(+), 14 deletions(-) diff --git a/zipline/components/merge.py b/zipline/components/merge.py index 689ec8cf..8a2ae7c1 100644 --- a/zipline/components/merge.py +++ b/zipline/components/merge.py @@ -1,7 +1,5 @@ -from feed import Feed import zipline.protocol as zp -from zipline.protocol import COMPONENT_TYPE from zipline.components.aggregator import Aggregate from collections import Counter diff --git a/zipline/finance/movingaverage.py b/zipline/finance/movingaverage.py index 2aee0531..f0e401fe 100644 --- a/zipline/finance/movingaverage.py +++ b/zipline/finance/movingaverage.py @@ -5,10 +5,17 @@ from zipline.transforms.base import BaseTransform class MovingAverageTransform(BaseTransform): - def init(self, name, daycount=3): - self.daycount = daycount + + def init(self, name, days=3): + self.state = {} + self.state['name'] = name + self.days = days self.by_sid = defaultdict(self._create) + @property + def get_id(self): + return self.state['name'] + def transform(self, event): cur = self.by_sid[event.sid] cur.update(event) @@ -16,12 +23,12 @@ class MovingAverageTransform(BaseTransform): return self.state def _create(self): - return MovingAverage(self.daycount) + return MovingAverage(self.days) class MovingAverage(object): - def init(self, daycount): - self.window = EventWindow(daycount) + def __init__(self, days): + self.window = EventWindow(days) self.total = 0.0 self.average = 0.0 @@ -43,10 +50,10 @@ class EventWindow(object): Tracks a window of the event history. Use an instance to track the events inside your window to efficiently calculate rolling statistics. """ - def init(self, daycount): + def __init__(self, days): self.ticks = [] self.dropped_ticks = [] - self.delta = timedelta(days=daycount) + self.delta = timedelta(days=days) def update(self, event): # add new event diff --git a/zipline/finance/returns.py b/zipline/finance/returns.py index 01dfd7fd..5585f325 100644 --- a/zipline/finance/returns.py +++ b/zipline/finance/returns.py @@ -4,8 +4,15 @@ from zipline.transforms.base import BaseTransform class ReturnsTransform(BaseTransform): def init(self, name): + self.state = {} + self.state['name'] = name self.by_sid = defaultdict(self._create) + @property + def get_id(self): + return self.state['name'] + + def transform(self, event): cur = self.by_sid[event.sid] cur.update(event) @@ -27,7 +34,6 @@ class ReturnsFromPriorClose(object): self.returns = 0.0 def update(self, event): - next_close = None if self.last_close: change = event.price - self.last_close.price self.returns = change / self.last_close.price diff --git a/zipline/finance/vwap.py b/zipline/finance/vwap.py index 3b165a71..35eb9578 100644 --- a/zipline/finance/vwap.py +++ b/zipline/finance/vwap.py @@ -7,9 +7,15 @@ from zipline.finance.movingaverage import EventWindow class VWAPTransform(BaseTransform): def init(self, name, daycount=3): + self.state = {} + self.state['name'] = name self.daycount = daycount self.by_sid = defaultdict(self.create_vwap) + @property + def get_id(self): + return self.state['name'] + def transform(self, event): cur = self.by_sid[event.sid] cur.update(event) @@ -24,12 +30,12 @@ class DailyVWAP(object): A class that tracks the volume weighted average price based on tick updates. """ - def init(self, name, daycount=3): - self.window = EventWindow(daycount) + def __init__(self, days=3): + self.window = EventWindow(days) self.flux = 0.0 self.volume = 0 self.vwap = 0.0 - self.delta = timedelta(days=daycount) + self.delta = timedelta(days=days) def update(self, event): diff --git a/zipline/transforms/base.py b/zipline/transforms/base.py index 8ab922dc..082e52b9 100644 --- a/zipline/transforms/base.py +++ b/zipline/transforms/base.py @@ -20,7 +20,6 @@ class BaseTransform(Component): Parent class for feed transforms. Subclass and override transform method to create a new derived value from the combined feed. """ - def init(self): pass