mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 14:47:08 +08:00
ENH: storing commits. test case added
This commit is contained in:
@@ -18,6 +18,7 @@ from numpy import (
|
||||
ones,
|
||||
ones_like,
|
||||
putmask,
|
||||
sum
|
||||
)
|
||||
from numpy.random import randn, seed as random_seed
|
||||
|
||||
@@ -379,3 +380,21 @@ class FilterTestCase(BasePipelineTestCase):
|
||||
initial_workspace={self.f: data},
|
||||
)
|
||||
check_arrays(results['isfinite'], isfinite(data))
|
||||
|
||||
def test_window_safe(self):
|
||||
"""
|
||||
Rationale : CustomFactors are *not all* window safe, while
|
||||
CustomFilters are
|
||||
"""
|
||||
data = array([[True, False, False, True],
|
||||
[False, True, False, True]], dtype=bool)
|
||||
|
||||
class CustomFilterBool(Filter):
|
||||
window_length = 0
|
||||
inputs = ()
|
||||
|
||||
filter_results = self.run_graph(
|
||||
TermGraph({'factor': sum(CustomFilterBool, axis=0)}),
|
||||
initial_workspace={CustomFilterBool: data},
|
||||
)
|
||||
check_arrays(filter_results['factor'], array([1, 1, 0, 2]))
|
||||
|
||||
@@ -170,7 +170,7 @@ class Filter(RestrictedDTypeMixin, ComputableTerm):
|
||||
|
||||
# make filters window safe
|
||||
window_safe = True
|
||||
|
||||
|
||||
ALLOWED_DTYPES = (bool_dtype,) # Used by RestrictedDTypeMixin
|
||||
dtype = bool_dtype
|
||||
|
||||
|
||||
Reference in New Issue
Block a user