ENH: Add if name == main blocks to examples

This commit is contained in:
Stewart Douglas
2015-09-09 12:59:02 -04:00
parent 3f4bbc521a
commit 723c5bb069
2 changed files with 79 additions and 0 deletions
+35
View File
@@ -24,3 +24,38 @@ def initialize(context):
def handle_data(context, data):
order(symbol('AAPL'), 10)
record(AAPL=data[symbol('AAPL')].price)
# Note: this if-block should be removed if running
# this algorithm on quantopian.com
if __name__ == '__main__':
from datetime import datetime
import matplotlib.pyplot as plt
import pytz
from zipline.algorithm import TradingAlgorithm
from zipline.utils.factory import load_from_yahoo
# Set the simulation start and end dates
start = datetime(2014, 1, 1, 0, 0, 0, 0, pytz.utc)
end = datetime(2014, 11, 1, 0, 0, 0, 0, pytz.utc)
# Load price data from yahoo.
data = load_from_yahoo(stocks=['AAPL'], indexes={}, start=start,
end=end)
# Create and run the algorithm.
algo = TradingAlgorithm(initialize=initialize, handle_data=handle_data,
identifiers=['AAPL'])
results = algo.run(data)
# Plot the portfolio and asset data.
ax1 = plt.subplot(211)
results.portfolio_value.plot(ax=ax1)
ax1.set_ylabel('Portfolio value (USD)')
ax2 = plt.subplot(212, sharex=ax1)
results.AAPL.plot(ax=ax2)
ax2.set_ylabel('AAPL price (USD)')
# Show the plot.
plt.gcf().set_size_inches(18, 8)
plt.show()
+44
View File
@@ -60,3 +60,47 @@ def handle_data(context, data):
record(AAPL=data[context.sym].price,
short_mavg=short_mavg[context.sym],
long_mavg=long_mavg[context.sym])
# Note: this if-block should be removed if running
# this algorithm on quantopian.com
if __name__ == '__main__':
from datetime import datetime
import matplotlib.pyplot as plt
import pytz
from zipline.algorithm import TradingAlgorithm
from zipline.utils.factory import load_from_yahoo
# Set the simulation start and end dates
start = datetime(2011, 1, 1, 0, 0, 0, 0, pytz.utc)
end = datetime(2013, 1, 1, 0, 0, 0, 0, pytz.utc)
# Load price data from yahoo.
data = load_from_yahoo(stocks=['AAPL'], indexes={}, start=start,
end=end)
# Create and run the algorithm.
algo = TradingAlgorithm(initialize=initialize, handle_data=handle_data,
identifiers=['AAPL'])
results = algo.run(data)
# Plot the portfolio and asset data.
fig = plt.figure()
ax1 = fig.add_subplot(211)
results.portfolio_value.plot(ax=ax1)
ax1.set_ylabel('Portfolio value (USD)')
ax2 = fig.add_subplot(212)
ax2.set_ylabel('Price in (USD)')
results[['AAPL', 'short_mavg', 'long_mavg']].plot(ax=ax2)
trans = results.ix[[t != [] for t in results.transactions]]
buys = trans.ix[[t[0]['amount'] > 0 for t in
trans.transactions]]
sells = trans.ix[[t[0]['amount'] < 0 for t in trans.transactions]]
ax2.plot(buys.index, results.short_mavg.ix[buys.index],
'^', markersize=10, color='m')
ax2.plot(sells.index, results.short_mavg.ix[sells.index],
'v', markersize=10, color='k')
plt.legend(loc=0)
# Show the plot.
plt.show()