mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 13:17:43 +08:00
BUG: rolling panel data became misaligned after extend_back
This commit is contained in:
@@ -22,7 +22,73 @@ import numpy as np
|
||||
import pandas as pd
|
||||
import pandas.util.testing as tm
|
||||
|
||||
from zipline.utils.data import MutableIndexRollingPanel
|
||||
from zipline.utils.data import MutableIndexRollingPanel, RollingPanel
|
||||
from zipline.finance.trading import with_environment
|
||||
|
||||
|
||||
class TestRollingPanel(unittest.TestCase):
|
||||
@with_environment()
|
||||
def test_alignment(self, env):
|
||||
items = ('a', 'b')
|
||||
sids = (1, 2)
|
||||
|
||||
dts = env.market_minute_window(
|
||||
env.open_and_closes.market_open[0], 4,
|
||||
).values
|
||||
rp = RollingPanel(2, items, sids, initial_dates=dts[1:-1])
|
||||
|
||||
frame = pd.DataFrame(
|
||||
data=np.arange(4).reshape((2, 2)),
|
||||
columns=sids,
|
||||
index=items,
|
||||
)
|
||||
|
||||
nan_arr = np.empty((2, 6))
|
||||
nan_arr.fill(np.nan)
|
||||
|
||||
rp.add_frame(dts[-1], frame)
|
||||
|
||||
cur = rp.get_current()
|
||||
data = np.array((((np.nan, np.nan),
|
||||
(0, 1)),
|
||||
((np.nan, np.nan),
|
||||
(2, 3))),
|
||||
float)
|
||||
expected = pd.Panel(
|
||||
data,
|
||||
major_axis=dts[2:],
|
||||
minor_axis=sids,
|
||||
items=items,
|
||||
)
|
||||
expected.major_axis = expected.major_axis.tz_localize('utc')
|
||||
tm.assert_panel_equal(
|
||||
cur,
|
||||
expected,
|
||||
)
|
||||
|
||||
rp.extend_back(dts[:-2])
|
||||
|
||||
cur = rp.get_current()
|
||||
data = np.array((((np.nan, np.nan),
|
||||
(np.nan, np.nan),
|
||||
(np.nan, np.nan),
|
||||
(0, 1)),
|
||||
((np.nan, np.nan),
|
||||
(np.nan, np.nan),
|
||||
(np.nan, np.nan),
|
||||
(2, 3))),
|
||||
float)
|
||||
expected = pd.Panel(
|
||||
data,
|
||||
major_axis=dts,
|
||||
minor_axis=sids,
|
||||
items=items,
|
||||
)
|
||||
expected.major_axis = expected.major_axis.tz_localize('utc')
|
||||
tm.assert_panel_equal(
|
||||
cur,
|
||||
expected,
|
||||
)
|
||||
|
||||
|
||||
class TestMutableIndexRollingPanel(unittest.TestCase):
|
||||
|
||||
@@ -439,24 +439,19 @@ class HistoryContainer(object):
|
||||
|
||||
window = spec.bar_count - 1
|
||||
|
||||
# everything after dt is going to be filled from calling update, no
|
||||
# need to precompute these dates.
|
||||
second = np.empty(window, dtype='datetime64[ns]')
|
||||
date_buf = np.hstack(
|
||||
(self._create_window_date_buf(
|
||||
window,
|
||||
spec.frequency.unit_str,
|
||||
spec.frequency.data_frequency,
|
||||
dt,
|
||||
env=env,
|
||||
), second),
|
||||
date_buf = self._create_window_date_buf(
|
||||
window,
|
||||
spec.frequency.unit_str,
|
||||
spec.frequency.data_frequency,
|
||||
dt,
|
||||
env=env,
|
||||
)
|
||||
|
||||
panel = RollingPanel(
|
||||
window=window,
|
||||
items=self.fields,
|
||||
sids=self.sids,
|
||||
date_buf=date_buf,
|
||||
initial_dates=date_buf,
|
||||
)
|
||||
|
||||
return panel
|
||||
|
||||
+39
-22
@@ -38,7 +38,7 @@ class RollingPanel(object):
|
||||
sids,
|
||||
cap_multiple=2,
|
||||
dtype=np.float64,
|
||||
date_buf=None):
|
||||
initial_dates=None):
|
||||
|
||||
self._pos = window
|
||||
self._window = window
|
||||
@@ -49,8 +49,20 @@ class RollingPanel(object):
|
||||
self.cap_multiple = cap_multiple
|
||||
|
||||
self.dtype = dtype
|
||||
self.date_buf = np.empty(self.cap, dtype='M8[ns]') \
|
||||
if date_buf is None else date_buf
|
||||
if initial_dates is None:
|
||||
self.date_buf = np.empty(self.cap, dtype='M8[ns]') * pd.NaT
|
||||
elif len(initial_dates) != window:
|
||||
raise ValueError('initial_dates must be of length window')
|
||||
else:
|
||||
self.date_buf = np.hstack(
|
||||
(
|
||||
initial_dates,
|
||||
np.empty(
|
||||
window * (cap_multiple - 1),
|
||||
dtype='datetime64[ns]',
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
self.buffer = self._create_buffer()
|
||||
|
||||
@@ -59,18 +71,18 @@ class RollingPanel(object):
|
||||
return self.cap_multiple * self._window
|
||||
|
||||
@property
|
||||
def start_index(self):
|
||||
def _start_index(self):
|
||||
return self._pos - self._window
|
||||
|
||||
@property
|
||||
def start_date(self):
|
||||
return self.date_buf[self.start_index]
|
||||
return self.date_buf[self._start_index]
|
||||
|
||||
def oldest_frame(self):
|
||||
"""
|
||||
Get the oldest frame in the panel.
|
||||
"""
|
||||
return self.buffer.iloc[:, self.start_index, :]
|
||||
return self.buffer.iloc[:, self._start_index, :]
|
||||
|
||||
def set_minor_axis(self, minor_axis):
|
||||
self.minor_axis = _ensure_index(minor_axis)
|
||||
@@ -109,22 +121,27 @@ class RollingPanel(object):
|
||||
self.date_buf.resize(self.cap)
|
||||
self.date_buf = np.roll(self.date_buf, delta)
|
||||
|
||||
self.buffer = pd.concat(
|
||||
[
|
||||
pd.Panel(
|
||||
items=self.items,
|
||||
minor_axis=self.minor_axis,
|
||||
major_axis=np.arange(delta * self.cap_multiple),
|
||||
dtype=self.dtype,
|
||||
),
|
||||
self.buffer
|
||||
],
|
||||
axis=1,
|
||||
old_vals = self.buffer.values
|
||||
shape = old_vals.shape
|
||||
nan_arr = np.empty((shape[0], delta, shape[2]))
|
||||
nan_arr.fill(np.nan)
|
||||
|
||||
new_vals = np.column_stack(
|
||||
(nan_arr,
|
||||
old_vals,
|
||||
np.empty((shape[0], delta * (self.cap_multiple - 1), shape[2]))),
|
||||
)
|
||||
|
||||
self.buffer = pd.Panel(
|
||||
data=new_vals,
|
||||
items=self.items,
|
||||
minor_axis=self.minor_axis,
|
||||
major_axis=np.arange(self.cap),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.buffer.major_axis = pd.Int64Index(range(self.cap))
|
||||
|
||||
# Fill the delta with the dates we calculated.
|
||||
where = slice(self.start_index, self.start_index + delta)
|
||||
where = slice(self._start_index, self._start_index + delta)
|
||||
self.date_buf[where] = missing_dts
|
||||
|
||||
def add_frame(self, tick, frame):
|
||||
@@ -144,7 +161,7 @@ class RollingPanel(object):
|
||||
these objects because internal data might change
|
||||
"""
|
||||
|
||||
where = slice(self.start_index, self._pos)
|
||||
where = slice(self._start_index, self._pos)
|
||||
major_axis = pd.DatetimeIndex(deepcopy(self.date_buf[where]), tz='utc')
|
||||
return pd.Panel(self.buffer.values[:, where, :], self.items,
|
||||
major_axis, self.minor_axis, dtype=self.dtype)
|
||||
@@ -155,11 +172,11 @@ class RollingPanel(object):
|
||||
passed panel. The passed panel must have the same indices as the panel
|
||||
that would be returned by self.get_current.
|
||||
"""
|
||||
where = slice(self.start_index, self._pos)
|
||||
where = slice(self._start_index, self._pos)
|
||||
self.buffer.values[:, where, :] = panel.values
|
||||
|
||||
def current_dates(self):
|
||||
where = slice(self.start_index, self._pos)
|
||||
where = slice(self._start_index, self._pos)
|
||||
return pd.DatetimeIndex(deepcopy(self.date_buf[where]), tz='utc')
|
||||
|
||||
def _roll_data(self):
|
||||
|
||||
Reference in New Issue
Block a user