diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index ce79bfd4..3d8a64b4 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -107,7 +107,22 @@ from zipline.test_algorithms import ( call_without_kwargs, call_with_bad_kwargs_current, call_with_bad_kwargs_history, + bad_type_history_assets, + bad_type_history_fields, + bad_type_history_bar_count, + bad_type_history_frequency, + bad_type_current_assets, + bad_type_current_fields, + bad_type_can_trade_assets, + bad_type_is_stale_assets, + bad_type_history_assets_kwarg, + bad_type_history_fields_kwarg, + bad_type_history_bar_count_kwarg, + bad_type_history_frequency_kwarg, + bad_type_current_assets_kwarg, + bad_type_current_fields_kwarg, no_handle_data) + from zipline.testing import ( make_jagged_equity_info, to_utc, @@ -1431,6 +1446,29 @@ class TestBeforeTradingStart(TestCase): class TestAlgoScript(TestCase): + ARG_TYPE_TEST_CASES = [ + ('history__assets', (bad_type_history_assets, 'Asset, str', True)), + ('history__fields', (bad_type_history_fields, 'str', True)), + ('history__bar_count', (bad_type_history_bar_count, 'int', False)), + ('history__frequency', (bad_type_history_frequency, 'str', False)), + ('current__assets', (bad_type_current_assets, 'Asset, str', True)), + ('current__fields', (bad_type_current_fields, 'str', True)), + ('is_stale__assets', (bad_type_is_stale_assets, 'Asset', True)), + ('can_trade__assets', (bad_type_can_trade_assets, 'Asset', True)), + ('history_kwarg__assets', + (bad_type_history_assets_kwarg, 'Asset, str', True)), + ('history_kwarg__fields', + (bad_type_history_fields_kwarg, 'str', True)), + ('history_kwarg__bar_count', + (bad_type_history_bar_count_kwarg, 'int', False)), + ('history_kwarg__frequency', + (bad_type_history_frequency_kwarg, 'str', False)), + ('current_kwarg__assets', + (bad_type_current_assets_kwarg, 'Asset, str', True)), + ('current_kwarg__fields', + (bad_type_current_fields_kwarg, 'str', True)), + ] + @classmethod def setUpClass(cls): setup_logger(cls) @@ -1846,6 +1884,27 @@ def handle_data(context, data): self.assertEqual("%s() got an unexpected keyword argument 'blahblah'" % name, cm.exception.args[0]) + @parameterized.expand(ARG_TYPE_TEST_CASES) + def test_arg_types(self, name, inputs): + + keyword = name.split('__')[1] + + with self.assertRaises(TypeError) as cm: + algo = TradingAlgorithm( + script=inputs[0], + sim_params=self.sim_params, + env=self.env + ) + algo.run(self.data_portal) + + expected = "Expected %s argument to be of type %s%s" % ( + keyword, + 'or iterable of type ' if inputs[2] else '', + inputs[1] + ) + + self.assertEqual(expected, cm.exception.args[0]) + class TestGetDatetime(TestCase): diff --git a/zipline/_protocol.pyx b/zipline/_protocol.pyx index 6e7f0c92..bb91e80b 100644 --- a/zipline/_protocol.pyx +++ b/zipline/_protocol.pyx @@ -21,27 +21,91 @@ import numpy as np from six import iteritems from cpython cimport bool +from collections import Iterable from zipline.assets import Asset from zipline.zipline_warnings import ZiplineDeprecationWarning -class assert_keywords(object): +cdef bool _is_iterable(obj): + return isinstance(obj, Iterable) and not isinstance(obj, str) + + +cdef class check_parameters(object): """ Asserts that the keywords passed into the wrapped function are included in those passed into this decorator. If not, raise a TypeError with a - meaningful message, unlike the one cython returns by default. - """ + meaningful message, unlike the one Cython returns by default. - def __init__(self, *args): - self.names = args + Also asserts that the arguments passed into the wrapped function are + consistent with the types passed into this decorator. If not, raise a + TypeError with a meaningful message. + """ + cdef tuple keyword_names + cdef tuple types + cdef dict keys_to_types + + def __init__(self, keyword_names, types): + self.keyword_names = keyword_names + self.types = types + + self.keys_to_types = dict(zip(keyword_names, types)) def __call__(self, func): def assert_keywords_and_call(*args, **kwargs): + cdef short i + + # verify all the keyword arguments for field in kwargs: - if field not in self.names: + if field not in self.keyword_names: raise TypeError("%s() got an unexpected keyword argument" " '%s'" % (func.__name__, field)) + + # verify type of each arg + i = 0 + while i < (len(args) - 1): + arg = args[i + 1] + expected_type = self.types[i] + + if isinstance(arg, expected_type): + i += 1 + continue + + elif (i == 0 or i == 1) and _is_iterable(arg): + if isinstance(arg[0], expected_type): + i += 1 + continue + + expected_type_name = expected_type.__name__ \ + if not _is_iterable(expected_type) \ + else ', '.join([type_.__name__ for type_ in expected_type]) + + raise TypeError("Expected %s argument to be of type %s%s" % + (self.keyword_names[i], + 'or iterable of type ' if i in (0, 1) else '', + expected_type_name) + ) + + # verify type of each kwarg + for keyword, arg in iteritems(kwargs): + if isinstance(arg, self.keys_to_types[keyword]): + continue + elif keyword in ('assets', 'fields') and _is_iterable(arg): + if isinstance(arg[0], self.keys_to_types[i]): + continue + + expected_type = self.keys_to_types[keyword].__name__ \ + if not _is_iterable(self.keys_to_types[keyword]) \ + else ', '.join([type.__name__ for type in + self.keys_to_types[keyword]]) + + raise TypeError("Expected %s argument to be of type %s%s" % + (keyword, + 'or iterable of type ' if keyword in + ('assets', 'fields') else '', + expected_type) + ) + return func(*args, **kwargs) return assert_keywords_and_call @@ -149,7 +213,7 @@ cdef class BarData: return dt - @assert_keywords('assets', 'fields') + @check_parameters(('assets', 'fields'), ((Asset, str), str)) def current(self, assets, fields): """ Returns the current value of the given assets for the given fields @@ -206,8 +270,8 @@ cdef class BarData: the current trade bar. If there is no current trade bar, NaN is returned. """ - multiple_assets = self._is_iterable(assets) - multiple_fields = self._is_iterable(fields) + multiple_assets = _is_iterable(assets) + multiple_fields = _is_iterable(fields) # There's some overly verbose code in here, particularly around # 'do something if self._adjust_minutes is False, otherwise do @@ -324,9 +388,7 @@ cdef class BarData: return pd.DataFrame(data) - cdef bool _is_iterable(self, obj): - return hasattr(obj, '__iter__') and not isinstance(obj, str) - + @check_parameters(('assets',), (Asset,)) def can_trade(self, assets): """ For the given asset or iterable of assets, returns true if the asset @@ -373,6 +435,7 @@ cdef class BarData: return False + @check_parameters(('assets',), (Asset,)) def is_stale(self, assets): """ For the given asset or iterable of assets, returns true if the asset @@ -432,7 +495,8 @@ cdef class BarData: return not (last_traded_dt is pd.NaT) - @assert_keywords('assets', 'fields', 'bar_count', 'frequency') + @check_parameters(('assets', 'fields', 'bar_count', 'frequency'), + ((Asset, str), str, int, str)) def history(self, assets, fields, bar_count, frequency): """ Returns a window of data for the given assets and fields. diff --git a/zipline/data/data_portal.py b/zipline/data/data_portal.py index 3bee2265..c8127591 100644 --- a/zipline/data/data_portal.py +++ b/zipline/data/data_portal.py @@ -694,9 +694,6 @@ class DataPortal(object): if field not in BASE_FIELDS: raise KeyError("Invalid column: " + str(field)) - if isinstance(asset, int): - asset = self.env.asset_finder.retrieve_asset(asset) - if dt < asset.start_date or \ (data_frequency == "daily" and dt > asset.end_date) or \ (data_frequency == "minute" and @@ -827,8 +824,6 @@ class DataPortal(object): ------- The value of the desired field at the desired time. """ - if isinstance(asset, int): - asset = self._asset_finder.retrieve_asset(asset) if spot_value is None: spot_value = self.get_spot_value(asset, field, dt, data_frequency) diff --git a/zipline/test_algorithms.py b/zipline/test_algorithms.py index 32932cd9..b2aa6edd 100644 --- a/zipline/test_algorithms.py +++ b/zipline/test_algorithms.py @@ -957,3 +957,134 @@ def initialize(context): def handle_data(context, data): current = data.current(assets=symbol('TEST'), blahblah="price") """ + +bad_type_history_assets = """ +def initialize(context): + pass + +def handle_data(context, data): + data.history(1, 'price', 5, '1d') +""" + +bad_type_history_fields = """ +from zipline.api import symbol + +def initialize(context): + pass + +def handle_data(context, data): + data.history(symbol('TEST'), 10 , 5, '1d') +""" + +bad_type_history_bar_count = """ +from zipline.api import symbol + +def initialize(context): + pass + +def handle_data(context, data): + data.history(symbol('TEST'), 'price', '5', '1d') +""" + +bad_type_history_frequency = """ +from zipline.api import symbol + +def initialize(context): + pass + +def handle_data(context, data): + data.history(symbol('TEST'), 'price', 5, 1) +""" + +bad_type_current_assets = """ +def initialize(context): + pass + +def handle_data(context, data): + data.current(1, 'price') +""" + +bad_type_current_fields = """ +from zipline.api import symbol + +def initialize(context): + pass + +def handle_data(context, data): + data.current(symbol('TEST'), 10) +""" + +bad_type_is_stale_assets = """ +def initialize(context): + pass + +def handle_data(context, data): + data.is_stale('TEST') +""" + +bad_type_can_trade_assets = """ +def initialize(context): + pass + +def handle_data(context, data): + data.can_trade('TEST') +""" + +bad_type_history_assets_kwarg = """ +def initialize(context): + pass + +def handle_data(context, data): + data.history(frequency='1d', fields='price', assets=1, bar_count=5) +""" + +bad_type_history_fields_kwarg = """ +from zipline.api import symbol + +def initialize(context): + pass + +def handle_data(context, data): + data.history(frequency='1d', fields=10, assets=symbol('TEST'), + bar_count=5) +""" + +bad_type_history_bar_count_kwarg = """ +from zipline.api import symbol + +def initialize(context): + pass + +def handle_data(context, data): + data.history(frequency='1d', fields='price', assets=symbol('TEST'), + bar_count='5') +""" + +bad_type_history_frequency_kwarg = """ +from zipline.api import symbol + +def initialize(context): + pass + +def handle_data(context, data): + data.history(frequency=1, fields='price', assets=symbol('TEST'), + bar_count=5) +""" + +bad_type_current_assets_kwarg = """ +def initialize(context): + pass + +def handle_data(context, data): + data.current(fields='price', assets=1) +""" + +bad_type_current_fields_kwarg = """ +from zipline.api import symbol + +def initialize(context): + pass + +def handle_data(context, data): + data.current(fields=10, assets=symbol('TEST')) +"""