From c40bbfae03026ffed3f7c8f0baf369b64f2acd4c Mon Sep 17 00:00:00 2001 From: Scott Sanderson Date: Mon, 2 May 2016 15:14:01 -0400 Subject: [PATCH] TEST: More tests for string predicates. --- tests/pipeline/test_classifier.py | 131 ++++++++++++-- zipline/lib/labelarray.py | 27 ++- zipline/pipeline/classifiers/classifier.py | 190 ++++++++++++++++----- 3 files changed, 290 insertions(+), 58 deletions(-) diff --git a/tests/pipeline/test_classifier.py b/tests/pipeline/test_classifier.py index 77f76079..1fa1b9b9 100644 --- a/tests/pipeline/test_classifier.py +++ b/tests/pipeline/test_classifier.py @@ -256,17 +256,128 @@ class ClassifierTestCase(BasePipelineTestCase): missing_value=missing, ) + terms = { + 'startswith': c.startswith(compval), + 'endswith': c.endswith(compval), + 'has_substring': c.has_substring(compval), + # Equivalent filters using regex matching. + 'startswith_re': c.matches('^' + compval + '.*'), + 'endswith_re': c.matches('.*' + compval + '$'), + 'has_substring_re': c.matches('.*' + compval + '.*'), + } + + expected = { + 'startswith': (data.startswith(compval) & (data != missing)), + 'endswith': (data.endswith(compval) & (data != missing)), + 'has_substring': (data.has_substring(compval) & (data != missing)), + } + for key in list(expected): + expected[key + '_re'] = expected[key] + self.check_terms( - terms={ - 'startswith': c.startswith(compval), - 'endswith': c.endswith(compval), - 'contains': c.contains(compval), - }, - expected={ - 'startswith': (data.startswith(compval) & (data != missing)), - 'endswith': (data.endswith(compval) & (data != missing)), - 'contains': (data.contains(compval) & (data != missing)), - }, + terms=terms, + expected=expected, initial_workspace={c: data}, mask=self.build_mask(self.ones_mask(shape=data.shape)), ) + + @parameter_space( + __fail_fast=True, + container_type=(set, list, tuple, frozenset), + labelarray_dtype=(categorical_dtype, bytes_dtype, unicode_dtype), + ) + def test_element_of(self, container_type, labelarray_dtype): + + missing = labelarray_dtype.type("not in the array") + + class C(Classifier): + dtype = categorical_dtype + missing_value = missing + inputs = () + window_length = 0 + + c = C() + + raw = np.asarray( + [['', 'a', 'ab', 'ba'], + ['z', 'ab', 'a', 'ab'], + ['aa', 'ab', '', 'ab'], + ['aa', 'a', 'ba', 'ba']], + dtype=labelarray_dtype, + ) + data = LabelArray(raw, missing_value=missing) + + choices = [ + container_type(choices) for choices in [ + [], + ['a', ''], + ['a', 'a', 'a', 'ab', 'a'], + set(data.reverse_categories) - {missing}, + ['random value', 'ab'], + ['_' * i for i in range(30)], + ] + ] + + def make_expected(choice_set): + return np.vectorize(choice_set.__contains__, otypes=[bool])(raw) + + terms = {str(i): c.element_of(s) for i, s in enumerate(choices)} + expected = {str(i): make_expected(s) for i, s in enumerate(choices)} + + self.check_terms( + terms=terms, + expected=expected, + initial_workspace={c: data}, + mask=self.build_mask(self.ones_mask(shape=data.shape)), + ) + + def test_element_of_rejects_missing_value(self): + """ + Test that element_of raises a useful error if we attempt to pass it an + array of choices that include the classifier's missing_value. + """ + missing = "not in the array" + + class C(Classifier): + dtype = categorical_dtype + missing_value = missing + inputs = () + window_length = 0 + + c = C() + + for bad_elems in ([missing], [missing, 'random other value']): + with self.assertRaises(ValueError) as e: + c.element_of(bad_elems) + errmsg = str(e.exception) + expected = ( + "Found self.missing_value ('not in the array') in choices" + " supplied to C.is_element().\n" + "Missing values have NaN semantics, so the requested" + " comparison would always produce False.\n" + "Use the isnull() method to check for missing values.\n" + "Received choices were {}.".format(bad_elems) + ) + self.assertEqual(errmsg, expected) + + def test_element_of_rejects_unhashable_type(self): + + class C(Classifier): + dtype = categorical_dtype + missing_value = '' + inputs = () + window_length = 0 + + c = C() + + with self.assertRaises(TypeError) as e: + c.element_of([{'a': 1}]) + + errmsg = str(e.exception) + expected = ( + "Expected `choices` to be an iterable of strings," + " but got [{'a': 1}] instead.\n" + "This caused the following error: " + "TypeError(\"unhashable type: 'dict'\",)." + ) + self.assertEqual(errmsg, expected) diff --git a/zipline/lib/labelarray.py b/zipline/lib/labelarray.py index 837a39c2..e5910a89 100644 --- a/zipline/lib/labelarray.py +++ b/zipline/lib/labelarray.py @@ -448,7 +448,7 @@ class LabelArray(ndarray): """ return self.apply(lambda elem: elem.endswith(suffix), dtype=bool) - def contains(self, substring): + def has_substring(self, substring): """ Elementwise contains. @@ -464,7 +464,7 @@ class LabelArray(ndarray): """ return self.apply(lambda elem: substring in elem, dtype=bool) - @preprocess(pattern=re.compile) + @preprocess(pattern=coerce(from_=(bytes, unicode), to=re.compile)) def matches(self, pattern): """ Elementwise regex match. @@ -479,4 +479,25 @@ class LabelArray(ndarray): An array with the same shape as self indicating whether each element of self was matched by ``pattern``. """ - return self.apply(compose(bool, pattern.match)) + return self.apply(compose(bool, pattern.match), dtype=bool) + + # These types all implement an O(N) __contains__, so pre-emptively + # coerce to `set`. + @preprocess(container=coerce((list, tuple, np.ndarray), set)) + def element_of(self, container): + """ + Check if each element of self is an of ``container``. + + Parameters + ---------- + container : object + An object implementing a __contains__ to call on each element of + ``self``. + + Returns + ------- + is_contained : np.ndarray[bool] + An array with the same shape as self indicating whether each + element of self was an element of ``container``. + """ + return self.apply(container.__contains__, dtype=bool) diff --git a/zipline/pipeline/classifiers/classifier.py b/zipline/pipeline/classifiers/classifier.py index e27016eb..332fc206 100644 --- a/zipline/pipeline/classifiers/classifier.py +++ b/zipline/pipeline/classifiers/classifier.py @@ -1,9 +1,9 @@ """ classifier.py """ -from functools import wraps from numbers import Number import operator +import re from numpy import where, isnan, nan, zeros @@ -28,7 +28,7 @@ from ..mixins import ( ) -strings_only = restrict_to_dtype( +string_classifiers_only = restrict_to_dtype( dtype=categorical_dtype, message_template=( "{method_name}() is only defined on Classifiers producing strings" @@ -95,7 +95,7 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm): binds=(self,), ) else: - return ScalarStringPredicate( + return StringPredicate( classifier=self, op=operator.eq, compval=other, @@ -118,47 +118,152 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm): binds=(self,), ) else: - return ScalarStringPredicate( + return StringPredicate( classifier=self, op=operator.ne, compval=other, ) - def _string_predicate(f): + @string_classifiers_only + @expect_types(prefix=(bytes, unicode)) + def startswith(self, prefix): """ - Decorator for converting a function from (LabelArray, str) -> bool - into a Classifier method that returns a ScalarStringPredicate filter. + Construct a Filter matching values starting with ``prefix``. - This mainly exists to avoid replicating shared boilerplate - (e.g. argument type validation). + Parameters + ---------- + prefix : str + String prefix against which to compare values produced by ``self``. + + Returns + ------- + matches : Filter + Filter returning True for all sid/date pairs for which ``self`` + produces a string starting with ``prefix``. """ - @wraps(f) - @expect_types(compval=(bytes, unicode)) - @strings_only - def method(self, compval): - return ScalarStringPredicate( - classifier=self, - op=f, - compval=compval, + return StringPredicate( + classifier=self, + op=LabelArray.startswith, + compval=prefix, + ) + + @string_classifiers_only + @expect_types(suffix=(bytes, unicode)) + def endswith(self, suffix): + """ + Construct a Filter matching values ending with ``suffix``. + + Parameters + ---------- + suffix : str + String suffix against which to compare values produced by ``self``. + + Returns + ------- + matches : Filter + Filter returning True for all sid/date pairs for which ``self`` + produces a string ending with ``prefix``. + """ + return StringPredicate( + classifier=self, + op=LabelArray.endswith, + compval=suffix, + ) + + @string_classifiers_only + @expect_types(substring=(bytes, unicode)) + def has_substring(self, substring): + """ + Construct a Filter matching values containing ``substring``. + + Parameters + ---------- + substring : str + Sub-string against which to compare values produced by ``self``. + + Returns + ------- + matches : Filter + Filter returning True for all sid/date pairs for which ``self`` + produces a string containing ``substring``. + """ + return StringPredicate( + classifier=self, + op=LabelArray.has_substring, + compval=substring, + ) + + @string_classifiers_only + @expect_types(pattern=(bytes, unicode, type(re.compile('')))) + def matches(self, pattern): + """ + Construct a Filter that checks regex matches against ``pattern``. + + Parameters + ---------- + pattern : str + Regex pattern against which to compare values produced by ``self``. + + Returns + ------- + matches : Filter + Filter returning True for all sid/date pairs for which ``self`` + produces a string matched by ``pattern``. + + See Also + -------- + https://docs.python.org/library/re.html + """ + return StringPredicate( + classifier=self, + op=LabelArray.matches, + compval=pattern, + ) + + @string_classifiers_only + def element_of(self, choices): + """ + Construct a Filter indicating whether values are in ``choices``. + + Parameters + ---------- + choices : iterable[str] + An iterable of choices. + + Returns + ------- + matches : Filter + Filter returning True for all sid/date pairs for which ``self`` + produces a string in ``choices``. + """ + try: + choices = frozenset(choices) + except Exception as e: + raise TypeError( + "Expected `choices` to be an iterable of strings," + " but got {} instead.\n" + "This caused the following error: {!r}.".format(choices, e) ) - return method - @_string_predicate - @expect_types(label_array=LabelArray) - def startswith(label_array, other): - return label_array.startswith(other) + if self.missing_value in choices: + raise ValueError( + "Found self.missing_value ({mv!r}) in choices supplied to" + " {typename}.is_element().\n" + "Missing values have NaN semantics, so the" + " requested comparison would always produce False.\n" + "Use the isnull() method to check for missing values.\n" + "Received choices were {choices}.".format( + mv=self.missing_value, + typename=(type(self).__name__), + choices=sorted(choices), + ) + ) - @_string_predicate - @expect_types(label_array=LabelArray) - def endswith(label_array, other): - return label_array.endswith(other) - - @_string_predicate - @expect_types(label_array=LabelArray) - def contains(label_array, other): - return label_array.contains(other) - - del _string_predicate + return StringPredicate( + classifier=self, + op=LabelArray.element_of, + compval=choices, + ) def postprocess(self, data): if self.dtype == int64_dtype: @@ -208,22 +313,17 @@ class Quantiles(SingleInputMixin, Classifier): return type(self).__name__ + '(%d)' % self.params['bins'] -class ScalarStringPredicate(SingleInputMixin, Filter): +class StringPredicate(SingleInputMixin, Filter): """ - A filter that applies a function from (LabelArray, str) -> ndarray[bool]. + A filter applying a function from (LabelArray, hashable) -> ndarray[bool]. - Examples include ``==, !=, startswith, and endswith``. - - This exists because we represent string arrays with - ``zipline.lib.LabelArray``s, which numexpr doesn't know about, so we can't - use the generic NumExprFilter implementation here. + Examples include ``==, !=, startswith, and is_element``. """ window_length = 0 - @expect_types(classifier=Classifier, compval=(bytes, unicode)) def __new__(cls, classifier, op, compval): - return super(ScalarStringPredicate, cls).__new__( - ScalarStringPredicate, + return super(StringPredicate, cls).__new__( + StringPredicate, compval=compval, op=op, inputs=(classifier,), @@ -233,12 +333,12 @@ class ScalarStringPredicate(SingleInputMixin, Filter): def _init(self, op, compval, *args, **kwargs): self._op = op self._compval = compval - return super(ScalarStringPredicate, self)._init(*args, **kwargs) + return super(StringPredicate, self)._init(*args, **kwargs) @classmethod def static_identity(cls, op, compval, *args, **kwargs): return ( - super(ScalarStringPredicate, cls).static_identity(*args, **kwargs), + super(StringPredicate, cls).static_identity(*args, **kwargs), op, compval, )