mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-03 20:03:41 +08:00
ENH: Rename StrictlyTrue to All and add Any().
Also, moved All() and Any() to `zipline.pipeline.filters.smoothing`.
This commit is contained in:
@@ -26,8 +26,10 @@ Enhancements
|
||||
:meth:`~zipline.pipeline.factors.Factor.top`, and
|
||||
:meth:`~zipline.pipeline.factors.Factor.bottom`. (:issue:`1349`).
|
||||
|
||||
- Added a smoothing filter that adds 'stickiness' to its input,
|
||||
making boolean designations less volatile over time. (:issue:`1358`)
|
||||
- Added new pipeline filters, :class:`~zipline.pipeline.filters.All` and
|
||||
:class:`~zipline.pipeline.filters.Any`, which takes another filter and
|
||||
returns True if an asset produced a True for any/all days in the previous
|
||||
``window_length`` days (:issue:`1358`).
|
||||
|
||||
Bug Fixes
|
||||
~~~~~~~~~
|
||||
|
||||
@@ -21,7 +21,6 @@ from numpy import (
|
||||
ones,
|
||||
ones_like,
|
||||
putmask,
|
||||
reshape,
|
||||
rot90,
|
||||
sum as np_sum
|
||||
)
|
||||
@@ -31,6 +30,7 @@ from zipline.errors import BadPercentileBounds
|
||||
from zipline.pipeline import Filter, Factor, TermGraph
|
||||
from zipline.pipeline.classifiers import Classifier
|
||||
from zipline.pipeline.factors import CustomFactor
|
||||
from zipline.pipeline.filters import All, Any
|
||||
from zipline.testing import check_arrays, parameter_space, permute_rows
|
||||
from zipline.utils.numpy_utils import float64_dtype, int64_dtype
|
||||
from .base import BasePipelineTestCase, with_default_shape
|
||||
@@ -396,41 +396,107 @@ class FilterTestCase(BasePipelineTestCase):
|
||||
)
|
||||
check_arrays(results['isfinite'], isfinite(data))
|
||||
|
||||
def test_strictly_true_filter(self):
|
||||
from zipline.pipeline.filters import StrictlyTrueFilter
|
||||
def test_all(self):
|
||||
|
||||
data = ~eye(N=self.default_shape[0],
|
||||
M=self.default_shape[1],
|
||||
k=1,
|
||||
dtype=bool)
|
||||
data = array([[1, 1, 1, 1, 1, 1],
|
||||
[0, 1, 1, 1, 1, 1],
|
||||
[1, 0, 1, 1, 1, 1],
|
||||
[1, 1, 0, 1, 1, 1],
|
||||
[1, 1, 1, 0, 1, 1],
|
||||
[1, 1, 1, 1, 0, 1],
|
||||
[1, 1, 1, 1, 1, 0]], dtype=bool)
|
||||
|
||||
class InputFilter(Filter):
|
||||
# With a window_length of N, 0's should be "sticky" for the (N - 1)
|
||||
# days after the 0 in the base data.
|
||||
|
||||
# Note that, the way ``self.run_graph`` works, we compute the same
|
||||
# number of output rows for all inputs, so we only get the last 4
|
||||
# outputs for expected_3 even though we have enought input data to
|
||||
# compute 5 rows.
|
||||
expected_3 = array([[0, 0, 0, 1, 1, 1],
|
||||
[1, 0, 0, 0, 1, 1],
|
||||
[1, 1, 0, 0, 0, 1],
|
||||
[1, 1, 1, 0, 0, 0]], dtype=bool)
|
||||
|
||||
expected_4 = array([[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 0, 1, 1],
|
||||
[1, 0, 0, 0, 0, 1],
|
||||
[1, 1, 0, 0, 0, 0]], dtype=bool)
|
||||
|
||||
class Input(Filter):
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
|
||||
strictly_true_filter = StrictlyTrueFilter(
|
||||
inputs=(InputFilter(), ),
|
||||
window_length=(self.default_shape[0]-1)
|
||||
results = self.run_graph(
|
||||
TermGraph({
|
||||
'3': All(inputs=[Input()], window_length=3),
|
||||
'4': All(inputs=[Input()], window_length=4),
|
||||
}),
|
||||
initial_workspace={Input(): data},
|
||||
mask=self.build_mask(ones(shape=data.shape)),
|
||||
)
|
||||
|
||||
check_arrays(results['3'], expected_3)
|
||||
check_arrays(results['4'], expected_4)
|
||||
|
||||
def test_any(self):
|
||||
|
||||
# FUN FACT: The inputs and outputs here are exactly the negation of
|
||||
# the inputs and outputs for test_all above. This isn't a coincidence.
|
||||
#
|
||||
# By de Morgan's Laws, we have::
|
||||
#
|
||||
# ~(a & b) == (~a | ~b)
|
||||
#
|
||||
# negating both sides, we have::
|
||||
#
|
||||
# (a & b) == ~(a | ~b)
|
||||
#
|
||||
# Since all(a, b) is isomorphic to (a & b), and any(a, b) is isomorphic
|
||||
# to (a | b), we have::
|
||||
#
|
||||
# all(a, b) == ~(any(~a, ~b))
|
||||
#
|
||||
data = array([[0, 0, 0, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0, 0],
|
||||
[0, 1, 0, 0, 0, 0],
|
||||
[0, 0, 1, 0, 0, 0],
|
||||
[0, 0, 0, 1, 0, 0],
|
||||
[0, 0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 0, 1]], dtype=bool)
|
||||
|
||||
# With a window_length of N, 1's should be "sticky" for the (N - 1)
|
||||
# days after the 1 in the base data.
|
||||
|
||||
# Note that, the way ``self.run_graph`` works, we compute the same
|
||||
# number of output rows for all inputs, so we only get the last 4
|
||||
# outputs for expected_3 even though we have enought input data to
|
||||
# compute 5 rows.
|
||||
expected_3 = array([[1, 1, 1, 0, 0, 0],
|
||||
[0, 1, 1, 1, 0, 0],
|
||||
[0, 0, 1, 1, 1, 0],
|
||||
[0, 0, 0, 1, 1, 1]], dtype=bool)
|
||||
|
||||
expected_4 = array([[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 1, 0, 0],
|
||||
[0, 1, 1, 1, 1, 0],
|
||||
[0, 0, 1, 1, 1, 1]], dtype=bool)
|
||||
|
||||
class Input(Filter):
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
|
||||
results = self.run_graph(
|
||||
TermGraph({'Filter': strictly_true_filter}),
|
||||
initial_workspace={InputFilter(): data}
|
||||
TermGraph({
|
||||
'3': Any(inputs=[Input()], window_length=3),
|
||||
'4': Any(inputs=[Input()], window_length=4),
|
||||
}),
|
||||
initial_workspace={Input(): data},
|
||||
mask=self.build_mask(ones(shape=data.shape)),
|
||||
)
|
||||
|
||||
expected_result_0 = full(self.default_shape[1], False, dtype=bool)
|
||||
expected_result_0[0] = True
|
||||
check_arrays(
|
||||
reshape(results['Filter'][0], expected_result_0.shape[0]),
|
||||
expected_result_0,
|
||||
)
|
||||
|
||||
expected_result_1 = full(self.default_shape[1], False, dtype=bool)
|
||||
expected_result_1[:2] = True
|
||||
check_arrays(
|
||||
reshape(results['Filter'][1], expected_result_1.shape[0]),
|
||||
expected_result_1,
|
||||
)
|
||||
check_arrays(results['3'], expected_3)
|
||||
check_arrays(results['4'], expected_4)
|
||||
|
||||
@parameter_space(factor_len=[2, 3, 4])
|
||||
def test_window_safe(self, factor_len):
|
||||
|
||||
@@ -8,10 +8,12 @@ from .filter import (
|
||||
NumExprFilter,
|
||||
PercentileFilter,
|
||||
SingleAsset,
|
||||
StrictlyTrueFilter,
|
||||
)
|
||||
from .smoothing import All, Any
|
||||
|
||||
__all__ = [
|
||||
'All',
|
||||
'Any',
|
||||
'ArrayPredicate',
|
||||
'CustomFilter',
|
||||
'Filter',
|
||||
@@ -21,5 +23,4 @@ __all__ = [
|
||||
'NumExprFilter',
|
||||
'PercentileFilter',
|
||||
'SingleAsset',
|
||||
'StrictlyTrueFilter',
|
||||
]
|
||||
|
||||
@@ -488,16 +488,3 @@ class SingleAsset(Filter):
|
||||
asset=self._asset, start_date=dates[0], end_date=dates[-1],
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
class StrictlyTrueFilter(CustomFilter):
|
||||
"""
|
||||
A Filter that requires its inputs to have
|
||||
been True for the last `window_length` days.
|
||||
|
||||
**Default Inputs**: None
|
||||
|
||||
**Default Window Length**: None
|
||||
"""
|
||||
def compute(self, today, assets, out, arg):
|
||||
out[:] = (arg.sum(axis=0) == self.window_length)
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
Filters that apply smoothing operations on other filters.
|
||||
|
||||
These are generally useful for controlling/minimizing turnover on existing
|
||||
Filters.
|
||||
"""
|
||||
from .filter import CustomFilter
|
||||
|
||||
|
||||
class All(CustomFilter):
|
||||
"""
|
||||
A Filter requiring that assets produce True for ``window_length``
|
||||
consecutive days.
|
||||
|
||||
**Default Inputs:** None
|
||||
|
||||
**Default Window Length:** None
|
||||
"""
|
||||
|
||||
def compute(self, today, assets, out, arg):
|
||||
out[:] = (arg.sum(axis=0) == self.window_length)
|
||||
|
||||
|
||||
class Any(CustomFilter):
|
||||
"""
|
||||
A Filter requiring that assets produce True for at least one day in the
|
||||
last ``window_length`` days.
|
||||
|
||||
**Default Inputs:** None
|
||||
|
||||
**Default Window Length:** None
|
||||
"""
|
||||
|
||||
def compute(self, today, assets, out, arg):
|
||||
out[:] = (arg.sum(axis=0) > 0)
|
||||
Reference in New Issue
Block a user