From d14739798392af9a87890f8342e3524a295bff8f Mon Sep 17 00:00:00 2001 From: Eddie Hebert Date: Mon, 28 Nov 2016 13:46:08 -0500 Subject: [PATCH] MAINT: Use a doubly linked list for contract chain. Instead of requiring the roll finder to juggle the indices into the ordered contracts, use a doubly linked list where the nodes element is the contract with members pointing to the previous and next contracts in the chain. Besides improving legibility in the roll finder code, this change is on the path to adding a predicate to exclude contracts from the chain, e.g. contracts in ED which are not in the roll schedule. Change test results for primary chain, since new implementaton does not stop at contract in which has not yet started when constructing the chain. --- tests/test_continuous_futures.py | 82 +++++++++------ zipline/assets/assets.py | 48 +++------ zipline/assets/continuous_futures.pyx | 143 +++++++++++++++++--------- zipline/assets/roll_finder.py | 43 +++----- 4 files changed, 178 insertions(+), 138 deletions(-) diff --git a/tests/test_continuous_futures.py b/tests/test_continuous_futures.py index fe3aaf8c..a90688d9 100644 --- a/tests/test_continuous_futures.py +++ b/tests/test_continuous_futures.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from collections import deque from textwrap import dedent from numpy import ( @@ -31,6 +31,7 @@ from zipline import TradingAlgorithm from zipline.assets.continuous_futures import OrderedContracts from zipline.data.minute_bars import FUTURES_MINUTES_PER_DAY from zipline.testing.fixtures import ( + WithAssetFinder, WithCreateBarData, WithDataPortal, WithBcolzFutureMinuteBarReader, @@ -415,15 +416,15 @@ def record_current_contract(algo, data): result = results.iloc[0] self.assertEqual(result.primary_len, - 5, - 'There should be only 5 contracts in the chain for ' - 'the primary, there are 6 contracts defined in the ' + 6, + 'There should be only 6 contracts in the chain for ' + 'the primary, there are 7 contracts defined in the ' 'fixture, but one has a start after the simulation ' 'date.') self.assertEqual(result.secondary_len, - 4, - 'There should be only 4 contracts in the chain for ' - 'the primary, there are 6 contracts defined in the ' + 5, + 'There should be only 5 contracts in the chain for ' + 'the primary, there are 7 contracts defined in the ' 'fixture, but one has a start after the simulation ' 'date. And the first is not included because it is ' 'the primary on that date.') @@ -438,11 +439,11 @@ def record_current_contract(algo, data): 'session.') self.assertEqual(result.primary_last, - 'FOK16', + 'FOG22', 'End of primary chain should be FOK16 on first ' 'session.') self.assertEqual(result.secondary_last, - 'FOK16', + 'FOG22', 'End of secondary chain should be FOK16 on first ' 'session.') @@ -450,15 +451,15 @@ def record_current_contract(algo, data): result = results.iloc[1] self.assertEqual(result.primary_len, - 4, - 'There should be only 4 contracts in the chain for ' - 'the primary, there are 6 contracts defined in the ' + 5, + 'There should be only 5 contracts in the chain for ' + 'the primary, there are 7 contracts defined in the ' 'fixture, but one has a start after the simulation ' 'date. The first is not included because of roll.') self.assertEqual(result.secondary_len, - 3, - 'There should be only 3 contracts in the chain for ' - 'the primary, there are 6 contracts defined in the ' + 4, + 'There should be only 4 contracts in the chain for ' + 'the primary, there are 7 contracts defined in the ' 'fixture, but one has a start after the simulation ' 'date. The first is not included because of roll, ' 'the second is the primary on that date.') @@ -475,11 +476,11 @@ def record_current_contract(algo, data): # These values remain FOJ16 because fixture data is not exhaustive # enough to move the end of the chain. self.assertEqual(result.primary_last, - 'FOK16', + 'FOG22', 'End of primary chain should be FOK16 on second ' 'session.') self.assertEqual(result.secondary_last, - 'FOK16', + 'FOG22', 'End of secondary chain should be FOK16 on second ' 'session.') @@ -968,17 +969,43 @@ def record_current_contract(algo, data): "Should remain FOH16 on next session.") -class OrderedContractsTestCase(ZiplineTestCase): +class OrderedContractsTestCase(WithAssetFinder, + ZiplineTestCase): + + @classmethod + def make_root_symbols_info(self): + return pd.DataFrame({ + 'root_symbol': ['FO'], + 'root_symbol_id': [1], + 'exchange': ['CME']}) + + @classmethod + def make_futures_info(self): + return DataFrame({ + 'root_symbol': ['FO'] * 4, + 'asset_name': ['Foo'] * 4, + 'sid': range(1, 5), + 'start_date': pd.date_range('2015-01-01', periods=4, tz="UTC"), + 'end_date': pd.date_range('2016-01-01', periods=4, tz="UTC"), + 'notice_date': pd.date_range('2016-01-01', periods=4, tz="UTC"), + 'expiration_date': pd.date_range( + '2016-01-01', periods=4, tz="UTC"), + 'expiration_date': pd.date_range( + '2016-01-01', periods=4, tz="UTC"), + 'auto_close_date': pd.date_range( + '2016-01-01', periods=4, tz="UTC"), + 'tick_size': [0.001] * 4, + 'multiplier': [1000.0] * 4, + 'exchange': ['CME'] * 4, + }) def test_contract_at_offset(self): contract_sids = array([1, 2, 3, 4], dtype=int64) start_dates = pd.date_range('2015-01-01', periods=4, tz="UTC") - auto_close_dates = pd.date_range('2016-04-01', periods=4, tz="UTC") - oc = OrderedContracts('FO', - contract_sids, - start_dates.astype('int64'), - auto_close_dates.astype('int64')) + contracts = deque(self.asset_finder.retrieve_all(contract_sids)) + + oc = OrderedContracts('FO', contracts) self.assertEquals(1, oc.contract_at_offset(1, 0, start_dates[-1].value), @@ -994,13 +1021,10 @@ class OrderedContractsTestCase(ZiplineTestCase): def test_active_chain(self): contract_sids = array([1, 2, 3, 4], dtype=int64) - start_dates = pd.date_range('2015-01-01', periods=4, tz="UTC") - auto_close_dates = pd.date_range('2016-04-01', periods=4, tz="UTC") - oc = OrderedContracts('FO', - contract_sids, - start_dates.astype('int64'), - auto_close_dates.astype('int64')) + contracts = deque(self.asset_finder.retrieve_all(contract_sids)) + + oc = OrderedContracts('FO', contracts) # Test sid 1 as days increment, as the sessions march forward # a contract should be added per day, until all defined contracts diff --git a/zipline/assets/assets.py b/zipline/assets/assets.py index 95ab7562..46e3c28d 100644 --- a/zipline/assets/assets.py +++ b/zipline/assets/assets.py @@ -15,7 +15,7 @@ from abc import ABCMeta import array import binascii -from collections import namedtuple +from collections import deque, namedtuple from numbers import Integral from operator import itemgetter, attrgetter import struct @@ -880,15 +880,14 @@ class AssetFinder(object): return sids - def _get_contract_info(self, root_symbol): + def _get_contract_sids(self, root_symbol): fc_cols = self.futures_contracts.c - fields = (fc_cols.sid, fc_cols.start_date, fc_cols.auto_close_date) - - return list(sa.select(fields).where( - (fc_cols.root_symbol == root_symbol) & - (fc_cols.start_date != pd.NaT.value)) - .order_by(fc_cols.auto_close_date).execute().fetchall()) + return [r.sid for r in + list(sa.select((fc_cols.sid,)).where( + (fc_cols.root_symbol == root_symbol) & + (fc_cols.start_date != pd.NaT.value)).order_by( + fc_cols.sid).execute().fetchall())] def _get_root_symbol_exchange(self, root_symbol): fc_cols = self.futures_root_symbols.c @@ -902,29 +901,14 @@ class AssetFinder(object): try: return self._ordered_contracts[root_symbol] except KeyError: - contract_info = self._get_contract_info(root_symbol) - size = len(contract_info) - sids = np.full(size, 0, dtype=np.int64) - start_dates = np.full(size, 0, dtype=np.int64) - auto_close_dates = np.full(size, 0, dtype=np.int64) - self._size = size - for i, info in enumerate(contract_info): - sid, start_date, auto_close_date = info - sids[i] = sid - start_dates[i] = start_date - auto_close_dates[i] = auto_close_date - oc = OrderedContracts(root_symbol, - sids, - start_dates, - auto_close_dates) + contract_sids = self._get_contract_sids(root_symbol) + contracts = deque(self.retrieve_all(contract_sids)) + oc = OrderedContracts(root_symbol, contracts) self._ordered_contracts[root_symbol] = oc return oc def create_continuous_future(self, root_symbol, offset, roll_style): oc = self.get_ordered_contracts(root_symbol) - contracts = self.retrieve_all(oc.contract_sids) - start_date = min(c.start_date for c in contracts) - end_date = max(c.end_date for c in contracts) exchange = self._get_root_symbol_exchange(root_symbol) sid = _encode_continuous_future_sid(root_symbol, offset, @@ -940,24 +924,24 @@ class AssetFinder(object): root_symbol, offset, roll_style, - start_date, - end_date, + oc.start_date, + oc.end_date, exchange, 'mul') add_cf = ContinuousFuture(add_sid, root_symbol, offset, roll_style, - start_date, - end_date, + oc.start_date, + oc.end_date, exchange, 'add') cf = ContinuousFuture(sid, root_symbol, offset, roll_style, - start_date, - end_date, + oc.start_date, + oc.end_date, exchange, adjustment_children={ 'mul': mul_cf, diff --git a/zipline/assets/continuous_futures.pyx b/zipline/assets/continuous_futures.pyx index f8eeae34..3c1e1721 100644 --- a/zipline/assets/continuous_futures.pyx +++ b/zipline/assets/continuous_futures.pyx @@ -29,8 +29,9 @@ from cpython.object cimport ( ) from cpython cimport bool -from numpy import empty +from numpy import array, empty, iinfo from numpy cimport long_t, int64_t +from pandas import Timestamp import warnings from zipline.utils.calendars import get_calendar @@ -241,6 +242,33 @@ cdef class ContinuousFuture: except KeyError: return None +cdef class ContractNode(object): + + cdef readonly object contract + cdef public object prev + cdef public object next + + def __init__(self, contract): + self.contract = contract + self.prev = None + self.next = None + + def __rshift__(self, offset): + i = 0 + curr = self + while i < offset and curr is not None: + curr = curr.next + i += 1 + return curr + + def __lshift__(self, offset): + i = 0 + curr = self + while i < offset and curr is not None: + curr = curr.prev + i += 1 + return curr + cdef class OrderedContracts(object): """ @@ -249,16 +277,12 @@ cdef class OrderedContracts(object): Used to get answers about contracts in relation to their auto close dates and start dates. - The number of contracts for a given root symbol is ~250, - which is why search by comparison over the range of contracts is - used. At this size, this is faster than using an index or np.searchsorted. - Members ------- root_symbol : str The root symbol of the future contract chain. - contract_sids : long[:] - The contract sids in sorted order of occurrence. + contracts : deque + The contracts in the chain in order of occurrence. start_dates : long[:] The start dates of the contracts in the chain. Corresponds by index with contract_sids. @@ -271,68 +295,85 @@ cdef class OrderedContracts(object): """ cdef readonly object root_symbol - cdef int _size - cdef readonly long_t[:] contract_sids - cdef readonly long_t[:] start_dates - cdef readonly long_t[:] auto_close_dates + cdef readonly object head_contract + cdef readonly dict sid_to_contract + cdef readonly int64_t _start_date + cdef readonly int64_t _end_date + + def __init__(self, object root_symbol, object contracts): - def __init__(self, - object root_symbol, - long_t[:] contract_sids, - long_t[:] start_dates, - long_t[:] auto_close_dates): - self._size = len(contract_sids) self.root_symbol = root_symbol - self.contract_sids = contract_sids - self.start_dates = start_dates - self.auto_close_dates = auto_close_dates + + self.sid_to_contract = {} + + self._start_date = iinfo('int64').max + self._end_date = 0 + + contract = contracts.popleft() + self.head_contract = ContractNode(contract) + self._start_date = min(contract.start_date.value, self._start_date) + self._end_date = max(contract.end_date.value, self._end_date) + self.sid_to_contract[contract.sid] = self.head_contract + prev = self.head_contract + while contracts: + contract = contracts.popleft() + + # Here is where a predicate would go to ensure continuity of the chain. + + self._start_date = min(contract.start_date.value, self._start_date) + self._end_date = max(contract.end_date.value, self._end_date) + + curr = ContractNode(contract) + curr.prev = prev + prev.next = curr + prev = curr + + self.sid_to_contract[contract.sid] = curr cpdef long_t contract_before_auto_close(self, long_t dt_value): """ Get the contract with next upcoming auto close date. """ - cdef Py_ssize_t i, auto_close_date - for i, auto_close_date in enumerate(self.auto_close_dates): - if auto_close_date > dt_value: + curr = self.head_contract + while curr.next is not None: + if curr.contract.auto_close_date.value > dt_value: break - return self.contract_sids[i] + curr = curr.next + return curr.contract.sid cpdef contract_at_offset(self, long_t sid, Py_ssize_t offset, int64_t start_cap): """ Get the sid which is the given sid plus the offset distance. An offset of 0 should be reflexive. """ - cdef Py_ssize_t i, j - cdef long_t[:] sids - sids = self.contract_sids - start_dates = self.start_dates + cdef Py_ssize_t i + curr = self.sid_to_contract[sid] i = 0 - j = i + offset - while j < self._size: - if sid == sids[i]: - if start_dates[j] < start_cap: - return sids[j] - else: - return None + while i < offset: + if curr.next is None: + return None + curr = curr.next i += 1 - j += 1 + if curr.contract.start_date.value <= start_cap: + return curr.contract.sid + else: + return None cpdef long_t[:] active_chain(self, long_t starting_sid, long_t dt_value): - cdef Py_ssize_t left, right, i, j - cdef long_t[:] sids, start_dates - left = 0 - right = self._size - sids = self.contract_sids - start_dates = self.start_dates + curr = self.sid_to_contract[starting_sid] + cdef list contracts = [] - for i in range(self._size): - if starting_sid == sids[i]: - left = i - break + while curr is not None: + if curr.contract.start_date.value <= dt_value: + contracts.append(curr.contract.sid) + curr = curr.next + + return array(contracts, dtype='int64') - for j in range(i, self._size): - if start_dates[j] > dt_value: - right = j - break + property start_date: + def __get__(self): + return Timestamp(self._start_date, tz='UTC') - return sids[left:right] + property end_date: + def __get__(self): + return Timestamp(self._end_date, tz='UTC') diff --git a/zipline/assets/roll_finder.py b/zipline/assets/roll_finder.py index 679ce9c5..f3ff13af 100644 --- a/zipline/assets/roll_finder.py +++ b/zipline/assets/roll_finder.py @@ -15,8 +15,6 @@ from abc import ABCMeta, abstractmethod from six import with_metaclass -from pandas import Timestamp - class RollFinder(with_metaclass(ABCMeta, object)): """ @@ -85,32 +83,30 @@ class RollFinder(with_metaclass(ABCMeta, object)): first = self._active_contract(oc, front, back, end) else: first = front - for i, sid in enumerate(oc.contract_sids): - if sid == first: - break - rolls = [(first + offset, None)] + first_contract = oc.sid_to_contract[first] + rolls = [((first_contract >> offset).contract.sid, None)] tc = self.trading_calendar sessions = tc.sessions_in_range(tc.minute_to_session_label(start), tc.minute_to_session_label(end)) if first == front: - i -= 1 + curr = first_contract << 1 else: - i -= 2 - curr = sessions[-1] - while curr > start and i > -1: - session_loc = sessions.searchsorted(curr) - front = oc.contract_sids[i] - back = oc.contract_sids[i + 1] + curr = first_contract << 2 + sess = sessions[-1] + while sess > start and curr is not None: + session_loc = sessions.searchsorted(sess) + front = curr.contract.sid + back = curr.next.contract.sid while session_loc > 0: session = sessions[session_loc] prev = sessions[session_loc - 1] if back != self._active_contract(oc, front, back, prev): - rolls.insert(0, (oc.contract_sids[i + offset], session)) + rolls.insert(0, ((curr >> offset).contract.sid, session)) break session_loc -= 1 - i -= 1 - curr = Timestamp(oc.auto_close_dates[i], - tz='UTC') + curr = curr.prev + if curr is not None: + sess = curr.contract.auto_close_date return rolls @@ -125,10 +121,8 @@ class CalendarRollFinder(RollFinder): self.asset_finder = asset_finder def _active_contract(self, oc, front, back, dt): - for i, sid in enumerate(oc.contract_sids): - if sid == front: - break - auto_close_date = Timestamp(oc.auto_close_dates[i], tz='UTC') + contract = oc.sid_to_contract[front].contract + auto_close_date = contract.auto_close_date auto_closed = dt >= auto_close_date return back if auto_closed else front @@ -153,9 +147,6 @@ class VolumeRollFinder(RollFinder): if back_vol > front_vol: return back else: - for i, sid in enumerate(oc.contract_sids): - if sid == front: - break - auto_close_date = Timestamp(oc.auto_close_dates[i], tz='UTC') - auto_closed = dt >= auto_close_date + contract = oc.sid_to_contract[front].contract + auto_closed = dt >= contract.auto_close_date return back if auto_closed else front