mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-29 21:24:16 +08:00
ENH: Add isnull and notnull methods to Factor.
This commit is contained in:
@@ -23,11 +23,11 @@ from zipline.lib.adjustment import (
|
||||
)
|
||||
from zipline.lib.adjusted_array import AdjustedArray, NOMASK
|
||||
from zipline.utils.numpy_utils import (
|
||||
coerce_to_dtype,
|
||||
datetime64ns_dtype,
|
||||
default_missing_value_for_dtype,
|
||||
float64_dtype,
|
||||
int64_dtype,
|
||||
make_datetime64ns,
|
||||
)
|
||||
from zipline.utils.test_utils import check_arrays, parameter_space
|
||||
|
||||
@@ -62,18 +62,6 @@ def valid_window_lengths(underlying_buffer_length):
|
||||
return iter(range(1, underlying_buffer_length + 1))
|
||||
|
||||
|
||||
def value_with_dtype(dtype, value):
|
||||
"""
|
||||
Make a value with the specified numpy dtype.
|
||||
"""
|
||||
name = dtype.name
|
||||
if name.startswith('datetime64'):
|
||||
if name != 'datetime64[ns]':
|
||||
raise TypeError("Expected datetime64[ns], but got %s." % name)
|
||||
return make_datetime64ns(value)
|
||||
return dtype.type(value)
|
||||
|
||||
|
||||
def _gen_unadjusted_cases(dtype):
|
||||
|
||||
nrows = 6
|
||||
@@ -124,7 +112,7 @@ def _gen_multiplicative_adjustment_cases(dtype):
|
||||
|
||||
# Note that row indices are inclusive!
|
||||
adjustments[1] = [
|
||||
adjustment_type(0, 0, 0, 0, value_with_dtype(dtype, 2)),
|
||||
adjustment_type(0, 0, 0, 0, coerce_to_dtype(dtype, 2)),
|
||||
]
|
||||
buffer_as_of[1] = array([[2, 1, 1],
|
||||
[1, 1, 1],
|
||||
@@ -137,8 +125,8 @@ def _gen_multiplicative_adjustment_cases(dtype):
|
||||
buffer_as_of[2] = buffer_as_of[1]
|
||||
|
||||
adjustments[3] = [
|
||||
adjustment_type(1, 2, 1, 1, value_with_dtype(dtype, 3)),
|
||||
adjustment_type(0, 1, 0, 0, value_with_dtype(dtype, 4)),
|
||||
adjustment_type(1, 2, 1, 1, coerce_to_dtype(dtype, 3)),
|
||||
adjustment_type(0, 1, 0, 0, coerce_to_dtype(dtype, 4)),
|
||||
]
|
||||
buffer_as_of[3] = array([[8, 1, 1],
|
||||
[4, 3, 1],
|
||||
@@ -148,7 +136,7 @@ def _gen_multiplicative_adjustment_cases(dtype):
|
||||
[1, 1, 1]], dtype=dtype)
|
||||
|
||||
adjustments[4] = [
|
||||
adjustment_type(0, 3, 2, 2, value_with_dtype(dtype, 5))
|
||||
adjustment_type(0, 3, 2, 2, coerce_to_dtype(dtype, 5))
|
||||
]
|
||||
buffer_as_of[4] = array([[8, 1, 5],
|
||||
[4, 3, 5],
|
||||
@@ -158,8 +146,8 @@ def _gen_multiplicative_adjustment_cases(dtype):
|
||||
[1, 1, 1]], dtype=dtype)
|
||||
|
||||
adjustments[5] = [
|
||||
adjustment_type(0, 4, 1, 1, value_with_dtype(dtype, 6)),
|
||||
adjustment_type(2, 2, 2, 2, value_with_dtype(dtype, 7)),
|
||||
adjustment_type(0, 4, 1, 1, coerce_to_dtype(dtype, 6)),
|
||||
adjustment_type(2, 2, 2, 2, coerce_to_dtype(dtype, 7)),
|
||||
]
|
||||
buffer_as_of[5] = array([[8, 6, 5],
|
||||
[4, 18, 5],
|
||||
@@ -191,7 +179,7 @@ def _gen_overwrite_adjustment_cases(dtype):
|
||||
|
||||
# Note that row indices are inclusive!
|
||||
adjustments[1] = [
|
||||
adjustment_type(0, 0, 0, 0, value_with_dtype(dtype, 1)),
|
||||
adjustment_type(0, 0, 0, 0, coerce_to_dtype(dtype, 1)),
|
||||
]
|
||||
buffer_as_of[1] = array([[1, 2, 2],
|
||||
[2, 2, 2],
|
||||
@@ -204,8 +192,8 @@ def _gen_overwrite_adjustment_cases(dtype):
|
||||
buffer_as_of[2] = buffer_as_of[1]
|
||||
|
||||
adjustments[3] = [
|
||||
adjustment_type(1, 2, 1, 1, value_with_dtype(dtype, 3)),
|
||||
adjustment_type(0, 1, 0, 0, value_with_dtype(dtype, 4)),
|
||||
adjustment_type(1, 2, 1, 1, coerce_to_dtype(dtype, 3)),
|
||||
adjustment_type(0, 1, 0, 0, coerce_to_dtype(dtype, 4)),
|
||||
]
|
||||
buffer_as_of[3] = array([[4, 2, 2],
|
||||
[4, 3, 2],
|
||||
@@ -215,7 +203,7 @@ def _gen_overwrite_adjustment_cases(dtype):
|
||||
[2, 2, 2]], dtype=dtype)
|
||||
|
||||
adjustments[4] = [
|
||||
adjustment_type(0, 3, 2, 2, value_with_dtype(dtype, 5))
|
||||
adjustment_type(0, 3, 2, 2, coerce_to_dtype(dtype, 5))
|
||||
]
|
||||
buffer_as_of[4] = array([[4, 2, 5],
|
||||
[4, 3, 5],
|
||||
@@ -225,8 +213,8 @@ def _gen_overwrite_adjustment_cases(dtype):
|
||||
[2, 2, 2]], dtype=dtype)
|
||||
|
||||
adjustments[5] = [
|
||||
adjustment_type(0, 4, 1, 1, value_with_dtype(dtype, 6)),
|
||||
adjustment_type(2, 2, 2, 2, value_with_dtype(dtype, 7)),
|
||||
adjustment_type(0, 4, 1, 1, coerce_to_dtype(dtype, 6)),
|
||||
adjustment_type(2, 2, 2, 2, coerce_to_dtype(dtype, 7)),
|
||||
]
|
||||
buffer_as_of[5] = array([[4, 6, 5],
|
||||
[4, 6, 5],
|
||||
@@ -335,7 +323,7 @@ class AdjustedArrayTestCase(TestCase):
|
||||
window_length=[2, 3],
|
||||
)
|
||||
def test_masking(self, dtype, missing_value, window_length):
|
||||
missing_value = value_with_dtype(dtype, missing_value)
|
||||
missing_value = coerce_to_dtype(dtype, missing_value)
|
||||
baseline_ints = arange(15).reshape(5, 3)
|
||||
baseline = baseline_ints.astype(dtype)
|
||||
mask = (baseline_ints % 2).astype(bool)
|
||||
|
||||
@@ -22,10 +22,15 @@ from zipline.pipeline.factors import (
|
||||
Returns,
|
||||
RSI,
|
||||
)
|
||||
from zipline.utils.test_utils import check_allclose, check_arrays
|
||||
from zipline.utils.test_utils import (
|
||||
check_allclose,
|
||||
check_arrays,
|
||||
parameter_space,
|
||||
)
|
||||
from zipline.utils.numpy_utils import (
|
||||
datetime64ns_dtype,
|
||||
float64_dtype,
|
||||
int64_dtype,
|
||||
NaTns,
|
||||
)
|
||||
|
||||
@@ -59,6 +64,73 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
with self.assertRaises(UnknownRankMethod):
|
||||
self.f.rank("not a real rank method")
|
||||
|
||||
@parameter_space(method_name=['isnan', 'notnan', 'isfinite'])
|
||||
def test_float64_only_ops(self, method_name):
|
||||
class NotFloat(Factor):
|
||||
dtype = datetime64ns_dtype
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
|
||||
nf = NotFloat()
|
||||
meth = getattr(nf, method_name)
|
||||
with self.assertRaises(TypeError):
|
||||
meth()
|
||||
|
||||
@parameter_space(custom_missing_value=[-1, 0])
|
||||
def test_isnull_int_dtype(self, custom_missing_value):
|
||||
|
||||
class CustomMissingValue(Factor):
|
||||
dtype = int64_dtype
|
||||
window_length = 0
|
||||
missing_value = custom_missing_value
|
||||
inputs = ()
|
||||
|
||||
factor = CustomMissingValue()
|
||||
|
||||
data = arange(25).reshape(5, 5)
|
||||
data[eye(5, dtype=bool)] = custom_missing_value
|
||||
|
||||
graph = TermGraph(
|
||||
{
|
||||
'isnull': factor.isnull(),
|
||||
'notnull': factor.notnull(),
|
||||
}
|
||||
)
|
||||
|
||||
results = self.run_graph(
|
||||
graph,
|
||||
initial_workspace={factor: data},
|
||||
mask=self.build_mask(ones((5, 5))),
|
||||
)
|
||||
check_arrays(results['isnull'], eye(5, dtype=bool))
|
||||
check_arrays(results['notnull'], ~eye(5, dtype=bool))
|
||||
|
||||
def test_isnull_datetime_dtype(self):
|
||||
class DatetimeFactor(Factor):
|
||||
dtype = datetime64ns_dtype
|
||||
window_length = 0
|
||||
inputs = ()
|
||||
|
||||
factor = DatetimeFactor()
|
||||
|
||||
data = arange(25).reshape(5, 5).astype('datetime64[ns]')
|
||||
data[eye(5, dtype=bool)] = NaTns
|
||||
|
||||
graph = TermGraph(
|
||||
{
|
||||
'isnull': factor.isnull(),
|
||||
'notnull': factor.notnull(),
|
||||
}
|
||||
)
|
||||
|
||||
results = self.run_graph(
|
||||
graph,
|
||||
initial_workspace={factor: data},
|
||||
mask=self.build_mask(ones((5, 5))),
|
||||
)
|
||||
check_arrays(results['isnull'], eye(5, dtype=bool))
|
||||
check_arrays(results['notnull'], ~eye(5, dtype=bool))
|
||||
|
||||
@for_each_factor_dtype
|
||||
def test_rank_ascending(self, name, factor_dtype):
|
||||
|
||||
|
||||
@@ -345,10 +345,14 @@ class FilterTestCase(BasePipelineTestCase):
|
||||
data[diag] = nan
|
||||
|
||||
results = self.run_graph(
|
||||
TermGraph({'isnan': self.f.isnan()}),
|
||||
TermGraph({
|
||||
'isnan': self.f.isnan(),
|
||||
'isnull': self.f.isnull(),
|
||||
}),
|
||||
initial_workspace={self.f: data},
|
||||
)
|
||||
check_arrays(results['isnan'], diag)
|
||||
check_arrays(results['isnull'], diag)
|
||||
|
||||
def test_notnan(self):
|
||||
data = self.randn_data(seed=10)
|
||||
@@ -356,10 +360,14 @@ class FilterTestCase(BasePipelineTestCase):
|
||||
data[diag] = nan
|
||||
|
||||
results = self.run_graph(
|
||||
TermGraph({'notnan': self.f.notnan()}),
|
||||
TermGraph({
|
||||
'notnan': self.f.notnan(),
|
||||
'notnull': self.f.notnull(),
|
||||
}),
|
||||
initial_workspace={self.f: data},
|
||||
)
|
||||
check_arrays(results['notnan'], ~diag)
|
||||
check_arrays(results['notnull'], ~diag)
|
||||
|
||||
def test_isfinite(self):
|
||||
data = self.randn_data(seed=10)
|
||||
|
||||
+17
-6
@@ -17,10 +17,26 @@ from numpy cimport (
|
||||
from numpy import apply_along_axis, float64, isnan, nan
|
||||
from scipy.stats import rankdata
|
||||
|
||||
from zipline.utils.numpy_utils import (
|
||||
is_float,
|
||||
float64_dtype,
|
||||
int64_dtype,
|
||||
datetime64ns_dtype,
|
||||
)
|
||||
|
||||
|
||||
import_array()
|
||||
|
||||
|
||||
cpdef ismissing(ndarray data, object missing_value):
|
||||
"""
|
||||
Generic ismissing function that handles quirks with NaN.
|
||||
"""
|
||||
if is_float(data) and isnan(missing_value):
|
||||
return isnan(data)
|
||||
return (data == missing_value)
|
||||
|
||||
|
||||
def masked_rankdata_2d(ndarray data,
|
||||
ndarray mask,
|
||||
object missing_value,
|
||||
@@ -35,12 +51,7 @@ def masked_rankdata_2d(ndarray data,
|
||||
"Can't compute rankdata on array of dtype %r." % dtype_name
|
||||
)
|
||||
|
||||
cdef ndarray missing_locations = ~mask
|
||||
# Mask out any entries that are equal to the missing value.
|
||||
if dtype_name == 'float64' and isnan(missing_value):
|
||||
missing_locations |= isnan(data)
|
||||
else:
|
||||
missing_locations |= (data == missing_value)
|
||||
cdef ndarray missing_locations = (~mask | ismissing(data, missing_value))
|
||||
|
||||
# Interpret the bytes of integral data as floats for sorting.
|
||||
data = data.copy().view(float64)
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
"""
|
||||
factor.py
|
||||
"""
|
||||
from functools import wraps
|
||||
from operator import attrgetter
|
||||
from numbers import Number
|
||||
|
||||
from numpy import float64, inf
|
||||
from numpy import inf
|
||||
from toolz import curry
|
||||
|
||||
from zipline.errors import (
|
||||
@@ -32,30 +33,43 @@ from zipline.pipeline.expression import (
|
||||
from zipline.pipeline.filters import (
|
||||
NumExprFilter,
|
||||
PercentileFilter,
|
||||
NullFilter,
|
||||
)
|
||||
from zipline.utils.control_flow import nullctx
|
||||
from zipline.utils.numpy_utils import (
|
||||
bool_dtype,
|
||||
coerce_to_dtype,
|
||||
datetime64ns_dtype,
|
||||
float64_dtype,
|
||||
int64_dtype,
|
||||
)
|
||||
from zipline.utils.preprocess import preprocess
|
||||
|
||||
|
||||
_RANK_METHODS = frozenset(['average', 'min', 'max', 'dense', 'ordinal'])
|
||||
|
||||
|
||||
def numbers_to_float64(func, argname, argvalue):
|
||||
def coerce_numbers_to_my_dtype(f):
|
||||
"""
|
||||
Preprocessor for converting numerical inputs into floats.
|
||||
A decorator for methods whose signature is f(self, other) that coerces
|
||||
``other`` to ``self.dtype``.
|
||||
|
||||
This is used in the binary operator constructors for Factor so that
|
||||
`2 + Factor()` has the same behavior as `2.0 + Factor()`.
|
||||
This is used to make comparison operations between numbers and `Factor`
|
||||
instances work independently of whether the user supplies a float or
|
||||
integer literal.
|
||||
|
||||
For example, if I write::
|
||||
|
||||
my_filter = my_factor > 3
|
||||
|
||||
my_factor probably has dtype float64, but 3 is an int, so we want to coerce
|
||||
to float64 before doing the comparison.
|
||||
"""
|
||||
if isinstance(argvalue, Number):
|
||||
return float64(argvalue)
|
||||
return argvalue
|
||||
@wraps(f)
|
||||
def method(self, other):
|
||||
if isinstance(other, Number):
|
||||
other = coerce_to_dtype(self.dtype, other)
|
||||
return f(self, other)
|
||||
return method
|
||||
|
||||
|
||||
@curry
|
||||
@@ -148,9 +162,9 @@ def binary_operator(op):
|
||||
# NumericalExpression operator.
|
||||
commuted_method_getter = attrgetter(method_name_for_op(op, commute=True))
|
||||
|
||||
@preprocess(other=numbers_to_float64)
|
||||
@with_doc("Binary Operator: '%s'" % op)
|
||||
@with_name(method_name_for_op(op))
|
||||
@coerce_numbers_to_my_dtype
|
||||
def binary_operator(self, other):
|
||||
# This can't be hoisted up a scope because the types returned by
|
||||
# binop_return_type aren't defined when the top-level function is
|
||||
@@ -207,8 +221,8 @@ def reflected_binary_operator(op):
|
||||
"""
|
||||
assert not is_comparison(op)
|
||||
|
||||
@preprocess(other=numbers_to_float64)
|
||||
@with_name(method_name_for_op(op, commute=True))
|
||||
@coerce_numbers_to_my_dtype
|
||||
def reflected_binary_operator(self, other):
|
||||
|
||||
if isinstance(self, NumericalExpression):
|
||||
@@ -304,6 +318,28 @@ def function_application(func):
|
||||
return mathfunc
|
||||
|
||||
|
||||
def if_not_float64_tell_caller_to_use_isnull(f):
|
||||
"""
|
||||
Factor method decorator that checks if self.dtype if float64.
|
||||
|
||||
If the factor instance is of another dtype, this raises a TypeError
|
||||
directing the user to `isnull` or `notnull` instead.
|
||||
"""
|
||||
@wraps(f)
|
||||
def wrapped_method(self, *args, **kwargs):
|
||||
if self.dtype != float64_dtype:
|
||||
raise TypeError(
|
||||
"{meth}() was called on a factor of dtype {dtype}.\n"
|
||||
"{meth}() is only defined for dtype float64."
|
||||
"To filter missing data, use isnull() or notnull().".format(
|
||||
meth=f.__name__,
|
||||
dtype=self.dtype,
|
||||
),
|
||||
)
|
||||
return f(self, *args, **kwargs)
|
||||
return wrapped_method
|
||||
|
||||
|
||||
FACTOR_DTYPES = frozenset([datetime64ns_dtype, float64_dtype, int64_dtype])
|
||||
|
||||
|
||||
@@ -476,6 +512,34 @@ class Factor(CompositeTerm):
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
def isnull(self):
|
||||
"""
|
||||
A Filter producing True for values where this Factor has missing data.
|
||||
|
||||
Equivalent to self.isnan() when ``self.dtype`` is float64.
|
||||
Otherwise equivalent to ``self.eq(self.missing_value)``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
filter : zipline.pipeline.filters.Filter
|
||||
"""
|
||||
if self.dtype == float64_dtype:
|
||||
# Using isnan is more efficient when possible because we can fold
|
||||
# the isnan computation with other NumExpr expressions.
|
||||
return self.isnan()
|
||||
else:
|
||||
return NullFilter(self)
|
||||
|
||||
def notnull(self):
|
||||
"""
|
||||
A Filter producing True for values where this Factor has complete data.
|
||||
|
||||
Equivalent to ``~self.isnan()` when ``self.dtype`` is float64.
|
||||
Otherwise equivalent to ``(self != self.missing_value)``.
|
||||
"""
|
||||
return ~self.isnull()
|
||||
|
||||
@if_not_float64_tell_caller_to_use_isnull
|
||||
def isnan(self):
|
||||
"""
|
||||
A Filter producing True for all values where this Factor is NaN.
|
||||
@@ -486,6 +550,7 @@ class Factor(CompositeTerm):
|
||||
"""
|
||||
return self != self
|
||||
|
||||
@if_not_float64_tell_caller_to_use_isnull
|
||||
def notnan(self):
|
||||
"""
|
||||
A Filter producing True for values where this Factor is not NaN.
|
||||
@@ -496,6 +561,7 @@ class Factor(CompositeTerm):
|
||||
"""
|
||||
return ~self.isnan()
|
||||
|
||||
@if_not_float64_tell_caller_to_use_isnull
|
||||
def isfinite(self):
|
||||
"""
|
||||
A Filter producing True for values where this Factor is anything but
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from .filter import Filter, NumExprFilter, PercentileFilter
|
||||
from .filter import Filter, NumExprFilter, NullFilter, PercentileFilter
|
||||
from .latest import Latest
|
||||
|
||||
__all__ = [
|
||||
'Filter',
|
||||
'Latest',
|
||||
'NumExprFilter',
|
||||
'NullFilter',
|
||||
'PercentileFilter',
|
||||
]
|
||||
|
||||
@@ -13,6 +13,7 @@ from zipline.errors import (
|
||||
BadPercentileBounds,
|
||||
UnsupportedDataType,
|
||||
)
|
||||
from zipline.lib.rank import ismissing
|
||||
from zipline.pipeline.mixins import (
|
||||
CustomTermMixin,
|
||||
PositiveWindowLengthMixin,
|
||||
@@ -173,6 +174,27 @@ class NumExprFilter(NumericalExpression, Filter):
|
||||
) & mask
|
||||
|
||||
|
||||
class NullFilter(SingleInputMixin, Filter):
|
||||
"""
|
||||
A Filter indicating whether an input input values are missing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
factor zipline.pipeline.factor.Factor
|
||||
The factor to compare with null.
|
||||
"""
|
||||
window_length = 0
|
||||
|
||||
def __new__(cls, factor):
|
||||
return super(NullFilter, cls).__new__(
|
||||
cls,
|
||||
inputs=(factor,),
|
||||
)
|
||||
|
||||
def _compute(self, arrays, dates, assets, mask):
|
||||
return ismissing(arrays[0], self.inputs[0].missing_value)
|
||||
|
||||
|
||||
class PercentileFilter(SingleInputMixin, Filter):
|
||||
"""
|
||||
A Filter representing assets falling between percentile bounds of a Factor.
|
||||
|
||||
Reference in New Issue
Block a user