Merge pull request #122 from quantopian/batch_transform

Refactoring of algorithm and batch transforms. Better unittests.
This commit is contained in:
Thomas Wiecki
2012-09-28 14:35:39 -07:00
8 changed files with 71 additions and 46 deletions
+12 -3
View File
@@ -294,15 +294,24 @@ class BatchTransformTestCase(TestCase):
setup_logger(self)
self.source, self.df = factory.create_test_df_source()
def test_batch_inherit(self):
def test_event_window(self):
algo = BatchTransformAlgorithm(sids=[0, 1])
algo.run(self.source)
assert algo.history_class[:2] == algo.history_decorator[:2] == [None, None], "First two iterations should return None"
self.assertEqual(algo.history_return_price_class[:2], [None, None], "First two iterations should return None")
self.assertEqual(algo.history_return_price_decorator[:2], [None, None], "First two iterations should return None")
# test overloaded class
for test_history in [algo.history_class, algo.history_decorator]:
for test_history in [algo.history_return_price_class, algo.history_return_price_decorator]:
self.assertTrue(np.all(test_history[2].values.flatten() == range(4, 10)))
self.assertTrue(np.all(test_history[3].values.flatten() == range(4, 10)))
self.assertTrue(np.all(test_history[4].values.flatten() == range(6, 14)))
def test_passing_of_args(self):
algo = BatchTransformAlgorithm([0, 1], 1, kwarg='str')
self.assertEqual(algo.args, (1,))
self.assertEqual(algo.kwargs, {'kwarg':'str'})
algo.run(self.source)
expected_item = ((1, ), {'kwarg': 'str'})
self.assertEqual(algo.history_return_args, [None, None, expected_item, expected_item, expected_item])
+4 -2
View File
@@ -24,7 +24,7 @@ class TradingAlgorithm(object):
```
To then to run this algorithm:
>>> my_algo = MyAlgo(100, sids=[0])
>>> my_algo = MyAlgo([0], 100) # first argument has to be list of sids
>>> stats = my_algo.run(data)
"""
@@ -32,7 +32,7 @@ class TradingAlgorithm(object):
"""
Initialize sids and other state variables.
Calls user-defined initialize and forwarding *args and **kwargs.
Calls user-defined initialize() forwarding *args and **kwargs.
"""
self.sids = sids
self.done = False
@@ -45,6 +45,8 @@ class TradingAlgorithm(object):
# call to user-defined initialize method
self.initialize(*args, **kwargs)
self.initialized = True
def _create_simulator(self, start, end):
"""
Create trading environment, transforms and SimulatedTrading object.
+5 -3
View File
@@ -219,9 +219,11 @@ class AlgorithmSimulator(object):
# snapshot time to any log record generated.
#with self.processor.threadbound(), self.stdout_capture(Logger('Print'),''):
# Call user's initialize method with a timeout.
with Timeout(INIT_TIMEOUT, message="Call to initialize timed out"):
self.algo.initialize()
# Call user's initialize method with a timeout (only if
# initialize wasn't called already).
if not getattr(self.algo, 'initialized', False):
with Timeout(INIT_TIMEOUT, message="Call to initialize timed out"):
self.algo.initialize()
# Group together events with the same dt field. This depends on the
# events already being sorted.
+2 -2
View File
@@ -339,7 +339,7 @@ class BatchTransform(EventWindow):
self.updated = False
self.data = None
def handle_data(self, data):
def handle_data(self, data, *args, **kwargs):
"""
New method to handle a data frame as sent to the algorithm's handle_data
method.
@@ -356,7 +356,7 @@ class BatchTransform(EventWindow):
self.update(data)
# return newly computed or cached value
return self.get_transform_value()
return self.get_transform_value(*args, **kwargs)
def handle_add(self, event):
if not self.last_refresh:
+2 -1
View File
@@ -60,7 +60,6 @@ before invoking simulate.
+---------------------------------+
"""
from zipline.test_algorithms import TestAlgorithm
from zipline.utils import factory
from zipline.gens.composites import (
@@ -144,6 +143,8 @@ class SimulatedTrading(object):
- transforms: optional parameter that provides a list
of StatefulTransform objects.
"""
from zipline.test_algorithms import TestAlgorithm
assert isinstance(config, dict)
sid_list = config.get('sid_list')
if not sid_list:
+7 -18
View File
@@ -8,13 +8,7 @@ import numpy as np
import matplotlib.pyplot as plt
import cProfile
from zipline.gens.mavg import MovingAverage
from zipline.gens.cov import CovTransform, cov
from zipline.algorithm import TradingAlgorithm
from zipline.gens.transform import BatchTransform, batch_transform
@batch_transform
def cov(data):
return data.price.cov()
class DMA(TradingAlgorithm):
"""Dual Moving Average algorithm.
@@ -35,14 +29,9 @@ class DMA(TradingAlgorithm):
market_aware=True,
days=long_window)
self.cov = cov(sids=self.sids, refresh_period=1, days=5)
def handle_data(self, data):
self.events += 1
cov = self.cov.handle_data(data)
print cov
for sid in self.sids:
# access transforms via their user-defined tag
if (data[sid].short_mavg['price'] > data[sid].long_mavg['price']) and not self.invested[sid]:
@@ -56,17 +45,17 @@ class DMA(TradingAlgorithm):
def load_close_px(indexes=None, stocks=None):
from pandas.io.data import DataReader
import pytz
from collections import OrderedDict
if indexes is None:
indexes = {'SPX' : '^GSPC'}
if stocks is None:
stocks = ['AAPL'] #, 'GE', 'IBM', 'MSFT', 'XOM', 'AA', 'JNJ', 'PEP']
stocks = ['AAPL', 'GE', 'IBM', 'MSFT', 'XOM', 'AA', 'JNJ', 'PEP']
#start = pd.datetime(1990, 1, 1)
start = pd.datetime(1990, 1, 1, 0, 0, 0, 0, pytz.utc)
end = pd.datetime(1992, 1, 1, 0, 0, 0, 0, pytz.utc) #pd.datetime.today()
end = pd.datetime(1992, 1, 1, 0, 0, 0, 0, pytz.utc)
data = {}
data = OrderedDict()
for stock in stocks:
print stock
@@ -82,13 +71,13 @@ def load_close_px(indexes=None, stocks=None):
df = pd.DataFrame({i: d['Close'] for i, d in enumerate(data.itervalues())})
df.index = df.index.tz_localize(pytz.utc)
df.save('close_px.dat')
return df
def run((short_window, long_window)):
#data = pd.DataFrame.from_csv('SP500.csv')
#data = pd.DataFrame.from_csv('aapl.csv') #load_close_px()
data = load_close_px()
data = pd.load('close_px.dat')
#data = load_close_px()
myalgo = DMA([0, 1], amount=100, short_window=short_window, long_window=long_window)
stats = myalgo.run(data)
stats['sw'] = short_window
+2
View File
@@ -0,0 +1,2 @@
#!/bin/bash
python -m cProfile -o example.prof example.py
+37 -17
View File
@@ -399,33 +399,53 @@ class TestRegisterTransformAlgorithm(TradingAlgorithm):
def handle_data(self, data):
pass
class NoopBatchTransform(BatchTransform):
##########################################
# Algorithm using simple batch transforms
class ReturnPriceBatchTransform(BatchTransform):
def get_value(self, data):
return data.price
@batch_transform
def noop_batch_decorator(data):
def return_price_batch_decorator(data):
return data.price
@batch_transform
def return_args_batch_decorator(data, *args, **kwargs):
return args, kwargs
class BatchTransformAlgorithm(TradingAlgorithm):
def initialize(self, *args, **kwargs):
self.history_class = []
self.history_decorator = []
self.days = 3
self.noop_class = NoopBatchTransform(sids=[0, 1],
market_aware=False,
refresh_period=2,
delta=timedelta(days=self.days))
self.history_return_price_class = []
self.history_return_price_decorator = []
self.history_return_args = []
self.noop_decorator = noop_batch_decorator(sids=[0, 1],
market_aware=False,
refresh_period=2,
delta=timedelta(days=self.days))
self.days = 3
self.args = args
self.kwargs = kwargs
self.return_price_class = ReturnPriceBatchTransform(sids=self.sids,
market_aware=False,
refresh_period=2,
delta=timedelta(days=self.days)
)
self.return_price_decorator = return_price_batch_decorator(sids=self.sids,
market_aware=False,
refresh_period=2,
delta=timedelta(days=self.days)
)
self.return_args_batch = return_args_batch_decorator(sids=self.sids,
market_aware=False,
refresh_period=2,
delta=timedelta(days=self.days)
)
def handle_data(self, data):
window_class = self.noop_class.handle_data(data)
window_decorator = self.noop_decorator.handle_data(data)
self.history_class.append(window_class)
self.history_decorator.append(window_decorator)
self.history_return_price_class.append(self.return_price_class.handle_data(data))
self.history_return_price_decorator.append(self.return_price_decorator.handle_data(data))
self.history_return_args.append(self.return_args_batch.handle_data(data, *self.args, **self.kwargs))