Merge pull request #1195 from quantopian/test-string-groupby

String Classifier Cleanup
This commit is contained in:
Scott Sanderson
2016-05-10 20:54:24 -04:00
7 changed files with 111 additions and 13 deletions
+4 -4
View File
@@ -8,7 +8,6 @@ from zipline.pipeline import Classifier
from zipline.testing import parameter_space
from zipline.utils.numpy_utils import (
categorical_dtype,
coerce_to_dtype,
int64_dtype,
)
@@ -162,7 +161,8 @@ class ClassifierTestCase(BasePipelineTestCase):
dtype_=[int64_dtype, categorical_dtype],
)
def test_disallow_comparison_to_missing_value(self, missing, dtype_):
missing = coerce_to_dtype(dtype_, missing)
if dtype_ == categorical_dtype:
missing = str(missing)
class C(Classifier):
dtype = dtype_
@@ -434,7 +434,7 @@ class ClassifierTestCase(BasePipelineTestCase):
errmsg = str(e.exception)
expected = (
"Found self.missing_value ('not in the array') in choices"
" supplied to C.is_element().\n"
" supplied to C.element_of().\n"
"Missing values have NaN semantics, so the requested"
" comparison would always produce False.\n"
"Use the isnull() method to check for missing values.\n"
@@ -447,7 +447,7 @@ class ClassifierTestCase(BasePipelineTestCase):
class C(Classifier):
dtype = dtype_
missing_value = ''
missing_value = dtype.type('1')
inputs = ()
window_length = 0
+13
View File
@@ -23,6 +23,7 @@ from numpy import (
from numpy.random import randn, seed
from zipline.errors import UnknownRankMethod
from zipline.lib.labelarray import LabelArray
from zipline.lib.rank import masked_rankdata_2d
from zipline.lib.normalize import naive_grouped_rowwise_apply as grouped_apply
from zipline.pipeline import Classifier, Factor, Filter, TermGraph
@@ -38,6 +39,7 @@ from zipline.testing import (
)
from zipline.utils.functional import dzip_exact
from zipline.utils.numpy_utils import (
categorical_dtype,
datetime64ns_dtype,
float64_dtype,
int64_dtype,
@@ -442,6 +444,7 @@ class FactorTestCase(BasePipelineTestCase):
f = self.f
m = Mask()
c = C()
str_c = C(dtype=categorical_dtype, missing_value=None)
factor_data = array(
[[1.0, 2.0, 3.0, 4.0],
@@ -463,12 +466,18 @@ class FactorTestCase(BasePipelineTestCase):
[1, 1, 2, 2]],
dtype=int64_dtype,
)
string_classifier_data = LabelArray(
classifier_data.astype(str).astype(object),
missing_value=None,
)
terms = {
'vanilla': f.demean(),
'masked': f.demean(mask=m),
'grouped': f.demean(groupby=c),
'grouped_str': f.demean(groupby=str_c),
'grouped_masked': f.demean(mask=m, groupby=c),
'grouped_masked_str': f.demean(mask=m, groupby=str_c),
}
expected = {
'vanilla': array(
@@ -496,6 +505,9 @@ class FactorTestCase(BasePipelineTestCase):
[-0.500, 0.500, 0.000, nan]]
)
}
# Changing the classifier dtype shouldn't affect anything.
expected['grouped_str'] = expected['grouped']
expected['grouped_masked_str'] = expected['grouped_masked']
graph = TermGraph(terms)
results = self.run_graph(
@@ -503,6 +515,7 @@ class FactorTestCase(BasePipelineTestCase):
initial_workspace={
f: factor_data,
c: classifier_data,
str_c: string_classifier_data,
m: filter_data,
},
mask=self.build_mask(self.ones_mask(shape=factor_data.shape)),
+23
View File
@@ -602,3 +602,26 @@ class SubDataSetTestCase(TestCase):
with self.assertRaises(ValueError) as e:
SomeClassifier()
self.assertEqual(str(e.exception), expected_error)
def test_unreasonable_missing_values(self):
for base_type, dtype_, bad_mv in ((Factor, float64_dtype, 'ayy'),
(Filter, bool_dtype, 'lmao'),
(Classifier, int64_dtype, 'lolwut'),
(Classifier, categorical_dtype, 7)):
class SomeTerm(base_type):
inputs = ()
window_length = 0
missing_value = bad_mv
dtype = dtype_
with self.assertRaises(TypeError) as e:
SomeTerm()
prefix = (
"^Missing value {mv!r} is not a valid choice "
"for term SomeTerm with dtype {dtype}.\n\n"
"Coercion attempt failed with:"
).format(mv=bad_mv, dtype=dtype_)
self.assertRegexpMatches(str(e.exception), prefix)
+4
View File
@@ -211,6 +211,10 @@ class LabelArray(ndarray):
# This is a property because it should be immutable.
return self._missing_value
@property
def missing_value_code(self):
return self.reverse_categories[self.missing_value]
def has_label(self, value):
return value in self.reverse_categories
+2 -1
View File
@@ -246,7 +246,7 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm):
if self.missing_value in choices:
raise ValueError(
"Found self.missing_value ({mv!r}) in choices supplied to"
" {typename}.is_element().\n"
" {typename}.{meth_name}().\n"
"Missing values have NaN semantics, so the"
" requested comparison would always produce False.\n"
"Use the isnull() method to check for missing values.\n"
@@ -254,6 +254,7 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm):
mv=self.missing_value,
typename=(type(self).__name__),
choices=sorted(choices),
meth_name=self.element_of.__name__,
)
)
+19 -7
View File
@@ -47,6 +47,7 @@ from zipline.utils.input_validation import expect_types
from zipline.utils.math_utils import nanmean, nanstd
from zipline.utils.numpy_utils import (
bool_dtype,
categorical_dtype,
coerce_to_dtype,
datetime64ns_dtype,
float64_dtype,
@@ -939,15 +940,26 @@ class GroupedRowTransform(Factor):
def _compute(self, arrays, dates, assets, mask):
data = arrays[0]
null_group_value = self.inputs[1].missing_value
group_labels = where(
mask,
arrays[1],
null_group_value,
)
groupby_expr = self.inputs[1]
if groupby_expr.dtype == int64_dtype:
group_labels = arrays[1]
null_label = self.inputs[1].missing_value
elif groupby_expr.dtype == categorical_dtype:
# Coerce our LabelArray into an isomorphic array of ints. This is
# necessary because np.where doesn't know about LabelArrays or the
# void dtype.
group_labels = arrays[1].as_int_array()
null_label = arrays[1].missing_value_code
else:
raise TypeError(
"Unexpected groupby dtype: %s." % groupby_expr.dtype
)
# Make a copy with the null code written to masked locations.
group_labels = where(mask, group_labels, null_label)
return where(
group_labels != null_group_value,
group_labels != null_label,
naive_grouped_rowwise_apply(
data=data,
group_labels=group_labels,
+46 -1
View File
@@ -4,7 +4,7 @@ Base class for Filters, Factors and Classifiers
from abc import ABCMeta, abstractproperty
from weakref import WeakValueDictionary
from numpy import dtype as dtype_class, ndarray
from numpy import array, dtype as dtype_class, ndarray
from six import with_metaclass
from zipline.errors import (
DTypeNotSpecified,
@@ -16,10 +16,12 @@ from zipline.errors import (
WindowLengthNotSpecified,
)
from zipline.lib.adjusted_array import can_represent_dtype
from zipline.lib.labelarray import LabelArray
from zipline.utils.input_validation import expect_types
from zipline.utils.memoize import lazyval
from zipline.utils.numpy_utils import (
bool_dtype,
categorical_dtype,
default_missing_value_for_dtype,
)
from zipline.utils.sentinel import sentinel
@@ -177,6 +179,7 @@ class Term(with_metaclass(ABCMeta, object)):
"""
if dtype is NotSpecified:
raise DTypeNotSpecified(termname=termname)
try:
dtype = dtype_class(dtype)
except TypeError:
@@ -188,6 +191,31 @@ class Term(with_metaclass(ABCMeta, object)):
if missing_value is NotSpecified:
missing_value = default_missing_value_for_dtype(dtype)
try:
if (dtype == categorical_dtype):
# This check is necessary because we use object dtype for
# categoricals, and numpy will allow us to promote numerical
# values to object even though we don't support them.
_assert_valid_categorical_missing_value(missing_value)
# For any other type, we can check if the missing_value is safe by
# making an array of that value and trying to safely convert it to
# the desired type.
# 'same_kind' allows casting between things like float32 and
# float64, but not str and int.
array([missing_value]).astype(dtype=dtype, casting='same_kind')
except TypeError as e:
raise TypeError(
"Missing value {value!r} is not a valid choice "
"for term {termname} with dtype {dtype}.\n\n"
"Coercion attempt failed with: {error}".format(
termname=termname,
value=missing_value,
dtype=dtype,
error=e,
)
)
return dtype, missing_value
def __init__(self, *args, **kwargs):
@@ -498,3 +526,20 @@ class ComputableTerm(Term):
inputs=self.inputs,
window_length=self.window_length,
)
def _assert_valid_categorical_missing_value(value):
"""
Check that value is a valid categorical missing_value.
Raises a TypeError if the value is cannot be used as the missing_value for
a categorical_dtype Term.
"""
label_types = LabelArray.SUPPORTED_SCALAR_TYPES
if not isinstance(value, label_types):
raise TypeError(
"Categorical terms must have missing values of type "
"{types}.".format(
types=' or '.join([t.__name__ for t in label_types]),
)
)