From 987d6d4e48cac5a62ada4e8f3d038fc37bfc3f3c Mon Sep 17 00:00:00 2001 From: warren-oneill Date: Mon, 16 Nov 2015 16:38:11 +0100 Subject: [PATCH] TST: tests removing of expired data and removes ffill in DataPanelSource --- tests/test_algorithm.py | 40 ++++++++++++++++++++++++++++ tests/test_sources.py | 9 ++++--- zipline/gens/tradesimulation.py | 11 ++++++++ zipline/sources/data_frame_source.py | 4 +-- zipline/test_algorithms.py | 10 +++++++ 5 files changed, 68 insertions(+), 6 deletions(-) diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index 074d947f..ca587370 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -59,6 +59,7 @@ from zipline.test_algorithms import ( TestTargetAlgorithm, TestTargetPercentAlgorithm, TestTargetValueAlgorithm, + TestRemoveDataAlgo, SetLongOnlyAlgorithm, SetAssetDateBoundsAlgorithm, SetMaxPositionSizeAlgorithm, @@ -1965,3 +1966,42 @@ class TestTradingAlgorithm(TestCase): analyze=analyze) results = algo.run(self.panel) self.assertIs(results, self.perf_ref) + + +class TestRemoveData(TestCase): + """ + tests if futures data is removed after expiry + """ + def setUp(self): + dt = pd.Timestamp('2015-01-01', tz='UTC') + metadata = {0: {'symbol': 'X', + 'expiration_date': dt + timedelta(days=5), + 'end_date': dt + timedelta(days=5)}, + 1: {'symbol': 'Y', + 'expiration_date': dt + timedelta(days=7), + 'end_date': dt + timedelta(days=7)}} + + env = TradingEnvironment() + env.write_data(futures_data=metadata) + + index_x = pd.date_range(dt, periods=5) + data_x = pd.DataFrame([[1, 100], [2, 100], [3, 100], [4, 100], + [5, 100]], + index=index_x, columns=['price', 'volume']) + index_y = index_x.shift(2) + data_y = pd.DataFrame([[6, 100], [7, 100], [8, 100], [9, 100], + [10, 100]], + index=index_y, columns=['price', 'volume']) + + pan = pd.Panel({0: data_x, 1: data_y}) + self.source = DataPanelSource(pan) + self.algo = TestRemoveDataAlgo(env=env) + + def test_remove_data(self): + self.algo.run(self.source) + + expected_length = [1, 2, 2, 2, 2, 1] + # initially only data for X should be sent and on the last day only + # data for Y should be sent since X is expired + for i, length in enumerate(self.algo.data): + self.assertEqual(expected_length[i], length, i) diff --git a/tests/test_sources.py b/tests/test_sources.py index ff82e2fe..0ffcf5c0 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -123,9 +123,12 @@ class TestDataFrameSource(TestCase): self.assertEqual(5, event.sid) event = next(source) self.assertEqual(4, event.sid) - event = next(source) - self.assertEqual(5, event.sid) - self.assertFalse(np.isnan(event.price)) + try: + x = False + event = next(source) + except StopIteration: + x = True + self.assertTrue(x) class TestRandomWalkSource(TestCase): diff --git a/zipline/gens/tradesimulation.py b/zipline/gens/tradesimulation.py index e63ca853..5fc4974c 100644 --- a/zipline/gens/tradesimulation.py +++ b/zipline/gens/tradesimulation.py @@ -26,6 +26,7 @@ from zipline.protocol import ( SIDData, DATASOURCE_TYPE ) +from zipline.errors import SidNotFound log = Logger('Trade Simulation') @@ -100,6 +101,16 @@ class AlgorithmSimulator(object): self.simulation_dt = date self.on_dt_changed(date) + # removing expired futures + for sid in self.current_data.keys(): + try: + if self.env.asset_finder.retrieve_asset(sid).end_date \ + < self.simulation_dt: + del self.current_data[sid] + except (AttributeError, TypeError, ValueError, + SidNotFound): + continue + # If we're still in the warmup period. Use the event to # update our universe, but don't yield any perf messages, # and don't send a snapshot to handle_data. diff --git a/zipline/sources/data_frame_source.py b/zipline/sources/data_frame_source.py index 50c0b9b7..eab04a61 100644 --- a/zipline/sources/data_frame_source.py +++ b/zipline/sources/data_frame_source.py @@ -114,7 +114,6 @@ class DataPanelSource(DataSource): # TODO is ffilling correct/necessary? # forward fill with volumes of 0 self.data = data.fillna(value={'volume': 0}) - self.data = self.data.fillna(method='ffill') # Unpack config dictionary with default values. self.start = kwargs.get('start', self.data.major_axis[0]) self.end = kwargs.get('end', self.data.major_axis[-1]) @@ -153,8 +152,7 @@ class DataPanelSource(DataSource): df = self.data.major_xs(dt) for sid, series in df.iteritems(): # Skip SIDs that can not be forward filled - if np.isnan(series['price']) and \ - sid not in self.started_sids: + if np.isnan(series['price']): continue self.started_sids.add(sid) diff --git a/zipline/test_algorithms.py b/zipline/test_algorithms.py index a56d001f..23965449 100644 --- a/zipline/test_algorithms.py +++ b/zipline/test_algorithms.py @@ -937,6 +937,16 @@ class InvalidOrderAlgorithm(TradingAlgorithm): style=style) +class TestRemoveDataAlgo(TradingAlgorithm): + def initialize(self, *args, **kwargs): + self.data = np.zeros(6) + self.i = 0 + + def handle_data(self, data): + self.data[self.i] = len(data) + self.i += 1 + + ############################## # Quantopian style algorithms