mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 05:39:30 +08:00
Merge addition of unit tests that exercise examples.
Also, multiple fixes to examples that the unit tests uncovered.
This commit is contained in:
@@ -12,3 +12,7 @@ pyflakes==0.6.1
|
||||
# Documentation Conversion
|
||||
|
||||
pyandoc==0.0.1
|
||||
|
||||
# Example scripts that are run during unit tests use the following:
|
||||
|
||||
matplotlib==1.2.1
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
#
|
||||
# Copyright 2013 Quantopian, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This code is based on a unittest written by John Salvatier:
|
||||
# https://github.com/pymc-devs/pymc/blob/pymc3/tests/test_examples.py
|
||||
|
||||
# Disable plotting
|
||||
#
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
|
||||
from os import path
|
||||
import os
|
||||
import fnmatch
|
||||
import imp
|
||||
|
||||
|
||||
def test_examples():
|
||||
os.chdir(example_dir())
|
||||
for fname in all_matching_files('.', '*.py'):
|
||||
yield check_example, fname
|
||||
|
||||
|
||||
def all_matching_files(d, pattern):
|
||||
def addfiles(fls, dir, nfiles):
|
||||
nfiles = fnmatch.filter(nfiles, pattern)
|
||||
nfiles = [path.join(dir, f) for f in nfiles]
|
||||
fls.extend(nfiles)
|
||||
|
||||
files = []
|
||||
path.walk(d, addfiles, files)
|
||||
return files
|
||||
|
||||
|
||||
def example_dir():
|
||||
import zipline
|
||||
d = path.dirname(zipline.__file__)
|
||||
return path.join(path.abspath(d), 'examples/')
|
||||
|
||||
|
||||
def check_example(p):
|
||||
imp.load_source('__main__', path.basename(p))
|
||||
@@ -15,6 +15,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from datetime import datetime
|
||||
import pytz
|
||||
|
||||
from zipline.algorithm import TradingAlgorithm
|
||||
from zipline.utils.factory import load_from_yahoo
|
||||
@@ -29,8 +31,15 @@ class BuyApple(TradingAlgorithm): # inherit from TradingAlgorithm
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
data = load_from_yahoo(stocks=['AAPL'], indexes={})
|
||||
start = datetime(2008, 1, 1, 0, 0, 0, 0, pytz.utc)
|
||||
end = datetime(2010, 1, 1, 0, 0, 0, 0, pytz.utc)
|
||||
data = load_from_yahoo(stocks=['AAPL'], indexes={}, start=start,
|
||||
end=end)
|
||||
simple_algo = BuyApple()
|
||||
results = simple_algo.run(data)
|
||||
results.portfolio_value.plot()
|
||||
|
||||
ax1 = plt.subplot(211)
|
||||
results.portfolio_value.plot(ax=ax1)
|
||||
ax2 = plt.subplot(212, sharex=ax1)
|
||||
data.AAPL.plot(ax=ax2)
|
||||
plt.show()
|
||||
|
||||
@@ -20,6 +20,9 @@ from zipline.algorithm import TradingAlgorithm
|
||||
from zipline.transforms import MovingAverage
|
||||
from zipline.utils.factory import load_from_yahoo
|
||||
|
||||
from datetime import datetime
|
||||
import pytz
|
||||
|
||||
|
||||
class DualMovingAverage(TradingAlgorithm):
|
||||
"""Dual Moving Average Crossover algorithm.
|
||||
@@ -30,7 +33,7 @@ class DualMovingAverage(TradingAlgorithm):
|
||||
momentum).
|
||||
|
||||
"""
|
||||
def initialize(self, short_window=200, long_window=400):
|
||||
def initialize(self, short_window=20, long_window=40):
|
||||
# Add 2 mavg transforms, one with a long window, one
|
||||
# with a short window.
|
||||
self.add_transform(MovingAverage, 'short_mavg', ['price'],
|
||||
@@ -63,16 +66,20 @@ class DualMovingAverage(TradingAlgorithm):
|
||||
sell=self.sell)
|
||||
|
||||
if __name__ == '__main__':
|
||||
data = load_from_yahoo(stocks=['AAPL'], indexes={})
|
||||
start = datetime(1990, 1, 1, 0, 0, 0, 0, pytz.utc)
|
||||
end = datetime(1991, 1, 1, 0, 0, 0, 0, pytz.utc)
|
||||
data = load_from_yahoo(stocks=['AAPL'], indexes={}, start=start,
|
||||
end=end)
|
||||
|
||||
dma = DualMovingAverage()
|
||||
results = dma.run(data)
|
||||
print results.short_mavg
|
||||
|
||||
fig = plt.figure()
|
||||
ax1 = fig.add_subplot(211)
|
||||
ax1 = fig.add_subplot(211, ylabel='portfolio value')
|
||||
results.portfolio_value.plot(ax=ax1)
|
||||
|
||||
ax2 = fig.add_subplot(212)
|
||||
data['AAPL'].plot(ax=ax2)
|
||||
data['AAPL'].plot(ax=ax2, color='r')
|
||||
results[['short_mavg', 'long_mavg']].plot(ax=ax2)
|
||||
|
||||
ax2.plot(results.ix[results.buy].index, results.short_mavg[results.buy],
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import sys
|
||||
import logbook
|
||||
import datetime
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
import pytz
|
||||
|
||||
from zipline.algorithm import TradingAlgorithm
|
||||
from zipline.transforms import MovingAverage
|
||||
from zipline.utils.factory import load_bars_from_yahoo
|
||||
from zipline.finance import slippage, commission
|
||||
from zipline.utils.factory import load_from_yahoo
|
||||
from zipline.finance import commission
|
||||
|
||||
zipline_logging = logbook.NestedSetup([
|
||||
logbook.NullHandler(level=logbook.DEBUG, bubble=True),
|
||||
@@ -38,11 +39,6 @@ class OLMAR(TradingAlgorithm):
|
||||
self.add_transform(MovingAverage, 'mavg', ['price'],
|
||||
window_length=window_length)
|
||||
|
||||
no_delay = datetime.timedelta(minutes=0)
|
||||
slip = slippage.VolumeShareSlippage(volume_limit=0.25,
|
||||
price_impact=0,
|
||||
delay=no_delay)
|
||||
self.set_slippage(slip)
|
||||
self.set_commission(commission.PerShare(cost=0))
|
||||
|
||||
def handle_data(self, data):
|
||||
@@ -157,7 +153,11 @@ def simplex_projection(v, b=1):
|
||||
|
||||
if __name__ == '__main__':
|
||||
import pylab as pl
|
||||
data = load_bars_from_yahoo(stocks=STOCKS, indexes={})
|
||||
start = datetime(2004, 1, 1, 0, 0, 0, 0, pytz.utc)
|
||||
end = datetime(2008, 1, 1, 0, 0, 0, 0, pytz.utc)
|
||||
data = load_from_yahoo(stocks=STOCKS, indexes={}, start=start,
|
||||
end=end)
|
||||
data = data.dropna()
|
||||
olmar = OLMAR()
|
||||
results = olmar.run(data)
|
||||
results.portfolio_value.plot()
|
||||
|
||||
@@ -17,6 +17,8 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import statsmodels.api as sm
|
||||
from datetime import datetime
|
||||
import pytz
|
||||
|
||||
from zipline.algorithm import TradingAlgorithm
|
||||
from zipline.transforms import batch_transform
|
||||
@@ -29,7 +31,7 @@ def ols_transform(data, sid1, sid2):
|
||||
via Ordinary Least Squares between two SIDs.
|
||||
"""
|
||||
p0 = data.price[sid1]
|
||||
p1 = sm.add_constant(data.price[sid2])
|
||||
p1 = sm.add_constant(data.price[sid2], prepend=True)
|
||||
slope, intercept = sm.OLS(p0, p1).fit().params
|
||||
|
||||
return slope, intercept
|
||||
@@ -53,7 +55,6 @@ class Pairtrade(TradingAlgorithm):
|
||||
|
||||
def initialize(self, window_length=100):
|
||||
self.spreads = []
|
||||
self.zscores = []
|
||||
self.invested = 0
|
||||
self.window_length = window_length
|
||||
self.ols_transform = ols_transform(refresh_period=self.window_length,
|
||||
@@ -65,12 +66,12 @@ class Pairtrade(TradingAlgorithm):
|
||||
params = self.ols_transform.handle_data(data, 'PEP', 'KO')
|
||||
if params is None:
|
||||
return
|
||||
slope, intercept = params
|
||||
intercept, slope = params
|
||||
|
||||
######################################################
|
||||
# 2. Compute spread and zscore
|
||||
zscore = self.compute_zscore(data, slope, intercept)
|
||||
self.zscores.append(zscore)
|
||||
self.record(zscores=zscore)
|
||||
|
||||
######################################################
|
||||
# 3. Place orders
|
||||
@@ -112,12 +113,14 @@ class Pairtrade(TradingAlgorithm):
|
||||
self.order('PEP', -1 * pep_amount)
|
||||
|
||||
if __name__ == '__main__':
|
||||
data = load_from_yahoo(stocks=['PEP', 'KO'], indexes={})
|
||||
start = datetime(2000, 1, 1, 0, 0, 0, 0, pytz.utc)
|
||||
end = datetime(2002, 1, 1, 0, 0, 0, 0, pytz.utc)
|
||||
data = load_from_yahoo(stocks=['PEP', 'KO'], indexes={},
|
||||
start=start, end=end)
|
||||
|
||||
pairtrade = Pairtrade()
|
||||
results = pairtrade.run(data)
|
||||
data['spreads'] = np.nan
|
||||
data.spreads[pairtrade.window_length:] = pairtrade.spreads
|
||||
|
||||
ax1 = plt.subplot(211)
|
||||
data[['PEP', 'KO']].plot(ax=ax1)
|
||||
@@ -125,7 +128,7 @@ if __name__ == '__main__':
|
||||
plt.setp(ax1.get_xticklabels(), visible=False)
|
||||
|
||||
ax2 = plt.subplot(212, sharex=ax1)
|
||||
data.spreads.plot(ax=ax2, color='r')
|
||||
plt.ylabel('spread')
|
||||
results.zscores.plot(ax=ax2, color='r')
|
||||
plt.ylabel('zscored spread')
|
||||
|
||||
plt.show()
|
||||
|
||||
@@ -622,7 +622,9 @@ class PerformancePeriod(object):
|
||||
|
||||
def update_last_sale(self, event):
|
||||
is_trade = event.type == zp.DATASOURCE_TYPE.TRADE
|
||||
if event.sid in self.positions and is_trade:
|
||||
has_price = not np.isnan(event.price)
|
||||
# isnan check will keep the last price if its not present
|
||||
if (event.sid in self.positions) and is_trade and has_price:
|
||||
self.positions[event.sid].last_sale_price = event.price
|
||||
index = self.index_for_position(event.sid)
|
||||
self._position_last_sale_prices[index] = event.price
|
||||
|
||||
Reference in New Issue
Block a user