mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 18:27:14 +08:00
TEST: More tests for string predicates.
This commit is contained in:
@@ -256,17 +256,128 @@ class ClassifierTestCase(BasePipelineTestCase):
|
||||
missing_value=missing,
|
||||
)
|
||||
|
||||
terms = {
|
||||
'startswith': c.startswith(compval),
|
||||
'endswith': c.endswith(compval),
|
||||
'has_substring': c.has_substring(compval),
|
||||
# Equivalent filters using regex matching.
|
||||
'startswith_re': c.matches('^' + compval + '.*'),
|
||||
'endswith_re': c.matches('.*' + compval + '$'),
|
||||
'has_substring_re': c.matches('.*' + compval + '.*'),
|
||||
}
|
||||
|
||||
expected = {
|
||||
'startswith': (data.startswith(compval) & (data != missing)),
|
||||
'endswith': (data.endswith(compval) & (data != missing)),
|
||||
'has_substring': (data.has_substring(compval) & (data != missing)),
|
||||
}
|
||||
for key in list(expected):
|
||||
expected[key + '_re'] = expected[key]
|
||||
|
||||
self.check_terms(
|
||||
terms={
|
||||
'startswith': c.startswith(compval),
|
||||
'endswith': c.endswith(compval),
|
||||
'contains': c.contains(compval),
|
||||
},
|
||||
expected={
|
||||
'startswith': (data.startswith(compval) & (data != missing)),
|
||||
'endswith': (data.endswith(compval) & (data != missing)),
|
||||
'contains': (data.contains(compval) & (data != missing)),
|
||||
},
|
||||
terms=terms,
|
||||
expected=expected,
|
||||
initial_workspace={c: data},
|
||||
mask=self.build_mask(self.ones_mask(shape=data.shape)),
|
||||
)
|
||||
|
||||
@parameter_space(
|
||||
__fail_fast=True,
|
||||
container_type=(set, list, tuple, frozenset),
|
||||
labelarray_dtype=(categorical_dtype, bytes_dtype, unicode_dtype),
|
||||
)
|
||||
def test_element_of(self, container_type, labelarray_dtype):
|
||||
|
||||
missing = labelarray_dtype.type("not in the array")
|
||||
|
||||
class C(Classifier):
|
||||
dtype = categorical_dtype
|
||||
missing_value = missing
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
|
||||
c = C()
|
||||
|
||||
raw = np.asarray(
|
||||
[['', 'a', 'ab', 'ba'],
|
||||
['z', 'ab', 'a', 'ab'],
|
||||
['aa', 'ab', '', 'ab'],
|
||||
['aa', 'a', 'ba', 'ba']],
|
||||
dtype=labelarray_dtype,
|
||||
)
|
||||
data = LabelArray(raw, missing_value=missing)
|
||||
|
||||
choices = [
|
||||
container_type(choices) for choices in [
|
||||
[],
|
||||
['a', ''],
|
||||
['a', 'a', 'a', 'ab', 'a'],
|
||||
set(data.reverse_categories) - {missing},
|
||||
['random value', 'ab'],
|
||||
['_' * i for i in range(30)],
|
||||
]
|
||||
]
|
||||
|
||||
def make_expected(choice_set):
|
||||
return np.vectorize(choice_set.__contains__, otypes=[bool])(raw)
|
||||
|
||||
terms = {str(i): c.element_of(s) for i, s in enumerate(choices)}
|
||||
expected = {str(i): make_expected(s) for i, s in enumerate(choices)}
|
||||
|
||||
self.check_terms(
|
||||
terms=terms,
|
||||
expected=expected,
|
||||
initial_workspace={c: data},
|
||||
mask=self.build_mask(self.ones_mask(shape=data.shape)),
|
||||
)
|
||||
|
||||
def test_element_of_rejects_missing_value(self):
|
||||
"""
|
||||
Test that element_of raises a useful error if we attempt to pass it an
|
||||
array of choices that include the classifier's missing_value.
|
||||
"""
|
||||
missing = "not in the array"
|
||||
|
||||
class C(Classifier):
|
||||
dtype = categorical_dtype
|
||||
missing_value = missing
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
|
||||
c = C()
|
||||
|
||||
for bad_elems in ([missing], [missing, 'random other value']):
|
||||
with self.assertRaises(ValueError) as e:
|
||||
c.element_of(bad_elems)
|
||||
errmsg = str(e.exception)
|
||||
expected = (
|
||||
"Found self.missing_value ('not in the array') in choices"
|
||||
" supplied to C.is_element().\n"
|
||||
"Missing values have NaN semantics, so the requested"
|
||||
" comparison would always produce False.\n"
|
||||
"Use the isnull() method to check for missing values.\n"
|
||||
"Received choices were {}.".format(bad_elems)
|
||||
)
|
||||
self.assertEqual(errmsg, expected)
|
||||
|
||||
def test_element_of_rejects_unhashable_type(self):
|
||||
|
||||
class C(Classifier):
|
||||
dtype = categorical_dtype
|
||||
missing_value = ''
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
|
||||
c = C()
|
||||
|
||||
with self.assertRaises(TypeError) as e:
|
||||
c.element_of([{'a': 1}])
|
||||
|
||||
errmsg = str(e.exception)
|
||||
expected = (
|
||||
"Expected `choices` to be an iterable of strings,"
|
||||
" but got [{'a': 1}] instead.\n"
|
||||
"This caused the following error: "
|
||||
"TypeError(\"unhashable type: 'dict'\",)."
|
||||
)
|
||||
self.assertEqual(errmsg, expected)
|
||||
|
||||
@@ -448,7 +448,7 @@ class LabelArray(ndarray):
|
||||
"""
|
||||
return self.apply(lambda elem: elem.endswith(suffix), dtype=bool)
|
||||
|
||||
def contains(self, substring):
|
||||
def has_substring(self, substring):
|
||||
"""
|
||||
Elementwise contains.
|
||||
|
||||
@@ -464,7 +464,7 @@ class LabelArray(ndarray):
|
||||
"""
|
||||
return self.apply(lambda elem: substring in elem, dtype=bool)
|
||||
|
||||
@preprocess(pattern=re.compile)
|
||||
@preprocess(pattern=coerce(from_=(bytes, unicode), to=re.compile))
|
||||
def matches(self, pattern):
|
||||
"""
|
||||
Elementwise regex match.
|
||||
@@ -479,4 +479,25 @@ class LabelArray(ndarray):
|
||||
An array with the same shape as self indicating whether each
|
||||
element of self was matched by ``pattern``.
|
||||
"""
|
||||
return self.apply(compose(bool, pattern.match))
|
||||
return self.apply(compose(bool, pattern.match), dtype=bool)
|
||||
|
||||
# These types all implement an O(N) __contains__, so pre-emptively
|
||||
# coerce to `set`.
|
||||
@preprocess(container=coerce((list, tuple, np.ndarray), set))
|
||||
def element_of(self, container):
|
||||
"""
|
||||
Check if each element of self is an of ``container``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
container : object
|
||||
An object implementing a __contains__ to call on each element of
|
||||
``self``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
is_contained : np.ndarray[bool]
|
||||
An array with the same shape as self indicating whether each
|
||||
element of self was an element of ``container``.
|
||||
"""
|
||||
return self.apply(container.__contains__, dtype=bool)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
"""
|
||||
classifier.py
|
||||
"""
|
||||
from functools import wraps
|
||||
from numbers import Number
|
||||
import operator
|
||||
import re
|
||||
|
||||
from numpy import where, isnan, nan, zeros
|
||||
|
||||
@@ -28,7 +28,7 @@ from ..mixins import (
|
||||
)
|
||||
|
||||
|
||||
strings_only = restrict_to_dtype(
|
||||
string_classifiers_only = restrict_to_dtype(
|
||||
dtype=categorical_dtype,
|
||||
message_template=(
|
||||
"{method_name}() is only defined on Classifiers producing strings"
|
||||
@@ -95,7 +95,7 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm):
|
||||
binds=(self,),
|
||||
)
|
||||
else:
|
||||
return ScalarStringPredicate(
|
||||
return StringPredicate(
|
||||
classifier=self,
|
||||
op=operator.eq,
|
||||
compval=other,
|
||||
@@ -118,47 +118,152 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm):
|
||||
binds=(self,),
|
||||
)
|
||||
else:
|
||||
return ScalarStringPredicate(
|
||||
return StringPredicate(
|
||||
classifier=self,
|
||||
op=operator.ne,
|
||||
compval=other,
|
||||
)
|
||||
|
||||
def _string_predicate(f):
|
||||
@string_classifiers_only
|
||||
@expect_types(prefix=(bytes, unicode))
|
||||
def startswith(self, prefix):
|
||||
"""
|
||||
Decorator for converting a function from (LabelArray, str) -> bool
|
||||
into a Classifier method that returns a ScalarStringPredicate filter.
|
||||
Construct a Filter matching values starting with ``prefix``.
|
||||
|
||||
This mainly exists to avoid replicating shared boilerplate
|
||||
(e.g. argument type validation).
|
||||
Parameters
|
||||
----------
|
||||
prefix : str
|
||||
String prefix against which to compare values produced by ``self``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
matches : Filter
|
||||
Filter returning True for all sid/date pairs for which ``self``
|
||||
produces a string starting with ``prefix``.
|
||||
"""
|
||||
@wraps(f)
|
||||
@expect_types(compval=(bytes, unicode))
|
||||
@strings_only
|
||||
def method(self, compval):
|
||||
return ScalarStringPredicate(
|
||||
classifier=self,
|
||||
op=f,
|
||||
compval=compval,
|
||||
return StringPredicate(
|
||||
classifier=self,
|
||||
op=LabelArray.startswith,
|
||||
compval=prefix,
|
||||
)
|
||||
|
||||
@string_classifiers_only
|
||||
@expect_types(suffix=(bytes, unicode))
|
||||
def endswith(self, suffix):
|
||||
"""
|
||||
Construct a Filter matching values ending with ``suffix``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
suffix : str
|
||||
String suffix against which to compare values produced by ``self``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
matches : Filter
|
||||
Filter returning True for all sid/date pairs for which ``self``
|
||||
produces a string ending with ``prefix``.
|
||||
"""
|
||||
return StringPredicate(
|
||||
classifier=self,
|
||||
op=LabelArray.endswith,
|
||||
compval=suffix,
|
||||
)
|
||||
|
||||
@string_classifiers_only
|
||||
@expect_types(substring=(bytes, unicode))
|
||||
def has_substring(self, substring):
|
||||
"""
|
||||
Construct a Filter matching values containing ``substring``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
substring : str
|
||||
Sub-string against which to compare values produced by ``self``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
matches : Filter
|
||||
Filter returning True for all sid/date pairs for which ``self``
|
||||
produces a string containing ``substring``.
|
||||
"""
|
||||
return StringPredicate(
|
||||
classifier=self,
|
||||
op=LabelArray.has_substring,
|
||||
compval=substring,
|
||||
)
|
||||
|
||||
@string_classifiers_only
|
||||
@expect_types(pattern=(bytes, unicode, type(re.compile(''))))
|
||||
def matches(self, pattern):
|
||||
"""
|
||||
Construct a Filter that checks regex matches against ``pattern``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pattern : str
|
||||
Regex pattern against which to compare values produced by ``self``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
matches : Filter
|
||||
Filter returning True for all sid/date pairs for which ``self``
|
||||
produces a string matched by ``pattern``.
|
||||
|
||||
See Also
|
||||
--------
|
||||
https://docs.python.org/library/re.html
|
||||
"""
|
||||
return StringPredicate(
|
||||
classifier=self,
|
||||
op=LabelArray.matches,
|
||||
compval=pattern,
|
||||
)
|
||||
|
||||
@string_classifiers_only
|
||||
def element_of(self, choices):
|
||||
"""
|
||||
Construct a Filter indicating whether values are in ``choices``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
choices : iterable[str]
|
||||
An iterable of choices.
|
||||
|
||||
Returns
|
||||
-------
|
||||
matches : Filter
|
||||
Filter returning True for all sid/date pairs for which ``self``
|
||||
produces a string in ``choices``.
|
||||
"""
|
||||
try:
|
||||
choices = frozenset(choices)
|
||||
except Exception as e:
|
||||
raise TypeError(
|
||||
"Expected `choices` to be an iterable of strings,"
|
||||
" but got {} instead.\n"
|
||||
"This caused the following error: {!r}.".format(choices, e)
|
||||
)
|
||||
return method
|
||||
|
||||
@_string_predicate
|
||||
@expect_types(label_array=LabelArray)
|
||||
def startswith(label_array, other):
|
||||
return label_array.startswith(other)
|
||||
if self.missing_value in choices:
|
||||
raise ValueError(
|
||||
"Found self.missing_value ({mv!r}) in choices supplied to"
|
||||
" {typename}.is_element().\n"
|
||||
"Missing values have NaN semantics, so the"
|
||||
" requested comparison would always produce False.\n"
|
||||
"Use the isnull() method to check for missing values.\n"
|
||||
"Received choices were {choices}.".format(
|
||||
mv=self.missing_value,
|
||||
typename=(type(self).__name__),
|
||||
choices=sorted(choices),
|
||||
)
|
||||
)
|
||||
|
||||
@_string_predicate
|
||||
@expect_types(label_array=LabelArray)
|
||||
def endswith(label_array, other):
|
||||
return label_array.endswith(other)
|
||||
|
||||
@_string_predicate
|
||||
@expect_types(label_array=LabelArray)
|
||||
def contains(label_array, other):
|
||||
return label_array.contains(other)
|
||||
|
||||
del _string_predicate
|
||||
return StringPredicate(
|
||||
classifier=self,
|
||||
op=LabelArray.element_of,
|
||||
compval=choices,
|
||||
)
|
||||
|
||||
def postprocess(self, data):
|
||||
if self.dtype == int64_dtype:
|
||||
@@ -208,22 +313,17 @@ class Quantiles(SingleInputMixin, Classifier):
|
||||
return type(self).__name__ + '(%d)' % self.params['bins']
|
||||
|
||||
|
||||
class ScalarStringPredicate(SingleInputMixin, Filter):
|
||||
class StringPredicate(SingleInputMixin, Filter):
|
||||
"""
|
||||
A filter that applies a function from (LabelArray, str) -> ndarray[bool].
|
||||
A filter applying a function from (LabelArray, hashable) -> ndarray[bool].
|
||||
|
||||
Examples include ``==, !=, startswith, and endswith``.
|
||||
|
||||
This exists because we represent string arrays with
|
||||
``zipline.lib.LabelArray``s, which numexpr doesn't know about, so we can't
|
||||
use the generic NumExprFilter implementation here.
|
||||
Examples include ``==, !=, startswith, and is_element``.
|
||||
"""
|
||||
window_length = 0
|
||||
|
||||
@expect_types(classifier=Classifier, compval=(bytes, unicode))
|
||||
def __new__(cls, classifier, op, compval):
|
||||
return super(ScalarStringPredicate, cls).__new__(
|
||||
ScalarStringPredicate,
|
||||
return super(StringPredicate, cls).__new__(
|
||||
StringPredicate,
|
||||
compval=compval,
|
||||
op=op,
|
||||
inputs=(classifier,),
|
||||
@@ -233,12 +333,12 @@ class ScalarStringPredicate(SingleInputMixin, Filter):
|
||||
def _init(self, op, compval, *args, **kwargs):
|
||||
self._op = op
|
||||
self._compval = compval
|
||||
return super(ScalarStringPredicate, self)._init(*args, **kwargs)
|
||||
return super(StringPredicate, self)._init(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def static_identity(cls, op, compval, *args, **kwargs):
|
||||
return (
|
||||
super(ScalarStringPredicate, cls).static_identity(*args, **kwargs),
|
||||
super(StringPredicate, cls).static_identity(*args, **kwargs),
|
||||
op,
|
||||
compval,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user