From 8dc3ed73abdc6edc9e4833d6bbf49f2aa019853d Mon Sep 17 00:00:00 2001 From: Andrew Liang Date: Thu, 7 Apr 2016 14:29:38 -0400 Subject: [PATCH 1/5] FIX: Check types of args passed to api methods on data --- tests/test_algorithm.py | 59 +++++++++++++++++ zipline/_protocol.pyx | 76 +++++++++++++++++++-- zipline/test_algorithms.py | 131 +++++++++++++++++++++++++++++++++++++ 3 files changed, 261 insertions(+), 5 deletions(-) diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index ce79bfd4..3d8a64b4 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -107,7 +107,22 @@ from zipline.test_algorithms import ( call_without_kwargs, call_with_bad_kwargs_current, call_with_bad_kwargs_history, + bad_type_history_assets, + bad_type_history_fields, + bad_type_history_bar_count, + bad_type_history_frequency, + bad_type_current_assets, + bad_type_current_fields, + bad_type_can_trade_assets, + bad_type_is_stale_assets, + bad_type_history_assets_kwarg, + bad_type_history_fields_kwarg, + bad_type_history_bar_count_kwarg, + bad_type_history_frequency_kwarg, + bad_type_current_assets_kwarg, + bad_type_current_fields_kwarg, no_handle_data) + from zipline.testing import ( make_jagged_equity_info, to_utc, @@ -1431,6 +1446,29 @@ class TestBeforeTradingStart(TestCase): class TestAlgoScript(TestCase): + ARG_TYPE_TEST_CASES = [ + ('history__assets', (bad_type_history_assets, 'Asset, str', True)), + ('history__fields', (bad_type_history_fields, 'str', True)), + ('history__bar_count', (bad_type_history_bar_count, 'int', False)), + ('history__frequency', (bad_type_history_frequency, 'str', False)), + ('current__assets', (bad_type_current_assets, 'Asset, str', True)), + ('current__fields', (bad_type_current_fields, 'str', True)), + ('is_stale__assets', (bad_type_is_stale_assets, 'Asset', True)), + ('can_trade__assets', (bad_type_can_trade_assets, 'Asset', True)), + ('history_kwarg__assets', + (bad_type_history_assets_kwarg, 'Asset, str', True)), + ('history_kwarg__fields', + (bad_type_history_fields_kwarg, 'str', True)), + ('history_kwarg__bar_count', + (bad_type_history_bar_count_kwarg, 'int', False)), + ('history_kwarg__frequency', + (bad_type_history_frequency_kwarg, 'str', False)), + ('current_kwarg__assets', + (bad_type_current_assets_kwarg, 'Asset, str', True)), + ('current_kwarg__fields', + (bad_type_current_fields_kwarg, 'str', True)), + ] + @classmethod def setUpClass(cls): setup_logger(cls) @@ -1846,6 +1884,27 @@ def handle_data(context, data): self.assertEqual("%s() got an unexpected keyword argument 'blahblah'" % name, cm.exception.args[0]) + @parameterized.expand(ARG_TYPE_TEST_CASES) + def test_arg_types(self, name, inputs): + + keyword = name.split('__')[1] + + with self.assertRaises(TypeError) as cm: + algo = TradingAlgorithm( + script=inputs[0], + sim_params=self.sim_params, + env=self.env + ) + algo.run(self.data_portal) + + expected = "Expected %s argument to be of type %s%s" % ( + keyword, + 'or iterable of type ' if inputs[2] else '', + inputs[1] + ) + + self.assertEqual(expected, cm.exception.args[0]) + class TestGetDatetime(TestCase): diff --git a/zipline/_protocol.pyx b/zipline/_protocol.pyx index 6e7f0c92..9d981a34 100644 --- a/zipline/_protocol.pyx +++ b/zipline/_protocol.pyx @@ -21,6 +21,7 @@ import numpy as np from six import iteritems from cpython cimport bool +from collections import Iterable from zipline.assets import Asset from zipline.zipline_warnings import ZiplineDeprecationWarning @@ -33,20 +34,80 @@ class assert_keywords(object): meaningful message, unlike the one cython returns by default. """ - def __init__(self, *args): - self.names = args + def __init__(self, arg_names, method_name): + self.names = arg_names + self.method = method_name def __call__(self, func): def assert_keywords_and_call(*args, **kwargs): for field in kwargs: if field not in self.names: raise TypeError("%s() got an unexpected keyword argument" - " '%s'" % (func.__name__, field)) + " '%s'" % (self.method, field)) return func(*args, **kwargs) return assert_keywords_and_call +KEYWORDS = ['assets', 'fields', 'bar_count', 'frequency'] + + +class assert_types(object): + """ + Asserts that the arguments passed into the wrapped function are consistent + with the types passed into this decorator. If not, raise a TypeError with + a meaningful message. + """ + + def __init__(self, *args): + self.types = args + self.keys_to_types = dict(zip(KEYWORDS, args)) + + def _is_iterable(self, obj): + return isinstance(obj, Iterable) and not isinstance(obj, str) + + def __call__(self, func): + def assert_types_and_call(*args, **kwargs): + for i, arg in enumerate(args[1:]): + if isinstance(arg, self.types[i]): + continue + elif i in (0, 1) and self._is_iterable(arg): + if isinstance(arg[0], self.types[i]): + continue + + expected_type = self.types[i].__name__ \ + if not self._is_iterable(self.types[i]) \ + else ', '.join([type.__name__ for type in self.types[i]]) + raise TypeError("Expected %s argument to be of type %s%s" % + (KEYWORDS[i], + 'or iterable of type ' if i in (0, 1) else '', + expected_type) + ) + + for keyword, arg in iteritems(kwargs): + if isinstance(arg, self.keys_to_types[keyword]): + continue + elif keyword in ('assets', 'fields') and \ + self._is_iterable(arg): + if isinstance(arg[0], self.keys_to_types[i]): + continue + + expected_type = self.keys_to_types[keyword].__name__ \ + if not self._is_iterable(self.keys_to_types[keyword]) \ + else ', '.join([type.__name__ for type in + self.keys_to_types[keyword]]) + raise TypeError("Expected %s argument to be of type %s%s" % + (keyword, + 'or iterable of type ' if keyword in + ('assets', 'fields') else '', + expected_type) + ) + + return func(*args, **kwargs) + + return assert_types_and_call + + @contextmanager def handle_non_market_minutes(bar_data): try: @@ -149,7 +210,8 @@ cdef class BarData: return dt - @assert_keywords('assets', 'fields') + @assert_keywords(arg_names=('assets', 'fields'), method_name='current') + @assert_types((Asset, str), str) def current(self, assets, fields): """ Returns the current value of the given assets for the given fields @@ -327,6 +389,7 @@ cdef class BarData: cdef bool _is_iterable(self, obj): return hasattr(obj, '__iter__') and not isinstance(obj, str) + @assert_types(Asset) def can_trade(self, assets): """ For the given asset or iterable of assets, returns true if the asset @@ -373,6 +436,7 @@ cdef class BarData: return False + @assert_types(Asset) def is_stale(self, assets): """ For the given asset or iterable of assets, returns true if the asset @@ -432,7 +496,9 @@ cdef class BarData: return not (last_traded_dt is pd.NaT) - @assert_keywords('assets', 'fields', 'bar_count', 'frequency') + @assert_keywords(arg_names=('assets', 'fields', 'bar_count', 'frequency'), + method_name='history') + @assert_types((Asset, str), str, int, str) def history(self, assets, fields, bar_count, frequency): """ Returns a window of data for the given assets and fields. diff --git a/zipline/test_algorithms.py b/zipline/test_algorithms.py index 32932cd9..b2aa6edd 100644 --- a/zipline/test_algorithms.py +++ b/zipline/test_algorithms.py @@ -957,3 +957,134 @@ def initialize(context): def handle_data(context, data): current = data.current(assets=symbol('TEST'), blahblah="price") """ + +bad_type_history_assets = """ +def initialize(context): + pass + +def handle_data(context, data): + data.history(1, 'price', 5, '1d') +""" + +bad_type_history_fields = """ +from zipline.api import symbol + +def initialize(context): + pass + +def handle_data(context, data): + data.history(symbol('TEST'), 10 , 5, '1d') +""" + +bad_type_history_bar_count = """ +from zipline.api import symbol + +def initialize(context): + pass + +def handle_data(context, data): + data.history(symbol('TEST'), 'price', '5', '1d') +""" + +bad_type_history_frequency = """ +from zipline.api import symbol + +def initialize(context): + pass + +def handle_data(context, data): + data.history(symbol('TEST'), 'price', 5, 1) +""" + +bad_type_current_assets = """ +def initialize(context): + pass + +def handle_data(context, data): + data.current(1, 'price') +""" + +bad_type_current_fields = """ +from zipline.api import symbol + +def initialize(context): + pass + +def handle_data(context, data): + data.current(symbol('TEST'), 10) +""" + +bad_type_is_stale_assets = """ +def initialize(context): + pass + +def handle_data(context, data): + data.is_stale('TEST') +""" + +bad_type_can_trade_assets = """ +def initialize(context): + pass + +def handle_data(context, data): + data.can_trade('TEST') +""" + +bad_type_history_assets_kwarg = """ +def initialize(context): + pass + +def handle_data(context, data): + data.history(frequency='1d', fields='price', assets=1, bar_count=5) +""" + +bad_type_history_fields_kwarg = """ +from zipline.api import symbol + +def initialize(context): + pass + +def handle_data(context, data): + data.history(frequency='1d', fields=10, assets=symbol('TEST'), + bar_count=5) +""" + +bad_type_history_bar_count_kwarg = """ +from zipline.api import symbol + +def initialize(context): + pass + +def handle_data(context, data): + data.history(frequency='1d', fields='price', assets=symbol('TEST'), + bar_count='5') +""" + +bad_type_history_frequency_kwarg = """ +from zipline.api import symbol + +def initialize(context): + pass + +def handle_data(context, data): + data.history(frequency=1, fields='price', assets=symbol('TEST'), + bar_count=5) +""" + +bad_type_current_assets_kwarg = """ +def initialize(context): + pass + +def handle_data(context, data): + data.current(fields='price', assets=1) +""" + +bad_type_current_fields_kwarg = """ +from zipline.api import symbol + +def initialize(context): + pass + +def handle_data(context, data): + data.current(fields=10, assets=symbol('TEST')) +""" From 2775cc7ca4a44ab12693874992445682fa2d90dc Mon Sep 17 00:00:00 2001 From: Andrew Liang Date: Mon, 11 Apr 2016 12:24:22 -0400 Subject: [PATCH 2/5] FIX: Remove support for passing in sid int in place of Asset --- zipline/data/data_portal.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/zipline/data/data_portal.py b/zipline/data/data_portal.py index d4bc1e18..59337765 100644 --- a/zipline/data/data_portal.py +++ b/zipline/data/data_portal.py @@ -714,9 +714,6 @@ class DataPortal(object): if field not in BASE_FIELDS: raise KeyError("Invalid column: " + str(field)) - if isinstance(asset, int): - asset = self.env.asset_finder.retrieve_asset(asset) - if dt < asset.start_date or \ (data_frequency == "daily" and dt > asset.end_date) or \ (data_frequency == "minute" and @@ -847,8 +844,6 @@ class DataPortal(object): ------- The value of the desired field at the desired time. """ - if isinstance(asset, int): - asset = self._asset_finder.retrieve_asset(asset) if spot_value is None: spot_value = self.get_spot_value(asset, field, dt, data_frequency) From d597a3caaa4e3695ba1c5035f90e9826838c7003 Mon Sep 17 00:00:00 2001 From: Jean Bredeche Date: Wed, 13 Apr 2016 16:01:00 -0400 Subject: [PATCH 3/5] DEV: combined the decorators MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This way the type decorator could have access to the argument decorator’s keyword list. --- zipline/_protocol.pyx | 88 ++++++++++++++++++++----------------------- 1 file changed, 40 insertions(+), 48 deletions(-) diff --git a/zipline/_protocol.pyx b/zipline/_protocol.pyx index 9d981a34..0cbe81a2 100644 --- a/zipline/_protocol.pyx +++ b/zipline/_protocol.pyx @@ -27,75 +27,71 @@ from zipline.assets import Asset from zipline.zipline_warnings import ZiplineDeprecationWarning -class assert_keywords(object): +cdef bool _is_iterable(obj): + return isinstance(obj, Iterable) and not isinstance(obj, str) + + +cdef class check_parameters(object): """ Asserts that the keywords passed into the wrapped function are included in those passed into this decorator. If not, raise a TypeError with a - meaningful message, unlike the one cython returns by default. - """ + meaningful message, unlike the one Cython returns by default. - def __init__(self, arg_names, method_name): - self.names = arg_names + Also asserts that the arguments passed into the wrapped function are + consistent with the types passed into this decorator. If not, raise a + TypeError with a meaningful message. + """ + cdef tuple keyword_names + cdef object method + cdef tuple types + cdef dict keys_to_types + + def __init__(self, method_name, keyword_names, types): + self.keyword_names = keyword_names + self.types = types self.method = method_name + self.keys_to_types = dict(zip(keyword_names, types)) + def __call__(self, func): def assert_keywords_and_call(*args, **kwargs): + # verify all the keyword arguments for field in kwargs: - if field not in self.names: + if field not in self.keyword_names: raise TypeError("%s() got an unexpected keyword argument" " '%s'" % (self.method, field)) - return func(*args, **kwargs) - return assert_keywords_and_call - - -KEYWORDS = ['assets', 'fields', 'bar_count', 'frequency'] - - -class assert_types(object): - """ - Asserts that the arguments passed into the wrapped function are consistent - with the types passed into this decorator. If not, raise a TypeError with - a meaningful message. - """ - - def __init__(self, *args): - self.types = args - self.keys_to_types = dict(zip(KEYWORDS, args)) - - def _is_iterable(self, obj): - return isinstance(obj, Iterable) and not isinstance(obj, str) - - def __call__(self, func): - def assert_types_and_call(*args, **kwargs): + # verify type of each arg for i, arg in enumerate(args[1:]): if isinstance(arg, self.types[i]): continue - elif i in (0, 1) and self._is_iterable(arg): + elif i in (0, 1) and _is_iterable(arg): if isinstance(arg[0], self.types[i]): continue expected_type = self.types[i].__name__ \ - if not self._is_iterable(self.types[i]) \ + if not _is_iterable(self.types[i]) \ else ', '.join([type.__name__ for type in self.types[i]]) + raise TypeError("Expected %s argument to be of type %s%s" % - (KEYWORDS[i], + (self.keyword_names[i], 'or iterable of type ' if i in (0, 1) else '', expected_type) ) + # verify type of each kwarg for keyword, arg in iteritems(kwargs): if isinstance(arg, self.keys_to_types[keyword]): continue - elif keyword in ('assets', 'fields') and \ - self._is_iterable(arg): + elif keyword in ('assets', 'fields') and _is_iterable(arg): if isinstance(arg[0], self.keys_to_types[i]): continue expected_type = self.keys_to_types[keyword].__name__ \ - if not self._is_iterable(self.keys_to_types[keyword]) \ + if not _is_iterable(self.keys_to_types[keyword]) \ else ', '.join([type.__name__ for type in self.keys_to_types[keyword]]) + raise TypeError("Expected %s argument to be of type %s%s" % (keyword, 'or iterable of type ' if keyword in @@ -105,7 +101,7 @@ class assert_types(object): return func(*args, **kwargs) - return assert_types_and_call + return assert_keywords_and_call @contextmanager @@ -210,8 +206,7 @@ cdef class BarData: return dt - @assert_keywords(arg_names=('assets', 'fields'), method_name='current') - @assert_types((Asset, str), str) + @check_parameters('current', ('assets', 'fields'), ((Asset, str), str)) def current(self, assets, fields): """ Returns the current value of the given assets for the given fields @@ -268,8 +263,8 @@ cdef class BarData: the current trade bar. If there is no current trade bar, NaN is returned. """ - multiple_assets = self._is_iterable(assets) - multiple_fields = self._is_iterable(fields) + multiple_assets = _is_iterable(assets) + multiple_fields = _is_iterable(fields) # There's some overly verbose code in here, particularly around # 'do something if self._adjust_minutes is False, otherwise do @@ -386,10 +381,7 @@ cdef class BarData: return pd.DataFrame(data) - cdef bool _is_iterable(self, obj): - return hasattr(obj, '__iter__') and not isinstance(obj, str) - - @assert_types(Asset) + @check_parameters('can_trade', ('assets',), (Asset,)) def can_trade(self, assets): """ For the given asset or iterable of assets, returns true if the asset @@ -436,7 +428,7 @@ cdef class BarData: return False - @assert_types(Asset) + @check_parameters('is_stale', ('assets',), (Asset,)) def is_stale(self, assets): """ For the given asset or iterable of assets, returns true if the asset @@ -496,9 +488,9 @@ cdef class BarData: return not (last_traded_dt is pd.NaT) - @assert_keywords(arg_names=('assets', 'fields', 'bar_count', 'frequency'), - method_name='history') - @assert_types((Asset, str), str, int, str) + @check_parameters('history', + ('assets', 'fields', 'bar_count', 'frequency'), + ((Asset, str), str, int, str)) def history(self, assets, fields, bar_count, frequency): """ Returns a window of data for the given assets and fields. From d94b7bb9e48ff995da276d83219d0cacfc4c6f83 Mon Sep 17 00:00:00 2001 From: Jean Bredeche Date: Wed, 13 Apr 2016 16:09:18 -0400 Subject: [PATCH 4/5] DEV: Don't need to pass method name in. --- zipline/_protocol.pyx | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/zipline/_protocol.pyx b/zipline/_protocol.pyx index 0cbe81a2..93e6786b 100644 --- a/zipline/_protocol.pyx +++ b/zipline/_protocol.pyx @@ -42,14 +42,12 @@ cdef class check_parameters(object): TypeError with a meaningful message. """ cdef tuple keyword_names - cdef object method cdef tuple types cdef dict keys_to_types - def __init__(self, method_name, keyword_names, types): + def __init__(self, keyword_names, types): self.keyword_names = keyword_names self.types = types - self.method = method_name self.keys_to_types = dict(zip(keyword_names, types)) @@ -59,7 +57,7 @@ cdef class check_parameters(object): for field in kwargs: if field not in self.keyword_names: raise TypeError("%s() got an unexpected keyword argument" - " '%s'" % (self.method, field)) + " '%s'" % (func.__name__, field)) # verify type of each arg for i, arg in enumerate(args[1:]): @@ -206,7 +204,7 @@ cdef class BarData: return dt - @check_parameters('current', ('assets', 'fields'), ((Asset, str), str)) + @check_parameters(('assets', 'fields'), ((Asset, str), str)) def current(self, assets, fields): """ Returns the current value of the given assets for the given fields @@ -381,7 +379,7 @@ cdef class BarData: return pd.DataFrame(data) - @check_parameters('can_trade', ('assets',), (Asset,)) + @check_parameters(('assets',), (Asset,)) def can_trade(self, assets): """ For the given asset or iterable of assets, returns true if the asset @@ -428,7 +426,7 @@ cdef class BarData: return False - @check_parameters('is_stale', ('assets',), (Asset,)) + @check_parameters(('assets',), (Asset,)) def is_stale(self, assets): """ For the given asset or iterable of assets, returns true if the asset @@ -488,8 +486,7 @@ cdef class BarData: return not (last_traded_dt is pd.NaT) - @check_parameters('history', - ('assets', 'fields', 'bar_count', 'frequency'), + @check_parameters(('assets', 'fields', 'bar_count', 'frequency'), ((Asset, str), str, int, str)) def history(self, assets, fields, bar_count, frequency): """ From bd36e92556ee232c282d387bb05081ac287a2cc3 Mon Sep 17 00:00:00 2001 From: Jean Bredeche Date: Wed, 13 Apr 2016 16:34:04 -0400 Subject: [PATCH 5/5] DEV: minor perf boosts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit likely doesn’t move the needle that much --- zipline/_protocol.pyx | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/zipline/_protocol.pyx b/zipline/_protocol.pyx index 93e6786b..bb91e80b 100644 --- a/zipline/_protocol.pyx +++ b/zipline/_protocol.pyx @@ -53,6 +53,8 @@ cdef class check_parameters(object): def __call__(self, func): def assert_keywords_and_call(*args, **kwargs): + cdef short i + # verify all the keyword arguments for field in kwargs: if field not in self.keyword_names: @@ -60,21 +62,28 @@ cdef class check_parameters(object): " '%s'" % (func.__name__, field)) # verify type of each arg - for i, arg in enumerate(args[1:]): - if isinstance(arg, self.types[i]): + i = 0 + while i < (len(args) - 1): + arg = args[i + 1] + expected_type = self.types[i] + + if isinstance(arg, expected_type): + i += 1 continue - elif i in (0, 1) and _is_iterable(arg): - if isinstance(arg[0], self.types[i]): + + elif (i == 0 or i == 1) and _is_iterable(arg): + if isinstance(arg[0], expected_type): + i += 1 continue - expected_type = self.types[i].__name__ \ - if not _is_iterable(self.types[i]) \ - else ', '.join([type.__name__ for type in self.types[i]]) + expected_type_name = expected_type.__name__ \ + if not _is_iterable(expected_type) \ + else ', '.join([type_.__name__ for type_ in expected_type]) raise TypeError("Expected %s argument to be of type %s%s" % (self.keyword_names[i], 'or iterable of type ' if i in (0, 1) else '', - expected_type) + expected_type_name) ) # verify type of each kwarg