From 95a56663d008c06f3b0e43f160bade4aa68bbae7 Mon Sep 17 00:00:00 2001 From: Joe Jevnik Date: Thu, 13 Oct 2016 15:12:46 -0400 Subject: [PATCH] BUG: fix blaze query in ffill_query_in_range to correct issue in events loader --- tests/pipeline/test_events.py | 54 +++++++++++++++++++++++--- zipline/pipeline/loaders/blaze/core.py | 37 +++++++++++------- 2 files changed, 70 insertions(+), 21 deletions(-) diff --git a/tests/pipeline/test_events.py b/tests/pipeline/test_events.py index c10b15c5..8381df06 100644 --- a/tests/pipeline/test_events.py +++ b/tests/pipeline/test_events.py @@ -330,9 +330,51 @@ class EventsLoaderTestCase(WithAssetFinder, for c in EventDataSet.columns: if c in self.next_value_columns: - self.check_next_value_results(c, results[c.name].unstack()) + self.check_next_value_results( + c, + results[c.name].unstack(), + self.trading_days, + ) elif c in self.previous_value_columns: - self.check_previous_value_results(c, results[c.name].unstack()) + self.check_previous_value_results( + c, + results[c.name].unstack(), + self.trading_days, + ) + else: + raise AssertionError("Unexpected column %s." % c) + + def test_load_properly_forward_fills(self): + engine = SimplePipelineEngine( + lambda x: self.loader, + self.trading_days, + self.asset_finder, + ) + + # Cut the dates in half so we need to forward fill some data which + # is not in our window. The results should be computed the same as if + # we had computed across the entire window and then sliced after the + # computation. + dates = self.trading_days[len(self.trading_days) / 2:] + results = engine.run_pipeline( + Pipeline({c.name: c.latest for c in EventDataSet.columns}), + start_date=dates[0], + end_date=dates[-1], + ) + + for c in EventDataSet.columns: + if c in self.next_value_columns: + self.check_next_value_results( + c, + results[c.name].unstack(), + dates, + ) + elif c in self.previous_value_columns: + self.check_previous_value_results( + c, + results[c.name].unstack(), + dates, + ) else: raise AssertionError("Unexpected column %s." % c) @@ -342,7 +384,7 @@ class EventsLoaderTestCase(WithAssetFinder, self.ASSET_FINDER_EQUITY_SIDS, ) - def check_previous_value_results(self, column, results): + def check_previous_value_results(self, column, results, dates): """ Check previous value results for a single column. """ @@ -352,7 +394,7 @@ class EventsLoaderTestCase(WithAssetFinder, events = self.raw_events_no_nulls # Remove timezone info from trading days, since the outputs # from pandas won't be tz_localized. - dates = self.trading_days.tz_localize(None) + dates = dates.tz_localize(None) for asset, asset_result in results.iteritems(): relevant_events = events[events.sid == asset.sid] @@ -387,7 +429,7 @@ class EventsLoaderTestCase(WithAssetFinder, allow_datetime_coercions=True, ) - def check_next_value_results(self, column, results): + def check_next_value_results(self, column, results, dates): """ Check results for a single column. """ @@ -396,7 +438,7 @@ class EventsLoaderTestCase(WithAssetFinder, events = self.raw_events_no_nulls # Remove timezone info from trading days, since the outputs # from pandas won't be tz_localized. - dates = self.trading_days.tz_localize(None) + dates = dates.tz_localize(None) for asset, asset_result in results.iteritems(): relevant_events = events[events.sid == asset.sid] self.assertEqual(len(relevant_events), 2) diff --git a/zipline/pipeline/loaders/blaze/core.py b/zipline/pipeline/loaders/blaze/core.py index 06b59cf4..145c503d 100644 --- a/zipline/pipeline/loaders/blaze/core.py +++ b/zipline/pipeline/loaders/blaze/core.py @@ -1181,6 +1181,8 @@ def get_materialized_checkpoints(checkpoints, colnames, lower_dt, odo_kwargs): **odo_kwargs ) if pd.isnull(checkpoints_ts): + # We don't have a checkpoint for before our start date so just + # don't constrain the lower date. materialized_checkpoints = pd.DataFrame(columns=colnames) lower = None else: @@ -1192,7 +1194,7 @@ def get_materialized_checkpoints(checkpoints, colnames, lower_dt, odo_kwargs): lower = checkpoints_ts else: materialized_checkpoints = pd.DataFrame(columns=colnames) - lower = None + lower = None # we don't have a good lower date constraint return lower, materialized_checkpoints @@ -1229,23 +1231,28 @@ def ffill_query_in_range(expr, """ odo_kwargs = odo_kwargs or {} computed_lower, materialized_checkpoints = get_materialized_checkpoints( - checkpoints, expr.fields, lower, odo_kwargs + checkpoints, + expr.fields, + lower, + odo_kwargs, ) - if pd.isnull(computed_lower): - # If there is no lower date, just query for data in the date - # range. It must all be null anyways. - computed_lower = lower + + pred = expr[ts_field] <= upper + + if computed_lower is not None: + # only constrain the lower date if we computed a new lower date + pred &= expr[ts_field] >= computed_lower raw = pd.concat( - [materialized_checkpoints, - odo( - expr[ - (expr[ts_field] >= computed_lower) & - (expr[ts_field] <= upper) - ], - pd.DataFrame, - **odo_kwargs - )] + ( + materialized_checkpoints, + odo( + expr[pred], + pd.DataFrame, + **odo_kwargs + ), + ), + ignore_index=True, ) raw.loc[:, ts_field] = raw.loc[:, ts_field].astype('datetime64[ns]') return raw