mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-02 11:52:39 +08:00
BUG: fixed #75, adjusted the ruturn value of run_algorithm to support minute stats.
This commit is contained in:
@@ -2,12 +2,15 @@ import talib
|
||||
import pandas as pd
|
||||
|
||||
from catalyst import run_algorithm
|
||||
from catalyst.api import symbol
|
||||
from catalyst.api import symbol, record
|
||||
from catalyst.exchange.stats_utils import get_pretty_stats, \
|
||||
extract_transactions
|
||||
|
||||
|
||||
def initialize(context):
|
||||
print('initializing')
|
||||
context.asset = symbol('swift_btc')
|
||||
context.asset = symbol('neo_usd')
|
||||
context.base_price = None
|
||||
|
||||
|
||||
def handle_data(context, data):
|
||||
@@ -20,26 +23,104 @@ def handle_data(context, data):
|
||||
prices = data.history(
|
||||
context.asset,
|
||||
fields='price',
|
||||
bar_count=15,
|
||||
frequency='1D'
|
||||
bar_count=14,
|
||||
frequency='15T'
|
||||
)
|
||||
rsi = talib.RSI(prices.values, timeperiod=14)[-1]
|
||||
print('got rsi: {}'.format(rsi))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
# If base_price is not set, we use the current value. This is the
|
||||
# price at the first bar which we reference to calculate price_change.
|
||||
if context.base_price is None:
|
||||
context.base_price = price
|
||||
|
||||
price_change = (price - context.base_price) / context.base_price
|
||||
cash = context.portfolio.cash
|
||||
|
||||
# Now that we've collected all current data for this frame, we use
|
||||
# the record() method to save it. This data will be available as
|
||||
# a parameter of the analyze() function for further analysis.
|
||||
record(
|
||||
price=price,
|
||||
price_change=price_change,
|
||||
cash=cash
|
||||
)
|
||||
|
||||
|
||||
def analyze(context, perf):
|
||||
import matplotlib.pyplot as plt
|
||||
print('the stats: {}'.format(get_pretty_stats(perf)))
|
||||
|
||||
# The base currency of the algo exchange
|
||||
base_currency = context.exchanges.values()[0].base_currency.upper()
|
||||
|
||||
# Plot the portfolio value over time.
|
||||
ax1 = plt.subplot(611)
|
||||
perf.loc[:, 'portfolio_value'].plot(ax=ax1)
|
||||
ax1.set_ylabel('Portfolio Value ({})'.format(base_currency))
|
||||
|
||||
# Plot the price increase or decrease over time.
|
||||
ax2 = plt.subplot(612, sharex=ax1)
|
||||
perf.loc[:, 'price'].plot(ax=ax2, label='Price')
|
||||
|
||||
ax2.set_ylabel('{asset} ({base})'.format(
|
||||
asset=context.asset.symbol, base=base_currency
|
||||
))
|
||||
|
||||
transaction_df = extract_transactions(perf)
|
||||
if not transaction_df.empty:
|
||||
buy_df = transaction_df[transaction_df['amount'] > 0]
|
||||
sell_df = transaction_df[transaction_df['amount'] < 0]
|
||||
ax2.scatter(
|
||||
buy_df.index.to_pydatetime(),
|
||||
perf.loc[buy_df.index, 'price'],
|
||||
marker='^',
|
||||
s=100,
|
||||
c='green',
|
||||
label=''
|
||||
)
|
||||
ax2.scatter(
|
||||
sell_df.index.to_pydatetime(),
|
||||
perf.loc[sell_df.index, 'price'],
|
||||
marker='v',
|
||||
s=100,
|
||||
c='red',
|
||||
label=''
|
||||
)
|
||||
|
||||
ax4 = plt.subplot(613, sharex=ax1)
|
||||
perf.loc[:, 'cash'].plot(
|
||||
ax=ax4, label='Base Currency ({})'.format(base_currency)
|
||||
)
|
||||
ax4.set_ylabel('Cash ({})'.format(base_currency))
|
||||
|
||||
perf['algorithm'] = perf.loc[:, 'algorithm_period_return']
|
||||
|
||||
ax5 = plt.subplot(614, sharex=ax1)
|
||||
perf.loc[:, ['algorithm', 'price_change']].plot(ax=ax5)
|
||||
ax5.set_ylabel('Percent Change')
|
||||
|
||||
plt.legend(loc=3)
|
||||
|
||||
# Show the plot.
|
||||
plt.gcf().set_size_inches(18, 8)
|
||||
plt.show()
|
||||
pass
|
||||
|
||||
|
||||
run_algorithm(
|
||||
capital_base=250,
|
||||
start=pd.to_datetime('2015-4-1', utc=True),
|
||||
end=pd.to_datetime('2017-11-1', utc=True),
|
||||
start=pd.to_datetime('2017-11-1 0:00', utc=True),
|
||||
end=pd.to_datetime('2017-11-10 23:59', utc=True),
|
||||
data_frequency='daily',
|
||||
initialize=initialize,
|
||||
handle_data=handle_data,
|
||||
analyze=None,
|
||||
exchange_name='bittrex',
|
||||
analyze=analyze,
|
||||
exchange_name='bitfinex',
|
||||
algo_namespace='simple_loop',
|
||||
base_currency='btc'
|
||||
base_currency='usd'
|
||||
)
|
||||
# run_algorithm(
|
||||
# initialize=initialize,
|
||||
|
||||
@@ -245,19 +245,42 @@ class ExchangeTradingAlgorithmBacktest(ExchangeTradingAlgorithmBase):
|
||||
else:
|
||||
return MarketOrder()
|
||||
|
||||
def is_last_frame_of_day(self, data):
|
||||
# TODO: adjust here to support more intervals
|
||||
next_frame_dt = data.current_dt + timedelta(minutes=1)
|
||||
if next_frame_dt.date() > data.current_dt.date():
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def handle_data(self, data):
|
||||
super(ExchangeTradingAlgorithmBacktest, self).handle_data(data)
|
||||
|
||||
minute_stats = self.prepare_period_stats(
|
||||
data.current_dt, data.current_dt + timedelta(minutes=1))
|
||||
self.frame_stats.append(minute_stats)
|
||||
if self.data_frequency == 'minute':
|
||||
frame_stats = self.prepare_period_stats(
|
||||
data.current_dt, data.current_dt + timedelta(minutes=1)
|
||||
)
|
||||
self.frame_stats.append(frame_stats)
|
||||
|
||||
def analyze(self, perf):
|
||||
def _create_stats_df(self):
|
||||
stats = pd.DataFrame(self.frame_stats)
|
||||
stats.set_index('period_close', inplace=True, drop=False)
|
||||
return stats
|
||||
|
||||
def analyze(self, perf):
|
||||
stats = self._create_stats_df() if self.data_frequency == 'minute' \
|
||||
else perf
|
||||
super(ExchangeTradingAlgorithmBacktest, self).analyze(stats)
|
||||
|
||||
def run(self, data=None, overwrite_sim_params=True):
|
||||
perf = super(ExchangeTradingAlgorithmBacktest, self).run(
|
||||
data, overwrite_sim_params
|
||||
)
|
||||
# Rebuilding the stats to support minute data
|
||||
stats = self._create_stats_df() if self.data_frequency == 'minute' \
|
||||
else perf
|
||||
return stats
|
||||
|
||||
|
||||
class ExchangeTradingAlgorithmLive(ExchangeTradingAlgorithmBase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
||||
@@ -26,7 +26,7 @@ from catalyst.exchange.exchange_bcolz import BcolzExchangeBarReader, \
|
||||
from catalyst.exchange.exchange_errors import EmptyValuesInBundleError, \
|
||||
TempBundleNotFoundError, \
|
||||
NoDataAvailableOnExchange, \
|
||||
PricingDataNotLoadedError
|
||||
PricingDataNotLoadedError, DataCorruptionError
|
||||
from catalyst.exchange.exchange_utils import get_exchange_folder
|
||||
from catalyst.utils.cli import maybe_show_progress
|
||||
from catalyst.utils.paths import ensure_directory
|
||||
@@ -881,6 +881,14 @@ class ExchangeBundle:
|
||||
start_dt=start_dt,
|
||||
end_dt=end_dt
|
||||
)
|
||||
if len(arrays) == 0:
|
||||
raise DataCorruptionError(
|
||||
exchange=self.exchange.name,
|
||||
symbols=asset.symbol,
|
||||
start_dt=asset_start_dt,
|
||||
end_dt=asset_end_dt
|
||||
)
|
||||
|
||||
field_values = arrays[0][:, 0]
|
||||
|
||||
value_series = pd.Series(field_values, index=periods)
|
||||
|
||||
@@ -218,6 +218,15 @@ class PricingDataNotLoadedError(ZiplineError):
|
||||
'for details.').strip()
|
||||
|
||||
|
||||
class DataCorruptionError(ZiplineError):
|
||||
msg = ('Unable to validate data for {exchange} {symbols} in date range '
|
||||
'[{start_dt} - {end_dt}]. The data is either corrupted or '
|
||||
'unavailable. Please try deleting this bundle:'
|
||||
'\n`catalyst clean-exchange -x {exchange}\n'
|
||||
'Then, ingest the data again. Please contact the Catalyst team if '
|
||||
'the issue persists.').strip()
|
||||
|
||||
|
||||
class ApiCandlesError(ZiplineError):
|
||||
msg = ('Unable to fetch candles from the remote API: {error}.').strip()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user