From 58fb830ebda0b93fa8614d14c9c55eb7001d2fb6 Mon Sep 17 00:00:00 2001 From: Maya Tydykov Date: Tue, 14 Mar 2017 15:19:48 -0400 Subject: [PATCH] BUG: sort data on asof_date to resolve ts conflicts MAINT: fix arg default and update docstring --- tests/pipeline/test_blaze.py | 55 +++++++++++++++++++++++++- zipline/pipeline/loaders/blaze/core.py | 5 +++ zipline/pipeline/loaders/utils.py | 10 +++-- 3 files changed, 65 insertions(+), 5 deletions(-) diff --git a/tests/pipeline/test_blaze.py b/tests/pipeline/test_blaze.py index f4ce7df6..3f5f7e08 100644 --- a/tests/pipeline/test_blaze.py +++ b/tests/pipeline/test_blaze.py @@ -776,8 +776,9 @@ class BlazeToPipelineTestCase(WithAssetFinder, ZiplineTestCase): check_dtype=False, ) - def _test_id_macro(self, df, dshape, expected, finder, add): - dates = self.dates + def _test_id_macro(self, df, dshape, expected, finder, add, dates=None): + if dates is None: + dates = self.dates expr = bz.data(df, name='expr', dshape=dshape) loader = BlazeLoader() ds = from_blaze( @@ -1875,6 +1876,56 @@ class BlazeToPipelineTestCase(WithAssetFinder, ZiplineTestCase): self._test_checkpoints(checkpoints) + def test_id_take_last_in_group_sorted(self): + """ + input + asof_date timestamp other value + 2014-01-03 2014-01-04 00 3 3 + 2014-01-02 2014-01-04 00 2 2 + + output (expected): + + other value + 2014-01-02 NaN NaN + 2014-01-03 NaN NaN + 2014-01-06 3 3 + """ + + dates = pd.DatetimeIndex([ + pd.Timestamp('2014-01-02'), + pd.Timestamp('2014-01-03'), + pd.Timestamp('2014-01-06'), + ]) + + T = pd.Timestamp + df = pd.DataFrame( + columns=['asof_date', 'timestamp', 'other', 'value'], + data=[ + # asof-dates are flipped in terms of order so that if we + # don't sort on asof-date before getting the last in group, + # we will get the wrong result. + [T('2014-01-03'), T('2014-01-04 00'), 3, 3], + [T('2014-01-02'), T('2014-01-04 00'), 2, 2], + ], + ) + fields = OrderedDict(self.macro_dshape.measure.fields) + fields['other'] = fields['value'] + expected = pd.DataFrame( + data=[[np.nan, np.nan], # 2014-01-02 + [np.nan, np.nan], # 2014-01-03 + [3, 3]], # 2014-01-06 + columns=['other', 'value'], + index=dates, + ) + self._test_id_macro( + df, + var * Record(fields), + expected, + self.asset_finder, + ('other', 'value'), + dates=dates, + ) + class MiscTestCase(ZiplineTestCase): def test_exprdata_repr(self): diff --git a/zipline/pipeline/loaders/blaze/core.py b/zipline/pipeline/loaders/blaze/core.py index 3909b752..6ba1df0f 100644 --- a/zipline/pipeline/loaders/blaze/core.py +++ b/zipline/pipeline/loaders/blaze/core.py @@ -1104,6 +1104,11 @@ class BlazeLoader(dict): materialized_deltas, dates, ) + # If we ever have cases where we find out about multiple asof_dates' + # data on the same TS, we want to make sure that last_in_date_group + # selects the correct last asof_date's value. + sparse_output.sort_values(AD_FIELD_NAME, inplace=True) + non_novel_deltas.sort_values(AD_FIELD_NAME, inplace=True) if AD_FIELD_NAME not in requested_columns: sparse_output.drop(AD_FIELD_NAME, axis=1, inplace=True) diff --git a/zipline/pipeline/loaders/utils.py b/zipline/pipeline/loaders/utils.py index 028da74c..f5385bf3 100644 --- a/zipline/pipeline/loaders/utils.py +++ b/zipline/pipeline/loaders/utils.py @@ -281,15 +281,17 @@ def last_in_date_group(df, assets, reindex=True, have_sids=True, - extra_groupers=[]): + extra_groupers=None): """ Determine the last piece of information known on each date in the date - index for each group. + index for each group. Input df MUST be sorted such that the correct last + item is chosen from each group. Parameters ---------- df : pd.DataFrame - The DataFrame containing the data to be grouped. + The DataFrame containing the data to be grouped. Must be sorted so that + the correct last item is chosen from each group. dates : pd.DatetimeIndex The dates to use for grouping and reindexing. assets : pd.Int64Index @@ -316,6 +318,8 @@ def last_in_date_group(df, )]] if have_sids: idx += [SID_FIELD_NAME] + if extra_groupers is None: + extra_groupers = [] idx += extra_groupers last_in_group = df.drop(TS_FIELD_NAME, axis=1).groupby(