mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 19:23:51 +08:00
fixed bug with BaseTransform state and name properties.
This commit is contained in:
@@ -1,7 +1,5 @@
|
||||
from feed import Feed
|
||||
|
||||
import zipline.protocol as zp
|
||||
from zipline.protocol import COMPONENT_TYPE
|
||||
from zipline.components.aggregator import Aggregate
|
||||
|
||||
from collections import Counter
|
||||
|
||||
@@ -5,10 +5,17 @@ from zipline.transforms.base import BaseTransform
|
||||
|
||||
class MovingAverageTransform(BaseTransform):
|
||||
|
||||
def init(self, name, daycount=3):
|
||||
self.daycount = daycount
|
||||
|
||||
def init(self, name, days=3):
|
||||
self.state = {}
|
||||
self.state['name'] = name
|
||||
self.days = days
|
||||
self.by_sid = defaultdict(self._create)
|
||||
|
||||
@property
|
||||
def get_id(self):
|
||||
return self.state['name']
|
||||
|
||||
def transform(self, event):
|
||||
cur = self.by_sid[event.sid]
|
||||
cur.update(event)
|
||||
@@ -16,12 +23,12 @@ class MovingAverageTransform(BaseTransform):
|
||||
return self.state
|
||||
|
||||
def _create(self):
|
||||
return MovingAverage(self.daycount)
|
||||
return MovingAverage(self.days)
|
||||
|
||||
class MovingAverage(object):
|
||||
|
||||
def init(self, daycount):
|
||||
self.window = EventWindow(daycount)
|
||||
def __init__(self, days):
|
||||
self.window = EventWindow(days)
|
||||
self.total = 0.0
|
||||
self.average = 0.0
|
||||
|
||||
@@ -43,10 +50,10 @@ class EventWindow(object):
|
||||
Tracks a window of the event history. Use an instance to track the events
|
||||
inside your window to efficiently calculate rolling statistics.
|
||||
"""
|
||||
def init(self, daycount):
|
||||
def __init__(self, days):
|
||||
self.ticks = []
|
||||
self.dropped_ticks = []
|
||||
self.delta = timedelta(days=daycount)
|
||||
self.delta = timedelta(days=days)
|
||||
|
||||
def update(self, event):
|
||||
# add new event
|
||||
|
||||
@@ -4,8 +4,15 @@ from zipline.transforms.base import BaseTransform
|
||||
class ReturnsTransform(BaseTransform):
|
||||
|
||||
def init(self, name):
|
||||
self.state = {}
|
||||
self.state['name'] = name
|
||||
self.by_sid = defaultdict(self._create)
|
||||
|
||||
@property
|
||||
def get_id(self):
|
||||
return self.state['name']
|
||||
|
||||
|
||||
def transform(self, event):
|
||||
cur = self.by_sid[event.sid]
|
||||
cur.update(event)
|
||||
@@ -27,7 +34,6 @@ class ReturnsFromPriorClose(object):
|
||||
self.returns = 0.0
|
||||
|
||||
def update(self, event):
|
||||
next_close = None
|
||||
if self.last_close:
|
||||
change = event.price - self.last_close.price
|
||||
self.returns = change / self.last_close.price
|
||||
|
||||
@@ -7,9 +7,15 @@ from zipline.finance.movingaverage import EventWindow
|
||||
class VWAPTransform(BaseTransform):
|
||||
|
||||
def init(self, name, daycount=3):
|
||||
self.state = {}
|
||||
self.state['name'] = name
|
||||
self.daycount = daycount
|
||||
self.by_sid = defaultdict(self.create_vwap)
|
||||
|
||||
@property
|
||||
def get_id(self):
|
||||
return self.state['name']
|
||||
|
||||
def transform(self, event):
|
||||
cur = self.by_sid[event.sid]
|
||||
cur.update(event)
|
||||
@@ -24,12 +30,12 @@ class DailyVWAP(object):
|
||||
A class that tracks the volume weighted average price based on tick
|
||||
updates.
|
||||
"""
|
||||
def init(self, name, daycount=3):
|
||||
self.window = EventWindow(daycount)
|
||||
def __init__(self, days=3):
|
||||
self.window = EventWindow(days)
|
||||
self.flux = 0.0
|
||||
self.volume = 0
|
||||
self.vwap = 0.0
|
||||
self.delta = timedelta(days=daycount)
|
||||
self.delta = timedelta(days=days)
|
||||
|
||||
def update(self, event):
|
||||
|
||||
|
||||
@@ -20,7 +20,6 @@ class BaseTransform(Component):
|
||||
Parent class for feed transforms. Subclass and override transform
|
||||
method to create a new derived value from the combined feed.
|
||||
"""
|
||||
|
||||
def init(self):
|
||||
pass
|
||||
|
||||
|
||||
Reference in New Issue
Block a user