mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 09:59:22 +08:00
TEST: Parameterize over window_length.
This commit is contained in:
@@ -26,7 +26,7 @@ from numpy.random import randn, seed as random_seed
|
||||
from zipline.errors import BadPercentileBounds
|
||||
from zipline.pipeline import Filter, Factor, TermGraph
|
||||
from zipline.pipeline.factors import CustomFactor
|
||||
from zipline.testing import check_arrays
|
||||
from zipline.testing import check_arrays, parameter_space
|
||||
from zipline.utils.numpy_utils import float64_dtype
|
||||
from .base import BasePipelineTestCase, with_default_shape
|
||||
|
||||
@@ -383,13 +383,11 @@ class FilterTestCase(BasePipelineTestCase):
|
||||
)
|
||||
check_arrays(results['isfinite'], isfinite(data))
|
||||
|
||||
def test_window_safe(self):
|
||||
@parameter_space(factor_len=[2, 3, 4])
|
||||
def test_window_safe(self, factor_len):
|
||||
# all true data set of (days, securities)
|
||||
data = full(self.default_shape, True, dtype=bool)
|
||||
|
||||
# rolling window length for TestFactor
|
||||
k = 8
|
||||
|
||||
class InputFilter(Filter):
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
@@ -397,7 +395,7 @@ class FilterTestCase(BasePipelineTestCase):
|
||||
class TestFactor(CustomFactor):
|
||||
dtype = float64_dtype
|
||||
inputs = (InputFilter(), )
|
||||
window_length = k
|
||||
window_length = factor_len
|
||||
|
||||
def compute(self, today, assets, out, filter_):
|
||||
# sum for each column
|
||||
@@ -412,5 +410,8 @@ class FilterTestCase(BasePipelineTestCase):
|
||||
n = self.default_shape[0]
|
||||
|
||||
# shape of output array
|
||||
output_shape = ((n - k + 1), self.default_shape[1])
|
||||
check_arrays(results['windowsafe'], full(output_shape, k))
|
||||
output_shape = ((n - factor_len + 1), self.default_shape[1])
|
||||
check_arrays(
|
||||
results['windowsafe'],
|
||||
full(output_shape, factor_len, dtype=float64)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user