From 0115cdc46c48e86286659bcf3f79f99ea24d3fca Mon Sep 17 00:00:00 2001 From: Scott Sanderson Date: Thu, 21 Jan 2016 17:55:21 -0500 Subject: [PATCH] MAINT: Fail fast on unsupported dtypes. --- tests/pipeline/test_term.py | 22 ++++++++++++--- zipline/errors.py | 13 ++++++++- zipline/lib/adjusted_array.py | 26 ++++++++++++++++-- zipline/pipeline/data/dataset.py | 46 +++++++++++++++++--------------- zipline/pipeline/term.py | 40 ++++++++++++++++----------- zipline/utils/numpy_utils.py | 3 +++ 6 files changed, 106 insertions(+), 44 deletions(-) diff --git a/tests/pipeline/test_term.py b/tests/pipeline/test_term.py index ce86b24a..3f386907 100644 --- a/tests/pipeline/test_term.py +++ b/tests/pipeline/test_term.py @@ -8,8 +8,9 @@ from unittest import TestCase from zipline.errors import ( DTypeNotSpecified, InputTermNotAtomic, - InvalidDType, + NotDType, TermInputsNotSpecified, + UnsupportedDType, WindowLengthNotSpecified, ) from zipline.pipeline import Factor, Filter, TermGraph @@ -19,6 +20,7 @@ from zipline.pipeline.term import AssetExists, NotSpecified from zipline.pipeline.expression import NUMEXPR_MATH_FUNCS from zipline.utils.numpy_utils import ( bool_dtype, + complex128_dtype, datetime64ns_dtype, float64_dtype, int64_dtype, @@ -332,9 +334,15 @@ class ObjectIdentityTestCase(TestCase): with self.assertRaises(DTypeNotSpecified): SomeFactorNoDType() - with self.assertRaises(InvalidDType): + with self.assertRaises(NotDType): SomeFactor(dtype=1) + with self.assertRaises(NoDefaultMissingValue): + SomeFactor(dtype=int64_dtype) + + with self.assertRaises(UnsupportedDType): + SomeFactor(dtype=complex128_dtype) + def test_latest_on_different_dtypes(self): factor_dtypes = (int64_dtype, float64_dtype, datetime64ns_dtype) for column in TestingDataSet.columns: @@ -350,11 +358,10 @@ class ObjectIdentityTestCase(TestCase): # property of correctly handling `NaN`. self.assertIs(column.missing_value, column.latest.missing_value) - def test_failure_timing_on_bad_missing_values(self): + def test_failure_timing_on_bad_dtypes(self): # Just constructing a bad column shouldn't fail. Column(dtype=int64_dtype) - with self.assertRaises(NoDefaultMissingValue) as e: class BadDataSet(DataSet): bad_column = Column(dtype=int64_dtype) @@ -367,6 +374,13 @@ class ObjectIdentityTestCase(TestCase): ) ) + Column(dtype=complex128_dtype) + with self.assertRaises(UnsupportedDType): + class BadDataSetComplex(DataSet): + bad_column = Column(dtype=complex128_dtype) + float_column = Column(dtype=float64_dtype) + int_column = Column(dtype=int64_dtype, missing_value=3) + class SubDataSetTestCase(TestCase): def test_subdataset(self): diff --git a/zipline/errors.py b/zipline/errors.py index 57e2df7d..95feb581 100644 --- a/zipline/errors.py +++ b/zipline/errors.py @@ -396,7 +396,7 @@ class DTypeNotSpecified(ZiplineError): ) -class InvalidDType(ZiplineError): +class NotDType(ZiplineError): """ Raised when a pipeline Term is constructed with a dtype that isn't a numpy dtype object. @@ -407,6 +407,17 @@ class InvalidDType(ZiplineError): ) +class UnsupportedDType(ZiplineError): + """ + Raised when a pipeline Term is constructed with a dtype that's not + supported. + """ + msg = ( + "Failed to construct {termname}.\n" + "Pipeline terms of dtype {dtype} are not yet supported." + ) + + class BadPercentileBounds(ZiplineError): """ Raised by API functions accepting percentile bounds when the passed bounds diff --git a/zipline/lib/adjusted_array.py b/zipline/lib/adjusted_array.py index 03556528..4c0ee5b8 100644 --- a/zipline/lib/adjusted_array.py +++ b/zipline/lib/adjusted_array.py @@ -7,6 +7,8 @@ from numpy import ( float64, int32, int64, + int16, + uint16, ndarray, uint32, uint8, @@ -29,13 +31,33 @@ from ._int64window import AdjustedArrayWindow as Int64Window from ._uint8window import AdjustedArrayWindow as UInt8Window NOMASK = None +BOOL_DTYPES = frozenset( + map(dtype, [bool_]), +) FLOAT_DTYPES = frozenset( - map(dtype, [float32, float64, int32]), + map(dtype, [float32, float64]), ) INT_DTYPES = frozenset( # NOTE: uint64 not supported because it can't be safely cast to int64. - map(dtype, [int32, int64, uint32]), + map(dtype, [int16, uint16, int32, int64, uint32]), ) +DATETIME_DTYPES = frozenset( + map(dtype, ['datetime64[ns]', 'datetime64[D]']), +) +REPRESENTABLE_DTYPES = BOOL_DTYPES.union( + FLOAT_DTYPES, + INT_DTYPES, + DATETIME_DTYPES +) + + +def can_represent_dtype(dtype): + """ + Can we build an AdjustedArray for a baseline of dtype ``dtype``? + """ + return dtype in REPRESENTABLE_DTYPES + + CONCRETE_WINDOW_TYPES = { float64_dtype: Float64Window, int64_dtype: Int64Window, diff --git a/zipline/pipeline/data/dataset.py b/zipline/pipeline/data/dataset.py index 7183dea8..6ded1f13 100644 --- a/zipline/pipeline/data/dataset.py +++ b/zipline/pipeline/data/dataset.py @@ -7,7 +7,11 @@ from six import ( with_metaclass, ) -from zipline.pipeline.term import Term, AssetExists, NotSpecified +from zipline.pipeline.term import ( + Term, + AssetExists, + NotSpecified, +) from zipline.utils.input_validation import ensure_dtype from zipline.utils.numpy_utils import ( bool_dtype, @@ -47,26 +51,26 @@ class _BoundColumnDescr(object): parent classes. """ def __init__(self, dtype, missing_value, name): - self.dtype = dtype - - # Calculating missing values here guarantees that we fail quickly if - # the user fails to provide a missing value for a dtype that requires - # one (e.g. int64), but still enables us to provide an error message - # that points to the name of the failing column. - if missing_value is NotSpecified: - try: - missing_value = default_missing_value_for_dtype(dtype) - except NoDefaultMissingValue: - # Re-raise with a better message. - raise NoDefaultMissingValue( - "Failed to create Column with name {name!r} and" - " dtype {dtype} because no missing_value was provided\n\n" - "Columns with dtype {dtype} require a missing_value.\n" - "Please pass missing_value to Column() or use a different" - " dtype.".format(dtype=dtype, name=name) - ) - - self.missing_value = missing_value + # Validating and calculating default missing values here guarantees + # that we fail quickly if the user passes an unsupporte dtype or fails + # to provide a missing value for a dtype that requires one + # (e.g. int64), but still enables us to provide an error message that + # points to the name of the failing column. + try: + self.dtype, self.missing_value = Term.validate_dtype( + termname="Column(name={name!r})".format(name=name), + dtype=dtype, + missing_value=missing_value, + ) + except NoDefaultMissingValue: + # Re-raise with a more specific message. + raise NoDefaultMissingValue( + "Failed to create Column with name {name!r} and" + " dtype {dtype} because no missing_value was provided\n\n" + "Columns with dtype {dtype} require a missing_value.\n" + "Please pass missing_value to Column() or use a different" + " dtype.".format(dtype=dtype, name=name) + ) self.name = name def __get__(self, instance, owner): diff --git a/zipline/pipeline/term.py b/zipline/pipeline/term.py index 5b730232..3cb9cc63 100644 --- a/zipline/pipeline/term.py +++ b/zipline/pipeline/term.py @@ -6,14 +6,15 @@ from weakref import WeakValueDictionary from numpy import dtype as dtype_class from six import with_metaclass - from zipline.errors import ( DTypeNotSpecified, InputTermNotAtomic, - InvalidDType, + NotDType, TermInputsNotSpecified, + UnsupportedDType, WindowLengthNotSpecified, ) +from zipline.lib.adjusted_array import can_represent_dtype from zipline.utils.memoize import lazyval from zipline.utils.numpy_utils import ( bool_dtype, @@ -69,7 +70,11 @@ class Term(with_metaclass(ABCMeta, object)): if missing_value is NotSpecified: missing_value = cls.missing_value - dtype, missing_value = cls._validate_dtype(dtype, missing_value) + dtype, missing_value = cls.validate_dtype( + cls.__name__, + dtype, + missing_value, + ) params = cls._pop_params(kwargs) identity = cls.static_identity( @@ -141,37 +146,40 @@ class Term(with_metaclass(ABCMeta, object)): ) return tuple(zip(cls.params, param_values)) - @classmethod - def _validate_dtype(cls, passed_dtype, missing_value): + @staticmethod + def validate_dtype(termname, dtype, missing_value): """ - Validate `dtype` passed to Term.__new__. + Validate a `dtype` and `missing_value` passed to Term.__new__. - If passed_dtype is NotSpecified, then we try to fall back to a - class-level attribute. If a value is found at that point, we pass it - to np.dtype so that users can pass `float` or `bool` and have them - coerce to the appropriate numpy types. + Ensures that we know how to represent ``dtype``, and that missing_value + is specified for types without default missing values. Returns ------- - validated : np.dtype - The dtype to use for the new term. + validated_dtype, validated_missing_value : np.dtype, any + The dtype and missing_value to use for the new term. Raises ------ DTypeNotSpecified When no dtype was passed to the instance, and the class doesn't provide a default. - InvalidDType + NotDType When either the class or the instance provides a value not coercible to a numpy dtype. + NoDefaultMissingValue + When dtype requires an explicit missing_value, but + ``missing_value`` is NotSpecified. """ - dtype = passed_dtype if dtype is NotSpecified: - raise DTypeNotSpecified(termname=cls.__name__) + raise DTypeNotSpecified(termname=termname) try: dtype = dtype_class(dtype) except TypeError: - raise InvalidDType(dtype=dtype, termname=cls.__name__) + raise NotDType(dtype=dtype, termname=termname) + + if not can_represent_dtype(dtype): + raise UnsupportedDType(dtype=dtype, termname=termname) if missing_value is NotSpecified: missing_value = default_missing_value_for_dtype(dtype) diff --git a/zipline/utils/numpy_utils.py b/zipline/utils/numpy_utils.py index 4cc37aff..56315cf7 100644 --- a/zipline/utils/numpy_utils.py +++ b/zipline/utils/numpy_utils.py @@ -15,11 +15,14 @@ from toolz import flip uint8_dtype = dtype('uint8') bool_dtype = dtype('bool') + int64_dtype = dtype('int64') float32_dtype = dtype('float32') float64_dtype = dtype('float64') +complex128_dtype = dtype('complex128') + datetime64D_dtype = dtype('datetime64[D]') datetime64ns_dtype = dtype('datetime64[ns]')