diff --git a/tests/utils/test_preprocess.py b/tests/utils/test_preprocess.py index c3ad15df..a1f2676f 100644 --- a/tests/utils/test_preprocess.py +++ b/tests/utils/test_preprocess.py @@ -207,6 +207,27 @@ class PreprocessTestCase(TestCase): with self.assertRaises(TypeError): foo(not_int(1), not_int(2), 3) + def test_expect_types_custom_funcname(self): + + class Foo(object): + @expect_types(__funcname='ArgleBargle', a=int) + def __init__(self, a): + self.a = a + + foo = Foo(1) + self.assertEqual(foo.a, 1) + + for not_int in (str, float): + with self.assertRaises(TypeError) as e: + Foo(not_int(1)) + self.assertEqual( + e.exception.args[0], + "ArgleBargle() expected a value of type " + "int for argument 'a', but got {t} instead.".format( + t=not_int.__name__, + ) + ) + def test_expect_types_with_tuple(self): @expect_types(a=(int, float)) def foo(a): @@ -269,6 +290,27 @@ class PreprocessTestCase(TestCase): ) self.assertEqual(e.exception.args[0], expected_message) + def test_expect_element_custom_funcname(self): + + set_ = {'a', 'b'} + + class Foo(object): + @expect_element(__funcname='ArgleBargle', a=set_) + def __init__(self, a): + self.a = a + + with self.assertRaises(ValueError) as e: + Foo('c') + + expected_message = ( + "ArgleBargle() expected a value in {set_!r}" + " for argument 'a', but got 'c' instead." + ).format( + # We special-case set to show a tuple instead of the set repr. + set_=tuple(sorted(set_)), + ) + self.assertEqual(e.exception.args[0], expected_message) + def test_expect_dtypes(self): @expect_dtypes(a=dtype(float), b=dtype('datetime64[ns]')) @@ -326,6 +368,24 @@ class PreprocessTestCase(TestCase): ).format(qualname=qualname(foo)) self.assertEqual(e.exception.args[0], expected_message) + def test_expect_dtypes_custom_funcname(self): + + allowed_dtypes = (dtype('datetime64[ns]'), dtype('float')) + + class Foo(object): + @expect_dtypes(__funcname='Foo', a=allowed_dtypes) + def __init__(self, a): + self.a = a + + with self.assertRaises(TypeError) as e: + Foo(arange(3, dtype='uint32')) + + expected_message = ( + "Foo() expected a value with dtype 'datetime64[ns]' " + "or 'float64' for argument 'a', but got 'uint32' instead." + ) + self.assertEqual(e.exception.args[0], expected_message) + def test_ensure_timezone(self): @preprocess(tz=ensure_timezone) def f(tz): @@ -407,3 +467,18 @@ class PreprocessTestCase(TestCase): " a scalar instead.".format(qualname=qualname(foo)) ) self.assertEqual(errmsg, expected) + + def test_expect_dimensions_custom_name(self): + + @expect_dimensions(__funcname='fizzbuzz', x=2) + def foo(x, y): + return x[0, 0] + + with self.assertRaises(ValueError) as e: + foo(arange(1), 1) + errmsg = str(e.exception) + expected = ( + "fizzbuzz() expected a 2-D array for argument 'x', but got" + " a 1-D array instead.".format(qualname=qualname(foo)) + ) + self.assertEqual(errmsg, expected) diff --git a/zipline/utils/input_validation.py b/zipline/utils/input_validation.py index 2dbbf9ef..f466b381 100644 --- a/zipline/utils/input_validation.py +++ b/zipline/utils/input_validation.py @@ -26,6 +26,19 @@ from zipline.utils.functional import getattrs from zipline.utils.preprocess import call, preprocess +if PY3: + _qualified_name = attrgetter('__qualname__') +else: + def _qualified_name(obj): + """ + Return the fully-qualified name (ignoring inner classes) of a type. + """ + module = obj.__module__ + if module in ('__builtin__', '__main__', 'builtins'): + return obj.__name__ + return '.'.join([module, obj.__name__]) + + def verify_indices_all_unique(obj): """ Check that all axes of a pandas object are unique. @@ -203,7 +216,7 @@ def ensure_timestamp(func, argname, arg): ) -def expect_dtypes(**named): +def expect_dtypes(__funcname=_qualified_name, **named): """ Preprocessing decorator that verifies inputs have expected numpy dtypes. @@ -232,6 +245,11 @@ def expect_dtypes(**named): ) ) + if isinstance(__funcname, str): + get_funcname = lambda _: __funcname + else: + get_funcname = __funcname + @preprocess(dtypes=call(lambda x: x if isinstance(x, tuple) else (x,))) def _expect_dtype(dtypes): """ @@ -249,7 +267,7 @@ def expect_dtypes(**named): "{funcname}() expected a value with dtype {dtype_str} " "for argument {argname!r}, but got {value!r} instead." ).format( - funcname=_qualified_name(func), + funcname=get_funcname(func), dtype_str=' or '.join(repr(d.name) for d in dtypes), argname=argname, value=value_to_show, @@ -328,7 +346,7 @@ def expect_kinds(**named): return preprocess(**valmap(_expect_kind, named)) -def expect_types(*_pos, **named): +def expect_types(__funcname=_qualified_name, **named): """ Preprocessing decorator that verifies inputs have expected types. @@ -345,10 +363,14 @@ def expect_types(*_pos, **named): ... TypeError: ...foo() expected a value of type int for argument 'x', but got float instead. - """ - if _pos: - raise TypeError("expect_types() only takes keyword arguments.") + Notes + ----- + A special argument, __funcname, can be provided as a string to override the + function name shown in error messages. This is most often used on __init__ + or __new__ methods to make errors refer to the class name instead of the + function name. + """ for name, type_ in iteritems(named): if not isinstance(type_, (type, tuple)): raise TypeError( @@ -372,29 +394,17 @@ def expect_types(*_pos, **named): template = _template.format(type_or_types=_qualified_name(type_)) return make_check( - TypeError, - template, - lambda v: not isinstance(v, type_), - compose(_qualified_name, type), + exc_type=TypeError, + template=template, + pred=lambda v: not isinstance(v, type_), + actual=compose(_qualified_name, type), + funcname=__funcname, ) return preprocess(**valmap(_expect_type, named)) -if PY3: - _qualified_name = attrgetter('__qualname__') -else: - def _qualified_name(obj): - """ - Return the fully-qualified name (ignoring inner classes) of a type. - """ - module = obj.__module__ - if module in ('__builtin__', '__main__', 'builtins'): - return obj.__name__ - return '.'.join([module, obj.__name__]) - - -def make_check(exc_type, template, pred, actual): +def make_check(exc_type, template, pred, actual, funcname): """ Factory for making preprocessing functions that check a predicate on the input value. @@ -413,13 +423,22 @@ def make_check(exc_type, template, pred, actual): actual : function[object -> object] A function to call on bad values to produce the value to display in the error message. + funcname : str or callable + Name to use in error messages, or function to call on decorated + functions to produce a name. Passing an explicit name is useful when + creating checks for __init__ or __new__ methods when you want the error + to refer to the class name instead of the method name. """ + if isinstance(funcname, str): + get_funcname = lambda _: funcname + else: + get_funcname = funcname def _check(func, argname, argvalue): if pred(argvalue): raise exc_type( template % { - 'funcname': _qualified_name(func), + 'funcname': get_funcname(func), 'argname': argname, 'actual': actual(argvalue), }, @@ -452,7 +471,7 @@ def optional(type_): return (type_, type(None)) -def expect_element(*_pos, **named): +def expect_element(__funcname=_qualified_name, **named): """ Preprocessing decorator that verifies inputs are elements of some expected collection. @@ -475,13 +494,15 @@ def expect_element(*_pos, **named): Notes ----- + A special argument, __funcname, can be provided as a string to override the + function name shown in error messages. This is most often used on __init__ + or __new__ methods to make errors refer to the class name instead of the + function name. + This uses the `in` operator (__contains__) to make the containment check. This allows us to use any custom container as long as the object supports the container protocol. """ - if _pos: - raise TypeError("expect_element() only takes keyword arguments.") - def _expect_element(collection): if isinstance(collection, (set, frozenset)): # Special case the error message for set and frozen set to make it @@ -499,11 +520,12 @@ def expect_element(*_pos, **named): template, complement(op.contains(collection)), repr, + funcname=__funcname, ) return preprocess(**valmap(_expect_element, named)) -def expect_bounded(**named): +def expect_bounded(__funcname=_qualified_name, **named): """ Preprocessing decorator verifying that inputs fall between bounds. @@ -590,11 +612,12 @@ def expect_bounded(**named): template=template, pred=should_fail, actual=repr, + funcname=__funcname, ) return preprocess(**valmap(_expect_bounded, named)) -def expect_dimensions(**dimensions): +def expect_dimensions(__funcname=_qualified_name, **dimensions): """ Preprocessing decorator that verifies inputs are numpy arrays with a specific dimensionality. @@ -615,9 +638,13 @@ def expect_dimensions(**dimensions): ValueError: ...foo() expected a 2-D array for argument 'y', but got a 1-D array instead. """ + if isinstance(__funcname, str): + get_funcname = lambda _: __funcname + else: + get_funcname = __funcname + def _expect_dimension(expected_ndim): def _check(func, argname, argvalue): - funcname = _qualified_name(func) actual_ndim = argvalue.ndim if actual_ndim != expected_ndim: if actual_ndim == 0: @@ -628,7 +655,7 @@ def expect_dimensions(**dimensions): "{func}() expected a {expected:d}-D array" " for argument {argname!r}, but got a {actual}" " instead.".format( - func=funcname, + func=get_funcname(func), expected=expected_ndim, argname=argname, actual=actual_repr,