From 188a9502b0e5916b9c70fadceca98f54bbe911af Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Tue, 25 Sep 2012 11:48:37 -0400 Subject: [PATCH 1/8] Added passing of arguments to algorithm initialize and batch transform. --- tests/test_transforms.py | 14 +++++++-- zipline/algorithm.py | 6 ++-- zipline/gens/tradesimulation.py | 8 +++-- zipline/gens/transform.py | 4 +-- zipline/test_algorithms.py | 54 ++++++++++++++++++++++----------- 5 files changed, 59 insertions(+), 27 deletions(-) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index ea1024f6..e3cc1961 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -294,15 +294,23 @@ class BatchTransformTestCase(TestCase): setup_logger(self) self.source, self.df = factory.create_test_df_source() - def test_batch_inherit(self): + def test_event_window(self): algo = BatchTransformAlgorithm(sids=[0, 1]) algo.run(self.source) - assert algo.history_class[:2] == algo.history_decorator[:2] == [None, None], "First two iterations should return None" + assert algo.history_return_price_class[:2] == algo.history_return_price_decorator[:2] == [None, None], "First two iterations should return None" # test overloaded class - for test_history in [algo.history_class, algo.history_decorator]: + for test_history in [algo.history_return_price_class, algo.history_return_price_decorator]: self.assertTrue(np.all(test_history[2].values.flatten() == range(4, 10))) self.assertTrue(np.all(test_history[3].values.flatten() == range(4, 10))) self.assertTrue(np.all(test_history[4].values.flatten() == range(6, 14))) + def test_passing_of_args(self): + algo = BatchTransformAlgorithm([0, 1], 1, kwarg='str') + algo.run(self.source) + self.assertEqual(algo.args, (1,)) + self.assertEqual(algo.kwargs, {'kwarg':'str'}) + expected_item = ((1, ), {'kwarg': 'str'}) + self.assertEqual(algo.history_return_args, [None, None, expected_item, expected_item, expected_item]) + diff --git a/zipline/algorithm.py b/zipline/algorithm.py index 08d388da..f54ba0cb 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -24,7 +24,7 @@ class TradingAlgorithm(object): ``` To then to run this algorithm: - >>> my_algo = MyAlgo(100, sids=[0]) + >>> my_algo = MyAlgo([0], 100) # first argument has to be list of sids >>> stats = my_algo.run(data) """ @@ -32,7 +32,7 @@ class TradingAlgorithm(object): """ Initialize sids and other state variables. - Calls user-defined initialize and forwarding *args and **kwargs. + Calls user-defined initialize() forwarding *args and **kwargs. """ self.sids = sids self.done = False @@ -45,6 +45,8 @@ class TradingAlgorithm(object): # call to user-defined initialize method self.initialize(*args, **kwargs) + self.initialized = True + def _create_simulator(self, start, end): """ Create trading environment, transforms and SimulatedTrading object. diff --git a/zipline/gens/tradesimulation.py b/zipline/gens/tradesimulation.py index 5efb1505..5b5cec52 100644 --- a/zipline/gens/tradesimulation.py +++ b/zipline/gens/tradesimulation.py @@ -219,9 +219,11 @@ class AlgorithmSimulator(object): # snapshot time to any log record generated. #with self.processor.threadbound(), self.stdout_capture(Logger('Print'),''): - # Call user's initialize method with a timeout. - with Timeout(INIT_TIMEOUT, message="Call to initialize timed out"): - self.algo.initialize() + # Call user's initialize method with a timeout (only if + # initialize wasn't called already). + if not getattr(self.algo, 'initialized', False): + with Timeout(INIT_TIMEOUT, message="Call to initialize timed out"): + self.algo.initialize() # Group together events with the same dt field. This depends on the # events already being sorted. diff --git a/zipline/gens/transform.py b/zipline/gens/transform.py index 8ec862eb..3ff259cc 100644 --- a/zipline/gens/transform.py +++ b/zipline/gens/transform.py @@ -339,7 +339,7 @@ class BatchTransform(EventWindow): self.updated = False self.data = None - def handle_data(self, data): + def handle_data(self, data, *args, **kwargs): """ New method to handle a data frame as sent to the algorithm's handle_data method. @@ -356,7 +356,7 @@ class BatchTransform(EventWindow): self.update(data) # return newly computed or cached value - return self.get_transform_value() + return self.get_transform_value(*args, **kwargs) def handle_add(self, event): if not self.last_refresh: diff --git a/zipline/test_algorithms.py b/zipline/test_algorithms.py index 10795e33..a8bc6ab4 100644 --- a/zipline/test_algorithms.py +++ b/zipline/test_algorithms.py @@ -399,33 +399,53 @@ class TestRegisterTransformAlgorithm(TradingAlgorithm): def handle_data(self, data): pass -class NoopBatchTransform(BatchTransform): +########################################## +# Algorithm using simple batch transforms + +class ReturnPriceBatchTransform(BatchTransform): def get_value(self, data): return data.price @batch_transform -def noop_batch_decorator(data): +def return_price_batch_decorator(data): return data.price +@batch_transform +def return_args_batch_decorator(data, *args, **kwargs): + return args, kwargs + class BatchTransformAlgorithm(TradingAlgorithm): def initialize(self, *args, **kwargs): - self.history_class = [] - self.history_decorator = [] - self.days = 3 - self.noop_class = NoopBatchTransform(sids=[0, 1], - market_aware=False, - refresh_period=2, - delta=timedelta(days=self.days)) + self.history_return_price_class = [] + self.history_return_price_decorator = [] + self.history_return_args = [] - self.noop_decorator = noop_batch_decorator(sids=[0, 1], - market_aware=False, - refresh_period=2, - delta=timedelta(days=self.days)) + self.days = 3 + + self.args = args + self.kwargs = kwargs + + self.return_price_class = ReturnPriceBatchTransform(sids=self.sids, + market_aware=False, + refresh_period=2, + delta=timedelta(days=self.days) + ) + + self.return_price_decorator = return_price_batch_decorator(sids=self.sids, + market_aware=False, + refresh_period=2, + delta=timedelta(days=self.days) + ) + + self.return_args_batch = return_args_batch_decorator(sids=self.sids, + market_aware=False, + refresh_period=2, + delta=timedelta(days=self.days) + ) def handle_data(self, data): - window_class = self.noop_class.handle_data(data) - window_decorator = self.noop_decorator.handle_data(data) - self.history_class.append(window_class) - self.history_decorator.append(window_decorator) + self.history_return_price_class.append(self.return_price_class.handle_data(data)) + self.history_return_price_decorator.append(self.return_price_decorator.handle_data(data)) + self.history_return_args.append(self.return_args_batch.handle_data(data, *self.args, **self.kwargs)) From 2e9947ab8ccd6449d117dd939a8774900d1e9a41 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Tue, 25 Sep 2012 12:21:26 -0400 Subject: [PATCH 2/8] Accidentally commented out logging capture. --- zipline/gens/tradesimulation.py | 100 ++++++++++++++++---------------- 1 file changed, 50 insertions(+), 50 deletions(-) diff --git a/zipline/gens/tradesimulation.py b/zipline/gens/tradesimulation.py index 5b5cec52..cb46965c 100644 --- a/zipline/gens/tradesimulation.py +++ b/zipline/gens/tradesimulation.py @@ -217,63 +217,63 @@ class AlgorithmSimulator(object): # Capture any output of this generator to stdout and pipe it # to a logbook interface. Also inject the current algo # snapshot time to any log record generated. - #with self.processor.threadbound(), self.stdout_capture(Logger('Print'),''): + with self.processor.threadbound(), self.stdout_capture(Logger('Print'),''): - # Call user's initialize method with a timeout (only if - # initialize wasn't called already). - if not getattr(self.algo, 'initialized', False): - with Timeout(INIT_TIMEOUT, message="Call to initialize timed out"): - self.algo.initialize() + # Call user's initialize method with a timeout (only if + # initialize wasn't called already). + if not getattr(self.algo, 'initialized', False): + with Timeout(INIT_TIMEOUT, message="Call to initialize timed out"): + self.algo.initialize() - # Group together events with the same dt field. This depends on the - # events already being sorted. - for date, snapshot in groupby(stream_in, attrgetter('dt')): - # Set the simulation date to be the first event we see. - # This should only occur once, at the start of the test. - if self.simulation_dt == None: - self.simulation_dt = date + # Group together events with the same dt field. This depends on the + # events already being sorted. + for date, snapshot in groupby(stream_in, attrgetter('dt')): + # Set the simulation date to be the first event we see. + # This should only occur once, at the start of the test. + if self.simulation_dt == None: + self.simulation_dt = date - # Done message has the risk report, so we yield before exiting. - if date == 'DONE': - for event in snapshot: - yield event.perf_message - raise StopIteration() - - # We're still in the warmup period. Use the event to - # update our universe, but don't yield any perf messages, - # and don't send a snapshot to handle_data. - elif date < self.algo_start: - for event in snapshot: - del event['perf_message'] - self.update_universe(event) - - # The algo has taken so long to process events that - # its simulated time is later than the event time. - # Update the universe and yield any perf messages - # encountered, but don't call handle_data. - elif date < self.simulation_dt: - for event in snapshot: - # Only yield if we have something interesting to say. - if event.perf_message != None: + # Done message has the risk report, so we yield before exiting. + if date == 'DONE': + for event in snapshot: yield event.perf_message - # Delete the message before updating so we don't send it - # to the user. - del event['perf_message'] - self.update_universe(event) + raise StopIteration() - # Regular snapshot. Update the universe and send a snapshot - # to handle data. - else: - for event in snapshot: - # Only yield if we have something interesting to say. - if event.perf_message != None: - yield event.perf_message - del event['perf_message'] + # We're still in the warmup period. Use the event to + # update our universe, but don't yield any perf messages, + # and don't send a snapshot to handle_data. + elif date < self.algo_start: + for event in snapshot: + del event['perf_message'] + self.update_universe(event) - self.update_universe(event) + # The algo has taken so long to process events that + # its simulated time is later than the event time. + # Update the universe and yield any perf messages + # encountered, but don't call handle_data. + elif date < self.simulation_dt: + for event in snapshot: + # Only yield if we have something interesting to say. + if event.perf_message != None: + yield event.perf_message + # Delete the message before updating so we don't send it + # to the user. + del event['perf_message'] + self.update_universe(event) - # Send the current state of the universe to the user's algo. - self.simulate_snapshot(date) + # Regular snapshot. Update the universe and send a snapshot + # to handle data. + else: + for event in snapshot: + # Only yield if we have something interesting to say. + if event.perf_message != None: + yield event.perf_message + del event['perf_message'] + + self.update_universe(event) + + # Send the current state of the universe to the user's algo. + self.simulate_snapshot(date) def update_universe(self, event): """ From 79ab1d13c01781afc10d5edf6f6e0ccdef41e3bd Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Tue, 25 Sep 2012 13:35:47 -0400 Subject: [PATCH 3/8] Fixes circular import problem. --- zipline/lines.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/zipline/lines.py b/zipline/lines.py index a3dfa9ba..a8fedc56 100644 --- a/zipline/lines.py +++ b/zipline/lines.py @@ -60,7 +60,6 @@ before invoking simulate. +---------------------------------+ """ -from zipline.test_algorithms import TestAlgorithm from zipline.utils import factory from zipline.gens.composites import ( @@ -144,6 +143,8 @@ class SimulatedTrading(object): - transforms: optional parameter that provides a list of StatefulTransform objects. """ + from zipline.test_algorithms import TestAlgorithm + assert isinstance(config, dict) sid_list = config.get('sid_list') if not sid_list: From 53de4effd3074cfb0c07e9ab285e2f124c5f7749 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Tue, 25 Sep 2012 14:53:13 -0400 Subject: [PATCH 4/8] stdout_capture was causing problems for ipython notebook. --- zipline/gens/tradesimulation.py | 100 ++++++++++++++++---------------- zipline/optimize/example.py | 8 +-- 2 files changed, 54 insertions(+), 54 deletions(-) diff --git a/zipline/gens/tradesimulation.py b/zipline/gens/tradesimulation.py index cb46965c..5b5cec52 100644 --- a/zipline/gens/tradesimulation.py +++ b/zipline/gens/tradesimulation.py @@ -217,63 +217,63 @@ class AlgorithmSimulator(object): # Capture any output of this generator to stdout and pipe it # to a logbook interface. Also inject the current algo # snapshot time to any log record generated. - with self.processor.threadbound(), self.stdout_capture(Logger('Print'),''): + #with self.processor.threadbound(), self.stdout_capture(Logger('Print'),''): - # Call user's initialize method with a timeout (only if - # initialize wasn't called already). - if not getattr(self.algo, 'initialized', False): - with Timeout(INIT_TIMEOUT, message="Call to initialize timed out"): - self.algo.initialize() + # Call user's initialize method with a timeout (only if + # initialize wasn't called already). + if not getattr(self.algo, 'initialized', False): + with Timeout(INIT_TIMEOUT, message="Call to initialize timed out"): + self.algo.initialize() - # Group together events with the same dt field. This depends on the - # events already being sorted. - for date, snapshot in groupby(stream_in, attrgetter('dt')): - # Set the simulation date to be the first event we see. - # This should only occur once, at the start of the test. - if self.simulation_dt == None: - self.simulation_dt = date + # Group together events with the same dt field. This depends on the + # events already being sorted. + for date, snapshot in groupby(stream_in, attrgetter('dt')): + # Set the simulation date to be the first event we see. + # This should only occur once, at the start of the test. + if self.simulation_dt == None: + self.simulation_dt = date - # Done message has the risk report, so we yield before exiting. - if date == 'DONE': - for event in snapshot: + # Done message has the risk report, so we yield before exiting. + if date == 'DONE': + for event in snapshot: + yield event.perf_message + raise StopIteration() + + # We're still in the warmup period. Use the event to + # update our universe, but don't yield any perf messages, + # and don't send a snapshot to handle_data. + elif date < self.algo_start: + for event in snapshot: + del event['perf_message'] + self.update_universe(event) + + # The algo has taken so long to process events that + # its simulated time is later than the event time. + # Update the universe and yield any perf messages + # encountered, but don't call handle_data. + elif date < self.simulation_dt: + for event in snapshot: + # Only yield if we have something interesting to say. + if event.perf_message != None: yield event.perf_message - raise StopIteration() + # Delete the message before updating so we don't send it + # to the user. + del event['perf_message'] + self.update_universe(event) - # We're still in the warmup period. Use the event to - # update our universe, but don't yield any perf messages, - # and don't send a snapshot to handle_data. - elif date < self.algo_start: - for event in snapshot: - del event['perf_message'] - self.update_universe(event) + # Regular snapshot. Update the universe and send a snapshot + # to handle data. + else: + for event in snapshot: + # Only yield if we have something interesting to say. + if event.perf_message != None: + yield event.perf_message + del event['perf_message'] - # The algo has taken so long to process events that - # its simulated time is later than the event time. - # Update the universe and yield any perf messages - # encountered, but don't call handle_data. - elif date < self.simulation_dt: - for event in snapshot: - # Only yield if we have something interesting to say. - if event.perf_message != None: - yield event.perf_message - # Delete the message before updating so we don't send it - # to the user. - del event['perf_message'] - self.update_universe(event) + self.update_universe(event) - # Regular snapshot. Update the universe and send a snapshot - # to handle data. - else: - for event in snapshot: - # Only yield if we have something interesting to say. - if event.perf_message != None: - yield event.perf_message - del event['perf_message'] - - self.update_universe(event) - - # Send the current state of the universe to the user's algo. - self.simulate_snapshot(date) + # Send the current state of the universe to the user's algo. + self.simulate_snapshot(date) def update_universe(self, event): """ diff --git a/zipline/optimize/example.py b/zipline/optimize/example.py index c8912558..be6efb9f 100644 --- a/zipline/optimize/example.py +++ b/zipline/optimize/example.py @@ -56,17 +56,17 @@ class DMA(TradingAlgorithm): def load_close_px(indexes=None, stocks=None): from pandas.io.data import DataReader import pytz + from collections import OrderedDict if indexes is None: indexes = {'SPX' : '^GSPC'} if stocks is None: - stocks = ['AAPL'] #, 'GE', 'IBM', 'MSFT', 'XOM', 'AA', 'JNJ', 'PEP'] + stocks = ['AAPL', 'GE', 'IBM', 'MSFT', 'XOM', 'AA', 'JNJ', 'PEP'] - #start = pd.datetime(1990, 1, 1) start = pd.datetime(1990, 1, 1, 0, 0, 0, 0, pytz.utc) - end = pd.datetime(1992, 1, 1, 0, 0, 0, 0, pytz.utc) #pd.datetime.today() + end = pd.datetime(1992, 1, 1, 0, 0, 0, 0, pytz.utc) - data = {} + data = OrderedDict() for stock in stocks: print stock From e09695819e56f0c529913345c3ce6ff07094e454 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Tue, 25 Sep 2012 16:10:50 -0400 Subject: [PATCH 5/8] Cleaned up example.py code. --- zipline/optimize/example.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/zipline/optimize/example.py b/zipline/optimize/example.py index be6efb9f..a8af7569 100644 --- a/zipline/optimize/example.py +++ b/zipline/optimize/example.py @@ -8,13 +8,7 @@ import numpy as np import matplotlib.pyplot as plt import cProfile from zipline.gens.mavg import MovingAverage -from zipline.gens.cov import CovTransform, cov from zipline.algorithm import TradingAlgorithm -from zipline.gens.transform import BatchTransform, batch_transform - -@batch_transform -def cov(data): - return data.price.cov() class DMA(TradingAlgorithm): """Dual Moving Average algorithm. @@ -35,14 +29,9 @@ class DMA(TradingAlgorithm): market_aware=True, days=long_window) - self.cov = cov(sids=self.sids, refresh_period=1, days=5) - def handle_data(self, data): self.events += 1 - cov = self.cov.handle_data(data) - print cov - for sid in self.sids: # access transforms via their user-defined tag if (data[sid].short_mavg['price'] > data[sid].long_mavg['price']) and not self.invested[sid]: @@ -87,8 +76,8 @@ def load_close_px(indexes=None, stocks=None): def run((short_window, long_window)): #data = pd.DataFrame.from_csv('SP500.csv') - #data = pd.DataFrame.from_csv('aapl.csv') #load_close_px() - data = load_close_px() + data = pd.load('close_px.dat') + #data = load_close_px() myalgo = DMA([0, 1], amount=100, short_window=short_window, long_window=long_window) stats = myalgo.run(data) stats['sw'] = short_window From e353900010e40dda6cb9af5da7df9920c774f1d0 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Tue, 25 Sep 2012 16:53:57 -0400 Subject: [PATCH 6/8] Added 2 year data file. --- zipline/optimize/example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zipline/optimize/example.py b/zipline/optimize/example.py index a8af7569..ef84fc0d 100644 --- a/zipline/optimize/example.py +++ b/zipline/optimize/example.py @@ -71,11 +71,11 @@ def load_close_px(indexes=None, stocks=None): df = pd.DataFrame({i: d['Close'] for i, d in enumerate(data.itervalues())}) df.index = df.index.tz_localize(pytz.utc) + df.save('close_px.dat') return df def run((short_window, long_window)): - #data = pd.DataFrame.from_csv('SP500.csv') data = pd.load('close_px.dat') #data = load_close_px() myalgo = DMA([0, 1], amount=100, short_window=short_window, long_window=long_window) From 39c3cf88ceed3de2045c6b5d4c3f3b2546dc99e8 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Tue, 25 Sep 2012 16:54:42 -0400 Subject: [PATCH 7/8] Added profiling shell script. --- zipline/optimize/profile.sh | 2 ++ 1 file changed, 2 insertions(+) create mode 100755 zipline/optimize/profile.sh diff --git a/zipline/optimize/profile.sh b/zipline/optimize/profile.sh new file mode 100755 index 00000000..0e0b602b --- /dev/null +++ b/zipline/optimize/profile.sh @@ -0,0 +1,2 @@ +#!/bin/bash +python -m cProfile -o example.prof example.py From c2dd51c24e4d8300598ae502becd91cf2e844664 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Fri, 28 Sep 2012 09:20:17 -0400 Subject: [PATCH 8/8] Cleaned up unittests in response to Eddie's comments. --- tests/test_transforms.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index e3cc1961..4f6cae0f 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -298,7 +298,8 @@ class BatchTransformTestCase(TestCase): algo = BatchTransformAlgorithm(sids=[0, 1]) algo.run(self.source) - assert algo.history_return_price_class[:2] == algo.history_return_price_decorator[:2] == [None, None], "First two iterations should return None" + self.assertEqual(algo.history_return_price_class[:2], [None, None], "First two iterations should return None") + self.assertEqual(algo.history_return_price_decorator[:2], [None, None], "First two iterations should return None") # test overloaded class for test_history in [algo.history_return_price_class, algo.history_return_price_decorator]: @@ -308,9 +309,9 @@ class BatchTransformTestCase(TestCase): def test_passing_of_args(self): algo = BatchTransformAlgorithm([0, 1], 1, kwarg='str') - algo.run(self.source) self.assertEqual(algo.args, (1,)) self.assertEqual(algo.kwargs, {'kwarg':'str'}) + + algo.run(self.source) expected_item = ((1, ), {'kwarg': 'str'}) self.assertEqual(algo.history_return_args, [None, None, expected_item, expected_item, expected_item]) -