mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-03 17:13:37 +08:00
ENH: added smoothing to zipline
This commit is contained in:
@@ -395,6 +395,36 @@ class FilterTestCase(BasePipelineTestCase):
|
||||
)
|
||||
check_arrays(results['isfinite'], isfinite(data))
|
||||
|
||||
def test_smoothing_filter(self):
|
||||
from zipline.pipeline.filters import SmoothingFilter
|
||||
|
||||
data = full(self.default_shape, True, dtype=bool)
|
||||
# one column all false
|
||||
data[0, 0] = False
|
||||
data[1, 1] = False
|
||||
|
||||
class InputFilter(Filter):
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
|
||||
smoothing_filter = SmoothingFilter(
|
||||
inputs=[InputFilter()],
|
||||
window_length=self.default_shape[0]
|
||||
)
|
||||
|
||||
results = self.run_graph(
|
||||
TermGraph({'smoothing': smoothing_filter}),
|
||||
initial_workspace={InputFilter(): data}
|
||||
)
|
||||
|
||||
expected_result = full(self.default_shape[1], True, dtype=bool)
|
||||
expected_result[0] = False
|
||||
expected_result[1] = False
|
||||
check_arrays(
|
||||
results['smoothing'].flatten(),
|
||||
expected_result,
|
||||
)
|
||||
|
||||
@parameter_space(factor_len=[2, 3, 4])
|
||||
def test_window_safe(self, factor_len):
|
||||
# all true data set of (days, securities)
|
||||
|
||||
@@ -8,6 +8,7 @@ from .filter import (
|
||||
NumExprFilter,
|
||||
PercentileFilter,
|
||||
SingleAsset,
|
||||
SmoothingFilter,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -20,4 +21,5 @@ __all__ = [
|
||||
'NumExprFilter',
|
||||
'PercentileFilter',
|
||||
'SingleAsset',
|
||||
'SmoothingFilter',
|
||||
]
|
||||
|
||||
@@ -8,6 +8,7 @@ from numpy import (
|
||||
float64,
|
||||
nan,
|
||||
nanpercentile,
|
||||
sum as sum_
|
||||
)
|
||||
|
||||
from zipline.errors import (
|
||||
@@ -488,3 +489,18 @@ class SingleAsset(Filter):
|
||||
asset=self._asset, start_date=dates[0], end_date=dates[-1],
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
class SmoothingFilter(CustomFilter):
|
||||
"""
|
||||
A Filter that requires its inputs to have
|
||||
been True for the last `window_length` days.
|
||||
An integral part of the Q500US methodology
|
||||
|
||||
**Default Inputs**: None
|
||||
|
||||
**Default Window Length**: None
|
||||
"""
|
||||
|
||||
def compute(self, today, assets, out, arg):
|
||||
out[:] = (sum_(arg, axis=0) == self.window_length)
|
||||
|
||||
Reference in New Issue
Block a user