From 4e2039c9b0fd0448db74fb892cf1df4cb6e3a1c2 Mon Sep 17 00:00:00 2001 From: Stewart Douglas Date: Mon, 5 Oct 2015 17:46:02 -0400 Subject: [PATCH] ENH: Coerce user input with API method decorator Previously we have capitalized input strings at different levels in our code: in the user-facing API methods and in the asset finder. This commit moves input string capitalization exclusively to the API method to which the string was supplied. Specifically, the string is capitalized by a preprocess API method decorator. The preprocess decorator passes the input string to the newly defined ensure_upper_case() method, which returns a TypeError if the argument supplied is not a string. ensure_upper_case() is defined in a new file, zipline/utils/input_validation.py. The existing expect_types() method is also moved there. Various tests in tests/test_assets.py are modified to account for the fact that the asset finder method lookup_symol() no longer capitalizes its supplied argument. --- tests/test_algorithm.py | 51 ++++++++++++ tests/test_assets.py | 34 ++++---- tests/utils/test_preprocess.py | 3 +- zipline/algorithm.py | 7 +- zipline/assets/assets.py | 1 - zipline/pipeline/pipeline.py | 2 +- zipline/utils/input_validation.py | 124 ++++++++++++++++++++++++++++++ zipline/utils/preprocess.py | 98 +---------------------- 8 files changed, 202 insertions(+), 118 deletions(-) create mode 100644 zipline/utils/input_validation.py diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index 53a20df9..e8d67335 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -425,6 +425,23 @@ class TestMiscellaneousAPI(TestCase): self.assertIsInstance(algo.sid(3), Equity) self.assertIsInstance(algo.sid(4), Equity) + # Supplying a non-string argument to symbol() + # should result in a TypeError. + with self.assertRaises(TypeError): + algo.symbol(1) + + with self.assertRaises(TypeError): + algo.symbol((1,)) + + with self.assertRaises(TypeError): + algo.symbol({1}) + + with self.assertRaises(TypeError): + algo.symbol([1]) + + with self.assertRaises(TypeError): + algo.symbol({'foo': 'bar'}) + def test_future_symbol(self): """ Tests the future_symbol API function. """ @@ -450,6 +467,23 @@ class TestMiscellaneousAPI(TestCase): with self.assertRaises(SymbolNotFound): algo.future_symbol('FOOBAR') + # Supplying a non-string argument to future_symbol() + # should result in a TypeError. + with self.assertRaises(TypeError): + algo.future_symbol(1) + + with self.assertRaises(TypeError): + algo.future_symbol((1,)) + + with self.assertRaises(TypeError): + algo.future_symbol({1}) + + with self.assertRaises(TypeError): + algo.future_symbol([1]) + + with self.assertRaises(TypeError): + algo.future_symbol({'foo': 'bar'}) + def test_future_chain(self): """ Tests the future_chain API function. """ @@ -493,6 +527,23 @@ class TestMiscellaneousAPI(TestCase): with self.assertRaises(UnsupportedDatetimeFormat): algo.future_chain('CL', '2015-09-') + # Supplying a non-string argument to future_chain() + # should result in a TypeError. + with self.assertRaises(TypeError): + algo.future_chain(1) + + with self.assertRaises(TypeError): + algo.future_chain((1,)) + + with self.assertRaises(TypeError): + algo.future_chain({1}) + + with self.assertRaises(TypeError): + algo.future_chain([1]) + + with self.assertRaises(TypeError): + algo.future_chain({'foo': 'bar'}) + def test_set_symbol_lookup_date(self): """ Test the set_symbol_lookup_date API method. diff --git a/tests/test_assets.py b/tests/test_assets.py index a34292c7..ae873f8a 100644 --- a/tests/test_assets.py +++ b/tests/test_assets.py @@ -109,11 +109,11 @@ def build_lookup_generic_cases(): (finder, 1, None, assets[1]), (finder, 2, None, assets[2]), # Duplicated symbol with resolution date - (finder, 'duplicated', dupe_0_start, dupe_0), - (finder, 'duplicated', dupe_1_start, dupe_1), + (finder, 'DUPLICATED', dupe_0_start, dupe_0), + (finder, 'DUPLICATED', dupe_1_start, dupe_1), # Unique symbol, with or without resolution date. - (finder, 'unique', unique_start, unique), - (finder, 'unique', None, unique), + (finder, 'UNIQUE', unique_start, unique), + (finder, 'UNIQUE', None, unique), ## # Iterables @@ -125,11 +125,11 @@ def build_lookup_generic_cases(): (finder, (0, 1), None, assets[:-1]), (finder, iter((0, 1)), None, assets[:-1]), # Iterables of symbols. - (finder, ('duplicated', 'unique'), dupe_0_start, [dupe_0, unique]), - (finder, ('duplicated', 'unique'), dupe_1_start, [dupe_1, unique]), + (finder, ('DUPLICATED', 'UNIQUE'), dupe_0_start, [dupe_0, unique]), + (finder, ('DUPLICATED', 'UNIQUE'), dupe_1_start, [dupe_1, unique]), # Mixed types (finder, - ('duplicated', 2, 'unique', 1, dupe_1), + ('DUPLICATED', 2, 'UNIQUE', 1, dupe_1), dupe_0_start, [dupe_0, assets[2], unique, assets[1], dupe_1]), ] @@ -360,18 +360,18 @@ class AssetFinderTestCase(TestCase): # we do it twice to catch caching bugs for i in range(2): with self.assertRaises(SymbolNotFound): - finder.lookup_symbol('test', as_of) + finder.lookup_symbol('TEST', as_of) with self.assertRaises(SymbolNotFound): - finder.lookup_symbol('test1', as_of) + finder.lookup_symbol('TEST1', as_of) # '@' is not a supported delimiter with self.assertRaises(SymbolNotFound): - finder.lookup_symbol('test@1', as_of) + finder.lookup_symbol('TEST@1', as_of) # Adding an unnecessary fuzzy shouldn't matter. for fuzzy_char in ['-', '/', '_', '.']: self.assertEqual( asset_1, - finder.lookup_symbol('test%s1' % fuzzy_char, as_of) + finder.lookup_symbol('TEST%s1' % fuzzy_char, as_of) ) def test_lookup_symbol_fuzzy(self): @@ -434,15 +434,15 @@ class AssetFinderTestCase(TestCase): finder = AssetFinder(self.env.engine) for _ in range(2): # Run checks twice to test for caching bugs. with self.assertRaises(SymbolNotFound): - finder.lookup_symbol('non_existing', dates[0]) + finder.lookup_symbol('NON_EXISTING', dates[0]) with self.assertRaises(MultipleSymbolsFound): - finder.lookup_symbol('existing', None) + finder.lookup_symbol('EXISTING', None) for i, date in enumerate(dates): # Verify that we correctly resolve multiple symbols using # the supplied date - result = finder.lookup_symbol('existing', date) + result = finder.lookup_symbol('EXISTING', date) self.assertEqual(result.symbol, 'EXISTING') self.assertEqual(result.sid, i) @@ -497,7 +497,7 @@ class AssetFinderTestCase(TestCase): self.env.write_data(equities_df=data) finder = AssetFinder(self.env.engine) results, missing = finder.lookup_generic( - ['real', 1, 'fake', 'real_but_old', 'real_but_in_the_future'], + ['REAL', 1, 'FAKE', 'REAL_BUT_OLD', 'REAL_BUT_IN_THE_FUTURE'], pd.Timestamp('2013-02-01', tz='UTC'), ) @@ -510,8 +510,8 @@ class AssetFinderTestCase(TestCase): self.assertEqual(results[2].sid, 2) self.assertEqual(len(missing), 2) - self.assertEqual(missing[0], 'fake') - self.assertEqual(missing[1], 'real_but_in_the_future') + self.assertEqual(missing[0], 'FAKE') + self.assertEqual(missing[1], 'REAL_BUT_IN_THE_FUTURE') def test_insert_metadata(self): data = {0: {'asset_type': 'equity', diff --git a/tests/utils/test_preprocess.py b/tests/utils/test_preprocess.py index b76538e9..4ed686fb 100644 --- a/tests/utils/test_preprocess.py +++ b/tests/utils/test_preprocess.py @@ -5,7 +5,8 @@ from types import FunctionType from unittest import TestCase from nose_parameterized import parameterized -from zipline.utils.preprocess import call, expect_types, preprocess, optional +from zipline.utils.preprocess import call, preprocess +from zipline.utils.input_validation import expect_types, optional def noop(func, argname, argvalue): diff --git a/zipline/algorithm.py b/zipline/algorithm.py index 5fd87da7..c6e967a3 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -85,6 +85,7 @@ from zipline.utils.api_support import ( require_not_initialized, ZiplineAPI, ) +from zipline.utils.input_validation import ensure_upper_case from zipline.utils.cache import CachedObject, Expired import zipline.utils.events from zipline.utils.events import ( @@ -95,6 +96,7 @@ from zipline.utils.events import ( ) from zipline.utils.factory import create_simulation_parameters from zipline.utils.math_utils import tolerant_equals +from zipline.utils.preprocess import preprocess import zipline.protocol from zipline.protocol import Event @@ -738,6 +740,7 @@ class TradingAlgorithm(object): self._recorded_vars[name] = value @api_method + @preprocess(symbol_str=ensure_upper_case) def symbol(self, symbol_str): """ Default symbol lookup for any source that directly maps the @@ -770,6 +773,7 @@ class TradingAlgorithm(object): return self.asset_finder.retrieve_asset(a_sid) @api_method + @preprocess(symbol=ensure_upper_case) def future_symbol(self, symbol): """ Lookup a futures contract with a given symbol. @@ -792,6 +796,7 @@ class TradingAlgorithm(object): return self.asset_finder.lookup_future_symbol(symbol) @api_method + @preprocess(root_symbol=ensure_upper_case) def future_chain(self, root_symbol, as_of_date=None): """ Look up a future chain with the specified parameters. @@ -823,7 +828,7 @@ class TradingAlgorithm(object): return FutureChain( asset_finder=self.asset_finder, get_datetime=self.get_datetime, - root_symbol=root_symbol.upper(), + root_symbol=root_symbol, as_of_date=as_of_date ) diff --git a/zipline/assets/assets.py b/zipline/assets/assets.py index 5b19ccea..12eb68d1 100644 --- a/zipline/assets/assets.py +++ b/zipline/assets/assets.py @@ -209,7 +209,6 @@ class AssetFinder(object): """ # Format inputs - symbol = symbol.upper() if as_of_date is not None: as_of_date = pd.Timestamp(normalize_date(as_of_date)) diff --git a/zipline/pipeline/pipeline.py b/zipline/pipeline/pipeline.py index 6b3f87f4..d64166d6 100644 --- a/zipline/pipeline/pipeline.py +++ b/zipline/pipeline/pipeline.py @@ -1,4 +1,4 @@ -from zipline.utils.preprocess import expect_types, optional +from zipline.utils.input_validation import expect_types, optional from .term import Term from .filters import Filter diff --git a/zipline/utils/input_validation.py b/zipline/utils/input_validation.py new file mode 100644 index 00000000..9ee0e109 --- /dev/null +++ b/zipline/utils/input_validation.py @@ -0,0 +1,124 @@ +# Copyright 2015 Quantopian, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from six import iteritems, string_types +from toolz import valmap + +from zipline.utils.preprocess import preprocess + + +def ensure_upper_case(func, argname, arg): + if isinstance(arg, string_types): + return arg.upper() + else: + raise TypeError( + "{0}() expected argument '{1}' to" + " be a string, but got {2} instead.".format( + func.__name__, argname, arg,) + ) + + +def expect_types(*_pos, **named): + """ + Preprocessing decorator that verifies inputs have expected types. + + Usage + ----- + >>> @expect_types(x=int, y=str) + ... def foo(x, y): + ... return x, y + ... + >>> foo(2, '3') + (2, '3') + >>> foo(2.0, '3') + Traceback (most recent call last): + ... + TypeError: foo() expected an argument of type 'int' for argument 'x', but got float instead. # noqa + """ + if _pos: + raise TypeError("expect_types() only takes keyword arguments.") + + for name, type_ in iteritems(named): + if not isinstance(type_, (type, tuple)): + raise TypeError( + "expect_types() expected a type or tuple of types for " + "argument '{name}', but got {type_} instead.".format( + name=name, type_=type_, + ) + ) + + return preprocess(**valmap(_expect_type, named)) + + +def _qualified_name(obj): + """ + Return the fully-qualified name (ignoring inner classes) of a type. + """ + module = obj.__module__ + if module in ('__builtin__', '__main__', 'builtins'): + return obj.__name__ + return '.'.join([module, obj.__name__]) + + +def _expect_type(type_): + """ + Factory for type-checking functions that work the @preprocess decorator. + """ + # Slightly different messages for type and tuple of types. + _template = ( + "{{funcname}}() expected a value of type {type_or_types} " + "for argument '{{argname}}', but got {{actual}} instead." + ) + if isinstance(type_, tuple): + template = _template.format( + type_or_types=' or '.join(map(_qualified_name, type_)) + ) + else: + template = _template.format(type_or_types=_qualified_name(type_)) + + def _check_type(func, argname, argvalue): + if not isinstance(argvalue, type_): + raise TypeError( + template.format( + funcname=_qualified_name(func), + argname=argname, + actual=_qualified_name(type(argvalue)), + ) + ) + return argvalue + return _check_type + + +def optional(type_): + """ + Helper for use with `expect_types` when an input can be `type_` or `None`. + + Returns an object such that both `None` and instances of `type_` pass + checks of the form `isinstance(obj, optional(type_))`. + + Parameters + ---------- + type_ : type + Type for which to produce an option. + + Examples + -------- + >>> isinstance({}, optional(dict)) + True + >>> isinstance(None, optional(dict)) + True + >>> isinstance(1, optional(dict)) + False + """ + return (type_, type(None)) diff --git a/zipline/utils/preprocess.py b/zipline/utils/preprocess.py index f9607bdb..f3cc0679 100644 --- a/zipline/utils/preprocess.py +++ b/zipline/utils/preprocess.py @@ -6,45 +6,12 @@ from functools import wraps from inspect import getargspec from uuid import uuid4 -from six import iteritems, viewkeys, exec_ -from toolz import valmap +from six import viewkeys, exec_ NO_DEFAULT = object() -def expect_types(*_pos, **named): - """ - Preprocessing decorator that verifies inputs have expected types. - - Usage - ----- - >>> @expect_types(x=int, y=str) - ... def foo(x, y): - ... return x, y - ... - >>> foo(2, '3') - (2, '3') - >>> foo(2.0, '3') - Traceback (most recent call last): - ... - TypeError: foo() expected an argument of type 'int' for argument 'x', but got float instead. # noqa - """ - if _pos: - raise TypeError("expect_types() only takes keyword arguments.") - - for name, type_ in iteritems(named): - if not isinstance(type_, (type, tuple)): - raise TypeError( - "expect_types() expected a type or tuple of types for " - "argument '{name}', but got {type_} instead.".format( - name=name, type_=type_, - ) - ) - - return preprocess(**valmap(_expect_type, named)) - - def preprocess(*_unused, **processors): """ Decorator that applies pre-processors to the arguments of a function before @@ -157,69 +124,6 @@ def call(f): return processor -def _qualified_name(obj): - """ - Return the fully-qualified name (ignoring inner classes) of a type. - """ - module = obj.__module__ - if module in ('__builtin__', '__main__', 'builtins'): - return obj.__name__ - return '.'.join([module, obj.__name__]) - - -def _expect_type(type_): - """ - Factory for type-checking functions that work the @preprocess decorator. - """ - # Slightly different messages for type and tuple of types. - _template = ( - "{{funcname}}() expected a value of type {type_or_types} " - "for argument '{{argname}}', but got {{actual}} instead." - ) - if isinstance(type_, tuple): - template = _template.format( - type_or_types=' or '.join(map(_qualified_name, type_)) - ) - else: - template = _template.format(type_or_types=_qualified_name(type_)) - - def _check_type(func, argname, argvalue): - if not isinstance(argvalue, type_): - raise TypeError( - template.format( - funcname=_qualified_name(func), - argname=argname, - actual=_qualified_name(type(argvalue)), - ) - ) - return argvalue - return _check_type - - -def optional(type_): - """ - Helper for use with `expect_types` when an input can be `type_` or `None`. - - Returns an object such that both `None` and instances of `type_` pass - checks of the form `isinstance(obj, optional(type_))`. - - Parameters - ---------- - type_ : type - Type for which to produce an option. - - Examples - -------- - >>> isinstance({}, optional(dict)) - True - >>> isinstance(None, optional(dict)) - True - >>> isinstance(1, optional(dict)) - False - """ - return (type_, type(None)) - - def _build_preprocessed_function(func, processors, args_defaults): """ Build a preprocessed function with the same signature as `func`.