mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-05 06:46:10 +08:00
Merge pull request #1195 from quantopian/test-string-groupby
String Classifier Cleanup
This commit is contained in:
@@ -8,7 +8,6 @@ from zipline.pipeline import Classifier
|
||||
from zipline.testing import parameter_space
|
||||
from zipline.utils.numpy_utils import (
|
||||
categorical_dtype,
|
||||
coerce_to_dtype,
|
||||
int64_dtype,
|
||||
)
|
||||
|
||||
@@ -162,7 +161,8 @@ class ClassifierTestCase(BasePipelineTestCase):
|
||||
dtype_=[int64_dtype, categorical_dtype],
|
||||
)
|
||||
def test_disallow_comparison_to_missing_value(self, missing, dtype_):
|
||||
missing = coerce_to_dtype(dtype_, missing)
|
||||
if dtype_ == categorical_dtype:
|
||||
missing = str(missing)
|
||||
|
||||
class C(Classifier):
|
||||
dtype = dtype_
|
||||
@@ -434,7 +434,7 @@ class ClassifierTestCase(BasePipelineTestCase):
|
||||
errmsg = str(e.exception)
|
||||
expected = (
|
||||
"Found self.missing_value ('not in the array') in choices"
|
||||
" supplied to C.is_element().\n"
|
||||
" supplied to C.element_of().\n"
|
||||
"Missing values have NaN semantics, so the requested"
|
||||
" comparison would always produce False.\n"
|
||||
"Use the isnull() method to check for missing values.\n"
|
||||
@@ -447,7 +447,7 @@ class ClassifierTestCase(BasePipelineTestCase):
|
||||
|
||||
class C(Classifier):
|
||||
dtype = dtype_
|
||||
missing_value = ''
|
||||
missing_value = dtype.type('1')
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ from numpy import (
|
||||
from numpy.random import randn, seed
|
||||
|
||||
from zipline.errors import UnknownRankMethod
|
||||
from zipline.lib.labelarray import LabelArray
|
||||
from zipline.lib.rank import masked_rankdata_2d
|
||||
from zipline.lib.normalize import naive_grouped_rowwise_apply as grouped_apply
|
||||
from zipline.pipeline import Classifier, Factor, Filter, TermGraph
|
||||
@@ -38,6 +39,7 @@ from zipline.testing import (
|
||||
)
|
||||
from zipline.utils.functional import dzip_exact
|
||||
from zipline.utils.numpy_utils import (
|
||||
categorical_dtype,
|
||||
datetime64ns_dtype,
|
||||
float64_dtype,
|
||||
int64_dtype,
|
||||
@@ -442,6 +444,7 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
f = self.f
|
||||
m = Mask()
|
||||
c = C()
|
||||
str_c = C(dtype=categorical_dtype, missing_value=None)
|
||||
|
||||
factor_data = array(
|
||||
[[1.0, 2.0, 3.0, 4.0],
|
||||
@@ -463,12 +466,18 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
[1, 1, 2, 2]],
|
||||
dtype=int64_dtype,
|
||||
)
|
||||
string_classifier_data = LabelArray(
|
||||
classifier_data.astype(str).astype(object),
|
||||
missing_value=None,
|
||||
)
|
||||
|
||||
terms = {
|
||||
'vanilla': f.demean(),
|
||||
'masked': f.demean(mask=m),
|
||||
'grouped': f.demean(groupby=c),
|
||||
'grouped_str': f.demean(groupby=str_c),
|
||||
'grouped_masked': f.demean(mask=m, groupby=c),
|
||||
'grouped_masked_str': f.demean(mask=m, groupby=str_c),
|
||||
}
|
||||
expected = {
|
||||
'vanilla': array(
|
||||
@@ -496,6 +505,9 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
[-0.500, 0.500, 0.000, nan]]
|
||||
)
|
||||
}
|
||||
# Changing the classifier dtype shouldn't affect anything.
|
||||
expected['grouped_str'] = expected['grouped']
|
||||
expected['grouped_masked_str'] = expected['grouped_masked']
|
||||
|
||||
graph = TermGraph(terms)
|
||||
results = self.run_graph(
|
||||
@@ -503,6 +515,7 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
initial_workspace={
|
||||
f: factor_data,
|
||||
c: classifier_data,
|
||||
str_c: string_classifier_data,
|
||||
m: filter_data,
|
||||
},
|
||||
mask=self.build_mask(self.ones_mask(shape=factor_data.shape)),
|
||||
|
||||
@@ -602,3 +602,26 @@ class SubDataSetTestCase(TestCase):
|
||||
with self.assertRaises(ValueError) as e:
|
||||
SomeClassifier()
|
||||
self.assertEqual(str(e.exception), expected_error)
|
||||
|
||||
def test_unreasonable_missing_values(self):
|
||||
|
||||
for base_type, dtype_, bad_mv in ((Factor, float64_dtype, 'ayy'),
|
||||
(Filter, bool_dtype, 'lmao'),
|
||||
(Classifier, int64_dtype, 'lolwut'),
|
||||
(Classifier, categorical_dtype, 7)):
|
||||
class SomeTerm(base_type):
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
missing_value = bad_mv
|
||||
dtype = dtype_
|
||||
|
||||
with self.assertRaises(TypeError) as e:
|
||||
SomeTerm()
|
||||
|
||||
prefix = (
|
||||
"^Missing value {mv!r} is not a valid choice "
|
||||
"for term SomeTerm with dtype {dtype}.\n\n"
|
||||
"Coercion attempt failed with:"
|
||||
).format(mv=bad_mv, dtype=dtype_)
|
||||
|
||||
self.assertRegexpMatches(str(e.exception), prefix)
|
||||
|
||||
@@ -211,6 +211,10 @@ class LabelArray(ndarray):
|
||||
# This is a property because it should be immutable.
|
||||
return self._missing_value
|
||||
|
||||
@property
|
||||
def missing_value_code(self):
|
||||
return self.reverse_categories[self.missing_value]
|
||||
|
||||
def has_label(self, value):
|
||||
return value in self.reverse_categories
|
||||
|
||||
|
||||
@@ -246,7 +246,7 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm):
|
||||
if self.missing_value in choices:
|
||||
raise ValueError(
|
||||
"Found self.missing_value ({mv!r}) in choices supplied to"
|
||||
" {typename}.is_element().\n"
|
||||
" {typename}.{meth_name}().\n"
|
||||
"Missing values have NaN semantics, so the"
|
||||
" requested comparison would always produce False.\n"
|
||||
"Use the isnull() method to check for missing values.\n"
|
||||
@@ -254,6 +254,7 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm):
|
||||
mv=self.missing_value,
|
||||
typename=(type(self).__name__),
|
||||
choices=sorted(choices),
|
||||
meth_name=self.element_of.__name__,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -47,6 +47,7 @@ from zipline.utils.input_validation import expect_types
|
||||
from zipline.utils.math_utils import nanmean, nanstd
|
||||
from zipline.utils.numpy_utils import (
|
||||
bool_dtype,
|
||||
categorical_dtype,
|
||||
coerce_to_dtype,
|
||||
datetime64ns_dtype,
|
||||
float64_dtype,
|
||||
@@ -939,15 +940,26 @@ class GroupedRowTransform(Factor):
|
||||
|
||||
def _compute(self, arrays, dates, assets, mask):
|
||||
data = arrays[0]
|
||||
null_group_value = self.inputs[1].missing_value
|
||||
group_labels = where(
|
||||
mask,
|
||||
arrays[1],
|
||||
null_group_value,
|
||||
)
|
||||
groupby_expr = self.inputs[1]
|
||||
if groupby_expr.dtype == int64_dtype:
|
||||
group_labels = arrays[1]
|
||||
null_label = self.inputs[1].missing_value
|
||||
elif groupby_expr.dtype == categorical_dtype:
|
||||
# Coerce our LabelArray into an isomorphic array of ints. This is
|
||||
# necessary because np.where doesn't know about LabelArrays or the
|
||||
# void dtype.
|
||||
group_labels = arrays[1].as_int_array()
|
||||
null_label = arrays[1].missing_value_code
|
||||
else:
|
||||
raise TypeError(
|
||||
"Unexpected groupby dtype: %s." % groupby_expr.dtype
|
||||
)
|
||||
|
||||
# Make a copy with the null code written to masked locations.
|
||||
group_labels = where(mask, group_labels, null_label)
|
||||
|
||||
return where(
|
||||
group_labels != null_group_value,
|
||||
group_labels != null_label,
|
||||
naive_grouped_rowwise_apply(
|
||||
data=data,
|
||||
group_labels=group_labels,
|
||||
|
||||
@@ -4,7 +4,7 @@ Base class for Filters, Factors and Classifiers
|
||||
from abc import ABCMeta, abstractproperty
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from numpy import dtype as dtype_class, ndarray
|
||||
from numpy import array, dtype as dtype_class, ndarray
|
||||
from six import with_metaclass
|
||||
from zipline.errors import (
|
||||
DTypeNotSpecified,
|
||||
@@ -16,10 +16,12 @@ from zipline.errors import (
|
||||
WindowLengthNotSpecified,
|
||||
)
|
||||
from zipline.lib.adjusted_array import can_represent_dtype
|
||||
from zipline.lib.labelarray import LabelArray
|
||||
from zipline.utils.input_validation import expect_types
|
||||
from zipline.utils.memoize import lazyval
|
||||
from zipline.utils.numpy_utils import (
|
||||
bool_dtype,
|
||||
categorical_dtype,
|
||||
default_missing_value_for_dtype,
|
||||
)
|
||||
from zipline.utils.sentinel import sentinel
|
||||
@@ -177,6 +179,7 @@ class Term(with_metaclass(ABCMeta, object)):
|
||||
"""
|
||||
if dtype is NotSpecified:
|
||||
raise DTypeNotSpecified(termname=termname)
|
||||
|
||||
try:
|
||||
dtype = dtype_class(dtype)
|
||||
except TypeError:
|
||||
@@ -188,6 +191,31 @@ class Term(with_metaclass(ABCMeta, object)):
|
||||
if missing_value is NotSpecified:
|
||||
missing_value = default_missing_value_for_dtype(dtype)
|
||||
|
||||
try:
|
||||
if (dtype == categorical_dtype):
|
||||
# This check is necessary because we use object dtype for
|
||||
# categoricals, and numpy will allow us to promote numerical
|
||||
# values to object even though we don't support them.
|
||||
_assert_valid_categorical_missing_value(missing_value)
|
||||
|
||||
# For any other type, we can check if the missing_value is safe by
|
||||
# making an array of that value and trying to safely convert it to
|
||||
# the desired type.
|
||||
# 'same_kind' allows casting between things like float32 and
|
||||
# float64, but not str and int.
|
||||
array([missing_value]).astype(dtype=dtype, casting='same_kind')
|
||||
except TypeError as e:
|
||||
raise TypeError(
|
||||
"Missing value {value!r} is not a valid choice "
|
||||
"for term {termname} with dtype {dtype}.\n\n"
|
||||
"Coercion attempt failed with: {error}".format(
|
||||
termname=termname,
|
||||
value=missing_value,
|
||||
dtype=dtype,
|
||||
error=e,
|
||||
)
|
||||
)
|
||||
|
||||
return dtype, missing_value
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -498,3 +526,20 @@ class ComputableTerm(Term):
|
||||
inputs=self.inputs,
|
||||
window_length=self.window_length,
|
||||
)
|
||||
|
||||
|
||||
def _assert_valid_categorical_missing_value(value):
|
||||
"""
|
||||
Check that value is a valid categorical missing_value.
|
||||
|
||||
Raises a TypeError if the value is cannot be used as the missing_value for
|
||||
a categorical_dtype Term.
|
||||
"""
|
||||
label_types = LabelArray.SUPPORTED_SCALAR_TYPES
|
||||
if not isinstance(value, label_types):
|
||||
raise TypeError(
|
||||
"Categorical terms must have missing values of type "
|
||||
"{types}.".format(
|
||||
types=' or '.join([t.__name__ for t in label_types]),
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user