From 5f6839beeaa068024aa4f1e74b33cfd94331171b Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Thu, 6 Dec 2012 12:36:47 -0500 Subject: [PATCH] BUG: Refactored batch_transform unittests and fixed some bugs. --- tests/test_transforms.py | 26 +++++++++----------------- zipline/test_algorithms.py | 4 ++-- zipline/transforms/utils.py | 33 +++++++++++++++++---------------- zipline/utils/factory.py | 4 ++-- 4 files changed, 30 insertions(+), 37 deletions(-) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 2d729f87..cb620d84 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -340,27 +340,19 @@ class TestBatchTransform(TestCase): ) self.assertTrue(all( - field['arbitrary'].values.flatten() == ['test'] * 8), + field['arbitrary'].values.flatten() == + ['test'] * algo.window_length), 'arbitrary dataframe should contain only "test"' ) # test overloaded class for test_history in [algo.history_return_price_class, algo.history_return_price_decorator]: - np.testing.assert_array_equal( - range(2, 8), - test_history[2].values.flatten() - ) - - np.testing.assert_array_equal( - range(2, 8), - test_history[3].values.flatten() - ) - - np.testing.assert_array_equal( - range(4, 12), - test_history[4].values.flatten() - ) + for i in range(3, 6): + np.testing.assert_array_equal( + range(i - algo.window_length + 1, i + 1), + test_history[i].values.flatten() + ) def test_passing_of_args(self): algo = BatchTransformAlgorithm(1, kwarg='str') @@ -371,8 +363,8 @@ class TestBatchTransform(TestCase): expected_item = ((1, ), {'kwarg': 'str'}) self.assertEqual( algo.history_return_args, - [None, None, expected_item, expected_item, - expected_item, expected_item]) + [None, None, None, expected_item, expected_item, + expected_item]) class TestBatchTransformMarketAware(TestCase): diff --git a/zipline/test_algorithms.py b/zipline/test_algorithms.py index e771ca5e..d3959d12 100644 --- a/zipline/test_algorithms.py +++ b/zipline/test_algorithms.py @@ -214,7 +214,6 @@ class TimeoutAlgorithm(TradingAlgorithm): time.sleep(100) pass -from datetime import timedelta from zipline.algorithm import TradingAlgorithm from zipline.transforms import BatchTransform, batch_transform from zipline.transforms import MovingAverage @@ -237,6 +236,7 @@ class TestRegisterTransformAlgorithm(TradingAlgorithm): class ReturnPriceBatchTransform(BatchTransform): def get_value(self, data): + assert data.shape[1] == self.window_length return data.price @@ -257,7 +257,7 @@ def return_data(data, *args, **kwargs): class BatchTransformAlgorithm(TradingAlgorithm): def initialize(self, *args, **kwargs): - self.refresh_period = kwargs.pop('refresh_period', 2) + self.refresh_period = kwargs.pop('refresh_period', 1) self.window_length = kwargs.pop('window_length', 3) self.args = args diff --git a/zipline/transforms/utils.py b/zipline/transforms/utils.py index 1785400d..9fc3c4e2 100644 --- a/zipline/transforms/utils.py +++ b/zipline/transforms/utils.py @@ -345,7 +345,7 @@ class BatchTransform(EventWindow): self.last_dt = None self.updated = False - self.data = None + self.cached = None self.field_names = None @@ -373,20 +373,22 @@ class BatchTransform(EventWindow): # return newly computed or cached value return self.get_transform_value(*args, **kwargs) - def handle_add(self, event): - if not self.last_dt: - self.last_dt = event.dt - return - + def _extract_field_names(self, event): # extract field names from sids (price, volume etc), make sure # every sid has the same fields. sid_keys = [set(sid.keys()) for sid in event.data.itervalues()] assert sid_keys[0] == set.intersection(*sid_keys),\ "Each sid must have the same keys." - if self.field_names is None: - unwanted_fields = set(['portfolio', 'sid', 'dt', 'type', - 'datetime']) - self.field_names = sid_keys[0] - unwanted_fields + + unwanted_fields = set(['portfolio', 'sid', 'dt', 'type', + 'datetime']) + return sid_keys[0] - unwanted_fields + + def handle_add(self, event): + if not self.last_dt: + self.field_names = self._extract_field_names(event) + self.last_dt = event.dt + return # update trading day counters if self.last_dt.day != event.dt.day: @@ -398,13 +400,11 @@ class BatchTransform(EventWindow): self.trading_days_total >= self.window_length and self.trading_days_since_update >= self.refresh_period ): - - # Create datapanel of running event window. - self.data = self.get_data() # Setting updated to True will cause get_transform_value() # to call the user-defined batch-transform with the most # recent datapanel self.updated = True + self.full = True self.trading_days_since_update = 0 else: self.updated = False @@ -427,7 +427,8 @@ class BatchTransform(EventWindow): fields = {} for field_name in self.field_names: - sids = self.ticks[0].data.keys() + # Extract all used sids + sids = set.union(*[set(tick.data.keys()) for tick in self.ticks]) values_per_sid = {} @@ -471,11 +472,11 @@ class BatchTransform(EventWindow): has actually been updated. Otherwise, the previously, cached value will be returned. """ - if self.data is None: + if not self.full: return None if self.updated: - self.cached = self.compute_transform_value(self.data, + self.cached = self.compute_transform_value(self.get_data(), *args, **kwargs) return self.cached diff --git a/zipline/utils/factory.py b/zipline/utils/factory.py index ac2e94b2..dea8c6d0 100644 --- a/zipline/utils/factory.py +++ b/zipline/utils/factory.py @@ -271,9 +271,9 @@ def create_test_df_source(): start = pd.datetime(1990, 1, 3, 0, 0, 0, 0, pytz.utc) end = pd.datetime(1990, 1, 8, 0, 0, 0, 0, pytz.utc) index = pd.DatetimeIndex(start=start, end=end, freq=pd.datetools.day) - x = np.arange(2., len(index) * 2 + 2).reshape((-1, 2)) + x = np.arange(0, len(index)) - df = pd.DataFrame(x, index=index, columns=[0, 1]) + df = pd.DataFrame(x, index=index, columns=[0]) return DataFrameSource(df), df