Merge pull request #1604 from quantopian/use-linked-list-for-contracts

MAINT: Use a doubly linked list for contract chain.
This commit is contained in:
Eddie Hebert
2016-11-30 06:54:28 -05:00
committed by GitHub
4 changed files with 178 additions and 138 deletions
+53 -29
View File
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import deque
from textwrap import dedent
from numpy import (
@@ -31,6 +31,7 @@ from zipline import TradingAlgorithm
from zipline.assets.continuous_futures import OrderedContracts
from zipline.data.minute_bars import FUTURES_MINUTES_PER_DAY
from zipline.testing.fixtures import (
WithAssetFinder,
WithCreateBarData,
WithDataPortal,
WithBcolzFutureMinuteBarReader,
@@ -415,15 +416,15 @@ def record_current_contract(algo, data):
result = results.iloc[0]
self.assertEqual(result.primary_len,
5,
'There should be only 5 contracts in the chain for '
'the primary, there are 6 contracts defined in the '
6,
'There should be only 6 contracts in the chain for '
'the primary, there are 7 contracts defined in the '
'fixture, but one has a start after the simulation '
'date.')
self.assertEqual(result.secondary_len,
4,
'There should be only 4 contracts in the chain for '
'the primary, there are 6 contracts defined in the '
5,
'There should be only 5 contracts in the chain for '
'the primary, there are 7 contracts defined in the '
'fixture, but one has a start after the simulation '
'date. And the first is not included because it is '
'the primary on that date.')
@@ -438,11 +439,11 @@ def record_current_contract(algo, data):
'session.')
self.assertEqual(result.primary_last,
'FOK16',
'FOG22',
'End of primary chain should be FOK16 on first '
'session.')
self.assertEqual(result.secondary_last,
'FOK16',
'FOG22',
'End of secondary chain should be FOK16 on first '
'session.')
@@ -450,15 +451,15 @@ def record_current_contract(algo, data):
result = results.iloc[1]
self.assertEqual(result.primary_len,
4,
'There should be only 4 contracts in the chain for '
'the primary, there are 6 contracts defined in the '
5,
'There should be only 5 contracts in the chain for '
'the primary, there are 7 contracts defined in the '
'fixture, but one has a start after the simulation '
'date. The first is not included because of roll.')
self.assertEqual(result.secondary_len,
3,
'There should be only 3 contracts in the chain for '
'the primary, there are 6 contracts defined in the '
4,
'There should be only 4 contracts in the chain for '
'the primary, there are 7 contracts defined in the '
'fixture, but one has a start after the simulation '
'date. The first is not included because of roll, '
'the second is the primary on that date.')
@@ -475,11 +476,11 @@ def record_current_contract(algo, data):
# These values remain FOJ16 because fixture data is not exhaustive
# enough to move the end of the chain.
self.assertEqual(result.primary_last,
'FOK16',
'FOG22',
'End of primary chain should be FOK16 on second '
'session.')
self.assertEqual(result.secondary_last,
'FOK16',
'FOG22',
'End of secondary chain should be FOK16 on second '
'session.')
@@ -968,17 +969,43 @@ def record_current_contract(algo, data):
"Should remain FOH16 on next session.")
class OrderedContractsTestCase(ZiplineTestCase):
class OrderedContractsTestCase(WithAssetFinder,
ZiplineTestCase):
@classmethod
def make_root_symbols_info(self):
return pd.DataFrame({
'root_symbol': ['FO'],
'root_symbol_id': [1],
'exchange': ['CME']})
@classmethod
def make_futures_info(self):
return DataFrame({
'root_symbol': ['FO'] * 4,
'asset_name': ['Foo'] * 4,
'sid': range(1, 5),
'start_date': pd.date_range('2015-01-01', periods=4, tz="UTC"),
'end_date': pd.date_range('2016-01-01', periods=4, tz="UTC"),
'notice_date': pd.date_range('2016-01-01', periods=4, tz="UTC"),
'expiration_date': pd.date_range(
'2016-01-01', periods=4, tz="UTC"),
'expiration_date': pd.date_range(
'2016-01-01', periods=4, tz="UTC"),
'auto_close_date': pd.date_range(
'2016-01-01', periods=4, tz="UTC"),
'tick_size': [0.001] * 4,
'multiplier': [1000.0] * 4,
'exchange': ['CME'] * 4,
})
def test_contract_at_offset(self):
contract_sids = array([1, 2, 3, 4], dtype=int64)
start_dates = pd.date_range('2015-01-01', periods=4, tz="UTC")
auto_close_dates = pd.date_range('2016-04-01', periods=4, tz="UTC")
oc = OrderedContracts('FO',
contract_sids,
start_dates.astype('int64'),
auto_close_dates.astype('int64'))
contracts = deque(self.asset_finder.retrieve_all(contract_sids))
oc = OrderedContracts('FO', contracts)
self.assertEquals(1,
oc.contract_at_offset(1, 0, start_dates[-1].value),
@@ -994,13 +1021,10 @@ class OrderedContractsTestCase(ZiplineTestCase):
def test_active_chain(self):
contract_sids = array([1, 2, 3, 4], dtype=int64)
start_dates = pd.date_range('2015-01-01', periods=4, tz="UTC")
auto_close_dates = pd.date_range('2016-04-01', periods=4, tz="UTC")
oc = OrderedContracts('FO',
contract_sids,
start_dates.astype('int64'),
auto_close_dates.astype('int64'))
contracts = deque(self.asset_finder.retrieve_all(contract_sids))
oc = OrderedContracts('FO', contracts)
# Test sid 1 as days increment, as the sessions march forward
# a contract should be added per day, until all defined contracts
+16 -32
View File
@@ -15,7 +15,7 @@
from abc import ABCMeta
import array
import binascii
from collections import namedtuple
from collections import deque, namedtuple
from numbers import Integral
from operator import itemgetter, attrgetter
import struct
@@ -880,15 +880,14 @@ class AssetFinder(object):
return sids
def _get_contract_info(self, root_symbol):
def _get_contract_sids(self, root_symbol):
fc_cols = self.futures_contracts.c
fields = (fc_cols.sid, fc_cols.start_date, fc_cols.auto_close_date)
return list(sa.select(fields).where(
(fc_cols.root_symbol == root_symbol) &
(fc_cols.start_date != pd.NaT.value))
.order_by(fc_cols.auto_close_date).execute().fetchall())
return [r.sid for r in
list(sa.select((fc_cols.sid,)).where(
(fc_cols.root_symbol == root_symbol) &
(fc_cols.start_date != pd.NaT.value)).order_by(
fc_cols.sid).execute().fetchall())]
def _get_root_symbol_exchange(self, root_symbol):
fc_cols = self.futures_root_symbols.c
@@ -902,29 +901,14 @@ class AssetFinder(object):
try:
return self._ordered_contracts[root_symbol]
except KeyError:
contract_info = self._get_contract_info(root_symbol)
size = len(contract_info)
sids = np.full(size, 0, dtype=np.int64)
start_dates = np.full(size, 0, dtype=np.int64)
auto_close_dates = np.full(size, 0, dtype=np.int64)
self._size = size
for i, info in enumerate(contract_info):
sid, start_date, auto_close_date = info
sids[i] = sid
start_dates[i] = start_date
auto_close_dates[i] = auto_close_date
oc = OrderedContracts(root_symbol,
sids,
start_dates,
auto_close_dates)
contract_sids = self._get_contract_sids(root_symbol)
contracts = deque(self.retrieve_all(contract_sids))
oc = OrderedContracts(root_symbol, contracts)
self._ordered_contracts[root_symbol] = oc
return oc
def create_continuous_future(self, root_symbol, offset, roll_style):
oc = self.get_ordered_contracts(root_symbol)
contracts = self.retrieve_all(oc.contract_sids)
start_date = min(c.start_date for c in contracts)
end_date = max(c.end_date for c in contracts)
exchange = self._get_root_symbol_exchange(root_symbol)
sid = _encode_continuous_future_sid(root_symbol, offset,
@@ -940,24 +924,24 @@ class AssetFinder(object):
root_symbol,
offset,
roll_style,
start_date,
end_date,
oc.start_date,
oc.end_date,
exchange,
'mul')
add_cf = ContinuousFuture(add_sid,
root_symbol,
offset,
roll_style,
start_date,
end_date,
oc.start_date,
oc.end_date,
exchange,
'add')
cf = ContinuousFuture(sid,
root_symbol,
offset,
roll_style,
start_date,
end_date,
oc.start_date,
oc.end_date,
exchange,
adjustment_children={
'mul': mul_cf,
+92 -51
View File
@@ -29,8 +29,9 @@ from cpython.object cimport (
)
from cpython cimport bool
from numpy import empty
from numpy import array, empty, iinfo
from numpy cimport long_t, int64_t
from pandas import Timestamp
import warnings
from zipline.utils.calendars import get_calendar
@@ -241,6 +242,33 @@ cdef class ContinuousFuture:
except KeyError:
return None
cdef class ContractNode(object):
cdef readonly object contract
cdef public object prev
cdef public object next
def __init__(self, contract):
self.contract = contract
self.prev = None
self.next = None
def __rshift__(self, offset):
i = 0
curr = self
while i < offset and curr is not None:
curr = curr.next
i += 1
return curr
def __lshift__(self, offset):
i = 0
curr = self
while i < offset and curr is not None:
curr = curr.prev
i += 1
return curr
cdef class OrderedContracts(object):
"""
@@ -249,16 +277,12 @@ cdef class OrderedContracts(object):
Used to get answers about contracts in relation to their auto close
dates and start dates.
The number of contracts for a given root symbol is ~250,
which is why search by comparison over the range of contracts is
used. At this size, this is faster than using an index or np.searchsorted.
Members
-------
root_symbol : str
The root symbol of the future contract chain.
contract_sids : long[:]
The contract sids in sorted order of occurrence.
contracts : deque
The contracts in the chain in order of occurrence.
start_dates : long[:]
The start dates of the contracts in the chain.
Corresponds by index with contract_sids.
@@ -271,68 +295,85 @@ cdef class OrderedContracts(object):
"""
cdef readonly object root_symbol
cdef int _size
cdef readonly long_t[:] contract_sids
cdef readonly long_t[:] start_dates
cdef readonly long_t[:] auto_close_dates
cdef readonly object head_contract
cdef readonly dict sid_to_contract
cdef readonly int64_t _start_date
cdef readonly int64_t _end_date
def __init__(self, object root_symbol, object contracts):
def __init__(self,
object root_symbol,
long_t[:] contract_sids,
long_t[:] start_dates,
long_t[:] auto_close_dates):
self._size = len(contract_sids)
self.root_symbol = root_symbol
self.contract_sids = contract_sids
self.start_dates = start_dates
self.auto_close_dates = auto_close_dates
self.sid_to_contract = {}
self._start_date = iinfo('int64').max
self._end_date = 0
contract = contracts.popleft()
self.head_contract = ContractNode(contract)
self._start_date = min(contract.start_date.value, self._start_date)
self._end_date = max(contract.end_date.value, self._end_date)
self.sid_to_contract[contract.sid] = self.head_contract
prev = self.head_contract
while contracts:
contract = contracts.popleft()
# Here is where a predicate would go to ensure continuity of the chain.
self._start_date = min(contract.start_date.value, self._start_date)
self._end_date = max(contract.end_date.value, self._end_date)
curr = ContractNode(contract)
curr.prev = prev
prev.next = curr
prev = curr
self.sid_to_contract[contract.sid] = curr
cpdef long_t contract_before_auto_close(self, long_t dt_value):
"""
Get the contract with next upcoming auto close date.
"""
cdef Py_ssize_t i, auto_close_date
for i, auto_close_date in enumerate(self.auto_close_dates):
if auto_close_date > dt_value:
curr = self.head_contract
while curr.next is not None:
if curr.contract.auto_close_date.value > dt_value:
break
return self.contract_sids[i]
curr = curr.next
return curr.contract.sid
cpdef contract_at_offset(self, long_t sid, Py_ssize_t offset, int64_t start_cap):
"""
Get the sid which is the given sid plus the offset distance.
An offset of 0 should be reflexive.
"""
cdef Py_ssize_t i, j
cdef long_t[:] sids
sids = self.contract_sids
start_dates = self.start_dates
cdef Py_ssize_t i
curr = self.sid_to_contract[sid]
i = 0
j = i + offset
while j < self._size:
if sid == sids[i]:
if start_dates[j] < start_cap:
return sids[j]
else:
return None
while i < offset:
if curr.next is None:
return None
curr = curr.next
i += 1
j += 1
if curr.contract.start_date.value <= start_cap:
return curr.contract.sid
else:
return None
cpdef long_t[:] active_chain(self, long_t starting_sid, long_t dt_value):
cdef Py_ssize_t left, right, i, j
cdef long_t[:] sids, start_dates
left = 0
right = self._size
sids = self.contract_sids
start_dates = self.start_dates
curr = self.sid_to_contract[starting_sid]
cdef list contracts = []
for i in range(self._size):
if starting_sid == sids[i]:
left = i
break
while curr is not None:
if curr.contract.start_date.value <= dt_value:
contracts.append(curr.contract.sid)
curr = curr.next
return array(contracts, dtype='int64')
for j in range(i, self._size):
if start_dates[j] > dt_value:
right = j
break
property start_date:
def __get__(self):
return Timestamp(self._start_date, tz='UTC')
return sids[left:right]
property end_date:
def __get__(self):
return Timestamp(self._end_date, tz='UTC')
+17 -26
View File
@@ -15,8 +15,6 @@
from abc import ABCMeta, abstractmethod
from six import with_metaclass
from pandas import Timestamp
class RollFinder(with_metaclass(ABCMeta, object)):
"""
@@ -85,32 +83,30 @@ class RollFinder(with_metaclass(ABCMeta, object)):
first = self._active_contract(oc, front, back, end)
else:
first = front
for i, sid in enumerate(oc.contract_sids):
if sid == first:
break
rolls = [(first + offset, None)]
first_contract = oc.sid_to_contract[first]
rolls = [((first_contract >> offset).contract.sid, None)]
tc = self.trading_calendar
sessions = tc.sessions_in_range(tc.minute_to_session_label(start),
tc.minute_to_session_label(end))
if first == front:
i -= 1
curr = first_contract << 1
else:
i -= 2
curr = sessions[-1]
while curr > start and i > -1:
session_loc = sessions.searchsorted(curr)
front = oc.contract_sids[i]
back = oc.contract_sids[i + 1]
curr = first_contract << 2
sess = sessions[-1]
while sess > start and curr is not None:
session_loc = sessions.searchsorted(sess)
front = curr.contract.sid
back = curr.next.contract.sid
while session_loc > 0:
session = sessions[session_loc]
prev = sessions[session_loc - 1]
if back != self._active_contract(oc, front, back, prev):
rolls.insert(0, (oc.contract_sids[i + offset], session))
rolls.insert(0, ((curr >> offset).contract.sid, session))
break
session_loc -= 1
i -= 1
curr = Timestamp(oc.auto_close_dates[i],
tz='UTC')
curr = curr.prev
if curr is not None:
sess = curr.contract.auto_close_date
return rolls
@@ -125,10 +121,8 @@ class CalendarRollFinder(RollFinder):
self.asset_finder = asset_finder
def _active_contract(self, oc, front, back, dt):
for i, sid in enumerate(oc.contract_sids):
if sid == front:
break
auto_close_date = Timestamp(oc.auto_close_dates[i], tz='UTC')
contract = oc.sid_to_contract[front].contract
auto_close_date = contract.auto_close_date
auto_closed = dt >= auto_close_date
return back if auto_closed else front
@@ -153,9 +147,6 @@ class VolumeRollFinder(RollFinder):
if back_vol > front_vol:
return back
else:
for i, sid in enumerate(oc.contract_sids):
if sid == front:
break
auto_close_date = Timestamp(oc.auto_close_dates[i], tz='UTC')
auto_closed = dt >= auto_close_date
contract = oc.sid_to_contract[front].contract
auto_closed = dt >= contract.auto_close_date
return back if auto_closed else front