From 40c7deb697bbcdbe65b5cef33afebd640be3a590 Mon Sep 17 00:00:00 2001 From: Eddie Hebert Date: Fri, 26 Aug 2016 09:12:33 -0400 Subject: [PATCH] ENH: Add asset dispatch to data portal. Combine the equity and future readers into asset dispatch readers, so that simulations that use both asset types can access data for each. This patch enables `history` for future assets in algorithms; however, it does not add extra coverage in the `test_data_portal` or `test_history` to cover future assets. Those tests will follow, however putting this in separately since it shows that the wrapping of the readers in the asset dispatch reader does not break existing equity strategies. --- tests/test_algorithm.py | 1 + zipline/data/data_portal.py | 115 +++++++++++++++++++++------- zipline/data/dispatch_bar_reader.py | 14 +++- zipline/data/minute_bars.py | 6 ++ zipline/data/resample.py | 4 + zipline/data/session_bars.py | 6 ++ 6 files changed, 115 insertions(+), 31 deletions(-) diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index cb00e3a3..3ac47f12 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -1076,6 +1076,7 @@ class TestBeforeTradingStart(WithDataPortal, DATA_PORTAL_FIRST_TRADING_DAY = pd.Timestamp("2016-01-05", tz='UTC') EQUITY_MINUTE_BAR_START_DATE = pd.Timestamp("2016-01-05", tz='UTC') + FUTURE_MINUTE_BAR_START_DATE = pd.Timestamp("2016-01-05", tz='UTC') data_start = ASSET_FINDER_EQUITY_START_DATE = pd.Timestamp( '2016-01-05', diff --git a/zipline/data/data_portal.py b/zipline/data/data_portal.py index d12632f8..4f7867c4 100644 --- a/zipline/data/data_portal.py +++ b/zipline/data/data_portal.py @@ -23,7 +23,15 @@ from six import iteritems from six.moves import reduce from zipline.assets import Asset, Future, Equity -from zipline.data.resample import DailyHistoryAggregator +from zipline.data.dispatch_bar_reader import ( + AssetDispatchMinuteBarReader, + AssetDispatchSessionBarReader +) +from zipline.data.resample import ( + DailyHistoryAggregator, + ReindexMinuteBarReader, + ReindexSessionBarReader, +) from zipline.data.history_loader import ( DailyHistoryLoader, MinuteHistoryLoader, @@ -124,36 +132,68 @@ class DataPortal(object): self._augmented_sources_map = {} self._extra_source_df = None - self._equity_daily_reader = equity_daily_reader - if self._equity_daily_reader is not None: - self._history_loader = DailyHistoryLoader( - self.trading_calendar, - self._equity_daily_reader, - self._adjustment_reader - ) - self._equity_minute_reader = equity_minute_reader - self._future_daily_reader = future_daily_reader - self._future_minute_reader = future_minute_reader + self._first_trading_session = first_trading_day + + _last_sessions = [r.last_available_dt + for r in [equity_daily_reader, future_daily_reader] + if r is not None] + if _last_sessions: + self._last_trading_session = min(_last_sessions) + else: + self._last_trading_session = None + + aligned_equity_minute_reader = self._ensure_reader_aligned( + equity_minute_reader) + aligned_equity_session_reader = self._ensure_reader_aligned( + equity_daily_reader) + aligned_future_minute_reader = self._ensure_reader_aligned( + future_minute_reader) + aligned_future_session_reader = self._ensure_reader_aligned( + future_daily_reader) + + aligned_minute_readers = {} + aligned_session_readers = {} + + if aligned_equity_minute_reader is not None: + aligned_minute_readers[Equity] = aligned_equity_minute_reader + if aligned_equity_session_reader is not None: + aligned_session_readers[Equity] = aligned_equity_session_reader + + if aligned_future_minute_reader is not None: + aligned_minute_readers[Future] = aligned_future_minute_reader + if aligned_future_session_reader is not None: + aligned_session_readers[Future] = aligned_future_session_reader + + _dispatch_minute_reader = AssetDispatchMinuteBarReader( + self.trading_calendar, + self.asset_finder, + aligned_minute_readers, + ) + + _dispatch_session_reader = AssetDispatchSessionBarReader( + self.trading_calendar, + self.asset_finder, + aligned_session_readers, + ) self._pricing_readers = { - Equity: { - 'minute': equity_minute_reader, - 'daily': equity_daily_reader, - }, - Future: { - 'minute': future_minute_reader, - 'daily': future_daily_reader - } + 'minute': _dispatch_minute_reader, + 'daily': _dispatch_session_reader, } self._daily_aggregator = DailyHistoryAggregator( self.trading_calendar.schedule.market_open, - self._equity_minute_reader, + _dispatch_minute_reader, self.trading_calendar ) + self._history_loader = DailyHistoryLoader( + self.trading_calendar, + _dispatch_session_reader, + self._adjustment_reader + ) self._minute_history_loader = MinuteHistoryLoader( self.trading_calendar, - self._equity_minute_reader, + _dispatch_minute_reader, self._adjustment_reader ) @@ -179,6 +219,27 @@ class DataPortal(object): if self._first_trading_minute is not None else None ) + def _ensure_reader_aligned(self, reader): + if reader is None: + return + + if reader.trading_calendar.name == self.trading_calendar.name: + return reader + elif reader.data_frequency == 'minute': + return ReindexMinuteBarReader( + self.trading_calendar, + reader, + self._first_trading_session, + self._last_trading_session + ) + elif reader.data_frequency == 'session': + return ReindexSessionBarReader( + self.trading_calendar, + reader, + self._first_trading_session, + self._last_trading_session + ) + def _reindex_extra_source(self, df, source_date_index): return df.reindex(index=source_date_index, method='ffill') @@ -263,8 +324,8 @@ class DataPortal(object): self._extra_source_df = extra_source_df - def _get_pricing_reader(self, asset, data_frequency): - return self._pricing_readers[type(asset)][data_frequency] + def _get_pricing_reader(self, data_frequency): + return self._pricing_readers[data_frequency] def get_last_traded_dt(self, asset, dt, data_frequency): """ @@ -273,8 +334,8 @@ class DataPortal(object): If there is a trade on the dt, the answer is dt provided. """ - return self._get_pricing_reader(asset, data_frequency).\ - get_last_traded_dt(asset, dt) + return self._get_pricing_reader(data_frequency).get_last_traded_dt( + asset, dt) @staticmethod def _is_extra_source(asset, field, map): @@ -471,7 +532,7 @@ class DataPortal(object): return spot_value def _get_minute_spot_value(self, asset, column, dt, ffill=False): - reader = self._get_pricing_reader(asset, 'minute') + reader = self._get_pricing_reader('minute') result = reader.get_value( asset.sid, dt, column ) @@ -510,7 +571,7 @@ class DataPortal(object): ) def _get_daily_data(self, asset, column, dt): - reader = self._pricing_readers[type(asset)]['daily'] + reader = self._get_pricing_reader('daily') if column == "last_traded": last_traded_dt = reader.get_last_traded_dt(asset, dt) diff --git a/zipline/data/dispatch_bar_reader.py b/zipline/data/dispatch_bar_reader.py index aedf808f..4a6ecf77 100644 --- a/zipline/data/dispatch_bar_reader.py +++ b/zipline/data/dispatch_bar_reader.py @@ -72,20 +72,20 @@ class AssetDispatchBarReader(with_metaclass(ABCMeta)): @lazyval def last_available_dt(self): - return min(r.last_available_dt for r in self._readers.values) + return min(r.last_available_dt for r in self._readers.values()) @lazyval def first_trading_day(self): - return max(r.first_trading_day for r in self._readers.values) + return max(r.first_trading_day for r in self._readers.values()) def get_value(self, sid, dt, field): - asset = self.asset_finder.retrieve_asset(sid) + asset = self._asset_finder.retrieve_asset(sid) r = self._readers[type(asset)] return r.get_value(sid, dt, field) def get_last_traded_dt(self, asset, dt): r = self._readers[type(asset)] - return r.get_value(asset, dt) + return r.get_last_traded_dt(asset, dt) def load_raw_arrays(self, fields, start_dt, end_dt, sids): asset_types = self._asset_types @@ -128,3 +128,9 @@ class AssetDispatchSessionBarReader(AssetDispatchBarReader): def _dt_window_size(self, start_dt, end_dt): return len(self.trading_calendar.sessions_in_range(start_dt, end_dt)) + + @lazyval + def sessions(self): + return self.trading_calendar.sessions_in_range( + self.first_trading_day, + self.last_available_dt) diff --git a/zipline/data/minute_bars.py b/zipline/data/minute_bars.py index 25d8f582..db0637a4 100644 --- a/zipline/data/minute_bars.py +++ b/zipline/data/minute_bars.py @@ -55,6 +55,12 @@ class BcolzMinuteWriterColumnMismatch(Exception): class MinuteBarReader(with_metaclass(ABCMeta)): + _data_frequency = 'minute' + + @property + def data_frequency(self): + return self._data_frequency + @abstractproperty def last_available_dt(self): """ diff --git a/zipline/data/resample.py b/zipline/data/resample.py index c9f4605a..73aacf94 100644 --- a/zipline/data/resample.py +++ b/zipline/data/resample.py @@ -523,6 +523,10 @@ class MinuteResampleSessionBarReader(SessionBarReader): self._minute_bar_reader.last_available_dt ) + @property + def first_trading_day(self): + return self._minute_bar_reader.first_trading_day + class ReindexBarReader(with_metaclass(ABCMeta)): """ diff --git a/zipline/data/session_bars.py b/zipline/data/session_bars.py index f4fef4ab..82d9c533 100644 --- a/zipline/data/session_bars.py +++ b/zipline/data/session_bars.py @@ -19,6 +19,12 @@ class SessionBarReader(with_metaclass(ABCMeta)): """ Reader for OHCLV pricing data at a session frequency. """ + _data_frequency = 'session' + + @property + def data_frequency(self): + return self._data_frequency + @abstractmethod def load_raw_arrays(self, columns, start_date, end_date, assets): """