diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index 53a20df9..e8d67335 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -425,6 +425,23 @@ class TestMiscellaneousAPI(TestCase): self.assertIsInstance(algo.sid(3), Equity) self.assertIsInstance(algo.sid(4), Equity) + # Supplying a non-string argument to symbol() + # should result in a TypeError. + with self.assertRaises(TypeError): + algo.symbol(1) + + with self.assertRaises(TypeError): + algo.symbol((1,)) + + with self.assertRaises(TypeError): + algo.symbol({1}) + + with self.assertRaises(TypeError): + algo.symbol([1]) + + with self.assertRaises(TypeError): + algo.symbol({'foo': 'bar'}) + def test_future_symbol(self): """ Tests the future_symbol API function. """ @@ -450,6 +467,23 @@ class TestMiscellaneousAPI(TestCase): with self.assertRaises(SymbolNotFound): algo.future_symbol('FOOBAR') + # Supplying a non-string argument to future_symbol() + # should result in a TypeError. + with self.assertRaises(TypeError): + algo.future_symbol(1) + + with self.assertRaises(TypeError): + algo.future_symbol((1,)) + + with self.assertRaises(TypeError): + algo.future_symbol({1}) + + with self.assertRaises(TypeError): + algo.future_symbol([1]) + + with self.assertRaises(TypeError): + algo.future_symbol({'foo': 'bar'}) + def test_future_chain(self): """ Tests the future_chain API function. """ @@ -493,6 +527,23 @@ class TestMiscellaneousAPI(TestCase): with self.assertRaises(UnsupportedDatetimeFormat): algo.future_chain('CL', '2015-09-') + # Supplying a non-string argument to future_chain() + # should result in a TypeError. + with self.assertRaises(TypeError): + algo.future_chain(1) + + with self.assertRaises(TypeError): + algo.future_chain((1,)) + + with self.assertRaises(TypeError): + algo.future_chain({1}) + + with self.assertRaises(TypeError): + algo.future_chain([1]) + + with self.assertRaises(TypeError): + algo.future_chain({'foo': 'bar'}) + def test_set_symbol_lookup_date(self): """ Test the set_symbol_lookup_date API method. diff --git a/tests/test_assets.py b/tests/test_assets.py index a34292c7..ae873f8a 100644 --- a/tests/test_assets.py +++ b/tests/test_assets.py @@ -109,11 +109,11 @@ def build_lookup_generic_cases(): (finder, 1, None, assets[1]), (finder, 2, None, assets[2]), # Duplicated symbol with resolution date - (finder, 'duplicated', dupe_0_start, dupe_0), - (finder, 'duplicated', dupe_1_start, dupe_1), + (finder, 'DUPLICATED', dupe_0_start, dupe_0), + (finder, 'DUPLICATED', dupe_1_start, dupe_1), # Unique symbol, with or without resolution date. - (finder, 'unique', unique_start, unique), - (finder, 'unique', None, unique), + (finder, 'UNIQUE', unique_start, unique), + (finder, 'UNIQUE', None, unique), ## # Iterables @@ -125,11 +125,11 @@ def build_lookup_generic_cases(): (finder, (0, 1), None, assets[:-1]), (finder, iter((0, 1)), None, assets[:-1]), # Iterables of symbols. - (finder, ('duplicated', 'unique'), dupe_0_start, [dupe_0, unique]), - (finder, ('duplicated', 'unique'), dupe_1_start, [dupe_1, unique]), + (finder, ('DUPLICATED', 'UNIQUE'), dupe_0_start, [dupe_0, unique]), + (finder, ('DUPLICATED', 'UNIQUE'), dupe_1_start, [dupe_1, unique]), # Mixed types (finder, - ('duplicated', 2, 'unique', 1, dupe_1), + ('DUPLICATED', 2, 'UNIQUE', 1, dupe_1), dupe_0_start, [dupe_0, assets[2], unique, assets[1], dupe_1]), ] @@ -360,18 +360,18 @@ class AssetFinderTestCase(TestCase): # we do it twice to catch caching bugs for i in range(2): with self.assertRaises(SymbolNotFound): - finder.lookup_symbol('test', as_of) + finder.lookup_symbol('TEST', as_of) with self.assertRaises(SymbolNotFound): - finder.lookup_symbol('test1', as_of) + finder.lookup_symbol('TEST1', as_of) # '@' is not a supported delimiter with self.assertRaises(SymbolNotFound): - finder.lookup_symbol('test@1', as_of) + finder.lookup_symbol('TEST@1', as_of) # Adding an unnecessary fuzzy shouldn't matter. for fuzzy_char in ['-', '/', '_', '.']: self.assertEqual( asset_1, - finder.lookup_symbol('test%s1' % fuzzy_char, as_of) + finder.lookup_symbol('TEST%s1' % fuzzy_char, as_of) ) def test_lookup_symbol_fuzzy(self): @@ -434,15 +434,15 @@ class AssetFinderTestCase(TestCase): finder = AssetFinder(self.env.engine) for _ in range(2): # Run checks twice to test for caching bugs. with self.assertRaises(SymbolNotFound): - finder.lookup_symbol('non_existing', dates[0]) + finder.lookup_symbol('NON_EXISTING', dates[0]) with self.assertRaises(MultipleSymbolsFound): - finder.lookup_symbol('existing', None) + finder.lookup_symbol('EXISTING', None) for i, date in enumerate(dates): # Verify that we correctly resolve multiple symbols using # the supplied date - result = finder.lookup_symbol('existing', date) + result = finder.lookup_symbol('EXISTING', date) self.assertEqual(result.symbol, 'EXISTING') self.assertEqual(result.sid, i) @@ -497,7 +497,7 @@ class AssetFinderTestCase(TestCase): self.env.write_data(equities_df=data) finder = AssetFinder(self.env.engine) results, missing = finder.lookup_generic( - ['real', 1, 'fake', 'real_but_old', 'real_but_in_the_future'], + ['REAL', 1, 'FAKE', 'REAL_BUT_OLD', 'REAL_BUT_IN_THE_FUTURE'], pd.Timestamp('2013-02-01', tz='UTC'), ) @@ -510,8 +510,8 @@ class AssetFinderTestCase(TestCase): self.assertEqual(results[2].sid, 2) self.assertEqual(len(missing), 2) - self.assertEqual(missing[0], 'fake') - self.assertEqual(missing[1], 'real_but_in_the_future') + self.assertEqual(missing[0], 'FAKE') + self.assertEqual(missing[1], 'REAL_BUT_IN_THE_FUTURE') def test_insert_metadata(self): data = {0: {'asset_type': 'equity', diff --git a/tests/utils/test_preprocess.py b/tests/utils/test_preprocess.py index b76538e9..4ed686fb 100644 --- a/tests/utils/test_preprocess.py +++ b/tests/utils/test_preprocess.py @@ -5,7 +5,8 @@ from types import FunctionType from unittest import TestCase from nose_parameterized import parameterized -from zipline.utils.preprocess import call, expect_types, preprocess, optional +from zipline.utils.preprocess import call, preprocess +from zipline.utils.input_validation import expect_types, optional def noop(func, argname, argvalue): diff --git a/zipline/algorithm.py b/zipline/algorithm.py index 5fd87da7..c6e967a3 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -85,6 +85,7 @@ from zipline.utils.api_support import ( require_not_initialized, ZiplineAPI, ) +from zipline.utils.input_validation import ensure_upper_case from zipline.utils.cache import CachedObject, Expired import zipline.utils.events from zipline.utils.events import ( @@ -95,6 +96,7 @@ from zipline.utils.events import ( ) from zipline.utils.factory import create_simulation_parameters from zipline.utils.math_utils import tolerant_equals +from zipline.utils.preprocess import preprocess import zipline.protocol from zipline.protocol import Event @@ -738,6 +740,7 @@ class TradingAlgorithm(object): self._recorded_vars[name] = value @api_method + @preprocess(symbol_str=ensure_upper_case) def symbol(self, symbol_str): """ Default symbol lookup for any source that directly maps the @@ -770,6 +773,7 @@ class TradingAlgorithm(object): return self.asset_finder.retrieve_asset(a_sid) @api_method + @preprocess(symbol=ensure_upper_case) def future_symbol(self, symbol): """ Lookup a futures contract with a given symbol. @@ -792,6 +796,7 @@ class TradingAlgorithm(object): return self.asset_finder.lookup_future_symbol(symbol) @api_method + @preprocess(root_symbol=ensure_upper_case) def future_chain(self, root_symbol, as_of_date=None): """ Look up a future chain with the specified parameters. @@ -823,7 +828,7 @@ class TradingAlgorithm(object): return FutureChain( asset_finder=self.asset_finder, get_datetime=self.get_datetime, - root_symbol=root_symbol.upper(), + root_symbol=root_symbol, as_of_date=as_of_date ) diff --git a/zipline/assets/assets.py b/zipline/assets/assets.py index 5b19ccea..12eb68d1 100644 --- a/zipline/assets/assets.py +++ b/zipline/assets/assets.py @@ -209,7 +209,6 @@ class AssetFinder(object): """ # Format inputs - symbol = symbol.upper() if as_of_date is not None: as_of_date = pd.Timestamp(normalize_date(as_of_date)) diff --git a/zipline/pipeline/pipeline.py b/zipline/pipeline/pipeline.py index 6b3f87f4..d64166d6 100644 --- a/zipline/pipeline/pipeline.py +++ b/zipline/pipeline/pipeline.py @@ -1,4 +1,4 @@ -from zipline.utils.preprocess import expect_types, optional +from zipline.utils.input_validation import expect_types, optional from .term import Term from .filters import Filter diff --git a/zipline/utils/input_validation.py b/zipline/utils/input_validation.py new file mode 100644 index 00000000..9ee0e109 --- /dev/null +++ b/zipline/utils/input_validation.py @@ -0,0 +1,124 @@ +# Copyright 2015 Quantopian, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from six import iteritems, string_types +from toolz import valmap + +from zipline.utils.preprocess import preprocess + + +def ensure_upper_case(func, argname, arg): + if isinstance(arg, string_types): + return arg.upper() + else: + raise TypeError( + "{0}() expected argument '{1}' to" + " be a string, but got {2} instead.".format( + func.__name__, argname, arg,) + ) + + +def expect_types(*_pos, **named): + """ + Preprocessing decorator that verifies inputs have expected types. + + Usage + ----- + >>> @expect_types(x=int, y=str) + ... def foo(x, y): + ... return x, y + ... + >>> foo(2, '3') + (2, '3') + >>> foo(2.0, '3') + Traceback (most recent call last): + ... + TypeError: foo() expected an argument of type 'int' for argument 'x', but got float instead. # noqa + """ + if _pos: + raise TypeError("expect_types() only takes keyword arguments.") + + for name, type_ in iteritems(named): + if not isinstance(type_, (type, tuple)): + raise TypeError( + "expect_types() expected a type or tuple of types for " + "argument '{name}', but got {type_} instead.".format( + name=name, type_=type_, + ) + ) + + return preprocess(**valmap(_expect_type, named)) + + +def _qualified_name(obj): + """ + Return the fully-qualified name (ignoring inner classes) of a type. + """ + module = obj.__module__ + if module in ('__builtin__', '__main__', 'builtins'): + return obj.__name__ + return '.'.join([module, obj.__name__]) + + +def _expect_type(type_): + """ + Factory for type-checking functions that work the @preprocess decorator. + """ + # Slightly different messages for type and tuple of types. + _template = ( + "{{funcname}}() expected a value of type {type_or_types} " + "for argument '{{argname}}', but got {{actual}} instead." + ) + if isinstance(type_, tuple): + template = _template.format( + type_or_types=' or '.join(map(_qualified_name, type_)) + ) + else: + template = _template.format(type_or_types=_qualified_name(type_)) + + def _check_type(func, argname, argvalue): + if not isinstance(argvalue, type_): + raise TypeError( + template.format( + funcname=_qualified_name(func), + argname=argname, + actual=_qualified_name(type(argvalue)), + ) + ) + return argvalue + return _check_type + + +def optional(type_): + """ + Helper for use with `expect_types` when an input can be `type_` or `None`. + + Returns an object such that both `None` and instances of `type_` pass + checks of the form `isinstance(obj, optional(type_))`. + + Parameters + ---------- + type_ : type + Type for which to produce an option. + + Examples + -------- + >>> isinstance({}, optional(dict)) + True + >>> isinstance(None, optional(dict)) + True + >>> isinstance(1, optional(dict)) + False + """ + return (type_, type(None)) diff --git a/zipline/utils/preprocess.py b/zipline/utils/preprocess.py index f9607bdb..f3cc0679 100644 --- a/zipline/utils/preprocess.py +++ b/zipline/utils/preprocess.py @@ -6,45 +6,12 @@ from functools import wraps from inspect import getargspec from uuid import uuid4 -from six import iteritems, viewkeys, exec_ -from toolz import valmap +from six import viewkeys, exec_ NO_DEFAULT = object() -def expect_types(*_pos, **named): - """ - Preprocessing decorator that verifies inputs have expected types. - - Usage - ----- - >>> @expect_types(x=int, y=str) - ... def foo(x, y): - ... return x, y - ... - >>> foo(2, '3') - (2, '3') - >>> foo(2.0, '3') - Traceback (most recent call last): - ... - TypeError: foo() expected an argument of type 'int' for argument 'x', but got float instead. # noqa - """ - if _pos: - raise TypeError("expect_types() only takes keyword arguments.") - - for name, type_ in iteritems(named): - if not isinstance(type_, (type, tuple)): - raise TypeError( - "expect_types() expected a type or tuple of types for " - "argument '{name}', but got {type_} instead.".format( - name=name, type_=type_, - ) - ) - - return preprocess(**valmap(_expect_type, named)) - - def preprocess(*_unused, **processors): """ Decorator that applies pre-processors to the arguments of a function before @@ -157,69 +124,6 @@ def call(f): return processor -def _qualified_name(obj): - """ - Return the fully-qualified name (ignoring inner classes) of a type. - """ - module = obj.__module__ - if module in ('__builtin__', '__main__', 'builtins'): - return obj.__name__ - return '.'.join([module, obj.__name__]) - - -def _expect_type(type_): - """ - Factory for type-checking functions that work the @preprocess decorator. - """ - # Slightly different messages for type and tuple of types. - _template = ( - "{{funcname}}() expected a value of type {type_or_types} " - "for argument '{{argname}}', but got {{actual}} instead." - ) - if isinstance(type_, tuple): - template = _template.format( - type_or_types=' or '.join(map(_qualified_name, type_)) - ) - else: - template = _template.format(type_or_types=_qualified_name(type_)) - - def _check_type(func, argname, argvalue): - if not isinstance(argvalue, type_): - raise TypeError( - template.format( - funcname=_qualified_name(func), - argname=argname, - actual=_qualified_name(type(argvalue)), - ) - ) - return argvalue - return _check_type - - -def optional(type_): - """ - Helper for use with `expect_types` when an input can be `type_` or `None`. - - Returns an object such that both `None` and instances of `type_` pass - checks of the form `isinstance(obj, optional(type_))`. - - Parameters - ---------- - type_ : type - Type for which to produce an option. - - Examples - -------- - >>> isinstance({}, optional(dict)) - True - >>> isinstance(None, optional(dict)) - True - >>> isinstance(1, optional(dict)) - False - """ - return (type_, type(None)) - - def _build_preprocessed_function(func, processors, args_defaults): """ Build a preprocessed function with the same signature as `func`.