From 8df1a49031297b99ed082a3fb6bdebbc787b0d3a Mon Sep 17 00:00:00 2001 From: Joe Jevnik Date: Mon, 10 Nov 2014 14:32:57 -0500 Subject: [PATCH] BUG: When increasing the length dynamically, the rolling panel was getting filled with the wrong datetimes and causing errors. Updates the logic for addressing missing datetimes and adds unit tests for the 2 main cases (no missing datetimes, and some missing datetimes). --- tests/test_history.py | 96 ++++++++++++++++++++++++++++ zipline/history/history_container.py | 17 ++--- zipline/utils/data.py | 38 +++++++---- 3 files changed, 126 insertions(+), 25 deletions(-) diff --git a/tests/test_history.py b/tests/test_history.py index 8db6d495..9a2f21cb 100644 --- a/tests/test_history.py +++ b/tests/test_history.py @@ -935,6 +935,102 @@ def handle_data(context, data): ' HistorySpec', ) + @parameterized.expand([ + (1,), + (2,), + ]) + def test_history_grow_length_inter_bar(self, incr): + """ + Tests growing the length of a digest panel with different date_buf + deltas once per bar. + """ + algo_text = dedent( + """\ + from zipline.api import history + + + def initialize(context): + context.bar_count = 1 + + + def handle_data(context, data): + prices = history(context.bar_count, '1d', 'price') + context.test_case.assertEqual(len(prices), context.bar_count) + context.bar_count += {incr} + """ + ).format(incr=incr) + start = pd.Timestamp('2007-04-05', tz='UTC') + end = pd.Timestamp('2007-04-10', tz='UTC') + + sim_params = SimulationParameters( + period_start=start, + period_end=end, + capital_base=float("1.0e5"), + data_frequency='minute', + emission_rate='daily' + ) + + test_algo = TradingAlgorithm( + script=algo_text, + data_frequency='minute', + sim_params=sim_params + ) + test_algo.test_case = self + + source = RandomWalkSource(start=start, end=end) + + self.assertIsNone(test_algo.history_container) + test_algo.run(source) + + @parameterized.expand([ + (1,), + (2,), + ]) + def test_history_grow_length_intra_bar(self, incr): + """ + Tests growing the length of a digest panel with different date_buf + deltas in a single bar. + """ + algo_text = dedent( + """\ + from zipline.api import history + + + def initialize(context): + context.bar_count = 1 + + + def handle_data(context, data): + prices = history(context.bar_count, '1d', 'price') + context.test_case.assertEqual(len(prices), context.bar_count) + context.bar_count += {incr} + prices = history(context.bar_count, '1d', 'price') + context.test_case.assertEqual(len(prices), context.bar_count) + """ + ).format(incr=incr) + start = pd.Timestamp('2007-04-05', tz='UTC') + end = pd.Timestamp('2007-04-10', tz='UTC') + + sim_params = SimulationParameters( + period_start=start, + period_end=end, + capital_base=float("1.0e5"), + data_frequency='minute', + emission_rate='daily' + ) + + test_algo = TradingAlgorithm( + script=algo_text, + data_frequency='minute', + sim_params=sim_params + ) + test_algo.test_case = self + + source = RandomWalkSource(start=start, end=end) + + self.assertIsNone(test_algo.history_container) + test_algo.run(source) + class TestHistoryContainerResize(TestCase): @parameterized.expand( diff --git a/zipline/history/history_container.py b/zipline/history/history_container.py index 29e06d25..7ed961ee 100644 --- a/zipline/history/history_container.py +++ b/zipline/history/history_container.py @@ -387,24 +387,15 @@ class HistoryContainer(object): """ # This is the oldest datetime that will be shown in the current window # of the panel. - oldest_idx = panel._oldest_frame_idx - oldest_dt = pd.Timestamp( - panel.date_buf[oldest_idx], tz='utc', - ) - old_cap = panel.cap - panel.resize(size) + oldest_dt = pd.Timestamp(panel.start_date, tz='utc',) + delta = size - panel.window_length - delta = (old_cap - oldest_idx) - panel._oldest_frame_idx - - # Backfill the missing dates of the new current window. + # Construct the missing dates. missing_dts = self._create_window_date_buf( delta, freq.unit_str, freq.data_frequency, oldest_dt, ) - # Fill the dates in between the new oldest index and adjusted oldest - # index. - where = slice(panel._oldest_frame_idx, -(old_cap - oldest_idx)) - panel.date_buf[where] = missing_dts + panel.extend_back(missing_dts) @with_environment() def _create_window_date_buf(self, diff --git a/zipline/utils/data.py b/zipline/utils/data.py index ef5f80f8..6f308120 100644 --- a/zipline/utils/data.py +++ b/zipline/utils/data.py @@ -47,7 +47,6 @@ class RollingPanel(object): self.minor_axis = _ensure_index(sids) self.cap_multiple = cap_multiple - self.cap = cap_multiple * window self.dtype = dtype self.date_buf = np.empty(self.cap, dtype='M8[ns]') \ @@ -56,14 +55,22 @@ class RollingPanel(object): self.buffer = self._create_buffer() @property - def _oldest_frame_idx(self): + def cap(self): + return self.cap_multiple * self._window + + @property + def start_index(self): return self._pos - self._window + @property + def start_date(self): + return self.date_buf[self.start_index] + def oldest_frame(self): """ Get the oldest frame in the panel. """ - return self.buffer.iloc[:, self._oldest_frame_idx, :] + return self.buffer.iloc[:, self.start_index, :] def set_minor_axis(self, minor_axis): self.minor_axis = _ensure_index(minor_axis) @@ -82,16 +89,19 @@ class RollingPanel(object): ) return panel - def resize(self, window): + def extend_back(self, missing_dts): """ Resizes the buffer to hold a new window with a new cap_multiple. If cap_multiple is None, then the old cap_multiple is used. """ - self._window = window + delta = len(missing_dts) - pre = self.cap - self.cap = self.cap_multiple * window - delta = self.cap - pre + if not delta: + raise ValueError( + 'missing_dts must be a non-empty index', + ) + + self._window += delta self._pos += delta @@ -104,7 +114,7 @@ class RollingPanel(object): pd.Panel( items=self.items, minor_axis=self.minor_axis, - major_axis=np.arange(delta), + major_axis=np.arange(delta * self.cap_multiple), dtype=self.dtype, ), self.buffer @@ -113,6 +123,10 @@ class RollingPanel(object): ) self.buffer.major_axis = pd.Int64Index(range(self.cap)) + # Fill the delta with the dates we calculated. + where = slice(self.start_index, self.start_index + delta) + self.date_buf[where] = missing_dts + def add_frame(self, tick, frame): """ """ @@ -130,7 +144,7 @@ class RollingPanel(object): these objects because internal data might change """ - where = slice(self._oldest_frame_idx, self._pos) + where = slice(self.start_index, self._pos) major_axis = pd.DatetimeIndex(deepcopy(self.date_buf[where]), tz='utc') return pd.Panel(self.buffer.values[:, where, :], self.items, major_axis, self.minor_axis, dtype=self.dtype) @@ -141,11 +155,11 @@ class RollingPanel(object): passed panel. The passed panel must have the same indices as the panel that would be returned by self.get_current. """ - where = slice(self._oldest_frame_idx, self._pos) + where = slice(self.start_index, self._pos) self.buffer.values[:, where, :] = panel.values def current_dates(self): - where = slice(self._oldest_frame_idx, self._pos) + where = slice(self.start_index, self._pos) return pd.DatetimeIndex(deepcopy(self.date_buf[where]), tz='utc') def _roll_data(self):