mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 01:14:26 +08:00
Merge pull request #861 from quantopian/pipeline-isnan
ENH: Add isnan, notnan, and isfinite Factor methods.
This commit is contained in:
@@ -11,6 +11,8 @@ from numpy import (
|
||||
eye,
|
||||
float64,
|
||||
full_like,
|
||||
inf,
|
||||
isfinite,
|
||||
nan,
|
||||
nanpercentile,
|
||||
ones,
|
||||
@@ -333,3 +335,37 @@ class FilterTestCase(BasePipelineTestCase):
|
||||
dtype=bool,
|
||||
)
|
||||
check_arrays(results['with'], expected_with)
|
||||
|
||||
def test_isnan(self):
|
||||
data = self.randn_data(seed=10)
|
||||
diag = eye(*data.shape, dtype=bool)
|
||||
data[diag] = nan
|
||||
|
||||
results = self.run_graph(
|
||||
TermGraph({'isnan': self.f.isnan()}),
|
||||
initial_workspace={self.f: data},
|
||||
)
|
||||
check_arrays(results['isnan'], diag)
|
||||
|
||||
def test_notnan(self):
|
||||
data = self.randn_data(seed=10)
|
||||
diag = eye(*data.shape, dtype=bool)
|
||||
data[diag] = nan
|
||||
|
||||
results = self.run_graph(
|
||||
TermGraph({'notnan': self.f.notnan()}),
|
||||
initial_workspace={self.f: data},
|
||||
)
|
||||
check_arrays(results['notnan'], ~diag)
|
||||
|
||||
def test_isfinite(self):
|
||||
data = self.randn_data(seed=10)
|
||||
data[:, 0] = nan
|
||||
data[:, 2] = inf
|
||||
data[:, 4] = -inf
|
||||
|
||||
results = self.run_graph(
|
||||
TermGraph({'isfinite': self.f.isfinite()}),
|
||||
initial_workspace={self.f: data},
|
||||
)
|
||||
check_arrays(results['isfinite'], isfinite(data))
|
||||
|
||||
@@ -10,6 +10,7 @@ from numexpr.necompiler import getExprNames
|
||||
from numpy import (
|
||||
empty,
|
||||
find_common_type,
|
||||
inf,
|
||||
)
|
||||
|
||||
from zipline.pipeline.term import Term, NotSpecified, CompositeTerm
|
||||
@@ -212,6 +213,8 @@ class NumericalExpression(CompositeTerm):
|
||||
variable_names, _unused = getExprNames(self._expr, {})
|
||||
expr_indices = []
|
||||
for name in variable_names:
|
||||
if name == 'inf':
|
||||
continue
|
||||
match = _VARIABLE_NAME_RE.match(name)
|
||||
if not match:
|
||||
raise ValueError("%r is not a valid variable name" % name)
|
||||
@@ -239,7 +242,7 @@ class NumericalExpression(CompositeTerm):
|
||||
"x_%d" % idx: array
|
||||
for idx, array in enumerate(arrays)
|
||||
},
|
||||
global_dict={},
|
||||
global_dict={'inf': inf},
|
||||
out=out,
|
||||
)
|
||||
return out
|
||||
|
||||
@@ -8,6 +8,7 @@ from numpy import (
|
||||
apply_along_axis,
|
||||
float64,
|
||||
nan,
|
||||
inf,
|
||||
)
|
||||
from scipy.stats import rankdata
|
||||
|
||||
@@ -343,6 +344,29 @@ class Factor(CompositeTerm):
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
def isnan(self):
|
||||
"""
|
||||
A Filter producing True for all values where this Factor is NaN.
|
||||
"""
|
||||
return self != self
|
||||
|
||||
def notnan(self):
|
||||
"""
|
||||
A Filter producing True for values where this Factor is not NaN.
|
||||
|
||||
Returns
|
||||
-------
|
||||
nanfilter : zipline.pipeline.filters.Filter
|
||||
"""
|
||||
return ~self.isnan()
|
||||
|
||||
def isfinite(self):
|
||||
"""
|
||||
A Filter producing True for values where this Factor is anything but
|
||||
NaN, inf, or -inf.
|
||||
"""
|
||||
return (-inf < self) & (self < inf)
|
||||
|
||||
|
||||
class NumExprFactor(NumericalExpression, Factor):
|
||||
"""
|
||||
|
||||
@@ -83,6 +83,30 @@ def binary_operator(op):
|
||||
return binary_operator
|
||||
|
||||
|
||||
def unary_operator(op):
|
||||
"""
|
||||
Factory function for making unary operator methods for Filters.
|
||||
"""
|
||||
valid_ops = {'~'}
|
||||
if op not in valid_ops:
|
||||
raise ValueError("Invalid unary operator %s." % op)
|
||||
|
||||
def unary_operator(self):
|
||||
# This can't be hoisted up a scope because the types returned by
|
||||
# unary_op_return_type aren't defined when the top-level function is
|
||||
# invoked.
|
||||
if isinstance(self, NumericalExpression):
|
||||
return NumExprFilter(
|
||||
"{op}({expr})".format(op=op, expr=self._expr),
|
||||
self.inputs,
|
||||
)
|
||||
else:
|
||||
return NumExprFilter("{op}x_0".format(op=op), (self,))
|
||||
|
||||
unary_operator.__doc__ = "Unary Operator: '%s'" % op
|
||||
return unary_operator
|
||||
|
||||
|
||||
class Filter(CompositeTerm):
|
||||
"""
|
||||
Pipeline API expression producing boolean-valued outputs.
|
||||
@@ -96,6 +120,7 @@ class Filter(CompositeTerm):
|
||||
for op in FILTER_BINOPS
|
||||
}
|
||||
)
|
||||
__invert__ = unary_operator('~')
|
||||
|
||||
|
||||
class NumExprFilter(NumericalExpression, Filter):
|
||||
|
||||
Reference in New Issue
Block a user