FIX: Check types of args passed to api methods on data

This commit is contained in:
Andrew Liang
2016-04-07 14:29:38 -04:00
parent fac5905c10
commit 8dc3ed73ab
3 changed files with 261 additions and 5 deletions
+59
View File
@@ -107,7 +107,22 @@ from zipline.test_algorithms import (
call_without_kwargs,
call_with_bad_kwargs_current,
call_with_bad_kwargs_history,
bad_type_history_assets,
bad_type_history_fields,
bad_type_history_bar_count,
bad_type_history_frequency,
bad_type_current_assets,
bad_type_current_fields,
bad_type_can_trade_assets,
bad_type_is_stale_assets,
bad_type_history_assets_kwarg,
bad_type_history_fields_kwarg,
bad_type_history_bar_count_kwarg,
bad_type_history_frequency_kwarg,
bad_type_current_assets_kwarg,
bad_type_current_fields_kwarg,
no_handle_data)
from zipline.testing import (
make_jagged_equity_info,
to_utc,
@@ -1431,6 +1446,29 @@ class TestBeforeTradingStart(TestCase):
class TestAlgoScript(TestCase):
ARG_TYPE_TEST_CASES = [
('history__assets', (bad_type_history_assets, 'Asset, str', True)),
('history__fields', (bad_type_history_fields, 'str', True)),
('history__bar_count', (bad_type_history_bar_count, 'int', False)),
('history__frequency', (bad_type_history_frequency, 'str', False)),
('current__assets', (bad_type_current_assets, 'Asset, str', True)),
('current__fields', (bad_type_current_fields, 'str', True)),
('is_stale__assets', (bad_type_is_stale_assets, 'Asset', True)),
('can_trade__assets', (bad_type_can_trade_assets, 'Asset', True)),
('history_kwarg__assets',
(bad_type_history_assets_kwarg, 'Asset, str', True)),
('history_kwarg__fields',
(bad_type_history_fields_kwarg, 'str', True)),
('history_kwarg__bar_count',
(bad_type_history_bar_count_kwarg, 'int', False)),
('history_kwarg__frequency',
(bad_type_history_frequency_kwarg, 'str', False)),
('current_kwarg__assets',
(bad_type_current_assets_kwarg, 'Asset, str', True)),
('current_kwarg__fields',
(bad_type_current_fields_kwarg, 'str', True)),
]
@classmethod
def setUpClass(cls):
setup_logger(cls)
@@ -1846,6 +1884,27 @@ def handle_data(context, data):
self.assertEqual("%s() got an unexpected keyword argument 'blahblah'"
% name, cm.exception.args[0])
@parameterized.expand(ARG_TYPE_TEST_CASES)
def test_arg_types(self, name, inputs):
keyword = name.split('__')[1]
with self.assertRaises(TypeError) as cm:
algo = TradingAlgorithm(
script=inputs[0],
sim_params=self.sim_params,
env=self.env
)
algo.run(self.data_portal)
expected = "Expected %s argument to be of type %s%s" % (
keyword,
'or iterable of type ' if inputs[2] else '',
inputs[1]
)
self.assertEqual(expected, cm.exception.args[0])
class TestGetDatetime(TestCase):
+71 -5
View File
@@ -21,6 +21,7 @@ import numpy as np
from six import iteritems
from cpython cimport bool
from collections import Iterable
from zipline.assets import Asset
from zipline.zipline_warnings import ZiplineDeprecationWarning
@@ -33,20 +34,80 @@ class assert_keywords(object):
meaningful message, unlike the one cython returns by default.
"""
def __init__(self, *args):
self.names = args
def __init__(self, arg_names, method_name):
self.names = arg_names
self.method = method_name
def __call__(self, func):
def assert_keywords_and_call(*args, **kwargs):
for field in kwargs:
if field not in self.names:
raise TypeError("%s() got an unexpected keyword argument"
" '%s'" % (func.__name__, field))
" '%s'" % (self.method, field))
return func(*args, **kwargs)
return assert_keywords_and_call
KEYWORDS = ['assets', 'fields', 'bar_count', 'frequency']
class assert_types(object):
"""
Asserts that the arguments passed into the wrapped function are consistent
with the types passed into this decorator. If not, raise a TypeError with
a meaningful message.
"""
def __init__(self, *args):
self.types = args
self.keys_to_types = dict(zip(KEYWORDS, args))
def _is_iterable(self, obj):
return isinstance(obj, Iterable) and not isinstance(obj, str)
def __call__(self, func):
def assert_types_and_call(*args, **kwargs):
for i, arg in enumerate(args[1:]):
if isinstance(arg, self.types[i]):
continue
elif i in (0, 1) and self._is_iterable(arg):
if isinstance(arg[0], self.types[i]):
continue
expected_type = self.types[i].__name__ \
if not self._is_iterable(self.types[i]) \
else ', '.join([type.__name__ for type in self.types[i]])
raise TypeError("Expected %s argument to be of type %s%s" %
(KEYWORDS[i],
'or iterable of type ' if i in (0, 1) else '',
expected_type)
)
for keyword, arg in iteritems(kwargs):
if isinstance(arg, self.keys_to_types[keyword]):
continue
elif keyword in ('assets', 'fields') and \
self._is_iterable(arg):
if isinstance(arg[0], self.keys_to_types[i]):
continue
expected_type = self.keys_to_types[keyword].__name__ \
if not self._is_iterable(self.keys_to_types[keyword]) \
else ', '.join([type.__name__ for type in
self.keys_to_types[keyword]])
raise TypeError("Expected %s argument to be of type %s%s" %
(keyword,
'or iterable of type ' if keyword in
('assets', 'fields') else '',
expected_type)
)
return func(*args, **kwargs)
return assert_types_and_call
@contextmanager
def handle_non_market_minutes(bar_data):
try:
@@ -149,7 +210,8 @@ cdef class BarData:
return dt
@assert_keywords('assets', 'fields')
@assert_keywords(arg_names=('assets', 'fields'), method_name='current')
@assert_types((Asset, str), str)
def current(self, assets, fields):
"""
Returns the current value of the given assets for the given fields
@@ -327,6 +389,7 @@ cdef class BarData:
cdef bool _is_iterable(self, obj):
return hasattr(obj, '__iter__') and not isinstance(obj, str)
@assert_types(Asset)
def can_trade(self, assets):
"""
For the given asset or iterable of assets, returns true if the asset
@@ -373,6 +436,7 @@ cdef class BarData:
return False
@assert_types(Asset)
def is_stale(self, assets):
"""
For the given asset or iterable of assets, returns true if the asset
@@ -432,7 +496,9 @@ cdef class BarData:
return not (last_traded_dt is pd.NaT)
@assert_keywords('assets', 'fields', 'bar_count', 'frequency')
@assert_keywords(arg_names=('assets', 'fields', 'bar_count', 'frequency'),
method_name='history')
@assert_types((Asset, str), str, int, str)
def history(self, assets, fields, bar_count, frequency):
"""
Returns a window of data for the given assets and fields.
+131
View File
@@ -957,3 +957,134 @@ def initialize(context):
def handle_data(context, data):
current = data.current(assets=symbol('TEST'), blahblah="price")
"""
bad_type_history_assets = """
def initialize(context):
pass
def handle_data(context, data):
data.history(1, 'price', 5, '1d')
"""
bad_type_history_fields = """
from zipline.api import symbol
def initialize(context):
pass
def handle_data(context, data):
data.history(symbol('TEST'), 10 , 5, '1d')
"""
bad_type_history_bar_count = """
from zipline.api import symbol
def initialize(context):
pass
def handle_data(context, data):
data.history(symbol('TEST'), 'price', '5', '1d')
"""
bad_type_history_frequency = """
from zipline.api import symbol
def initialize(context):
pass
def handle_data(context, data):
data.history(symbol('TEST'), 'price', 5, 1)
"""
bad_type_current_assets = """
def initialize(context):
pass
def handle_data(context, data):
data.current(1, 'price')
"""
bad_type_current_fields = """
from zipline.api import symbol
def initialize(context):
pass
def handle_data(context, data):
data.current(symbol('TEST'), 10)
"""
bad_type_is_stale_assets = """
def initialize(context):
pass
def handle_data(context, data):
data.is_stale('TEST')
"""
bad_type_can_trade_assets = """
def initialize(context):
pass
def handle_data(context, data):
data.can_trade('TEST')
"""
bad_type_history_assets_kwarg = """
def initialize(context):
pass
def handle_data(context, data):
data.history(frequency='1d', fields='price', assets=1, bar_count=5)
"""
bad_type_history_fields_kwarg = """
from zipline.api import symbol
def initialize(context):
pass
def handle_data(context, data):
data.history(frequency='1d', fields=10, assets=symbol('TEST'),
bar_count=5)
"""
bad_type_history_bar_count_kwarg = """
from zipline.api import symbol
def initialize(context):
pass
def handle_data(context, data):
data.history(frequency='1d', fields='price', assets=symbol('TEST'),
bar_count='5')
"""
bad_type_history_frequency_kwarg = """
from zipline.api import symbol
def initialize(context):
pass
def handle_data(context, data):
data.history(frequency=1, fields='price', assets=symbol('TEST'),
bar_count=5)
"""
bad_type_current_assets_kwarg = """
def initialize(context):
pass
def handle_data(context, data):
data.current(fields='price', assets=1)
"""
bad_type_current_fields_kwarg = """
from zipline.api import symbol
def initialize(context):
pass
def handle_data(context, data):
data.current(fields=10, assets=symbol('TEST'))
"""