BUG: fix loader bug for 1 day

This commit is contained in:
Maya Tydykov
2017-02-15 15:01:44 -05:00
parent fb85e4a2bd
commit 382eef0e3d
2 changed files with 132 additions and 6 deletions
+124 -2
View File
@@ -24,9 +24,10 @@ from zipline.pipeline.data import DataSet
from zipline.pipeline.data import Column
from zipline.pipeline.loaders.blaze.estimates import (
BlazeNextEstimatesLoader,
BlazePreviousEstimatesLoader,
BlazeNextSplitAdjustedEstimatesLoader,
BlazePreviousSplitAdjustedEstimatesLoader)
BlazePreviousEstimatesLoader,
BlazePreviousSplitAdjustedEstimatesLoader,
)
from zipline.pipeline.loaders.earnings_estimates import (
INVALID_NUM_QTRS_MESSAGE,
NextEarningsEstimatesLoader,
@@ -174,6 +175,127 @@ class WithEstimates(WithTradingSessions, WithAdjustmentReader):
cls.columns.items()})
class WithOneDayPipeline(WithEstimates):
"""
ZiplineTestCase mixin providing cls.events as a class level fixture and
defining a test for all inheritors to use.
Attributes
----------
events : pd.DataFrame
A simple DataFrame with columns needed for estimates and a single sid
and no other data.
Tests
------
test_wrong_num_announcements_passed()
Tests that loading with an incorrect quarter number raises an error.
test_no_num_announcements_attr()
Tests that the loader throws an AssertionError if the dataset being
loaded has no `num_announcements` attribute.
"""
@classmethod
def make_columns(cls):
return {
MultipleColumnsEstimates.event_date: 'event_date',
MultipleColumnsEstimates.fiscal_quarter: 'fiscal_quarter',
MultipleColumnsEstimates.fiscal_year: 'fiscal_year',
MultipleColumnsEstimates.estimate1: 'estimate1',
MultipleColumnsEstimates.estimate2: 'estimate2'
}
@classmethod
def make_events(cls):
return pd.DataFrame({
SID_FIELD_NAME: [0] * 2,
TS_FIELD_NAME: [pd.Timestamp('2015-01-01'),
pd.Timestamp('2015-01-06')],
EVENT_DATE_FIELD_NAME: [pd.Timestamp('2015-01-10'),
pd.Timestamp('2015-01-20')],
'estimate1': [1., 2.],
'estimate2': [3., 4.],
FISCAL_QUARTER_FIELD_NAME: [1, 2],
FISCAL_YEAR_FIELD_NAME: [2015, 2015]
})
@classmethod
def make_expected_out(cls):
raise NotImplementedError('make_expected_out')
@classmethod
def init_class_fixtures(cls):
super(WithOneDayPipeline, cls).init_class_fixtures()
cls.sid0 = cls.asset_finder.retrieve_asset(0)
cls.expected_out = cls.make_expected_out()
def test_load_one_day(self):
# We want to test multiple columns
dataset = MultipleColumnsQuartersEstimates(1)
engine = SimplePipelineEngine(
lambda x: self.loader,
self.trading_days,
self.asset_finder,
)
results = engine.run_pipeline(
Pipeline({c.name: c.latest for c in dataset.columns}),
start_date=pd.Timestamp('2015-01-15', tz='utc'),
end_date=pd.Timestamp('2015-01-15', tz='utc'),
)
assert_frame_equal(results, self.expected_out)
class PreviousWithOneDayPipeline(WithOneDayPipeline, ZiplineTestCase):
"""
Tests that previous quarter loader correctly breaks if an incorrect
number of quarters is passed.
"""
@classmethod
def make_loader(cls, events, columns):
return PreviousEarningsEstimatesLoader(events, columns)
@classmethod
def make_expected_out(cls):
return pd.DataFrame(
{
EVENT_DATE_FIELD_NAME: pd.Timestamp('2015-01-10'),
'estimate1': 1.,
'estimate2': 3.,
FISCAL_QUARTER_FIELD_NAME: 1.,
FISCAL_YEAR_FIELD_NAME: 2015.,
},
index=pd.MultiIndex.from_tuples(
((pd.Timestamp('2015-01-15', tz='utc'), cls.sid0),)
)
)
class NextWithOneDayPipeline(WithOneDayPipeline, ZiplineTestCase):
"""
Tests that next quarter loader correctly breaks if an incorrect
number of quarters is passed.
"""
@classmethod
def make_loader(cls, events, columns):
return NextEarningsEstimatesLoader(events, columns)
@classmethod
def make_expected_out(cls):
return pd.DataFrame(
{
EVENT_DATE_FIELD_NAME: pd.Timestamp('2015-01-20'),
'estimate1': 2.,
'estimate2': 4.,
FISCAL_QUARTER_FIELD_NAME: 2.,
FISCAL_YEAR_FIELD_NAME: 2015.,
},
index=pd.MultiIndex.from_tuples(
((pd.Timestamp('2015-01-15', tz='utc'), cls.sid0),)
)
)
dummy_df = pd.DataFrame({SID_FIELD_NAME: 0},
columns=[SID_FIELD_NAME,
TS_FIELD_NAME,
@@ -318,6 +318,11 @@ class EarningsEstimatesLoader(PipelineLoader):
sid : int
The sid for which overwrites should be computed.
"""
# If data was requested for only 1 date, there can never be any
# overwrites, so skip the extra work.
if len(dates) == 1:
return
next_qtr_start_indices = dates.searchsorted(
group[EVENT_DATE_FIELD_NAME].values,
side=self.searchsorted_side,
@@ -649,11 +654,12 @@ class EarningsEstimatesLoader(PipelineLoader):
:,
asset_indexer,
] = requested_qtr_data[column_name].values
out[col] = AdjustedArray(
output_array,
mask,
dict(col_to_adjustments[column_name]),
# There may not be any adjustments at all (e.g. if
# len(date) == 1), so provide a default.
dict(col_to_adjustments.get(column_name, {})),
col.missing_value,
)
return out
@@ -727,8 +733,6 @@ class NextEarningsEstimatesLoader(EarningsEstimatesLoader):
sid_idx,
col_to_split_adjustments=None,
split_adjusted_asof_idx=None):
# if not isinstance(sid_idx, int):
# import pdb; pdb.set_trace()
return [self.array_overwrites_dict[column.dtype](
0,
next_qtr_start_idx - 1,