mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-02 14:46:54 +08:00
Reverting changes MovingStandardDevWindow.
Though the addition of tracking mulitple values in the window is powerful, the changes broke behavior of existing algorithms by changing method signatures and names. So temporarily reverting these changes, to be pulled back in when a way to have the multiple fields tracked with the existing API is written, or a cutover of the API is figured out and determined.
This commit is contained in:
+14
-27
@@ -273,57 +273,44 @@ class TestFinanceTransforms(TestCase):
|
||||
assert tnfm_volumes == expected_volumes
|
||||
|
||||
def test_moving_stddev(self):
|
||||
|
||||
stddev = MovingStandardDev(
|
||||
fields=['price', 'volume'],
|
||||
market_aware=False,
|
||||
delta=timedelta(days=3),
|
||||
)
|
||||
|
||||
trade_history = factory.create_trade_history(
|
||||
133,
|
||||
[10.0, 15.0, 13.0, 12.0],
|
||||
[100, 200, 100, 200],
|
||||
timedelta(days=1),
|
||||
[100, 100, 100, 100],
|
||||
timedelta(hours=1),
|
||||
self.trading_environment
|
||||
)
|
||||
|
||||
stddev = MovingStandardDev(
|
||||
market_aware=False,
|
||||
delta=timedelta(minutes=150),
|
||||
)
|
||||
self.source = SpecificEquityTrades(event_list=trade_history)
|
||||
|
||||
transformed = list(stddev.transform(self.source))
|
||||
# Output values
|
||||
tnfm_prices = [message.tnfm_value.price for message in transformed]
|
||||
tnfm_volumes = [message.tnfm_value.volume for message in transformed]
|
||||
|
||||
expected_prices = [
|
||||
vals = [message.tnfm_value for message in transformed]
|
||||
|
||||
expected = [
|
||||
None,
|
||||
np.std([10.0, 15.0], ddof=1),
|
||||
np.std([10.0, 15.0, 13.0], ddof=1),
|
||||
np.std([15.0, 13.0, 12.0], ddof=1),
|
||||
]
|
||||
|
||||
expected_volumes = [
|
||||
None,
|
||||
np.std([100, 200], ddof=1),
|
||||
np.std([100, 200, 100], ddof=1),
|
||||
np.std([200, 100, 200], ddof=1),
|
||||
]
|
||||
# np has odd rounding behavior, cf.
|
||||
# http://docs.scipy.org/doc/np/reference/generated/np.std.html
|
||||
for v1, v2 in zip(vals, expected):
|
||||
|
||||
for v1, v2 in zip(tnfm_prices, expected_prices):
|
||||
if v1 is None:
|
||||
assert v2 is None
|
||||
continue
|
||||
self.assertAlmostEqual(v1, v2)
|
||||
assert round(v1, 5) == round(v2, 5)
|
||||
|
||||
|
||||
for v1, v2 in zip(tnfm_volumes, expected_volumes):
|
||||
if v1 is None:
|
||||
assert v2 is None
|
||||
continue
|
||||
self.assertAlmostEqual(v1, v2)
|
||||
############################################################
|
||||
# Test BatchTransform
|
||||
|
||||
|
||||
class TestBatchTransform(TestCase):
|
||||
def setUp(self):
|
||||
setup_logger(self)
|
||||
|
||||
@@ -17,25 +17,19 @@ from numbers import Number
|
||||
from collections import defaultdict
|
||||
from math import sqrt
|
||||
|
||||
from zipline import ndict
|
||||
from zipline.transforms.utils import EventWindow, TransformMeta
|
||||
|
||||
|
||||
class MovingStandardDev(object):
|
||||
"""
|
||||
Class that maintains a dictionary from sids to
|
||||
MovingStandardDevWindows. For each sid, we maintain standard
|
||||
deviations over any number of distinct fields. (For example, we can
|
||||
maintain a sid's moving standard deviation of returns as well as its
|
||||
moving standard deviation of prices.
|
||||
Class that maintains a dicitonary from sids to
|
||||
MovingStandardDevWindows. For each sid, we maintain a the
|
||||
standard deviation of all events falling within the specified
|
||||
window.
|
||||
"""
|
||||
__metaclass__ = TransformMeta
|
||||
|
||||
def __init__(self, fields='price', market_aware=True, window_length=None,
|
||||
delta=None):
|
||||
if isinstance(fields, basestring):
|
||||
fields = [fields]
|
||||
self.fields = fields
|
||||
def __init__(self, market_aware=True, window_length=None, delta=None):
|
||||
|
||||
self.market_aware = market_aware
|
||||
|
||||
@@ -44,9 +38,6 @@ class MovingStandardDev(object):
|
||||
|
||||
# Market-aware mode only works with full-day windows.
|
||||
if self.market_aware:
|
||||
# Window length must be 1 or greater
|
||||
assert self.window_length >= 1
|
||||
|
||||
assert self.window_length and not self.delta,\
|
||||
"Market-aware mode only works with full-day windows."
|
||||
|
||||
@@ -66,113 +57,57 @@ class MovingStandardDev(object):
|
||||
return MovingStandardDevWindow(
|
||||
self.market_aware,
|
||||
self.window_length,
|
||||
self.delta,
|
||||
fields=self.fields
|
||||
self.delta
|
||||
)
|
||||
|
||||
def update(self, event):
|
||||
"""
|
||||
Update the event window for this event's sid. Return an ndict
|
||||
from tracked fields to moving standard deviations.
|
||||
from tracked fields to moving averages.
|
||||
"""
|
||||
# This will create a new EventWindow if this is the first
|
||||
# message for this sid.
|
||||
window = self.sid_windows[event.sid]
|
||||
window.update(event)
|
||||
return window.get_stddevs()
|
||||
return window.get_stddev()
|
||||
|
||||
|
||||
class MovingStandardDevWindow(EventWindow):
|
||||
"""
|
||||
Iteratively calculates moving standard deviations for a particular sid
|
||||
over a given time window. We can maintain standard deviations for
|
||||
arbitrarily many fields on a single sid. (For example, we might track
|
||||
moving standard deviation of returns as well as its moving standard
|
||||
deviation of prices.) The expected functionality of this class is to be
|
||||
instantiated inside a MovingStandardDev.
|
||||
Iteratively calculates standard deviation for a particular sid
|
||||
over a given time window. The expected functionality of this
|
||||
class is to be instantiated inside a MovingStandardDev.
|
||||
"""
|
||||
|
||||
def __init__(self, market_aware, window_length, delta, fields='price'):
|
||||
def __init__(self, market_aware, days, delta):
|
||||
# Call the superclass constructor to set up base EventWindow
|
||||
# infrastructure.
|
||||
EventWindow.__init__(self, market_aware, window_length, delta)
|
||||
if isinstance(fields, basestring):
|
||||
fields = [fields]
|
||||
self.fields = fields
|
||||
EventWindow.__init__(self, market_aware, days, delta)
|
||||
|
||||
self.sum = defaultdict(float)
|
||||
self.sum_sqr = defaultdict(float)
|
||||
self.sum = 0.0
|
||||
self.sum_sqr = 0.0
|
||||
|
||||
def handle_add(self, event):
|
||||
# Sanity check on the event.
|
||||
self.assert_required_fields(event)
|
||||
assert isinstance(event.price, Number)
|
||||
|
||||
# Increment our running totals with data from the event.
|
||||
for field in self.fields:
|
||||
self.sum[field] += event[field]
|
||||
self.sum_sqr[field] += event[field] ** 2
|
||||
self.sum += event.price
|
||||
self.sum_sqr += event.price ** 2
|
||||
|
||||
def handle_remove(self, event):
|
||||
# Sanity check on the event.
|
||||
self.assert_required_fields(event)
|
||||
assert isinstance(event.price, Number)
|
||||
|
||||
# Decrement our running totals with data from the event.
|
||||
for field in self.fields:
|
||||
self.sum[field] -= event[field]
|
||||
self.sum_sqr[field] -= event[field] ** 2
|
||||
|
||||
def stdev(self, field):
|
||||
"""
|
||||
Calculate the standard deviation of our ticks over a single field
|
||||
using a naive algorithm (see http://goo.gl/wPFtf).
|
||||
"""
|
||||
# Sanity check.
|
||||
assert field in self.fields
|
||||
# Standard deviation is undefined for no event and 0 for one event
|
||||
if len(self.ticks) <= 1:
|
||||
return None
|
||||
|
||||
# Calculate and return the standard deviation.
|
||||
else:
|
||||
_mean = self.sum[field] / len(self.ticks)
|
||||
_var = (self.sum_sqr[field] -
|
||||
self.sum[field] * _mean) / (len(self.ticks) - 1)
|
||||
return sqrt(_var)
|
||||
self.sum -= event.price
|
||||
self.sum_sqr -= event.price ** 2
|
||||
|
||||
def get_stddev(self):
|
||||
"""
|
||||
Returns stddev with 'price' for existing algorithms.
|
||||
|
||||
Could possibly use existing algos.
|
||||
"""
|
||||
# Sample standard deviation is undefined for a single event or
|
||||
# no events.
|
||||
if len(self.ticks) <= 1:
|
||||
if len(self) <= 1:
|
||||
return None
|
||||
|
||||
else:
|
||||
average = self.sum['price'] / len(self.ticks)
|
||||
s_squared = (self.sum_sqr['price'] - self.sum['price'] * average) \
|
||||
/ (len(self.ticks) - 1)
|
||||
average = self.sum / len(self)
|
||||
s_squared = (self.sum_sqr - self.sum * average) \
|
||||
/ (len(self) - 1)
|
||||
stddev = sqrt(s_squared)
|
||||
return stddev
|
||||
|
||||
def get_stddevs(self):
|
||||
"""
|
||||
Return an ndict of all our tracked standard deviations.
|
||||
"""
|
||||
out = ndict()
|
||||
for field in self.fields:
|
||||
out[field] = self.stdev(field)
|
||||
return out
|
||||
|
||||
def assert_required_fields(self, event):
|
||||
"""
|
||||
We only allow events with all of our tracked fields.
|
||||
"""
|
||||
for field in self.fields:
|
||||
assert field in event, \
|
||||
"Event missing [%s] in MovingStandardDevEventWindow" % field
|
||||
assert isinstance(event[field], Number), \
|
||||
"Got %s for %s in MovingStandardDevEventWindow" \
|
||||
% (event[field], field)
|
||||
|
||||
Reference in New Issue
Block a user