Reverting changes MovingStandardDevWindow.

Though the addition of tracking mulitple values in the window
is powerful, the changes broke behavior of existing algorithms
by changing method signatures and names.

So temporarily reverting these changes, to be pulled back in when
a way to have the multiple fields tracked with the existing API
is written, or a cutover of the API is figured out and determined.
This commit is contained in:
Eddie Hebert
2013-01-21 00:12:33 -05:00
parent 9e31d83084
commit 39f44a44f8
2 changed files with 39 additions and 117 deletions
+14 -27
View File
@@ -273,57 +273,44 @@ class TestFinanceTransforms(TestCase):
assert tnfm_volumes == expected_volumes
def test_moving_stddev(self):
stddev = MovingStandardDev(
fields=['price', 'volume'],
market_aware=False,
delta=timedelta(days=3),
)
trade_history = factory.create_trade_history(
133,
[10.0, 15.0, 13.0, 12.0],
[100, 200, 100, 200],
timedelta(days=1),
[100, 100, 100, 100],
timedelta(hours=1),
self.trading_environment
)
stddev = MovingStandardDev(
market_aware=False,
delta=timedelta(minutes=150),
)
self.source = SpecificEquityTrades(event_list=trade_history)
transformed = list(stddev.transform(self.source))
# Output values
tnfm_prices = [message.tnfm_value.price for message in transformed]
tnfm_volumes = [message.tnfm_value.volume for message in transformed]
expected_prices = [
vals = [message.tnfm_value for message in transformed]
expected = [
None,
np.std([10.0, 15.0], ddof=1),
np.std([10.0, 15.0, 13.0], ddof=1),
np.std([15.0, 13.0, 12.0], ddof=1),
]
expected_volumes = [
None,
np.std([100, 200], ddof=1),
np.std([100, 200, 100], ddof=1),
np.std([200, 100, 200], ddof=1),
]
# np has odd rounding behavior, cf.
# http://docs.scipy.org/doc/np/reference/generated/np.std.html
for v1, v2 in zip(vals, expected):
for v1, v2 in zip(tnfm_prices, expected_prices):
if v1 is None:
assert v2 is None
continue
self.assertAlmostEqual(v1, v2)
assert round(v1, 5) == round(v2, 5)
for v1, v2 in zip(tnfm_volumes, expected_volumes):
if v1 is None:
assert v2 is None
continue
self.assertAlmostEqual(v1, v2)
############################################################
# Test BatchTransform
class TestBatchTransform(TestCase):
def setUp(self):
setup_logger(self)
+25 -90
View File
@@ -17,25 +17,19 @@ from numbers import Number
from collections import defaultdict
from math import sqrt
from zipline import ndict
from zipline.transforms.utils import EventWindow, TransformMeta
class MovingStandardDev(object):
"""
Class that maintains a dictionary from sids to
MovingStandardDevWindows. For each sid, we maintain standard
deviations over any number of distinct fields. (For example, we can
maintain a sid's moving standard deviation of returns as well as its
moving standard deviation of prices.
Class that maintains a dicitonary from sids to
MovingStandardDevWindows. For each sid, we maintain a the
standard deviation of all events falling within the specified
window.
"""
__metaclass__ = TransformMeta
def __init__(self, fields='price', market_aware=True, window_length=None,
delta=None):
if isinstance(fields, basestring):
fields = [fields]
self.fields = fields
def __init__(self, market_aware=True, window_length=None, delta=None):
self.market_aware = market_aware
@@ -44,9 +38,6 @@ class MovingStandardDev(object):
# Market-aware mode only works with full-day windows.
if self.market_aware:
# Window length must be 1 or greater
assert self.window_length >= 1
assert self.window_length and not self.delta,\
"Market-aware mode only works with full-day windows."
@@ -66,113 +57,57 @@ class MovingStandardDev(object):
return MovingStandardDevWindow(
self.market_aware,
self.window_length,
self.delta,
fields=self.fields
self.delta
)
def update(self, event):
"""
Update the event window for this event's sid. Return an ndict
from tracked fields to moving standard deviations.
from tracked fields to moving averages.
"""
# This will create a new EventWindow if this is the first
# message for this sid.
window = self.sid_windows[event.sid]
window.update(event)
return window.get_stddevs()
return window.get_stddev()
class MovingStandardDevWindow(EventWindow):
"""
Iteratively calculates moving standard deviations for a particular sid
over a given time window. We can maintain standard deviations for
arbitrarily many fields on a single sid. (For example, we might track
moving standard deviation of returns as well as its moving standard
deviation of prices.) The expected functionality of this class is to be
instantiated inside a MovingStandardDev.
Iteratively calculates standard deviation for a particular sid
over a given time window. The expected functionality of this
class is to be instantiated inside a MovingStandardDev.
"""
def __init__(self, market_aware, window_length, delta, fields='price'):
def __init__(self, market_aware, days, delta):
# Call the superclass constructor to set up base EventWindow
# infrastructure.
EventWindow.__init__(self, market_aware, window_length, delta)
if isinstance(fields, basestring):
fields = [fields]
self.fields = fields
EventWindow.__init__(self, market_aware, days, delta)
self.sum = defaultdict(float)
self.sum_sqr = defaultdict(float)
self.sum = 0.0
self.sum_sqr = 0.0
def handle_add(self, event):
# Sanity check on the event.
self.assert_required_fields(event)
assert isinstance(event.price, Number)
# Increment our running totals with data from the event.
for field in self.fields:
self.sum[field] += event[field]
self.sum_sqr[field] += event[field] ** 2
self.sum += event.price
self.sum_sqr += event.price ** 2
def handle_remove(self, event):
# Sanity check on the event.
self.assert_required_fields(event)
assert isinstance(event.price, Number)
# Decrement our running totals with data from the event.
for field in self.fields:
self.sum[field] -= event[field]
self.sum_sqr[field] -= event[field] ** 2
def stdev(self, field):
"""
Calculate the standard deviation of our ticks over a single field
using a naive algorithm (see http://goo.gl/wPFtf).
"""
# Sanity check.
assert field in self.fields
# Standard deviation is undefined for no event and 0 for one event
if len(self.ticks) <= 1:
return None
# Calculate and return the standard deviation.
else:
_mean = self.sum[field] / len(self.ticks)
_var = (self.sum_sqr[field] -
self.sum[field] * _mean) / (len(self.ticks) - 1)
return sqrt(_var)
self.sum -= event.price
self.sum_sqr -= event.price ** 2
def get_stddev(self):
"""
Returns stddev with 'price' for existing algorithms.
Could possibly use existing algos.
"""
# Sample standard deviation is undefined for a single event or
# no events.
if len(self.ticks) <= 1:
if len(self) <= 1:
return None
else:
average = self.sum['price'] / len(self.ticks)
s_squared = (self.sum_sqr['price'] - self.sum['price'] * average) \
/ (len(self.ticks) - 1)
average = self.sum / len(self)
s_squared = (self.sum_sqr - self.sum * average) \
/ (len(self) - 1)
stddev = sqrt(s_squared)
return stddev
def get_stddevs(self):
"""
Return an ndict of all our tracked standard deviations.
"""
out = ndict()
for field in self.fields:
out[field] = self.stdev(field)
return out
def assert_required_fields(self, event):
"""
We only allow events with all of our tracked fields.
"""
for field in self.fields:
assert field in event, \
"Event missing [%s] in MovingStandardDevEventWindow" % field
assert isinstance(event[field], Number), \
"Got %s for %s in MovingStandardDevEventWindow" \
% (event[field], field)