mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-02 19:37:42 +08:00
ENH: Added dynamic name functionality to record() API function.
Added the ability to pass *args before the **kwargs so that positional arguments of the form name, value can be recorded.
This commit is contained in:
committed by
Scott Sanderson
parent
75b415ac48
commit
c3169f60cd
@@ -98,6 +98,12 @@ class TestRecordAlgorithm(TestCase):
|
||||
|
||||
np.testing.assert_array_equal(output['incr'].values,
|
||||
range(1, len(output) + 1))
|
||||
np.testing.assert_array_equal(output['name'].values,
|
||||
range(1, len(output) + 1))
|
||||
np.testing.assert_array_equal(output['name2'].values,
|
||||
[2] * len(output))
|
||||
np.testing.assert_array_equal(output['name3'].values,
|
||||
range(1, len(output) + 1))
|
||||
|
||||
|
||||
class TestMiscellaneousAPI(TestCase):
|
||||
|
||||
+11
-3
@@ -21,7 +21,7 @@ import numpy as np
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from itertools import groupby
|
||||
from itertools import groupby, chain
|
||||
from six.moves import filter
|
||||
from six import iteritems, exec_
|
||||
from operator import attrgetter
|
||||
@@ -467,11 +467,19 @@ class TradingAlgorithm(object):
|
||||
'kwargs': kwargs}
|
||||
|
||||
@api_method
|
||||
def record(self, **kwargs):
|
||||
def record(self, *args, **kwargs):
|
||||
"""
|
||||
Track and record local variable (i.e. attributes) each day.
|
||||
"""
|
||||
for name, value in kwargs.items():
|
||||
# Make 2 objects both referencing the same iterator
|
||||
args = [iter(args)] * 2
|
||||
|
||||
# Zip generates list entries by calling `next` on each iterator it
|
||||
# receives. In this case the two iterators are the same object, so the
|
||||
# call to next on args[0] will also advance args[1], resulting in zip
|
||||
# returning (a,b) (c,d) (e,f) rather than (a,a) (b,b) (c,c) etc.
|
||||
positionals = zip(*args)
|
||||
for name, value in chain(positionals, iteritems(kwargs)):
|
||||
self._recorded_vars[name] = value
|
||||
|
||||
@api_method
|
||||
|
||||
@@ -233,6 +233,9 @@ class RecordAlgorithm(TradingAlgorithm):
|
||||
def handle_data(self, data):
|
||||
self.incr += 1
|
||||
self.record(incr=self.incr)
|
||||
name = 'name'
|
||||
self.record(name, self.incr)
|
||||
record(name, self.incr, 'name2', 2, name3=self.incr)
|
||||
|
||||
|
||||
class TestOrderAlgorithm(TradingAlgorithm):
|
||||
|
||||
Reference in New Issue
Block a user