diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index 3d8c9a0c..e5c4e2ed 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -1440,7 +1440,7 @@ class TestAlgoScript(WithLogger, STRING_TYPE_NAMES) ARG_TYPE_TEST_CASES = ( ('history__assets', (bad_type_history_assets, - ASSET_OR_STRING_TYPE_NAMES, + ASSET_OR_STRING_OR_CF_TYPE_NAMES, True)), ('history__fields', (bad_type_history_fields, STRING_TYPE_NAMES_STRING, @@ -1458,10 +1458,12 @@ class TestAlgoScript(WithLogger, ('is_stale__assets', (bad_type_is_stale_assets, 'Asset', True)), ('can_trade__assets', (bad_type_can_trade_assets, 'Asset', True)), ('history_kwarg__assets', - (bad_type_history_assets_kwarg, ASSET_OR_STRING_TYPE_NAMES, True)), + (bad_type_history_assets_kwarg, + ASSET_OR_STRING_OR_CF_TYPE_NAMES, + True)), ('history_kwarg_bad_list__assets', (bad_type_history_assets_kwarg_list, - ASSET_OR_STRING_TYPE_NAMES, + ASSET_OR_STRING_OR_CF_TYPE_NAMES, True)), ('history_kwarg__fields', (bad_type_history_fields_kwarg, STRING_TYPE_NAMES_STRING, True)), diff --git a/tests/test_continuous_futures.py b/tests/test_continuous_futures.py index 6d8ea811..aae3c41e 100644 --- a/tests/test_continuous_futures.py +++ b/tests/test_continuous_futures.py @@ -15,14 +15,23 @@ from textwrap import dedent -from numpy import array, int64 +from numpy import ( + arange, + array, + int64, + full, + repeat, +) +from numpy.testing import assert_almost_equal import pandas as pd from pandas import Timestamp, DataFrame from zipline import TradingAlgorithm from zipline.assets.continuous_futures import OrderedContracts +from zipline.data.minute_bars import FUTURES_MINUTES_PER_DAY from zipline.testing.fixtures import ( WithCreateBarData, + WithBcolzFutureMinuteBarReader, WithSimParams, ZiplineTestCase, ) @@ -30,6 +39,7 @@ from zipline.testing.fixtures import ( class ContinuousFuturesTestCase(WithCreateBarData, WithSimParams, + WithBcolzFutureMinuteBarReader, ZiplineTestCase): START_DATE = pd.Timestamp('2015-01-05', tz='UTC') @@ -66,17 +76,17 @@ class ContinuousFuturesTestCase(WithCreateBarData, Timestamp('2022-08-19', tz='UTC')], 'notice_date': [Timestamp('2016-01-26', tz='UTC'), Timestamp('2016-02-26', tz='UTC'), - Timestamp('2016-03-26', tz='UTC'), + Timestamp('2016-03-24', tz='UTC'), Timestamp('2016-04-26', tz='UTC'), Timestamp('2022-01-26', tz='UTC')], 'expiration_date': [Timestamp('2016-01-26', tz='UTC'), Timestamp('2016-02-26', tz='UTC'), - Timestamp('2016-03-26', tz='UTC'), + Timestamp('2016-03-24', tz='UTC'), Timestamp('2016-04-26', tz='UTC'), Timestamp('2022-01-26', tz='UTC')], 'auto_close_date': [Timestamp('2016-01-26', tz='UTC'), Timestamp('2016-02-26', tz='UTC'), - Timestamp('2016-03-26', tz='UTC'), + Timestamp('2016-03-24', tz='UTC'), Timestamp('2016-04-26', tz='UTC'), Timestamp('2022-01-26', tz='UTC')], 'tick_size': [0.001] * 5, @@ -84,6 +94,36 @@ class ContinuousFuturesTestCase(WithCreateBarData, 'exchange': ['CME'] * 5, }) + @classmethod + def make_future_minute_bar_data(cls): + tc = cls.trading_calendar + start = pd.Timestamp('2016-01-26', tz='UTC') + end = pd.Timestamp('2016-04-29', tz='UTC') + dts = tc.minutes_for_sessions_in_range(start, end) + sessions = tc.sessions_in_range(start, end) + # Generate values in the .0XX space such that the first session + # has 0.001 added to all values, the second session has 0.002, + # etc. + markers = repeat( + arange(0.001, 0.001 * (len(sessions) + 1), 0.001), + FUTURES_MINUTES_PER_DAY) + vol_markers = repeat( + arange(1, (len(sessions) + 1), 1, dtype=int64), + FUTURES_MINUTES_PER_DAY) + base_df = pd.DataFrame( + { + 'open': full(len(dts), 100.2) + markers, + 'high': full(len(dts), 100.9) + markers, + 'low': full(len(dts), 100.1) + markers, + 'close': full(len(dts), 100.5) + markers, + 'volume': full(len(dts), 1000, dtype=int64) + vol_markers, + }, + index=dts) + # Add the sid to the ones place of the prices, so that the ones + # place can be used to eyeball the source contract. + for i in range(5): + yield i, base_df + i + def test_create_continuous_future(self): cf_primary = self.asset_finder.create_continuous_future( 'FO', 0, 'calendar') @@ -287,6 +327,180 @@ def record_current_contract(algo, data): 'End of secondary chain should be FOJ16 on second ' 'session.') + def test_history_sid_session(self): + cf = self.data_portal.asset_finder.create_continuous_future( + 'FO', 0, 'calendar') + window = self.data_portal.get_history_window( + [cf], + Timestamp('2016-03-03 18:01', tz='US/Eastern').tz_convert('UTC'), + 30, '1d', 'sid') + + self.assertEqual(window.loc['2016-01-25', cf], + 0, + "Should be FOF16 at beginning of window.") + + self.assertEqual(window.loc['2016-01-26', cf], + 1, + "Should be FOG16 after first roll.") + + self.assertEqual(window.loc['2016-02-25', cf], + 1, + "Should be FOF16 on session before roll.") + + self.assertEqual(window.loc['2016-02-26', cf], + 2, + "Should be FOH16 on session with roll.") + + self.assertEqual(window.loc['2016-02-29', cf], + 2, + "Should be FOH16 on session after roll.") + + # Advance the window a month. + window = self.data_portal.get_history_window( + [cf], + Timestamp('2016-04-06 18:01', tz='US/Eastern').tz_convert('UTC'), + 30, '1d', 'sid') + + self.assertEqual(window.loc['2016-02-25', cf], + 1, + "Should be FOG16 at beginning of window.") + + self.assertEqual(window.loc['2016-02-26', cf], + 2, + "Should be FOH16 on session with roll.") + + self.assertEqual(window.loc['2016-02-29', cf], + 2, + "Should be FOH16 on session after roll.") + + self.assertEqual(window.loc['2016-03-24', cf], + 3, + "Should be FOJ16 on session with roll.") + + self.assertEqual(window.loc['2016-03-28', cf], + 3, + "Should be FOJ16 on session after roll.") + + def test_history_sid_minute(self): + cf = self.data_portal.asset_finder.create_continuous_future( + 'FO', 0, 'calendar') + window = self.data_portal.get_history_window( + [cf.sid], + Timestamp('2016-01-25 18:01', tz='US/Eastern').tz_convert('UTC'), + 30, '1m', 'sid') + + self.assertEqual(window.loc['2016-01-25 22:32', cf], + 0, + "Should be FOF16 at beginning of window. A minute " + "which is in the 01-25 session, before the roll.") + + self.assertEqual(window.loc['2016-01-25 23:00', cf], + 0, + "Should be FOF16 on on minute before roll minute.") + + self.assertEqual(window.loc['2016-01-25 23:01', cf], + 1, + "Should be FOG16 on minute after roll.") + + # Advance the window a day. + window = self.data_portal.get_history_window( + [cf], + Timestamp('2016-01-26 18:01', tz='US/Eastern').tz_convert('UTC'), + 30, '1m', 'sid') + + self.assertEqual(window.loc['2016-01-26 22:32', cf], + 1, + "Should be FOG16 at beginning of window.") + + self.assertEqual(window.loc['2016-01-26 23:01', cf], + 1, + "Should remain FOG16 on next session.") + + def test_history_close_session(self): + cf = self.data_portal.asset_finder.create_continuous_future( + 'FO', 0, 'calendar') + window = self.data_portal.get_history_window( + [cf.sid], Timestamp('2016-03-06', tz='UTC'), 30, '1d', 'close') + + assert_almost_equal( + window.loc['2016-01-26', cf], + 101.501, + err_msg="At beginning of window, should be FOG16's first value.") + + assert_almost_equal( + window.loc['2016-02-26', cf], + 102.524, + err_msg="On session with roll, should be FOH16's 24th value.") + + assert_almost_equal( + window.loc['2016-02-29', cf], + 102.525, + err_msg="After roll, Should be FOH16's 25th value.") + + # Advance the window a month. + window = self.data_portal.get_history_window( + [cf.sid], Timestamp('2016-04-06', tz='UTC'), 30, '1d', 'close') + + assert_almost_equal( + window.loc['2016-02-24', cf], + 101.522, + err_msg="At beginning of window, should be FOG16's 22nd value.") + + assert_almost_equal( + window.loc['2016-02-26', cf], + 102.524, + err_msg="On session with roll, should be FOH16's 24th value.") + + assert_almost_equal( + window.loc['2016-02-29', cf], + 102.525, + err_msg="On session after roll, should be FOH16's 25th value.") + + assert_almost_equal( + window.loc['2016-03-24', cf], + 103.543, + err_msg="On session with roll, should be FOJ16's 43rd value.") + + assert_almost_equal( + window.loc['2016-03-28', cf], + 103.544, + err_msg="On session after roll, Should be FOJ16's 44th value.") + + def test_history_close_minute(self): + cf = self.data_portal.asset_finder.create_continuous_future( + 'FO', 0, 'calendar') + window = self.data_portal.get_history_window( + [cf.sid], + Timestamp('2016-02-25 18:01', tz='US/Eastern').tz_convert('UTC'), + 30, '1m', 'close') + + self.assertEqual(window.loc['2016-02-25 22:32', cf], + 101.523, + "Should be FOG16 at beginning of window. A minute " + "which is in the 02-25 session, before the roll.") + + self.assertEqual(window.loc['2016-02-25 23:00', cf], + 101.523, + "Should be FOG16 on on minute before roll minute.") + + self.assertEqual(window.loc['2016-02-25 23:01', cf], + 102.524, + "Should be FOH16 on minute after roll.") + + # Advance the window a session. + window = self.data_portal.get_history_window( + [cf], + Timestamp('2016-02-28 18:01', tz='US/Eastern').tz_convert('UTC'), + 30, '1m', 'close') + + self.assertEqual(window.loc['2016-02-26 22:32', cf], + 102.524, + "Should be FOH16 at beginning of window.") + + self.assertEqual(window.loc['2016-02-28 23:01', cf], + 102.525, + "Should remain FOH16 on next session.") + class OrderedContractsTestCase(ZiplineTestCase): diff --git a/zipline/_protocol.pyx b/zipline/_protocol.pyx index c90a5224..ff7033ff 100644 --- a/zipline/_protocol.pyx +++ b/zipline/_protocol.pyx @@ -588,7 +588,8 @@ cdef class BarData: @check_parameters(('assets', 'fields', 'bar_count', 'frequency'), - ((Asset,) + string_types, string_types, int, + ((Asset, ContinuousFuture) + string_types, string_types, + int, string_types)) def history(self, assets, fields, bar_count, frequency): """ diff --git a/zipline/assets/continuous_futures.pyx b/zipline/assets/continuous_futures.pyx index 2896403c..8c83f1c5 100644 --- a/zipline/assets/continuous_futures.pyx +++ b/zipline/assets/continuous_futures.pyx @@ -106,7 +106,7 @@ cdef class ContinuousFuture: Cython rich comparison method. This is used in place of various equality checkers in pure python. """ - cdef int x_as_int, y_as_int + cdef long_t x_as_int, y_as_int try: x_as_int = PyNumber_Index(x) diff --git a/zipline/assets/roll_finder.py b/zipline/assets/roll_finder.py index c26fb22a..47862900 100644 --- a/zipline/assets/roll_finder.py +++ b/zipline/assets/roll_finder.py @@ -15,6 +15,8 @@ from abc import ABCMeta, abstractmethod from six import with_metaclass +from pandas import Timestamp + class RollFinder(with_metaclass(ABCMeta, object)): """ @@ -42,6 +44,33 @@ class RollFinder(with_metaclass(ABCMeta, object)): """ raise NotImplemented + @abstractmethod + def get_rolls(self, root_symbol, start, end, offset): + """ + Get the rolls, i.e. the session at which to hop from contract to + contract in the chain. + + Parameters + ---------- + root_symbol : str + The root symbol for which to calculate rolls. + start : Timestamp + Start of the date range. + end : Timestamp + End of the date range. + offset : int + Offset from the primary. + + Returns + ------- + rolls - list[tuple(sid, roll_date)] + A list of rolls, where first value is the first active `sid`, + and the `roll_date` on which to hop to the next contract. + The last pair in the chain has a value of `None` since the roll + is after the range. + """ + raise NotImplemented + class CalendarRollFinder(RollFinder): """ @@ -61,3 +90,23 @@ class CalendarRollFinder(RollFinder): # Here is where a volume check would be. primary = primary_candidate return oc.contract_at_offset(primary, offset) + + def get_rolls(self, root_symbol, start, end, offset): + oc = self.asset_finder.get_ordered_contracts(root_symbol) + primary_at_end = self.get_contract_center(root_symbol, end, 0) + for i, sid in enumerate(oc.contract_sids): + if sid == primary_at_end: + break + i += offset + first = oc.contract_sids[i] + rolls = [(first, None)] + i -= 1 + auto_close_date = Timestamp(oc.auto_close_dates[i - offset], tz='UTC') + while auto_close_date > start and i > -1: + rolls.insert(0, (oc.contract_sids[i - offset], + auto_close_date)) + i -= 1 + auto_close_date = Timestamp(oc.auto_close_dates[i - offset], + tz='UTC') + + return rolls diff --git a/zipline/data/continuous_future_reader.py b/zipline/data/continuous_future_reader.py new file mode 100644 index 00000000..0c960f28 --- /dev/null +++ b/zipline/data/continuous_future_reader.py @@ -0,0 +1,358 @@ +import numpy as np +from zipline.data.session_bars import SessionBarReader + + +class ContinuousFutureSessionBarReader(SessionBarReader): + + def __init__(self, bar_reader, roll_finders): + self._bar_reader = bar_reader + self._roll_finders = roll_finders + + def load_raw_arrays(self, columns, start_date, end_date, assets): + """ + Parameters + ---------- + fields : list of str + 'sid' + start_dt: Timestamp + Beginning of the window range. + end_dt: Timestamp + End of the window range. + sids : list of int + The asset identifiers in the window. + + Returns + ------- + list of np.ndarray + A list with an entry per field of ndarrays with shape + (minutes in range, sids) with a dtype of float64, containing the + values for the respective field over start and end dt range. + """ + rolls_by_asset = {} + for asset in assets: + rf = self._roll_finders[asset.roll_style] + rolls_by_asset[asset] = rf.get_rolls( + asset.root_symbol, start_date, end_date, asset.offset) + num_sessions = len( + self.trading_calendar.sessions_in_range(start_date, end_date)) + shape = num_sessions, len(assets) + + results = [] + + tc = self._bar_reader.trading_calendar + sessions = tc.sessions_in_range(start_date, end_date) + + # Get partitions + partitions_by_asset = {} + for asset in assets: + rolls_by_asset[asset] = rf.get_rolls( + asset.root_symbol, start_date, end_date, asset.offset) + partitions = [] + partitions_by_asset[asset] = partitions + rolls = rolls_by_asset[asset] + start = start_date + for roll in rolls: + sid, roll_date = roll + start_loc = sessions.get_loc(start) + if roll_date is not None: + end = roll_date - sessions.freq + end_loc = sessions.get_loc(end) + else: + end = end_date + end_loc = len(sessions) - 1 + partitions.append((sid, start, end, start_loc, end_loc)) + if roll[-1] is not None: + start = sessions[end_loc + 1] + + for column in columns: + if column != 'volume' and column != 'sid': + out = np.full(shape, np.nan) + else: + out = np.zeros(shape, dtype=np.int64) + for i, asset in enumerate(assets): + partitions = partitions_by_asset[asset] + for sid, start, end, start_loc, end_loc in partitions: + if column != 'sid': + result = self._bar_reader.load_raw_arrays( + [column], start, end, [sid])[0][:, 0] + else: + result = int(sid) + out[start_loc:end_loc + 1, i] = result + results.append(out) + return results + + @property + def last_available_dt(self): + """ + Returns + ------- + dt : pd.Timestamp + The last session for which the reader can provide data. + """ + return self._bar_reader.last_available_dt + + @property + def trading_calendar(self): + """ + Returns the zipline.utils.calendar.trading_calendar used to read + the data. Can be None (if the writer didn't specify it). + """ + return self._bar_reader.trading_calendar + + @property + def first_trading_day(self): + """ + Returns + ------- + dt : pd.Timestamp + The first trading day (session) for which the reader can provide + data. + """ + return self._bar_reader.first_trading_day + + def get_value(self, continuous_future, dt, field): + """ + Retrieve the value at the given coordinates. + + Parameters + ---------- + sid : int + The asset identifier. + dt : pd.Timestamp + The timestamp for the desired data point. + field : string + The OHLVC name for the desired data point. + + Returns + ------- + value : float|int + The value at the given coordinates, ``float`` for OHLC, ``int`` + for 'volume'. + + Raises + ------ + NoDataOnDate + If the given dt is not a valid market minute (in minute mode) or + session (in daily mode) according to this reader's tradingcalendar. + """ + rf = self._roll_finders[continuous_future.roll] + sid = (rf.get_contract_center(continuous_future.root_symbol, + dt, + continuous_future.offset)) + return self._bar_reader.get_value(sid, dt, field) + + def get_last_traded_dt(self, asset, dt): + """ + Get the latest minute on or before ``dt`` in which ``asset`` traded. + + If there are no trades on or before ``dt``, returns ``pd.NaT``. + + Parameters + ---------- + asset : zipline.asset.Asset + The asset for which to get the last traded minute. + dt : pd.Timestamp + The minute at which to start searching for the last traded minute. + + Returns + ------- + last_traded : pd.Timestamp + The dt of the last trade for the given asset, using the input + dt as a vantage point. + """ + rf = self._roll_finders[asset.roll_style] + sid = (rf.get_contract_center(asset.root_symbol, + dt, + asset.offset)) + contract = rf.asset_finder.retrieve_asset(sid) + return self._bar_reader.get_last_traded_dt(contract, dt) + + @property + def sessions(self): + """ + Returns + ------- + sessions : DatetimeIndex + All session labels (unionining the range for all assets) which the + reader can provide. + """ + return self._bar_reader.sessions + + +class ContinuousFutureMinuteBarReader(SessionBarReader): + + def __init__(self, bar_reader, roll_finders): + self._bar_reader = bar_reader + self._roll_finders = roll_finders + + def load_raw_arrays(self, columns, start_date, end_date, assets): + """ + Parameters + ---------- + fields : list of str + 'open', 'high', 'low', 'close', or 'volume' + start_dt: Timestamp + Beginning of the window range. + end_dt: Timestamp + End of the window range. + sids : list of int + The asset identifiers in the window. + + Returns + ------- + list of np.ndarray + A list with an entry per field of ndarrays with shape + (minutes in range, sids) with a dtype of float64, containing the + values for the respective field over start and end dt range. + """ + rolls_by_asset = {} + + tc = self.trading_calendar + start_session = tc.minute_to_session_label(start_date) + end_session = tc.minute_to_session_label(end_date) + + for asset in assets: + rf = self._roll_finders[asset.roll_style] + rolls_by_asset[asset] = rf.get_rolls( + asset.root_symbol, + start_session, + end_session, asset.offset) + + sessions = tc.sessions_in_range(start_date, end_date) + + minutes = tc.minutes_in_range(start_date, end_date) + num_minutes = len(minutes) + shape = num_minutes, len(assets) + + results = [] + + # Get partitions + partitions_by_asset = {} + for asset in assets: + rolls_by_asset[asset] = rf.get_rolls( + asset.root_symbol, start_date, end_date, asset.offset) + partitions = [] + partitions_by_asset[asset] = partitions + rolls = rolls_by_asset[asset] + start = start_date + for roll in rolls: + sid, roll_date = roll + start_loc = minutes.searchsorted(start) + if roll_date is not None: + _, end = tc.open_and_close_for_session( + roll_date - sessions.freq) + end_loc = minutes.searchsorted(end) + else: + end = end_date + end_loc = len(minutes) - 1 + partitions.append((sid, start, end, start_loc, end_loc)) + if roll[-1] is not None: + start, _ = tc.open_and_close_for_session( + tc.minute_to_session_label(minutes[end_loc + 1])) + + for column in columns: + if column != 'volume': + out = np.full(shape, np.nan) + else: + out = np.zeros(shape, dtype=np.uint32) + for i, asset in enumerate(assets): + partitions = partitions_by_asset[asset] + for sid, start, end, start_loc, end_loc in partitions: + if column != 'sid': + result = self._bar_reader.load_raw_arrays( + [column], start, end, [sid])[0][:, 0] + else: + result = int(sid) + out[start_loc:end_loc + 1, i] = result + results.append(out) + return results + + @property + def last_available_dt(self): + """ + Returns + ------- + dt : pd.Timestamp + The last session for which the reader can provide data. + """ + return self._bar_reader.last_available_dt + + @property + def trading_calendar(self): + """ + Returns the zipline.utils.calendar.trading_calendar used to read + the data. Can be None (if the writer didn't specify it). + """ + return self._bar_reader.trading_calendar + + @property + def first_trading_day(self): + """ + Returns + ------- + dt : pd.Timestamp + The first trading day (session) for which the reader can provide + data. + """ + return self._bar_reader.first_trading_day + + def get_value(self, continuous_future, dt, field): + """ + Retrieve the value at the given coordinates. + + Parameters + ---------- + sid : int + The asset identifier. + dt : pd.Timestamp + The timestamp for the desired data point. + field : string + The OHLVC name for the desired data point. + + Returns + ------- + value : float|int + The value at the given coordinates, ``float`` for OHLC, ``int`` + for 'volume'. + + Raises + ------ + NoDataOnDate + If the given dt is not a valid market minute (in minute mode) or + session (in daily mode) according to this reader's tradingcalendar. + """ + rf = self._roll_finders[continuous_future.roll_style] + sid = (rf.get_contract_center(continuous_future.root_symbol, + dt, + continuous_future.offset)) + return self._bar_reader.get_value(sid, dt, field) + + def get_last_traded_dt(self, asset, dt): + """ + Get the latest minute on or before ``dt`` in which ``asset`` traded. + + If there are no trades on or before ``dt``, returns ``pd.NaT``. + + Parameters + ---------- + asset : zipline.asset.Asset + The asset for which to get the last traded minute. + dt : pd.Timestamp + The minute at which to start searching for the last traded minute. + + Returns + ------- + last_traded : pd.Timestamp + The dt of the last trade for the given asset, using the input + dt as a vantage point. + """ + rf = self._roll_finders[asset.roll_style] + sid = (rf.get_contract_center(asset.root_symbol, + dt, + asset.offset)) + contract = rf.asset_finder.retrieve_asset(sid) + return self._bar_reader.get_last_traded_dt(contract, dt) + + @property + def sessions(self): + return self._bar_reader.sessions diff --git a/zipline/data/data_portal.py b/zipline/data/data_portal.py index 58fdc91b..f4f92931 100644 --- a/zipline/data/data_portal.py +++ b/zipline/data/data_portal.py @@ -17,6 +17,7 @@ from operator import mul from logbook import Logger import numpy as np +from numpy import float64, int64 import pandas as pd from pandas.tslib import normalize_date from six import iteritems @@ -24,6 +25,10 @@ from six.moves import reduce from zipline.assets import Asset, Future, Equity from zipline.assets.continuous_futures import ContinuousFuture +from zipline.data.continuous_future_reader import ( + ContinuousFutureSessionBarReader, + ContinuousFutureMinuteBarReader +) from zipline.assets.roll_finder import CalendarRollFinder from zipline.data.dispatch_bar_reader import ( AssetDispatchMinuteBarReader, @@ -63,6 +68,7 @@ BASE_FIELDS = frozenset([ "volume", "price", "contract", + "sid", "last_traded", ]) @@ -182,8 +188,19 @@ class DataPortal(object): if aligned_future_minute_reader is not None: aligned_minute_readers[Future] = aligned_future_minute_reader + aligned_minute_readers[ContinuousFuture] = \ + ContinuousFutureMinuteBarReader( + aligned_future_minute_reader, + self._roll_finders, + ) + if aligned_future_session_reader is not None: aligned_session_readers[Future] = aligned_future_session_reader + aligned_session_readers[ContinuousFuture] = \ + ContinuousFutureSessionBarReader( + aligned_future_session_reader, + self._roll_finders, + ) _dispatch_minute_reader = AssetDispatchMinuteBarReader( self.trading_calendar, @@ -718,6 +735,10 @@ class DataPortal(object): elif field_to_use == 'volume': minute_value = self._daily_aggregator.volumes( assets, end_dt) + elif field_to_use == 'sid': + minute_value = [ + int(self._get_current_contract(asset, end_dt)) + for asset in assets] # append the partial day. daily_data[-1] = minute_value @@ -801,7 +822,7 @@ class DataPortal(object): ------- A dataframe containing the requested data. """ - if field not in OHLCVP_FIELDS: + if field not in OHLCVP_FIELDS and field != 'sid': raise ValueError("Invalid field: {0}".format(field)) if frequency == "1d": @@ -929,10 +950,11 @@ class DataPortal(object): """ bar_count = len(days_in_window) # create an np.array of size bar_count + dtype = float64 if field != 'sid' else int64 if extra_slot: - return_array = np.zeros((bar_count + 1, len(assets))) + return_array = np.zeros((bar_count + 1, len(assets)), dtype=dtype) else: - return_array = np.zeros((bar_count, len(assets))) + return_array = np.zeros((bar_count, len(assets)), dtype=dtype) if field != "volume": # volumes default to 0, so we don't need to put NaNs in the array diff --git a/zipline/data/dispatch_bar_reader.py b/zipline/data/dispatch_bar_reader.py index dab38439..6a39a2ec 100644 --- a/zipline/data/dispatch_bar_reader.py +++ b/zipline/data/dispatch_bar_reader.py @@ -17,7 +17,7 @@ from abc import ABCMeta, abstractmethod from numpy import ( full, nan, - uint32, + int64, zeros ) from six import iteritems, with_metaclass @@ -70,10 +70,10 @@ class AssetDispatchBarReader(with_metaclass(ABCMeta)): return self._dt_window_size(start_dt, end_dt), num_sids def _make_raw_array_out(self, field, shape): - if field != 'volume': + if field != 'volume' and field != 'sid': out = full(shape, nan) else: - out = zeros(shape, dtype=uint32) + out = zeros(shape, dtype=int64) return out @property @@ -94,7 +94,7 @@ class AssetDispatchBarReader(with_metaclass(ABCMeta)): def get_value(self, sid, dt, field): asset = self._asset_finder.retrieve_asset(sid) r = self._readers[type(asset)] - return r.get_value(sid, dt, field) + return r.get_value(asset, dt, field) def get_last_traded_dt(self, asset, dt): r = self._readers[type(asset)] diff --git a/zipline/data/history_loader.py b/zipline/data/history_loader.py index ab31df80..b53f2ced 100644 --- a/zipline/data/history_loader.py +++ b/zipline/data/history_loader.py @@ -24,6 +24,7 @@ from pandas.tslib import normalize_date from six import with_metaclass +from zipline.lib._int64window import AdjustedArrayWindow as Int64Window from zipline.lib._float64window import AdjustedArrayWindow as Float64Window from zipline.lib.adjustment import Float64Multiply from zipline.utils.cache import ExpiringCache @@ -82,7 +83,7 @@ class HistoryLoader(with_metaclass(ABCMeta)): adjustment_reader : SQLiteAdjustmentReader Reader for adjustment data. """ - FIELDS = ('open', 'high', 'low', 'close', 'volume') + FIELDS = ('open', 'high', 'low', 'close', 'volume', 'sid') def __init__(self, trading_calendar, reader, adjustment_reader, sid_cache_size=1000): @@ -270,6 +271,12 @@ class HistoryLoader(with_metaclass(ABCMeta)): prefetch_dts = cal[start_ix:prefetch_end_ix + 1] prefetch_len = len(prefetch_dts) array = self._array(prefetch_dts, needed_assets, field) + + if field == 'sid': + window_type = Int64Window + else: + window_type = Float64Window + view_kwargs = {} if field == 'volume': array = array.astype(float64_dtype) @@ -280,7 +287,7 @@ class HistoryLoader(with_metaclass(ABCMeta)): asset, prefetch_dts, field, is_perspective_after) else: adjs = {} - window = Float64Window( + window = window_type( array[:, i].reshape(prefetch_len, 1), view_kwargs, adjs, diff --git a/zipline/testing/fixtures.py b/zipline/testing/fixtures.py index 98709ab8..ed5ad62f 100644 --- a/zipline/testing/fixtures.py +++ b/zipline/testing/fixtures.py @@ -14,7 +14,10 @@ from .core import ( tmp_dir, ) from ..data.data_portal import DataPortal -from ..data.resample import minute_to_session +from ..data.resample import ( + minute_to_session, + MinuteResampleSessionBarReader +) from ..data.us_equity_pricing import ( SQLiteAdjustmentReader, SQLiteAdjustmentWriter, @@ -1303,6 +1306,12 @@ class WithDataPortal(WithAdjustmentReader, if self.DATA_PORTAL_USE_MINUTE_DATA else None ), + future_daily_reader=( + MinuteResampleSessionBarReader( + self.bcolz_future_minute_bar_reader.trading_calendar, + self.bcolz_future_minute_bar_reader) + if self.DATA_PORTAL_USE_MINUTE_DATA else None + ), last_available_session=self.DATA_PORTAL_LAST_AVAILABLE_SESSION, last_available_minute=self.DATA_PORTAL_LAST_AVAILABLE_MINUTE, )