BUG: fix blaze query in ffill_query_in_range to correct issue in events loader

This commit is contained in:
Joe Jevnik
2016-10-13 15:12:46 -04:00
parent 8f85bef9fe
commit 95a56663d0
2 changed files with 70 additions and 21 deletions
+48 -6
View File
@@ -330,9 +330,51 @@ class EventsLoaderTestCase(WithAssetFinder,
for c in EventDataSet.columns:
if c in self.next_value_columns:
self.check_next_value_results(c, results[c.name].unstack())
self.check_next_value_results(
c,
results[c.name].unstack(),
self.trading_days,
)
elif c in self.previous_value_columns:
self.check_previous_value_results(c, results[c.name].unstack())
self.check_previous_value_results(
c,
results[c.name].unstack(),
self.trading_days,
)
else:
raise AssertionError("Unexpected column %s." % c)
def test_load_properly_forward_fills(self):
engine = SimplePipelineEngine(
lambda x: self.loader,
self.trading_days,
self.asset_finder,
)
# Cut the dates in half so we need to forward fill some data which
# is not in our window. The results should be computed the same as if
# we had computed across the entire window and then sliced after the
# computation.
dates = self.trading_days[len(self.trading_days) / 2:]
results = engine.run_pipeline(
Pipeline({c.name: c.latest for c in EventDataSet.columns}),
start_date=dates[0],
end_date=dates[-1],
)
for c in EventDataSet.columns:
if c in self.next_value_columns:
self.check_next_value_results(
c,
results[c.name].unstack(),
dates,
)
elif c in self.previous_value_columns:
self.check_previous_value_results(
c,
results[c.name].unstack(),
dates,
)
else:
raise AssertionError("Unexpected column %s." % c)
@@ -342,7 +384,7 @@ class EventsLoaderTestCase(WithAssetFinder,
self.ASSET_FINDER_EQUITY_SIDS,
)
def check_previous_value_results(self, column, results):
def check_previous_value_results(self, column, results, dates):
"""
Check previous value results for a single column.
"""
@@ -352,7 +394,7 @@ class EventsLoaderTestCase(WithAssetFinder,
events = self.raw_events_no_nulls
# Remove timezone info from trading days, since the outputs
# from pandas won't be tz_localized.
dates = self.trading_days.tz_localize(None)
dates = dates.tz_localize(None)
for asset, asset_result in results.iteritems():
relevant_events = events[events.sid == asset.sid]
@@ -387,7 +429,7 @@ class EventsLoaderTestCase(WithAssetFinder,
allow_datetime_coercions=True,
)
def check_next_value_results(self, column, results):
def check_next_value_results(self, column, results, dates):
"""
Check results for a single column.
"""
@@ -396,7 +438,7 @@ class EventsLoaderTestCase(WithAssetFinder,
events = self.raw_events_no_nulls
# Remove timezone info from trading days, since the outputs
# from pandas won't be tz_localized.
dates = self.trading_days.tz_localize(None)
dates = dates.tz_localize(None)
for asset, asset_result in results.iteritems():
relevant_events = events[events.sid == asset.sid]
self.assertEqual(len(relevant_events), 2)
+22 -15
View File
@@ -1181,6 +1181,8 @@ def get_materialized_checkpoints(checkpoints, colnames, lower_dt, odo_kwargs):
**odo_kwargs
)
if pd.isnull(checkpoints_ts):
# We don't have a checkpoint for before our start date so just
# don't constrain the lower date.
materialized_checkpoints = pd.DataFrame(columns=colnames)
lower = None
else:
@@ -1192,7 +1194,7 @@ def get_materialized_checkpoints(checkpoints, colnames, lower_dt, odo_kwargs):
lower = checkpoints_ts
else:
materialized_checkpoints = pd.DataFrame(columns=colnames)
lower = None
lower = None # we don't have a good lower date constraint
return lower, materialized_checkpoints
@@ -1229,23 +1231,28 @@ def ffill_query_in_range(expr,
"""
odo_kwargs = odo_kwargs or {}
computed_lower, materialized_checkpoints = get_materialized_checkpoints(
checkpoints, expr.fields, lower, odo_kwargs
checkpoints,
expr.fields,
lower,
odo_kwargs,
)
if pd.isnull(computed_lower):
# If there is no lower date, just query for data in the date
# range. It must all be null anyways.
computed_lower = lower
pred = expr[ts_field] <= upper
if computed_lower is not None:
# only constrain the lower date if we computed a new lower date
pred &= expr[ts_field] >= computed_lower
raw = pd.concat(
[materialized_checkpoints,
odo(
expr[
(expr[ts_field] >= computed_lower) &
(expr[ts_field] <= upper)
],
pd.DataFrame,
**odo_kwargs
)]
(
materialized_checkpoints,
odo(
expr[pred],
pd.DataFrame,
**odo_kwargs
),
),
ignore_index=True,
)
raw.loc[:, ts_field] = raw.loc[:, ts_field].astype('datetime64[ns]')
return raw