mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-29 07:56:38 +08:00
MAINT: Switch PanelBarReader to take trading calendar and freq args
This commit is contained in:
@@ -22,13 +22,12 @@ from zipline.data.us_equity_pricing import PanelBarReader
|
||||
from zipline.testing import ExplodingObject
|
||||
from zipline.testing.fixtures import (
|
||||
WithAssetFinder,
|
||||
WithNYSETradingDays,
|
||||
ZiplineTestCase,
|
||||
)
|
||||
from zipline.utils.calendars import get_calendar
|
||||
|
||||
|
||||
class TestPanelDailyBarReader(WithAssetFinder,
|
||||
WithNYSETradingDays,
|
||||
ZiplineTestCase):
|
||||
|
||||
START_DATE = pd.Timestamp('2006-01-03', tz='utc')
|
||||
@@ -39,10 +38,13 @@ class TestPanelDailyBarReader(WithAssetFinder,
|
||||
super(TestPanelDailyBarReader, cls).init_class_fixtures()
|
||||
|
||||
finder = cls.asset_finder
|
||||
days = cls.trading_days
|
||||
trading_calendar = get_calendar('NYSE')
|
||||
|
||||
items = finder.retrieve_all(finder.sids)
|
||||
major_axis = days
|
||||
major_axis = trading_calendar.sessions_in_range(
|
||||
cls.START_DATE,
|
||||
cls.END_DATE
|
||||
)
|
||||
minor_axis = ['open', 'high', 'low', 'close', 'volume']
|
||||
|
||||
shape = tuple(map(len, [items, major_axis, minor_axis]))
|
||||
@@ -55,7 +57,7 @@ class TestPanelDailyBarReader(WithAssetFinder,
|
||||
minor_axis=minor_axis,
|
||||
)
|
||||
|
||||
cls.reader = PanelBarReader(days, cls.panel)
|
||||
cls.reader = PanelBarReader(trading_calendar, cls.panel, 'daily')
|
||||
|
||||
def test_spot_price(self):
|
||||
panel = self.panel
|
||||
@@ -83,7 +85,7 @@ class TestPanelDailyBarReader(WithAssetFinder,
|
||||
for axis_order in permutations((0, 1, 2)):
|
||||
transposed = panel.transpose(*axis_order)
|
||||
with self.assertRaises(ValueError) as e:
|
||||
PanelBarReader(unused, transposed)
|
||||
PanelBarReader(unused, transposed, 'daily')
|
||||
|
||||
expected = (
|
||||
"Duplicate entries in Panel.{name}: ['a', 'b'].".format(
|
||||
|
||||
+11
-5
@@ -619,8 +619,12 @@ class TradingAlgorithm(object):
|
||||
# to be inferred.
|
||||
if overwrite_sim_params:
|
||||
self.sim_params = self.sim_params.create_new(
|
||||
normalize_date(data.major_axis[0]),
|
||||
normalize_date(data.major_axis[-1])
|
||||
self.trading_calendar.minute_to_session_label(
|
||||
data.major_axis[0]
|
||||
),
|
||||
self.trading_calendar.minute_to_session_label(
|
||||
data.major_axis[-1]
|
||||
),
|
||||
)
|
||||
|
||||
# Assume data is daily if timestamp times are
|
||||
@@ -649,11 +653,13 @@ class TradingAlgorithm(object):
|
||||
|
||||
if self.sim_params.data_frequency == 'daily':
|
||||
equity_reader_arg = 'equity_daily_reader'
|
||||
calendar = self.trading_calendar.all_sessions
|
||||
elif self.sim_params.data_frequency == 'minute':
|
||||
equity_reader_arg = 'equity_minute_reader'
|
||||
calendar = self.trading_calendar.all_minutes
|
||||
equity_reader = PanelBarReader(calendar, copy_panel)
|
||||
equity_reader = PanelBarReader(
|
||||
self.trading_calendar,
|
||||
copy_panel,
|
||||
self.sim_params.data_frequency,
|
||||
)
|
||||
|
||||
self.data_portal = DataPortal(
|
||||
self.asset_finder,
|
||||
|
||||
@@ -43,7 +43,6 @@ from pandas import (
|
||||
NaT,
|
||||
DatetimeIndex
|
||||
)
|
||||
from pandas.core.datetools import normalize_date
|
||||
from pandas.tslib import iNaT
|
||||
from six import (
|
||||
iteritems,
|
||||
@@ -770,15 +769,34 @@ class PanelBarReader(DailyBarReader):
|
||||
The first trading day in the dataset.
|
||||
"""
|
||||
@preprocess(panel=call(verify_indices_all_unique))
|
||||
def __init__(self, calendar, panel):
|
||||
@expect_element(data_frequency={'daily', 'minute'})
|
||||
def __init__(self, trading_calendar, panel, data_frequency):
|
||||
|
||||
panel = panel.copy()
|
||||
if 'volume' not in panel.minor_axis:
|
||||
# Fake volume if it does not exist.
|
||||
panel.loc[:, :, 'volume'] = int(1e9)
|
||||
|
||||
self.first_trading_day = normalize_date(panel.major_axis[0])
|
||||
self._calendar = calendar
|
||||
self.trading_calendar = trading_calendar
|
||||
self.first_trading_day = trading_calendar.minute_to_session_label(
|
||||
panel.major_axis[0]
|
||||
)
|
||||
last_trading_day = trading_calendar.minute_to_session_label(
|
||||
panel.major_axis[-1]
|
||||
)
|
||||
|
||||
self._sessions = trading_calendar.sessions_in_range(
|
||||
self.first_trading_day,
|
||||
last_trading_day
|
||||
)
|
||||
|
||||
if data_frequency == 'daily':
|
||||
self._calendar = self._sessions
|
||||
elif data_frequency == 'minute':
|
||||
self._calendar = trading_calendar.minutes_for_sessions_in_range(
|
||||
self.first_trading_day,
|
||||
last_trading_day
|
||||
)
|
||||
|
||||
self.panel = panel
|
||||
|
||||
@@ -788,18 +806,9 @@ class PanelBarReader(DailyBarReader):
|
||||
|
||||
@property
|
||||
def last_available_dt(self):
|
||||
# Returns the last Panel index that is on the calendar.
|
||||
# The slice end is converted from dt to date string so that
|
||||
# dts on the last day of the calendar get included.
|
||||
return self.panel.major_axis[
|
||||
self.panel.major_axis.slice_indexer(
|
||||
end=self._calendar[-1].strftime('%Y-%m-%d')
|
||||
)
|
||||
][-1]
|
||||
return self._calendar[-1]
|
||||
|
||||
@property
|
||||
def trading_calendar(self):
|
||||
return None
|
||||
trading_calendar = None
|
||||
|
||||
def load_raw_arrays(self, columns, start_dt, end_dt, assets):
|
||||
cal = self._calendar
|
||||
|
||||
Reference in New Issue
Block a user