diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index c1a473e4..8ec6a0f8 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -25,7 +25,7 @@ import toolz from logbook import TestHandler, WARNING from mock import MagicMock from nose_parameterized import parameterized -from six import iteritems, itervalues +from six import iteritems, itervalues, string_types from six.moves import range from testfixtures import TempDirectory @@ -37,7 +37,7 @@ from pandas.io.common import PerformanceWarning from zipline import run_algorithm from zipline import TradingAlgorithm from zipline.api import FixedSlippage -from zipline.assets import Equity, Future +from zipline.assets import Equity, Future, Asset from zipline.assets.synthetic import ( make_jagged_equity_info, make_simple_equity_info, @@ -1523,29 +1523,46 @@ class TestAlgoScript(WithLogger, DATA_PORTAL_USE_MINUTE_DATA = False EQUITY_DAILY_BAR_LOOKBACK_DAYS = 5 # max history window length + STRING_TYPE_NAMES = [s.__name__ for s in string_types] + STRING_TYPE_NAMES_STRING = ', '.join(STRING_TYPE_NAMES) + ASSET_TYPE_NAME = Asset.__name__ + ASSET_OR_STRING_TYPE_NAMES = ', '.join([ASSET_TYPE_NAME] + + STRING_TYPE_NAMES) ARG_TYPE_TEST_CASES = ( - ('history__assets', (bad_type_history_assets, 'Asset, str', True)), - ('history__fields', (bad_type_history_fields, 'str', True)), + ('history__assets', (bad_type_history_assets, + ASSET_OR_STRING_TYPE_NAMES, + True)), + ('history__fields', (bad_type_history_fields, + STRING_TYPE_NAMES_STRING, + True)), ('history__bar_count', (bad_type_history_bar_count, 'int', False)), - ('history__frequency', (bad_type_history_frequency, 'str', False)), - ('current__assets', (bad_type_current_assets, 'Asset, str', True)), - ('current__fields', (bad_type_current_fields, 'str', True)), + ('history__frequency', (bad_type_history_frequency, + STRING_TYPE_NAMES_STRING, + False)), + ('current__assets', (bad_type_current_assets, + ASSET_OR_STRING_TYPE_NAMES, + True)), + ('current__fields', (bad_type_current_fields, + STRING_TYPE_NAMES_STRING, + True)), ('is_stale__assets', (bad_type_is_stale_assets, 'Asset', True)), ('can_trade__assets', (bad_type_can_trade_assets, 'Asset', True)), ('history_kwarg__assets', - (bad_type_history_assets_kwarg, 'Asset, str', True)), + (bad_type_history_assets_kwarg, ASSET_OR_STRING_TYPE_NAMES, True)), ('history_kwarg_bad_list__assets', - (bad_type_history_assets_kwarg_list, 'Asset, str', True)), + (bad_type_history_assets_kwarg_list, + ASSET_OR_STRING_TYPE_NAMES, + True)), ('history_kwarg__fields', - (bad_type_history_fields_kwarg, 'str', True)), + (bad_type_history_fields_kwarg, STRING_TYPE_NAMES_STRING, True)), ('history_kwarg__bar_count', (bad_type_history_bar_count_kwarg, 'int', False)), ('history_kwarg__frequency', - (bad_type_history_frequency_kwarg, 'str', False)), + (bad_type_history_frequency_kwarg, STRING_TYPE_NAMES_STRING, False)), ('current_kwarg__assets', - (bad_type_current_assets_kwarg, 'Asset, str', True)), + (bad_type_current_assets_kwarg, ASSET_OR_STRING_TYPE_NAMES, True)), ('current_kwarg__fields', - (bad_type_current_fields_kwarg, 'str', True)), + (bad_type_current_fields_kwarg, STRING_TYPE_NAMES_STRING, True)), ) sids = 0, 1, 3, 133 diff --git a/zipline/_protocol.pyx b/zipline/_protocol.pyx index a1b68a05..fd0c39b2 100644 --- a/zipline/_protocol.pyx +++ b/zipline/_protocol.pyx @@ -20,7 +20,7 @@ from pandas.tslib import normalize_date import pandas as pd import numpy as np -from six import iteritems, PY2 +from six import iteritems, PY2, string_types from cpython cimport bool from collections import Iterable @@ -29,7 +29,7 @@ from zipline.zipline_warnings import ZiplineDeprecationWarning cdef bool _is_iterable(obj): - return isinstance(obj, Iterable) and not isinstance(obj, str) + return isinstance(obj, Iterable) and not isinstance(obj, string_types) # Wraps doesn't work for method objects in python2. Docs should be generated @@ -247,7 +247,8 @@ cdef class BarData: return dt - @check_parameters(('assets', 'fields'), ((Asset, str), str)) + @check_parameters(('assets', 'fields'), + ((Asset,) + string_types, string_types)) def current(self, assets, fields): """ Returns the current value of the given assets for the given fields @@ -568,8 +569,10 @@ cdef class BarData: return not (last_traded_dt is pd.NaT) - @check_parameters(('assets', 'fields', 'bar_count', 'frequency'), - ((Asset, str), str, int, str)) + @check_parameters(('assets', 'fields', 'bar_count', + 'frequency'), + ((Asset,) + string_types, string_types, int, + string_types)) def history(self, assets, fields, bar_count, frequency): """ Returns a window of data for the given assets and fields. @@ -615,7 +618,7 @@ cdef class BarData: If the current simulation time is not a valid market time, we use the last market close instead. """ - if isinstance(fields, str): + if isinstance(fields, string_types): single_asset = isinstance(assets, Asset) if single_asset: