diff --git a/tests/test_continuous_futures.py b/tests/test_continuous_futures.py index fe3aaf8c..a90688d9 100644 --- a/tests/test_continuous_futures.py +++ b/tests/test_continuous_futures.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from collections import deque from textwrap import dedent from numpy import ( @@ -31,6 +31,7 @@ from zipline import TradingAlgorithm from zipline.assets.continuous_futures import OrderedContracts from zipline.data.minute_bars import FUTURES_MINUTES_PER_DAY from zipline.testing.fixtures import ( + WithAssetFinder, WithCreateBarData, WithDataPortal, WithBcolzFutureMinuteBarReader, @@ -415,15 +416,15 @@ def record_current_contract(algo, data): result = results.iloc[0] self.assertEqual(result.primary_len, - 5, - 'There should be only 5 contracts in the chain for ' - 'the primary, there are 6 contracts defined in the ' + 6, + 'There should be only 6 contracts in the chain for ' + 'the primary, there are 7 contracts defined in the ' 'fixture, but one has a start after the simulation ' 'date.') self.assertEqual(result.secondary_len, - 4, - 'There should be only 4 contracts in the chain for ' - 'the primary, there are 6 contracts defined in the ' + 5, + 'There should be only 5 contracts in the chain for ' + 'the primary, there are 7 contracts defined in the ' 'fixture, but one has a start after the simulation ' 'date. And the first is not included because it is ' 'the primary on that date.') @@ -438,11 +439,11 @@ def record_current_contract(algo, data): 'session.') self.assertEqual(result.primary_last, - 'FOK16', + 'FOG22', 'End of primary chain should be FOK16 on first ' 'session.') self.assertEqual(result.secondary_last, - 'FOK16', + 'FOG22', 'End of secondary chain should be FOK16 on first ' 'session.') @@ -450,15 +451,15 @@ def record_current_contract(algo, data): result = results.iloc[1] self.assertEqual(result.primary_len, - 4, - 'There should be only 4 contracts in the chain for ' - 'the primary, there are 6 contracts defined in the ' + 5, + 'There should be only 5 contracts in the chain for ' + 'the primary, there are 7 contracts defined in the ' 'fixture, but one has a start after the simulation ' 'date. The first is not included because of roll.') self.assertEqual(result.secondary_len, - 3, - 'There should be only 3 contracts in the chain for ' - 'the primary, there are 6 contracts defined in the ' + 4, + 'There should be only 4 contracts in the chain for ' + 'the primary, there are 7 contracts defined in the ' 'fixture, but one has a start after the simulation ' 'date. The first is not included because of roll, ' 'the second is the primary on that date.') @@ -475,11 +476,11 @@ def record_current_contract(algo, data): # These values remain FOJ16 because fixture data is not exhaustive # enough to move the end of the chain. self.assertEqual(result.primary_last, - 'FOK16', + 'FOG22', 'End of primary chain should be FOK16 on second ' 'session.') self.assertEqual(result.secondary_last, - 'FOK16', + 'FOG22', 'End of secondary chain should be FOK16 on second ' 'session.') @@ -968,17 +969,43 @@ def record_current_contract(algo, data): "Should remain FOH16 on next session.") -class OrderedContractsTestCase(ZiplineTestCase): +class OrderedContractsTestCase(WithAssetFinder, + ZiplineTestCase): + + @classmethod + def make_root_symbols_info(self): + return pd.DataFrame({ + 'root_symbol': ['FO'], + 'root_symbol_id': [1], + 'exchange': ['CME']}) + + @classmethod + def make_futures_info(self): + return DataFrame({ + 'root_symbol': ['FO'] * 4, + 'asset_name': ['Foo'] * 4, + 'sid': range(1, 5), + 'start_date': pd.date_range('2015-01-01', periods=4, tz="UTC"), + 'end_date': pd.date_range('2016-01-01', periods=4, tz="UTC"), + 'notice_date': pd.date_range('2016-01-01', periods=4, tz="UTC"), + 'expiration_date': pd.date_range( + '2016-01-01', periods=4, tz="UTC"), + 'expiration_date': pd.date_range( + '2016-01-01', periods=4, tz="UTC"), + 'auto_close_date': pd.date_range( + '2016-01-01', periods=4, tz="UTC"), + 'tick_size': [0.001] * 4, + 'multiplier': [1000.0] * 4, + 'exchange': ['CME'] * 4, + }) def test_contract_at_offset(self): contract_sids = array([1, 2, 3, 4], dtype=int64) start_dates = pd.date_range('2015-01-01', periods=4, tz="UTC") - auto_close_dates = pd.date_range('2016-04-01', periods=4, tz="UTC") - oc = OrderedContracts('FO', - contract_sids, - start_dates.astype('int64'), - auto_close_dates.astype('int64')) + contracts = deque(self.asset_finder.retrieve_all(contract_sids)) + + oc = OrderedContracts('FO', contracts) self.assertEquals(1, oc.contract_at_offset(1, 0, start_dates[-1].value), @@ -994,13 +1021,10 @@ class OrderedContractsTestCase(ZiplineTestCase): def test_active_chain(self): contract_sids = array([1, 2, 3, 4], dtype=int64) - start_dates = pd.date_range('2015-01-01', periods=4, tz="UTC") - auto_close_dates = pd.date_range('2016-04-01', periods=4, tz="UTC") - oc = OrderedContracts('FO', - contract_sids, - start_dates.astype('int64'), - auto_close_dates.astype('int64')) + contracts = deque(self.asset_finder.retrieve_all(contract_sids)) + + oc = OrderedContracts('FO', contracts) # Test sid 1 as days increment, as the sessions march forward # a contract should be added per day, until all defined contracts diff --git a/zipline/assets/assets.py b/zipline/assets/assets.py index 95ab7562..46e3c28d 100644 --- a/zipline/assets/assets.py +++ b/zipline/assets/assets.py @@ -15,7 +15,7 @@ from abc import ABCMeta import array import binascii -from collections import namedtuple +from collections import deque, namedtuple from numbers import Integral from operator import itemgetter, attrgetter import struct @@ -880,15 +880,14 @@ class AssetFinder(object): return sids - def _get_contract_info(self, root_symbol): + def _get_contract_sids(self, root_symbol): fc_cols = self.futures_contracts.c - fields = (fc_cols.sid, fc_cols.start_date, fc_cols.auto_close_date) - - return list(sa.select(fields).where( - (fc_cols.root_symbol == root_symbol) & - (fc_cols.start_date != pd.NaT.value)) - .order_by(fc_cols.auto_close_date).execute().fetchall()) + return [r.sid for r in + list(sa.select((fc_cols.sid,)).where( + (fc_cols.root_symbol == root_symbol) & + (fc_cols.start_date != pd.NaT.value)).order_by( + fc_cols.sid).execute().fetchall())] def _get_root_symbol_exchange(self, root_symbol): fc_cols = self.futures_root_symbols.c @@ -902,29 +901,14 @@ class AssetFinder(object): try: return self._ordered_contracts[root_symbol] except KeyError: - contract_info = self._get_contract_info(root_symbol) - size = len(contract_info) - sids = np.full(size, 0, dtype=np.int64) - start_dates = np.full(size, 0, dtype=np.int64) - auto_close_dates = np.full(size, 0, dtype=np.int64) - self._size = size - for i, info in enumerate(contract_info): - sid, start_date, auto_close_date = info - sids[i] = sid - start_dates[i] = start_date - auto_close_dates[i] = auto_close_date - oc = OrderedContracts(root_symbol, - sids, - start_dates, - auto_close_dates) + contract_sids = self._get_contract_sids(root_symbol) + contracts = deque(self.retrieve_all(contract_sids)) + oc = OrderedContracts(root_symbol, contracts) self._ordered_contracts[root_symbol] = oc return oc def create_continuous_future(self, root_symbol, offset, roll_style): oc = self.get_ordered_contracts(root_symbol) - contracts = self.retrieve_all(oc.contract_sids) - start_date = min(c.start_date for c in contracts) - end_date = max(c.end_date for c in contracts) exchange = self._get_root_symbol_exchange(root_symbol) sid = _encode_continuous_future_sid(root_symbol, offset, @@ -940,24 +924,24 @@ class AssetFinder(object): root_symbol, offset, roll_style, - start_date, - end_date, + oc.start_date, + oc.end_date, exchange, 'mul') add_cf = ContinuousFuture(add_sid, root_symbol, offset, roll_style, - start_date, - end_date, + oc.start_date, + oc.end_date, exchange, 'add') cf = ContinuousFuture(sid, root_symbol, offset, roll_style, - start_date, - end_date, + oc.start_date, + oc.end_date, exchange, adjustment_children={ 'mul': mul_cf, diff --git a/zipline/assets/continuous_futures.pyx b/zipline/assets/continuous_futures.pyx index f8eeae34..3c1e1721 100644 --- a/zipline/assets/continuous_futures.pyx +++ b/zipline/assets/continuous_futures.pyx @@ -29,8 +29,9 @@ from cpython.object cimport ( ) from cpython cimport bool -from numpy import empty +from numpy import array, empty, iinfo from numpy cimport long_t, int64_t +from pandas import Timestamp import warnings from zipline.utils.calendars import get_calendar @@ -241,6 +242,33 @@ cdef class ContinuousFuture: except KeyError: return None +cdef class ContractNode(object): + + cdef readonly object contract + cdef public object prev + cdef public object next + + def __init__(self, contract): + self.contract = contract + self.prev = None + self.next = None + + def __rshift__(self, offset): + i = 0 + curr = self + while i < offset and curr is not None: + curr = curr.next + i += 1 + return curr + + def __lshift__(self, offset): + i = 0 + curr = self + while i < offset and curr is not None: + curr = curr.prev + i += 1 + return curr + cdef class OrderedContracts(object): """ @@ -249,16 +277,12 @@ cdef class OrderedContracts(object): Used to get answers about contracts in relation to their auto close dates and start dates. - The number of contracts for a given root symbol is ~250, - which is why search by comparison over the range of contracts is - used. At this size, this is faster than using an index or np.searchsorted. - Members ------- root_symbol : str The root symbol of the future contract chain. - contract_sids : long[:] - The contract sids in sorted order of occurrence. + contracts : deque + The contracts in the chain in order of occurrence. start_dates : long[:] The start dates of the contracts in the chain. Corresponds by index with contract_sids. @@ -271,68 +295,85 @@ cdef class OrderedContracts(object): """ cdef readonly object root_symbol - cdef int _size - cdef readonly long_t[:] contract_sids - cdef readonly long_t[:] start_dates - cdef readonly long_t[:] auto_close_dates + cdef readonly object head_contract + cdef readonly dict sid_to_contract + cdef readonly int64_t _start_date + cdef readonly int64_t _end_date + + def __init__(self, object root_symbol, object contracts): - def __init__(self, - object root_symbol, - long_t[:] contract_sids, - long_t[:] start_dates, - long_t[:] auto_close_dates): - self._size = len(contract_sids) self.root_symbol = root_symbol - self.contract_sids = contract_sids - self.start_dates = start_dates - self.auto_close_dates = auto_close_dates + + self.sid_to_contract = {} + + self._start_date = iinfo('int64').max + self._end_date = 0 + + contract = contracts.popleft() + self.head_contract = ContractNode(contract) + self._start_date = min(contract.start_date.value, self._start_date) + self._end_date = max(contract.end_date.value, self._end_date) + self.sid_to_contract[contract.sid] = self.head_contract + prev = self.head_contract + while contracts: + contract = contracts.popleft() + + # Here is where a predicate would go to ensure continuity of the chain. + + self._start_date = min(contract.start_date.value, self._start_date) + self._end_date = max(contract.end_date.value, self._end_date) + + curr = ContractNode(contract) + curr.prev = prev + prev.next = curr + prev = curr + + self.sid_to_contract[contract.sid] = curr cpdef long_t contract_before_auto_close(self, long_t dt_value): """ Get the contract with next upcoming auto close date. """ - cdef Py_ssize_t i, auto_close_date - for i, auto_close_date in enumerate(self.auto_close_dates): - if auto_close_date > dt_value: + curr = self.head_contract + while curr.next is not None: + if curr.contract.auto_close_date.value > dt_value: break - return self.contract_sids[i] + curr = curr.next + return curr.contract.sid cpdef contract_at_offset(self, long_t sid, Py_ssize_t offset, int64_t start_cap): """ Get the sid which is the given sid plus the offset distance. An offset of 0 should be reflexive. """ - cdef Py_ssize_t i, j - cdef long_t[:] sids - sids = self.contract_sids - start_dates = self.start_dates + cdef Py_ssize_t i + curr = self.sid_to_contract[sid] i = 0 - j = i + offset - while j < self._size: - if sid == sids[i]: - if start_dates[j] < start_cap: - return sids[j] - else: - return None + while i < offset: + if curr.next is None: + return None + curr = curr.next i += 1 - j += 1 + if curr.contract.start_date.value <= start_cap: + return curr.contract.sid + else: + return None cpdef long_t[:] active_chain(self, long_t starting_sid, long_t dt_value): - cdef Py_ssize_t left, right, i, j - cdef long_t[:] sids, start_dates - left = 0 - right = self._size - sids = self.contract_sids - start_dates = self.start_dates + curr = self.sid_to_contract[starting_sid] + cdef list contracts = [] - for i in range(self._size): - if starting_sid == sids[i]: - left = i - break + while curr is not None: + if curr.contract.start_date.value <= dt_value: + contracts.append(curr.contract.sid) + curr = curr.next + + return array(contracts, dtype='int64') - for j in range(i, self._size): - if start_dates[j] > dt_value: - right = j - break + property start_date: + def __get__(self): + return Timestamp(self._start_date, tz='UTC') - return sids[left:right] + property end_date: + def __get__(self): + return Timestamp(self._end_date, tz='UTC') diff --git a/zipline/assets/roll_finder.py b/zipline/assets/roll_finder.py index 679ce9c5..f3ff13af 100644 --- a/zipline/assets/roll_finder.py +++ b/zipline/assets/roll_finder.py @@ -15,8 +15,6 @@ from abc import ABCMeta, abstractmethod from six import with_metaclass -from pandas import Timestamp - class RollFinder(with_metaclass(ABCMeta, object)): """ @@ -85,32 +83,30 @@ class RollFinder(with_metaclass(ABCMeta, object)): first = self._active_contract(oc, front, back, end) else: first = front - for i, sid in enumerate(oc.contract_sids): - if sid == first: - break - rolls = [(first + offset, None)] + first_contract = oc.sid_to_contract[first] + rolls = [((first_contract >> offset).contract.sid, None)] tc = self.trading_calendar sessions = tc.sessions_in_range(tc.minute_to_session_label(start), tc.minute_to_session_label(end)) if first == front: - i -= 1 + curr = first_contract << 1 else: - i -= 2 - curr = sessions[-1] - while curr > start and i > -1: - session_loc = sessions.searchsorted(curr) - front = oc.contract_sids[i] - back = oc.contract_sids[i + 1] + curr = first_contract << 2 + sess = sessions[-1] + while sess > start and curr is not None: + session_loc = sessions.searchsorted(sess) + front = curr.contract.sid + back = curr.next.contract.sid while session_loc > 0: session = sessions[session_loc] prev = sessions[session_loc - 1] if back != self._active_contract(oc, front, back, prev): - rolls.insert(0, (oc.contract_sids[i + offset], session)) + rolls.insert(0, ((curr >> offset).contract.sid, session)) break session_loc -= 1 - i -= 1 - curr = Timestamp(oc.auto_close_dates[i], - tz='UTC') + curr = curr.prev + if curr is not None: + sess = curr.contract.auto_close_date return rolls @@ -125,10 +121,8 @@ class CalendarRollFinder(RollFinder): self.asset_finder = asset_finder def _active_contract(self, oc, front, back, dt): - for i, sid in enumerate(oc.contract_sids): - if sid == front: - break - auto_close_date = Timestamp(oc.auto_close_dates[i], tz='UTC') + contract = oc.sid_to_contract[front].contract + auto_close_date = contract.auto_close_date auto_closed = dt >= auto_close_date return back if auto_closed else front @@ -153,9 +147,6 @@ class VolumeRollFinder(RollFinder): if back_vol > front_vol: return back else: - for i, sid in enumerate(oc.contract_sids): - if sid == front: - break - auto_close_date = Timestamp(oc.auto_close_dates[i], tz='UTC') - auto_closed = dt >= auto_close_date + contract = oc.sid_to_contract[front].contract + auto_closed = dt >= contract.auto_close_date return back if auto_closed else front