From c3169f60cd0a641e4457d87f3afd726f58205784 Mon Sep 17 00:00:00 2001 From: Delaney Granizo-Mackenzie Date: Thu, 26 Jun 2014 16:46:23 -0400 Subject: [PATCH] ENH: Added dynamic name functionality to record() API function. Added the ability to pass *args before the **kwargs so that positional arguments of the form name, value can be recorded. --- tests/test_algorithm.py | 6 ++++++ zipline/algorithm.py | 14 +++++++++++--- zipline/test_algorithms.py | 3 +++ 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index dcf5906b..a806c5fd 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -98,6 +98,12 @@ class TestRecordAlgorithm(TestCase): np.testing.assert_array_equal(output['incr'].values, range(1, len(output) + 1)) + np.testing.assert_array_equal(output['name'].values, + range(1, len(output) + 1)) + np.testing.assert_array_equal(output['name2'].values, + [2] * len(output)) + np.testing.assert_array_equal(output['name3'].values, + range(1, len(output) + 1)) class TestMiscellaneousAPI(TestCase): diff --git a/zipline/algorithm.py b/zipline/algorithm.py index 950ae099..1949763d 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -21,7 +21,7 @@ import numpy as np from datetime import datetime -from itertools import groupby +from itertools import groupby, chain from six.moves import filter from six import iteritems, exec_ from operator import attrgetter @@ -467,11 +467,19 @@ class TradingAlgorithm(object): 'kwargs': kwargs} @api_method - def record(self, **kwargs): + def record(self, *args, **kwargs): """ Track and record local variable (i.e. attributes) each day. """ - for name, value in kwargs.items(): + # Make 2 objects both referencing the same iterator + args = [iter(args)] * 2 + + # Zip generates list entries by calling `next` on each iterator it + # receives. In this case the two iterators are the same object, so the + # call to next on args[0] will also advance args[1], resulting in zip + # returning (a,b) (c,d) (e,f) rather than (a,a) (b,b) (c,c) etc. + positionals = zip(*args) + for name, value in chain(positionals, iteritems(kwargs)): self._recorded_vars[name] = value @api_method diff --git a/zipline/test_algorithms.py b/zipline/test_algorithms.py index 6349ba1c..ef99f597 100644 --- a/zipline/test_algorithms.py +++ b/zipline/test_algorithms.py @@ -233,6 +233,9 @@ class RecordAlgorithm(TradingAlgorithm): def handle_data(self, data): self.incr += 1 self.record(incr=self.incr) + name = 'name' + self.record(name, self.incr) + record(name, self.incr, 'name2', 2, name3=self.incr) class TestOrderAlgorithm(TradingAlgorithm):