mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-04 22:27:50 +08:00
Added passing of arguments to algorithm initialize and batch transform.
This commit is contained in:
@@ -294,15 +294,23 @@ class BatchTransformTestCase(TestCase):
|
||||
setup_logger(self)
|
||||
self.source, self.df = factory.create_test_df_source()
|
||||
|
||||
def test_batch_inherit(self):
|
||||
def test_event_window(self):
|
||||
algo = BatchTransformAlgorithm(sids=[0, 1])
|
||||
algo.run(self.source)
|
||||
|
||||
assert algo.history_class[:2] == algo.history_decorator[:2] == [None, None], "First two iterations should return None"
|
||||
assert algo.history_return_price_class[:2] == algo.history_return_price_decorator[:2] == [None, None], "First two iterations should return None"
|
||||
|
||||
# test overloaded class
|
||||
for test_history in [algo.history_class, algo.history_decorator]:
|
||||
for test_history in [algo.history_return_price_class, algo.history_return_price_decorator]:
|
||||
self.assertTrue(np.all(test_history[2].values.flatten() == range(4, 10)))
|
||||
self.assertTrue(np.all(test_history[3].values.flatten() == range(4, 10)))
|
||||
self.assertTrue(np.all(test_history[4].values.flatten() == range(6, 14)))
|
||||
|
||||
def test_passing_of_args(self):
|
||||
algo = BatchTransformAlgorithm([0, 1], 1, kwarg='str')
|
||||
algo.run(self.source)
|
||||
self.assertEqual(algo.args, (1,))
|
||||
self.assertEqual(algo.kwargs, {'kwarg':'str'})
|
||||
expected_item = ((1, ), {'kwarg': 'str'})
|
||||
self.assertEqual(algo.history_return_args, [None, None, expected_item, expected_item, expected_item])
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ class TradingAlgorithm(object):
|
||||
```
|
||||
To then to run this algorithm:
|
||||
|
||||
>>> my_algo = MyAlgo(100, sids=[0])
|
||||
>>> my_algo = MyAlgo([0], 100) # first argument has to be list of sids
|
||||
>>> stats = my_algo.run(data)
|
||||
|
||||
"""
|
||||
@@ -32,7 +32,7 @@ class TradingAlgorithm(object):
|
||||
"""
|
||||
Initialize sids and other state variables.
|
||||
|
||||
Calls user-defined initialize and forwarding *args and **kwargs.
|
||||
Calls user-defined initialize() forwarding *args and **kwargs.
|
||||
"""
|
||||
self.sids = sids
|
||||
self.done = False
|
||||
@@ -45,6 +45,8 @@ class TradingAlgorithm(object):
|
||||
# call to user-defined initialize method
|
||||
self.initialize(*args, **kwargs)
|
||||
|
||||
self.initialized = True
|
||||
|
||||
def _create_simulator(self, start, end):
|
||||
"""
|
||||
Create trading environment, transforms and SimulatedTrading object.
|
||||
|
||||
@@ -219,9 +219,11 @@ class AlgorithmSimulator(object):
|
||||
# snapshot time to any log record generated.
|
||||
#with self.processor.threadbound(), self.stdout_capture(Logger('Print'),''):
|
||||
|
||||
# Call user's initialize method with a timeout.
|
||||
with Timeout(INIT_TIMEOUT, message="Call to initialize timed out"):
|
||||
self.algo.initialize()
|
||||
# Call user's initialize method with a timeout (only if
|
||||
# initialize wasn't called already).
|
||||
if not getattr(self.algo, 'initialized', False):
|
||||
with Timeout(INIT_TIMEOUT, message="Call to initialize timed out"):
|
||||
self.algo.initialize()
|
||||
|
||||
# Group together events with the same dt field. This depends on the
|
||||
# events already being sorted.
|
||||
|
||||
@@ -339,7 +339,7 @@ class BatchTransform(EventWindow):
|
||||
self.updated = False
|
||||
self.data = None
|
||||
|
||||
def handle_data(self, data):
|
||||
def handle_data(self, data, *args, **kwargs):
|
||||
"""
|
||||
New method to handle a data frame as sent to the algorithm's handle_data
|
||||
method.
|
||||
@@ -356,7 +356,7 @@ class BatchTransform(EventWindow):
|
||||
self.update(data)
|
||||
|
||||
# return newly computed or cached value
|
||||
return self.get_transform_value()
|
||||
return self.get_transform_value(*args, **kwargs)
|
||||
|
||||
def handle_add(self, event):
|
||||
if not self.last_refresh:
|
||||
|
||||
+37
-17
@@ -399,33 +399,53 @@ class TestRegisterTransformAlgorithm(TradingAlgorithm):
|
||||
def handle_data(self, data):
|
||||
pass
|
||||
|
||||
class NoopBatchTransform(BatchTransform):
|
||||
##########################################
|
||||
# Algorithm using simple batch transforms
|
||||
|
||||
class ReturnPriceBatchTransform(BatchTransform):
|
||||
def get_value(self, data):
|
||||
return data.price
|
||||
|
||||
@batch_transform
|
||||
def noop_batch_decorator(data):
|
||||
def return_price_batch_decorator(data):
|
||||
return data.price
|
||||
|
||||
@batch_transform
|
||||
def return_args_batch_decorator(data, *args, **kwargs):
|
||||
return args, kwargs
|
||||
|
||||
class BatchTransformAlgorithm(TradingAlgorithm):
|
||||
def initialize(self, *args, **kwargs):
|
||||
self.history_class = []
|
||||
self.history_decorator = []
|
||||
self.days = 3
|
||||
self.noop_class = NoopBatchTransform(sids=[0, 1],
|
||||
market_aware=False,
|
||||
refresh_period=2,
|
||||
delta=timedelta(days=self.days))
|
||||
self.history_return_price_class = []
|
||||
self.history_return_price_decorator = []
|
||||
self.history_return_args = []
|
||||
|
||||
self.noop_decorator = noop_batch_decorator(sids=[0, 1],
|
||||
market_aware=False,
|
||||
refresh_period=2,
|
||||
delta=timedelta(days=self.days))
|
||||
self.days = 3
|
||||
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
self.return_price_class = ReturnPriceBatchTransform(sids=self.sids,
|
||||
market_aware=False,
|
||||
refresh_period=2,
|
||||
delta=timedelta(days=self.days)
|
||||
)
|
||||
|
||||
self.return_price_decorator = return_price_batch_decorator(sids=self.sids,
|
||||
market_aware=False,
|
||||
refresh_period=2,
|
||||
delta=timedelta(days=self.days)
|
||||
)
|
||||
|
||||
self.return_args_batch = return_args_batch_decorator(sids=self.sids,
|
||||
market_aware=False,
|
||||
refresh_period=2,
|
||||
delta=timedelta(days=self.days)
|
||||
)
|
||||
|
||||
def handle_data(self, data):
|
||||
window_class = self.noop_class.handle_data(data)
|
||||
window_decorator = self.noop_decorator.handle_data(data)
|
||||
self.history_class.append(window_class)
|
||||
self.history_decorator.append(window_decorator)
|
||||
self.history_return_price_class.append(self.return_price_class.handle_data(data))
|
||||
self.history_return_price_decorator.append(self.return_price_decorator.handle_data(data))
|
||||
self.history_return_args.append(self.return_args_batch.handle_data(data, *self.args, **self.kwargs))
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user