mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-04 05:50:56 +08:00
BUG: fix blaze query in ffill_query_in_range to correct issue in events loader
This commit is contained in:
@@ -330,9 +330,51 @@ class EventsLoaderTestCase(WithAssetFinder,
|
||||
|
||||
for c in EventDataSet.columns:
|
||||
if c in self.next_value_columns:
|
||||
self.check_next_value_results(c, results[c.name].unstack())
|
||||
self.check_next_value_results(
|
||||
c,
|
||||
results[c.name].unstack(),
|
||||
self.trading_days,
|
||||
)
|
||||
elif c in self.previous_value_columns:
|
||||
self.check_previous_value_results(c, results[c.name].unstack())
|
||||
self.check_previous_value_results(
|
||||
c,
|
||||
results[c.name].unstack(),
|
||||
self.trading_days,
|
||||
)
|
||||
else:
|
||||
raise AssertionError("Unexpected column %s." % c)
|
||||
|
||||
def test_load_properly_forward_fills(self):
|
||||
engine = SimplePipelineEngine(
|
||||
lambda x: self.loader,
|
||||
self.trading_days,
|
||||
self.asset_finder,
|
||||
)
|
||||
|
||||
# Cut the dates in half so we need to forward fill some data which
|
||||
# is not in our window. The results should be computed the same as if
|
||||
# we had computed across the entire window and then sliced after the
|
||||
# computation.
|
||||
dates = self.trading_days[len(self.trading_days) / 2:]
|
||||
results = engine.run_pipeline(
|
||||
Pipeline({c.name: c.latest for c in EventDataSet.columns}),
|
||||
start_date=dates[0],
|
||||
end_date=dates[-1],
|
||||
)
|
||||
|
||||
for c in EventDataSet.columns:
|
||||
if c in self.next_value_columns:
|
||||
self.check_next_value_results(
|
||||
c,
|
||||
results[c.name].unstack(),
|
||||
dates,
|
||||
)
|
||||
elif c in self.previous_value_columns:
|
||||
self.check_previous_value_results(
|
||||
c,
|
||||
results[c.name].unstack(),
|
||||
dates,
|
||||
)
|
||||
else:
|
||||
raise AssertionError("Unexpected column %s." % c)
|
||||
|
||||
@@ -342,7 +384,7 @@ class EventsLoaderTestCase(WithAssetFinder,
|
||||
self.ASSET_FINDER_EQUITY_SIDS,
|
||||
)
|
||||
|
||||
def check_previous_value_results(self, column, results):
|
||||
def check_previous_value_results(self, column, results, dates):
|
||||
"""
|
||||
Check previous value results for a single column.
|
||||
"""
|
||||
@@ -352,7 +394,7 @@ class EventsLoaderTestCase(WithAssetFinder,
|
||||
events = self.raw_events_no_nulls
|
||||
# Remove timezone info from trading days, since the outputs
|
||||
# from pandas won't be tz_localized.
|
||||
dates = self.trading_days.tz_localize(None)
|
||||
dates = dates.tz_localize(None)
|
||||
|
||||
for asset, asset_result in results.iteritems():
|
||||
relevant_events = events[events.sid == asset.sid]
|
||||
@@ -387,7 +429,7 @@ class EventsLoaderTestCase(WithAssetFinder,
|
||||
allow_datetime_coercions=True,
|
||||
)
|
||||
|
||||
def check_next_value_results(self, column, results):
|
||||
def check_next_value_results(self, column, results, dates):
|
||||
"""
|
||||
Check results for a single column.
|
||||
"""
|
||||
@@ -396,7 +438,7 @@ class EventsLoaderTestCase(WithAssetFinder,
|
||||
events = self.raw_events_no_nulls
|
||||
# Remove timezone info from trading days, since the outputs
|
||||
# from pandas won't be tz_localized.
|
||||
dates = self.trading_days.tz_localize(None)
|
||||
dates = dates.tz_localize(None)
|
||||
for asset, asset_result in results.iteritems():
|
||||
relevant_events = events[events.sid == asset.sid]
|
||||
self.assertEqual(len(relevant_events), 2)
|
||||
|
||||
@@ -1181,6 +1181,8 @@ def get_materialized_checkpoints(checkpoints, colnames, lower_dt, odo_kwargs):
|
||||
**odo_kwargs
|
||||
)
|
||||
if pd.isnull(checkpoints_ts):
|
||||
# We don't have a checkpoint for before our start date so just
|
||||
# don't constrain the lower date.
|
||||
materialized_checkpoints = pd.DataFrame(columns=colnames)
|
||||
lower = None
|
||||
else:
|
||||
@@ -1192,7 +1194,7 @@ def get_materialized_checkpoints(checkpoints, colnames, lower_dt, odo_kwargs):
|
||||
lower = checkpoints_ts
|
||||
else:
|
||||
materialized_checkpoints = pd.DataFrame(columns=colnames)
|
||||
lower = None
|
||||
lower = None # we don't have a good lower date constraint
|
||||
return lower, materialized_checkpoints
|
||||
|
||||
|
||||
@@ -1229,23 +1231,28 @@ def ffill_query_in_range(expr,
|
||||
"""
|
||||
odo_kwargs = odo_kwargs or {}
|
||||
computed_lower, materialized_checkpoints = get_materialized_checkpoints(
|
||||
checkpoints, expr.fields, lower, odo_kwargs
|
||||
checkpoints,
|
||||
expr.fields,
|
||||
lower,
|
||||
odo_kwargs,
|
||||
)
|
||||
if pd.isnull(computed_lower):
|
||||
# If there is no lower date, just query for data in the date
|
||||
# range. It must all be null anyways.
|
||||
computed_lower = lower
|
||||
|
||||
pred = expr[ts_field] <= upper
|
||||
|
||||
if computed_lower is not None:
|
||||
# only constrain the lower date if we computed a new lower date
|
||||
pred &= expr[ts_field] >= computed_lower
|
||||
|
||||
raw = pd.concat(
|
||||
[materialized_checkpoints,
|
||||
odo(
|
||||
expr[
|
||||
(expr[ts_field] >= computed_lower) &
|
||||
(expr[ts_field] <= upper)
|
||||
],
|
||||
pd.DataFrame,
|
||||
**odo_kwargs
|
||||
)]
|
||||
(
|
||||
materialized_checkpoints,
|
||||
odo(
|
||||
expr[pred],
|
||||
pd.DataFrame,
|
||||
**odo_kwargs
|
||||
),
|
||||
),
|
||||
ignore_index=True,
|
||||
)
|
||||
raw.loc[:, ts_field] = raw.loc[:, ts_field].astype('datetime64[ns]')
|
||||
return raw
|
||||
|
||||
Reference in New Issue
Block a user