diff --git a/tests/test_history.py b/tests/test_history.py index 8db6d495..9a2f21cb 100644 --- a/tests/test_history.py +++ b/tests/test_history.py @@ -935,6 +935,102 @@ def handle_data(context, data): ' HistorySpec', ) + @parameterized.expand([ + (1,), + (2,), + ]) + def test_history_grow_length_inter_bar(self, incr): + """ + Tests growing the length of a digest panel with different date_buf + deltas once per bar. + """ + algo_text = dedent( + """\ + from zipline.api import history + + + def initialize(context): + context.bar_count = 1 + + + def handle_data(context, data): + prices = history(context.bar_count, '1d', 'price') + context.test_case.assertEqual(len(prices), context.bar_count) + context.bar_count += {incr} + """ + ).format(incr=incr) + start = pd.Timestamp('2007-04-05', tz='UTC') + end = pd.Timestamp('2007-04-10', tz='UTC') + + sim_params = SimulationParameters( + period_start=start, + period_end=end, + capital_base=float("1.0e5"), + data_frequency='minute', + emission_rate='daily' + ) + + test_algo = TradingAlgorithm( + script=algo_text, + data_frequency='minute', + sim_params=sim_params + ) + test_algo.test_case = self + + source = RandomWalkSource(start=start, end=end) + + self.assertIsNone(test_algo.history_container) + test_algo.run(source) + + @parameterized.expand([ + (1,), + (2,), + ]) + def test_history_grow_length_intra_bar(self, incr): + """ + Tests growing the length of a digest panel with different date_buf + deltas in a single bar. + """ + algo_text = dedent( + """\ + from zipline.api import history + + + def initialize(context): + context.bar_count = 1 + + + def handle_data(context, data): + prices = history(context.bar_count, '1d', 'price') + context.test_case.assertEqual(len(prices), context.bar_count) + context.bar_count += {incr} + prices = history(context.bar_count, '1d', 'price') + context.test_case.assertEqual(len(prices), context.bar_count) + """ + ).format(incr=incr) + start = pd.Timestamp('2007-04-05', tz='UTC') + end = pd.Timestamp('2007-04-10', tz='UTC') + + sim_params = SimulationParameters( + period_start=start, + period_end=end, + capital_base=float("1.0e5"), + data_frequency='minute', + emission_rate='daily' + ) + + test_algo = TradingAlgorithm( + script=algo_text, + data_frequency='minute', + sim_params=sim_params + ) + test_algo.test_case = self + + source = RandomWalkSource(start=start, end=end) + + self.assertIsNone(test_algo.history_container) + test_algo.run(source) + class TestHistoryContainerResize(TestCase): @parameterized.expand( diff --git a/zipline/history/history_container.py b/zipline/history/history_container.py index 29e06d25..7ed961ee 100644 --- a/zipline/history/history_container.py +++ b/zipline/history/history_container.py @@ -387,24 +387,15 @@ class HistoryContainer(object): """ # This is the oldest datetime that will be shown in the current window # of the panel. - oldest_idx = panel._oldest_frame_idx - oldest_dt = pd.Timestamp( - panel.date_buf[oldest_idx], tz='utc', - ) - old_cap = panel.cap - panel.resize(size) + oldest_dt = pd.Timestamp(panel.start_date, tz='utc',) + delta = size - panel.window_length - delta = (old_cap - oldest_idx) - panel._oldest_frame_idx - - # Backfill the missing dates of the new current window. + # Construct the missing dates. missing_dts = self._create_window_date_buf( delta, freq.unit_str, freq.data_frequency, oldest_dt, ) - # Fill the dates in between the new oldest index and adjusted oldest - # index. - where = slice(panel._oldest_frame_idx, -(old_cap - oldest_idx)) - panel.date_buf[where] = missing_dts + panel.extend_back(missing_dts) @with_environment() def _create_window_date_buf(self, diff --git a/zipline/utils/data.py b/zipline/utils/data.py index ef5f80f8..6f308120 100644 --- a/zipline/utils/data.py +++ b/zipline/utils/data.py @@ -47,7 +47,6 @@ class RollingPanel(object): self.minor_axis = _ensure_index(sids) self.cap_multiple = cap_multiple - self.cap = cap_multiple * window self.dtype = dtype self.date_buf = np.empty(self.cap, dtype='M8[ns]') \ @@ -56,14 +55,22 @@ class RollingPanel(object): self.buffer = self._create_buffer() @property - def _oldest_frame_idx(self): + def cap(self): + return self.cap_multiple * self._window + + @property + def start_index(self): return self._pos - self._window + @property + def start_date(self): + return self.date_buf[self.start_index] + def oldest_frame(self): """ Get the oldest frame in the panel. """ - return self.buffer.iloc[:, self._oldest_frame_idx, :] + return self.buffer.iloc[:, self.start_index, :] def set_minor_axis(self, minor_axis): self.minor_axis = _ensure_index(minor_axis) @@ -82,16 +89,19 @@ class RollingPanel(object): ) return panel - def resize(self, window): + def extend_back(self, missing_dts): """ Resizes the buffer to hold a new window with a new cap_multiple. If cap_multiple is None, then the old cap_multiple is used. """ - self._window = window + delta = len(missing_dts) - pre = self.cap - self.cap = self.cap_multiple * window - delta = self.cap - pre + if not delta: + raise ValueError( + 'missing_dts must be a non-empty index', + ) + + self._window += delta self._pos += delta @@ -104,7 +114,7 @@ class RollingPanel(object): pd.Panel( items=self.items, minor_axis=self.minor_axis, - major_axis=np.arange(delta), + major_axis=np.arange(delta * self.cap_multiple), dtype=self.dtype, ), self.buffer @@ -113,6 +123,10 @@ class RollingPanel(object): ) self.buffer.major_axis = pd.Int64Index(range(self.cap)) + # Fill the delta with the dates we calculated. + where = slice(self.start_index, self.start_index + delta) + self.date_buf[where] = missing_dts + def add_frame(self, tick, frame): """ """ @@ -130,7 +144,7 @@ class RollingPanel(object): these objects because internal data might change """ - where = slice(self._oldest_frame_idx, self._pos) + where = slice(self.start_index, self._pos) major_axis = pd.DatetimeIndex(deepcopy(self.date_buf[where]), tz='utc') return pd.Panel(self.buffer.values[:, where, :], self.items, major_axis, self.minor_axis, dtype=self.dtype) @@ -141,11 +155,11 @@ class RollingPanel(object): passed panel. The passed panel must have the same indices as the panel that would be returned by self.get_current. """ - where = slice(self._oldest_frame_idx, self._pos) + where = slice(self.start_index, self._pos) self.buffer.values[:, where, :] = panel.values def current_dates(self): - where = slice(self._oldest_frame_idx, self._pos) + where = slice(self.start_index, self._pos) return pd.DatetimeIndex(deepcopy(self.date_buf[where]), tz='utc') def _roll_data(self):