ENH: Make element_of work for ints too.

This commit is contained in:
Scott Sanderson
2016-05-04 16:02:47 -04:00
parent 4357673221
commit 5a1ed7b1d3
5 changed files with 169 additions and 75 deletions
+43 -4
View File
@@ -1,3 +1,5 @@
from operator import or_
import numpy as np
from zipline.lib.labelarray import LabelArray
@@ -286,7 +288,7 @@ class ClassifierTestCase(BasePipelineTestCase):
container_type=(set, list, tuple, frozenset),
labelarray_dtype=(categorical_dtype, bytes_dtype, unicode_dtype),
)
def test_element_of(self, container_type, labelarray_dtype):
def test_element_of_strings(self, container_type, labelarray_dtype):
missing = labelarray_dtype.type("not in the array")
@@ -331,6 +333,42 @@ class ClassifierTestCase(BasePipelineTestCase):
mask=self.build_mask(self.ones_mask(shape=data.shape)),
)
def test_element_of_integral(self):
"""
Element of is well-defined for integral classifiers.
"""
class C(Classifier):
dtype = int64_dtype
missing_value = -1
inputs = ()
window_length = 0
c = C()
# There's no significance to the values here other than that they
# contain a mix of missing and non-missing values.
data = np.array([[-1, 1, 0, 2],
[3, 0, 1, 0],
[-5, 0, -1, 0],
[-3, 1, 2, 2]], dtype=int64_dtype)
terms = {}
expected = {}
for choices in [(0,), (0, 1), (0, 1, 2)]:
terms[str(choices)] = c.element_of(choices)
expected[str(choices)] = reduce(
or_,
(data == elem for elem in choices),
np.zeros_like(data, dtype=bool),
)
self.check_terms(
terms=terms,
expected=expected,
initial_workspace={c: data},
mask=self.build_mask(self.ones_mask(shape=data.shape)),
)
def test_element_of_rejects_missing_value(self):
"""
Test that element_of raises a useful error if we attempt to pass it an
@@ -360,10 +398,11 @@ class ClassifierTestCase(BasePipelineTestCase):
)
self.assertEqual(errmsg, expected)
def test_element_of_rejects_unhashable_type(self):
@parameter_space(dtype_=Classifier.ALLOWED_DTYPES)
def test_element_of_rejects_unhashable_type(self, dtype_):
class C(Classifier):
dtype = categorical_dtype
dtype = dtype_
missing_value = ''
inputs = ()
window_length = 0
@@ -375,7 +414,7 @@ class ClassifierTestCase(BasePipelineTestCase):
errmsg = str(e.exception)
expected = (
"Expected `choices` to be an iterable of strings,"
"Expected `choices` to be an iterable of hashable values,"
" but got [{'a': 1}] instead.\n"
"This caused the following error: "
"TypeError(\"unhashable type: 'dict'\",)."
+56 -68
View File
@@ -16,9 +16,10 @@ from zipline.utils.input_validation import expect_types
from zipline.utils.numpy_utils import (
categorical_dtype,
int64_dtype,
vectorized_is_element,
)
from ..filters import Filter, NullFilter, NumExprFilter
from ..filters import ArrayPredicate, NullFilter, NumExprFilter
from ..mixins import (
CustomTermMixin,
LatestMixin,
@@ -96,10 +97,10 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm):
binds=(self,),
)
else:
return StringPredicate(
classifier=self,
return ArrayPredicate(
term=self,
op=operator.eq,
compval=other,
opargs=(other,),
)
def __ne__(self, other):
@@ -119,11 +120,8 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm):
binds=(self,),
)
else:
return StringPredicate(
classifier=self,
op=operator.ne,
compval=other,
)
# Numexpr doesn't know how to use LabelArrays.
return ArrayPredicate(term=self, op=operator.ne, opargs=(other,))
@string_classifiers_only
@expect_types(prefix=(bytes, unicode))
@@ -142,10 +140,10 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm):
Filter returning True for all sid/date pairs for which ``self``
produces a string starting with ``prefix``.
"""
return StringPredicate(
classifier=self,
return ArrayPredicate(
term=self,
op=LabelArray.startswith,
compval=prefix,
opargs=(prefix,),
)
@string_classifiers_only
@@ -165,10 +163,10 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm):
Filter returning True for all sid/date pairs for which ``self``
produces a string ending with ``prefix``.
"""
return StringPredicate(
classifier=self,
return ArrayPredicate(
term=self,
op=LabelArray.endswith,
compval=suffix,
opargs=(suffix,),
)
@string_classifiers_only
@@ -188,10 +186,10 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm):
Filter returning True for all sid/date pairs for which ``self``
produces a string containing ``substring``.
"""
return StringPredicate(
classifier=self,
return ArrayPredicate(
term=self,
op=LabelArray.has_substring,
compval=substring,
opargs=(substring,),
)
@string_classifiers_only
@@ -215,33 +213,32 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm):
--------
https://docs.python.org/library/re.html
"""
return StringPredicate(
classifier=self,
return ArrayPredicate(
term=self,
op=LabelArray.matches,
compval=pattern,
opargs=(pattern,),
)
@string_classifiers_only
def element_of(self, choices):
"""
Construct a Filter indicating whether values are in ``choices``.
Parameters
----------
choices : iterable[str]
choices : iterable[str or int]
An iterable of choices.
Returns
-------
matches : Filter
Filter returning True for all sid/date pairs for which ``self``
produces a string in ``choices``.
produces an entry in ``choices``.
"""
try:
choices = frozenset(choices)
except Exception as e:
raise TypeError(
"Expected `choices` to be an iterable of strings,"
"Expected `choices` to be an iterable of hashable values,"
" but got {} instead.\n"
"This caused the following error: {!r}.".format(choices, e)
)
@@ -260,11 +257,40 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm):
)
)
return StringPredicate(
classifier=self,
op=LabelArray.element_of,
compval=choices,
)
def only_contains(type_, values):
return all(isinstance(v, type_) for v in values)
if self.dtype == int64_dtype:
if only_contains(int, choices):
return ArrayPredicate(
term=self,
op=vectorized_is_element,
opargs=(choices,),
)
else:
raise TypeError(
"Found non-int in choices for {typename}.element_of.\n"
"Supplied choices were {choices}.".format(
typename=type(self).__name__,
choices=choices,
)
)
elif self.dtype == categorical_dtype:
if only_contains((bytes, unicode), choices):
return ArrayPredicate(
term=self,
op=LabelArray.element_of,
opargs=(choices,),
)
else:
raise TypeError(
"Found non-string in choices for {typename}.element_of.\n"
"Supplied choices were {choices}.".format(
typename=type(self).__name__,
choices=choices,
)
)
assert False, "Unknown dtype in Classifier.element_of %s." % self.dtype
def postprocess(self, data):
if self.dtype == int64_dtype:
@@ -314,44 +340,6 @@ class Quantiles(SingleInputMixin, Classifier):
return type(self).__name__ + '(%d)' % self.params['bins']
class StringPredicate(SingleInputMixin, Filter):
"""
A filter applying a function from (LabelArray, hashable) -> ndarray[bool].
Examples include ``==, !=, startswith, and is_element``.
"""
window_length = 0
def __new__(cls, classifier, op, compval):
return super(StringPredicate, cls).__new__(
StringPredicate,
compval=compval,
op=op,
inputs=(classifier,),
mask=classifier.mask,
)
def _init(self, op, compval, *args, **kwargs):
self._op = op
self._compval = compval
return super(StringPredicate, self)._init(*args, **kwargs)
@classmethod
def static_identity(cls, op, compval, *args, **kwargs):
return (
super(StringPredicate, cls).static_identity(*args, **kwargs),
op,
compval,
)
def _compute(self, arrays, dates, assets, mask):
data = arrays[0]
return (
self._op(data, self._compval)
& mask
)
class CustomClassifier(PositiveWindowLengthMixin,
StandardOutputs,
CustomTermMixin,
+2
View File
@@ -1,4 +1,5 @@
from .filter import (
ArrayPredicate,
CustomFilter,
Filter,
Latest,
@@ -8,6 +9,7 @@ from .filter import (
)
__all__ = [
'ArrayPredicate',
'CustomFilter',
'Filter',
'Latest',
+49 -3
View File
@@ -1,14 +1,15 @@
"""
filter.py
"""
from itertools import chain
from operator import attrgetter
from numpy import (
float64,
nan,
nanpercentile,
)
from itertools import chain
from operator import attrgetter
from zipline.errors import (
BadPercentileBounds,
UnsupportedDataType,
@@ -28,6 +29,7 @@ from zipline.pipeline.expression import (
method_name_for_op,
NumericalExpression,
)
from zipline.utils.input_validation import expect_types
from zipline.utils.numpy_utils import bool_dtype
@@ -372,6 +374,50 @@ class CustomFilter(PositiveWindowLengthMixin, CustomTermMixin, Filter):
"""
class ArrayPredicate(SingleInputMixin, Filter):
"""
A filter applying a function from (ndarray, *args) -> ndarray[bool].
Parameters
----------
term : zipline.pipeline.Term
Term producing the array over which the predicate will be computed.
op : function(ndarray, *args) -> ndarray[bool]
Function to apply to the result of `term`.
opargs : tuple[hashable]
Additional argument to apply to ``op``.
"""
window_length = 0
@expect_types(term=Term, opargs=tuple)
def __new__(cls, term, op, opargs):
hash(opargs) # fail fast if opargs isn't hashable.
return super(ArrayPredicate, cls).__new__(
ArrayPredicate,
op=op,
opargs=opargs,
inputs=(term,),
mask=term.mask,
)
def _init(self, op, opargs, *args, **kwargs):
self._op = op
self._opargs = opargs
return super(ArrayPredicate, self)._init(*args, **kwargs)
@classmethod
def static_identity(cls, op, opargs, *args, **kwargs):
return (
super(ArrayPredicate, cls).static_identity(*args, **kwargs),
op,
opargs,
)
def _compute(self, arrays, dates, assets, mask):
data = arrays[0]
return self._op(data, *self._opargs) & mask
class Latest(LatestMixin, CustomFilter):
"""
Filter producing the most recently-known value of `inputs[0]` on each day.
+19
View File
@@ -14,6 +14,7 @@ from numpy import (
dtype,
empty,
nan,
vectorize,
where
)
from numpy.lib.stride_tricks import as_strided
@@ -344,3 +345,21 @@ def ignore_nanwarnings():
{'category': RuntimeWarning, 'module': 'numpy.lib.nanfunctions'},
)
)
def vectorized_is_element(array, choices):
"""
Check if each element of ``array`` is in choices.
Parameters
----------
array : np.ndarray
choices : object
Object implementing __contains__.
Returns
-------
was_element : np.ndarray[bool]
Array indicating whether each element of ``array`` was in ``choices``.
"""
return vectorize(choices.__contains__, otypes=[bool])(array)