mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-04 01:56:36 +08:00
ENH: Make element_of work for ints too.
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from operator import or_
|
||||
|
||||
import numpy as np
|
||||
|
||||
from zipline.lib.labelarray import LabelArray
|
||||
@@ -286,7 +288,7 @@ class ClassifierTestCase(BasePipelineTestCase):
|
||||
container_type=(set, list, tuple, frozenset),
|
||||
labelarray_dtype=(categorical_dtype, bytes_dtype, unicode_dtype),
|
||||
)
|
||||
def test_element_of(self, container_type, labelarray_dtype):
|
||||
def test_element_of_strings(self, container_type, labelarray_dtype):
|
||||
|
||||
missing = labelarray_dtype.type("not in the array")
|
||||
|
||||
@@ -331,6 +333,42 @@ class ClassifierTestCase(BasePipelineTestCase):
|
||||
mask=self.build_mask(self.ones_mask(shape=data.shape)),
|
||||
)
|
||||
|
||||
def test_element_of_integral(self):
|
||||
"""
|
||||
Element of is well-defined for integral classifiers.
|
||||
"""
|
||||
class C(Classifier):
|
||||
dtype = int64_dtype
|
||||
missing_value = -1
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
|
||||
c = C()
|
||||
|
||||
# There's no significance to the values here other than that they
|
||||
# contain a mix of missing and non-missing values.
|
||||
data = np.array([[-1, 1, 0, 2],
|
||||
[3, 0, 1, 0],
|
||||
[-5, 0, -1, 0],
|
||||
[-3, 1, 2, 2]], dtype=int64_dtype)
|
||||
|
||||
terms = {}
|
||||
expected = {}
|
||||
for choices in [(0,), (0, 1), (0, 1, 2)]:
|
||||
terms[str(choices)] = c.element_of(choices)
|
||||
expected[str(choices)] = reduce(
|
||||
or_,
|
||||
(data == elem for elem in choices),
|
||||
np.zeros_like(data, dtype=bool),
|
||||
)
|
||||
|
||||
self.check_terms(
|
||||
terms=terms,
|
||||
expected=expected,
|
||||
initial_workspace={c: data},
|
||||
mask=self.build_mask(self.ones_mask(shape=data.shape)),
|
||||
)
|
||||
|
||||
def test_element_of_rejects_missing_value(self):
|
||||
"""
|
||||
Test that element_of raises a useful error if we attempt to pass it an
|
||||
@@ -360,10 +398,11 @@ class ClassifierTestCase(BasePipelineTestCase):
|
||||
)
|
||||
self.assertEqual(errmsg, expected)
|
||||
|
||||
def test_element_of_rejects_unhashable_type(self):
|
||||
@parameter_space(dtype_=Classifier.ALLOWED_DTYPES)
|
||||
def test_element_of_rejects_unhashable_type(self, dtype_):
|
||||
|
||||
class C(Classifier):
|
||||
dtype = categorical_dtype
|
||||
dtype = dtype_
|
||||
missing_value = ''
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
@@ -375,7 +414,7 @@ class ClassifierTestCase(BasePipelineTestCase):
|
||||
|
||||
errmsg = str(e.exception)
|
||||
expected = (
|
||||
"Expected `choices` to be an iterable of strings,"
|
||||
"Expected `choices` to be an iterable of hashable values,"
|
||||
" but got [{'a': 1}] instead.\n"
|
||||
"This caused the following error: "
|
||||
"TypeError(\"unhashable type: 'dict'\",)."
|
||||
|
||||
@@ -16,9 +16,10 @@ from zipline.utils.input_validation import expect_types
|
||||
from zipline.utils.numpy_utils import (
|
||||
categorical_dtype,
|
||||
int64_dtype,
|
||||
vectorized_is_element,
|
||||
)
|
||||
|
||||
from ..filters import Filter, NullFilter, NumExprFilter
|
||||
from ..filters import ArrayPredicate, NullFilter, NumExprFilter
|
||||
from ..mixins import (
|
||||
CustomTermMixin,
|
||||
LatestMixin,
|
||||
@@ -96,10 +97,10 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm):
|
||||
binds=(self,),
|
||||
)
|
||||
else:
|
||||
return StringPredicate(
|
||||
classifier=self,
|
||||
return ArrayPredicate(
|
||||
term=self,
|
||||
op=operator.eq,
|
||||
compval=other,
|
||||
opargs=(other,),
|
||||
)
|
||||
|
||||
def __ne__(self, other):
|
||||
@@ -119,11 +120,8 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm):
|
||||
binds=(self,),
|
||||
)
|
||||
else:
|
||||
return StringPredicate(
|
||||
classifier=self,
|
||||
op=operator.ne,
|
||||
compval=other,
|
||||
)
|
||||
# Numexpr doesn't know how to use LabelArrays.
|
||||
return ArrayPredicate(term=self, op=operator.ne, opargs=(other,))
|
||||
|
||||
@string_classifiers_only
|
||||
@expect_types(prefix=(bytes, unicode))
|
||||
@@ -142,10 +140,10 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm):
|
||||
Filter returning True for all sid/date pairs for which ``self``
|
||||
produces a string starting with ``prefix``.
|
||||
"""
|
||||
return StringPredicate(
|
||||
classifier=self,
|
||||
return ArrayPredicate(
|
||||
term=self,
|
||||
op=LabelArray.startswith,
|
||||
compval=prefix,
|
||||
opargs=(prefix,),
|
||||
)
|
||||
|
||||
@string_classifiers_only
|
||||
@@ -165,10 +163,10 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm):
|
||||
Filter returning True for all sid/date pairs for which ``self``
|
||||
produces a string ending with ``prefix``.
|
||||
"""
|
||||
return StringPredicate(
|
||||
classifier=self,
|
||||
return ArrayPredicate(
|
||||
term=self,
|
||||
op=LabelArray.endswith,
|
||||
compval=suffix,
|
||||
opargs=(suffix,),
|
||||
)
|
||||
|
||||
@string_classifiers_only
|
||||
@@ -188,10 +186,10 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm):
|
||||
Filter returning True for all sid/date pairs for which ``self``
|
||||
produces a string containing ``substring``.
|
||||
"""
|
||||
return StringPredicate(
|
||||
classifier=self,
|
||||
return ArrayPredicate(
|
||||
term=self,
|
||||
op=LabelArray.has_substring,
|
||||
compval=substring,
|
||||
opargs=(substring,),
|
||||
)
|
||||
|
||||
@string_classifiers_only
|
||||
@@ -215,33 +213,32 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm):
|
||||
--------
|
||||
https://docs.python.org/library/re.html
|
||||
"""
|
||||
return StringPredicate(
|
||||
classifier=self,
|
||||
return ArrayPredicate(
|
||||
term=self,
|
||||
op=LabelArray.matches,
|
||||
compval=pattern,
|
||||
opargs=(pattern,),
|
||||
)
|
||||
|
||||
@string_classifiers_only
|
||||
def element_of(self, choices):
|
||||
"""
|
||||
Construct a Filter indicating whether values are in ``choices``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
choices : iterable[str]
|
||||
choices : iterable[str or int]
|
||||
An iterable of choices.
|
||||
|
||||
Returns
|
||||
-------
|
||||
matches : Filter
|
||||
Filter returning True for all sid/date pairs for which ``self``
|
||||
produces a string in ``choices``.
|
||||
produces an entry in ``choices``.
|
||||
"""
|
||||
try:
|
||||
choices = frozenset(choices)
|
||||
except Exception as e:
|
||||
raise TypeError(
|
||||
"Expected `choices` to be an iterable of strings,"
|
||||
"Expected `choices` to be an iterable of hashable values,"
|
||||
" but got {} instead.\n"
|
||||
"This caused the following error: {!r}.".format(choices, e)
|
||||
)
|
||||
@@ -260,11 +257,40 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm):
|
||||
)
|
||||
)
|
||||
|
||||
return StringPredicate(
|
||||
classifier=self,
|
||||
op=LabelArray.element_of,
|
||||
compval=choices,
|
||||
)
|
||||
def only_contains(type_, values):
|
||||
return all(isinstance(v, type_) for v in values)
|
||||
|
||||
if self.dtype == int64_dtype:
|
||||
if only_contains(int, choices):
|
||||
return ArrayPredicate(
|
||||
term=self,
|
||||
op=vectorized_is_element,
|
||||
opargs=(choices,),
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
"Found non-int in choices for {typename}.element_of.\n"
|
||||
"Supplied choices were {choices}.".format(
|
||||
typename=type(self).__name__,
|
||||
choices=choices,
|
||||
)
|
||||
)
|
||||
elif self.dtype == categorical_dtype:
|
||||
if only_contains((bytes, unicode), choices):
|
||||
return ArrayPredicate(
|
||||
term=self,
|
||||
op=LabelArray.element_of,
|
||||
opargs=(choices,),
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
"Found non-string in choices for {typename}.element_of.\n"
|
||||
"Supplied choices were {choices}.".format(
|
||||
typename=type(self).__name__,
|
||||
choices=choices,
|
||||
)
|
||||
)
|
||||
assert False, "Unknown dtype in Classifier.element_of %s." % self.dtype
|
||||
|
||||
def postprocess(self, data):
|
||||
if self.dtype == int64_dtype:
|
||||
@@ -314,44 +340,6 @@ class Quantiles(SingleInputMixin, Classifier):
|
||||
return type(self).__name__ + '(%d)' % self.params['bins']
|
||||
|
||||
|
||||
class StringPredicate(SingleInputMixin, Filter):
|
||||
"""
|
||||
A filter applying a function from (LabelArray, hashable) -> ndarray[bool].
|
||||
|
||||
Examples include ``==, !=, startswith, and is_element``.
|
||||
"""
|
||||
window_length = 0
|
||||
|
||||
def __new__(cls, classifier, op, compval):
|
||||
return super(StringPredicate, cls).__new__(
|
||||
StringPredicate,
|
||||
compval=compval,
|
||||
op=op,
|
||||
inputs=(classifier,),
|
||||
mask=classifier.mask,
|
||||
)
|
||||
|
||||
def _init(self, op, compval, *args, **kwargs):
|
||||
self._op = op
|
||||
self._compval = compval
|
||||
return super(StringPredicate, self)._init(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def static_identity(cls, op, compval, *args, **kwargs):
|
||||
return (
|
||||
super(StringPredicate, cls).static_identity(*args, **kwargs),
|
||||
op,
|
||||
compval,
|
||||
)
|
||||
|
||||
def _compute(self, arrays, dates, assets, mask):
|
||||
data = arrays[0]
|
||||
return (
|
||||
self._op(data, self._compval)
|
||||
& mask
|
||||
)
|
||||
|
||||
|
||||
class CustomClassifier(PositiveWindowLengthMixin,
|
||||
StandardOutputs,
|
||||
CustomTermMixin,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from .filter import (
|
||||
ArrayPredicate,
|
||||
CustomFilter,
|
||||
Filter,
|
||||
Latest,
|
||||
@@ -8,6 +9,7 @@ from .filter import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'ArrayPredicate',
|
||||
'CustomFilter',
|
||||
'Filter',
|
||||
'Latest',
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
"""
|
||||
filter.py
|
||||
"""
|
||||
from itertools import chain
|
||||
from operator import attrgetter
|
||||
|
||||
|
||||
from numpy import (
|
||||
float64,
|
||||
nan,
|
||||
nanpercentile,
|
||||
)
|
||||
from itertools import chain
|
||||
from operator import attrgetter
|
||||
|
||||
from zipline.errors import (
|
||||
BadPercentileBounds,
|
||||
UnsupportedDataType,
|
||||
@@ -28,6 +29,7 @@ from zipline.pipeline.expression import (
|
||||
method_name_for_op,
|
||||
NumericalExpression,
|
||||
)
|
||||
from zipline.utils.input_validation import expect_types
|
||||
from zipline.utils.numpy_utils import bool_dtype
|
||||
|
||||
|
||||
@@ -372,6 +374,50 @@ class CustomFilter(PositiveWindowLengthMixin, CustomTermMixin, Filter):
|
||||
"""
|
||||
|
||||
|
||||
class ArrayPredicate(SingleInputMixin, Filter):
|
||||
"""
|
||||
A filter applying a function from (ndarray, *args) -> ndarray[bool].
|
||||
|
||||
Parameters
|
||||
----------
|
||||
term : zipline.pipeline.Term
|
||||
Term producing the array over which the predicate will be computed.
|
||||
op : function(ndarray, *args) -> ndarray[bool]
|
||||
Function to apply to the result of `term`.
|
||||
opargs : tuple[hashable]
|
||||
Additional argument to apply to ``op``.
|
||||
"""
|
||||
window_length = 0
|
||||
|
||||
@expect_types(term=Term, opargs=tuple)
|
||||
def __new__(cls, term, op, opargs):
|
||||
hash(opargs) # fail fast if opargs isn't hashable.
|
||||
return super(ArrayPredicate, cls).__new__(
|
||||
ArrayPredicate,
|
||||
op=op,
|
||||
opargs=opargs,
|
||||
inputs=(term,),
|
||||
mask=term.mask,
|
||||
)
|
||||
|
||||
def _init(self, op, opargs, *args, **kwargs):
|
||||
self._op = op
|
||||
self._opargs = opargs
|
||||
return super(ArrayPredicate, self)._init(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def static_identity(cls, op, opargs, *args, **kwargs):
|
||||
return (
|
||||
super(ArrayPredicate, cls).static_identity(*args, **kwargs),
|
||||
op,
|
||||
opargs,
|
||||
)
|
||||
|
||||
def _compute(self, arrays, dates, assets, mask):
|
||||
data = arrays[0]
|
||||
return self._op(data, *self._opargs) & mask
|
||||
|
||||
|
||||
class Latest(LatestMixin, CustomFilter):
|
||||
"""
|
||||
Filter producing the most recently-known value of `inputs[0]` on each day.
|
||||
|
||||
@@ -14,6 +14,7 @@ from numpy import (
|
||||
dtype,
|
||||
empty,
|
||||
nan,
|
||||
vectorize,
|
||||
where
|
||||
)
|
||||
from numpy.lib.stride_tricks import as_strided
|
||||
@@ -344,3 +345,21 @@ def ignore_nanwarnings():
|
||||
{'category': RuntimeWarning, 'module': 'numpy.lib.nanfunctions'},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def vectorized_is_element(array, choices):
|
||||
"""
|
||||
Check if each element of ``array`` is in choices.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
array : np.ndarray
|
||||
choices : object
|
||||
Object implementing __contains__.
|
||||
|
||||
Returns
|
||||
-------
|
||||
was_element : np.ndarray[bool]
|
||||
Array indicating whether each element of ``array`` was in ``choices``.
|
||||
"""
|
||||
return vectorize(choices.__contains__, otypes=[bool])(array)
|
||||
|
||||
Reference in New Issue
Block a user