mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-04 13:38:20 +08:00
TEST: Add more populate_initial_workspace tests.
- Tests different pipeline lengths and window lengths. - Tests a term that depends on a window of a term that's been precomputed.
This commit is contained in:
committed by
Joe Jevnik
parent
7f40f7a99d
commit
0f57dac4ab
@@ -72,6 +72,7 @@ from zipline.testing import (
|
||||
make_alternating_boolean_array,
|
||||
make_cascading_boolean_array,
|
||||
OpenPrice,
|
||||
parameter_space,
|
||||
product_upper_triangle,
|
||||
)
|
||||
from zipline.testing.fixtures import (
|
||||
@@ -1322,20 +1323,85 @@ class StringColumnTestCase(WithSeededRandomPipelineEngine,
|
||||
assert_frame_equal(result.c.unstack(), expected_final_result)
|
||||
|
||||
|
||||
class WindowSafetyPropagationTestCase(WithSeededRandomPipelineEngine,
|
||||
ZiplineTestCase):
|
||||
|
||||
SEEDED_RANDOM_PIPELINE_SEED = 5
|
||||
|
||||
def test_window_safety_propagation(self):
|
||||
dates = self.trading_days[-30:]
|
||||
start_date, end_date = dates[[-10, -1]]
|
||||
|
||||
col = TestingDataSet.float_col
|
||||
pipe = Pipeline(
|
||||
columns={
|
||||
'average_of_rank_plus_one': SimpleMovingAverage(
|
||||
inputs=[col.latest.rank() + 1],
|
||||
window_length=10,
|
||||
),
|
||||
'average_of_aliased_rank_plus_one': SimpleMovingAverage(
|
||||
inputs=[col.latest.rank().alias('some_alias') + 1],
|
||||
window_length=10,
|
||||
),
|
||||
'average_of_rank_plus_one_aliased': SimpleMovingAverage(
|
||||
inputs=[(col.latest.rank() + 1).alias('some_alias')],
|
||||
window_length=10,
|
||||
),
|
||||
}
|
||||
)
|
||||
results = self.run_pipeline(pipe, start_date, end_date).unstack()
|
||||
|
||||
expected_ranks = DataFrame(
|
||||
self.raw_expected_values(
|
||||
col,
|
||||
dates[-19],
|
||||
dates[-1],
|
||||
),
|
||||
index=dates[-19:],
|
||||
columns=self.asset_finder.retrieve_all(
|
||||
self.ASSET_FINDER_EQUITY_SIDS,
|
||||
)
|
||||
).rank(axis='columns')
|
||||
|
||||
# All three expressions should be equivalent and evaluate to this.
|
||||
expected_result = (
|
||||
(expected_ranks + 1)
|
||||
.rolling(10)
|
||||
.mean()
|
||||
.dropna(how='any')
|
||||
)
|
||||
|
||||
for colname in results.columns.levels[0]:
|
||||
assert_equal(expected_result, results[colname])
|
||||
|
||||
|
||||
class PopulateInitialWorkspaceTestCase(WithConstantInputs, ZiplineTestCase):
|
||||
def test_populate_default_workspace(self):
|
||||
window_length = 5
|
||||
|
||||
@parameter_space(window_length=[3, 5], pipeline_length=[5, 10])
|
||||
def test_populate_initial_workspace(self, window_length, pipeline_length):
|
||||
column = USEquityPricing.low
|
||||
base_term = column.latest
|
||||
precomputed_term = (base_term + 1).alias('precomputed_term')
|
||||
|
||||
# Take a Z-Score here so that the precomputed term is window-safe. The
|
||||
# z-score will never actually get computed because we swap it out.
|
||||
precomputed_term = (base_term.zscore()).alias('precomputed_term')
|
||||
|
||||
# A term that has `precomputed_term` as an input.
|
||||
depends_on_precomputed_term = precomputed_term + 1
|
||||
# A term that requires a window of `precomputed_term`.
|
||||
depends_on_window_of_precomputed_term = SimpleMovingAverage(
|
||||
inputs=[precomputed_term],
|
||||
window_length=window_length,
|
||||
)
|
||||
|
||||
precomputed_term_with_window = SimpleMovingAverage(
|
||||
inputs=(column,),
|
||||
window_length=window_length,
|
||||
).alias('precomputed_term_with_window')
|
||||
depends_on_precomputed_term = precomputed_term + 1
|
||||
depends_on_precomputed_term_with_window = (
|
||||
precomputed_term_with_window + 1
|
||||
)
|
||||
|
||||
column_value = self.constants[column]
|
||||
precomputed_term_value = -column_value
|
||||
precomputed_term_with_window_value = -(column_value + 1)
|
||||
@@ -1345,30 +1411,24 @@ class PopulateInitialWorkspaceTestCase(WithConstantInputs, ZiplineTestCase):
|
||||
execution_plan,
|
||||
dates,
|
||||
assets):
|
||||
def shape_for_term(term):
|
||||
ndates = len(execution_plan.mask_and_dates_for_term(
|
||||
term,
|
||||
root_mask_term,
|
||||
initial_workspace,
|
||||
dates,
|
||||
)[1])
|
||||
nassets = len(assets)
|
||||
return (ndates, nassets)
|
||||
|
||||
ws = initial_workspace.copy()
|
||||
_, precomputed_term_dates = execution_plan.mask_and_dates_for_term(
|
||||
precomputed_term,
|
||||
root_mask_term,
|
||||
initial_workspace,
|
||||
dates,
|
||||
)
|
||||
ws[precomputed_term] = full(
|
||||
(len(precomputed_term_dates), len(assets)),
|
||||
shape_for_term(precomputed_term),
|
||||
precomputed_term_value,
|
||||
dtype=float64,
|
||||
)
|
||||
(
|
||||
_,
|
||||
precomputed_term_with_window_dates,
|
||||
) = execution_plan.mask_and_dates_for_term(
|
||||
precomputed_term,
|
||||
root_mask_term,
|
||||
initial_workspace,
|
||||
dates,
|
||||
)
|
||||
|
||||
ws[precomputed_term_with_window] = full(
|
||||
(len(precomputed_term_with_window_dates), len(assets)),
|
||||
shape_for_term(precomputed_term_with_window),
|
||||
precomputed_term_with_window_value,
|
||||
dtype=float64,
|
||||
)
|
||||
@@ -1395,8 +1455,10 @@ class PopulateInitialWorkspaceTestCase(WithConstantInputs, ZiplineTestCase):
|
||||
'depends_on_precomputed_term': depends_on_precomputed_term,
|
||||
'depends_on_precomputed_term_with_window':
|
||||
depends_on_precomputed_term_with_window,
|
||||
'depends_on_window_of_precomputed_term':
|
||||
depends_on_window_of_precomputed_term,
|
||||
}),
|
||||
self.dates[window_length - 1],
|
||||
self.dates[-pipeline_length],
|
||||
self.dates[-1],
|
||||
)
|
||||
|
||||
@@ -1428,3 +1490,11 @@ class PopulateInitialWorkspaceTestCase(WithConstantInputs, ZiplineTestCase):
|
||||
precomputed_term_with_window_value + 1,
|
||||
),
|
||||
)
|
||||
assert_equal(
|
||||
results['depends_on_window_of_precomputed_term'].values,
|
||||
full_like(
|
||||
results['depends_on_window_of_precomputed_term'],
|
||||
precomputed_term_value,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -255,6 +255,7 @@ class AliasedMixin(SingleInputMixin):
|
||||
dtype=term.dtype,
|
||||
missing_value=term.missing_value,
|
||||
ndim=term.ndim,
|
||||
window_safe=term.window_safe,
|
||||
)
|
||||
|
||||
def _init(self, name, *args, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user