ENH: Coerce user input with API method decorator

Previously we have capitalized input strings at different levels in
our code: in the user-facing API methods and in the asset finder.
This commit moves input string capitalization exclusively to the API
method to which the string was supplied. Specifically, the string is
capitalized by a preprocess API method decorator. The preprocess
decorator passes the input string to the newly defined
ensure_upper_case() method, which returns a TypeError if the argument
supplied is not a string.

ensure_upper_case() is defined in a new file, zipline/utils/input_validation.py.
The existing expect_types() method is also moved there.

Various tests in tests/test_assets.py are modified to account for the
fact that the asset finder method lookup_symol() no longer capitalizes
its supplied argument.
This commit is contained in:
Stewart Douglas
2015-10-05 17:46:02 -04:00
parent 69b734129f
commit 4e2039c9b0
8 changed files with 202 additions and 118 deletions
+51
View File
@@ -425,6 +425,23 @@ class TestMiscellaneousAPI(TestCase):
self.assertIsInstance(algo.sid(3), Equity)
self.assertIsInstance(algo.sid(4), Equity)
# Supplying a non-string argument to symbol()
# should result in a TypeError.
with self.assertRaises(TypeError):
algo.symbol(1)
with self.assertRaises(TypeError):
algo.symbol((1,))
with self.assertRaises(TypeError):
algo.symbol({1})
with self.assertRaises(TypeError):
algo.symbol([1])
with self.assertRaises(TypeError):
algo.symbol({'foo': 'bar'})
def test_future_symbol(self):
""" Tests the future_symbol API function.
"""
@@ -450,6 +467,23 @@ class TestMiscellaneousAPI(TestCase):
with self.assertRaises(SymbolNotFound):
algo.future_symbol('FOOBAR')
# Supplying a non-string argument to future_symbol()
# should result in a TypeError.
with self.assertRaises(TypeError):
algo.future_symbol(1)
with self.assertRaises(TypeError):
algo.future_symbol((1,))
with self.assertRaises(TypeError):
algo.future_symbol({1})
with self.assertRaises(TypeError):
algo.future_symbol([1])
with self.assertRaises(TypeError):
algo.future_symbol({'foo': 'bar'})
def test_future_chain(self):
""" Tests the future_chain API function.
"""
@@ -493,6 +527,23 @@ class TestMiscellaneousAPI(TestCase):
with self.assertRaises(UnsupportedDatetimeFormat):
algo.future_chain('CL', '2015-09-')
# Supplying a non-string argument to future_chain()
# should result in a TypeError.
with self.assertRaises(TypeError):
algo.future_chain(1)
with self.assertRaises(TypeError):
algo.future_chain((1,))
with self.assertRaises(TypeError):
algo.future_chain({1})
with self.assertRaises(TypeError):
algo.future_chain([1])
with self.assertRaises(TypeError):
algo.future_chain({'foo': 'bar'})
def test_set_symbol_lookup_date(self):
"""
Test the set_symbol_lookup_date API method.
+17 -17
View File
@@ -109,11 +109,11 @@ def build_lookup_generic_cases():
(finder, 1, None, assets[1]),
(finder, 2, None, assets[2]),
# Duplicated symbol with resolution date
(finder, 'duplicated', dupe_0_start, dupe_0),
(finder, 'duplicated', dupe_1_start, dupe_1),
(finder, 'DUPLICATED', dupe_0_start, dupe_0),
(finder, 'DUPLICATED', dupe_1_start, dupe_1),
# Unique symbol, with or without resolution date.
(finder, 'unique', unique_start, unique),
(finder, 'unique', None, unique),
(finder, 'UNIQUE', unique_start, unique),
(finder, 'UNIQUE', None, unique),
##
# Iterables
@@ -125,11 +125,11 @@ def build_lookup_generic_cases():
(finder, (0, 1), None, assets[:-1]),
(finder, iter((0, 1)), None, assets[:-1]),
# Iterables of symbols.
(finder, ('duplicated', 'unique'), dupe_0_start, [dupe_0, unique]),
(finder, ('duplicated', 'unique'), dupe_1_start, [dupe_1, unique]),
(finder, ('DUPLICATED', 'UNIQUE'), dupe_0_start, [dupe_0, unique]),
(finder, ('DUPLICATED', 'UNIQUE'), dupe_1_start, [dupe_1, unique]),
# Mixed types
(finder,
('duplicated', 2, 'unique', 1, dupe_1),
('DUPLICATED', 2, 'UNIQUE', 1, dupe_1),
dupe_0_start,
[dupe_0, assets[2], unique, assets[1], dupe_1]),
]
@@ -360,18 +360,18 @@ class AssetFinderTestCase(TestCase):
# we do it twice to catch caching bugs
for i in range(2):
with self.assertRaises(SymbolNotFound):
finder.lookup_symbol('test', as_of)
finder.lookup_symbol('TEST', as_of)
with self.assertRaises(SymbolNotFound):
finder.lookup_symbol('test1', as_of)
finder.lookup_symbol('TEST1', as_of)
# '@' is not a supported delimiter
with self.assertRaises(SymbolNotFound):
finder.lookup_symbol('test@1', as_of)
finder.lookup_symbol('TEST@1', as_of)
# Adding an unnecessary fuzzy shouldn't matter.
for fuzzy_char in ['-', '/', '_', '.']:
self.assertEqual(
asset_1,
finder.lookup_symbol('test%s1' % fuzzy_char, as_of)
finder.lookup_symbol('TEST%s1' % fuzzy_char, as_of)
)
def test_lookup_symbol_fuzzy(self):
@@ -434,15 +434,15 @@ class AssetFinderTestCase(TestCase):
finder = AssetFinder(self.env.engine)
for _ in range(2): # Run checks twice to test for caching bugs.
with self.assertRaises(SymbolNotFound):
finder.lookup_symbol('non_existing', dates[0])
finder.lookup_symbol('NON_EXISTING', dates[0])
with self.assertRaises(MultipleSymbolsFound):
finder.lookup_symbol('existing', None)
finder.lookup_symbol('EXISTING', None)
for i, date in enumerate(dates):
# Verify that we correctly resolve multiple symbols using
# the supplied date
result = finder.lookup_symbol('existing', date)
result = finder.lookup_symbol('EXISTING', date)
self.assertEqual(result.symbol, 'EXISTING')
self.assertEqual(result.sid, i)
@@ -497,7 +497,7 @@ class AssetFinderTestCase(TestCase):
self.env.write_data(equities_df=data)
finder = AssetFinder(self.env.engine)
results, missing = finder.lookup_generic(
['real', 1, 'fake', 'real_but_old', 'real_but_in_the_future'],
['REAL', 1, 'FAKE', 'REAL_BUT_OLD', 'REAL_BUT_IN_THE_FUTURE'],
pd.Timestamp('2013-02-01', tz='UTC'),
)
@@ -510,8 +510,8 @@ class AssetFinderTestCase(TestCase):
self.assertEqual(results[2].sid, 2)
self.assertEqual(len(missing), 2)
self.assertEqual(missing[0], 'fake')
self.assertEqual(missing[1], 'real_but_in_the_future')
self.assertEqual(missing[0], 'FAKE')
self.assertEqual(missing[1], 'REAL_BUT_IN_THE_FUTURE')
def test_insert_metadata(self):
data = {0: {'asset_type': 'equity',
+2 -1
View File
@@ -5,7 +5,8 @@ from types import FunctionType
from unittest import TestCase
from nose_parameterized import parameterized
from zipline.utils.preprocess import call, expect_types, preprocess, optional
from zipline.utils.preprocess import call, preprocess
from zipline.utils.input_validation import expect_types, optional
def noop(func, argname, argvalue):
+6 -1
View File
@@ -85,6 +85,7 @@ from zipline.utils.api_support import (
require_not_initialized,
ZiplineAPI,
)
from zipline.utils.input_validation import ensure_upper_case
from zipline.utils.cache import CachedObject, Expired
import zipline.utils.events
from zipline.utils.events import (
@@ -95,6 +96,7 @@ from zipline.utils.events import (
)
from zipline.utils.factory import create_simulation_parameters
from zipline.utils.math_utils import tolerant_equals
from zipline.utils.preprocess import preprocess
import zipline.protocol
from zipline.protocol import Event
@@ -738,6 +740,7 @@ class TradingAlgorithm(object):
self._recorded_vars[name] = value
@api_method
@preprocess(symbol_str=ensure_upper_case)
def symbol(self, symbol_str):
"""
Default symbol lookup for any source that directly maps the
@@ -770,6 +773,7 @@ class TradingAlgorithm(object):
return self.asset_finder.retrieve_asset(a_sid)
@api_method
@preprocess(symbol=ensure_upper_case)
def future_symbol(self, symbol):
""" Lookup a futures contract with a given symbol.
@@ -792,6 +796,7 @@ class TradingAlgorithm(object):
return self.asset_finder.lookup_future_symbol(symbol)
@api_method
@preprocess(root_symbol=ensure_upper_case)
def future_chain(self, root_symbol, as_of_date=None):
""" Look up a future chain with the specified parameters.
@@ -823,7 +828,7 @@ class TradingAlgorithm(object):
return FutureChain(
asset_finder=self.asset_finder,
get_datetime=self.get_datetime,
root_symbol=root_symbol.upper(),
root_symbol=root_symbol,
as_of_date=as_of_date
)
-1
View File
@@ -209,7 +209,6 @@ class AssetFinder(object):
"""
# Format inputs
symbol = symbol.upper()
if as_of_date is not None:
as_of_date = pd.Timestamp(normalize_date(as_of_date))
+1 -1
View File
@@ -1,4 +1,4 @@
from zipline.utils.preprocess import expect_types, optional
from zipline.utils.input_validation import expect_types, optional
from .term import Term
from .filters import Filter
+124
View File
@@ -0,0 +1,124 @@
# Copyright 2015 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from six import iteritems, string_types
from toolz import valmap
from zipline.utils.preprocess import preprocess
def ensure_upper_case(func, argname, arg):
if isinstance(arg, string_types):
return arg.upper()
else:
raise TypeError(
"{0}() expected argument '{1}' to"
" be a string, but got {2} instead.".format(
func.__name__, argname, arg,)
)
def expect_types(*_pos, **named):
"""
Preprocessing decorator that verifies inputs have expected types.
Usage
-----
>>> @expect_types(x=int, y=str)
... def foo(x, y):
... return x, y
...
>>> foo(2, '3')
(2, '3')
>>> foo(2.0, '3')
Traceback (most recent call last):
...
TypeError: foo() expected an argument of type 'int' for argument 'x', but got float instead. # noqa
"""
if _pos:
raise TypeError("expect_types() only takes keyword arguments.")
for name, type_ in iteritems(named):
if not isinstance(type_, (type, tuple)):
raise TypeError(
"expect_types() expected a type or tuple of types for "
"argument '{name}', but got {type_} instead.".format(
name=name, type_=type_,
)
)
return preprocess(**valmap(_expect_type, named))
def _qualified_name(obj):
"""
Return the fully-qualified name (ignoring inner classes) of a type.
"""
module = obj.__module__
if module in ('__builtin__', '__main__', 'builtins'):
return obj.__name__
return '.'.join([module, obj.__name__])
def _expect_type(type_):
"""
Factory for type-checking functions that work the @preprocess decorator.
"""
# Slightly different messages for type and tuple of types.
_template = (
"{{funcname}}() expected a value of type {type_or_types} "
"for argument '{{argname}}', but got {{actual}} instead."
)
if isinstance(type_, tuple):
template = _template.format(
type_or_types=' or '.join(map(_qualified_name, type_))
)
else:
template = _template.format(type_or_types=_qualified_name(type_))
def _check_type(func, argname, argvalue):
if not isinstance(argvalue, type_):
raise TypeError(
template.format(
funcname=_qualified_name(func),
argname=argname,
actual=_qualified_name(type(argvalue)),
)
)
return argvalue
return _check_type
def optional(type_):
"""
Helper for use with `expect_types` when an input can be `type_` or `None`.
Returns an object such that both `None` and instances of `type_` pass
checks of the form `isinstance(obj, optional(type_))`.
Parameters
----------
type_ : type
Type for which to produce an option.
Examples
--------
>>> isinstance({}, optional(dict))
True
>>> isinstance(None, optional(dict))
True
>>> isinstance(1, optional(dict))
False
"""
return (type_, type(None))
+1 -97
View File
@@ -6,45 +6,12 @@ from functools import wraps
from inspect import getargspec
from uuid import uuid4
from six import iteritems, viewkeys, exec_
from toolz import valmap
from six import viewkeys, exec_
NO_DEFAULT = object()
def expect_types(*_pos, **named):
"""
Preprocessing decorator that verifies inputs have expected types.
Usage
-----
>>> @expect_types(x=int, y=str)
... def foo(x, y):
... return x, y
...
>>> foo(2, '3')
(2, '3')
>>> foo(2.0, '3')
Traceback (most recent call last):
...
TypeError: foo() expected an argument of type 'int' for argument 'x', but got float instead. # noqa
"""
if _pos:
raise TypeError("expect_types() only takes keyword arguments.")
for name, type_ in iteritems(named):
if not isinstance(type_, (type, tuple)):
raise TypeError(
"expect_types() expected a type or tuple of types for "
"argument '{name}', but got {type_} instead.".format(
name=name, type_=type_,
)
)
return preprocess(**valmap(_expect_type, named))
def preprocess(*_unused, **processors):
"""
Decorator that applies pre-processors to the arguments of a function before
@@ -157,69 +124,6 @@ def call(f):
return processor
def _qualified_name(obj):
"""
Return the fully-qualified name (ignoring inner classes) of a type.
"""
module = obj.__module__
if module in ('__builtin__', '__main__', 'builtins'):
return obj.__name__
return '.'.join([module, obj.__name__])
def _expect_type(type_):
"""
Factory for type-checking functions that work the @preprocess decorator.
"""
# Slightly different messages for type and tuple of types.
_template = (
"{{funcname}}() expected a value of type {type_or_types} "
"for argument '{{argname}}', but got {{actual}} instead."
)
if isinstance(type_, tuple):
template = _template.format(
type_or_types=' or '.join(map(_qualified_name, type_))
)
else:
template = _template.format(type_or_types=_qualified_name(type_))
def _check_type(func, argname, argvalue):
if not isinstance(argvalue, type_):
raise TypeError(
template.format(
funcname=_qualified_name(func),
argname=argname,
actual=_qualified_name(type(argvalue)),
)
)
return argvalue
return _check_type
def optional(type_):
"""
Helper for use with `expect_types` when an input can be `type_` or `None`.
Returns an object such that both `None` and instances of `type_` pass
checks of the form `isinstance(obj, optional(type_))`.
Parameters
----------
type_ : type
Type for which to produce an option.
Examples
--------
>>> isinstance({}, optional(dict))
True
>>> isinstance(None, optional(dict))
True
>>> isinstance(1, optional(dict))
False
"""
return (type_, type(None))
def _build_preprocessed_function(func, processors, args_defaults):
"""
Build a preprocessed function with the same signature as `func`.