mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-27 17:47:56 +08:00
BLD: minor adjustments to the talib sample algo
This commit is contained in:
@@ -5,31 +5,29 @@
|
||||
# Simple TALib Example showing how to use various indicators in you strategy
|
||||
# Based loosly on https://github.com/mellertson/talib-macd-example/blob/master/talib-macd-matplotlib-example.py
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import talib as ta
|
||||
import datetime
|
||||
import os
|
||||
from os.path import basename
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import talib as ta
|
||||
from logbook import Logger
|
||||
from matplotlib.dates import date2num
|
||||
from matplotlib.finance import candlestick_ohlc
|
||||
from logbook import Logger
|
||||
from catalyst.exchange.stats_utils import get_pretty_stats
|
||||
|
||||
from catalyst import run_algorithm
|
||||
from catalyst.api import (
|
||||
order,
|
||||
order_target_percent,
|
||||
symbol,
|
||||
record,
|
||||
get_open_orders,
|
||||
)
|
||||
from catalyst.exchange.stats_utils import get_pretty_stats
|
||||
|
||||
algo_namespace = 'talib_sample'
|
||||
log = Logger(algo_namespace)
|
||||
|
||||
def initialize(context):
|
||||
|
||||
def initialize(context):
|
||||
log.info('Starting TALib Simple Example')
|
||||
|
||||
context.ASSET_NAME = 'BTC_USDT'
|
||||
@@ -62,17 +60,17 @@ def initialize(context):
|
||||
|
||||
pass
|
||||
|
||||
def _handle_data(context, data):
|
||||
|
||||
def _handle_data(context, data):
|
||||
# Get price, open, high, low, close
|
||||
prices = data.history(
|
||||
context.asset,
|
||||
bar_count=context.BARS,
|
||||
fields=['price','open','high','low','close'],
|
||||
fields=['price', 'open', 'high', 'low', 'close'],
|
||||
frequency='1d')
|
||||
|
||||
# Create a analysis data frame
|
||||
analysis = pd.DataFrame(index = prices.index)
|
||||
analysis = pd.DataFrame(index=prices.index)
|
||||
|
||||
# SMA FAST
|
||||
analysis['sma_f'] = ta.SMA(prices.close.as_matrix(), context.SMA_FAST)
|
||||
@@ -82,44 +80,59 @@ def _handle_data(context, data):
|
||||
# Relative Strength Index
|
||||
analysis['rsi'] = ta.RSI(prices.close.as_matrix(), context.RSI_PERIOD)
|
||||
# RSI SMA
|
||||
analysis['sma_r'] = ta.SMA(analysis.rsi.as_matrix(), context.RSI_AVG_PERIOD)
|
||||
analysis['sma_r'] = ta.SMA(analysis.rsi.as_matrix(),
|
||||
context.RSI_AVG_PERIOD)
|
||||
|
||||
# MACD, MACD Signal, MACD Histogram
|
||||
analysis['macd'], analysis['macdSignal'], analysis['macdHist'] = ta.MACD(prices.close.as_matrix(), fastperiod=context.MACD_FAST, slowperiod=context.MACD_SLOW, signalperiod=context.MACD_SIGNAL)
|
||||
|
||||
analysis['macd'], analysis['macdSignal'], analysis['macdHist'] = ta.MACD(
|
||||
prices.close.as_matrix(), fastperiod=context.MACD_FAST,
|
||||
slowperiod=context.MACD_SLOW, signalperiod=context.MACD_SIGNAL)
|
||||
|
||||
# Stochastics %K %D
|
||||
# %K = (Current Close - Lowest Low)/(Highest High - Lowest Low) * 100
|
||||
# %D = 3-day SMA of %K
|
||||
analysis['stoch_k'], analysis['stoch_d'] = ta.STOCH(prices.high.as_matrix(), prices.low.as_matrix(), prices.close.as_matrix(), slowk_period=context.STOCH_K, slowd_period=context.STOCH_D)
|
||||
analysis['stoch_k'], analysis['stoch_d'] = ta.STOCH(
|
||||
prices.high.as_matrix(), prices.low.as_matrix(),
|
||||
prices.close.as_matrix(), slowk_period=context.STOCH_K,
|
||||
slowd_period=context.STOCH_D)
|
||||
|
||||
# SMA FAST over SLOW Crossover
|
||||
analysis['sma_test'] = np.where(analysis.sma_f > analysis.sma_s, 1, 0)
|
||||
|
||||
# MACD over Signal Crossover
|
||||
analysis['macd_test'] = np.where((analysis.macd > analysis.macdSignal), 1, 0)
|
||||
analysis['macd_test'] = np.where((analysis.macd > analysis.macdSignal), 1,
|
||||
0)
|
||||
|
||||
# Stochastics OVER BOUGHT & Decreasing
|
||||
analysis['stoch_over_bought'] = np.where((analysis.stoch_k > context.STOCH_OVER_BOUGHT) & (analysis.stoch_k > analysis.stoch_k.shift(1)), 1, 0)
|
||||
analysis['stoch_over_bought'] = np.where(
|
||||
(analysis.stoch_k > context.STOCH_OVER_BOUGHT) & (
|
||||
analysis.stoch_k > analysis.stoch_k.shift(1)), 1, 0)
|
||||
|
||||
# Stochastics OVER SOLD & Increasing
|
||||
analysis['stoch_over_sold'] = np.where((analysis.stoch_k < context.STOCH_OVER_SOLD) & (analysis.stoch_k > analysis.stoch_k.shift(1)), 1, 0)
|
||||
analysis['stoch_over_sold'] = np.where(
|
||||
(analysis.stoch_k < context.STOCH_OVER_SOLD) & (
|
||||
analysis.stoch_k > analysis.stoch_k.shift(1)), 1, 0)
|
||||
|
||||
# RSI OVER BOUGHT & Decreasing
|
||||
analysis['rsi_over_bought'] = np.where((analysis.rsi > context.RSI_OVER_BOUGHT) & (analysis.rsi < analysis.rsi.shift(1)), 1, 0)
|
||||
analysis['rsi_over_bought'] = np.where(
|
||||
(analysis.rsi > context.RSI_OVER_BOUGHT) & (
|
||||
analysis.rsi < analysis.rsi.shift(1)), 1, 0)
|
||||
|
||||
# RSI OVER SOLD & Increasing
|
||||
analysis['rsi_over_sold'] = np.where((analysis.rsi < context.RSI_OVER_SOLD) & (analysis.rsi > analysis.rsi.shift(1)), 1, 0)
|
||||
analysis['rsi_over_sold'] = np.where(
|
||||
(analysis.rsi < context.RSI_OVER_SOLD) & (
|
||||
analysis.rsi > analysis.rsi.shift(1)), 1, 0)
|
||||
|
||||
# Save the prices and analysis to send to analyze
|
||||
context.prices=prices
|
||||
context.analysis=analysis
|
||||
context.prices = prices
|
||||
context.analysis = analysis
|
||||
context.price = data.current(context.asset, 'price')
|
||||
|
||||
makeOrders(context, analysis)
|
||||
|
||||
# Log the values of this bar
|
||||
logAnalysis(analysis)
|
||||
|
||||
|
||||
|
||||
def handle_data(context, data):
|
||||
log.info('handling bar {}'.format(data.current_dt))
|
||||
@@ -147,14 +160,14 @@ def analyze(context, results):
|
||||
chart(context, context.prices, context.analysis, results)
|
||||
pass
|
||||
|
||||
def makeOrders(context, analysis):
|
||||
|
||||
def makeOrders(context, analysis):
|
||||
if context.asset in context.portfolio.positions:
|
||||
|
||||
# Current position
|
||||
position = context.portfolio.positions[context.asset]
|
||||
|
||||
if(position == 0):
|
||||
|
||||
if (position == 0):
|
||||
log.info('Position Zero')
|
||||
return
|
||||
|
||||
@@ -170,7 +183,8 @@ def makeOrders(context, analysis):
|
||||
|
||||
# Sell when holding and got sell singnal
|
||||
if isSell(context, analysis):
|
||||
profit = (context.price * position.amount) - (cost_basis * position.amount)
|
||||
profit = (context.price * position.amount) - (
|
||||
cost_basis * position.amount)
|
||||
order_target_percent(
|
||||
asset=context.asset,
|
||||
target=0,
|
||||
@@ -178,16 +192,16 @@ def makeOrders(context, analysis):
|
||||
)
|
||||
log.info(
|
||||
'Sold {amount} @ {price} Profit: {profit}'.format(
|
||||
amount=position.amount,
|
||||
price=context.price,
|
||||
profit=profit
|
||||
amount=position.amount,
|
||||
price=context.price,
|
||||
profit=profit
|
||||
)
|
||||
)
|
||||
else:
|
||||
log.info('no buy or sell opportunity found')
|
||||
else:
|
||||
# Buy when not holding and got buy signal
|
||||
if isBuy(context, analysis):
|
||||
if isBuy(context, analysis):
|
||||
order(
|
||||
asset=context.asset,
|
||||
amount=context.ORDER_SIZE,
|
||||
@@ -195,17 +209,17 @@ def makeOrders(context, analysis):
|
||||
)
|
||||
log.info(
|
||||
'Bought {amount} @ {price}'.format(
|
||||
amount=context.ORDER_SIZE,
|
||||
price=context.price
|
||||
amount=context.ORDER_SIZE,
|
||||
price=context.price
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def isBuy(context, analysis):
|
||||
# Bullish SMA Crossover
|
||||
if(getLast(analysis, 'sma_test') == 1):
|
||||
if (getLast(analysis, 'sma_test') == 1):
|
||||
# Bullish MACD
|
||||
if(getLast(analysis, 'macd_test') == 1):
|
||||
if (getLast(analysis, 'macd_test') == 1):
|
||||
return True
|
||||
|
||||
# # Bullish Stochastics
|
||||
@@ -218,11 +232,12 @@ def isBuy(context, analysis):
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def isSell(context, analysis):
|
||||
# Bearish SMA Crossover
|
||||
if(getLast(analysis, 'sma_test') == 0):
|
||||
if (getLast(analysis, 'sma_test') == 0):
|
||||
# Bearish MACD
|
||||
if(getLast(analysis, 'macd_test') == 0):
|
||||
if (getLast(analysis, 'macd_test') == 0):
|
||||
return True
|
||||
|
||||
# # Bearish Stochastics
|
||||
@@ -235,6 +250,7 @@ def isSell(context, analysis):
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def chart(context, prices, analysis, results):
|
||||
results.portfolio_value.plot()
|
||||
|
||||
@@ -243,37 +259,40 @@ def chart(context, prices, analysis, results):
|
||||
|
||||
# Create the Open High Low Close Tuple
|
||||
prices_ohlc = [tuple([dates[i],
|
||||
prices.open[i],
|
||||
prices.high[i],
|
||||
prices.low[i],
|
||||
prices.close[i]]) for i in range(len(dates))]
|
||||
prices.open[i],
|
||||
prices.high[i],
|
||||
prices.low[i],
|
||||
prices.close[i]]) for i in range(len(dates))]
|
||||
|
||||
fig = plt.figure(figsize=(14,18))
|
||||
fig = plt.figure(figsize=(14, 18))
|
||||
|
||||
# Draw the candle sticks
|
||||
ax1 = fig.add_subplot(411)
|
||||
ax1.set_ylabel(context.ASSET_NAME, size=20)
|
||||
candlestick_ohlc(ax1, prices_ohlc, width=0.4 ,colorup='g', colordown='r')
|
||||
candlestick_ohlc(ax1, prices_ohlc, width=0.4, colorup='g', colordown='r')
|
||||
|
||||
# Draw Moving Averages
|
||||
analysis.sma_f.plot(ax=ax1, c='r')
|
||||
analysis.sma_s.plot(ax=ax1, c='g')
|
||||
|
||||
#RSI
|
||||
# RSI
|
||||
ax2 = fig.add_subplot(412)
|
||||
ax2.set_ylabel('RSI', size=12)
|
||||
analysis.rsi.plot(ax = ax2, c='g', label = 'Period: ' + str(context.RSI_PERIOD))
|
||||
analysis.sma_r.plot(ax = ax2, c='r', label = 'MA: ' + str(context.RSI_AVG_PERIOD))
|
||||
analysis.rsi.plot(ax=ax2, c='g',
|
||||
label='Period: ' + str(context.RSI_PERIOD))
|
||||
analysis.sma_r.plot(ax=ax2, c='r',
|
||||
label='MA: ' + str(context.RSI_AVG_PERIOD))
|
||||
ax2.axhline(y=30, c='b')
|
||||
ax2.axhline(y=50, c='black')
|
||||
ax2.axhline(y=70, c='b')
|
||||
ax2.set_ylim([0,100])
|
||||
ax2.set_ylim([0, 100])
|
||||
handles, labels = ax2.get_legend_handles_labels()
|
||||
ax2.legend(handles, labels)
|
||||
|
||||
# Draw MACD computed with Talib
|
||||
ax3 = fig.add_subplot(413)
|
||||
ax3.set_ylabel('MACD: '+ str(context.MACD_FAST) + ', ' + str(context.MACD_SLOW) + ', ' + str(context.MACD_SIGNAL), size=12)
|
||||
ax3.set_ylabel('MACD: ' + str(context.MACD_FAST) + ', ' + str(
|
||||
context.MACD_SLOW) + ', ' + str(context.MACD_SIGNAL), size=12)
|
||||
analysis.macd.plot(ax=ax3, color='b', label='Macd')
|
||||
analysis.macdSignal.plot(ax=ax3, color='g', label='Signal')
|
||||
analysis.macdHist.plot(ax=ax3, color='r', label='Hist')
|
||||
@@ -284,8 +303,10 @@ def chart(context, prices, analysis, results):
|
||||
# Stochastic plot
|
||||
ax4 = fig.add_subplot(414)
|
||||
ax4.set_ylabel('Stoch (k,d)', size=12)
|
||||
analysis.stoch_k.plot(ax=ax4, label='stoch_k:'+ str(context.STOCH_K), color='r')
|
||||
analysis.stoch_d.plot(ax=ax4, label='stoch_d:'+ str(context.STOCH_D), color='g')
|
||||
analysis.stoch_k.plot(ax=ax4, label='stoch_k:' + str(context.STOCH_K),
|
||||
color='r')
|
||||
analysis.stoch_d.plot(ax=ax4, label='stoch_d:' + str(context.STOCH_D),
|
||||
color='g')
|
||||
handles, labels = ax4.get_legend_handles_labels()
|
||||
ax4.legend(handles, labels)
|
||||
ax4.axhline(y=20, c='b')
|
||||
@@ -294,6 +315,7 @@ def chart(context, prices, analysis, results):
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def logAnalysis(analysis):
|
||||
# Log only the last value in the array
|
||||
log.info('- sma_f: {:.2f}'.format(getLast(analysis, 'sma_f')))
|
||||
@@ -303,7 +325,8 @@ def logAnalysis(analysis):
|
||||
log.info('- sma_r: {:.2f}'.format(getLast(analysis, 'sma_r')))
|
||||
|
||||
log.info('- macd: {:.2f}'.format(getLast(analysis, 'macd')))
|
||||
log.info('- macdSignal: {:.2f}'.format(getLast(analysis, 'macdSignal')))
|
||||
log.info(
|
||||
'- macdSignal: {:.2f}'.format(getLast(analysis, 'macdSignal')))
|
||||
log.info('- macdHist: {:.2f}'.format(getLast(analysis, 'macdHist')))
|
||||
|
||||
log.info('- stoch_k: {:.2f}'.format(getLast(analysis, 'stoch_k')))
|
||||
@@ -312,11 +335,30 @@ def logAnalysis(analysis):
|
||||
log.info('- sma_test: {}'.format(getLast(analysis, 'sma_test')))
|
||||
log.info('- macd_test: {}'.format(getLast(analysis, 'macd_test')))
|
||||
|
||||
log.info('- stoch_over_bought: {}'.format(getLast(analysis, 'stoch_over_bought')))
|
||||
log.info('- stoch_over_sold: {}'.format(getLast(analysis, 'stoch_over_sold')))
|
||||
log.info('- stoch_over_bought: {}'.format(
|
||||
getLast(analysis, 'stoch_over_bought')))
|
||||
log.info(
|
||||
'- stoch_over_sold: {}'.format(getLast(analysis, 'stoch_over_sold')))
|
||||
|
||||
log.info('- rsi_over_bought: {}'.format(
|
||||
getLast(analysis, 'rsi_over_bought')))
|
||||
log.info(
|
||||
'- rsi_over_sold: {}'.format(getLast(analysis, 'rsi_over_sold')))
|
||||
|
||||
log.info('- rsi_over_bought: {}'.format(getLast(analysis, 'rsi_over_bought')))
|
||||
log.info('- rsi_over_sold: {}'.format(getLast(analysis, 'rsi_over_sold')))
|
||||
|
||||
def getLast(arr, name):
|
||||
return arr[name][arr[name].index[-1]]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_algorithm(
|
||||
capital_base=10000,
|
||||
data_frequency='daily',
|
||||
initialize=initialize,
|
||||
handle_data=handle_data,
|
||||
analyze=analyze,
|
||||
exchange_name='poloniex',
|
||||
base_currency='usdt',
|
||||
start=pd.to_datetime('2016-11-1', utc=True),
|
||||
end=pd.to_datetime('2017-11-10', utc=True),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user