From 5a1ed7b1d34d47b77403ee63e904474c9d2cf4f2 Mon Sep 17 00:00:00 2001 From: Scott Sanderson Date: Wed, 4 May 2016 16:02:47 -0400 Subject: [PATCH] ENH: Make element_of work for ints too. --- tests/pipeline/test_classifier.py | 47 +++++++- zipline/pipeline/classifiers/classifier.py | 124 ++++++++++----------- zipline/pipeline/filters/__init__.py | 2 + zipline/pipeline/filters/filter.py | 52 ++++++++- zipline/utils/numpy_utils.py | 19 ++++ 5 files changed, 169 insertions(+), 75 deletions(-) diff --git a/tests/pipeline/test_classifier.py b/tests/pipeline/test_classifier.py index 1fa1b9b9..f52e801e 100644 --- a/tests/pipeline/test_classifier.py +++ b/tests/pipeline/test_classifier.py @@ -1,3 +1,5 @@ +from operator import or_ + import numpy as np from zipline.lib.labelarray import LabelArray @@ -286,7 +288,7 @@ class ClassifierTestCase(BasePipelineTestCase): container_type=(set, list, tuple, frozenset), labelarray_dtype=(categorical_dtype, bytes_dtype, unicode_dtype), ) - def test_element_of(self, container_type, labelarray_dtype): + def test_element_of_strings(self, container_type, labelarray_dtype): missing = labelarray_dtype.type("not in the array") @@ -331,6 +333,42 @@ class ClassifierTestCase(BasePipelineTestCase): mask=self.build_mask(self.ones_mask(shape=data.shape)), ) + def test_element_of_integral(self): + """ + Element of is well-defined for integral classifiers. + """ + class C(Classifier): + dtype = int64_dtype + missing_value = -1 + inputs = () + window_length = 0 + + c = C() + + # There's no significance to the values here other than that they + # contain a mix of missing and non-missing values. + data = np.array([[-1, 1, 0, 2], + [3, 0, 1, 0], + [-5, 0, -1, 0], + [-3, 1, 2, 2]], dtype=int64_dtype) + + terms = {} + expected = {} + for choices in [(0,), (0, 1), (0, 1, 2)]: + terms[str(choices)] = c.element_of(choices) + expected[str(choices)] = reduce( + or_, + (data == elem for elem in choices), + np.zeros_like(data, dtype=bool), + ) + + self.check_terms( + terms=terms, + expected=expected, + initial_workspace={c: data}, + mask=self.build_mask(self.ones_mask(shape=data.shape)), + ) + def test_element_of_rejects_missing_value(self): """ Test that element_of raises a useful error if we attempt to pass it an @@ -360,10 +398,11 @@ class ClassifierTestCase(BasePipelineTestCase): ) self.assertEqual(errmsg, expected) - def test_element_of_rejects_unhashable_type(self): + @parameter_space(dtype_=Classifier.ALLOWED_DTYPES) + def test_element_of_rejects_unhashable_type(self, dtype_): class C(Classifier): - dtype = categorical_dtype + dtype = dtype_ missing_value = '' inputs = () window_length = 0 @@ -375,7 +414,7 @@ class ClassifierTestCase(BasePipelineTestCase): errmsg = str(e.exception) expected = ( - "Expected `choices` to be an iterable of strings," + "Expected `choices` to be an iterable of hashable values," " but got [{'a': 1}] instead.\n" "This caused the following error: " "TypeError(\"unhashable type: 'dict'\",)." diff --git a/zipline/pipeline/classifiers/classifier.py b/zipline/pipeline/classifiers/classifier.py index 5b33dba3..05a99593 100644 --- a/zipline/pipeline/classifiers/classifier.py +++ b/zipline/pipeline/classifiers/classifier.py @@ -16,9 +16,10 @@ from zipline.utils.input_validation import expect_types from zipline.utils.numpy_utils import ( categorical_dtype, int64_dtype, + vectorized_is_element, ) -from ..filters import Filter, NullFilter, NumExprFilter +from ..filters import ArrayPredicate, NullFilter, NumExprFilter from ..mixins import ( CustomTermMixin, LatestMixin, @@ -96,10 +97,10 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm): binds=(self,), ) else: - return StringPredicate( - classifier=self, + return ArrayPredicate( + term=self, op=operator.eq, - compval=other, + opargs=(other,), ) def __ne__(self, other): @@ -119,11 +120,8 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm): binds=(self,), ) else: - return StringPredicate( - classifier=self, - op=operator.ne, - compval=other, - ) + # Numexpr doesn't know how to use LabelArrays. + return ArrayPredicate(term=self, op=operator.ne, opargs=(other,)) @string_classifiers_only @expect_types(prefix=(bytes, unicode)) @@ -142,10 +140,10 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm): Filter returning True for all sid/date pairs for which ``self`` produces a string starting with ``prefix``. """ - return StringPredicate( - classifier=self, + return ArrayPredicate( + term=self, op=LabelArray.startswith, - compval=prefix, + opargs=(prefix,), ) @string_classifiers_only @@ -165,10 +163,10 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm): Filter returning True for all sid/date pairs for which ``self`` produces a string ending with ``prefix``. """ - return StringPredicate( - classifier=self, + return ArrayPredicate( + term=self, op=LabelArray.endswith, - compval=suffix, + opargs=(suffix,), ) @string_classifiers_only @@ -188,10 +186,10 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm): Filter returning True for all sid/date pairs for which ``self`` produces a string containing ``substring``. """ - return StringPredicate( - classifier=self, + return ArrayPredicate( + term=self, op=LabelArray.has_substring, - compval=substring, + opargs=(substring,), ) @string_classifiers_only @@ -215,33 +213,32 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm): -------- https://docs.python.org/library/re.html """ - return StringPredicate( - classifier=self, + return ArrayPredicate( + term=self, op=LabelArray.matches, - compval=pattern, + opargs=(pattern,), ) - @string_classifiers_only def element_of(self, choices): """ Construct a Filter indicating whether values are in ``choices``. Parameters ---------- - choices : iterable[str] + choices : iterable[str or int] An iterable of choices. Returns ------- matches : Filter Filter returning True for all sid/date pairs for which ``self`` - produces a string in ``choices``. + produces an entry in ``choices``. """ try: choices = frozenset(choices) except Exception as e: raise TypeError( - "Expected `choices` to be an iterable of strings," + "Expected `choices` to be an iterable of hashable values," " but got {} instead.\n" "This caused the following error: {!r}.".format(choices, e) ) @@ -260,11 +257,40 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm): ) ) - return StringPredicate( - classifier=self, - op=LabelArray.element_of, - compval=choices, - ) + def only_contains(type_, values): + return all(isinstance(v, type_) for v in values) + + if self.dtype == int64_dtype: + if only_contains(int, choices): + return ArrayPredicate( + term=self, + op=vectorized_is_element, + opargs=(choices,), + ) + else: + raise TypeError( + "Found non-int in choices for {typename}.element_of.\n" + "Supplied choices were {choices}.".format( + typename=type(self).__name__, + choices=choices, + ) + ) + elif self.dtype == categorical_dtype: + if only_contains((bytes, unicode), choices): + return ArrayPredicate( + term=self, + op=LabelArray.element_of, + opargs=(choices,), + ) + else: + raise TypeError( + "Found non-string in choices for {typename}.element_of.\n" + "Supplied choices were {choices}.".format( + typename=type(self).__name__, + choices=choices, + ) + ) + assert False, "Unknown dtype in Classifier.element_of %s." % self.dtype def postprocess(self, data): if self.dtype == int64_dtype: @@ -314,44 +340,6 @@ class Quantiles(SingleInputMixin, Classifier): return type(self).__name__ + '(%d)' % self.params['bins'] -class StringPredicate(SingleInputMixin, Filter): - """ - A filter applying a function from (LabelArray, hashable) -> ndarray[bool]. - - Examples include ``==, !=, startswith, and is_element``. - """ - window_length = 0 - - def __new__(cls, classifier, op, compval): - return super(StringPredicate, cls).__new__( - StringPredicate, - compval=compval, - op=op, - inputs=(classifier,), - mask=classifier.mask, - ) - - def _init(self, op, compval, *args, **kwargs): - self._op = op - self._compval = compval - return super(StringPredicate, self)._init(*args, **kwargs) - - @classmethod - def static_identity(cls, op, compval, *args, **kwargs): - return ( - super(StringPredicate, cls).static_identity(*args, **kwargs), - op, - compval, - ) - - def _compute(self, arrays, dates, assets, mask): - data = arrays[0] - return ( - self._op(data, self._compval) - & mask - ) - - class CustomClassifier(PositiveWindowLengthMixin, StandardOutputs, CustomTermMixin, diff --git a/zipline/pipeline/filters/__init__.py b/zipline/pipeline/filters/__init__.py index 4f05fc6f..be7a63a3 100644 --- a/zipline/pipeline/filters/__init__.py +++ b/zipline/pipeline/filters/__init__.py @@ -1,4 +1,5 @@ from .filter import ( + ArrayPredicate, CustomFilter, Filter, Latest, @@ -8,6 +9,7 @@ from .filter import ( ) __all__ = [ + 'ArrayPredicate', 'CustomFilter', 'Filter', 'Latest', diff --git a/zipline/pipeline/filters/filter.py b/zipline/pipeline/filters/filter.py index 5c3a8783..360667f4 100644 --- a/zipline/pipeline/filters/filter.py +++ b/zipline/pipeline/filters/filter.py @@ -1,14 +1,15 @@ """ filter.py """ +from itertools import chain +from operator import attrgetter + + from numpy import ( float64, nan, nanpercentile, ) -from itertools import chain -from operator import attrgetter - from zipline.errors import ( BadPercentileBounds, UnsupportedDataType, @@ -28,6 +29,7 @@ from zipline.pipeline.expression import ( method_name_for_op, NumericalExpression, ) +from zipline.utils.input_validation import expect_types from zipline.utils.numpy_utils import bool_dtype @@ -372,6 +374,50 @@ class CustomFilter(PositiveWindowLengthMixin, CustomTermMixin, Filter): """ +class ArrayPredicate(SingleInputMixin, Filter): + """ + A filter applying a function from (ndarray, *args) -> ndarray[bool]. + + Parameters + ---------- + term : zipline.pipeline.Term + Term producing the array over which the predicate will be computed. + op : function(ndarray, *args) -> ndarray[bool] + Function to apply to the result of `term`. + opargs : tuple[hashable] + Additional argument to apply to ``op``. + """ + window_length = 0 + + @expect_types(term=Term, opargs=tuple) + def __new__(cls, term, op, opargs): + hash(opargs) # fail fast if opargs isn't hashable. + return super(ArrayPredicate, cls).__new__( + ArrayPredicate, + op=op, + opargs=opargs, + inputs=(term,), + mask=term.mask, + ) + + def _init(self, op, opargs, *args, **kwargs): + self._op = op + self._opargs = opargs + return super(ArrayPredicate, self)._init(*args, **kwargs) + + @classmethod + def static_identity(cls, op, opargs, *args, **kwargs): + return ( + super(ArrayPredicate, cls).static_identity(*args, **kwargs), + op, + opargs, + ) + + def _compute(self, arrays, dates, assets, mask): + data = arrays[0] + return self._op(data, *self._opargs) & mask + + class Latest(LatestMixin, CustomFilter): """ Filter producing the most recently-known value of `inputs[0]` on each day. diff --git a/zipline/utils/numpy_utils.py b/zipline/utils/numpy_utils.py index 63be0aa5..fbadc5c7 100644 --- a/zipline/utils/numpy_utils.py +++ b/zipline/utils/numpy_utils.py @@ -14,6 +14,7 @@ from numpy import ( dtype, empty, nan, + vectorize, where ) from numpy.lib.stride_tricks import as_strided @@ -344,3 +345,21 @@ def ignore_nanwarnings(): {'category': RuntimeWarning, 'module': 'numpy.lib.nanfunctions'}, ) ) + + +def vectorized_is_element(array, choices): + """ + Check if each element of ``array`` is in choices. + + Parameters + ---------- + array : np.ndarray + choices : object + Object implementing __contains__. + + Returns + ------- + was_element : np.ndarray[bool] + Array indicating whether each element of ``array`` was in ``choices``. + """ + return vectorize(choices.__contains__, otypes=[bool])(array)