mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 18:49:42 +08:00
ENH: Allow future chains to only use certain delivery months.
To support contracts such as `PL` which should roll from F->J->N->V, add the ability to pass a predicate function to the ordered contract chain contstrution which returns `True` if the contract is allowed in the chain.
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections import deque
|
||||
from functools import partial
|
||||
from textwrap import dedent
|
||||
|
||||
from numpy import (
|
||||
@@ -28,7 +29,10 @@ import pandas as pd
|
||||
from pandas import Timestamp, DataFrame
|
||||
|
||||
from zipline import TradingAlgorithm
|
||||
from zipline.assets.continuous_futures import OrderedContracts
|
||||
from zipline.assets.continuous_futures import (
|
||||
OrderedContracts,
|
||||
delivery_predicate
|
||||
)
|
||||
from zipline.data.minute_bars import FUTURES_MINUTES_PER_DAY
|
||||
from zipline.testing.fixtures import (
|
||||
WithAssetFinder,
|
||||
@@ -55,12 +59,16 @@ class ContinuousFuturesTestCase(WithCreateBarData,
|
||||
TRADING_CALENDAR_STRS = ('us_futures',)
|
||||
TRADING_CALENDAR_PRIMARY_CAL = 'us_futures'
|
||||
|
||||
TRADING_ENV_FUTURE_CHAIN_PREDICATES = {
|
||||
'BZ': partial(delivery_predicate, set(['F', 'H'])),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def make_root_symbols_info(self):
|
||||
return pd.DataFrame({
|
||||
'root_symbol': ['FO', 'BA'],
|
||||
'root_symbol_id': [1, 2],
|
||||
'exchange': ['CME', 'CME']})
|
||||
'root_symbol': ['FO', 'BA', 'BZ'],
|
||||
'root_symbol_id': [1, 2, 3],
|
||||
'exchange': ['CME', 'CME', 'CME']})
|
||||
|
||||
@classmethod
|
||||
def make_futures_info(self):
|
||||
@@ -144,7 +152,34 @@ class ContinuousFuturesTestCase(WithCreateBarData,
|
||||
'exchange': ['CME'] * 3,
|
||||
})
|
||||
|
||||
return pd.concat([fo_frame, ba_frame])
|
||||
# BZ is set up to test chain predicates, for futures such as PL which
|
||||
# only use a subset of contracts for the roll chain.
|
||||
bz_frame = DataFrame({
|
||||
'symbol': ['BZF16', 'BZG16', 'BZH16'],
|
||||
'root_symbol': ['BZ'] * 3,
|
||||
'asset_name': ['Baz'] * 3,
|
||||
'sid': range(10, 13),
|
||||
'start_date': [Timestamp('2005-01-01', tz='UTC'),
|
||||
Timestamp('2005-01-21', tz='UTC'),
|
||||
Timestamp('2005-01-21', tz='UTC')],
|
||||
'end_date': [Timestamp('2016-08-19', tz='UTC'),
|
||||
Timestamp('2016-11-21', tz='UTC'),
|
||||
Timestamp('2016-10-19', tz='UTC')],
|
||||
'notice_date': [Timestamp('2016-01-11', tz='UTC'),
|
||||
Timestamp('2016-02-08', tz='UTC'),
|
||||
Timestamp('2016-03-09', tz='UTC')],
|
||||
'expiration_date': [Timestamp('2016-01-11', tz='UTC'),
|
||||
Timestamp('2016-02-08', tz='UTC'),
|
||||
Timestamp('2016-03-09', tz='UTC')],
|
||||
'auto_close_date': [Timestamp('2016-01-11', tz='UTC'),
|
||||
Timestamp('2016-02-08', tz='UTC'),
|
||||
Timestamp('2016-03-09', tz='UTC')],
|
||||
'tick_size': [0.001] * 3,
|
||||
'multiplier': [1000.0] * 3,
|
||||
'exchange': ['CME'] * 3,
|
||||
})
|
||||
|
||||
return pd.concat([fo_frame, ba_frame, bz_frame])
|
||||
|
||||
@classmethod
|
||||
def make_future_minute_bar_data(cls):
|
||||
@@ -593,6 +628,27 @@ def record_current_contract(algo, data):
|
||||
9,
|
||||
"Should have remained BAM16")
|
||||
|
||||
def test_history_sid_session_delivery_predicate(self):
|
||||
cf = self.data_portal.asset_finder.create_continuous_future(
|
||||
'BZ', 0, 'calendar')
|
||||
window = self.data_portal.get_history_window(
|
||||
[cf],
|
||||
Timestamp('2016-01-11 18:01', tz='US/Eastern').tz_convert('UTC'),
|
||||
3, '1d', 'sid')
|
||||
|
||||
self.assertEqual(window.loc['2016-01-08', cf],
|
||||
10,
|
||||
"Should be BZF16 at beginning of window.")
|
||||
|
||||
self.assertEqual(window.loc['2016-01-11', cf],
|
||||
12,
|
||||
"Should be BZH16 after first roll, having skipped "
|
||||
"over BZG16.")
|
||||
|
||||
self.assertEqual(window.loc['2016-01-12', cf],
|
||||
12,
|
||||
"Should have remained BZG16")
|
||||
|
||||
def test_history_sid_session_secondary(self):
|
||||
cf = self.data_portal.asset_finder.create_continuous_future(
|
||||
'FO', 1, 'calendar')
|
||||
@@ -1094,15 +1150,16 @@ class OrderedContractsTestCase(WithAssetFinder,
|
||||
@classmethod
|
||||
def make_root_symbols_info(self):
|
||||
return pd.DataFrame({
|
||||
'root_symbol': ['FO'],
|
||||
'root_symbol_id': [1],
|
||||
'exchange': ['CME']})
|
||||
'root_symbol': ['FO', 'BA'],
|
||||
'root_symbol_id': [1, 2],
|
||||
'exchange': ['CME', 'CME']})
|
||||
|
||||
@classmethod
|
||||
def make_futures_info(self):
|
||||
return DataFrame({
|
||||
fo_frame = DataFrame({
|
||||
'root_symbol': ['FO'] * 4,
|
||||
'asset_name': ['Foo'] * 4,
|
||||
'symbol': ['FOF16', 'FOG16', 'FOH16', 'FOJ16'],
|
||||
'sid': range(1, 5),
|
||||
'start_date': pd.date_range('2015-01-01', periods=4, tz="UTC"),
|
||||
'end_date': pd.date_range('2016-01-01', periods=4, tz="UTC"),
|
||||
@@ -1117,6 +1174,29 @@ class OrderedContractsTestCase(WithAssetFinder,
|
||||
'multiplier': [1000.0] * 4,
|
||||
'exchange': ['CME'] * 4,
|
||||
})
|
||||
# BA is set up to test a quarterly roll, to test Eurodollar-like
|
||||
# behavior
|
||||
# The roll should go from BAH16 -> BAM16
|
||||
ba_frame = DataFrame({
|
||||
'root_symbol': ['BA'] * 3,
|
||||
'asset_name': ['Bar'] * 3,
|
||||
'symbol': ['BAF16', 'BAG16', 'BAH16'],
|
||||
'sid': range(5, 8),
|
||||
'start_date': pd.date_range('2015-01-01', periods=3, tz="UTC"),
|
||||
'end_date': pd.date_range('2016-01-01', periods=3, tz="UTC"),
|
||||
'notice_date': pd.date_range('2016-01-01', periods=3, tz="UTC"),
|
||||
'expiration_date': pd.date_range(
|
||||
'2016-01-01', periods=3, tz="UTC"),
|
||||
'expiration_date': pd.date_range(
|
||||
'2016-01-01', periods=3, tz="UTC"),
|
||||
'auto_close_date': pd.date_range(
|
||||
'2016-01-01', periods=3, tz="UTC"),
|
||||
'tick_size': [0.001] * 3,
|
||||
'multiplier': [1000.0] * 3,
|
||||
'exchange': ['CME'] * 3,
|
||||
})
|
||||
|
||||
return pd.concat([fo_frame, ba_frame])
|
||||
|
||||
def test_contract_at_offset(self):
|
||||
contract_sids = array([1, 2, 3, 4], dtype=int64)
|
||||
@@ -1197,6 +1277,23 @@ class OrderedContractsTestCase(WithAssetFinder,
|
||||
self.assertEquals([4], list(chain),
|
||||
"[4] should be active beginning at its start date.")
|
||||
|
||||
def test_delivery_predicate(self):
|
||||
contract_sids = range(5, 8)
|
||||
contracts = deque(self.asset_finder.retrieve_all(contract_sids))
|
||||
|
||||
oc = OrderedContracts('BA', contracts,
|
||||
chain_predicate=partial(delivery_predicate,
|
||||
set(['F', 'H'])))
|
||||
|
||||
# Test sid 1 as days increment, as the sessions march forward
|
||||
# a contract should be added per day, until all defined contracts
|
||||
# are returned.
|
||||
chain = oc.active_chain(5, pd.Timestamp('2015-01-05', tz='UTC').value)
|
||||
self.assertEquals(
|
||||
[5, 7], list(chain),
|
||||
"Contract BAG16 (sid=6) should be ommitted from chain, since "
|
||||
"it does not satisfy the roll predicate.")
|
||||
|
||||
|
||||
class NoPrefetchContinuousFuturesTestCase(ContinuousFuturesTestCase):
|
||||
DATA_PORTAL_MINUTE_HISTORY_PREFETCH = 0
|
||||
|
||||
@@ -49,7 +49,11 @@ from zipline.errors import (
|
||||
from . import (
|
||||
Asset, Equity, Future,
|
||||
)
|
||||
from . continuous_futures import OrderedContracts, ContinuousFuture
|
||||
from . continuous_futures import (
|
||||
OrderedContracts,
|
||||
ContinuousFuture,
|
||||
CHAIN_PREDICATES
|
||||
)
|
||||
from .asset_writer import (
|
||||
check_version_info,
|
||||
split_delimited_symbol,
|
||||
@@ -183,6 +187,10 @@ class AssetFinder(object):
|
||||
engine : str or SQLAlchemy.engine
|
||||
An engine with a connection to the asset database to use, or a string
|
||||
that can be parsed by SQLAlchemy as a URI.
|
||||
future_chain_predicates : dict
|
||||
A dict mapping future root symbol to a predicate function which accepts
|
||||
a contract as a parameter and returns whether or not the contract should be
|
||||
included in the chain.
|
||||
|
||||
See Also
|
||||
--------
|
||||
@@ -193,7 +201,7 @@ class AssetFinder(object):
|
||||
PERSISTENT_TOKEN = "<AssetFinder>"
|
||||
|
||||
@preprocess(engine=coerce_string_to_eng)
|
||||
def __init__(self, engine):
|
||||
def __init__(self, engine, future_chain_predicates=CHAIN_PREDICATES):
|
||||
self.engine = engine
|
||||
metadata = sa.MetaData(bind=engine)
|
||||
metadata.reflect(only=asset_db_table_names)
|
||||
@@ -213,6 +221,8 @@ class AssetFinder(object):
|
||||
# retrieve_asset will populate the cache on first retrieval.
|
||||
self._caches = (self._asset_cache, self._asset_type_cache) = {}, {}
|
||||
|
||||
self._future_chain_predicates = future_chain_predicates \
|
||||
if future_chain_predicates is not None else {}
|
||||
self._ordered_contracts = {}
|
||||
|
||||
# Populated on first call to `lifetimes`.
|
||||
@@ -903,7 +913,9 @@ class AssetFinder(object):
|
||||
except KeyError:
|
||||
contract_sids = self._get_contract_sids(root_symbol)
|
||||
contracts = deque(self.retrieve_all(contract_sids))
|
||||
oc = OrderedContracts(root_symbol, contracts)
|
||||
chain_predicate = self._future_chain_predicates.get(root_symbol,
|
||||
None)
|
||||
oc = OrderedContracts(root_symbol, contracts, chain_predicate)
|
||||
self._ordered_contracts[root_symbol] = oc
|
||||
return oc
|
||||
|
||||
|
||||
@@ -29,6 +29,8 @@ from cpython.object cimport (
|
||||
)
|
||||
from cpython cimport bool
|
||||
|
||||
from functools import partial
|
||||
|
||||
from numpy import array, empty, iinfo
|
||||
from numpy cimport long_t, int64_t
|
||||
from pandas import Timestamp
|
||||
@@ -37,6 +39,21 @@ import warnings
|
||||
from zipline.utils.calendars import get_calendar
|
||||
|
||||
|
||||
def delivery_predicate(codes, contract):
|
||||
# This relies on symbols that are construct following a pattern of
|
||||
# root symbol + delivery code + year, e.g. PLF16
|
||||
# This check would be more robust if the future contract class had
|
||||
# a 'delivery_month' member.
|
||||
delivery_code = contract.symbol[-3]
|
||||
return delivery_code in codes
|
||||
|
||||
CHAIN_PREDICATES = {
|
||||
'ME': partial(delivery_predicate, set(['H', 'M', 'U', 'Z'])),
|
||||
'PL': partial(delivery_predicate, set(['F', 'J', 'N', 'V'])),
|
||||
'PA': partial(delivery_predicate, set(['H', 'M', 'U', 'Z']))
|
||||
}
|
||||
|
||||
|
||||
cdef class ContinuousFuture:
|
||||
"""
|
||||
Represents a specifier for a chain of future contracts, where the
|
||||
@@ -289,18 +306,23 @@ cdef class OrderedContracts(object):
|
||||
auto_close_dates : long[:]
|
||||
The auto close dates of the contracts in the chain.
|
||||
Corresponds by index with contract_sids.
|
||||
future_chain_predicates : dict
|
||||
A dict mapping root symbol to a predicate function which accepts a contract
|
||||
as a parameter and returns whether or not the contract should be included in the
|
||||
chain.
|
||||
|
||||
Instances of this class are used by the simulation engine, but not
|
||||
exposed to the algorithm.
|
||||
"""
|
||||
|
||||
cdef readonly object root_symbol
|
||||
cdef readonly object head_contract
|
||||
cdef readonly object _head_contract
|
||||
cdef readonly dict sid_to_contract
|
||||
cdef readonly int64_t _start_date
|
||||
cdef readonly int64_t _end_date
|
||||
cdef readonly object chain_predicate
|
||||
|
||||
def __init__(self, object root_symbol, object contracts):
|
||||
def __init__(self, object root_symbol, object contracts, object chain_predicate=None):
|
||||
|
||||
self.root_symbol = root_symbol
|
||||
|
||||
@@ -309,12 +331,11 @@ cdef class OrderedContracts(object):
|
||||
self._start_date = iinfo('int64').max
|
||||
self._end_date = 0
|
||||
|
||||
contract = contracts.popleft()
|
||||
self.head_contract = ContractNode(contract)
|
||||
self._start_date = min(contract.start_date.value, self._start_date)
|
||||
self._end_date = max(contract.end_date.value, self._end_date)
|
||||
self.sid_to_contract[contract.sid] = self.head_contract
|
||||
prev = self.head_contract
|
||||
if chain_predicate is None:
|
||||
chain_predicate = lambda x: True
|
||||
|
||||
self._head_contract = None
|
||||
prev = None
|
||||
while contracts:
|
||||
contract = contracts.popleft()
|
||||
|
||||
@@ -322,24 +343,30 @@ cdef class OrderedContracts(object):
|
||||
# next contract.
|
||||
# This is in lieu of more explicit support for
|
||||
# contracts with quarterly rolls. e.g. Eurodollar
|
||||
if contract.start_date > prev.contract.auto_close_date:
|
||||
if prev is not None and contract.start_date > prev.contract.auto_close_date:
|
||||
continue
|
||||
|
||||
if not chain_predicate(contract):
|
||||
continue
|
||||
|
||||
self._start_date = min(contract.start_date.value, self._start_date)
|
||||
self._end_date = max(contract.end_date.value, self._end_date)
|
||||
|
||||
curr = ContractNode(contract)
|
||||
self.sid_to_contract[contract.sid] = curr
|
||||
if self._head_contract is None:
|
||||
self._head_contract = curr
|
||||
prev = curr
|
||||
continue
|
||||
curr.prev = prev
|
||||
prev.next = curr
|
||||
prev = curr
|
||||
|
||||
self.sid_to_contract[contract.sid] = curr
|
||||
|
||||
cpdef long_t contract_before_auto_close(self, long_t dt_value):
|
||||
"""
|
||||
Get the contract with next upcoming auto close date.
|
||||
"""
|
||||
curr = self.head_contract
|
||||
curr = self._head_contract
|
||||
while curr.next is not None:
|
||||
if curr.contract.auto_close_date.value > dt_value:
|
||||
break
|
||||
|
||||
@@ -20,6 +20,7 @@ from six import string_types
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
from zipline.assets import AssetDBWriter, AssetFinder
|
||||
from zipline.assets.continuous_futures import CHAIN_PREDICATES
|
||||
from zipline.data.loader import load_market_data
|
||||
from zipline.utils.calendars import get_calendar
|
||||
from zipline.utils.memoize import remember_last
|
||||
@@ -80,7 +81,8 @@ class TradingEnvironment(object):
|
||||
bm_symbol='^GSPC',
|
||||
exchange_tz="US/Eastern",
|
||||
trading_calendar=None,
|
||||
asset_db_path=':memory:'
|
||||
asset_db_path=':memory:',
|
||||
future_chain_predicates=CHAIN_PREDICATES,
|
||||
):
|
||||
|
||||
self.bm_symbol = bm_symbol
|
||||
@@ -106,7 +108,9 @@ class TradingEnvironment(object):
|
||||
|
||||
if engine is not None:
|
||||
AssetDBWriter(engine).init_db()
|
||||
self.asset_finder = AssetFinder(engine)
|
||||
self.asset_finder = AssetFinder(
|
||||
engine,
|
||||
future_chain_predicates=future_chain_predicates)
|
||||
else:
|
||||
self.asset_finder = None
|
||||
|
||||
|
||||
@@ -451,6 +451,8 @@ class WithTradingEnvironment(WithAssetFinder,
|
||||
The max date to forward to the constructed TradingEnvironment.
|
||||
TRADING_ENV_TRADING_CALENDAR : pd.DatetimeIndex
|
||||
The trading calendar to use for the class's TradingEnvironment.
|
||||
TRADING_ENV_FUTURE_CHAIN_PREDICATES : dict
|
||||
The roll predicates to apply when creating contract chains.
|
||||
|
||||
Methods
|
||||
-------
|
||||
@@ -468,6 +470,7 @@ class WithTradingEnvironment(WithAssetFinder,
|
||||
--------
|
||||
:class:`zipline.finance.trading.TradingEnvironment`
|
||||
"""
|
||||
TRADING_ENV_FUTURE_CHAIN_PREDICATES = None
|
||||
|
||||
@classmethod
|
||||
def make_load_function(cls):
|
||||
@@ -479,6 +482,7 @@ class WithTradingEnvironment(WithAssetFinder,
|
||||
load=cls.make_load_function(),
|
||||
asset_db_path=cls.asset_finder.engine,
|
||||
trading_calendar=cls.trading_calendar,
|
||||
future_chain_predicates=cls.TRADING_ENV_FUTURE_CHAIN_PREDICATES,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user