Added passing of arguments to algorithm initialize and batch transform.

This commit is contained in:
Thomas Wiecki
2012-09-25 11:48:37 -04:00
parent 52e6a69492
commit 188a9502b0
5 changed files with 59 additions and 27 deletions
+11 -3
View File
@@ -294,15 +294,23 @@ class BatchTransformTestCase(TestCase):
setup_logger(self)
self.source, self.df = factory.create_test_df_source()
def test_batch_inherit(self):
def test_event_window(self):
algo = BatchTransformAlgorithm(sids=[0, 1])
algo.run(self.source)
assert algo.history_class[:2] == algo.history_decorator[:2] == [None, None], "First two iterations should return None"
assert algo.history_return_price_class[:2] == algo.history_return_price_decorator[:2] == [None, None], "First two iterations should return None"
# test overloaded class
for test_history in [algo.history_class, algo.history_decorator]:
for test_history in [algo.history_return_price_class, algo.history_return_price_decorator]:
self.assertTrue(np.all(test_history[2].values.flatten() == range(4, 10)))
self.assertTrue(np.all(test_history[3].values.flatten() == range(4, 10)))
self.assertTrue(np.all(test_history[4].values.flatten() == range(6, 14)))
def test_passing_of_args(self):
algo = BatchTransformAlgorithm([0, 1], 1, kwarg='str')
algo.run(self.source)
self.assertEqual(algo.args, (1,))
self.assertEqual(algo.kwargs, {'kwarg':'str'})
expected_item = ((1, ), {'kwarg': 'str'})
self.assertEqual(algo.history_return_args, [None, None, expected_item, expected_item, expected_item])
+4 -2
View File
@@ -24,7 +24,7 @@ class TradingAlgorithm(object):
```
To then to run this algorithm:
>>> my_algo = MyAlgo(100, sids=[0])
>>> my_algo = MyAlgo([0], 100) # first argument has to be list of sids
>>> stats = my_algo.run(data)
"""
@@ -32,7 +32,7 @@ class TradingAlgorithm(object):
"""
Initialize sids and other state variables.
Calls user-defined initialize and forwarding *args and **kwargs.
Calls user-defined initialize() forwarding *args and **kwargs.
"""
self.sids = sids
self.done = False
@@ -45,6 +45,8 @@ class TradingAlgorithm(object):
# call to user-defined initialize method
self.initialize(*args, **kwargs)
self.initialized = True
def _create_simulator(self, start, end):
"""
Create trading environment, transforms and SimulatedTrading object.
+5 -3
View File
@@ -219,9 +219,11 @@ class AlgorithmSimulator(object):
# snapshot time to any log record generated.
#with self.processor.threadbound(), self.stdout_capture(Logger('Print'),''):
# Call user's initialize method with a timeout.
with Timeout(INIT_TIMEOUT, message="Call to initialize timed out"):
self.algo.initialize()
# Call user's initialize method with a timeout (only if
# initialize wasn't called already).
if not getattr(self.algo, 'initialized', False):
with Timeout(INIT_TIMEOUT, message="Call to initialize timed out"):
self.algo.initialize()
# Group together events with the same dt field. This depends on the
# events already being sorted.
+2 -2
View File
@@ -339,7 +339,7 @@ class BatchTransform(EventWindow):
self.updated = False
self.data = None
def handle_data(self, data):
def handle_data(self, data, *args, **kwargs):
"""
New method to handle a data frame as sent to the algorithm's handle_data
method.
@@ -356,7 +356,7 @@ class BatchTransform(EventWindow):
self.update(data)
# return newly computed or cached value
return self.get_transform_value()
return self.get_transform_value(*args, **kwargs)
def handle_add(self, event):
if not self.last_refresh:
+37 -17
View File
@@ -399,33 +399,53 @@ class TestRegisterTransformAlgorithm(TradingAlgorithm):
def handle_data(self, data):
pass
class NoopBatchTransform(BatchTransform):
##########################################
# Algorithm using simple batch transforms
class ReturnPriceBatchTransform(BatchTransform):
def get_value(self, data):
return data.price
@batch_transform
def noop_batch_decorator(data):
def return_price_batch_decorator(data):
return data.price
@batch_transform
def return_args_batch_decorator(data, *args, **kwargs):
return args, kwargs
class BatchTransformAlgorithm(TradingAlgorithm):
def initialize(self, *args, **kwargs):
self.history_class = []
self.history_decorator = []
self.days = 3
self.noop_class = NoopBatchTransform(sids=[0, 1],
market_aware=False,
refresh_period=2,
delta=timedelta(days=self.days))
self.history_return_price_class = []
self.history_return_price_decorator = []
self.history_return_args = []
self.noop_decorator = noop_batch_decorator(sids=[0, 1],
market_aware=False,
refresh_period=2,
delta=timedelta(days=self.days))
self.days = 3
self.args = args
self.kwargs = kwargs
self.return_price_class = ReturnPriceBatchTransform(sids=self.sids,
market_aware=False,
refresh_period=2,
delta=timedelta(days=self.days)
)
self.return_price_decorator = return_price_batch_decorator(sids=self.sids,
market_aware=False,
refresh_period=2,
delta=timedelta(days=self.days)
)
self.return_args_batch = return_args_batch_decorator(sids=self.sids,
market_aware=False,
refresh_period=2,
delta=timedelta(days=self.days)
)
def handle_data(self, data):
window_class = self.noop_class.handle_data(data)
window_decorator = self.noop_decorator.handle_data(data)
self.history_class.append(window_class)
self.history_decorator.append(window_decorator)
self.history_return_price_class.append(self.return_price_class.handle_data(data))
self.history_return_price_decorator.append(self.return_price_decorator.handle_data(data))
self.history_return_args.append(self.return_args_batch.handle_data(data, *self.args, **self.kwargs))