MAINT: Fail fast on unsupported dtypes.

This commit is contained in:
Scott Sanderson
2016-01-21 17:55:21 -05:00
parent 3cfc22ed77
commit 0115cdc46c
6 changed files with 106 additions and 44 deletions
+18 -4
View File
@@ -8,8 +8,9 @@ from unittest import TestCase
from zipline.errors import (
DTypeNotSpecified,
InputTermNotAtomic,
InvalidDType,
NotDType,
TermInputsNotSpecified,
UnsupportedDType,
WindowLengthNotSpecified,
)
from zipline.pipeline import Factor, Filter, TermGraph
@@ -19,6 +20,7 @@ from zipline.pipeline.term import AssetExists, NotSpecified
from zipline.pipeline.expression import NUMEXPR_MATH_FUNCS
from zipline.utils.numpy_utils import (
bool_dtype,
complex128_dtype,
datetime64ns_dtype,
float64_dtype,
int64_dtype,
@@ -332,9 +334,15 @@ class ObjectIdentityTestCase(TestCase):
with self.assertRaises(DTypeNotSpecified):
SomeFactorNoDType()
with self.assertRaises(InvalidDType):
with self.assertRaises(NotDType):
SomeFactor(dtype=1)
with self.assertRaises(NoDefaultMissingValue):
SomeFactor(dtype=int64_dtype)
with self.assertRaises(UnsupportedDType):
SomeFactor(dtype=complex128_dtype)
def test_latest_on_different_dtypes(self):
factor_dtypes = (int64_dtype, float64_dtype, datetime64ns_dtype)
for column in TestingDataSet.columns:
@@ -350,11 +358,10 @@ class ObjectIdentityTestCase(TestCase):
# property of correctly handling `NaN`.
self.assertIs(column.missing_value, column.latest.missing_value)
def test_failure_timing_on_bad_missing_values(self):
def test_failure_timing_on_bad_dtypes(self):
# Just constructing a bad column shouldn't fail.
Column(dtype=int64_dtype)
with self.assertRaises(NoDefaultMissingValue) as e:
class BadDataSet(DataSet):
bad_column = Column(dtype=int64_dtype)
@@ -367,6 +374,13 @@ class ObjectIdentityTestCase(TestCase):
)
)
Column(dtype=complex128_dtype)
with self.assertRaises(UnsupportedDType):
class BadDataSetComplex(DataSet):
bad_column = Column(dtype=complex128_dtype)
float_column = Column(dtype=float64_dtype)
int_column = Column(dtype=int64_dtype, missing_value=3)
class SubDataSetTestCase(TestCase):
def test_subdataset(self):
+12 -1
View File
@@ -396,7 +396,7 @@ class DTypeNotSpecified(ZiplineError):
)
class InvalidDType(ZiplineError):
class NotDType(ZiplineError):
"""
Raised when a pipeline Term is constructed with a dtype that isn't a numpy
dtype object.
@@ -407,6 +407,17 @@ class InvalidDType(ZiplineError):
)
class UnsupportedDType(ZiplineError):
"""
Raised when a pipeline Term is constructed with a dtype that's not
supported.
"""
msg = (
"Failed to construct {termname}.\n"
"Pipeline terms of dtype {dtype} are not yet supported."
)
class BadPercentileBounds(ZiplineError):
"""
Raised by API functions accepting percentile bounds when the passed bounds
+24 -2
View File
@@ -7,6 +7,8 @@ from numpy import (
float64,
int32,
int64,
int16,
uint16,
ndarray,
uint32,
uint8,
@@ -29,13 +31,33 @@ from ._int64window import AdjustedArrayWindow as Int64Window
from ._uint8window import AdjustedArrayWindow as UInt8Window
NOMASK = None
BOOL_DTYPES = frozenset(
map(dtype, [bool_]),
)
FLOAT_DTYPES = frozenset(
map(dtype, [float32, float64, int32]),
map(dtype, [float32, float64]),
)
INT_DTYPES = frozenset(
# NOTE: uint64 not supported because it can't be safely cast to int64.
map(dtype, [int32, int64, uint32]),
map(dtype, [int16, uint16, int32, int64, uint32]),
)
DATETIME_DTYPES = frozenset(
map(dtype, ['datetime64[ns]', 'datetime64[D]']),
)
REPRESENTABLE_DTYPES = BOOL_DTYPES.union(
FLOAT_DTYPES,
INT_DTYPES,
DATETIME_DTYPES
)
def can_represent_dtype(dtype):
"""
Can we build an AdjustedArray for a baseline of dtype ``dtype``?
"""
return dtype in REPRESENTABLE_DTYPES
CONCRETE_WINDOW_TYPES = {
float64_dtype: Float64Window,
int64_dtype: Int64Window,
+25 -21
View File
@@ -7,7 +7,11 @@ from six import (
with_metaclass,
)
from zipline.pipeline.term import Term, AssetExists, NotSpecified
from zipline.pipeline.term import (
Term,
AssetExists,
NotSpecified,
)
from zipline.utils.input_validation import ensure_dtype
from zipline.utils.numpy_utils import (
bool_dtype,
@@ -47,26 +51,26 @@ class _BoundColumnDescr(object):
parent classes.
"""
def __init__(self, dtype, missing_value, name):
self.dtype = dtype
# Calculating missing values here guarantees that we fail quickly if
# the user fails to provide a missing value for a dtype that requires
# one (e.g. int64), but still enables us to provide an error message
# that points to the name of the failing column.
if missing_value is NotSpecified:
try:
missing_value = default_missing_value_for_dtype(dtype)
except NoDefaultMissingValue:
# Re-raise with a better message.
raise NoDefaultMissingValue(
"Failed to create Column with name {name!r} and"
" dtype {dtype} because no missing_value was provided\n\n"
"Columns with dtype {dtype} require a missing_value.\n"
"Please pass missing_value to Column() or use a different"
" dtype.".format(dtype=dtype, name=name)
)
self.missing_value = missing_value
# Validating and calculating default missing values here guarantees
# that we fail quickly if the user passes an unsupporte dtype or fails
# to provide a missing value for a dtype that requires one
# (e.g. int64), but still enables us to provide an error message that
# points to the name of the failing column.
try:
self.dtype, self.missing_value = Term.validate_dtype(
termname="Column(name={name!r})".format(name=name),
dtype=dtype,
missing_value=missing_value,
)
except NoDefaultMissingValue:
# Re-raise with a more specific message.
raise NoDefaultMissingValue(
"Failed to create Column with name {name!r} and"
" dtype {dtype} because no missing_value was provided\n\n"
"Columns with dtype {dtype} require a missing_value.\n"
"Please pass missing_value to Column() or use a different"
" dtype.".format(dtype=dtype, name=name)
)
self.name = name
def __get__(self, instance, owner):
+24 -16
View File
@@ -6,14 +6,15 @@ from weakref import WeakValueDictionary
from numpy import dtype as dtype_class
from six import with_metaclass
from zipline.errors import (
DTypeNotSpecified,
InputTermNotAtomic,
InvalidDType,
NotDType,
TermInputsNotSpecified,
UnsupportedDType,
WindowLengthNotSpecified,
)
from zipline.lib.adjusted_array import can_represent_dtype
from zipline.utils.memoize import lazyval
from zipline.utils.numpy_utils import (
bool_dtype,
@@ -69,7 +70,11 @@ class Term(with_metaclass(ABCMeta, object)):
if missing_value is NotSpecified:
missing_value = cls.missing_value
dtype, missing_value = cls._validate_dtype(dtype, missing_value)
dtype, missing_value = cls.validate_dtype(
cls.__name__,
dtype,
missing_value,
)
params = cls._pop_params(kwargs)
identity = cls.static_identity(
@@ -141,37 +146,40 @@ class Term(with_metaclass(ABCMeta, object)):
)
return tuple(zip(cls.params, param_values))
@classmethod
def _validate_dtype(cls, passed_dtype, missing_value):
@staticmethod
def validate_dtype(termname, dtype, missing_value):
"""
Validate `dtype` passed to Term.__new__.
Validate a `dtype` and `missing_value` passed to Term.__new__.
If passed_dtype is NotSpecified, then we try to fall back to a
class-level attribute. If a value is found at that point, we pass it
to np.dtype so that users can pass `float` or `bool` and have them
coerce to the appropriate numpy types.
Ensures that we know how to represent ``dtype``, and that missing_value
is specified for types without default missing values.
Returns
-------
validated : np.dtype
The dtype to use for the new term.
validated_dtype, validated_missing_value : np.dtype, any
The dtype and missing_value to use for the new term.
Raises
------
DTypeNotSpecified
When no dtype was passed to the instance, and the class doesn't
provide a default.
InvalidDType
NotDType
When either the class or the instance provides a value not
coercible to a numpy dtype.
NoDefaultMissingValue
When dtype requires an explicit missing_value, but
``missing_value`` is NotSpecified.
"""
dtype = passed_dtype
if dtype is NotSpecified:
raise DTypeNotSpecified(termname=cls.__name__)
raise DTypeNotSpecified(termname=termname)
try:
dtype = dtype_class(dtype)
except TypeError:
raise InvalidDType(dtype=dtype, termname=cls.__name__)
raise NotDType(dtype=dtype, termname=termname)
if not can_represent_dtype(dtype):
raise UnsupportedDType(dtype=dtype, termname=termname)
if missing_value is NotSpecified:
missing_value = default_missing_value_for_dtype(dtype)
+3
View File
@@ -15,11 +15,14 @@ from toolz import flip
uint8_dtype = dtype('uint8')
bool_dtype = dtype('bool')
int64_dtype = dtype('int64')
float32_dtype = dtype('float32')
float64_dtype = dtype('float64')
complex128_dtype = dtype('complex128')
datetime64D_dtype = dtype('datetime64[D]')
datetime64ns_dtype = dtype('datetime64[ns]')