mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-05 19:03:30 +08:00
Merge pull request #1117 from quantopian/error_messages2
FIX: Check types of args passed to api methods on data
This commit is contained in:
@@ -107,7 +107,22 @@ from zipline.test_algorithms import (
|
||||
call_without_kwargs,
|
||||
call_with_bad_kwargs_current,
|
||||
call_with_bad_kwargs_history,
|
||||
bad_type_history_assets,
|
||||
bad_type_history_fields,
|
||||
bad_type_history_bar_count,
|
||||
bad_type_history_frequency,
|
||||
bad_type_current_assets,
|
||||
bad_type_current_fields,
|
||||
bad_type_can_trade_assets,
|
||||
bad_type_is_stale_assets,
|
||||
bad_type_history_assets_kwarg,
|
||||
bad_type_history_fields_kwarg,
|
||||
bad_type_history_bar_count_kwarg,
|
||||
bad_type_history_frequency_kwarg,
|
||||
bad_type_current_assets_kwarg,
|
||||
bad_type_current_fields_kwarg,
|
||||
no_handle_data)
|
||||
|
||||
from zipline.testing import (
|
||||
make_jagged_equity_info,
|
||||
to_utc,
|
||||
@@ -1431,6 +1446,29 @@ class TestBeforeTradingStart(TestCase):
|
||||
|
||||
class TestAlgoScript(TestCase):
|
||||
|
||||
ARG_TYPE_TEST_CASES = [
|
||||
('history__assets', (bad_type_history_assets, 'Asset, str', True)),
|
||||
('history__fields', (bad_type_history_fields, 'str', True)),
|
||||
('history__bar_count', (bad_type_history_bar_count, 'int', False)),
|
||||
('history__frequency', (bad_type_history_frequency, 'str', False)),
|
||||
('current__assets', (bad_type_current_assets, 'Asset, str', True)),
|
||||
('current__fields', (bad_type_current_fields, 'str', True)),
|
||||
('is_stale__assets', (bad_type_is_stale_assets, 'Asset', True)),
|
||||
('can_trade__assets', (bad_type_can_trade_assets, 'Asset', True)),
|
||||
('history_kwarg__assets',
|
||||
(bad_type_history_assets_kwarg, 'Asset, str', True)),
|
||||
('history_kwarg__fields',
|
||||
(bad_type_history_fields_kwarg, 'str', True)),
|
||||
('history_kwarg__bar_count',
|
||||
(bad_type_history_bar_count_kwarg, 'int', False)),
|
||||
('history_kwarg__frequency',
|
||||
(bad_type_history_frequency_kwarg, 'str', False)),
|
||||
('current_kwarg__assets',
|
||||
(bad_type_current_assets_kwarg, 'Asset, str', True)),
|
||||
('current_kwarg__fields',
|
||||
(bad_type_current_fields_kwarg, 'str', True)),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
setup_logger(cls)
|
||||
@@ -1846,6 +1884,27 @@ def handle_data(context, data):
|
||||
self.assertEqual("%s() got an unexpected keyword argument 'blahblah'"
|
||||
% name, cm.exception.args[0])
|
||||
|
||||
@parameterized.expand(ARG_TYPE_TEST_CASES)
|
||||
def test_arg_types(self, name, inputs):
|
||||
|
||||
keyword = name.split('__')[1]
|
||||
|
||||
with self.assertRaises(TypeError) as cm:
|
||||
algo = TradingAlgorithm(
|
||||
script=inputs[0],
|
||||
sim_params=self.sim_params,
|
||||
env=self.env
|
||||
)
|
||||
algo.run(self.data_portal)
|
||||
|
||||
expected = "Expected %s argument to be of type %s%s" % (
|
||||
keyword,
|
||||
'or iterable of type ' if inputs[2] else '',
|
||||
inputs[1]
|
||||
)
|
||||
|
||||
self.assertEqual(expected, cm.exception.args[0])
|
||||
|
||||
|
||||
class TestGetDatetime(TestCase):
|
||||
|
||||
|
||||
+77
-13
@@ -21,27 +21,91 @@ import numpy as np
|
||||
|
||||
from six import iteritems
|
||||
from cpython cimport bool
|
||||
from collections import Iterable
|
||||
|
||||
from zipline.assets import Asset
|
||||
from zipline.zipline_warnings import ZiplineDeprecationWarning
|
||||
|
||||
|
||||
class assert_keywords(object):
|
||||
cdef bool _is_iterable(obj):
|
||||
return isinstance(obj, Iterable) and not isinstance(obj, str)
|
||||
|
||||
|
||||
cdef class check_parameters(object):
|
||||
"""
|
||||
Asserts that the keywords passed into the wrapped function are included
|
||||
in those passed into this decorator. If not, raise a TypeError with a
|
||||
meaningful message, unlike the one cython returns by default.
|
||||
"""
|
||||
meaningful message, unlike the one Cython returns by default.
|
||||
|
||||
def __init__(self, *args):
|
||||
self.names = args
|
||||
Also asserts that the arguments passed into the wrapped function are
|
||||
consistent with the types passed into this decorator. If not, raise a
|
||||
TypeError with a meaningful message.
|
||||
"""
|
||||
cdef tuple keyword_names
|
||||
cdef tuple types
|
||||
cdef dict keys_to_types
|
||||
|
||||
def __init__(self, keyword_names, types):
|
||||
self.keyword_names = keyword_names
|
||||
self.types = types
|
||||
|
||||
self.keys_to_types = dict(zip(keyword_names, types))
|
||||
|
||||
def __call__(self, func):
|
||||
def assert_keywords_and_call(*args, **kwargs):
|
||||
cdef short i
|
||||
|
||||
# verify all the keyword arguments
|
||||
for field in kwargs:
|
||||
if field not in self.names:
|
||||
if field not in self.keyword_names:
|
||||
raise TypeError("%s() got an unexpected keyword argument"
|
||||
" '%s'" % (func.__name__, field))
|
||||
|
||||
# verify type of each arg
|
||||
i = 0
|
||||
while i < (len(args) - 1):
|
||||
arg = args[i + 1]
|
||||
expected_type = self.types[i]
|
||||
|
||||
if isinstance(arg, expected_type):
|
||||
i += 1
|
||||
continue
|
||||
|
||||
elif (i == 0 or i == 1) and _is_iterable(arg):
|
||||
if isinstance(arg[0], expected_type):
|
||||
i += 1
|
||||
continue
|
||||
|
||||
expected_type_name = expected_type.__name__ \
|
||||
if not _is_iterable(expected_type) \
|
||||
else ', '.join([type_.__name__ for type_ in expected_type])
|
||||
|
||||
raise TypeError("Expected %s argument to be of type %s%s" %
|
||||
(self.keyword_names[i],
|
||||
'or iterable of type ' if i in (0, 1) else '',
|
||||
expected_type_name)
|
||||
)
|
||||
|
||||
# verify type of each kwarg
|
||||
for keyword, arg in iteritems(kwargs):
|
||||
if isinstance(arg, self.keys_to_types[keyword]):
|
||||
continue
|
||||
elif keyword in ('assets', 'fields') and _is_iterable(arg):
|
||||
if isinstance(arg[0], self.keys_to_types[i]):
|
||||
continue
|
||||
|
||||
expected_type = self.keys_to_types[keyword].__name__ \
|
||||
if not _is_iterable(self.keys_to_types[keyword]) \
|
||||
else ', '.join([type.__name__ for type in
|
||||
self.keys_to_types[keyword]])
|
||||
|
||||
raise TypeError("Expected %s argument to be of type %s%s" %
|
||||
(keyword,
|
||||
'or iterable of type ' if keyword in
|
||||
('assets', 'fields') else '',
|
||||
expected_type)
|
||||
)
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return assert_keywords_and_call
|
||||
@@ -149,7 +213,7 @@ cdef class BarData:
|
||||
|
||||
return dt
|
||||
|
||||
@assert_keywords('assets', 'fields')
|
||||
@check_parameters(('assets', 'fields'), ((Asset, str), str))
|
||||
def current(self, assets, fields):
|
||||
"""
|
||||
Returns the current value of the given assets for the given fields
|
||||
@@ -206,8 +270,8 @@ cdef class BarData:
|
||||
the current trade bar. If there is no current trade bar, NaN is
|
||||
returned.
|
||||
"""
|
||||
multiple_assets = self._is_iterable(assets)
|
||||
multiple_fields = self._is_iterable(fields)
|
||||
multiple_assets = _is_iterable(assets)
|
||||
multiple_fields = _is_iterable(fields)
|
||||
|
||||
# There's some overly verbose code in here, particularly around
|
||||
# 'do something if self._adjust_minutes is False, otherwise do
|
||||
@@ -324,9 +388,7 @@ cdef class BarData:
|
||||
|
||||
return pd.DataFrame(data)
|
||||
|
||||
cdef bool _is_iterable(self, obj):
|
||||
return hasattr(obj, '__iter__') and not isinstance(obj, str)
|
||||
|
||||
@check_parameters(('assets',), (Asset,))
|
||||
def can_trade(self, assets):
|
||||
"""
|
||||
For the given asset or iterable of assets, returns true if the asset
|
||||
@@ -373,6 +435,7 @@ cdef class BarData:
|
||||
|
||||
return False
|
||||
|
||||
@check_parameters(('assets',), (Asset,))
|
||||
def is_stale(self, assets):
|
||||
"""
|
||||
For the given asset or iterable of assets, returns true if the asset
|
||||
@@ -432,7 +495,8 @@ cdef class BarData:
|
||||
|
||||
return not (last_traded_dt is pd.NaT)
|
||||
|
||||
@assert_keywords('assets', 'fields', 'bar_count', 'frequency')
|
||||
@check_parameters(('assets', 'fields', 'bar_count', 'frequency'),
|
||||
((Asset, str), str, int, str))
|
||||
def history(self, assets, fields, bar_count, frequency):
|
||||
"""
|
||||
Returns a window of data for the given assets and fields.
|
||||
|
||||
@@ -694,9 +694,6 @@ class DataPortal(object):
|
||||
if field not in BASE_FIELDS:
|
||||
raise KeyError("Invalid column: " + str(field))
|
||||
|
||||
if isinstance(asset, int):
|
||||
asset = self.env.asset_finder.retrieve_asset(asset)
|
||||
|
||||
if dt < asset.start_date or \
|
||||
(data_frequency == "daily" and dt > asset.end_date) or \
|
||||
(data_frequency == "minute" and
|
||||
@@ -827,8 +824,6 @@ class DataPortal(object):
|
||||
-------
|
||||
The value of the desired field at the desired time.
|
||||
"""
|
||||
if isinstance(asset, int):
|
||||
asset = self._asset_finder.retrieve_asset(asset)
|
||||
|
||||
if spot_value is None:
|
||||
spot_value = self.get_spot_value(asset, field, dt, data_frequency)
|
||||
|
||||
@@ -957,3 +957,134 @@ def initialize(context):
|
||||
def handle_data(context, data):
|
||||
current = data.current(assets=symbol('TEST'), blahblah="price")
|
||||
"""
|
||||
|
||||
bad_type_history_assets = """
|
||||
def initialize(context):
|
||||
pass
|
||||
|
||||
def handle_data(context, data):
|
||||
data.history(1, 'price', 5, '1d')
|
||||
"""
|
||||
|
||||
bad_type_history_fields = """
|
||||
from zipline.api import symbol
|
||||
|
||||
def initialize(context):
|
||||
pass
|
||||
|
||||
def handle_data(context, data):
|
||||
data.history(symbol('TEST'), 10 , 5, '1d')
|
||||
"""
|
||||
|
||||
bad_type_history_bar_count = """
|
||||
from zipline.api import symbol
|
||||
|
||||
def initialize(context):
|
||||
pass
|
||||
|
||||
def handle_data(context, data):
|
||||
data.history(symbol('TEST'), 'price', '5', '1d')
|
||||
"""
|
||||
|
||||
bad_type_history_frequency = """
|
||||
from zipline.api import symbol
|
||||
|
||||
def initialize(context):
|
||||
pass
|
||||
|
||||
def handle_data(context, data):
|
||||
data.history(symbol('TEST'), 'price', 5, 1)
|
||||
"""
|
||||
|
||||
bad_type_current_assets = """
|
||||
def initialize(context):
|
||||
pass
|
||||
|
||||
def handle_data(context, data):
|
||||
data.current(1, 'price')
|
||||
"""
|
||||
|
||||
bad_type_current_fields = """
|
||||
from zipline.api import symbol
|
||||
|
||||
def initialize(context):
|
||||
pass
|
||||
|
||||
def handle_data(context, data):
|
||||
data.current(symbol('TEST'), 10)
|
||||
"""
|
||||
|
||||
bad_type_is_stale_assets = """
|
||||
def initialize(context):
|
||||
pass
|
||||
|
||||
def handle_data(context, data):
|
||||
data.is_stale('TEST')
|
||||
"""
|
||||
|
||||
bad_type_can_trade_assets = """
|
||||
def initialize(context):
|
||||
pass
|
||||
|
||||
def handle_data(context, data):
|
||||
data.can_trade('TEST')
|
||||
"""
|
||||
|
||||
bad_type_history_assets_kwarg = """
|
||||
def initialize(context):
|
||||
pass
|
||||
|
||||
def handle_data(context, data):
|
||||
data.history(frequency='1d', fields='price', assets=1, bar_count=5)
|
||||
"""
|
||||
|
||||
bad_type_history_fields_kwarg = """
|
||||
from zipline.api import symbol
|
||||
|
||||
def initialize(context):
|
||||
pass
|
||||
|
||||
def handle_data(context, data):
|
||||
data.history(frequency='1d', fields=10, assets=symbol('TEST'),
|
||||
bar_count=5)
|
||||
"""
|
||||
|
||||
bad_type_history_bar_count_kwarg = """
|
||||
from zipline.api import symbol
|
||||
|
||||
def initialize(context):
|
||||
pass
|
||||
|
||||
def handle_data(context, data):
|
||||
data.history(frequency='1d', fields='price', assets=symbol('TEST'),
|
||||
bar_count='5')
|
||||
"""
|
||||
|
||||
bad_type_history_frequency_kwarg = """
|
||||
from zipline.api import symbol
|
||||
|
||||
def initialize(context):
|
||||
pass
|
||||
|
||||
def handle_data(context, data):
|
||||
data.history(frequency=1, fields='price', assets=symbol('TEST'),
|
||||
bar_count=5)
|
||||
"""
|
||||
|
||||
bad_type_current_assets_kwarg = """
|
||||
def initialize(context):
|
||||
pass
|
||||
|
||||
def handle_data(context, data):
|
||||
data.current(fields='price', assets=1)
|
||||
"""
|
||||
|
||||
bad_type_current_fields_kwarg = """
|
||||
from zipline.api import symbol
|
||||
|
||||
def initialize(context):
|
||||
pass
|
||||
|
||||
def handle_data(context, data):
|
||||
data.current(fields=10, assets=symbol('TEST'))
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user