mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-29 13:47:54 +08:00
BUG: fix deltas in blaze core loader
This commit is contained in:
@@ -858,6 +858,58 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
compute_fn=np.nanmax,
|
||||
)
|
||||
|
||||
@with_extra_sid
|
||||
def test_deltas_only_one_delta_in_universe(self, asset_info):
|
||||
expr = bz.Data(self.df, name='expr', dshape=self.dshape)
|
||||
deltas = pd.DataFrame({
|
||||
'sid': [65, 66],
|
||||
'asof_date': [self.dates[1], self.dates[0]],
|
||||
'timestamp': [self.dates[2], self.dates[1]],
|
||||
'value': [10, 11],
|
||||
})
|
||||
deltas = bz.Data(deltas, name='deltas', dshape=self.dshape)
|
||||
expected_views = keymap(pd.Timestamp, {
|
||||
'2014-01-02': np.array([[0.0, 11.0, 2.0],
|
||||
[1.0, 2.0, 3.0]]),
|
||||
'2014-01-03': np.array([[10.0, 2.0, 3.0],
|
||||
[2.0, 3.0, 4.0]]),
|
||||
'2014-01-04': np.array([[2.0, 3.0, 4.0],
|
||||
[2.0, 3.0, 4.0]]),
|
||||
})
|
||||
|
||||
nassets = len(asset_info)
|
||||
if nassets == 4:
|
||||
expected_views = valmap(
|
||||
lambda view: np.c_[view, [np.nan, np.nan]],
|
||||
expected_views,
|
||||
)
|
||||
|
||||
with tmp_asset_finder(equities=asset_info) as finder:
|
||||
expected_output = pd.DataFrame(
|
||||
columns=[
|
||||
'value',
|
||||
],
|
||||
data=np.array([11, 10, 4]).repeat(len(asset_info.index)),
|
||||
index=pd.MultiIndex.from_product((
|
||||
sorted(expected_views.keys()),
|
||||
finder.retrieve_all(asset_info.index),
|
||||
)),
|
||||
)
|
||||
dates = self.dates
|
||||
dates = dates.insert(len(dates), dates[-1] + timedelta(days=1))
|
||||
self._run_pipeline(
|
||||
expr,
|
||||
deltas,
|
||||
expected_views,
|
||||
expected_output,
|
||||
finder,
|
||||
calendar=dates,
|
||||
start=dates[1],
|
||||
end=dates[-1],
|
||||
window_length=2,
|
||||
compute_fn=np.nanmax,
|
||||
)
|
||||
|
||||
def test_deltas_macro(self):
|
||||
asset_info = asset_infos[0][0]
|
||||
expr = bz.Data(self.macro_df, name='expr', dshape=self.macro_dshape)
|
||||
|
||||
@@ -143,6 +143,7 @@ from datashape import (
|
||||
isscalar,
|
||||
promote,
|
||||
)
|
||||
import numpy as np
|
||||
from odo import odo
|
||||
import pandas as pd
|
||||
from six import with_metaclass, PY2, itervalues, iteritems
|
||||
@@ -675,6 +676,11 @@ def overwrite_from_dates(asof, dense_dates, sparse_dates, asset_idx, value):
|
||||
|
||||
Then the overwrite will apply to indexes: 1, 2, 3, 4
|
||||
"""
|
||||
if asof is pd.NaT:
|
||||
# Not an actual delta.
|
||||
# This happens due to the groupby we do on the deltas.
|
||||
return
|
||||
|
||||
first_row = dense_dates.searchsorted(asof)
|
||||
next_idx = sparse_dates.searchsorted(asof.asm8, 'right')
|
||||
if next_idx == len(sparse_dates):
|
||||
@@ -697,7 +703,7 @@ def adjustments_from_deltas_no_sids(dense_dates,
|
||||
sparse_dates,
|
||||
column_idx,
|
||||
column_name,
|
||||
assets,
|
||||
asset_idx,
|
||||
deltas):
|
||||
"""Collect all the adjustments that occur in a dataset that does not
|
||||
have a sid column.
|
||||
@@ -712,6 +718,8 @@ def adjustments_from_deltas_no_sids(dense_dates,
|
||||
The index of the column in the dataset.
|
||||
column_name : str
|
||||
The name of the column to compute deltas for.
|
||||
asset_idx : pd.Series[int -> int]
|
||||
The mapping of sids to their index in the output.
|
||||
deltas : pd.DataFrame
|
||||
The overwrites that should be applied to the dataset.
|
||||
|
||||
@@ -721,13 +729,13 @@ def adjustments_from_deltas_no_sids(dense_dates,
|
||||
The adjustments dictionary to feed to the adjusted array.
|
||||
"""
|
||||
ad_series = deltas[AD_FIELD_NAME]
|
||||
asset_idx = 0, len(assets) - 1
|
||||
idx = 0, len(asset_idx) - 1
|
||||
return {
|
||||
dense_dates.get_loc(kd): overwrite_from_dates(
|
||||
ad_series.loc[kd],
|
||||
dense_dates,
|
||||
sparse_dates,
|
||||
asset_idx,
|
||||
idx,
|
||||
v,
|
||||
) for kd, v in deltas[column_name].iteritems()
|
||||
}
|
||||
@@ -737,7 +745,7 @@ def adjustments_from_deltas_with_sids(dense_dates,
|
||||
sparse_dates,
|
||||
column_idx,
|
||||
column_name,
|
||||
assets,
|
||||
asset_idx,
|
||||
deltas):
|
||||
"""Collect all the adjustments that occur in a dataset that does not
|
||||
have a sid column.
|
||||
@@ -752,6 +760,8 @@ def adjustments_from_deltas_with_sids(dense_dates,
|
||||
The index of the column in the dataset.
|
||||
column_name : str
|
||||
The name of the column to compute deltas for.
|
||||
asset_idx : pd.Series[int -> int]
|
||||
The mapping of sids to their index in the output.
|
||||
deltas : pd.DataFrame
|
||||
The overwrites that should be applied to the dataset.
|
||||
|
||||
@@ -762,14 +772,15 @@ def adjustments_from_deltas_with_sids(dense_dates,
|
||||
"""
|
||||
ad_series = deltas[AD_FIELD_NAME]
|
||||
adjustments = defaultdict(list)
|
||||
for sid_idx, (sid, per_sid) in enumerate(deltas[column_name].iteritems()):
|
||||
for sid, per_sid in deltas[column_name].iteritems():
|
||||
idx = asset_idx[sid]
|
||||
for kd, v in per_sid.iteritems():
|
||||
adjustments[dense_dates.searchsorted(kd)].extend(
|
||||
overwrite_from_dates(
|
||||
ad_series.loc[kd, sid],
|
||||
dense_dates,
|
||||
sparse_dates,
|
||||
(sid_idx, sid_idx),
|
||||
(idx, idx),
|
||||
v,
|
||||
),
|
||||
)
|
||||
@@ -829,6 +840,7 @@ class BlazeLoader(dict):
|
||||
|
||||
expr, deltas, resources = self[dataset]
|
||||
have_sids = SID_FIELD_NAME in expr.fields
|
||||
asset_idx = pd.Series(index=assets, data=np.arange(len(assets)))
|
||||
assets = list(map(int, assets)) # coerce from numpy.int64
|
||||
added_query_fields = [AD_FIELD_NAME, TS_FIELD_NAME] + (
|
||||
[SID_FIELD_NAME] if have_sids else []
|
||||
@@ -1011,7 +1023,7 @@ class BlazeLoader(dict):
|
||||
sparse_output[TS_FIELD_NAME].values,
|
||||
column_idx,
|
||||
column_name,
|
||||
assets,
|
||||
asset_idx,
|
||||
sparse_deltas,
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user