mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-01 04:14:47 +08:00
ENH: Coerce user input with API method decorator
Previously we have capitalized input strings at different levels in our code: in the user-facing API methods and in the asset finder. This commit moves input string capitalization exclusively to the API method to which the string was supplied. Specifically, the string is capitalized by a preprocess API method decorator. The preprocess decorator passes the input string to the newly defined ensure_upper_case() method, which returns a TypeError if the argument supplied is not a string. ensure_upper_case() is defined in a new file, zipline/utils/input_validation.py. The existing expect_types() method is also moved there. Various tests in tests/test_assets.py are modified to account for the fact that the asset finder method lookup_symol() no longer capitalizes its supplied argument.
This commit is contained in:
@@ -425,6 +425,23 @@ class TestMiscellaneousAPI(TestCase):
|
||||
self.assertIsInstance(algo.sid(3), Equity)
|
||||
self.assertIsInstance(algo.sid(4), Equity)
|
||||
|
||||
# Supplying a non-string argument to symbol()
|
||||
# should result in a TypeError.
|
||||
with self.assertRaises(TypeError):
|
||||
algo.symbol(1)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
algo.symbol((1,))
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
algo.symbol({1})
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
algo.symbol([1])
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
algo.symbol({'foo': 'bar'})
|
||||
|
||||
def test_future_symbol(self):
|
||||
""" Tests the future_symbol API function.
|
||||
"""
|
||||
@@ -450,6 +467,23 @@ class TestMiscellaneousAPI(TestCase):
|
||||
with self.assertRaises(SymbolNotFound):
|
||||
algo.future_symbol('FOOBAR')
|
||||
|
||||
# Supplying a non-string argument to future_symbol()
|
||||
# should result in a TypeError.
|
||||
with self.assertRaises(TypeError):
|
||||
algo.future_symbol(1)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
algo.future_symbol((1,))
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
algo.future_symbol({1})
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
algo.future_symbol([1])
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
algo.future_symbol({'foo': 'bar'})
|
||||
|
||||
def test_future_chain(self):
|
||||
""" Tests the future_chain API function.
|
||||
"""
|
||||
@@ -493,6 +527,23 @@ class TestMiscellaneousAPI(TestCase):
|
||||
with self.assertRaises(UnsupportedDatetimeFormat):
|
||||
algo.future_chain('CL', '2015-09-')
|
||||
|
||||
# Supplying a non-string argument to future_chain()
|
||||
# should result in a TypeError.
|
||||
with self.assertRaises(TypeError):
|
||||
algo.future_chain(1)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
algo.future_chain((1,))
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
algo.future_chain({1})
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
algo.future_chain([1])
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
algo.future_chain({'foo': 'bar'})
|
||||
|
||||
def test_set_symbol_lookup_date(self):
|
||||
"""
|
||||
Test the set_symbol_lookup_date API method.
|
||||
|
||||
+17
-17
@@ -109,11 +109,11 @@ def build_lookup_generic_cases():
|
||||
(finder, 1, None, assets[1]),
|
||||
(finder, 2, None, assets[2]),
|
||||
# Duplicated symbol with resolution date
|
||||
(finder, 'duplicated', dupe_0_start, dupe_0),
|
||||
(finder, 'duplicated', dupe_1_start, dupe_1),
|
||||
(finder, 'DUPLICATED', dupe_0_start, dupe_0),
|
||||
(finder, 'DUPLICATED', dupe_1_start, dupe_1),
|
||||
# Unique symbol, with or without resolution date.
|
||||
(finder, 'unique', unique_start, unique),
|
||||
(finder, 'unique', None, unique),
|
||||
(finder, 'UNIQUE', unique_start, unique),
|
||||
(finder, 'UNIQUE', None, unique),
|
||||
|
||||
##
|
||||
# Iterables
|
||||
@@ -125,11 +125,11 @@ def build_lookup_generic_cases():
|
||||
(finder, (0, 1), None, assets[:-1]),
|
||||
(finder, iter((0, 1)), None, assets[:-1]),
|
||||
# Iterables of symbols.
|
||||
(finder, ('duplicated', 'unique'), dupe_0_start, [dupe_0, unique]),
|
||||
(finder, ('duplicated', 'unique'), dupe_1_start, [dupe_1, unique]),
|
||||
(finder, ('DUPLICATED', 'UNIQUE'), dupe_0_start, [dupe_0, unique]),
|
||||
(finder, ('DUPLICATED', 'UNIQUE'), dupe_1_start, [dupe_1, unique]),
|
||||
# Mixed types
|
||||
(finder,
|
||||
('duplicated', 2, 'unique', 1, dupe_1),
|
||||
('DUPLICATED', 2, 'UNIQUE', 1, dupe_1),
|
||||
dupe_0_start,
|
||||
[dupe_0, assets[2], unique, assets[1], dupe_1]),
|
||||
]
|
||||
@@ -360,18 +360,18 @@ class AssetFinderTestCase(TestCase):
|
||||
# we do it twice to catch caching bugs
|
||||
for i in range(2):
|
||||
with self.assertRaises(SymbolNotFound):
|
||||
finder.lookup_symbol('test', as_of)
|
||||
finder.lookup_symbol('TEST', as_of)
|
||||
with self.assertRaises(SymbolNotFound):
|
||||
finder.lookup_symbol('test1', as_of)
|
||||
finder.lookup_symbol('TEST1', as_of)
|
||||
# '@' is not a supported delimiter
|
||||
with self.assertRaises(SymbolNotFound):
|
||||
finder.lookup_symbol('test@1', as_of)
|
||||
finder.lookup_symbol('TEST@1', as_of)
|
||||
|
||||
# Adding an unnecessary fuzzy shouldn't matter.
|
||||
for fuzzy_char in ['-', '/', '_', '.']:
|
||||
self.assertEqual(
|
||||
asset_1,
|
||||
finder.lookup_symbol('test%s1' % fuzzy_char, as_of)
|
||||
finder.lookup_symbol('TEST%s1' % fuzzy_char, as_of)
|
||||
)
|
||||
|
||||
def test_lookup_symbol_fuzzy(self):
|
||||
@@ -434,15 +434,15 @@ class AssetFinderTestCase(TestCase):
|
||||
finder = AssetFinder(self.env.engine)
|
||||
for _ in range(2): # Run checks twice to test for caching bugs.
|
||||
with self.assertRaises(SymbolNotFound):
|
||||
finder.lookup_symbol('non_existing', dates[0])
|
||||
finder.lookup_symbol('NON_EXISTING', dates[0])
|
||||
|
||||
with self.assertRaises(MultipleSymbolsFound):
|
||||
finder.lookup_symbol('existing', None)
|
||||
finder.lookup_symbol('EXISTING', None)
|
||||
|
||||
for i, date in enumerate(dates):
|
||||
# Verify that we correctly resolve multiple symbols using
|
||||
# the supplied date
|
||||
result = finder.lookup_symbol('existing', date)
|
||||
result = finder.lookup_symbol('EXISTING', date)
|
||||
self.assertEqual(result.symbol, 'EXISTING')
|
||||
self.assertEqual(result.sid, i)
|
||||
|
||||
@@ -497,7 +497,7 @@ class AssetFinderTestCase(TestCase):
|
||||
self.env.write_data(equities_df=data)
|
||||
finder = AssetFinder(self.env.engine)
|
||||
results, missing = finder.lookup_generic(
|
||||
['real', 1, 'fake', 'real_but_old', 'real_but_in_the_future'],
|
||||
['REAL', 1, 'FAKE', 'REAL_BUT_OLD', 'REAL_BUT_IN_THE_FUTURE'],
|
||||
pd.Timestamp('2013-02-01', tz='UTC'),
|
||||
)
|
||||
|
||||
@@ -510,8 +510,8 @@ class AssetFinderTestCase(TestCase):
|
||||
self.assertEqual(results[2].sid, 2)
|
||||
|
||||
self.assertEqual(len(missing), 2)
|
||||
self.assertEqual(missing[0], 'fake')
|
||||
self.assertEqual(missing[1], 'real_but_in_the_future')
|
||||
self.assertEqual(missing[0], 'FAKE')
|
||||
self.assertEqual(missing[1], 'REAL_BUT_IN_THE_FUTURE')
|
||||
|
||||
def test_insert_metadata(self):
|
||||
data = {0: {'asset_type': 'equity',
|
||||
|
||||
@@ -5,7 +5,8 @@ from types import FunctionType
|
||||
from unittest import TestCase
|
||||
from nose_parameterized import parameterized
|
||||
|
||||
from zipline.utils.preprocess import call, expect_types, preprocess, optional
|
||||
from zipline.utils.preprocess import call, preprocess
|
||||
from zipline.utils.input_validation import expect_types, optional
|
||||
|
||||
|
||||
def noop(func, argname, argvalue):
|
||||
|
||||
@@ -85,6 +85,7 @@ from zipline.utils.api_support import (
|
||||
require_not_initialized,
|
||||
ZiplineAPI,
|
||||
)
|
||||
from zipline.utils.input_validation import ensure_upper_case
|
||||
from zipline.utils.cache import CachedObject, Expired
|
||||
import zipline.utils.events
|
||||
from zipline.utils.events import (
|
||||
@@ -95,6 +96,7 @@ from zipline.utils.events import (
|
||||
)
|
||||
from zipline.utils.factory import create_simulation_parameters
|
||||
from zipline.utils.math_utils import tolerant_equals
|
||||
from zipline.utils.preprocess import preprocess
|
||||
|
||||
import zipline.protocol
|
||||
from zipline.protocol import Event
|
||||
@@ -738,6 +740,7 @@ class TradingAlgorithm(object):
|
||||
self._recorded_vars[name] = value
|
||||
|
||||
@api_method
|
||||
@preprocess(symbol_str=ensure_upper_case)
|
||||
def symbol(self, symbol_str):
|
||||
"""
|
||||
Default symbol lookup for any source that directly maps the
|
||||
@@ -770,6 +773,7 @@ class TradingAlgorithm(object):
|
||||
return self.asset_finder.retrieve_asset(a_sid)
|
||||
|
||||
@api_method
|
||||
@preprocess(symbol=ensure_upper_case)
|
||||
def future_symbol(self, symbol):
|
||||
""" Lookup a futures contract with a given symbol.
|
||||
|
||||
@@ -792,6 +796,7 @@ class TradingAlgorithm(object):
|
||||
return self.asset_finder.lookup_future_symbol(symbol)
|
||||
|
||||
@api_method
|
||||
@preprocess(root_symbol=ensure_upper_case)
|
||||
def future_chain(self, root_symbol, as_of_date=None):
|
||||
""" Look up a future chain with the specified parameters.
|
||||
|
||||
@@ -823,7 +828,7 @@ class TradingAlgorithm(object):
|
||||
return FutureChain(
|
||||
asset_finder=self.asset_finder,
|
||||
get_datetime=self.get_datetime,
|
||||
root_symbol=root_symbol.upper(),
|
||||
root_symbol=root_symbol,
|
||||
as_of_date=as_of_date
|
||||
)
|
||||
|
||||
|
||||
@@ -209,7 +209,6 @@ class AssetFinder(object):
|
||||
"""
|
||||
|
||||
# Format inputs
|
||||
symbol = symbol.upper()
|
||||
if as_of_date is not None:
|
||||
as_of_date = pd.Timestamp(normalize_date(as_of_date))
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from zipline.utils.preprocess import expect_types, optional
|
||||
from zipline.utils.input_validation import expect_types, optional
|
||||
|
||||
from .term import Term
|
||||
from .filters import Filter
|
||||
|
||||
@@ -0,0 +1,124 @@
|
||||
# Copyright 2015 Quantopian, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from six import iteritems, string_types
|
||||
from toolz import valmap
|
||||
|
||||
from zipline.utils.preprocess import preprocess
|
||||
|
||||
|
||||
def ensure_upper_case(func, argname, arg):
|
||||
if isinstance(arg, string_types):
|
||||
return arg.upper()
|
||||
else:
|
||||
raise TypeError(
|
||||
"{0}() expected argument '{1}' to"
|
||||
" be a string, but got {2} instead.".format(
|
||||
func.__name__, argname, arg,)
|
||||
)
|
||||
|
||||
|
||||
def expect_types(*_pos, **named):
|
||||
"""
|
||||
Preprocessing decorator that verifies inputs have expected types.
|
||||
|
||||
Usage
|
||||
-----
|
||||
>>> @expect_types(x=int, y=str)
|
||||
... def foo(x, y):
|
||||
... return x, y
|
||||
...
|
||||
>>> foo(2, '3')
|
||||
(2, '3')
|
||||
>>> foo(2.0, '3')
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
TypeError: foo() expected an argument of type 'int' for argument 'x', but got float instead. # noqa
|
||||
"""
|
||||
if _pos:
|
||||
raise TypeError("expect_types() only takes keyword arguments.")
|
||||
|
||||
for name, type_ in iteritems(named):
|
||||
if not isinstance(type_, (type, tuple)):
|
||||
raise TypeError(
|
||||
"expect_types() expected a type or tuple of types for "
|
||||
"argument '{name}', but got {type_} instead.".format(
|
||||
name=name, type_=type_,
|
||||
)
|
||||
)
|
||||
|
||||
return preprocess(**valmap(_expect_type, named))
|
||||
|
||||
|
||||
def _qualified_name(obj):
|
||||
"""
|
||||
Return the fully-qualified name (ignoring inner classes) of a type.
|
||||
"""
|
||||
module = obj.__module__
|
||||
if module in ('__builtin__', '__main__', 'builtins'):
|
||||
return obj.__name__
|
||||
return '.'.join([module, obj.__name__])
|
||||
|
||||
|
||||
def _expect_type(type_):
|
||||
"""
|
||||
Factory for type-checking functions that work the @preprocess decorator.
|
||||
"""
|
||||
# Slightly different messages for type and tuple of types.
|
||||
_template = (
|
||||
"{{funcname}}() expected a value of type {type_or_types} "
|
||||
"for argument '{{argname}}', but got {{actual}} instead."
|
||||
)
|
||||
if isinstance(type_, tuple):
|
||||
template = _template.format(
|
||||
type_or_types=' or '.join(map(_qualified_name, type_))
|
||||
)
|
||||
else:
|
||||
template = _template.format(type_or_types=_qualified_name(type_))
|
||||
|
||||
def _check_type(func, argname, argvalue):
|
||||
if not isinstance(argvalue, type_):
|
||||
raise TypeError(
|
||||
template.format(
|
||||
funcname=_qualified_name(func),
|
||||
argname=argname,
|
||||
actual=_qualified_name(type(argvalue)),
|
||||
)
|
||||
)
|
||||
return argvalue
|
||||
return _check_type
|
||||
|
||||
|
||||
def optional(type_):
|
||||
"""
|
||||
Helper for use with `expect_types` when an input can be `type_` or `None`.
|
||||
|
||||
Returns an object such that both `None` and instances of `type_` pass
|
||||
checks of the form `isinstance(obj, optional(type_))`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
type_ : type
|
||||
Type for which to produce an option.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> isinstance({}, optional(dict))
|
||||
True
|
||||
>>> isinstance(None, optional(dict))
|
||||
True
|
||||
>>> isinstance(1, optional(dict))
|
||||
False
|
||||
"""
|
||||
return (type_, type(None))
|
||||
@@ -6,45 +6,12 @@ from functools import wraps
|
||||
from inspect import getargspec
|
||||
from uuid import uuid4
|
||||
|
||||
from six import iteritems, viewkeys, exec_
|
||||
from toolz import valmap
|
||||
from six import viewkeys, exec_
|
||||
|
||||
|
||||
NO_DEFAULT = object()
|
||||
|
||||
|
||||
def expect_types(*_pos, **named):
|
||||
"""
|
||||
Preprocessing decorator that verifies inputs have expected types.
|
||||
|
||||
Usage
|
||||
-----
|
||||
>>> @expect_types(x=int, y=str)
|
||||
... def foo(x, y):
|
||||
... return x, y
|
||||
...
|
||||
>>> foo(2, '3')
|
||||
(2, '3')
|
||||
>>> foo(2.0, '3')
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
TypeError: foo() expected an argument of type 'int' for argument 'x', but got float instead. # noqa
|
||||
"""
|
||||
if _pos:
|
||||
raise TypeError("expect_types() only takes keyword arguments.")
|
||||
|
||||
for name, type_ in iteritems(named):
|
||||
if not isinstance(type_, (type, tuple)):
|
||||
raise TypeError(
|
||||
"expect_types() expected a type or tuple of types for "
|
||||
"argument '{name}', but got {type_} instead.".format(
|
||||
name=name, type_=type_,
|
||||
)
|
||||
)
|
||||
|
||||
return preprocess(**valmap(_expect_type, named))
|
||||
|
||||
|
||||
def preprocess(*_unused, **processors):
|
||||
"""
|
||||
Decorator that applies pre-processors to the arguments of a function before
|
||||
@@ -157,69 +124,6 @@ def call(f):
|
||||
return processor
|
||||
|
||||
|
||||
def _qualified_name(obj):
|
||||
"""
|
||||
Return the fully-qualified name (ignoring inner classes) of a type.
|
||||
"""
|
||||
module = obj.__module__
|
||||
if module in ('__builtin__', '__main__', 'builtins'):
|
||||
return obj.__name__
|
||||
return '.'.join([module, obj.__name__])
|
||||
|
||||
|
||||
def _expect_type(type_):
|
||||
"""
|
||||
Factory for type-checking functions that work the @preprocess decorator.
|
||||
"""
|
||||
# Slightly different messages for type and tuple of types.
|
||||
_template = (
|
||||
"{{funcname}}() expected a value of type {type_or_types} "
|
||||
"for argument '{{argname}}', but got {{actual}} instead."
|
||||
)
|
||||
if isinstance(type_, tuple):
|
||||
template = _template.format(
|
||||
type_or_types=' or '.join(map(_qualified_name, type_))
|
||||
)
|
||||
else:
|
||||
template = _template.format(type_or_types=_qualified_name(type_))
|
||||
|
||||
def _check_type(func, argname, argvalue):
|
||||
if not isinstance(argvalue, type_):
|
||||
raise TypeError(
|
||||
template.format(
|
||||
funcname=_qualified_name(func),
|
||||
argname=argname,
|
||||
actual=_qualified_name(type(argvalue)),
|
||||
)
|
||||
)
|
||||
return argvalue
|
||||
return _check_type
|
||||
|
||||
|
||||
def optional(type_):
|
||||
"""
|
||||
Helper for use with `expect_types` when an input can be `type_` or `None`.
|
||||
|
||||
Returns an object such that both `None` and instances of `type_` pass
|
||||
checks of the form `isinstance(obj, optional(type_))`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
type_ : type
|
||||
Type for which to produce an option.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> isinstance({}, optional(dict))
|
||||
True
|
||||
>>> isinstance(None, optional(dict))
|
||||
True
|
||||
>>> isinstance(1, optional(dict))
|
||||
False
|
||||
"""
|
||||
return (type_, type(None))
|
||||
|
||||
|
||||
def _build_preprocessed_function(func, processors, args_defaults):
|
||||
"""
|
||||
Build a preprocessed function with the same signature as `func`.
|
||||
|
||||
Reference in New Issue
Block a user