mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 11:42:40 +08:00
Merge pull request #1367 from quantopian/smoothing-at-least
ENH: Added AtLeastN filter
This commit is contained in:
@@ -31,6 +31,11 @@ Enhancements
|
||||
returns True if an asset produced a True for any/all days in the previous
|
||||
``window_length`` days (:issue:`1358`).
|
||||
|
||||
- Added new pipeline filter :class:`~zipline.pipeline.filters.AtLeastN`,
|
||||
which takes another filter and an int N and returns True if an asset
|
||||
produced a True on N or more days in the previous ``window_length``
|
||||
days (:issue:`1367`).
|
||||
|
||||
Bug Fixes
|
||||
~~~~~~~~~
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ from zipline.errors import BadPercentileBounds
|
||||
from zipline.pipeline import Filter, Factor, TermGraph
|
||||
from zipline.pipeline.classifiers import Classifier
|
||||
from zipline.pipeline.factors import CustomFactor
|
||||
from zipline.pipeline.filters import All, Any
|
||||
from zipline.pipeline.filters import All, Any, AtLeastN
|
||||
from zipline.testing import check_arrays, parameter_space, permute_rows
|
||||
from zipline.utils.numpy_utils import float64_dtype, int64_dtype
|
||||
from .base import BasePipelineTestCase, with_default_shape
|
||||
@@ -498,6 +498,81 @@ class FilterTestCase(BasePipelineTestCase):
|
||||
check_arrays(results['3'], expected_3)
|
||||
check_arrays(results['4'], expected_4)
|
||||
|
||||
def test_at_least_N(self):
|
||||
|
||||
# With a window_length of K, AtLeastN should return 1
|
||||
# if N or more 1's exist in the lookback window
|
||||
|
||||
# This smoothing filter gives customizable "stickiness"
|
||||
|
||||
data = array([[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0, 0]], dtype=bool)
|
||||
|
||||
expected_1 = array([[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 0, 0]], dtype=bool)
|
||||
|
||||
expected_2 = array([[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0]], dtype=bool)
|
||||
|
||||
expected_3 = array([[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0]], dtype=bool)
|
||||
|
||||
expected_4 = array([[1, 1, 1, 1, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0, 0]], dtype=bool)
|
||||
|
||||
class Input(Filter):
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
|
||||
all_but_one = AtLeastN(inputs=[Input()],
|
||||
window_length=4,
|
||||
N=3)
|
||||
|
||||
all_but_two = AtLeastN(inputs=[Input()],
|
||||
window_length=4,
|
||||
N=2)
|
||||
|
||||
any_equiv = AtLeastN(inputs=[Input()],
|
||||
window_length=4,
|
||||
N=1)
|
||||
|
||||
all_equiv = AtLeastN(inputs=[Input()],
|
||||
window_length=4,
|
||||
N=4)
|
||||
|
||||
results = self.run_graph(
|
||||
TermGraph({
|
||||
'AllButOne': all_but_one,
|
||||
'AllButTwo': all_but_two,
|
||||
'AnyEquiv': any_equiv,
|
||||
'AllEquiv': all_equiv,
|
||||
'Any': Any(inputs=[Input()], window_length=4),
|
||||
'All': All(inputs=[Input()], window_length=4)
|
||||
}),
|
||||
initial_workspace={Input(): data},
|
||||
mask=self.build_mask(ones(shape=data.shape)),
|
||||
)
|
||||
|
||||
check_arrays(results['Any'], expected_1)
|
||||
check_arrays(results['AnyEquiv'], expected_1)
|
||||
check_arrays(results['AllButTwo'], expected_2)
|
||||
check_arrays(results['AllButOne'], expected_3)
|
||||
check_arrays(results['All'], expected_4)
|
||||
check_arrays(results['AllEquiv'], expected_4)
|
||||
|
||||
@parameter_space(factor_len=[2, 3, 4])
|
||||
def test_window_safe(self, factor_len):
|
||||
# all true data set of (days, securities)
|
||||
|
||||
@@ -9,12 +9,13 @@ from .filter import (
|
||||
PercentileFilter,
|
||||
SingleAsset,
|
||||
)
|
||||
from .smoothing import All, Any
|
||||
from .smoothing import All, Any, AtLeastN
|
||||
|
||||
__all__ = [
|
||||
'All',
|
||||
'Any',
|
||||
'ArrayPredicate',
|
||||
'AtLeastN',
|
||||
'CustomFilter',
|
||||
'Filter',
|
||||
'Latest',
|
||||
|
||||
@@ -33,3 +33,19 @@ class Any(CustomFilter):
|
||||
|
||||
def compute(self, today, assets, out, arg):
|
||||
out[:] = (arg.sum(axis=0) > 0)
|
||||
|
||||
|
||||
class AtLeastN(CustomFilter):
|
||||
"""
|
||||
A Filter requiring that assets produce True for at least N days in the
|
||||
last ``window_length`` days.
|
||||
|
||||
**Default Inputs:** None
|
||||
|
||||
**Default Window Length:** None
|
||||
"""
|
||||
|
||||
params = ('N',)
|
||||
|
||||
def compute(self, today, assets, out, arg, N):
|
||||
out[:] = (arg.sum(axis=0) >= N)
|
||||
|
||||
Reference in New Issue
Block a user