mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 17:33:01 +08:00
TST: add test for changing event dates and adjustments
BUG: get column names from column dict BUG: fix name map
This commit is contained in:
@@ -3,6 +3,7 @@ import itertools
|
||||
from nose.tools import assert_true
|
||||
from nose_parameterized import parameterized
|
||||
import numpy as np
|
||||
from numpy.testing import assert_array_equal
|
||||
import pandas as pd
|
||||
from toolz import merge
|
||||
|
||||
@@ -44,9 +45,9 @@ class Estimates(DataSet):
|
||||
estimate = Column(dtype=float64_dtype)
|
||||
|
||||
|
||||
def QuartersEstimates(num_qtr):
|
||||
def QuartersEstimates(announcements_out):
|
||||
class QtrEstimates(Estimates):
|
||||
num_quarters = num_qtr
|
||||
num_announcements = announcements_out
|
||||
name = Estimates
|
||||
return QtrEstimates
|
||||
|
||||
@@ -123,11 +124,11 @@ class WithWrongLoaderDefinition(WithEstimates):
|
||||
|
||||
Tests
|
||||
------
|
||||
test_wrong_num_quarters_passed()
|
||||
test_wrong_num_announcements_passed()
|
||||
Tests that loading with an incorrect quarter number raises an error.
|
||||
test_no_num_quarters_attr()
|
||||
test_no_num_announcements_attr()
|
||||
Tests that the loader throws an AssertionError if the dataset being
|
||||
loaded has no `num_quarters` attribute.
|
||||
loaded has no `num_announcements` attribute.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@@ -141,7 +142,7 @@ class WithWrongLoaderDefinition(WithEstimates):
|
||||
'estimate'],
|
||||
index=[0])
|
||||
|
||||
def test_wrong_num_quarters_passed(self):
|
||||
def test_wrong_num_announcements_passed(self):
|
||||
bad_dataset1 = QuartersEstimates(-1)
|
||||
bad_dataset2 = QuartersEstimates(-2)
|
||||
good_dataset = QuartersEstimates(1)
|
||||
@@ -150,7 +151,7 @@ class WithWrongLoaderDefinition(WithEstimates):
|
||||
self.trading_days,
|
||||
self.asset_finder,
|
||||
)
|
||||
columns = {c.name + str(dataset.num_quarters): c.latest
|
||||
columns = {c.name + str(dataset.num_announcements): c.latest
|
||||
for dataset in (bad_dataset1,
|
||||
bad_dataset2,
|
||||
good_dataset)
|
||||
@@ -165,7 +166,7 @@ class WithWrongLoaderDefinition(WithEstimates):
|
||||
)
|
||||
assert_raises_regex(e, INVALID_NUM_QTRS_MESSAGE % "-1,-2")
|
||||
|
||||
def test_no_num_quarters_attr(self):
|
||||
def test_no_num_announcements_attr(self):
|
||||
dataset = QuartersEstimatesNoNumQuartersAttr(1)
|
||||
engine = SimplePipelineEngine(
|
||||
lambda x: self.loader,
|
||||
@@ -657,6 +658,119 @@ class PreviousEstimateMultipleQuarters(
|
||||
return expected
|
||||
|
||||
|
||||
class WithVaryingNumEstimates(WithEstimates):
|
||||
"""
|
||||
ZiplineTestCase mixin providing fixtures and a test to ensure that we
|
||||
have the correct overwrites when the event date changes. We want to make
|
||||
sure that if we have a quarter with an event date that gets pushed back,
|
||||
we don't start overwriting for the next quarter early. Likewise,
|
||||
if we have a quarter with an event date that gets pushed forward, we want
|
||||
to make sure that we start applying adjustments at the appropriate, earlier
|
||||
date, rather than the later date.
|
||||
|
||||
Methods
|
||||
-------
|
||||
assert_compute()
|
||||
Defines how to determine that results computed for the `SomeFactor`
|
||||
factor are correct.
|
||||
|
||||
Tests
|
||||
-----
|
||||
test_windows_with_varying_num_estimates()
|
||||
Tests that we create the correct overwrites from 2015-01-13 to
|
||||
2015-01-14 regardless of how event dates were updated for each
|
||||
quarter for each sid.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def make_events(cls):
|
||||
return pd.DataFrame({
|
||||
SID_FIELD_NAME: [0] * 3 + [1] * 3,
|
||||
TS_FIELD_NAME: [pd.Timestamp('2015-01-09'),
|
||||
pd.Timestamp('2015-01-12'),
|
||||
pd.Timestamp('2015-01-13')] * 2,
|
||||
EVENT_DATE_FIELD_NAME: [pd.Timestamp('2015-01-12'),
|
||||
pd.Timestamp('2015-01-13'),
|
||||
pd.Timestamp('2015-01-20'),
|
||||
pd.Timestamp('2015-01-13'),
|
||||
pd.Timestamp('2015-01-12'),
|
||||
pd.Timestamp('2015-01-20')],
|
||||
'estimate': [11., 12., 21.] * 2,
|
||||
FISCAL_QUARTER_FIELD_NAME: [1, 1, 2] * 2,
|
||||
FISCAL_YEAR_FIELD_NAME: [2015] * 6
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def assert_compute(cls, estimate, today):
|
||||
raise NotImplementedError('assert_compute')
|
||||
|
||||
def test_windows_with_varying_num_estimates(self):
|
||||
dataset = QuartersEstimates(1)
|
||||
assert_compute = self.assert_compute
|
||||
|
||||
class SomeFactor(CustomFactor):
|
||||
inputs = [dataset.estimate]
|
||||
window_length = 3
|
||||
|
||||
def compute(self, today, assets, out, estimate):
|
||||
assert_compute(estimate, today)
|
||||
|
||||
engine = SimplePipelineEngine(
|
||||
lambda x: self.loader,
|
||||
self.trading_days,
|
||||
self.asset_finder,
|
||||
)
|
||||
engine.run_pipeline(
|
||||
Pipeline({'est': SomeFactor()}),
|
||||
start_date=pd.Timestamp('2015-01-13', tz='utc'),
|
||||
# last event date we have
|
||||
end_date=pd.Timestamp('2015-01-14', tz='utc'),
|
||||
)
|
||||
|
||||
|
||||
class PreviousVaryingNumEstimates(
|
||||
WithVaryingNumEstimates,
|
||||
ZiplineTestCase
|
||||
):
|
||||
def assert_compute(self, estimate, today):
|
||||
if today == pd.Timestamp('2015-01-13', tz='utc'):
|
||||
assert_array_equal(estimate[:, 0],
|
||||
np.array([np.NaN, np.NaN, 12]))
|
||||
assert_array_equal(estimate[:, 1],
|
||||
np.array([np.NaN, 12, 12]))
|
||||
else:
|
||||
assert_array_equal(estimate[:, 0],
|
||||
np.array([np.NaN, 12, 12]))
|
||||
assert_array_equal(estimate[:, 1],
|
||||
np.array([12, 12, 12]))
|
||||
|
||||
@classmethod
|
||||
def make_loader(cls, events, columns):
|
||||
return PreviousEarningsEstimatesLoader(events, columns)
|
||||
|
||||
|
||||
class NextVaryingNumEstimates(
|
||||
WithVaryingNumEstimates,
|
||||
ZiplineTestCase
|
||||
):
|
||||
|
||||
def assert_compute(self, estimate, today):
|
||||
if today == pd.Timestamp('2015-01-13', tz='utc'):
|
||||
assert_array_equal(estimate[:, 0],
|
||||
np.array([11, 12, 12]))
|
||||
assert_array_equal(estimate[:, 1],
|
||||
np.array([np.NaN, np.NaN, 21]))
|
||||
else:
|
||||
assert_array_equal(estimate[:, 0],
|
||||
np.array([np.NaN, 21, 21]))
|
||||
assert_array_equal(estimate[:, 1],
|
||||
np.array([np.NaN, 21, 21]))
|
||||
|
||||
@classmethod
|
||||
def make_loader(cls, events, columns):
|
||||
return NextEarningsEstimatesLoader(events, columns)
|
||||
|
||||
|
||||
class WithEstimateWindows(WithEstimates):
|
||||
"""
|
||||
ZiplineTestCase mixin providing fixures and a test to test running a
|
||||
@@ -761,8 +875,8 @@ class WithEstimateWindows(WithEstimates):
|
||||
@parameterized.expand(window_test_cases)
|
||||
def test_estimate_windows_at_quarter_boundaries(self,
|
||||
start_idx,
|
||||
num_quarters_out):
|
||||
dataset = QuartersEstimates(num_quarters_out)
|
||||
num_announcements_out):
|
||||
dataset = QuartersEstimates(num_announcements_out)
|
||||
trading_days = self.trading_days
|
||||
timelines = self.timelines
|
||||
# The window length should be from the starting index back to the first
|
||||
@@ -781,7 +895,7 @@ class WithEstimateWindows(WithEstimates):
|
||||
def compute(self, today, assets, out, estimate):
|
||||
today_idx = trading_days.get_loc(today)
|
||||
today_timeline = timelines[
|
||||
num_quarters_out
|
||||
num_announcements_out
|
||||
].loc[today].reindex(
|
||||
trading_days[:today_idx + 1]
|
||||
).values
|
||||
|
||||
@@ -188,7 +188,7 @@ from zipline.utils.input_validation import (
|
||||
ensure_timezone,
|
||||
optionally,
|
||||
)
|
||||
from zipline.utils.numpy_utils import bool_dtype, categorical_dtype
|
||||
from zipline.utils.numpy_utils import bool_dtype
|
||||
from zipline.utils.pool import SequentialPool
|
||||
from zipline.utils.preprocess import preprocess
|
||||
|
||||
|
||||
@@ -105,20 +105,22 @@ class BlazeEstimatesLoader(PipelineLoader):
|
||||
self._checkpoints = checkpoints
|
||||
|
||||
def load_adjusted_array(self, columns, dates, assets, mask):
|
||||
column_names = [column.name for column in columns]
|
||||
# Only load requested columns.
|
||||
requested_column_names = [self._columns[column.name]
|
||||
for column in columns]
|
||||
raw = load_raw_data(
|
||||
assets,
|
||||
dates,
|
||||
self._data_query_time,
|
||||
self._data_query_tz,
|
||||
self._expr[sorted(metadata_columns.union(column_names))],
|
||||
self._expr[sorted(metadata_columns.union(requested_column_names))],
|
||||
self._odo_kwargs,
|
||||
checkpoints=self._checkpoints,
|
||||
)
|
||||
|
||||
return self.loader(
|
||||
raw,
|
||||
{k: self._columns[k] for k in column_names}
|
||||
{column.name: self._columns[column.name] for column in columns}
|
||||
).load_adjusted_array(
|
||||
columns,
|
||||
dates,
|
||||
|
||||
@@ -58,8 +58,8 @@ class BlazeEventsLoader(PipelineLoader):
|
||||
"""
|
||||
|
||||
__doc__ = __doc__.format(SID_FIELD_NAME=SID_FIELD_NAME,
|
||||
TS_FIELD_NAME=TS_FIELD_NAME,
|
||||
EVENT_DATE_FIELD_NAME=EVENT_DATE_FIELD_NAME)
|
||||
TS_FIELD_NAME=TS_FIELD_NAME,
|
||||
EVENT_DATE_FIELD_NAME=EVENT_DATE_FIELD_NAME)
|
||||
|
||||
@preprocess(data_query_tz=optionally(ensure_timezone))
|
||||
def __init__(self,
|
||||
|
||||
@@ -37,7 +37,7 @@ NORMALIZED_QUARTERS = 'normalized_quarters'
|
||||
PREVIOUS_FISCAL_QUARTER = 'previous_fiscal_quarter'
|
||||
PREVIOUS_FISCAL_YEAR = 'previous_fiscal_year'
|
||||
SHIFTED_NORMALIZED_QTRS = 'shifted_normalized_quarters'
|
||||
SIMULTATION_DATES = 'dates'
|
||||
SIMULATION_DATES = 'dates'
|
||||
|
||||
|
||||
def normalize_quarters(years, quarters):
|
||||
@@ -95,7 +95,7 @@ class EarningsEstimatesLoader(PipelineLoader):
|
||||
"""
|
||||
An abstract pipeline loader for estimates data that can load data a
|
||||
variable number of quarters forwards/backwards from calendar dates
|
||||
depending on the `num_quarters` attribute of the columns' dataset.
|
||||
depending on the `num_announcements` attribute of the columns' dataset.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -152,11 +152,11 @@ class EarningsEstimatesLoader(PipelineLoader):
|
||||
self.name_map = name_map
|
||||
|
||||
@abstractmethod
|
||||
def get_zeroth_quarter_idx(self, num_quarters, last, dates):
|
||||
def get_zeroth_quarter_idx(self, num_announcements, last, dates):
|
||||
raise NotImplementedError('get_zeroth_quarter_idx')
|
||||
|
||||
@abstractmethod
|
||||
def get_shifted_qtrs(self, zero_qtrs, num_quarters):
|
||||
def get_shifted_qtrs(self, zero_qtrs, num_announcements):
|
||||
raise NotImplementedError('get_shifted_qtrs')
|
||||
|
||||
@abstractmethod
|
||||
@@ -178,7 +178,7 @@ class EarningsEstimatesLoader(PipelineLoader):
|
||||
zero_qtr_data,
|
||||
zeroth_quarter_idx,
|
||||
stacked_last_per_qtr,
|
||||
num_quarters,
|
||||
num_announcements,
|
||||
dates):
|
||||
"""
|
||||
Selects the requested data for each date.
|
||||
@@ -212,7 +212,7 @@ class EarningsEstimatesLoader(PipelineLoader):
|
||||
zeroth_quarter_idx.get_level_values(
|
||||
NORMALIZED_QUARTERS,
|
||||
),
|
||||
num_quarters,
|
||||
num_announcements,
|
||||
),
|
||||
],
|
||||
names=[
|
||||
@@ -397,18 +397,19 @@ class EarningsEstimatesLoader(PipelineLoader):
|
||||
|
||||
def load_adjusted_array(self, columns, dates, assets, mask):
|
||||
# Separate out getting the columns' datasets and the datasets'
|
||||
# num_quarters attributes to ensure that we're catching the right
|
||||
# num_announcements attributes to ensure that we're catching the right
|
||||
# AttributeError.
|
||||
col_to_datasets = {col: col.dataset for col in columns}
|
||||
try:
|
||||
groups = groupby(lambda col: col_to_datasets[col].num_quarters,
|
||||
groups = groupby(lambda col:
|
||||
col_to_datasets[col].num_announcements,
|
||||
col_to_datasets)
|
||||
except AttributeError:
|
||||
raise AttributeError("Datasets loaded via the "
|
||||
"EarningsEstimatesLoader must define a "
|
||||
"`num_quarters` attribute that defines how "
|
||||
"many quarters out the loader should load "
|
||||
"the data relative to `dates`.")
|
||||
"`num_announcements` attribute that defines "
|
||||
"how many quarters out the loader should load"
|
||||
" the data relative to `dates`.")
|
||||
if any(num_qtr < 0 for num_qtr in groups):
|
||||
raise ValueError(
|
||||
INVALID_NUM_QTRS_MESSAGE % ','.join(
|
||||
@@ -430,12 +431,12 @@ class EarningsEstimatesLoader(PipelineLoader):
|
||||
zeroth_quarter_idx = self.get_zeroth_quarter_idx(stacked_last_per_qtr)
|
||||
zero_qtr_data = stacked_last_per_qtr.loc[zeroth_quarter_idx]
|
||||
|
||||
for num_quarters, columns in groups.items():
|
||||
for num_announcements, columns in groups.items():
|
||||
requested_qtr_data = self.get_requested_quarter_data(
|
||||
zero_qtr_data,
|
||||
zeroth_quarter_idx,
|
||||
stacked_last_per_qtr,
|
||||
num_quarters,
|
||||
num_announcements,
|
||||
dates,
|
||||
)
|
||||
|
||||
@@ -523,7 +524,7 @@ class EarningsEstimatesLoader(PipelineLoader):
|
||||
)
|
||||
# Set date index name for ease of reference
|
||||
stacked_last_per_qtr.index.set_names(
|
||||
SIMULTATION_DATES,
|
||||
SIMULATION_DATES,
|
||||
level=0,
|
||||
inplace=True,
|
||||
)
|
||||
@@ -560,8 +561,8 @@ class NextEarningsEstimatesLoader(EarningsEstimatesLoader):
|
||||
].values[:next_qtr_start_idx],
|
||||
)
|
||||
|
||||
def get_shifted_qtrs(self, zero_qtrs, num_quarters):
|
||||
return zero_qtrs + (num_quarters - 1)
|
||||
def get_shifted_qtrs(self, zero_qtrs, num_announcements):
|
||||
return zero_qtrs + (num_announcements - 1)
|
||||
|
||||
def get_zeroth_quarter_idx(self, stacked_last_per_qtr):
|
||||
"""
|
||||
@@ -584,9 +585,9 @@ class NextEarningsEstimatesLoader(EarningsEstimatesLoader):
|
||||
"""
|
||||
next_releases_per_date = stacked_last_per_qtr.loc[
|
||||
stacked_last_per_qtr[EVENT_DATE_FIELD_NAME] >=
|
||||
stacked_last_per_qtr.index.get_level_values(SIMULTATION_DATES)
|
||||
stacked_last_per_qtr.index.get_level_values(SIMULATION_DATES)
|
||||
].groupby(
|
||||
level=[SIMULTATION_DATES, SID_FIELD_NAME],
|
||||
level=[SIMULATION_DATES, SID_FIELD_NAME],
|
||||
as_index=False,
|
||||
# Here we take advantage of the fact that `stacked_last_per_qtr` is
|
||||
# sorted by event date.
|
||||
@@ -612,8 +613,8 @@ class PreviousEarningsEstimatesLoader(EarningsEstimatesLoader):
|
||||
sid_idx,
|
||||
)
|
||||
|
||||
def get_shifted_qtrs(self, zero_qtrs, num_quarters):
|
||||
return zero_qtrs - (num_quarters - 1)
|
||||
def get_shifted_qtrs(self, zero_qtrs, num_announcements):
|
||||
return zero_qtrs - (num_announcements - 1)
|
||||
|
||||
def get_zeroth_quarter_idx(self, stacked_last_per_qtr):
|
||||
"""
|
||||
@@ -636,9 +637,9 @@ class PreviousEarningsEstimatesLoader(EarningsEstimatesLoader):
|
||||
"""
|
||||
previous_releases_per_date = stacked_last_per_qtr.loc[
|
||||
stacked_last_per_qtr[EVENT_DATE_FIELD_NAME] <=
|
||||
stacked_last_per_qtr.index.get_level_values(SIMULTATION_DATES)
|
||||
stacked_last_per_qtr.index.get_level_values(SIMULATION_DATES)
|
||||
].groupby(
|
||||
level=[SIMULTATION_DATES, SID_FIELD_NAME],
|
||||
level=[SIMULATION_DATES, SID_FIELD_NAME],
|
||||
as_index=False,
|
||||
# Here we take advantage of the fact that `stacked_last_per_qtr` is
|
||||
# sorted by event date.
|
||||
|
||||
@@ -49,14 +49,8 @@ from zipline.pipeline.loaders.testing import make_seeded_random_loader
|
||||
from zipline.utils import security_list
|
||||
from zipline.utils.calendars import get_calendar
|
||||
from zipline.utils.input_validation import expect_dimensions
|
||||
<<<<<<< HEAD
|
||||
from zipline.utils.numpy_utils import as_column, isnat
|
||||
from zipline.utils.pandas_utils import timedelta_to_integral_seconds
|
||||
=======
|
||||
from zipline.utils.numpy_utils import (
|
||||
as_column,
|
||||
)
|
||||
>>>>>>> WIP
|
||||
from zipline.utils.sentinel import sentinel
|
||||
|
||||
import numpy as np
|
||||
|
||||
Reference in New Issue
Block a user