mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-02 12:29:43 +08:00
BUG: When increasing the length dynamically, the rolling panel was
getting filled with the wrong datetimes and causing errors. Updates the logic for addressing missing datetimes and adds unit tests for the 2 main cases (no missing datetimes, and some missing datetimes).
This commit is contained in:
@@ -935,6 +935,102 @@ def handle_data(context, data):
|
||||
' HistorySpec',
|
||||
)
|
||||
|
||||
@parameterized.expand([
|
||||
(1,),
|
||||
(2,),
|
||||
])
|
||||
def test_history_grow_length_inter_bar(self, incr):
|
||||
"""
|
||||
Tests growing the length of a digest panel with different date_buf
|
||||
deltas once per bar.
|
||||
"""
|
||||
algo_text = dedent(
|
||||
"""\
|
||||
from zipline.api import history
|
||||
|
||||
|
||||
def initialize(context):
|
||||
context.bar_count = 1
|
||||
|
||||
|
||||
def handle_data(context, data):
|
||||
prices = history(context.bar_count, '1d', 'price')
|
||||
context.test_case.assertEqual(len(prices), context.bar_count)
|
||||
context.bar_count += {incr}
|
||||
"""
|
||||
).format(incr=incr)
|
||||
start = pd.Timestamp('2007-04-05', tz='UTC')
|
||||
end = pd.Timestamp('2007-04-10', tz='UTC')
|
||||
|
||||
sim_params = SimulationParameters(
|
||||
period_start=start,
|
||||
period_end=end,
|
||||
capital_base=float("1.0e5"),
|
||||
data_frequency='minute',
|
||||
emission_rate='daily'
|
||||
)
|
||||
|
||||
test_algo = TradingAlgorithm(
|
||||
script=algo_text,
|
||||
data_frequency='minute',
|
||||
sim_params=sim_params
|
||||
)
|
||||
test_algo.test_case = self
|
||||
|
||||
source = RandomWalkSource(start=start, end=end)
|
||||
|
||||
self.assertIsNone(test_algo.history_container)
|
||||
test_algo.run(source)
|
||||
|
||||
@parameterized.expand([
|
||||
(1,),
|
||||
(2,),
|
||||
])
|
||||
def test_history_grow_length_intra_bar(self, incr):
|
||||
"""
|
||||
Tests growing the length of a digest panel with different date_buf
|
||||
deltas in a single bar.
|
||||
"""
|
||||
algo_text = dedent(
|
||||
"""\
|
||||
from zipline.api import history
|
||||
|
||||
|
||||
def initialize(context):
|
||||
context.bar_count = 1
|
||||
|
||||
|
||||
def handle_data(context, data):
|
||||
prices = history(context.bar_count, '1d', 'price')
|
||||
context.test_case.assertEqual(len(prices), context.bar_count)
|
||||
context.bar_count += {incr}
|
||||
prices = history(context.bar_count, '1d', 'price')
|
||||
context.test_case.assertEqual(len(prices), context.bar_count)
|
||||
"""
|
||||
).format(incr=incr)
|
||||
start = pd.Timestamp('2007-04-05', tz='UTC')
|
||||
end = pd.Timestamp('2007-04-10', tz='UTC')
|
||||
|
||||
sim_params = SimulationParameters(
|
||||
period_start=start,
|
||||
period_end=end,
|
||||
capital_base=float("1.0e5"),
|
||||
data_frequency='minute',
|
||||
emission_rate='daily'
|
||||
)
|
||||
|
||||
test_algo = TradingAlgorithm(
|
||||
script=algo_text,
|
||||
data_frequency='minute',
|
||||
sim_params=sim_params
|
||||
)
|
||||
test_algo.test_case = self
|
||||
|
||||
source = RandomWalkSource(start=start, end=end)
|
||||
|
||||
self.assertIsNone(test_algo.history_container)
|
||||
test_algo.run(source)
|
||||
|
||||
|
||||
class TestHistoryContainerResize(TestCase):
|
||||
@parameterized.expand(
|
||||
|
||||
@@ -387,24 +387,15 @@ class HistoryContainer(object):
|
||||
"""
|
||||
# This is the oldest datetime that will be shown in the current window
|
||||
# of the panel.
|
||||
oldest_idx = panel._oldest_frame_idx
|
||||
oldest_dt = pd.Timestamp(
|
||||
panel.date_buf[oldest_idx], tz='utc',
|
||||
)
|
||||
old_cap = panel.cap
|
||||
panel.resize(size)
|
||||
oldest_dt = pd.Timestamp(panel.start_date, tz='utc',)
|
||||
delta = size - panel.window_length
|
||||
|
||||
delta = (old_cap - oldest_idx) - panel._oldest_frame_idx
|
||||
|
||||
# Backfill the missing dates of the new current window.
|
||||
# Construct the missing dates.
|
||||
missing_dts = self._create_window_date_buf(
|
||||
delta, freq.unit_str, freq.data_frequency, oldest_dt,
|
||||
)
|
||||
|
||||
# Fill the dates in between the new oldest index and adjusted oldest
|
||||
# index.
|
||||
where = slice(panel._oldest_frame_idx, -(old_cap - oldest_idx))
|
||||
panel.date_buf[where] = missing_dts
|
||||
panel.extend_back(missing_dts)
|
||||
|
||||
@with_environment()
|
||||
def _create_window_date_buf(self,
|
||||
|
||||
+26
-12
@@ -47,7 +47,6 @@ class RollingPanel(object):
|
||||
self.minor_axis = _ensure_index(sids)
|
||||
|
||||
self.cap_multiple = cap_multiple
|
||||
self.cap = cap_multiple * window
|
||||
|
||||
self.dtype = dtype
|
||||
self.date_buf = np.empty(self.cap, dtype='M8[ns]') \
|
||||
@@ -56,14 +55,22 @@ class RollingPanel(object):
|
||||
self.buffer = self._create_buffer()
|
||||
|
||||
@property
|
||||
def _oldest_frame_idx(self):
|
||||
def cap(self):
|
||||
return self.cap_multiple * self._window
|
||||
|
||||
@property
|
||||
def start_index(self):
|
||||
return self._pos - self._window
|
||||
|
||||
@property
|
||||
def start_date(self):
|
||||
return self.date_buf[self.start_index]
|
||||
|
||||
def oldest_frame(self):
|
||||
"""
|
||||
Get the oldest frame in the panel.
|
||||
"""
|
||||
return self.buffer.iloc[:, self._oldest_frame_idx, :]
|
||||
return self.buffer.iloc[:, self.start_index, :]
|
||||
|
||||
def set_minor_axis(self, minor_axis):
|
||||
self.minor_axis = _ensure_index(minor_axis)
|
||||
@@ -82,16 +89,19 @@ class RollingPanel(object):
|
||||
)
|
||||
return panel
|
||||
|
||||
def resize(self, window):
|
||||
def extend_back(self, missing_dts):
|
||||
"""
|
||||
Resizes the buffer to hold a new window with a new cap_multiple.
|
||||
If cap_multiple is None, then the old cap_multiple is used.
|
||||
"""
|
||||
self._window = window
|
||||
delta = len(missing_dts)
|
||||
|
||||
pre = self.cap
|
||||
self.cap = self.cap_multiple * window
|
||||
delta = self.cap - pre
|
||||
if not delta:
|
||||
raise ValueError(
|
||||
'missing_dts must be a non-empty index',
|
||||
)
|
||||
|
||||
self._window += delta
|
||||
|
||||
self._pos += delta
|
||||
|
||||
@@ -104,7 +114,7 @@ class RollingPanel(object):
|
||||
pd.Panel(
|
||||
items=self.items,
|
||||
minor_axis=self.minor_axis,
|
||||
major_axis=np.arange(delta),
|
||||
major_axis=np.arange(delta * self.cap_multiple),
|
||||
dtype=self.dtype,
|
||||
),
|
||||
self.buffer
|
||||
@@ -113,6 +123,10 @@ class RollingPanel(object):
|
||||
)
|
||||
self.buffer.major_axis = pd.Int64Index(range(self.cap))
|
||||
|
||||
# Fill the delta with the dates we calculated.
|
||||
where = slice(self.start_index, self.start_index + delta)
|
||||
self.date_buf[where] = missing_dts
|
||||
|
||||
def add_frame(self, tick, frame):
|
||||
"""
|
||||
"""
|
||||
@@ -130,7 +144,7 @@ class RollingPanel(object):
|
||||
these objects because internal data might change
|
||||
"""
|
||||
|
||||
where = slice(self._oldest_frame_idx, self._pos)
|
||||
where = slice(self.start_index, self._pos)
|
||||
major_axis = pd.DatetimeIndex(deepcopy(self.date_buf[where]), tz='utc')
|
||||
return pd.Panel(self.buffer.values[:, where, :], self.items,
|
||||
major_axis, self.minor_axis, dtype=self.dtype)
|
||||
@@ -141,11 +155,11 @@ class RollingPanel(object):
|
||||
passed panel. The passed panel must have the same indices as the panel
|
||||
that would be returned by self.get_current.
|
||||
"""
|
||||
where = slice(self._oldest_frame_idx, self._pos)
|
||||
where = slice(self.start_index, self._pos)
|
||||
self.buffer.values[:, where, :] = panel.values
|
||||
|
||||
def current_dates(self):
|
||||
where = slice(self._oldest_frame_idx, self._pos)
|
||||
where = slice(self.start_index, self._pos)
|
||||
return pd.DatetimeIndex(deepcopy(self.date_buf[where]), tz='utc')
|
||||
|
||||
def _roll_data(self):
|
||||
|
||||
Reference in New Issue
Block a user