ENH: Allow future chains to only use certain delivery months.

To support contracts such as `PL` which should roll from F->J->N->V, add the
ability to pass a predicate function to the ordered contract chain contstrution
which returns `True` if the contract is allowed in the chain.
This commit is contained in:
Eddie Hebert
2016-11-30 23:59:31 -05:00
parent 3575351306
commit 1f71c8d068
5 changed files with 170 additions and 26 deletions
+106 -9
View File
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import deque
from functools import partial
from textwrap import dedent
from numpy import (
@@ -28,7 +29,10 @@ import pandas as pd
from pandas import Timestamp, DataFrame
from zipline import TradingAlgorithm
from zipline.assets.continuous_futures import OrderedContracts
from zipline.assets.continuous_futures import (
OrderedContracts,
delivery_predicate
)
from zipline.data.minute_bars import FUTURES_MINUTES_PER_DAY
from zipline.testing.fixtures import (
WithAssetFinder,
@@ -55,12 +59,16 @@ class ContinuousFuturesTestCase(WithCreateBarData,
TRADING_CALENDAR_STRS = ('us_futures',)
TRADING_CALENDAR_PRIMARY_CAL = 'us_futures'
TRADING_ENV_FUTURE_CHAIN_PREDICATES = {
'BZ': partial(delivery_predicate, set(['F', 'H'])),
}
@classmethod
def make_root_symbols_info(self):
return pd.DataFrame({
'root_symbol': ['FO', 'BA'],
'root_symbol_id': [1, 2],
'exchange': ['CME', 'CME']})
'root_symbol': ['FO', 'BA', 'BZ'],
'root_symbol_id': [1, 2, 3],
'exchange': ['CME', 'CME', 'CME']})
@classmethod
def make_futures_info(self):
@@ -144,7 +152,34 @@ class ContinuousFuturesTestCase(WithCreateBarData,
'exchange': ['CME'] * 3,
})
return pd.concat([fo_frame, ba_frame])
# BZ is set up to test chain predicates, for futures such as PL which
# only use a subset of contracts for the roll chain.
bz_frame = DataFrame({
'symbol': ['BZF16', 'BZG16', 'BZH16'],
'root_symbol': ['BZ'] * 3,
'asset_name': ['Baz'] * 3,
'sid': range(10, 13),
'start_date': [Timestamp('2005-01-01', tz='UTC'),
Timestamp('2005-01-21', tz='UTC'),
Timestamp('2005-01-21', tz='UTC')],
'end_date': [Timestamp('2016-08-19', tz='UTC'),
Timestamp('2016-11-21', tz='UTC'),
Timestamp('2016-10-19', tz='UTC')],
'notice_date': [Timestamp('2016-01-11', tz='UTC'),
Timestamp('2016-02-08', tz='UTC'),
Timestamp('2016-03-09', tz='UTC')],
'expiration_date': [Timestamp('2016-01-11', tz='UTC'),
Timestamp('2016-02-08', tz='UTC'),
Timestamp('2016-03-09', tz='UTC')],
'auto_close_date': [Timestamp('2016-01-11', tz='UTC'),
Timestamp('2016-02-08', tz='UTC'),
Timestamp('2016-03-09', tz='UTC')],
'tick_size': [0.001] * 3,
'multiplier': [1000.0] * 3,
'exchange': ['CME'] * 3,
})
return pd.concat([fo_frame, ba_frame, bz_frame])
@classmethod
def make_future_minute_bar_data(cls):
@@ -593,6 +628,27 @@ def record_current_contract(algo, data):
9,
"Should have remained BAM16")
def test_history_sid_session_delivery_predicate(self):
cf = self.data_portal.asset_finder.create_continuous_future(
'BZ', 0, 'calendar')
window = self.data_portal.get_history_window(
[cf],
Timestamp('2016-01-11 18:01', tz='US/Eastern').tz_convert('UTC'),
3, '1d', 'sid')
self.assertEqual(window.loc['2016-01-08', cf],
10,
"Should be BZF16 at beginning of window.")
self.assertEqual(window.loc['2016-01-11', cf],
12,
"Should be BZH16 after first roll, having skipped "
"over BZG16.")
self.assertEqual(window.loc['2016-01-12', cf],
12,
"Should have remained BZG16")
def test_history_sid_session_secondary(self):
cf = self.data_portal.asset_finder.create_continuous_future(
'FO', 1, 'calendar')
@@ -1094,15 +1150,16 @@ class OrderedContractsTestCase(WithAssetFinder,
@classmethod
def make_root_symbols_info(self):
return pd.DataFrame({
'root_symbol': ['FO'],
'root_symbol_id': [1],
'exchange': ['CME']})
'root_symbol': ['FO', 'BA'],
'root_symbol_id': [1, 2],
'exchange': ['CME', 'CME']})
@classmethod
def make_futures_info(self):
return DataFrame({
fo_frame = DataFrame({
'root_symbol': ['FO'] * 4,
'asset_name': ['Foo'] * 4,
'symbol': ['FOF16', 'FOG16', 'FOH16', 'FOJ16'],
'sid': range(1, 5),
'start_date': pd.date_range('2015-01-01', periods=4, tz="UTC"),
'end_date': pd.date_range('2016-01-01', periods=4, tz="UTC"),
@@ -1117,6 +1174,29 @@ class OrderedContractsTestCase(WithAssetFinder,
'multiplier': [1000.0] * 4,
'exchange': ['CME'] * 4,
})
# BA is set up to test a quarterly roll, to test Eurodollar-like
# behavior
# The roll should go from BAH16 -> BAM16
ba_frame = DataFrame({
'root_symbol': ['BA'] * 3,
'asset_name': ['Bar'] * 3,
'symbol': ['BAF16', 'BAG16', 'BAH16'],
'sid': range(5, 8),
'start_date': pd.date_range('2015-01-01', periods=3, tz="UTC"),
'end_date': pd.date_range('2016-01-01', periods=3, tz="UTC"),
'notice_date': pd.date_range('2016-01-01', periods=3, tz="UTC"),
'expiration_date': pd.date_range(
'2016-01-01', periods=3, tz="UTC"),
'expiration_date': pd.date_range(
'2016-01-01', periods=3, tz="UTC"),
'auto_close_date': pd.date_range(
'2016-01-01', periods=3, tz="UTC"),
'tick_size': [0.001] * 3,
'multiplier': [1000.0] * 3,
'exchange': ['CME'] * 3,
})
return pd.concat([fo_frame, ba_frame])
def test_contract_at_offset(self):
contract_sids = array([1, 2, 3, 4], dtype=int64)
@@ -1197,6 +1277,23 @@ class OrderedContractsTestCase(WithAssetFinder,
self.assertEquals([4], list(chain),
"[4] should be active beginning at its start date.")
def test_delivery_predicate(self):
contract_sids = range(5, 8)
contracts = deque(self.asset_finder.retrieve_all(contract_sids))
oc = OrderedContracts('BA', contracts,
chain_predicate=partial(delivery_predicate,
set(['F', 'H'])))
# Test sid 1 as days increment, as the sessions march forward
# a contract should be added per day, until all defined contracts
# are returned.
chain = oc.active_chain(5, pd.Timestamp('2015-01-05', tz='UTC').value)
self.assertEquals(
[5, 7], list(chain),
"Contract BAG16 (sid=6) should be ommitted from chain, since "
"it does not satisfy the roll predicate.")
class NoPrefetchContinuousFuturesTestCase(ContinuousFuturesTestCase):
DATA_PORTAL_MINUTE_HISTORY_PREFETCH = 0
+15 -3
View File
@@ -49,7 +49,11 @@ from zipline.errors import (
from . import (
Asset, Equity, Future,
)
from . continuous_futures import OrderedContracts, ContinuousFuture
from . continuous_futures import (
OrderedContracts,
ContinuousFuture,
CHAIN_PREDICATES
)
from .asset_writer import (
check_version_info,
split_delimited_symbol,
@@ -183,6 +187,10 @@ class AssetFinder(object):
engine : str or SQLAlchemy.engine
An engine with a connection to the asset database to use, or a string
that can be parsed by SQLAlchemy as a URI.
future_chain_predicates : dict
A dict mapping future root symbol to a predicate function which accepts
a contract as a parameter and returns whether or not the contract should be
included in the chain.
See Also
--------
@@ -193,7 +201,7 @@ class AssetFinder(object):
PERSISTENT_TOKEN = "<AssetFinder>"
@preprocess(engine=coerce_string_to_eng)
def __init__(self, engine):
def __init__(self, engine, future_chain_predicates=CHAIN_PREDICATES):
self.engine = engine
metadata = sa.MetaData(bind=engine)
metadata.reflect(only=asset_db_table_names)
@@ -213,6 +221,8 @@ class AssetFinder(object):
# retrieve_asset will populate the cache on first retrieval.
self._caches = (self._asset_cache, self._asset_type_cache) = {}, {}
self._future_chain_predicates = future_chain_predicates \
if future_chain_predicates is not None else {}
self._ordered_contracts = {}
# Populated on first call to `lifetimes`.
@@ -903,7 +913,9 @@ class AssetFinder(object):
except KeyError:
contract_sids = self._get_contract_sids(root_symbol)
contracts = deque(self.retrieve_all(contract_sids))
oc = OrderedContracts(root_symbol, contracts)
chain_predicate = self._future_chain_predicates.get(root_symbol,
None)
oc = OrderedContracts(root_symbol, contracts, chain_predicate)
self._ordered_contracts[root_symbol] = oc
return oc
+39 -12
View File
@@ -29,6 +29,8 @@ from cpython.object cimport (
)
from cpython cimport bool
from functools import partial
from numpy import array, empty, iinfo
from numpy cimport long_t, int64_t
from pandas import Timestamp
@@ -37,6 +39,21 @@ import warnings
from zipline.utils.calendars import get_calendar
def delivery_predicate(codes, contract):
# This relies on symbols that are construct following a pattern of
# root symbol + delivery code + year, e.g. PLF16
# This check would be more robust if the future contract class had
# a 'delivery_month' member.
delivery_code = contract.symbol[-3]
return delivery_code in codes
CHAIN_PREDICATES = {
'ME': partial(delivery_predicate, set(['H', 'M', 'U', 'Z'])),
'PL': partial(delivery_predicate, set(['F', 'J', 'N', 'V'])),
'PA': partial(delivery_predicate, set(['H', 'M', 'U', 'Z']))
}
cdef class ContinuousFuture:
"""
Represents a specifier for a chain of future contracts, where the
@@ -289,18 +306,23 @@ cdef class OrderedContracts(object):
auto_close_dates : long[:]
The auto close dates of the contracts in the chain.
Corresponds by index with contract_sids.
future_chain_predicates : dict
A dict mapping root symbol to a predicate function which accepts a contract
as a parameter and returns whether or not the contract should be included in the
chain.
Instances of this class are used by the simulation engine, but not
exposed to the algorithm.
"""
cdef readonly object root_symbol
cdef readonly object head_contract
cdef readonly object _head_contract
cdef readonly dict sid_to_contract
cdef readonly int64_t _start_date
cdef readonly int64_t _end_date
cdef readonly object chain_predicate
def __init__(self, object root_symbol, object contracts):
def __init__(self, object root_symbol, object contracts, object chain_predicate=None):
self.root_symbol = root_symbol
@@ -309,12 +331,11 @@ cdef class OrderedContracts(object):
self._start_date = iinfo('int64').max
self._end_date = 0
contract = contracts.popleft()
self.head_contract = ContractNode(contract)
self._start_date = min(contract.start_date.value, self._start_date)
self._end_date = max(contract.end_date.value, self._end_date)
self.sid_to_contract[contract.sid] = self.head_contract
prev = self.head_contract
if chain_predicate is None:
chain_predicate = lambda x: True
self._head_contract = None
prev = None
while contracts:
contract = contracts.popleft()
@@ -322,24 +343,30 @@ cdef class OrderedContracts(object):
# next contract.
# This is in lieu of more explicit support for
# contracts with quarterly rolls. e.g. Eurodollar
if contract.start_date > prev.contract.auto_close_date:
if prev is not None and contract.start_date > prev.contract.auto_close_date:
continue
if not chain_predicate(contract):
continue
self._start_date = min(contract.start_date.value, self._start_date)
self._end_date = max(contract.end_date.value, self._end_date)
curr = ContractNode(contract)
self.sid_to_contract[contract.sid] = curr
if self._head_contract is None:
self._head_contract = curr
prev = curr
continue
curr.prev = prev
prev.next = curr
prev = curr
self.sid_to_contract[contract.sid] = curr
cpdef long_t contract_before_auto_close(self, long_t dt_value):
"""
Get the contract with next upcoming auto close date.
"""
curr = self.head_contract
curr = self._head_contract
while curr.next is not None:
if curr.contract.auto_close_date.value > dt_value:
break
+6 -2
View File
@@ -20,6 +20,7 @@ from six import string_types
from sqlalchemy import create_engine
from zipline.assets import AssetDBWriter, AssetFinder
from zipline.assets.continuous_futures import CHAIN_PREDICATES
from zipline.data.loader import load_market_data
from zipline.utils.calendars import get_calendar
from zipline.utils.memoize import remember_last
@@ -80,7 +81,8 @@ class TradingEnvironment(object):
bm_symbol='^GSPC',
exchange_tz="US/Eastern",
trading_calendar=None,
asset_db_path=':memory:'
asset_db_path=':memory:',
future_chain_predicates=CHAIN_PREDICATES,
):
self.bm_symbol = bm_symbol
@@ -106,7 +108,9 @@ class TradingEnvironment(object):
if engine is not None:
AssetDBWriter(engine).init_db()
self.asset_finder = AssetFinder(engine)
self.asset_finder = AssetFinder(
engine,
future_chain_predicates=future_chain_predicates)
else:
self.asset_finder = None
+4
View File
@@ -451,6 +451,8 @@ class WithTradingEnvironment(WithAssetFinder,
The max date to forward to the constructed TradingEnvironment.
TRADING_ENV_TRADING_CALENDAR : pd.DatetimeIndex
The trading calendar to use for the class's TradingEnvironment.
TRADING_ENV_FUTURE_CHAIN_PREDICATES : dict
The roll predicates to apply when creating contract chains.
Methods
-------
@@ -468,6 +470,7 @@ class WithTradingEnvironment(WithAssetFinder,
--------
:class:`zipline.finance.trading.TradingEnvironment`
"""
TRADING_ENV_FUTURE_CHAIN_PREDICATES = None
@classmethod
def make_load_function(cls):
@@ -479,6 +482,7 @@ class WithTradingEnvironment(WithAssetFinder,
load=cls.make_load_function(),
asset_db_path=cls.asset_finder.engine,
trading_calendar=cls.trading_calendar,
future_chain_predicates=cls.TRADING_ENV_FUTURE_CHAIN_PREDICATES,
)
@classmethod