mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 18:49:42 +08:00
PERF: only query for the columns requested + metadata
BUG: choose last event date for quarter shift
This commit is contained in:
@@ -746,7 +746,7 @@ class WithEstimateWindows(WithEstimates):
|
||||
columns=[SID_FIELD_NAME,
|
||||
'estimate',
|
||||
'knowledge_date'])
|
||||
df = df.pivot_table(columns='sid',
|
||||
df = df.pivot_table(columns=SID_FIELD_NAME,
|
||||
values='estimate',
|
||||
index='knowledge_date')
|
||||
df = df.reindex(
|
||||
@@ -796,8 +796,8 @@ class WithEstimateWindows(WithEstimates):
|
||||
engine.run_pipeline(
|
||||
Pipeline({'est': SomeFactor()}),
|
||||
start_date=start_idx,
|
||||
end_date=pd.Timestamp('2015-01-20', tz='utc'), # last event date
|
||||
# we have
|
||||
# last event date we have
|
||||
end_date=pd.Timestamp('2015-01-20', tz='utc'),
|
||||
)
|
||||
|
||||
|
||||
@@ -938,7 +938,7 @@ class QuarterShiftTestCase(ZiplineTestCase):
|
||||
backwards/forwards from a starting point.
|
||||
"""
|
||||
def test_quarter_normalization(self):
|
||||
input_yrs = pd.Series([0] * 4, dtype=np.int64)
|
||||
input_yrs = pd.Series(range(2011, 2015), dtype=np.int64)
|
||||
input_qtrs = pd.Series(range(1, 5), dtype=np.int64)
|
||||
result_years, result_quarters = split_normalized_quarters(
|
||||
normalize_quarters(input_yrs, input_qtrs)
|
||||
|
||||
@@ -16,6 +16,7 @@ from zipline.pipeline.loaders.earnings_estimates import (
|
||||
NextEarningsEstimatesLoader,
|
||||
PreviousEarningsEstimatesLoader,
|
||||
required_estimates_fields,
|
||||
metadata_columns,
|
||||
)
|
||||
from zipline.pipeline.loaders.utils import (
|
||||
check_data_query_args,
|
||||
@@ -104,17 +105,20 @@ class BlazeEstimatesLoader(PipelineLoader):
|
||||
self._checkpoints = checkpoints
|
||||
|
||||
def load_adjusted_array(self, columns, dates, assets, mask):
|
||||
raw = load_raw_data(assets,
|
||||
dates,
|
||||
self._data_query_time,
|
||||
self._data_query_tz,
|
||||
self._expr,
|
||||
self._odo_kwargs,
|
||||
checkpoints=self._checkpoints)
|
||||
column_names = [column.name for column in columns]
|
||||
raw = load_raw_data(
|
||||
assets,
|
||||
dates,
|
||||
self._data_query_time,
|
||||
self._data_query_tz,
|
||||
self._expr[sorted(metadata_columns.union(column_names))],
|
||||
self._odo_kwargs,
|
||||
checkpoints=self._checkpoints,
|
||||
)
|
||||
|
||||
return self.loader(
|
||||
raw,
|
||||
self._columns,
|
||||
{k: self._columns[k] for k in column_names}
|
||||
).load_adjusted_array(
|
||||
columns,
|
||||
dates,
|
||||
|
||||
@@ -57,7 +57,7 @@ class BlazeEventsLoader(PipelineLoader):
|
||||
start the backtest with knowledge of all announcements.
|
||||
"""
|
||||
|
||||
__doc__ == __doc__.format(SID_FIELD_NAME=SID_FIELD_NAME,
|
||||
__doc__ = __doc__.format(SID_FIELD_NAME=SID_FIELD_NAME,
|
||||
TS_FIELD_NAME=TS_FIELD_NAME,
|
||||
EVENT_DATE_FIELD_NAME=EVENT_DATE_FIELD_NAME)
|
||||
|
||||
|
||||
@@ -50,23 +50,24 @@ def split_normalized_quarters(normalized_quarters):
|
||||
return years, quarters + 1
|
||||
|
||||
|
||||
# These metadata columns are used to align event indexers.
|
||||
metadata_columns = frozenset({
|
||||
TS_FIELD_NAME,
|
||||
SID_FIELD_NAME,
|
||||
EVENT_DATE_FIELD_NAME,
|
||||
FISCAL_QUARTER_FIELD_NAME,
|
||||
FISCAL_YEAR_FIELD_NAME,
|
||||
})
|
||||
|
||||
|
||||
def required_estimates_fields(columns):
|
||||
"""
|
||||
Compute the set of resource columns required to serve
|
||||
`columns`.
|
||||
"""
|
||||
# These metadata columns are used to align event indexers.
|
||||
return {
|
||||
TS_FIELD_NAME,
|
||||
SID_FIELD_NAME,
|
||||
EVENT_DATE_FIELD_NAME,
|
||||
FISCAL_QUARTER_FIELD_NAME,
|
||||
FISCAL_YEAR_FIELD_NAME
|
||||
}.union(
|
||||
# We also expect any of the field names that our loadable columns
|
||||
# are mapped to.
|
||||
viewvalues(columns),
|
||||
)
|
||||
# We also expect any of the field names that our loadable columns
|
||||
# are mapped to.
|
||||
return metadata_columns.union(viewvalues(columns))
|
||||
|
||||
|
||||
def validate_column_specs(events, columns):
|
||||
@@ -269,18 +270,13 @@ class EarningsEstimatesLoader(PipelineLoader):
|
||||
The array of data and overwrites for the given column.
|
||||
"""
|
||||
col_to_overwrites = defaultdict(dict)
|
||||
# We no longer need NORMALIZED_QUARTERS in the index, but we do need it
|
||||
# as a column to calculate adjustments.
|
||||
zero_qtr_data = zero_qtr_data.reset_index(level=NORMALIZED_QUARTERS)
|
||||
zero_qtr_data.sort_index(inplace=True)
|
||||
|
||||
quarter_shifts = zero_qtr_data.loc[
|
||||
zero_qtr_data.index[
|
||||
zero_qtr_data.groupby(level=SID_FIELD_NAME)[
|
||||
NORMALIZED_QUARTERS
|
||||
].diff().nonzero()
|
||||
]
|
||||
]
|
||||
# Here we want to get the LAST record from each group of records
|
||||
# corresponding to a single quarter. This is to ensure that we select
|
||||
# the most up-to-date event date in case the event date changes.
|
||||
quarter_shifts = zero_qtr_data.groupby(
|
||||
level=[SID_FIELD_NAME, NORMALIZED_QUARTERS]
|
||||
).nth(-1)
|
||||
|
||||
sid_to_idx = dict(zip(assets, range(len(assets))))
|
||||
|
||||
@@ -290,7 +286,9 @@ class EarningsEstimatesLoader(PipelineLoader):
|
||||
side=self.searchsorted_side,
|
||||
)
|
||||
sid = int(group.name)
|
||||
qtrs_with_estimates = group[NORMALIZED_QUARTERS].values
|
||||
qtrs_with_estimates = group.index.get_level_values(
|
||||
NORMALIZED_QUARTERS
|
||||
).values
|
||||
for idx in next_qtr_start_indices:
|
||||
if 0 < idx < len(dates):
|
||||
# Only add adjustments if the next quarter starts somewhere
|
||||
@@ -584,13 +582,14 @@ class NextEarningsEstimatesLoader(EarningsEstimatesLoader):
|
||||
An index of calendar dates, sid, and normalized quarters, for only
|
||||
the rows that have a next event.
|
||||
"""
|
||||
|
||||
next_releases_per_date = stacked_last_per_qtr.loc[
|
||||
stacked_last_per_qtr[EVENT_DATE_FIELD_NAME] >=
|
||||
stacked_last_per_qtr.index.get_level_values(SIMULTATION_DATES)
|
||||
].groupby(
|
||||
level=[SIMULTATION_DATES, SID_FIELD_NAME],
|
||||
as_index=False,
|
||||
# Here we take advantage of the fact that `stacked_last_per_qtr` is
|
||||
# sorted by event date.
|
||||
).nth(0)
|
||||
return next_releases_per_date.index
|
||||
|
||||
@@ -635,12 +634,13 @@ class PreviousEarningsEstimatesLoader(EarningsEstimatesLoader):
|
||||
An index of calendar dates, sid, and normalized quarters, for only
|
||||
the rows that have a previous event.
|
||||
"""
|
||||
|
||||
previous_releases_per_date = stacked_last_per_qtr.loc[
|
||||
stacked_last_per_qtr[EVENT_DATE_FIELD_NAME] <=
|
||||
stacked_last_per_qtr.index.get_level_values(SIMULTATION_DATES)
|
||||
].groupby(
|
||||
level=[SIMULTATION_DATES, SID_FIELD_NAME],
|
||||
as_index=False,
|
||||
# Here we take advantage of the fact that `stacked_last_per_qtr` is
|
||||
# sorted by event date.
|
||||
).nth(-1)
|
||||
return previous_releases_per_date.index
|
||||
|
||||
Reference in New Issue
Block a user