BUG: When increasing the length dynamically, the rolling panel was

getting filled with the wrong datetimes and causing errors.

Updates the logic for addressing missing datetimes and adds unit tests
for the 2 main cases (no missing datetimes, and some missing datetimes).
This commit is contained in:
Joe Jevnik
2014-11-10 14:32:57 -05:00
parent 54b1f15983
commit 8df1a49031
3 changed files with 126 additions and 25 deletions
+96
View File
@@ -935,6 +935,102 @@ def handle_data(context, data):
' HistorySpec',
)
@parameterized.expand([
(1,),
(2,),
])
def test_history_grow_length_inter_bar(self, incr):
"""
Tests growing the length of a digest panel with different date_buf
deltas once per bar.
"""
algo_text = dedent(
"""\
from zipline.api import history
def initialize(context):
context.bar_count = 1
def handle_data(context, data):
prices = history(context.bar_count, '1d', 'price')
context.test_case.assertEqual(len(prices), context.bar_count)
context.bar_count += {incr}
"""
).format(incr=incr)
start = pd.Timestamp('2007-04-05', tz='UTC')
end = pd.Timestamp('2007-04-10', tz='UTC')
sim_params = SimulationParameters(
period_start=start,
period_end=end,
capital_base=float("1.0e5"),
data_frequency='minute',
emission_rate='daily'
)
test_algo = TradingAlgorithm(
script=algo_text,
data_frequency='minute',
sim_params=sim_params
)
test_algo.test_case = self
source = RandomWalkSource(start=start, end=end)
self.assertIsNone(test_algo.history_container)
test_algo.run(source)
@parameterized.expand([
(1,),
(2,),
])
def test_history_grow_length_intra_bar(self, incr):
"""
Tests growing the length of a digest panel with different date_buf
deltas in a single bar.
"""
algo_text = dedent(
"""\
from zipline.api import history
def initialize(context):
context.bar_count = 1
def handle_data(context, data):
prices = history(context.bar_count, '1d', 'price')
context.test_case.assertEqual(len(prices), context.bar_count)
context.bar_count += {incr}
prices = history(context.bar_count, '1d', 'price')
context.test_case.assertEqual(len(prices), context.bar_count)
"""
).format(incr=incr)
start = pd.Timestamp('2007-04-05', tz='UTC')
end = pd.Timestamp('2007-04-10', tz='UTC')
sim_params = SimulationParameters(
period_start=start,
period_end=end,
capital_base=float("1.0e5"),
data_frequency='minute',
emission_rate='daily'
)
test_algo = TradingAlgorithm(
script=algo_text,
data_frequency='minute',
sim_params=sim_params
)
test_algo.test_case = self
source = RandomWalkSource(start=start, end=end)
self.assertIsNone(test_algo.history_container)
test_algo.run(source)
class TestHistoryContainerResize(TestCase):
@parameterized.expand(
+4 -13
View File
@@ -387,24 +387,15 @@ class HistoryContainer(object):
"""
# This is the oldest datetime that will be shown in the current window
# of the panel.
oldest_idx = panel._oldest_frame_idx
oldest_dt = pd.Timestamp(
panel.date_buf[oldest_idx], tz='utc',
)
old_cap = panel.cap
panel.resize(size)
oldest_dt = pd.Timestamp(panel.start_date, tz='utc',)
delta = size - panel.window_length
delta = (old_cap - oldest_idx) - panel._oldest_frame_idx
# Backfill the missing dates of the new current window.
# Construct the missing dates.
missing_dts = self._create_window_date_buf(
delta, freq.unit_str, freq.data_frequency, oldest_dt,
)
# Fill the dates in between the new oldest index and adjusted oldest
# index.
where = slice(panel._oldest_frame_idx, -(old_cap - oldest_idx))
panel.date_buf[where] = missing_dts
panel.extend_back(missing_dts)
@with_environment()
def _create_window_date_buf(self,
+26 -12
View File
@@ -47,7 +47,6 @@ class RollingPanel(object):
self.minor_axis = _ensure_index(sids)
self.cap_multiple = cap_multiple
self.cap = cap_multiple * window
self.dtype = dtype
self.date_buf = np.empty(self.cap, dtype='M8[ns]') \
@@ -56,14 +55,22 @@ class RollingPanel(object):
self.buffer = self._create_buffer()
@property
def _oldest_frame_idx(self):
def cap(self):
return self.cap_multiple * self._window
@property
def start_index(self):
return self._pos - self._window
@property
def start_date(self):
return self.date_buf[self.start_index]
def oldest_frame(self):
"""
Get the oldest frame in the panel.
"""
return self.buffer.iloc[:, self._oldest_frame_idx, :]
return self.buffer.iloc[:, self.start_index, :]
def set_minor_axis(self, minor_axis):
self.minor_axis = _ensure_index(minor_axis)
@@ -82,16 +89,19 @@ class RollingPanel(object):
)
return panel
def resize(self, window):
def extend_back(self, missing_dts):
"""
Resizes the buffer to hold a new window with a new cap_multiple.
If cap_multiple is None, then the old cap_multiple is used.
"""
self._window = window
delta = len(missing_dts)
pre = self.cap
self.cap = self.cap_multiple * window
delta = self.cap - pre
if not delta:
raise ValueError(
'missing_dts must be a non-empty index',
)
self._window += delta
self._pos += delta
@@ -104,7 +114,7 @@ class RollingPanel(object):
pd.Panel(
items=self.items,
minor_axis=self.minor_axis,
major_axis=np.arange(delta),
major_axis=np.arange(delta * self.cap_multiple),
dtype=self.dtype,
),
self.buffer
@@ -113,6 +123,10 @@ class RollingPanel(object):
)
self.buffer.major_axis = pd.Int64Index(range(self.cap))
# Fill the delta with the dates we calculated.
where = slice(self.start_index, self.start_index + delta)
self.date_buf[where] = missing_dts
def add_frame(self, tick, frame):
"""
"""
@@ -130,7 +144,7 @@ class RollingPanel(object):
these objects because internal data might change
"""
where = slice(self._oldest_frame_idx, self._pos)
where = slice(self.start_index, self._pos)
major_axis = pd.DatetimeIndex(deepcopy(self.date_buf[where]), tz='utc')
return pd.Panel(self.buffer.values[:, where, :], self.items,
major_axis, self.minor_axis, dtype=self.dtype)
@@ -141,11 +155,11 @@ class RollingPanel(object):
passed panel. The passed panel must have the same indices as the panel
that would be returned by self.get_current.
"""
where = slice(self._oldest_frame_idx, self._pos)
where = slice(self.start_index, self._pos)
self.buffer.values[:, where, :] = panel.values
def current_dates(self):
where = slice(self._oldest_frame_idx, self._pos)
where = slice(self.start_index, self._pos)
return pd.DatetimeIndex(deepcopy(self.date_buf[where]), tz='utc')
def _roll_data(self):