diff --git a/zipline/_protocol.pyx b/zipline/_protocol.pyx index 93e6786b..bb91e80b 100644 --- a/zipline/_protocol.pyx +++ b/zipline/_protocol.pyx @@ -53,6 +53,8 @@ cdef class check_parameters(object): def __call__(self, func): def assert_keywords_and_call(*args, **kwargs): + cdef short i + # verify all the keyword arguments for field in kwargs: if field not in self.keyword_names: @@ -60,21 +62,28 @@ cdef class check_parameters(object): " '%s'" % (func.__name__, field)) # verify type of each arg - for i, arg in enumerate(args[1:]): - if isinstance(arg, self.types[i]): + i = 0 + while i < (len(args) - 1): + arg = args[i + 1] + expected_type = self.types[i] + + if isinstance(arg, expected_type): + i += 1 continue - elif i in (0, 1) and _is_iterable(arg): - if isinstance(arg[0], self.types[i]): + + elif (i == 0 or i == 1) and _is_iterable(arg): + if isinstance(arg[0], expected_type): + i += 1 continue - expected_type = self.types[i].__name__ \ - if not _is_iterable(self.types[i]) \ - else ', '.join([type.__name__ for type in self.types[i]]) + expected_type_name = expected_type.__name__ \ + if not _is_iterable(expected_type) \ + else ', '.join([type_.__name__ for type_ in expected_type]) raise TypeError("Expected %s argument to be of type %s%s" % (self.keyword_names[i], 'or iterable of type ' if i in (0, 1) else '', - expected_type) + expected_type_name) ) # verify type of each kwarg