diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index dcf5906b..a806c5fd 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -98,6 +98,12 @@ class TestRecordAlgorithm(TestCase): np.testing.assert_array_equal(output['incr'].values, range(1, len(output) + 1)) + np.testing.assert_array_equal(output['name'].values, + range(1, len(output) + 1)) + np.testing.assert_array_equal(output['name2'].values, + [2] * len(output)) + np.testing.assert_array_equal(output['name3'].values, + range(1, len(output) + 1)) class TestMiscellaneousAPI(TestCase): diff --git a/zipline/algorithm.py b/zipline/algorithm.py index 950ae099..1949763d 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -21,7 +21,7 @@ import numpy as np from datetime import datetime -from itertools import groupby +from itertools import groupby, chain from six.moves import filter from six import iteritems, exec_ from operator import attrgetter @@ -467,11 +467,19 @@ class TradingAlgorithm(object): 'kwargs': kwargs} @api_method - def record(self, **kwargs): + def record(self, *args, **kwargs): """ Track and record local variable (i.e. attributes) each day. """ - for name, value in kwargs.items(): + # Make 2 objects both referencing the same iterator + args = [iter(args)] * 2 + + # Zip generates list entries by calling `next` on each iterator it + # receives. In this case the two iterators are the same object, so the + # call to next on args[0] will also advance args[1], resulting in zip + # returning (a,b) (c,d) (e,f) rather than (a,a) (b,b) (c,c) etc. + positionals = zip(*args) + for name, value in chain(positionals, iteritems(kwargs)): self._recorded_vars[name] = value @api_method diff --git a/zipline/test_algorithms.py b/zipline/test_algorithms.py index 6349ba1c..ef99f597 100644 --- a/zipline/test_algorithms.py +++ b/zipline/test_algorithms.py @@ -233,6 +233,9 @@ class RecordAlgorithm(TradingAlgorithm): def handle_data(self, data): self.incr += 1 self.record(incr=self.incr) + name = 'name' + self.record(name, self.incr) + record(name, self.incr, 'name2', 2, name3=self.incr) class TestOrderAlgorithm(TradingAlgorithm):