mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 12:43:47 +08:00
Merge pull request #1604 from quantopian/use-linked-list-for-contracts
MAINT: Use a doubly linked list for contract chain.
This commit is contained in:
@@ -12,7 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import deque
|
||||
from textwrap import dedent
|
||||
|
||||
from numpy import (
|
||||
@@ -31,6 +31,7 @@ from zipline import TradingAlgorithm
|
||||
from zipline.assets.continuous_futures import OrderedContracts
|
||||
from zipline.data.minute_bars import FUTURES_MINUTES_PER_DAY
|
||||
from zipline.testing.fixtures import (
|
||||
WithAssetFinder,
|
||||
WithCreateBarData,
|
||||
WithDataPortal,
|
||||
WithBcolzFutureMinuteBarReader,
|
||||
@@ -415,15 +416,15 @@ def record_current_contract(algo, data):
|
||||
result = results.iloc[0]
|
||||
|
||||
self.assertEqual(result.primary_len,
|
||||
5,
|
||||
'There should be only 5 contracts in the chain for '
|
||||
'the primary, there are 6 contracts defined in the '
|
||||
6,
|
||||
'There should be only 6 contracts in the chain for '
|
||||
'the primary, there are 7 contracts defined in the '
|
||||
'fixture, but one has a start after the simulation '
|
||||
'date.')
|
||||
self.assertEqual(result.secondary_len,
|
||||
4,
|
||||
'There should be only 4 contracts in the chain for '
|
||||
'the primary, there are 6 contracts defined in the '
|
||||
5,
|
||||
'There should be only 5 contracts in the chain for '
|
||||
'the primary, there are 7 contracts defined in the '
|
||||
'fixture, but one has a start after the simulation '
|
||||
'date. And the first is not included because it is '
|
||||
'the primary on that date.')
|
||||
@@ -438,11 +439,11 @@ def record_current_contract(algo, data):
|
||||
'session.')
|
||||
|
||||
self.assertEqual(result.primary_last,
|
||||
'FOK16',
|
||||
'FOG22',
|
||||
'End of primary chain should be FOK16 on first '
|
||||
'session.')
|
||||
self.assertEqual(result.secondary_last,
|
||||
'FOK16',
|
||||
'FOG22',
|
||||
'End of secondary chain should be FOK16 on first '
|
||||
'session.')
|
||||
|
||||
@@ -450,15 +451,15 @@ def record_current_contract(algo, data):
|
||||
result = results.iloc[1]
|
||||
|
||||
self.assertEqual(result.primary_len,
|
||||
4,
|
||||
'There should be only 4 contracts in the chain for '
|
||||
'the primary, there are 6 contracts defined in the '
|
||||
5,
|
||||
'There should be only 5 contracts in the chain for '
|
||||
'the primary, there are 7 contracts defined in the '
|
||||
'fixture, but one has a start after the simulation '
|
||||
'date. The first is not included because of roll.')
|
||||
self.assertEqual(result.secondary_len,
|
||||
3,
|
||||
'There should be only 3 contracts in the chain for '
|
||||
'the primary, there are 6 contracts defined in the '
|
||||
4,
|
||||
'There should be only 4 contracts in the chain for '
|
||||
'the primary, there are 7 contracts defined in the '
|
||||
'fixture, but one has a start after the simulation '
|
||||
'date. The first is not included because of roll, '
|
||||
'the second is the primary on that date.')
|
||||
@@ -475,11 +476,11 @@ def record_current_contract(algo, data):
|
||||
# These values remain FOJ16 because fixture data is not exhaustive
|
||||
# enough to move the end of the chain.
|
||||
self.assertEqual(result.primary_last,
|
||||
'FOK16',
|
||||
'FOG22',
|
||||
'End of primary chain should be FOK16 on second '
|
||||
'session.')
|
||||
self.assertEqual(result.secondary_last,
|
||||
'FOK16',
|
||||
'FOG22',
|
||||
'End of secondary chain should be FOK16 on second '
|
||||
'session.')
|
||||
|
||||
@@ -968,17 +969,43 @@ def record_current_contract(algo, data):
|
||||
"Should remain FOH16 on next session.")
|
||||
|
||||
|
||||
class OrderedContractsTestCase(ZiplineTestCase):
|
||||
class OrderedContractsTestCase(WithAssetFinder,
|
||||
ZiplineTestCase):
|
||||
|
||||
@classmethod
|
||||
def make_root_symbols_info(self):
|
||||
return pd.DataFrame({
|
||||
'root_symbol': ['FO'],
|
||||
'root_symbol_id': [1],
|
||||
'exchange': ['CME']})
|
||||
|
||||
@classmethod
|
||||
def make_futures_info(self):
|
||||
return DataFrame({
|
||||
'root_symbol': ['FO'] * 4,
|
||||
'asset_name': ['Foo'] * 4,
|
||||
'sid': range(1, 5),
|
||||
'start_date': pd.date_range('2015-01-01', periods=4, tz="UTC"),
|
||||
'end_date': pd.date_range('2016-01-01', periods=4, tz="UTC"),
|
||||
'notice_date': pd.date_range('2016-01-01', periods=4, tz="UTC"),
|
||||
'expiration_date': pd.date_range(
|
||||
'2016-01-01', periods=4, tz="UTC"),
|
||||
'expiration_date': pd.date_range(
|
||||
'2016-01-01', periods=4, tz="UTC"),
|
||||
'auto_close_date': pd.date_range(
|
||||
'2016-01-01', periods=4, tz="UTC"),
|
||||
'tick_size': [0.001] * 4,
|
||||
'multiplier': [1000.0] * 4,
|
||||
'exchange': ['CME'] * 4,
|
||||
})
|
||||
|
||||
def test_contract_at_offset(self):
|
||||
contract_sids = array([1, 2, 3, 4], dtype=int64)
|
||||
start_dates = pd.date_range('2015-01-01', periods=4, tz="UTC")
|
||||
auto_close_dates = pd.date_range('2016-04-01', periods=4, tz="UTC")
|
||||
|
||||
oc = OrderedContracts('FO',
|
||||
contract_sids,
|
||||
start_dates.astype('int64'),
|
||||
auto_close_dates.astype('int64'))
|
||||
contracts = deque(self.asset_finder.retrieve_all(contract_sids))
|
||||
|
||||
oc = OrderedContracts('FO', contracts)
|
||||
|
||||
self.assertEquals(1,
|
||||
oc.contract_at_offset(1, 0, start_dates[-1].value),
|
||||
@@ -994,13 +1021,10 @@ class OrderedContractsTestCase(ZiplineTestCase):
|
||||
|
||||
def test_active_chain(self):
|
||||
contract_sids = array([1, 2, 3, 4], dtype=int64)
|
||||
start_dates = pd.date_range('2015-01-01', periods=4, tz="UTC")
|
||||
auto_close_dates = pd.date_range('2016-04-01', periods=4, tz="UTC")
|
||||
|
||||
oc = OrderedContracts('FO',
|
||||
contract_sids,
|
||||
start_dates.astype('int64'),
|
||||
auto_close_dates.astype('int64'))
|
||||
contracts = deque(self.asset_finder.retrieve_all(contract_sids))
|
||||
|
||||
oc = OrderedContracts('FO', contracts)
|
||||
|
||||
# Test sid 1 as days increment, as the sessions march forward
|
||||
# a contract should be added per day, until all defined contracts
|
||||
|
||||
+16
-32
@@ -15,7 +15,7 @@
|
||||
from abc import ABCMeta
|
||||
import array
|
||||
import binascii
|
||||
from collections import namedtuple
|
||||
from collections import deque, namedtuple
|
||||
from numbers import Integral
|
||||
from operator import itemgetter, attrgetter
|
||||
import struct
|
||||
@@ -880,15 +880,14 @@ class AssetFinder(object):
|
||||
|
||||
return sids
|
||||
|
||||
def _get_contract_info(self, root_symbol):
|
||||
def _get_contract_sids(self, root_symbol):
|
||||
fc_cols = self.futures_contracts.c
|
||||
|
||||
fields = (fc_cols.sid, fc_cols.start_date, fc_cols.auto_close_date)
|
||||
|
||||
return list(sa.select(fields).where(
|
||||
(fc_cols.root_symbol == root_symbol) &
|
||||
(fc_cols.start_date != pd.NaT.value))
|
||||
.order_by(fc_cols.auto_close_date).execute().fetchall())
|
||||
return [r.sid for r in
|
||||
list(sa.select((fc_cols.sid,)).where(
|
||||
(fc_cols.root_symbol == root_symbol) &
|
||||
(fc_cols.start_date != pd.NaT.value)).order_by(
|
||||
fc_cols.sid).execute().fetchall())]
|
||||
|
||||
def _get_root_symbol_exchange(self, root_symbol):
|
||||
fc_cols = self.futures_root_symbols.c
|
||||
@@ -902,29 +901,14 @@ class AssetFinder(object):
|
||||
try:
|
||||
return self._ordered_contracts[root_symbol]
|
||||
except KeyError:
|
||||
contract_info = self._get_contract_info(root_symbol)
|
||||
size = len(contract_info)
|
||||
sids = np.full(size, 0, dtype=np.int64)
|
||||
start_dates = np.full(size, 0, dtype=np.int64)
|
||||
auto_close_dates = np.full(size, 0, dtype=np.int64)
|
||||
self._size = size
|
||||
for i, info in enumerate(contract_info):
|
||||
sid, start_date, auto_close_date = info
|
||||
sids[i] = sid
|
||||
start_dates[i] = start_date
|
||||
auto_close_dates[i] = auto_close_date
|
||||
oc = OrderedContracts(root_symbol,
|
||||
sids,
|
||||
start_dates,
|
||||
auto_close_dates)
|
||||
contract_sids = self._get_contract_sids(root_symbol)
|
||||
contracts = deque(self.retrieve_all(contract_sids))
|
||||
oc = OrderedContracts(root_symbol, contracts)
|
||||
self._ordered_contracts[root_symbol] = oc
|
||||
return oc
|
||||
|
||||
def create_continuous_future(self, root_symbol, offset, roll_style):
|
||||
oc = self.get_ordered_contracts(root_symbol)
|
||||
contracts = self.retrieve_all(oc.contract_sids)
|
||||
start_date = min(c.start_date for c in contracts)
|
||||
end_date = max(c.end_date for c in contracts)
|
||||
exchange = self._get_root_symbol_exchange(root_symbol)
|
||||
|
||||
sid = _encode_continuous_future_sid(root_symbol, offset,
|
||||
@@ -940,24 +924,24 @@ class AssetFinder(object):
|
||||
root_symbol,
|
||||
offset,
|
||||
roll_style,
|
||||
start_date,
|
||||
end_date,
|
||||
oc.start_date,
|
||||
oc.end_date,
|
||||
exchange,
|
||||
'mul')
|
||||
add_cf = ContinuousFuture(add_sid,
|
||||
root_symbol,
|
||||
offset,
|
||||
roll_style,
|
||||
start_date,
|
||||
end_date,
|
||||
oc.start_date,
|
||||
oc.end_date,
|
||||
exchange,
|
||||
'add')
|
||||
cf = ContinuousFuture(sid,
|
||||
root_symbol,
|
||||
offset,
|
||||
roll_style,
|
||||
start_date,
|
||||
end_date,
|
||||
oc.start_date,
|
||||
oc.end_date,
|
||||
exchange,
|
||||
adjustment_children={
|
||||
'mul': mul_cf,
|
||||
|
||||
@@ -29,8 +29,9 @@ from cpython.object cimport (
|
||||
)
|
||||
from cpython cimport bool
|
||||
|
||||
from numpy import empty
|
||||
from numpy import array, empty, iinfo
|
||||
from numpy cimport long_t, int64_t
|
||||
from pandas import Timestamp
|
||||
import warnings
|
||||
|
||||
from zipline.utils.calendars import get_calendar
|
||||
@@ -241,6 +242,33 @@ cdef class ContinuousFuture:
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
cdef class ContractNode(object):
|
||||
|
||||
cdef readonly object contract
|
||||
cdef public object prev
|
||||
cdef public object next
|
||||
|
||||
def __init__(self, contract):
|
||||
self.contract = contract
|
||||
self.prev = None
|
||||
self.next = None
|
||||
|
||||
def __rshift__(self, offset):
|
||||
i = 0
|
||||
curr = self
|
||||
while i < offset and curr is not None:
|
||||
curr = curr.next
|
||||
i += 1
|
||||
return curr
|
||||
|
||||
def __lshift__(self, offset):
|
||||
i = 0
|
||||
curr = self
|
||||
while i < offset and curr is not None:
|
||||
curr = curr.prev
|
||||
i += 1
|
||||
return curr
|
||||
|
||||
|
||||
cdef class OrderedContracts(object):
|
||||
"""
|
||||
@@ -249,16 +277,12 @@ cdef class OrderedContracts(object):
|
||||
Used to get answers about contracts in relation to their auto close
|
||||
dates and start dates.
|
||||
|
||||
The number of contracts for a given root symbol is ~250,
|
||||
which is why search by comparison over the range of contracts is
|
||||
used. At this size, this is faster than using an index or np.searchsorted.
|
||||
|
||||
Members
|
||||
-------
|
||||
root_symbol : str
|
||||
The root symbol of the future contract chain.
|
||||
contract_sids : long[:]
|
||||
The contract sids in sorted order of occurrence.
|
||||
contracts : deque
|
||||
The contracts in the chain in order of occurrence.
|
||||
start_dates : long[:]
|
||||
The start dates of the contracts in the chain.
|
||||
Corresponds by index with contract_sids.
|
||||
@@ -271,68 +295,85 @@ cdef class OrderedContracts(object):
|
||||
"""
|
||||
|
||||
cdef readonly object root_symbol
|
||||
cdef int _size
|
||||
cdef readonly long_t[:] contract_sids
|
||||
cdef readonly long_t[:] start_dates
|
||||
cdef readonly long_t[:] auto_close_dates
|
||||
cdef readonly object head_contract
|
||||
cdef readonly dict sid_to_contract
|
||||
cdef readonly int64_t _start_date
|
||||
cdef readonly int64_t _end_date
|
||||
|
||||
def __init__(self, object root_symbol, object contracts):
|
||||
|
||||
def __init__(self,
|
||||
object root_symbol,
|
||||
long_t[:] contract_sids,
|
||||
long_t[:] start_dates,
|
||||
long_t[:] auto_close_dates):
|
||||
self._size = len(contract_sids)
|
||||
self.root_symbol = root_symbol
|
||||
self.contract_sids = contract_sids
|
||||
self.start_dates = start_dates
|
||||
self.auto_close_dates = auto_close_dates
|
||||
|
||||
self.sid_to_contract = {}
|
||||
|
||||
self._start_date = iinfo('int64').max
|
||||
self._end_date = 0
|
||||
|
||||
contract = contracts.popleft()
|
||||
self.head_contract = ContractNode(contract)
|
||||
self._start_date = min(contract.start_date.value, self._start_date)
|
||||
self._end_date = max(contract.end_date.value, self._end_date)
|
||||
self.sid_to_contract[contract.sid] = self.head_contract
|
||||
prev = self.head_contract
|
||||
while contracts:
|
||||
contract = contracts.popleft()
|
||||
|
||||
# Here is where a predicate would go to ensure continuity of the chain.
|
||||
|
||||
self._start_date = min(contract.start_date.value, self._start_date)
|
||||
self._end_date = max(contract.end_date.value, self._end_date)
|
||||
|
||||
curr = ContractNode(contract)
|
||||
curr.prev = prev
|
||||
prev.next = curr
|
||||
prev = curr
|
||||
|
||||
self.sid_to_contract[contract.sid] = curr
|
||||
|
||||
cpdef long_t contract_before_auto_close(self, long_t dt_value):
|
||||
"""
|
||||
Get the contract with next upcoming auto close date.
|
||||
"""
|
||||
cdef Py_ssize_t i, auto_close_date
|
||||
for i, auto_close_date in enumerate(self.auto_close_dates):
|
||||
if auto_close_date > dt_value:
|
||||
curr = self.head_contract
|
||||
while curr.next is not None:
|
||||
if curr.contract.auto_close_date.value > dt_value:
|
||||
break
|
||||
return self.contract_sids[i]
|
||||
curr = curr.next
|
||||
return curr.contract.sid
|
||||
|
||||
cpdef contract_at_offset(self, long_t sid, Py_ssize_t offset, int64_t start_cap):
|
||||
"""
|
||||
Get the sid which is the given sid plus the offset distance.
|
||||
An offset of 0 should be reflexive.
|
||||
"""
|
||||
cdef Py_ssize_t i, j
|
||||
cdef long_t[:] sids
|
||||
sids = self.contract_sids
|
||||
start_dates = self.start_dates
|
||||
cdef Py_ssize_t i
|
||||
curr = self.sid_to_contract[sid]
|
||||
i = 0
|
||||
j = i + offset
|
||||
while j < self._size:
|
||||
if sid == sids[i]:
|
||||
if start_dates[j] < start_cap:
|
||||
return sids[j]
|
||||
else:
|
||||
return None
|
||||
while i < offset:
|
||||
if curr.next is None:
|
||||
return None
|
||||
curr = curr.next
|
||||
i += 1
|
||||
j += 1
|
||||
if curr.contract.start_date.value <= start_cap:
|
||||
return curr.contract.sid
|
||||
else:
|
||||
return None
|
||||
|
||||
cpdef long_t[:] active_chain(self, long_t starting_sid, long_t dt_value):
|
||||
cdef Py_ssize_t left, right, i, j
|
||||
cdef long_t[:] sids, start_dates
|
||||
left = 0
|
||||
right = self._size
|
||||
sids = self.contract_sids
|
||||
start_dates = self.start_dates
|
||||
curr = self.sid_to_contract[starting_sid]
|
||||
cdef list contracts = []
|
||||
|
||||
for i in range(self._size):
|
||||
if starting_sid == sids[i]:
|
||||
left = i
|
||||
break
|
||||
while curr is not None:
|
||||
if curr.contract.start_date.value <= dt_value:
|
||||
contracts.append(curr.contract.sid)
|
||||
curr = curr.next
|
||||
|
||||
return array(contracts, dtype='int64')
|
||||
|
||||
for j in range(i, self._size):
|
||||
if start_dates[j] > dt_value:
|
||||
right = j
|
||||
break
|
||||
property start_date:
|
||||
def __get__(self):
|
||||
return Timestamp(self._start_date, tz='UTC')
|
||||
|
||||
return sids[left:right]
|
||||
property end_date:
|
||||
def __get__(self):
|
||||
return Timestamp(self._end_date, tz='UTC')
|
||||
|
||||
@@ -15,8 +15,6 @@
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from six import with_metaclass
|
||||
|
||||
from pandas import Timestamp
|
||||
|
||||
|
||||
class RollFinder(with_metaclass(ABCMeta, object)):
|
||||
"""
|
||||
@@ -85,32 +83,30 @@ class RollFinder(with_metaclass(ABCMeta, object)):
|
||||
first = self._active_contract(oc, front, back, end)
|
||||
else:
|
||||
first = front
|
||||
for i, sid in enumerate(oc.contract_sids):
|
||||
if sid == first:
|
||||
break
|
||||
rolls = [(first + offset, None)]
|
||||
first_contract = oc.sid_to_contract[first]
|
||||
rolls = [((first_contract >> offset).contract.sid, None)]
|
||||
tc = self.trading_calendar
|
||||
sessions = tc.sessions_in_range(tc.minute_to_session_label(start),
|
||||
tc.minute_to_session_label(end))
|
||||
if first == front:
|
||||
i -= 1
|
||||
curr = first_contract << 1
|
||||
else:
|
||||
i -= 2
|
||||
curr = sessions[-1]
|
||||
while curr > start and i > -1:
|
||||
session_loc = sessions.searchsorted(curr)
|
||||
front = oc.contract_sids[i]
|
||||
back = oc.contract_sids[i + 1]
|
||||
curr = first_contract << 2
|
||||
sess = sessions[-1]
|
||||
while sess > start and curr is not None:
|
||||
session_loc = sessions.searchsorted(sess)
|
||||
front = curr.contract.sid
|
||||
back = curr.next.contract.sid
|
||||
while session_loc > 0:
|
||||
session = sessions[session_loc]
|
||||
prev = sessions[session_loc - 1]
|
||||
if back != self._active_contract(oc, front, back, prev):
|
||||
rolls.insert(0, (oc.contract_sids[i + offset], session))
|
||||
rolls.insert(0, ((curr >> offset).contract.sid, session))
|
||||
break
|
||||
session_loc -= 1
|
||||
i -= 1
|
||||
curr = Timestamp(oc.auto_close_dates[i],
|
||||
tz='UTC')
|
||||
curr = curr.prev
|
||||
if curr is not None:
|
||||
sess = curr.contract.auto_close_date
|
||||
return rolls
|
||||
|
||||
|
||||
@@ -125,10 +121,8 @@ class CalendarRollFinder(RollFinder):
|
||||
self.asset_finder = asset_finder
|
||||
|
||||
def _active_contract(self, oc, front, back, dt):
|
||||
for i, sid in enumerate(oc.contract_sids):
|
||||
if sid == front:
|
||||
break
|
||||
auto_close_date = Timestamp(oc.auto_close_dates[i], tz='UTC')
|
||||
contract = oc.sid_to_contract[front].contract
|
||||
auto_close_date = contract.auto_close_date
|
||||
auto_closed = dt >= auto_close_date
|
||||
return back if auto_closed else front
|
||||
|
||||
@@ -153,9 +147,6 @@ class VolumeRollFinder(RollFinder):
|
||||
if back_vol > front_vol:
|
||||
return back
|
||||
else:
|
||||
for i, sid in enumerate(oc.contract_sids):
|
||||
if sid == front:
|
||||
break
|
||||
auto_close_date = Timestamp(oc.auto_close_dates[i], tz='UTC')
|
||||
auto_closed = dt >= auto_close_date
|
||||
contract = oc.sid_to_contract[front].contract
|
||||
auto_closed = dt >= contract.auto_close_date
|
||||
return back if auto_closed else front
|
||||
|
||||
Reference in New Issue
Block a user