diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index ce79bfd4..3d8a64b4 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -107,7 +107,22 @@ from zipline.test_algorithms import ( call_without_kwargs, call_with_bad_kwargs_current, call_with_bad_kwargs_history, + bad_type_history_assets, + bad_type_history_fields, + bad_type_history_bar_count, + bad_type_history_frequency, + bad_type_current_assets, + bad_type_current_fields, + bad_type_can_trade_assets, + bad_type_is_stale_assets, + bad_type_history_assets_kwarg, + bad_type_history_fields_kwarg, + bad_type_history_bar_count_kwarg, + bad_type_history_frequency_kwarg, + bad_type_current_assets_kwarg, + bad_type_current_fields_kwarg, no_handle_data) + from zipline.testing import ( make_jagged_equity_info, to_utc, @@ -1431,6 +1446,29 @@ class TestBeforeTradingStart(TestCase): class TestAlgoScript(TestCase): + ARG_TYPE_TEST_CASES = [ + ('history__assets', (bad_type_history_assets, 'Asset, str', True)), + ('history__fields', (bad_type_history_fields, 'str', True)), + ('history__bar_count', (bad_type_history_bar_count, 'int', False)), + ('history__frequency', (bad_type_history_frequency, 'str', False)), + ('current__assets', (bad_type_current_assets, 'Asset, str', True)), + ('current__fields', (bad_type_current_fields, 'str', True)), + ('is_stale__assets', (bad_type_is_stale_assets, 'Asset', True)), + ('can_trade__assets', (bad_type_can_trade_assets, 'Asset', True)), + ('history_kwarg__assets', + (bad_type_history_assets_kwarg, 'Asset, str', True)), + ('history_kwarg__fields', + (bad_type_history_fields_kwarg, 'str', True)), + ('history_kwarg__bar_count', + (bad_type_history_bar_count_kwarg, 'int', False)), + ('history_kwarg__frequency', + (bad_type_history_frequency_kwarg, 'str', False)), + ('current_kwarg__assets', + (bad_type_current_assets_kwarg, 'Asset, str', True)), + ('current_kwarg__fields', + (bad_type_current_fields_kwarg, 'str', True)), + ] + @classmethod def setUpClass(cls): setup_logger(cls) @@ -1846,6 +1884,27 @@ def handle_data(context, data): self.assertEqual("%s() got an unexpected keyword argument 'blahblah'" % name, cm.exception.args[0]) + @parameterized.expand(ARG_TYPE_TEST_CASES) + def test_arg_types(self, name, inputs): + + keyword = name.split('__')[1] + + with self.assertRaises(TypeError) as cm: + algo = TradingAlgorithm( + script=inputs[0], + sim_params=self.sim_params, + env=self.env + ) + algo.run(self.data_portal) + + expected = "Expected %s argument to be of type %s%s" % ( + keyword, + 'or iterable of type ' if inputs[2] else '', + inputs[1] + ) + + self.assertEqual(expected, cm.exception.args[0]) + class TestGetDatetime(TestCase): diff --git a/zipline/_protocol.pyx b/zipline/_protocol.pyx index 6e7f0c92..9d981a34 100644 --- a/zipline/_protocol.pyx +++ b/zipline/_protocol.pyx @@ -21,6 +21,7 @@ import numpy as np from six import iteritems from cpython cimport bool +from collections import Iterable from zipline.assets import Asset from zipline.zipline_warnings import ZiplineDeprecationWarning @@ -33,20 +34,80 @@ class assert_keywords(object): meaningful message, unlike the one cython returns by default. """ - def __init__(self, *args): - self.names = args + def __init__(self, arg_names, method_name): + self.names = arg_names + self.method = method_name def __call__(self, func): def assert_keywords_and_call(*args, **kwargs): for field in kwargs: if field not in self.names: raise TypeError("%s() got an unexpected keyword argument" - " '%s'" % (func.__name__, field)) + " '%s'" % (self.method, field)) return func(*args, **kwargs) return assert_keywords_and_call +KEYWORDS = ['assets', 'fields', 'bar_count', 'frequency'] + + +class assert_types(object): + """ + Asserts that the arguments passed into the wrapped function are consistent + with the types passed into this decorator. If not, raise a TypeError with + a meaningful message. + """ + + def __init__(self, *args): + self.types = args + self.keys_to_types = dict(zip(KEYWORDS, args)) + + def _is_iterable(self, obj): + return isinstance(obj, Iterable) and not isinstance(obj, str) + + def __call__(self, func): + def assert_types_and_call(*args, **kwargs): + for i, arg in enumerate(args[1:]): + if isinstance(arg, self.types[i]): + continue + elif i in (0, 1) and self._is_iterable(arg): + if isinstance(arg[0], self.types[i]): + continue + + expected_type = self.types[i].__name__ \ + if not self._is_iterable(self.types[i]) \ + else ', '.join([type.__name__ for type in self.types[i]]) + raise TypeError("Expected %s argument to be of type %s%s" % + (KEYWORDS[i], + 'or iterable of type ' if i in (0, 1) else '', + expected_type) + ) + + for keyword, arg in iteritems(kwargs): + if isinstance(arg, self.keys_to_types[keyword]): + continue + elif keyword in ('assets', 'fields') and \ + self._is_iterable(arg): + if isinstance(arg[0], self.keys_to_types[i]): + continue + + expected_type = self.keys_to_types[keyword].__name__ \ + if not self._is_iterable(self.keys_to_types[keyword]) \ + else ', '.join([type.__name__ for type in + self.keys_to_types[keyword]]) + raise TypeError("Expected %s argument to be of type %s%s" % + (keyword, + 'or iterable of type ' if keyword in + ('assets', 'fields') else '', + expected_type) + ) + + return func(*args, **kwargs) + + return assert_types_and_call + + @contextmanager def handle_non_market_minutes(bar_data): try: @@ -149,7 +210,8 @@ cdef class BarData: return dt - @assert_keywords('assets', 'fields') + @assert_keywords(arg_names=('assets', 'fields'), method_name='current') + @assert_types((Asset, str), str) def current(self, assets, fields): """ Returns the current value of the given assets for the given fields @@ -327,6 +389,7 @@ cdef class BarData: cdef bool _is_iterable(self, obj): return hasattr(obj, '__iter__') and not isinstance(obj, str) + @assert_types(Asset) def can_trade(self, assets): """ For the given asset or iterable of assets, returns true if the asset @@ -373,6 +436,7 @@ cdef class BarData: return False + @assert_types(Asset) def is_stale(self, assets): """ For the given asset or iterable of assets, returns true if the asset @@ -432,7 +496,9 @@ cdef class BarData: return not (last_traded_dt is pd.NaT) - @assert_keywords('assets', 'fields', 'bar_count', 'frequency') + @assert_keywords(arg_names=('assets', 'fields', 'bar_count', 'frequency'), + method_name='history') + @assert_types((Asset, str), str, int, str) def history(self, assets, fields, bar_count, frequency): """ Returns a window of data for the given assets and fields. diff --git a/zipline/test_algorithms.py b/zipline/test_algorithms.py index 32932cd9..b2aa6edd 100644 --- a/zipline/test_algorithms.py +++ b/zipline/test_algorithms.py @@ -957,3 +957,134 @@ def initialize(context): def handle_data(context, data): current = data.current(assets=symbol('TEST'), blahblah="price") """ + +bad_type_history_assets = """ +def initialize(context): + pass + +def handle_data(context, data): + data.history(1, 'price', 5, '1d') +""" + +bad_type_history_fields = """ +from zipline.api import symbol + +def initialize(context): + pass + +def handle_data(context, data): + data.history(symbol('TEST'), 10 , 5, '1d') +""" + +bad_type_history_bar_count = """ +from zipline.api import symbol + +def initialize(context): + pass + +def handle_data(context, data): + data.history(symbol('TEST'), 'price', '5', '1d') +""" + +bad_type_history_frequency = """ +from zipline.api import symbol + +def initialize(context): + pass + +def handle_data(context, data): + data.history(symbol('TEST'), 'price', 5, 1) +""" + +bad_type_current_assets = """ +def initialize(context): + pass + +def handle_data(context, data): + data.current(1, 'price') +""" + +bad_type_current_fields = """ +from zipline.api import symbol + +def initialize(context): + pass + +def handle_data(context, data): + data.current(symbol('TEST'), 10) +""" + +bad_type_is_stale_assets = """ +def initialize(context): + pass + +def handle_data(context, data): + data.is_stale('TEST') +""" + +bad_type_can_trade_assets = """ +def initialize(context): + pass + +def handle_data(context, data): + data.can_trade('TEST') +""" + +bad_type_history_assets_kwarg = """ +def initialize(context): + pass + +def handle_data(context, data): + data.history(frequency='1d', fields='price', assets=1, bar_count=5) +""" + +bad_type_history_fields_kwarg = """ +from zipline.api import symbol + +def initialize(context): + pass + +def handle_data(context, data): + data.history(frequency='1d', fields=10, assets=symbol('TEST'), + bar_count=5) +""" + +bad_type_history_bar_count_kwarg = """ +from zipline.api import symbol + +def initialize(context): + pass + +def handle_data(context, data): + data.history(frequency='1d', fields='price', assets=symbol('TEST'), + bar_count='5') +""" + +bad_type_history_frequency_kwarg = """ +from zipline.api import symbol + +def initialize(context): + pass + +def handle_data(context, data): + data.history(frequency=1, fields='price', assets=symbol('TEST'), + bar_count=5) +""" + +bad_type_current_assets_kwarg = """ +def initialize(context): + pass + +def handle_data(context, data): + data.current(fields='price', assets=1) +""" + +bad_type_current_fields_kwarg = """ +from zipline.api import symbol + +def initialize(context): + pass + +def handle_data(context, data): + data.current(fields=10, assets=symbol('TEST')) +"""