diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index cb00e3a3..3ac47f12 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -1076,6 +1076,7 @@ class TestBeforeTradingStart(WithDataPortal, DATA_PORTAL_FIRST_TRADING_DAY = pd.Timestamp("2016-01-05", tz='UTC') EQUITY_MINUTE_BAR_START_DATE = pd.Timestamp("2016-01-05", tz='UTC') + FUTURE_MINUTE_BAR_START_DATE = pd.Timestamp("2016-01-05", tz='UTC') data_start = ASSET_FINDER_EQUITY_START_DATE = pd.Timestamp( '2016-01-05', diff --git a/zipline/data/data_portal.py b/zipline/data/data_portal.py index d12632f8..4f7867c4 100644 --- a/zipline/data/data_portal.py +++ b/zipline/data/data_portal.py @@ -23,7 +23,15 @@ from six import iteritems from six.moves import reduce from zipline.assets import Asset, Future, Equity -from zipline.data.resample import DailyHistoryAggregator +from zipline.data.dispatch_bar_reader import ( + AssetDispatchMinuteBarReader, + AssetDispatchSessionBarReader +) +from zipline.data.resample import ( + DailyHistoryAggregator, + ReindexMinuteBarReader, + ReindexSessionBarReader, +) from zipline.data.history_loader import ( DailyHistoryLoader, MinuteHistoryLoader, @@ -124,36 +132,68 @@ class DataPortal(object): self._augmented_sources_map = {} self._extra_source_df = None - self._equity_daily_reader = equity_daily_reader - if self._equity_daily_reader is not None: - self._history_loader = DailyHistoryLoader( - self.trading_calendar, - self._equity_daily_reader, - self._adjustment_reader - ) - self._equity_minute_reader = equity_minute_reader - self._future_daily_reader = future_daily_reader - self._future_minute_reader = future_minute_reader + self._first_trading_session = first_trading_day + + _last_sessions = [r.last_available_dt + for r in [equity_daily_reader, future_daily_reader] + if r is not None] + if _last_sessions: + self._last_trading_session = min(_last_sessions) + else: + self._last_trading_session = None + + aligned_equity_minute_reader = self._ensure_reader_aligned( + equity_minute_reader) + aligned_equity_session_reader = self._ensure_reader_aligned( + equity_daily_reader) + aligned_future_minute_reader = self._ensure_reader_aligned( + future_minute_reader) + aligned_future_session_reader = self._ensure_reader_aligned( + future_daily_reader) + + aligned_minute_readers = {} + aligned_session_readers = {} + + if aligned_equity_minute_reader is not None: + aligned_minute_readers[Equity] = aligned_equity_minute_reader + if aligned_equity_session_reader is not None: + aligned_session_readers[Equity] = aligned_equity_session_reader + + if aligned_future_minute_reader is not None: + aligned_minute_readers[Future] = aligned_future_minute_reader + if aligned_future_session_reader is not None: + aligned_session_readers[Future] = aligned_future_session_reader + + _dispatch_minute_reader = AssetDispatchMinuteBarReader( + self.trading_calendar, + self.asset_finder, + aligned_minute_readers, + ) + + _dispatch_session_reader = AssetDispatchSessionBarReader( + self.trading_calendar, + self.asset_finder, + aligned_session_readers, + ) self._pricing_readers = { - Equity: { - 'minute': equity_minute_reader, - 'daily': equity_daily_reader, - }, - Future: { - 'minute': future_minute_reader, - 'daily': future_daily_reader - } + 'minute': _dispatch_minute_reader, + 'daily': _dispatch_session_reader, } self._daily_aggregator = DailyHistoryAggregator( self.trading_calendar.schedule.market_open, - self._equity_minute_reader, + _dispatch_minute_reader, self.trading_calendar ) + self._history_loader = DailyHistoryLoader( + self.trading_calendar, + _dispatch_session_reader, + self._adjustment_reader + ) self._minute_history_loader = MinuteHistoryLoader( self.trading_calendar, - self._equity_minute_reader, + _dispatch_minute_reader, self._adjustment_reader ) @@ -179,6 +219,27 @@ class DataPortal(object): if self._first_trading_minute is not None else None ) + def _ensure_reader_aligned(self, reader): + if reader is None: + return + + if reader.trading_calendar.name == self.trading_calendar.name: + return reader + elif reader.data_frequency == 'minute': + return ReindexMinuteBarReader( + self.trading_calendar, + reader, + self._first_trading_session, + self._last_trading_session + ) + elif reader.data_frequency == 'session': + return ReindexSessionBarReader( + self.trading_calendar, + reader, + self._first_trading_session, + self._last_trading_session + ) + def _reindex_extra_source(self, df, source_date_index): return df.reindex(index=source_date_index, method='ffill') @@ -263,8 +324,8 @@ class DataPortal(object): self._extra_source_df = extra_source_df - def _get_pricing_reader(self, asset, data_frequency): - return self._pricing_readers[type(asset)][data_frequency] + def _get_pricing_reader(self, data_frequency): + return self._pricing_readers[data_frequency] def get_last_traded_dt(self, asset, dt, data_frequency): """ @@ -273,8 +334,8 @@ class DataPortal(object): If there is a trade on the dt, the answer is dt provided. """ - return self._get_pricing_reader(asset, data_frequency).\ - get_last_traded_dt(asset, dt) + return self._get_pricing_reader(data_frequency).get_last_traded_dt( + asset, dt) @staticmethod def _is_extra_source(asset, field, map): @@ -471,7 +532,7 @@ class DataPortal(object): return spot_value def _get_minute_spot_value(self, asset, column, dt, ffill=False): - reader = self._get_pricing_reader(asset, 'minute') + reader = self._get_pricing_reader('minute') result = reader.get_value( asset.sid, dt, column ) @@ -510,7 +571,7 @@ class DataPortal(object): ) def _get_daily_data(self, asset, column, dt): - reader = self._pricing_readers[type(asset)]['daily'] + reader = self._get_pricing_reader('daily') if column == "last_traded": last_traded_dt = reader.get_last_traded_dt(asset, dt) diff --git a/zipline/data/dispatch_bar_reader.py b/zipline/data/dispatch_bar_reader.py index aedf808f..4a6ecf77 100644 --- a/zipline/data/dispatch_bar_reader.py +++ b/zipline/data/dispatch_bar_reader.py @@ -72,20 +72,20 @@ class AssetDispatchBarReader(with_metaclass(ABCMeta)): @lazyval def last_available_dt(self): - return min(r.last_available_dt for r in self._readers.values) + return min(r.last_available_dt for r in self._readers.values()) @lazyval def first_trading_day(self): - return max(r.first_trading_day for r in self._readers.values) + return max(r.first_trading_day for r in self._readers.values()) def get_value(self, sid, dt, field): - asset = self.asset_finder.retrieve_asset(sid) + asset = self._asset_finder.retrieve_asset(sid) r = self._readers[type(asset)] return r.get_value(sid, dt, field) def get_last_traded_dt(self, asset, dt): r = self._readers[type(asset)] - return r.get_value(asset, dt) + return r.get_last_traded_dt(asset, dt) def load_raw_arrays(self, fields, start_dt, end_dt, sids): asset_types = self._asset_types @@ -128,3 +128,9 @@ class AssetDispatchSessionBarReader(AssetDispatchBarReader): def _dt_window_size(self, start_dt, end_dt): return len(self.trading_calendar.sessions_in_range(start_dt, end_dt)) + + @lazyval + def sessions(self): + return self.trading_calendar.sessions_in_range( + self.first_trading_day, + self.last_available_dt) diff --git a/zipline/data/minute_bars.py b/zipline/data/minute_bars.py index 25d8f582..db0637a4 100644 --- a/zipline/data/minute_bars.py +++ b/zipline/data/minute_bars.py @@ -55,6 +55,12 @@ class BcolzMinuteWriterColumnMismatch(Exception): class MinuteBarReader(with_metaclass(ABCMeta)): + _data_frequency = 'minute' + + @property + def data_frequency(self): + return self._data_frequency + @abstractproperty def last_available_dt(self): """ diff --git a/zipline/data/resample.py b/zipline/data/resample.py index c9f4605a..73aacf94 100644 --- a/zipline/data/resample.py +++ b/zipline/data/resample.py @@ -523,6 +523,10 @@ class MinuteResampleSessionBarReader(SessionBarReader): self._minute_bar_reader.last_available_dt ) + @property + def first_trading_day(self): + return self._minute_bar_reader.first_trading_day + class ReindexBarReader(with_metaclass(ABCMeta)): """ diff --git a/zipline/data/session_bars.py b/zipline/data/session_bars.py index f4fef4ab..82d9c533 100644 --- a/zipline/data/session_bars.py +++ b/zipline/data/session_bars.py @@ -19,6 +19,12 @@ class SessionBarReader(with_metaclass(ABCMeta)): """ Reader for OHCLV pricing data at a session frequency. """ + _data_frequency = 'session' + + @property + def data_frequency(self): + return self._data_frequency + @abstractmethod def load_raw_arrays(self, columns, start_date, end_date, assets): """