TEST: More tests for string predicates.

This commit is contained in:
Scott Sanderson
2016-05-02 15:14:01 -04:00
parent bb6f908036
commit c40bbfae03
3 changed files with 290 additions and 58 deletions
+121 -10
View File
@@ -256,17 +256,128 @@ class ClassifierTestCase(BasePipelineTestCase):
missing_value=missing,
)
terms = {
'startswith': c.startswith(compval),
'endswith': c.endswith(compval),
'has_substring': c.has_substring(compval),
# Equivalent filters using regex matching.
'startswith_re': c.matches('^' + compval + '.*'),
'endswith_re': c.matches('.*' + compval + '$'),
'has_substring_re': c.matches('.*' + compval + '.*'),
}
expected = {
'startswith': (data.startswith(compval) & (data != missing)),
'endswith': (data.endswith(compval) & (data != missing)),
'has_substring': (data.has_substring(compval) & (data != missing)),
}
for key in list(expected):
expected[key + '_re'] = expected[key]
self.check_terms(
terms={
'startswith': c.startswith(compval),
'endswith': c.endswith(compval),
'contains': c.contains(compval),
},
expected={
'startswith': (data.startswith(compval) & (data != missing)),
'endswith': (data.endswith(compval) & (data != missing)),
'contains': (data.contains(compval) & (data != missing)),
},
terms=terms,
expected=expected,
initial_workspace={c: data},
mask=self.build_mask(self.ones_mask(shape=data.shape)),
)
@parameter_space(
__fail_fast=True,
container_type=(set, list, tuple, frozenset),
labelarray_dtype=(categorical_dtype, bytes_dtype, unicode_dtype),
)
def test_element_of(self, container_type, labelarray_dtype):
missing = labelarray_dtype.type("not in the array")
class C(Classifier):
dtype = categorical_dtype
missing_value = missing
inputs = ()
window_length = 0
c = C()
raw = np.asarray(
[['', 'a', 'ab', 'ba'],
['z', 'ab', 'a', 'ab'],
['aa', 'ab', '', 'ab'],
['aa', 'a', 'ba', 'ba']],
dtype=labelarray_dtype,
)
data = LabelArray(raw, missing_value=missing)
choices = [
container_type(choices) for choices in [
[],
['a', ''],
['a', 'a', 'a', 'ab', 'a'],
set(data.reverse_categories) - {missing},
['random value', 'ab'],
['_' * i for i in range(30)],
]
]
def make_expected(choice_set):
return np.vectorize(choice_set.__contains__, otypes=[bool])(raw)
terms = {str(i): c.element_of(s) for i, s in enumerate(choices)}
expected = {str(i): make_expected(s) for i, s in enumerate(choices)}
self.check_terms(
terms=terms,
expected=expected,
initial_workspace={c: data},
mask=self.build_mask(self.ones_mask(shape=data.shape)),
)
def test_element_of_rejects_missing_value(self):
"""
Test that element_of raises a useful error if we attempt to pass it an
array of choices that include the classifier's missing_value.
"""
missing = "not in the array"
class C(Classifier):
dtype = categorical_dtype
missing_value = missing
inputs = ()
window_length = 0
c = C()
for bad_elems in ([missing], [missing, 'random other value']):
with self.assertRaises(ValueError) as e:
c.element_of(bad_elems)
errmsg = str(e.exception)
expected = (
"Found self.missing_value ('not in the array') in choices"
" supplied to C.is_element().\n"
"Missing values have NaN semantics, so the requested"
" comparison would always produce False.\n"
"Use the isnull() method to check for missing values.\n"
"Received choices were {}.".format(bad_elems)
)
self.assertEqual(errmsg, expected)
def test_element_of_rejects_unhashable_type(self):
class C(Classifier):
dtype = categorical_dtype
missing_value = ''
inputs = ()
window_length = 0
c = C()
with self.assertRaises(TypeError) as e:
c.element_of([{'a': 1}])
errmsg = str(e.exception)
expected = (
"Expected `choices` to be an iterable of strings,"
" but got [{'a': 1}] instead.\n"
"This caused the following error: "
"TypeError(\"unhashable type: 'dict'\",)."
)
self.assertEqual(errmsg, expected)
+24 -3
View File
@@ -448,7 +448,7 @@ class LabelArray(ndarray):
"""
return self.apply(lambda elem: elem.endswith(suffix), dtype=bool)
def contains(self, substring):
def has_substring(self, substring):
"""
Elementwise contains.
@@ -464,7 +464,7 @@ class LabelArray(ndarray):
"""
return self.apply(lambda elem: substring in elem, dtype=bool)
@preprocess(pattern=re.compile)
@preprocess(pattern=coerce(from_=(bytes, unicode), to=re.compile))
def matches(self, pattern):
"""
Elementwise regex match.
@@ -479,4 +479,25 @@ class LabelArray(ndarray):
An array with the same shape as self indicating whether each
element of self was matched by ``pattern``.
"""
return self.apply(compose(bool, pattern.match))
return self.apply(compose(bool, pattern.match), dtype=bool)
# These types all implement an O(N) __contains__, so pre-emptively
# coerce to `set`.
@preprocess(container=coerce((list, tuple, np.ndarray), set))
def element_of(self, container):
"""
Check if each element of self is an of ``container``.
Parameters
----------
container : object
An object implementing a __contains__ to call on each element of
``self``.
Returns
-------
is_contained : np.ndarray[bool]
An array with the same shape as self indicating whether each
element of self was an element of ``container``.
"""
return self.apply(container.__contains__, dtype=bool)
+145 -45
View File
@@ -1,9 +1,9 @@
"""
classifier.py
"""
from functools import wraps
from numbers import Number
import operator
import re
from numpy import where, isnan, nan, zeros
@@ -28,7 +28,7 @@ from ..mixins import (
)
strings_only = restrict_to_dtype(
string_classifiers_only = restrict_to_dtype(
dtype=categorical_dtype,
message_template=(
"{method_name}() is only defined on Classifiers producing strings"
@@ -95,7 +95,7 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm):
binds=(self,),
)
else:
return ScalarStringPredicate(
return StringPredicate(
classifier=self,
op=operator.eq,
compval=other,
@@ -118,47 +118,152 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm):
binds=(self,),
)
else:
return ScalarStringPredicate(
return StringPredicate(
classifier=self,
op=operator.ne,
compval=other,
)
def _string_predicate(f):
@string_classifiers_only
@expect_types(prefix=(bytes, unicode))
def startswith(self, prefix):
"""
Decorator for converting a function from (LabelArray, str) -> bool
into a Classifier method that returns a ScalarStringPredicate filter.
Construct a Filter matching values starting with ``prefix``.
This mainly exists to avoid replicating shared boilerplate
(e.g. argument type validation).
Parameters
----------
prefix : str
String prefix against which to compare values produced by ``self``.
Returns
-------
matches : Filter
Filter returning True for all sid/date pairs for which ``self``
produces a string starting with ``prefix``.
"""
@wraps(f)
@expect_types(compval=(bytes, unicode))
@strings_only
def method(self, compval):
return ScalarStringPredicate(
classifier=self,
op=f,
compval=compval,
return StringPredicate(
classifier=self,
op=LabelArray.startswith,
compval=prefix,
)
@string_classifiers_only
@expect_types(suffix=(bytes, unicode))
def endswith(self, suffix):
"""
Construct a Filter matching values ending with ``suffix``.
Parameters
----------
suffix : str
String suffix against which to compare values produced by ``self``.
Returns
-------
matches : Filter
Filter returning True for all sid/date pairs for which ``self``
produces a string ending with ``prefix``.
"""
return StringPredicate(
classifier=self,
op=LabelArray.endswith,
compval=suffix,
)
@string_classifiers_only
@expect_types(substring=(bytes, unicode))
def has_substring(self, substring):
"""
Construct a Filter matching values containing ``substring``.
Parameters
----------
substring : str
Sub-string against which to compare values produced by ``self``.
Returns
-------
matches : Filter
Filter returning True for all sid/date pairs for which ``self``
produces a string containing ``substring``.
"""
return StringPredicate(
classifier=self,
op=LabelArray.has_substring,
compval=substring,
)
@string_classifiers_only
@expect_types(pattern=(bytes, unicode, type(re.compile(''))))
def matches(self, pattern):
"""
Construct a Filter that checks regex matches against ``pattern``.
Parameters
----------
pattern : str
Regex pattern against which to compare values produced by ``self``.
Returns
-------
matches : Filter
Filter returning True for all sid/date pairs for which ``self``
produces a string matched by ``pattern``.
See Also
--------
https://docs.python.org/library/re.html
"""
return StringPredicate(
classifier=self,
op=LabelArray.matches,
compval=pattern,
)
@string_classifiers_only
def element_of(self, choices):
"""
Construct a Filter indicating whether values are in ``choices``.
Parameters
----------
choices : iterable[str]
An iterable of choices.
Returns
-------
matches : Filter
Filter returning True for all sid/date pairs for which ``self``
produces a string in ``choices``.
"""
try:
choices = frozenset(choices)
except Exception as e:
raise TypeError(
"Expected `choices` to be an iterable of strings,"
" but got {} instead.\n"
"This caused the following error: {!r}.".format(choices, e)
)
return method
@_string_predicate
@expect_types(label_array=LabelArray)
def startswith(label_array, other):
return label_array.startswith(other)
if self.missing_value in choices:
raise ValueError(
"Found self.missing_value ({mv!r}) in choices supplied to"
" {typename}.is_element().\n"
"Missing values have NaN semantics, so the"
" requested comparison would always produce False.\n"
"Use the isnull() method to check for missing values.\n"
"Received choices were {choices}.".format(
mv=self.missing_value,
typename=(type(self).__name__),
choices=sorted(choices),
)
)
@_string_predicate
@expect_types(label_array=LabelArray)
def endswith(label_array, other):
return label_array.endswith(other)
@_string_predicate
@expect_types(label_array=LabelArray)
def contains(label_array, other):
return label_array.contains(other)
del _string_predicate
return StringPredicate(
classifier=self,
op=LabelArray.element_of,
compval=choices,
)
def postprocess(self, data):
if self.dtype == int64_dtype:
@@ -208,22 +313,17 @@ class Quantiles(SingleInputMixin, Classifier):
return type(self).__name__ + '(%d)' % self.params['bins']
class ScalarStringPredicate(SingleInputMixin, Filter):
class StringPredicate(SingleInputMixin, Filter):
"""
A filter that applies a function from (LabelArray, str) -> ndarray[bool].
A filter applying a function from (LabelArray, hashable) -> ndarray[bool].
Examples include ``==, !=, startswith, and endswith``.
This exists because we represent string arrays with
``zipline.lib.LabelArray``s, which numexpr doesn't know about, so we can't
use the generic NumExprFilter implementation here.
Examples include ``==, !=, startswith, and is_element``.
"""
window_length = 0
@expect_types(classifier=Classifier, compval=(bytes, unicode))
def __new__(cls, classifier, op, compval):
return super(ScalarStringPredicate, cls).__new__(
ScalarStringPredicate,
return super(StringPredicate, cls).__new__(
StringPredicate,
compval=compval,
op=op,
inputs=(classifier,),
@@ -233,12 +333,12 @@ class ScalarStringPredicate(SingleInputMixin, Filter):
def _init(self, op, compval, *args, **kwargs):
self._op = op
self._compval = compval
return super(ScalarStringPredicate, self)._init(*args, **kwargs)
return super(StringPredicate, self)._init(*args, **kwargs)
@classmethod
def static_identity(cls, op, compval, *args, **kwargs):
return (
super(ScalarStringPredicate, cls).static_identity(*args, **kwargs),
super(StringPredicate, cls).static_identity(*args, **kwargs),
op,
compval,
)