diff --git a/tests/test_transforms.py b/tests/test_transforms.py index ccd38fb7..8b775990 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -273,57 +273,44 @@ class TestFinanceTransforms(TestCase): assert tnfm_volumes == expected_volumes def test_moving_stddev(self): - - stddev = MovingStandardDev( - fields=['price', 'volume'], - market_aware=False, - delta=timedelta(days=3), - ) - trade_history = factory.create_trade_history( 133, [10.0, 15.0, 13.0, 12.0], - [100, 200, 100, 200], - timedelta(days=1), + [100, 100, 100, 100], + timedelta(hours=1), self.trading_environment ) + stddev = MovingStandardDev( + market_aware=False, + delta=timedelta(minutes=150), + ) self.source = SpecificEquityTrades(event_list=trade_history) transformed = list(stddev.transform(self.source)) - # Output values - tnfm_prices = [message.tnfm_value.price for message in transformed] - tnfm_volumes = [message.tnfm_value.volume for message in transformed] - expected_prices = [ + vals = [message.tnfm_value for message in transformed] + + expected = [ None, np.std([10.0, 15.0], ddof=1), np.std([10.0, 15.0, 13.0], ddof=1), np.std([15.0, 13.0, 12.0], ddof=1), ] - expected_volumes = [ - None, - np.std([100, 200], ddof=1), - np.std([100, 200, 100], ddof=1), - np.std([200, 100, 200], ddof=1), - ] + # np has odd rounding behavior, cf. + # http://docs.scipy.org/doc/np/reference/generated/np.std.html + for v1, v2 in zip(vals, expected): - for v1, v2 in zip(tnfm_prices, expected_prices): if v1 is None: assert v2 is None continue - self.assertAlmostEqual(v1, v2) + assert round(v1, 5) == round(v2, 5) + - for v1, v2 in zip(tnfm_volumes, expected_volumes): - if v1 is None: - assert v2 is None - continue - self.assertAlmostEqual(v1, v2) ############################################################ # Test BatchTransform - class TestBatchTransform(TestCase): def setUp(self): setup_logger(self) diff --git a/zipline/transforms/stddev.py b/zipline/transforms/stddev.py index 1f144c07..bcb9fb38 100644 --- a/zipline/transforms/stddev.py +++ b/zipline/transforms/stddev.py @@ -17,25 +17,19 @@ from numbers import Number from collections import defaultdict from math import sqrt -from zipline import ndict from zipline.transforms.utils import EventWindow, TransformMeta class MovingStandardDev(object): """ - Class that maintains a dictionary from sids to - MovingStandardDevWindows. For each sid, we maintain standard - deviations over any number of distinct fields. (For example, we can - maintain a sid's moving standard deviation of returns as well as its - moving standard deviation of prices. + Class that maintains a dicitonary from sids to + MovingStandardDevWindows. For each sid, we maintain a the + standard deviation of all events falling within the specified + window. """ __metaclass__ = TransformMeta - def __init__(self, fields='price', market_aware=True, window_length=None, - delta=None): - if isinstance(fields, basestring): - fields = [fields] - self.fields = fields + def __init__(self, market_aware=True, window_length=None, delta=None): self.market_aware = market_aware @@ -44,9 +38,6 @@ class MovingStandardDev(object): # Market-aware mode only works with full-day windows. if self.market_aware: - # Window length must be 1 or greater - assert self.window_length >= 1 - assert self.window_length and not self.delta,\ "Market-aware mode only works with full-day windows." @@ -66,113 +57,57 @@ class MovingStandardDev(object): return MovingStandardDevWindow( self.market_aware, self.window_length, - self.delta, - fields=self.fields + self.delta ) def update(self, event): """ Update the event window for this event's sid. Return an ndict - from tracked fields to moving standard deviations. + from tracked fields to moving averages. """ # This will create a new EventWindow if this is the first # message for this sid. window = self.sid_windows[event.sid] window.update(event) - return window.get_stddevs() + return window.get_stddev() class MovingStandardDevWindow(EventWindow): """ - Iteratively calculates moving standard deviations for a particular sid - over a given time window. We can maintain standard deviations for - arbitrarily many fields on a single sid. (For example, we might track - moving standard deviation of returns as well as its moving standard - deviation of prices.) The expected functionality of this class is to be - instantiated inside a MovingStandardDev. + Iteratively calculates standard deviation for a particular sid + over a given time window. The expected functionality of this + class is to be instantiated inside a MovingStandardDev. """ - def __init__(self, market_aware, window_length, delta, fields='price'): + def __init__(self, market_aware, days, delta): # Call the superclass constructor to set up base EventWindow # infrastructure. - EventWindow.__init__(self, market_aware, window_length, delta) - if isinstance(fields, basestring): - fields = [fields] - self.fields = fields + EventWindow.__init__(self, market_aware, days, delta) - self.sum = defaultdict(float) - self.sum_sqr = defaultdict(float) + self.sum = 0.0 + self.sum_sqr = 0.0 def handle_add(self, event): - # Sanity check on the event. - self.assert_required_fields(event) + assert isinstance(event.price, Number) - # Increment our running totals with data from the event. - for field in self.fields: - self.sum[field] += event[field] - self.sum_sqr[field] += event[field] ** 2 + self.sum += event.price + self.sum_sqr += event.price ** 2 def handle_remove(self, event): - # Sanity check on the event. - self.assert_required_fields(event) + assert isinstance(event.price, Number) - # Decrement our running totals with data from the event. - for field in self.fields: - self.sum[field] -= event[field] - self.sum_sqr[field] -= event[field] ** 2 - - def stdev(self, field): - """ - Calculate the standard deviation of our ticks over a single field - using a naive algorithm (see http://goo.gl/wPFtf). - """ - # Sanity check. - assert field in self.fields - # Standard deviation is undefined for no event and 0 for one event - if len(self.ticks) <= 1: - return None - - # Calculate and return the standard deviation. - else: - _mean = self.sum[field] / len(self.ticks) - _var = (self.sum_sqr[field] - - self.sum[field] * _mean) / (len(self.ticks) - 1) - return sqrt(_var) + self.sum -= event.price + self.sum_sqr -= event.price ** 2 def get_stddev(self): - """ - Returns stddev with 'price' for existing algorithms. - - Could possibly use existing algos. - """ # Sample standard deviation is undefined for a single event or # no events. - if len(self.ticks) <= 1: + if len(self) <= 1: return None else: - average = self.sum['price'] / len(self.ticks) - s_squared = (self.sum_sqr['price'] - self.sum['price'] * average) \ - / (len(self.ticks) - 1) + average = self.sum / len(self) + s_squared = (self.sum_sqr - self.sum * average) \ + / (len(self) - 1) stddev = sqrt(s_squared) return stddev - - def get_stddevs(self): - """ - Return an ndict of all our tracked standard deviations. - """ - out = ndict() - for field in self.fields: - out[field] = self.stdev(field) - return out - - def assert_required_fields(self, event): - """ - We only allow events with all of our tracked fields. - """ - for field in self.fields: - assert field in event, \ - "Event missing [%s] in MovingStandardDevEventWindow" % field - assert isinstance(event[field], Number), \ - "Got %s for %s in MovingStandardDevEventWindow" \ - % (event[field], field)