From 126e9fdf2604009a2b962af8a79d838e93ec5935 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Tue, 6 Nov 2012 12:04:01 -0500 Subject: [PATCH] Fixed batch_transform window length not being market aware. Added accompanying unittest. Minor refactoring of unittests and factory. --- tests/test_transforms.py | 35 +++++++++++++++++++++++++++++++---- zipline/examples/pairtrade.py | 2 +- zipline/test_algorithms.py | 26 ++++++++++++++++++-------- zipline/transforms/utils.py | 16 ++++++++++------ zipline/utils/factory.py | 3 ++- 5 files changed, 62 insertions(+), 20 deletions(-) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index e77ad66f..3ad29435 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -15,6 +15,7 @@ import pytz import numpy as np +import pandas as pd from datetime import timedelta, datetime from unittest2 import TestCase @@ -57,8 +58,7 @@ class NoopEventWindow(EventWindow): self.removed.append(event) -class EventWindowTestCase(TestCase): - +class TestEventWindow(TestCase): def setUp(self): setup_logger(self) @@ -154,7 +154,7 @@ class EventWindowTestCase(TestCase): setup_logger(self) -class FinanceTransformsTestCase(TestCase): +class TestFinanceTransforms(TestCase): def setUp(self): self.trading_environment = factory.create_trading_environment() @@ -309,7 +309,7 @@ class FinanceTransformsTestCase(TestCase): ############################################################ # Test BatchTransform -class BatchTransformTestCase(TestCase): +class TestBatchTransform(TestCase): def setUp(self): setup_logger(self) self.source, self.df = factory.create_test_df_source() @@ -324,6 +324,9 @@ class BatchTransformTestCase(TestCase): self.assertEqual(algo.history_return_price_decorator[:2], [None, None], "First two iterations should return None") + self.assertEqual(algo.history_return_price_market_aware[:2], + [None, None], + "First two iterations should return None") # test overloaded class for test_history in [algo.history_return_price_class, @@ -354,3 +357,27 @@ class BatchTransformTestCase(TestCase): algo.history_return_args, [None, None, expected_item, expected_item, expected_item, expected_item]) + + +class TestBatchTransformMarketAware(TestCase): + def setUp(self): + setup_logger(self) + start = pd.datetime(1993, 1, 1, 0, 0, 0, 0, pytz.utc) + end = pd.datetime(1994, 1, 1, 0, 0, 0, 0, pytz.utc) + + self.data = factory.load_from_yahoo(stocks=['AAPL'], + indexes={}, + start=start, end=end) + + def test_event_window(self): + days = 50 + algo = BatchTransformAlgorithm(days=days, refresh_period=days) + algo.run(self.data) + + self.assertEqual(algo.history_return_price_market_aware[:days], + [None] * days, + "First {days} iterations should return None" + .format(days=days)) + self.assertFalse(algo.history_return_price_market_aware[days + 1] + is None, + "Window is contains too many Nones.") diff --git a/zipline/examples/pairtrade.py b/zipline/examples/pairtrade.py index 990b0cf2..cb37a605 100755 --- a/zipline/examples/pairtrade.py +++ b/zipline/examples/pairtrade.py @@ -117,7 +117,7 @@ if __name__ == '__main__': pairtrade = Pairtrade() results = pairtrade.run(data) data['spreads'] = np.nan - data.spreads[70:] = pairtrade.spreads + data.spreads[pairtrade.window_length:] = pairtrade.spreads ax1 = plt.subplot(211) data[['PEP', 'KO']].plot(ax=ax1) diff --git a/zipline/test_algorithms.py b/zipline/test_algorithms.py index f6fcaee9..98d6b497 100644 --- a/zipline/test_algorithms.py +++ b/zipline/test_algorithms.py @@ -246,33 +246,41 @@ def return_args_batch_decorator(data, *args, **kwargs): class BatchTransformAlgorithm(TradingAlgorithm): def initialize(self, *args, **kwargs): - self.history_return_price_class = [] - self.history_return_price_decorator = [] - self.history_return_args = [] - - self.days = 3 + self.refresh_period = kwargs.pop('refresh_period', 2) + self.days = kwargs.pop('days', 3) self.args = args self.kwargs = kwargs + self.history_return_price_class = [] + self.history_return_price_decorator = [] + self.history_return_args = [] + self.history_return_price_market_aware = [] + self.return_price_class = ReturnPriceBatchTransform( market_aware=False, - refresh_period=2, + refresh_period=self.refresh_period, delta=timedelta(days=self.days) ) self.return_price_decorator = return_price_batch_decorator( market_aware=False, - refresh_period=2, + refresh_period=self.refresh_period, delta=timedelta(days=self.days) ) self.return_args_batch = return_args_batch_decorator( market_aware=False, - refresh_period=2, + refresh_period=self.refresh_period, delta=timedelta(days=self.days) ) + self.return_price_market_aware = ReturnPriceBatchTransform( + market_aware=True, + refresh_period=self.refresh_period, + days=self.days + ) + self.set_slippage(FixedSlippage()) def handle_data(self, data): @@ -283,6 +291,8 @@ class BatchTransformAlgorithm(TradingAlgorithm): self.history_return_args.append( self.return_args_batch.handle_data( data, *self.args, **self.kwargs)) + self.history_return_price_market_aware.append( + self.return_price_market_aware.handle_data(data)) class SetPortfolioAlgorithm(TradingAlgorithm): diff --git a/zipline/transforms/utils.py b/zipline/transforms/utils.py index 43a38745..242a2b18 100644 --- a/zipline/transforms/utils.py +++ b/zipline/transforms/utils.py @@ -360,9 +360,10 @@ class BatchTransform(EventWindow): self.refresh_period = refresh_period self.days = days + self.trading_days_since_update = 0 self.full = False - self.last_refresh = None + self.last_dt = None self.updated = False self.data = None @@ -392,12 +393,15 @@ class BatchTransform(EventWindow): return self.get_transform_value(*args, **kwargs) def handle_add(self, event): - if not self.last_refresh: - self.last_refresh = event.dt + if not self.last_dt: + self.last_dt = event.dt return - age = event.dt - self.last_refresh - if age.days >= self.refresh_period: + if self.last_dt.day != event.dt.day: + self.last_dt = event.dt + self.trading_days_since_update += 1 + + if self.trading_days_since_update >= self.refresh_period: # Create a pandas.Panel (i.e. 3d DataFrame) from the # events in the current window. # @@ -432,7 +436,7 @@ class BatchTransform(EventWindow): self.data = pd.Panel.from_dict(fields, orient='items') self.updated = True - self.last_refresh = event.dt + self.trading_days_since_update = 0 else: self.updated = False diff --git a/zipline/utils/factory.py b/zipline/utils/factory.py index 86a19183..ebaa4898 100644 --- a/zipline/utils/factory.py +++ b/zipline/utils/factory.py @@ -274,7 +274,8 @@ def create_test_df_source(): start = pd.datetime(1990, 1, 3, 0, 0, 0, 0, pytz.utc) end = pd.datetime(1990, 1, 8, 0, 0, 0, 0, pytz.utc) index = pd.DatetimeIndex(start=start, end=end, freq=pd.datetools.day) - x = np.arange(2., 14.).reshape((6, 2)) + x = np.arange(2., len(index) * 2 + 2).reshape((-1, 2)) + df = pd.DataFrame(x, index=index, columns=[0, 1]) return DataFrameSource(df), df