diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index 555e900d..b19eb601 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -132,6 +132,7 @@ from zipline.test_algorithms import ( bad_type_history_fields, bad_type_history_bar_count, bad_type_history_frequency, + bad_type_history_assets_kwarg_list, bad_type_current_assets, bad_type_current_fields, bad_type_can_trade_assets, @@ -1333,6 +1334,8 @@ class TestAlgoScript(WithLogger, ('can_trade__assets', (bad_type_can_trade_assets, 'Asset', True)), ('history_kwarg__assets', (bad_type_history_assets_kwarg, 'Asset, str', True)), + ('history_kwarg_bad_list__assets', + (bad_type_history_assets_kwarg_list, 'Asset, str', True)), ('history_kwarg__fields', (bad_type_history_fields_kwarg, 'str', True)), ('history_kwarg__bar_count', diff --git a/zipline/_protocol.pyx b/zipline/_protocol.pyx index 486f9931..64446caf 100644 --- a/zipline/_protocol.pyx +++ b/zipline/_protocol.pyx @@ -61,7 +61,7 @@ cdef class check_parameters(object): raise TypeError("%s() got an unexpected keyword argument" " '%s'" % (func.__name__, field)) - # verify type of each arg + # verify type of each argument for i, arg in enumerate(args[1:]): expected_type = self.types[i] @@ -93,12 +93,12 @@ cdef class check_parameters(object): if len(arg) == 0: continue - if isinstance(arg[0], self.keys_to_types[i]): + if isinstance(arg[0], self.keys_to_types[keyword]): continue expected_type = self.keys_to_types[keyword].__name__ \ if not _is_iterable(self.keys_to_types[keyword]) \ - else ', '.join([type.__name__ for type in + else ', '.join([type_.__name__ for type_ in self.keys_to_types[keyword]]) raise TypeError("Expected %s argument to be of type %s%s" % diff --git a/zipline/test_algorithms.py b/zipline/test_algorithms.py index b2aa6edd..e969684a 100644 --- a/zipline/test_algorithms.py +++ b/zipline/test_algorithms.py @@ -1088,3 +1088,11 @@ def initialize(context): def handle_data(context, data): data.current(fields=10, assets=symbol('TEST')) """ + +bad_type_history_assets_kwarg_list = """ +def initialize(context): + pass + +def handle_data(context, data): + data.history(assets=[1,2], fields='price', bar_count=5, frequency="1d") +"""