mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-29 09:21:05 +08:00
MAINT: Fail fast on unsupported dtypes.
This commit is contained in:
@@ -8,8 +8,9 @@ from unittest import TestCase
|
||||
from zipline.errors import (
|
||||
DTypeNotSpecified,
|
||||
InputTermNotAtomic,
|
||||
InvalidDType,
|
||||
NotDType,
|
||||
TermInputsNotSpecified,
|
||||
UnsupportedDType,
|
||||
WindowLengthNotSpecified,
|
||||
)
|
||||
from zipline.pipeline import Factor, Filter, TermGraph
|
||||
@@ -19,6 +20,7 @@ from zipline.pipeline.term import AssetExists, NotSpecified
|
||||
from zipline.pipeline.expression import NUMEXPR_MATH_FUNCS
|
||||
from zipline.utils.numpy_utils import (
|
||||
bool_dtype,
|
||||
complex128_dtype,
|
||||
datetime64ns_dtype,
|
||||
float64_dtype,
|
||||
int64_dtype,
|
||||
@@ -332,9 +334,15 @@ class ObjectIdentityTestCase(TestCase):
|
||||
with self.assertRaises(DTypeNotSpecified):
|
||||
SomeFactorNoDType()
|
||||
|
||||
with self.assertRaises(InvalidDType):
|
||||
with self.assertRaises(NotDType):
|
||||
SomeFactor(dtype=1)
|
||||
|
||||
with self.assertRaises(NoDefaultMissingValue):
|
||||
SomeFactor(dtype=int64_dtype)
|
||||
|
||||
with self.assertRaises(UnsupportedDType):
|
||||
SomeFactor(dtype=complex128_dtype)
|
||||
|
||||
def test_latest_on_different_dtypes(self):
|
||||
factor_dtypes = (int64_dtype, float64_dtype, datetime64ns_dtype)
|
||||
for column in TestingDataSet.columns:
|
||||
@@ -350,11 +358,10 @@ class ObjectIdentityTestCase(TestCase):
|
||||
# property of correctly handling `NaN`.
|
||||
self.assertIs(column.missing_value, column.latest.missing_value)
|
||||
|
||||
def test_failure_timing_on_bad_missing_values(self):
|
||||
def test_failure_timing_on_bad_dtypes(self):
|
||||
|
||||
# Just constructing a bad column shouldn't fail.
|
||||
Column(dtype=int64_dtype)
|
||||
|
||||
with self.assertRaises(NoDefaultMissingValue) as e:
|
||||
class BadDataSet(DataSet):
|
||||
bad_column = Column(dtype=int64_dtype)
|
||||
@@ -367,6 +374,13 @@ class ObjectIdentityTestCase(TestCase):
|
||||
)
|
||||
)
|
||||
|
||||
Column(dtype=complex128_dtype)
|
||||
with self.assertRaises(UnsupportedDType):
|
||||
class BadDataSetComplex(DataSet):
|
||||
bad_column = Column(dtype=complex128_dtype)
|
||||
float_column = Column(dtype=float64_dtype)
|
||||
int_column = Column(dtype=int64_dtype, missing_value=3)
|
||||
|
||||
|
||||
class SubDataSetTestCase(TestCase):
|
||||
def test_subdataset(self):
|
||||
|
||||
+12
-1
@@ -396,7 +396,7 @@ class DTypeNotSpecified(ZiplineError):
|
||||
)
|
||||
|
||||
|
||||
class InvalidDType(ZiplineError):
|
||||
class NotDType(ZiplineError):
|
||||
"""
|
||||
Raised when a pipeline Term is constructed with a dtype that isn't a numpy
|
||||
dtype object.
|
||||
@@ -407,6 +407,17 @@ class InvalidDType(ZiplineError):
|
||||
)
|
||||
|
||||
|
||||
class UnsupportedDType(ZiplineError):
|
||||
"""
|
||||
Raised when a pipeline Term is constructed with a dtype that's not
|
||||
supported.
|
||||
"""
|
||||
msg = (
|
||||
"Failed to construct {termname}.\n"
|
||||
"Pipeline terms of dtype {dtype} are not yet supported."
|
||||
)
|
||||
|
||||
|
||||
class BadPercentileBounds(ZiplineError):
|
||||
"""
|
||||
Raised by API functions accepting percentile bounds when the passed bounds
|
||||
|
||||
@@ -7,6 +7,8 @@ from numpy import (
|
||||
float64,
|
||||
int32,
|
||||
int64,
|
||||
int16,
|
||||
uint16,
|
||||
ndarray,
|
||||
uint32,
|
||||
uint8,
|
||||
@@ -29,13 +31,33 @@ from ._int64window import AdjustedArrayWindow as Int64Window
|
||||
from ._uint8window import AdjustedArrayWindow as UInt8Window
|
||||
|
||||
NOMASK = None
|
||||
BOOL_DTYPES = frozenset(
|
||||
map(dtype, [bool_]),
|
||||
)
|
||||
FLOAT_DTYPES = frozenset(
|
||||
map(dtype, [float32, float64, int32]),
|
||||
map(dtype, [float32, float64]),
|
||||
)
|
||||
INT_DTYPES = frozenset(
|
||||
# NOTE: uint64 not supported because it can't be safely cast to int64.
|
||||
map(dtype, [int32, int64, uint32]),
|
||||
map(dtype, [int16, uint16, int32, int64, uint32]),
|
||||
)
|
||||
DATETIME_DTYPES = frozenset(
|
||||
map(dtype, ['datetime64[ns]', 'datetime64[D]']),
|
||||
)
|
||||
REPRESENTABLE_DTYPES = BOOL_DTYPES.union(
|
||||
FLOAT_DTYPES,
|
||||
INT_DTYPES,
|
||||
DATETIME_DTYPES
|
||||
)
|
||||
|
||||
|
||||
def can_represent_dtype(dtype):
|
||||
"""
|
||||
Can we build an AdjustedArray for a baseline of dtype ``dtype``?
|
||||
"""
|
||||
return dtype in REPRESENTABLE_DTYPES
|
||||
|
||||
|
||||
CONCRETE_WINDOW_TYPES = {
|
||||
float64_dtype: Float64Window,
|
||||
int64_dtype: Int64Window,
|
||||
|
||||
@@ -7,7 +7,11 @@ from six import (
|
||||
with_metaclass,
|
||||
)
|
||||
|
||||
from zipline.pipeline.term import Term, AssetExists, NotSpecified
|
||||
from zipline.pipeline.term import (
|
||||
Term,
|
||||
AssetExists,
|
||||
NotSpecified,
|
||||
)
|
||||
from zipline.utils.input_validation import ensure_dtype
|
||||
from zipline.utils.numpy_utils import (
|
||||
bool_dtype,
|
||||
@@ -47,26 +51,26 @@ class _BoundColumnDescr(object):
|
||||
parent classes.
|
||||
"""
|
||||
def __init__(self, dtype, missing_value, name):
|
||||
self.dtype = dtype
|
||||
|
||||
# Calculating missing values here guarantees that we fail quickly if
|
||||
# the user fails to provide a missing value for a dtype that requires
|
||||
# one (e.g. int64), but still enables us to provide an error message
|
||||
# that points to the name of the failing column.
|
||||
if missing_value is NotSpecified:
|
||||
try:
|
||||
missing_value = default_missing_value_for_dtype(dtype)
|
||||
except NoDefaultMissingValue:
|
||||
# Re-raise with a better message.
|
||||
raise NoDefaultMissingValue(
|
||||
"Failed to create Column with name {name!r} and"
|
||||
" dtype {dtype} because no missing_value was provided\n\n"
|
||||
"Columns with dtype {dtype} require a missing_value.\n"
|
||||
"Please pass missing_value to Column() or use a different"
|
||||
" dtype.".format(dtype=dtype, name=name)
|
||||
)
|
||||
|
||||
self.missing_value = missing_value
|
||||
# Validating and calculating default missing values here guarantees
|
||||
# that we fail quickly if the user passes an unsupporte dtype or fails
|
||||
# to provide a missing value for a dtype that requires one
|
||||
# (e.g. int64), but still enables us to provide an error message that
|
||||
# points to the name of the failing column.
|
||||
try:
|
||||
self.dtype, self.missing_value = Term.validate_dtype(
|
||||
termname="Column(name={name!r})".format(name=name),
|
||||
dtype=dtype,
|
||||
missing_value=missing_value,
|
||||
)
|
||||
except NoDefaultMissingValue:
|
||||
# Re-raise with a more specific message.
|
||||
raise NoDefaultMissingValue(
|
||||
"Failed to create Column with name {name!r} and"
|
||||
" dtype {dtype} because no missing_value was provided\n\n"
|
||||
"Columns with dtype {dtype} require a missing_value.\n"
|
||||
"Please pass missing_value to Column() or use a different"
|
||||
" dtype.".format(dtype=dtype, name=name)
|
||||
)
|
||||
self.name = name
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
|
||||
+24
-16
@@ -6,14 +6,15 @@ from weakref import WeakValueDictionary
|
||||
|
||||
from numpy import dtype as dtype_class
|
||||
from six import with_metaclass
|
||||
|
||||
from zipline.errors import (
|
||||
DTypeNotSpecified,
|
||||
InputTermNotAtomic,
|
||||
InvalidDType,
|
||||
NotDType,
|
||||
TermInputsNotSpecified,
|
||||
UnsupportedDType,
|
||||
WindowLengthNotSpecified,
|
||||
)
|
||||
from zipline.lib.adjusted_array import can_represent_dtype
|
||||
from zipline.utils.memoize import lazyval
|
||||
from zipline.utils.numpy_utils import (
|
||||
bool_dtype,
|
||||
@@ -69,7 +70,11 @@ class Term(with_metaclass(ABCMeta, object)):
|
||||
if missing_value is NotSpecified:
|
||||
missing_value = cls.missing_value
|
||||
|
||||
dtype, missing_value = cls._validate_dtype(dtype, missing_value)
|
||||
dtype, missing_value = cls.validate_dtype(
|
||||
cls.__name__,
|
||||
dtype,
|
||||
missing_value,
|
||||
)
|
||||
params = cls._pop_params(kwargs)
|
||||
|
||||
identity = cls.static_identity(
|
||||
@@ -141,37 +146,40 @@ class Term(with_metaclass(ABCMeta, object)):
|
||||
)
|
||||
return tuple(zip(cls.params, param_values))
|
||||
|
||||
@classmethod
|
||||
def _validate_dtype(cls, passed_dtype, missing_value):
|
||||
@staticmethod
|
||||
def validate_dtype(termname, dtype, missing_value):
|
||||
"""
|
||||
Validate `dtype` passed to Term.__new__.
|
||||
Validate a `dtype` and `missing_value` passed to Term.__new__.
|
||||
|
||||
If passed_dtype is NotSpecified, then we try to fall back to a
|
||||
class-level attribute. If a value is found at that point, we pass it
|
||||
to np.dtype so that users can pass `float` or `bool` and have them
|
||||
coerce to the appropriate numpy types.
|
||||
Ensures that we know how to represent ``dtype``, and that missing_value
|
||||
is specified for types without default missing values.
|
||||
|
||||
Returns
|
||||
-------
|
||||
validated : np.dtype
|
||||
The dtype to use for the new term.
|
||||
validated_dtype, validated_missing_value : np.dtype, any
|
||||
The dtype and missing_value to use for the new term.
|
||||
|
||||
Raises
|
||||
------
|
||||
DTypeNotSpecified
|
||||
When no dtype was passed to the instance, and the class doesn't
|
||||
provide a default.
|
||||
InvalidDType
|
||||
NotDType
|
||||
When either the class or the instance provides a value not
|
||||
coercible to a numpy dtype.
|
||||
NoDefaultMissingValue
|
||||
When dtype requires an explicit missing_value, but
|
||||
``missing_value`` is NotSpecified.
|
||||
"""
|
||||
dtype = passed_dtype
|
||||
if dtype is NotSpecified:
|
||||
raise DTypeNotSpecified(termname=cls.__name__)
|
||||
raise DTypeNotSpecified(termname=termname)
|
||||
try:
|
||||
dtype = dtype_class(dtype)
|
||||
except TypeError:
|
||||
raise InvalidDType(dtype=dtype, termname=cls.__name__)
|
||||
raise NotDType(dtype=dtype, termname=termname)
|
||||
|
||||
if not can_represent_dtype(dtype):
|
||||
raise UnsupportedDType(dtype=dtype, termname=termname)
|
||||
|
||||
if missing_value is NotSpecified:
|
||||
missing_value = default_missing_value_for_dtype(dtype)
|
||||
|
||||
@@ -15,11 +15,14 @@ from toolz import flip
|
||||
|
||||
uint8_dtype = dtype('uint8')
|
||||
bool_dtype = dtype('bool')
|
||||
|
||||
int64_dtype = dtype('int64')
|
||||
|
||||
float32_dtype = dtype('float32')
|
||||
float64_dtype = dtype('float64')
|
||||
|
||||
complex128_dtype = dtype('complex128')
|
||||
|
||||
datetime64D_dtype = dtype('datetime64[D]')
|
||||
datetime64ns_dtype = dtype('datetime64[ns]')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user