Merge addition of unit tests that exercise examples.

Also, multiple fixes to examples that the unit tests uncovered.
This commit is contained in:
Eddie Hebert
2013-05-02 17:05:06 -04:00
7 changed files with 104 additions and 25 deletions
+4
View File
@@ -12,3 +12,7 @@ pyflakes==0.6.1
# Documentation Conversion
pyandoc==0.0.1
# Example scripts that are run during unit tests use the following:
matplotlib==1.2.1
+54
View File
@@ -0,0 +1,54 @@
#
# Copyright 2013 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This code is based on a unittest written by John Salvatier:
# https://github.com/pymc-devs/pymc/blob/pymc3/tests/test_examples.py
# Disable plotting
#
import matplotlib
matplotlib.use('Agg')
from os import path
import os
import fnmatch
import imp
def test_examples():
os.chdir(example_dir())
for fname in all_matching_files('.', '*.py'):
yield check_example, fname
def all_matching_files(d, pattern):
def addfiles(fls, dir, nfiles):
nfiles = fnmatch.filter(nfiles, pattern)
nfiles = [path.join(dir, f) for f in nfiles]
fls.extend(nfiles)
files = []
path.walk(d, addfiles, files)
return files
def example_dir():
import zipline
d = path.dirname(zipline.__file__)
return path.join(path.abspath(d), 'examples/')
def check_example(p):
imp.load_source('__main__', path.basename(p))
+11 -2
View File
@@ -15,6 +15,8 @@
# limitations under the License.
import matplotlib.pyplot as plt
from datetime import datetime
import pytz
from zipline.algorithm import TradingAlgorithm
from zipline.utils.factory import load_from_yahoo
@@ -29,8 +31,15 @@ class BuyApple(TradingAlgorithm): # inherit from TradingAlgorithm
if __name__ == '__main__':
data = load_from_yahoo(stocks=['AAPL'], indexes={})
start = datetime(2008, 1, 1, 0, 0, 0, 0, pytz.utc)
end = datetime(2010, 1, 1, 0, 0, 0, 0, pytz.utc)
data = load_from_yahoo(stocks=['AAPL'], indexes={}, start=start,
end=end)
simple_algo = BuyApple()
results = simple_algo.run(data)
results.portfolio_value.plot()
ax1 = plt.subplot(211)
results.portfolio_value.plot(ax=ax1)
ax2 = plt.subplot(212, sharex=ax1)
data.AAPL.plot(ax=ax2)
plt.show()
+12 -5
View File
@@ -20,6 +20,9 @@ from zipline.algorithm import TradingAlgorithm
from zipline.transforms import MovingAverage
from zipline.utils.factory import load_from_yahoo
from datetime import datetime
import pytz
class DualMovingAverage(TradingAlgorithm):
"""Dual Moving Average Crossover algorithm.
@@ -30,7 +33,7 @@ class DualMovingAverage(TradingAlgorithm):
momentum).
"""
def initialize(self, short_window=200, long_window=400):
def initialize(self, short_window=20, long_window=40):
# Add 2 mavg transforms, one with a long window, one
# with a short window.
self.add_transform(MovingAverage, 'short_mavg', ['price'],
@@ -63,16 +66,20 @@ class DualMovingAverage(TradingAlgorithm):
sell=self.sell)
if __name__ == '__main__':
data = load_from_yahoo(stocks=['AAPL'], indexes={})
start = datetime(1990, 1, 1, 0, 0, 0, 0, pytz.utc)
end = datetime(1991, 1, 1, 0, 0, 0, 0, pytz.utc)
data = load_from_yahoo(stocks=['AAPL'], indexes={}, start=start,
end=end)
dma = DualMovingAverage()
results = dma.run(data)
print results.short_mavg
fig = plt.figure()
ax1 = fig.add_subplot(211)
ax1 = fig.add_subplot(211, ylabel='portfolio value')
results.portfolio_value.plot(ax=ax1)
ax2 = fig.add_subplot(212)
data['AAPL'].plot(ax=ax2)
data['AAPL'].plot(ax=ax2, color='r')
results[['short_mavg', 'long_mavg']].plot(ax=ax2)
ax2.plot(results.ix[results.buy].index, results.short_mavg[results.buy],
+9 -9
View File
@@ -1,12 +1,13 @@
import sys
import logbook
import datetime
import numpy as np
from datetime import datetime
import pytz
from zipline.algorithm import TradingAlgorithm
from zipline.transforms import MovingAverage
from zipline.utils.factory import load_bars_from_yahoo
from zipline.finance import slippage, commission
from zipline.utils.factory import load_from_yahoo
from zipline.finance import commission
zipline_logging = logbook.NestedSetup([
logbook.NullHandler(level=logbook.DEBUG, bubble=True),
@@ -38,11 +39,6 @@ class OLMAR(TradingAlgorithm):
self.add_transform(MovingAverage, 'mavg', ['price'],
window_length=window_length)
no_delay = datetime.timedelta(minutes=0)
slip = slippage.VolumeShareSlippage(volume_limit=0.25,
price_impact=0,
delay=no_delay)
self.set_slippage(slip)
self.set_commission(commission.PerShare(cost=0))
def handle_data(self, data):
@@ -157,7 +153,11 @@ def simplex_projection(v, b=1):
if __name__ == '__main__':
import pylab as pl
data = load_bars_from_yahoo(stocks=STOCKS, indexes={})
start = datetime(2004, 1, 1, 0, 0, 0, 0, pytz.utc)
end = datetime(2008, 1, 1, 0, 0, 0, 0, pytz.utc)
data = load_from_yahoo(stocks=STOCKS, indexes={}, start=start,
end=end)
data = data.dropna()
olmar = OLMAR()
results = olmar.run(data)
results.portfolio_value.plot()
+11 -8
View File
@@ -17,6 +17,8 @@
import matplotlib.pyplot as plt
import numpy as np
import statsmodels.api as sm
from datetime import datetime
import pytz
from zipline.algorithm import TradingAlgorithm
from zipline.transforms import batch_transform
@@ -29,7 +31,7 @@ def ols_transform(data, sid1, sid2):
via Ordinary Least Squares between two SIDs.
"""
p0 = data.price[sid1]
p1 = sm.add_constant(data.price[sid2])
p1 = sm.add_constant(data.price[sid2], prepend=True)
slope, intercept = sm.OLS(p0, p1).fit().params
return slope, intercept
@@ -53,7 +55,6 @@ class Pairtrade(TradingAlgorithm):
def initialize(self, window_length=100):
self.spreads = []
self.zscores = []
self.invested = 0
self.window_length = window_length
self.ols_transform = ols_transform(refresh_period=self.window_length,
@@ -65,12 +66,12 @@ class Pairtrade(TradingAlgorithm):
params = self.ols_transform.handle_data(data, 'PEP', 'KO')
if params is None:
return
slope, intercept = params
intercept, slope = params
######################################################
# 2. Compute spread and zscore
zscore = self.compute_zscore(data, slope, intercept)
self.zscores.append(zscore)
self.record(zscores=zscore)
######################################################
# 3. Place orders
@@ -112,12 +113,14 @@ class Pairtrade(TradingAlgorithm):
self.order('PEP', -1 * pep_amount)
if __name__ == '__main__':
data = load_from_yahoo(stocks=['PEP', 'KO'], indexes={})
start = datetime(2000, 1, 1, 0, 0, 0, 0, pytz.utc)
end = datetime(2002, 1, 1, 0, 0, 0, 0, pytz.utc)
data = load_from_yahoo(stocks=['PEP', 'KO'], indexes={},
start=start, end=end)
pairtrade = Pairtrade()
results = pairtrade.run(data)
data['spreads'] = np.nan
data.spreads[pairtrade.window_length:] = pairtrade.spreads
ax1 = plt.subplot(211)
data[['PEP', 'KO']].plot(ax=ax1)
@@ -125,7 +128,7 @@ if __name__ == '__main__':
plt.setp(ax1.get_xticklabels(), visible=False)
ax2 = plt.subplot(212, sharex=ax1)
data.spreads.plot(ax=ax2, color='r')
plt.ylabel('spread')
results.zscores.plot(ax=ax2, color='r')
plt.ylabel('zscored spread')
plt.show()
+3 -1
View File
@@ -622,7 +622,9 @@ class PerformancePeriod(object):
def update_last_sale(self, event):
is_trade = event.type == zp.DATASOURCE_TYPE.TRADE
if event.sid in self.positions and is_trade:
has_price = not np.isnan(event.price)
# isnan check will keep the last price if its not present
if (event.sid in self.positions) and is_trade and has_price:
self.positions[event.sid].last_sale_price = event.price
index = self.index_for_position(event.sid)
self._position_last_sale_prices[index] = event.price