mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-04 17:21:10 +08:00
ENH: Add PanelMinuteBarReader, use it in TradingAlgorithm.run.
TradingAlgorithm.run didn't support Panel minute bar data, and assumed all Panel data was daily. To rectify this, adding PanelMinuteBarReader class. TradingAlgorithm.run decides whether to use it or PanelDailyBarReader by assuming data is daily if and only if the time of day of every Timestamp is identical.
This commit is contained in:
+33
-12
@@ -38,6 +38,7 @@ from zipline._protocol import handle_non_market_minutes
|
||||
from zipline.assets.synthetic import make_simple_equity_info
|
||||
from zipline.data.data_portal import DataPortal
|
||||
from zipline.data.us_equity_pricing import PanelDailyBarReader
|
||||
from zipline.data.minute_bars import PanelMinuteBarReader
|
||||
from zipline.errors import (
|
||||
AttachPipelineAfterInitialize,
|
||||
HistoryInInitialize,
|
||||
@@ -615,8 +616,8 @@ class TradingAlgorithm(object):
|
||||
# to be inferred.
|
||||
if overwrite_sim_params:
|
||||
self.sim_params = self.sim_params.create_new(
|
||||
data.major_axis[0],
|
||||
data.major_axis[-1]
|
||||
normalize_date(data.major_axis[0]),
|
||||
normalize_date(data.major_axis[-1])
|
||||
)
|
||||
|
||||
copy_panel = data.rename(
|
||||
@@ -634,16 +635,36 @@ class TradingAlgorithm(object):
|
||||
copy_panel.items
|
||||
)
|
||||
)
|
||||
equity_daily_reader = PanelDailyBarReader(
|
||||
self.trading_calendar.all_sessions,
|
||||
copy_panel,
|
||||
)
|
||||
self.data_portal = DataPortal(
|
||||
self.asset_finder,
|
||||
self.trading_calendar,
|
||||
first_trading_day=equity_daily_reader.first_trading_day,
|
||||
equity_daily_reader=equity_daily_reader,
|
||||
)
|
||||
|
||||
# Assume data is daily if timestamp times are
|
||||
# standardized, otherwise assume minute bars.
|
||||
times = copy_panel.major_axis.time
|
||||
if np.all(times == times[0]):
|
||||
equity_daily_reader = PanelDailyBarReader(
|
||||
self.trading_calendar.all_sessions,
|
||||
copy_panel,
|
||||
)
|
||||
self.data_portal = DataPortal(
|
||||
self.asset_finder,
|
||||
self.trading_calendar,
|
||||
first_trading_day=equity_daily_reader
|
||||
.first_trading_day,
|
||||
equity_daily_reader=equity_daily_reader,
|
||||
)
|
||||
else:
|
||||
if overwrite_sim_params:
|
||||
self.sim_params.data_frequency = 'minute'
|
||||
equity_minute_reader = PanelMinuteBarReader(
|
||||
self.trading_calendar.all_minutes,
|
||||
copy_panel,
|
||||
)
|
||||
self.data_portal = DataPortal(
|
||||
self.asset_finder,
|
||||
self.trading_calendar,
|
||||
first_trading_day=equity_minute_reader
|
||||
.first_trading_day,
|
||||
equity_minute_reader=equity_minute_reader,
|
||||
)
|
||||
|
||||
# Force a reset of the performance tracker, in case
|
||||
# this is a repeat run of the algorithm.
|
||||
|
||||
@@ -21,7 +21,9 @@ import bcolz
|
||||
from bcolz import ctable
|
||||
from intervaltree import IntervalTree
|
||||
import numpy as np
|
||||
from numpy import zeros
|
||||
import pandas as pd
|
||||
from pandas import NaT
|
||||
|
||||
from zipline.data._minute_bar_internal import (
|
||||
minute_value,
|
||||
@@ -30,6 +32,11 @@ from zipline.data._minute_bar_internal import (
|
||||
)
|
||||
|
||||
from zipline.gens.sim_engine import NANOS_IN_MINUTE
|
||||
from zipline.utils.preprocess import call
|
||||
from zipline.utils.input_validation import (
|
||||
preprocess,
|
||||
verify_indices_all_unique,
|
||||
)
|
||||
from zipline.utils.cli import maybe_show_progress
|
||||
from zipline.utils.memoize import lazyval
|
||||
|
||||
@@ -979,3 +986,107 @@ class BcolzMinuteBarReader(object):
|
||||
out *= self._ohlc_inverse
|
||||
results.append(out)
|
||||
return results
|
||||
|
||||
|
||||
class PanelMinuteBarReader(object):
|
||||
"""
|
||||
Reader for data passed as Panel.
|
||||
|
||||
DataPanel Structure
|
||||
-------
|
||||
items : Int64Index
|
||||
Asset identifiers. Must be unique.
|
||||
major_axis : DatetimeIndex
|
||||
Datetimes for data provided by the Panel. Must be unique.
|
||||
minor_axis : ['open', 'high', 'low', 'close', 'volume']
|
||||
Price attributes. Must be unique.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
The table with which this loader interacts contains the following
|
||||
attributes:
|
||||
|
||||
panel : pd.Panel
|
||||
The panel from which to read OHLCV data.
|
||||
first_trading_day : pd.Timestamp
|
||||
The first trading day in the dataset.
|
||||
"""
|
||||
@preprocess(panel=call(verify_indices_all_unique))
|
||||
def __init__(self, calendar, panel):
|
||||
|
||||
panel = panel.copy()
|
||||
if 'volume' not in panel.minor_axis:
|
||||
# Fake volume if it does not exist.
|
||||
panel.loc[:, :, 'volume'] = int(1e9)
|
||||
|
||||
self.first_trading_day = pd.datetools.normalize_date(
|
||||
panel.major_axis[0]
|
||||
)
|
||||
self._calendar = calendar
|
||||
|
||||
self.panel = panel
|
||||
|
||||
self._ohlc_inverse = 1. / OHLC_RATIO
|
||||
|
||||
@property
|
||||
def last_available_dt(self):
|
||||
return self.panel.major_axis[-1]
|
||||
|
||||
def load_raw_arrays(self, columns, start_dt, end_dt, assets):
|
||||
columns = list(columns)
|
||||
dts = self.panel.major_axis
|
||||
index = dts[dts.slice_indexer(start_dt, end_dt)]
|
||||
shape = (len(index), len(assets))
|
||||
results = []
|
||||
for col in columns:
|
||||
outbuf = zeros(shape=shape)
|
||||
for i, asset in enumerate(assets):
|
||||
data = self.panel.loc[asset, start_dt:end_dt, col]
|
||||
data = data.reindex_axis(index).values
|
||||
outbuf[:, i] = data
|
||||
results.append(outbuf)
|
||||
return results
|
||||
|
||||
def spot_price(self, sid, dt, colname):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
sid : int
|
||||
The asset identifier.
|
||||
dt : datetime64-like
|
||||
Midnight of the day for which data is requested.
|
||||
colname : string
|
||||
The price field. e.g. ('open', 'high', 'low', 'close', 'volume')
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
The spot price for colname of the given sid on the given day.
|
||||
Raises a NoDataOnDate exception if the given day and sid is before
|
||||
or after the date range of the equity.
|
||||
Returns -1 if the day is within the date range, but the price is
|
||||
0.
|
||||
"""
|
||||
return self.panel.loc[sid, dt, colname]
|
||||
|
||||
get_value = spot_price
|
||||
|
||||
def get_last_traded_dt(self, sid, dt):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
sid : int
|
||||
The asset identifier.
|
||||
dt : datetime64-like
|
||||
Midnight of the day for which data is requested.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.Timestamp : The last known dt for the asset and dt;
|
||||
NaT if no trade is found before the given dt.
|
||||
"""
|
||||
for ts in self.panel.major_axis[self.panel.major_axis
|
||||
.slice_indexer(end=dt)][::-1]:
|
||||
if not pd.isnull(self.panel.loc[sid, ts, 'close']):
|
||||
return ts
|
||||
return NaT
|
||||
|
||||
Reference in New Issue
Block a user