From ef323a31654cd0cb99ea479d2ddb5697673edccb Mon Sep 17 00:00:00 2001 From: warren-oneill Date: Thu, 26 Nov 2015 16:46:28 +0100 Subject: [PATCH] ENH: adds lookup_expired_futures to asset_finder --- tests/test_algorithm.py | 27 ++++++++++++++------------- tests/test_events_through_risk.py | 7 ++++--- tests/test_sources.py | 7 +------ zipline/assets/assets.py | 27 ++++++++++++++++++++++++--- zipline/gens/tradesimulation.py | 15 +++++++-------- zipline/test_algorithms.py | 2 +- 6 files changed, 51 insertions(+), 34 deletions(-) diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index ca587370..a4d2111e 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -1973,22 +1973,24 @@ class TestRemoveData(TestCase): tests if futures data is removed after expiry """ def setUp(self): - dt = pd.Timestamp('2015-01-01', tz='UTC') - metadata = {0: {'symbol': 'X', - 'expiration_date': dt + timedelta(days=5), - 'end_date': dt + timedelta(days=5)}, - 1: {'symbol': 'Y', - 'expiration_date': dt + timedelta(days=7), - 'end_date': dt + timedelta(days=7)}} - + dt = pd.Timestamp('2015-01-02', tz='UTC') env = TradingEnvironment() + ix = env.trading_days.get_loc(dt) + + metadata = {0: {'symbol': 'X', + 'expiration_date': env.trading_days[ix + 5], + 'end_date': env.trading_days[ix + 6]}, + 1: {'symbol': 'Y', + 'expiration_date': env.trading_days[ix + 7], + 'end_date': env.trading_days[ix + 8]}} + env.write_data(futures_data=metadata) - index_x = pd.date_range(dt, periods=5) + index_x = env.trading_days[ix:ix + 5] data_x = pd.DataFrame([[1, 100], [2, 100], [3, 100], [4, 100], [5, 100]], index=index_x, columns=['price', 'volume']) - index_y = index_x.shift(2) + index_y = env.trading_days[ix:ix + 5].shift(2) data_y = pd.DataFrame([[6, 100], [7, 100], [8, 100], [9, 100], [10, 100]], index=index_y, columns=['price', 'volume']) @@ -2000,8 +2002,7 @@ class TestRemoveData(TestCase): def test_remove_data(self): self.algo.run(self.source) - expected_length = [1, 2, 2, 2, 2, 1] + expected_lengths = [1, 1, 2, 2, 2, 2, 1] # initially only data for X should be sent and on the last day only # data for Y should be sent since X is expired - for i, length in enumerate(self.algo.data): - self.assertEqual(expected_length[i], length, i) + np.testing.assert_array_equal(self.algo.data, expected_lengths) diff --git a/tests/test_events_through_risk.py b/tests/test_events_through_risk.py index b40ea1ca..91649931 100644 --- a/tests/test_events_through_risk.py +++ b/tests/test_events_through_risk.py @@ -15,6 +15,7 @@ import unittest import datetime +import pandas as pd import pytz import numpy as np @@ -77,9 +78,9 @@ class TestEventsThroughRisk(unittest.TestCase): algo = BuyAndHoldAlgorithm(sim_params=sim_params, env=self.env) - first_date = datetime.datetime(2006, 1, 3, tzinfo=pytz.utc) - second_date = datetime.datetime(2006, 1, 4, tzinfo=pytz.utc) - third_date = datetime.datetime(2006, 1, 5, tzinfo=pytz.utc) + first_date = pd.Timestamp('2006-01-03', tz='UTC') + second_date = pd.Timestamp('2006-01-04', tz='UTC') + third_date = pd.Timestamp('2006-01-05', tz='UTC') trade_bar_data = [ Event({ diff --git a/tests/test_sources.py b/tests/test_sources.py index 0ffcf5c0..7c5c099e 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -123,12 +123,7 @@ class TestDataFrameSource(TestCase): self.assertEqual(5, event.sid) event = next(source) self.assertEqual(4, event.sid) - try: - x = False - event = next(source) - except StopIteration: - x = True - self.assertTrue(x) + self.assertRaises(StopIteration, next, source) class TestRandomWalkSource(TestCase): diff --git a/zipline/assets/assets.py b/zipline/assets/assets.py index 8d059072..f0958fa4 100644 --- a/zipline/assets/assets.py +++ b/zipline/assets/assets.py @@ -266,9 +266,9 @@ class AssetFinder(object): self.equities.c.share_class_symbol == share_class_symbol, self.equities.c.start_date <= ad_value), - ).order_by( - self.equities.c.end_date.desc(), - ).execute().fetchall() + ).order_by( + self.equities.c.end_date.desc(), + ).execute().fetchall() return candidates def _get_best_candidate(self, candidates): @@ -492,6 +492,26 @@ class AssetFinder(object): return list(map(self._retrieve_futures_contract, sids)) + def lookup_expired_futures(self, start, end): + start = start.value + end = end.value + + fc_cols = self.futures_contracts.c + + nd = sa.func.nullif(fc_cols.notice_date, pd.tslib.iNaT) + ed = sa.func.nullif(fc_cols.expiration_date, pd.tslib.iNaT) + date = sa.func.coalesce(sa.func.min(nd, ed), ed, nd) + + sids = list(map( + itemgetter('sid'), + sa.select((fc_cols.sid,)).where( + (date >= start) & (date < end)).order_by( + sa.func.coalesce(ed, nd).asc() + ).execute().fetchall() + )) + + return sids + @property def sids(self): return tuple(map( @@ -741,6 +761,7 @@ class AssetFinderCachedEquities(AssetFinder): into memory and overrides the methods that lookup_symbol uses to look up those equities. """ + def __init__(self, engine): super(AssetFinderCachedEquities, self).__init__(engine) self.fuzzy_symbol_hashed_equities = {} diff --git a/zipline/gens/tradesimulation.py b/zipline/gens/tradesimulation.py index 5fc4974c..f720333b 100644 --- a/zipline/gens/tradesimulation.py +++ b/zipline/gens/tradesimulation.py @@ -26,7 +26,6 @@ from zipline.protocol import ( SIDData, DATASOURCE_TYPE ) -from zipline.errors import SidNotFound log = Logger('Trade Simulation') @@ -65,6 +64,7 @@ class AlgorithmSimulator(object): # We don't have a datetime for the current snapshot until we # receive a message. self.simulation_dt = None + self.previous_dt = self.algo_start # ============= # Logging Setup @@ -97,18 +97,17 @@ class AlgorithmSimulator(object): self._call_before_trading_start(mkt_open) for date, snapshot in stream_in: - + expired_sids = self.env.asset_finder.lookup_expired_futures( + start=self.previous_dt, end=date) + self.previous_dt = date self.simulation_dt = date self.on_dt_changed(date) # removing expired futures - for sid in self.current_data.keys(): + for sid in expired_sids: try: - if self.env.asset_finder.retrieve_asset(sid).end_date \ - < self.simulation_dt: - del self.current_data[sid] - except (AttributeError, TypeError, ValueError, - SidNotFound): + del self.current_data[sid] + except KeyError: continue # If we're still in the warmup period. Use the event to diff --git a/zipline/test_algorithms.py b/zipline/test_algorithms.py index 23965449..ffe39226 100644 --- a/zipline/test_algorithms.py +++ b/zipline/test_algorithms.py @@ -939,7 +939,7 @@ class InvalidOrderAlgorithm(TradingAlgorithm): class TestRemoveDataAlgo(TradingAlgorithm): def initialize(self, *args, **kwargs): - self.data = np.zeros(6) + self.data = np.zeros(7) self.i = 0 def handle_data(self, data):