mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-27 22:21:37 +08:00
BUG: fix loader bug for 1 day
This commit is contained in:
@@ -24,9 +24,10 @@ from zipline.pipeline.data import DataSet
|
||||
from zipline.pipeline.data import Column
|
||||
from zipline.pipeline.loaders.blaze.estimates import (
|
||||
BlazeNextEstimatesLoader,
|
||||
BlazePreviousEstimatesLoader,
|
||||
BlazeNextSplitAdjustedEstimatesLoader,
|
||||
BlazePreviousSplitAdjustedEstimatesLoader)
|
||||
BlazePreviousEstimatesLoader,
|
||||
BlazePreviousSplitAdjustedEstimatesLoader,
|
||||
)
|
||||
from zipline.pipeline.loaders.earnings_estimates import (
|
||||
INVALID_NUM_QTRS_MESSAGE,
|
||||
NextEarningsEstimatesLoader,
|
||||
@@ -174,6 +175,127 @@ class WithEstimates(WithTradingSessions, WithAdjustmentReader):
|
||||
cls.columns.items()})
|
||||
|
||||
|
||||
class WithOneDayPipeline(WithEstimates):
|
||||
"""
|
||||
ZiplineTestCase mixin providing cls.events as a class level fixture and
|
||||
defining a test for all inheritors to use.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
events : pd.DataFrame
|
||||
A simple DataFrame with columns needed for estimates and a single sid
|
||||
and no other data.
|
||||
|
||||
Tests
|
||||
------
|
||||
test_wrong_num_announcements_passed()
|
||||
Tests that loading with an incorrect quarter number raises an error.
|
||||
test_no_num_announcements_attr()
|
||||
Tests that the loader throws an AssertionError if the dataset being
|
||||
loaded has no `num_announcements` attribute.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def make_columns(cls):
|
||||
return {
|
||||
MultipleColumnsEstimates.event_date: 'event_date',
|
||||
MultipleColumnsEstimates.fiscal_quarter: 'fiscal_quarter',
|
||||
MultipleColumnsEstimates.fiscal_year: 'fiscal_year',
|
||||
MultipleColumnsEstimates.estimate1: 'estimate1',
|
||||
MultipleColumnsEstimates.estimate2: 'estimate2'
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def make_events(cls):
|
||||
return pd.DataFrame({
|
||||
SID_FIELD_NAME: [0] * 2,
|
||||
TS_FIELD_NAME: [pd.Timestamp('2015-01-01'),
|
||||
pd.Timestamp('2015-01-06')],
|
||||
EVENT_DATE_FIELD_NAME: [pd.Timestamp('2015-01-10'),
|
||||
pd.Timestamp('2015-01-20')],
|
||||
'estimate1': [1., 2.],
|
||||
'estimate2': [3., 4.],
|
||||
FISCAL_QUARTER_FIELD_NAME: [1, 2],
|
||||
FISCAL_YEAR_FIELD_NAME: [2015, 2015]
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def make_expected_out(cls):
|
||||
raise NotImplementedError('make_expected_out')
|
||||
|
||||
@classmethod
|
||||
def init_class_fixtures(cls):
|
||||
super(WithOneDayPipeline, cls).init_class_fixtures()
|
||||
cls.sid0 = cls.asset_finder.retrieve_asset(0)
|
||||
cls.expected_out = cls.make_expected_out()
|
||||
|
||||
def test_load_one_day(self):
|
||||
# We want to test multiple columns
|
||||
dataset = MultipleColumnsQuartersEstimates(1)
|
||||
engine = SimplePipelineEngine(
|
||||
lambda x: self.loader,
|
||||
self.trading_days,
|
||||
self.asset_finder,
|
||||
)
|
||||
|
||||
results = engine.run_pipeline(
|
||||
Pipeline({c.name: c.latest for c in dataset.columns}),
|
||||
start_date=pd.Timestamp('2015-01-15', tz='utc'),
|
||||
end_date=pd.Timestamp('2015-01-15', tz='utc'),
|
||||
)
|
||||
assert_frame_equal(results, self.expected_out)
|
||||
|
||||
|
||||
class PreviousWithOneDayPipeline(WithOneDayPipeline, ZiplineTestCase):
|
||||
"""
|
||||
Tests that previous quarter loader correctly breaks if an incorrect
|
||||
number of quarters is passed.
|
||||
"""
|
||||
@classmethod
|
||||
def make_loader(cls, events, columns):
|
||||
return PreviousEarningsEstimatesLoader(events, columns)
|
||||
|
||||
@classmethod
|
||||
def make_expected_out(cls):
|
||||
return pd.DataFrame(
|
||||
{
|
||||
EVENT_DATE_FIELD_NAME: pd.Timestamp('2015-01-10'),
|
||||
'estimate1': 1.,
|
||||
'estimate2': 3.,
|
||||
FISCAL_QUARTER_FIELD_NAME: 1.,
|
||||
FISCAL_YEAR_FIELD_NAME: 2015.,
|
||||
},
|
||||
index=pd.MultiIndex.from_tuples(
|
||||
((pd.Timestamp('2015-01-15', tz='utc'), cls.sid0),)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class NextWithOneDayPipeline(WithOneDayPipeline, ZiplineTestCase):
|
||||
"""
|
||||
Tests that next quarter loader correctly breaks if an incorrect
|
||||
number of quarters is passed.
|
||||
"""
|
||||
@classmethod
|
||||
def make_loader(cls, events, columns):
|
||||
return NextEarningsEstimatesLoader(events, columns)
|
||||
|
||||
@classmethod
|
||||
def make_expected_out(cls):
|
||||
return pd.DataFrame(
|
||||
{
|
||||
EVENT_DATE_FIELD_NAME: pd.Timestamp('2015-01-20'),
|
||||
'estimate1': 2.,
|
||||
'estimate2': 4.,
|
||||
FISCAL_QUARTER_FIELD_NAME: 2.,
|
||||
FISCAL_YEAR_FIELD_NAME: 2015.,
|
||||
},
|
||||
index=pd.MultiIndex.from_tuples(
|
||||
((pd.Timestamp('2015-01-15', tz='utc'), cls.sid0),)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
dummy_df = pd.DataFrame({SID_FIELD_NAME: 0},
|
||||
columns=[SID_FIELD_NAME,
|
||||
TS_FIELD_NAME,
|
||||
|
||||
@@ -318,6 +318,11 @@ class EarningsEstimatesLoader(PipelineLoader):
|
||||
sid : int
|
||||
The sid for which overwrites should be computed.
|
||||
"""
|
||||
# If data was requested for only 1 date, there can never be any
|
||||
# overwrites, so skip the extra work.
|
||||
if len(dates) == 1:
|
||||
return
|
||||
|
||||
next_qtr_start_indices = dates.searchsorted(
|
||||
group[EVENT_DATE_FIELD_NAME].values,
|
||||
side=self.searchsorted_side,
|
||||
@@ -649,11 +654,12 @@ class EarningsEstimatesLoader(PipelineLoader):
|
||||
:,
|
||||
asset_indexer,
|
||||
] = requested_qtr_data[column_name].values
|
||||
|
||||
out[col] = AdjustedArray(
|
||||
output_array,
|
||||
mask,
|
||||
dict(col_to_adjustments[column_name]),
|
||||
# There may not be any adjustments at all (e.g. if
|
||||
# len(date) == 1), so provide a default.
|
||||
dict(col_to_adjustments.get(column_name, {})),
|
||||
col.missing_value,
|
||||
)
|
||||
return out
|
||||
@@ -727,8 +733,6 @@ class NextEarningsEstimatesLoader(EarningsEstimatesLoader):
|
||||
sid_idx,
|
||||
col_to_split_adjustments=None,
|
||||
split_adjusted_asof_idx=None):
|
||||
# if not isinstance(sid_idx, int):
|
||||
# import pdb; pdb.set_trace()
|
||||
return [self.array_overwrites_dict[column.dtype](
|
||||
0,
|
||||
next_qtr_start_idx - 1,
|
||||
|
||||
Reference in New Issue
Block a user