mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-04 21:53:45 +08:00
Merge pull request #122 from quantopian/batch_transform
Refactoring of algorithm and batch transforms. Better unittests.
This commit is contained in:
@@ -294,15 +294,24 @@ class BatchTransformTestCase(TestCase):
|
||||
setup_logger(self)
|
||||
self.source, self.df = factory.create_test_df_source()
|
||||
|
||||
def test_batch_inherit(self):
|
||||
def test_event_window(self):
|
||||
algo = BatchTransformAlgorithm(sids=[0, 1])
|
||||
algo.run(self.source)
|
||||
|
||||
assert algo.history_class[:2] == algo.history_decorator[:2] == [None, None], "First two iterations should return None"
|
||||
self.assertEqual(algo.history_return_price_class[:2], [None, None], "First two iterations should return None")
|
||||
self.assertEqual(algo.history_return_price_decorator[:2], [None, None], "First two iterations should return None")
|
||||
|
||||
# test overloaded class
|
||||
for test_history in [algo.history_class, algo.history_decorator]:
|
||||
for test_history in [algo.history_return_price_class, algo.history_return_price_decorator]:
|
||||
self.assertTrue(np.all(test_history[2].values.flatten() == range(4, 10)))
|
||||
self.assertTrue(np.all(test_history[3].values.flatten() == range(4, 10)))
|
||||
self.assertTrue(np.all(test_history[4].values.flatten() == range(6, 14)))
|
||||
|
||||
def test_passing_of_args(self):
|
||||
algo = BatchTransformAlgorithm([0, 1], 1, kwarg='str')
|
||||
self.assertEqual(algo.args, (1,))
|
||||
self.assertEqual(algo.kwargs, {'kwarg':'str'})
|
||||
|
||||
algo.run(self.source)
|
||||
expected_item = ((1, ), {'kwarg': 'str'})
|
||||
self.assertEqual(algo.history_return_args, [None, None, expected_item, expected_item, expected_item])
|
||||
|
||||
@@ -24,7 +24,7 @@ class TradingAlgorithm(object):
|
||||
```
|
||||
To then to run this algorithm:
|
||||
|
||||
>>> my_algo = MyAlgo(100, sids=[0])
|
||||
>>> my_algo = MyAlgo([0], 100) # first argument has to be list of sids
|
||||
>>> stats = my_algo.run(data)
|
||||
|
||||
"""
|
||||
@@ -32,7 +32,7 @@ class TradingAlgorithm(object):
|
||||
"""
|
||||
Initialize sids and other state variables.
|
||||
|
||||
Calls user-defined initialize and forwarding *args and **kwargs.
|
||||
Calls user-defined initialize() forwarding *args and **kwargs.
|
||||
"""
|
||||
self.sids = sids
|
||||
self.done = False
|
||||
@@ -45,6 +45,8 @@ class TradingAlgorithm(object):
|
||||
# call to user-defined initialize method
|
||||
self.initialize(*args, **kwargs)
|
||||
|
||||
self.initialized = True
|
||||
|
||||
def _create_simulator(self, start, end):
|
||||
"""
|
||||
Create trading environment, transforms and SimulatedTrading object.
|
||||
|
||||
@@ -219,9 +219,11 @@ class AlgorithmSimulator(object):
|
||||
# snapshot time to any log record generated.
|
||||
#with self.processor.threadbound(), self.stdout_capture(Logger('Print'),''):
|
||||
|
||||
# Call user's initialize method with a timeout.
|
||||
with Timeout(INIT_TIMEOUT, message="Call to initialize timed out"):
|
||||
self.algo.initialize()
|
||||
# Call user's initialize method with a timeout (only if
|
||||
# initialize wasn't called already).
|
||||
if not getattr(self.algo, 'initialized', False):
|
||||
with Timeout(INIT_TIMEOUT, message="Call to initialize timed out"):
|
||||
self.algo.initialize()
|
||||
|
||||
# Group together events with the same dt field. This depends on the
|
||||
# events already being sorted.
|
||||
|
||||
@@ -339,7 +339,7 @@ class BatchTransform(EventWindow):
|
||||
self.updated = False
|
||||
self.data = None
|
||||
|
||||
def handle_data(self, data):
|
||||
def handle_data(self, data, *args, **kwargs):
|
||||
"""
|
||||
New method to handle a data frame as sent to the algorithm's handle_data
|
||||
method.
|
||||
@@ -356,7 +356,7 @@ class BatchTransform(EventWindow):
|
||||
self.update(data)
|
||||
|
||||
# return newly computed or cached value
|
||||
return self.get_transform_value()
|
||||
return self.get_transform_value(*args, **kwargs)
|
||||
|
||||
def handle_add(self, event):
|
||||
if not self.last_refresh:
|
||||
|
||||
+2
-1
@@ -60,7 +60,6 @@ before invoking simulate.
|
||||
+---------------------------------+
|
||||
"""
|
||||
|
||||
from zipline.test_algorithms import TestAlgorithm
|
||||
from zipline.utils import factory
|
||||
|
||||
from zipline.gens.composites import (
|
||||
@@ -144,6 +143,8 @@ class SimulatedTrading(object):
|
||||
- transforms: optional parameter that provides a list
|
||||
of StatefulTransform objects.
|
||||
"""
|
||||
from zipline.test_algorithms import TestAlgorithm
|
||||
|
||||
assert isinstance(config, dict)
|
||||
sid_list = config.get('sid_list')
|
||||
if not sid_list:
|
||||
|
||||
@@ -8,13 +8,7 @@ import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import cProfile
|
||||
from zipline.gens.mavg import MovingAverage
|
||||
from zipline.gens.cov import CovTransform, cov
|
||||
from zipline.algorithm import TradingAlgorithm
|
||||
from zipline.gens.transform import BatchTransform, batch_transform
|
||||
|
||||
@batch_transform
|
||||
def cov(data):
|
||||
return data.price.cov()
|
||||
|
||||
class DMA(TradingAlgorithm):
|
||||
"""Dual Moving Average algorithm.
|
||||
@@ -35,14 +29,9 @@ class DMA(TradingAlgorithm):
|
||||
market_aware=True,
|
||||
days=long_window)
|
||||
|
||||
self.cov = cov(sids=self.sids, refresh_period=1, days=5)
|
||||
|
||||
def handle_data(self, data):
|
||||
self.events += 1
|
||||
|
||||
cov = self.cov.handle_data(data)
|
||||
print cov
|
||||
|
||||
for sid in self.sids:
|
||||
# access transforms via their user-defined tag
|
||||
if (data[sid].short_mavg['price'] > data[sid].long_mavg['price']) and not self.invested[sid]:
|
||||
@@ -56,17 +45,17 @@ class DMA(TradingAlgorithm):
|
||||
def load_close_px(indexes=None, stocks=None):
|
||||
from pandas.io.data import DataReader
|
||||
import pytz
|
||||
from collections import OrderedDict
|
||||
|
||||
if indexes is None:
|
||||
indexes = {'SPX' : '^GSPC'}
|
||||
if stocks is None:
|
||||
stocks = ['AAPL'] #, 'GE', 'IBM', 'MSFT', 'XOM', 'AA', 'JNJ', 'PEP']
|
||||
stocks = ['AAPL', 'GE', 'IBM', 'MSFT', 'XOM', 'AA', 'JNJ', 'PEP']
|
||||
|
||||
#start = pd.datetime(1990, 1, 1)
|
||||
start = pd.datetime(1990, 1, 1, 0, 0, 0, 0, pytz.utc)
|
||||
end = pd.datetime(1992, 1, 1, 0, 0, 0, 0, pytz.utc) #pd.datetime.today()
|
||||
end = pd.datetime(1992, 1, 1, 0, 0, 0, 0, pytz.utc)
|
||||
|
||||
data = {}
|
||||
data = OrderedDict()
|
||||
|
||||
for stock in stocks:
|
||||
print stock
|
||||
@@ -82,13 +71,13 @@ def load_close_px(indexes=None, stocks=None):
|
||||
df = pd.DataFrame({i: d['Close'] for i, d in enumerate(data.itervalues())})
|
||||
df.index = df.index.tz_localize(pytz.utc)
|
||||
|
||||
df.save('close_px.dat')
|
||||
return df
|
||||
|
||||
|
||||
def run((short_window, long_window)):
|
||||
#data = pd.DataFrame.from_csv('SP500.csv')
|
||||
#data = pd.DataFrame.from_csv('aapl.csv') #load_close_px()
|
||||
data = load_close_px()
|
||||
data = pd.load('close_px.dat')
|
||||
#data = load_close_px()
|
||||
myalgo = DMA([0, 1], amount=100, short_window=short_window, long_window=long_window)
|
||||
stats = myalgo.run(data)
|
||||
stats['sw'] = short_window
|
||||
|
||||
Executable
+2
@@ -0,0 +1,2 @@
|
||||
#!/bin/bash
|
||||
python -m cProfile -o example.prof example.py
|
||||
+37
-17
@@ -399,33 +399,53 @@ class TestRegisterTransformAlgorithm(TradingAlgorithm):
|
||||
def handle_data(self, data):
|
||||
pass
|
||||
|
||||
class NoopBatchTransform(BatchTransform):
|
||||
##########################################
|
||||
# Algorithm using simple batch transforms
|
||||
|
||||
class ReturnPriceBatchTransform(BatchTransform):
|
||||
def get_value(self, data):
|
||||
return data.price
|
||||
|
||||
@batch_transform
|
||||
def noop_batch_decorator(data):
|
||||
def return_price_batch_decorator(data):
|
||||
return data.price
|
||||
|
||||
@batch_transform
|
||||
def return_args_batch_decorator(data, *args, **kwargs):
|
||||
return args, kwargs
|
||||
|
||||
class BatchTransformAlgorithm(TradingAlgorithm):
|
||||
def initialize(self, *args, **kwargs):
|
||||
self.history_class = []
|
||||
self.history_decorator = []
|
||||
self.days = 3
|
||||
self.noop_class = NoopBatchTransform(sids=[0, 1],
|
||||
market_aware=False,
|
||||
refresh_period=2,
|
||||
delta=timedelta(days=self.days))
|
||||
self.history_return_price_class = []
|
||||
self.history_return_price_decorator = []
|
||||
self.history_return_args = []
|
||||
|
||||
self.noop_decorator = noop_batch_decorator(sids=[0, 1],
|
||||
market_aware=False,
|
||||
refresh_period=2,
|
||||
delta=timedelta(days=self.days))
|
||||
self.days = 3
|
||||
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
self.return_price_class = ReturnPriceBatchTransform(sids=self.sids,
|
||||
market_aware=False,
|
||||
refresh_period=2,
|
||||
delta=timedelta(days=self.days)
|
||||
)
|
||||
|
||||
self.return_price_decorator = return_price_batch_decorator(sids=self.sids,
|
||||
market_aware=False,
|
||||
refresh_period=2,
|
||||
delta=timedelta(days=self.days)
|
||||
)
|
||||
|
||||
self.return_args_batch = return_args_batch_decorator(sids=self.sids,
|
||||
market_aware=False,
|
||||
refresh_period=2,
|
||||
delta=timedelta(days=self.days)
|
||||
)
|
||||
|
||||
def handle_data(self, data):
|
||||
window_class = self.noop_class.handle_data(data)
|
||||
window_decorator = self.noop_decorator.handle_data(data)
|
||||
self.history_class.append(window_class)
|
||||
self.history_decorator.append(window_decorator)
|
||||
self.history_return_price_class.append(self.return_price_class.handle_data(data))
|
||||
self.history_return_price_decorator.append(self.return_price_decorator.handle_data(data))
|
||||
self.history_return_args.append(self.return_args_batch.handle_data(data, *self.args, **self.kwargs))
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user