From c2aae2e0f4ece23502c11f4e3d09c83e3e33ccc2 Mon Sep 17 00:00:00 2001 From: Joe Jevnik Date: Tue, 11 Nov 2014 17:46:50 -0500 Subject: [PATCH] BUG: rolling panel data became misaligned after extend_back --- tests/test_rolling_panel.py | 68 +++++++++++++++++++++++++++- zipline/history/history_container.py | 19 +++----- zipline/utils/data.py | 61 ++++++++++++++++--------- 3 files changed, 113 insertions(+), 35 deletions(-) diff --git a/tests/test_rolling_panel.py b/tests/test_rolling_panel.py index ab87bf09..60f42cfa 100644 --- a/tests/test_rolling_panel.py +++ b/tests/test_rolling_panel.py @@ -22,7 +22,73 @@ import numpy as np import pandas as pd import pandas.util.testing as tm -from zipline.utils.data import MutableIndexRollingPanel +from zipline.utils.data import MutableIndexRollingPanel, RollingPanel +from zipline.finance.trading import with_environment + + +class TestRollingPanel(unittest.TestCase): + @with_environment() + def test_alignment(self, env): + items = ('a', 'b') + sids = (1, 2) + + dts = env.market_minute_window( + env.open_and_closes.market_open[0], 4, + ).values + rp = RollingPanel(2, items, sids, initial_dates=dts[1:-1]) + + frame = pd.DataFrame( + data=np.arange(4).reshape((2, 2)), + columns=sids, + index=items, + ) + + nan_arr = np.empty((2, 6)) + nan_arr.fill(np.nan) + + rp.add_frame(dts[-1], frame) + + cur = rp.get_current() + data = np.array((((np.nan, np.nan), + (0, 1)), + ((np.nan, np.nan), + (2, 3))), + float) + expected = pd.Panel( + data, + major_axis=dts[2:], + minor_axis=sids, + items=items, + ) + expected.major_axis = expected.major_axis.tz_localize('utc') + tm.assert_panel_equal( + cur, + expected, + ) + + rp.extend_back(dts[:-2]) + + cur = rp.get_current() + data = np.array((((np.nan, np.nan), + (np.nan, np.nan), + (np.nan, np.nan), + (0, 1)), + ((np.nan, np.nan), + (np.nan, np.nan), + (np.nan, np.nan), + (2, 3))), + float) + expected = pd.Panel( + data, + major_axis=dts, + minor_axis=sids, + items=items, + ) + expected.major_axis = expected.major_axis.tz_localize('utc') + tm.assert_panel_equal( + cur, + expected, + ) class TestMutableIndexRollingPanel(unittest.TestCase): diff --git a/zipline/history/history_container.py b/zipline/history/history_container.py index 7ed961ee..77f1441d 100644 --- a/zipline/history/history_container.py +++ b/zipline/history/history_container.py @@ -439,24 +439,19 @@ class HistoryContainer(object): window = spec.bar_count - 1 - # everything after dt is going to be filled from calling update, no - # need to precompute these dates. - second = np.empty(window, dtype='datetime64[ns]') - date_buf = np.hstack( - (self._create_window_date_buf( - window, - spec.frequency.unit_str, - spec.frequency.data_frequency, - dt, - env=env, - ), second), + date_buf = self._create_window_date_buf( + window, + spec.frequency.unit_str, + spec.frequency.data_frequency, + dt, + env=env, ) panel = RollingPanel( window=window, items=self.fields, sids=self.sids, - date_buf=date_buf, + initial_dates=date_buf, ) return panel diff --git a/zipline/utils/data.py b/zipline/utils/data.py index 6f308120..797493b8 100644 --- a/zipline/utils/data.py +++ b/zipline/utils/data.py @@ -38,7 +38,7 @@ class RollingPanel(object): sids, cap_multiple=2, dtype=np.float64, - date_buf=None): + initial_dates=None): self._pos = window self._window = window @@ -49,8 +49,20 @@ class RollingPanel(object): self.cap_multiple = cap_multiple self.dtype = dtype - self.date_buf = np.empty(self.cap, dtype='M8[ns]') \ - if date_buf is None else date_buf + if initial_dates is None: + self.date_buf = np.empty(self.cap, dtype='M8[ns]') * pd.NaT + elif len(initial_dates) != window: + raise ValueError('initial_dates must be of length window') + else: + self.date_buf = np.hstack( + ( + initial_dates, + np.empty( + window * (cap_multiple - 1), + dtype='datetime64[ns]', + ), + ), + ) self.buffer = self._create_buffer() @@ -59,18 +71,18 @@ class RollingPanel(object): return self.cap_multiple * self._window @property - def start_index(self): + def _start_index(self): return self._pos - self._window @property def start_date(self): - return self.date_buf[self.start_index] + return self.date_buf[self._start_index] def oldest_frame(self): """ Get the oldest frame in the panel. """ - return self.buffer.iloc[:, self.start_index, :] + return self.buffer.iloc[:, self._start_index, :] def set_minor_axis(self, minor_axis): self.minor_axis = _ensure_index(minor_axis) @@ -109,22 +121,27 @@ class RollingPanel(object): self.date_buf.resize(self.cap) self.date_buf = np.roll(self.date_buf, delta) - self.buffer = pd.concat( - [ - pd.Panel( - items=self.items, - minor_axis=self.minor_axis, - major_axis=np.arange(delta * self.cap_multiple), - dtype=self.dtype, - ), - self.buffer - ], - axis=1, + old_vals = self.buffer.values + shape = old_vals.shape + nan_arr = np.empty((shape[0], delta, shape[2])) + nan_arr.fill(np.nan) + + new_vals = np.column_stack( + (nan_arr, + old_vals, + np.empty((shape[0], delta * (self.cap_multiple - 1), shape[2]))), + ) + + self.buffer = pd.Panel( + data=new_vals, + items=self.items, + minor_axis=self.minor_axis, + major_axis=np.arange(self.cap), + dtype=self.dtype, ) - self.buffer.major_axis = pd.Int64Index(range(self.cap)) # Fill the delta with the dates we calculated. - where = slice(self.start_index, self.start_index + delta) + where = slice(self._start_index, self._start_index + delta) self.date_buf[where] = missing_dts def add_frame(self, tick, frame): @@ -144,7 +161,7 @@ class RollingPanel(object): these objects because internal data might change """ - where = slice(self.start_index, self._pos) + where = slice(self._start_index, self._pos) major_axis = pd.DatetimeIndex(deepcopy(self.date_buf[where]), tz='utc') return pd.Panel(self.buffer.values[:, where, :], self.items, major_axis, self.minor_axis, dtype=self.dtype) @@ -155,11 +172,11 @@ class RollingPanel(object): passed panel. The passed panel must have the same indices as the panel that would be returned by self.get_current. """ - where = slice(self.start_index, self._pos) + where = slice(self._start_index, self._pos) self.buffer.values[:, where, :] = panel.values def current_dates(self): - where = slice(self.start_index, self._pos) + where = slice(self._start_index, self._pos) return pd.DatetimeIndex(deepcopy(self.date_buf[where]), tz='utc') def _roll_data(self):