diff --git a/tests/pipeline/test_filter.py b/tests/pipeline/test_filter.py index ccf6211f..1c14ded2 100644 --- a/tests/pipeline/test_filter.py +++ b/tests/pipeline/test_filter.py @@ -395,6 +395,36 @@ class FilterTestCase(BasePipelineTestCase): ) check_arrays(results['isfinite'], isfinite(data)) + def test_smoothing_filter(self): + from zipline.pipeline.filters import SmoothingFilter + + data = full(self.default_shape, True, dtype=bool) + # one column all false + data[0, 0] = False + data[1, 1] = False + + class InputFilter(Filter): + inputs = () + window_length = 0 + + smoothing_filter = SmoothingFilter( + inputs=[InputFilter()], + window_length=self.default_shape[0] + ) + + results = self.run_graph( + TermGraph({'smoothing': smoothing_filter}), + initial_workspace={InputFilter(): data} + ) + + expected_result = full(self.default_shape[1], True, dtype=bool) + expected_result[0] = False + expected_result[1] = False + check_arrays( + results['smoothing'].flatten(), + expected_result, + ) + @parameter_space(factor_len=[2, 3, 4]) def test_window_safe(self, factor_len): # all true data set of (days, securities) diff --git a/zipline/pipeline/filters/__init__.py b/zipline/pipeline/filters/__init__.py index b1c06ec4..0f0bb52f 100644 --- a/zipline/pipeline/filters/__init__.py +++ b/zipline/pipeline/filters/__init__.py @@ -8,6 +8,7 @@ from .filter import ( NumExprFilter, PercentileFilter, SingleAsset, + SmoothingFilter, ) __all__ = [ @@ -20,4 +21,5 @@ __all__ = [ 'NumExprFilter', 'PercentileFilter', 'SingleAsset', + 'SmoothingFilter', ] diff --git a/zipline/pipeline/filters/filter.py b/zipline/pipeline/filters/filter.py index 502fd031..3cf8ae10 100644 --- a/zipline/pipeline/filters/filter.py +++ b/zipline/pipeline/filters/filter.py @@ -8,6 +8,7 @@ from numpy import ( float64, nan, nanpercentile, + sum as sum_ ) from zipline.errors import ( @@ -488,3 +489,18 @@ class SingleAsset(Filter): asset=self._asset, start_date=dates[0], end_date=dates[-1], ) return out + + +class SmoothingFilter(CustomFilter): + """ + A Filter that requires its inputs to have + been True for the last `window_length` days. + An integral part of the Q500US methodology + + **Default Inputs**: None + + **Default Window Length**: None + """ + + def compute(self, today, assets, out, arg): + out[:] = (sum_(arg, axis=0) == self.window_length)