BUG: fixed #75, adjusted the ruturn value of run_algorithm to support minute stats.

This commit is contained in:
fredfortier
2017-11-20 21:53:48 -05:00
parent 1e8b0c36a1
commit 0d366a350d
4 changed files with 135 additions and 14 deletions
+90 -9
View File
@@ -2,12 +2,15 @@ import talib
import pandas as pd
from catalyst import run_algorithm
from catalyst.api import symbol
from catalyst.api import symbol, record
from catalyst.exchange.stats_utils import get_pretty_stats, \
extract_transactions
def initialize(context):
print('initializing')
context.asset = symbol('swift_btc')
context.asset = symbol('neo_usd')
context.base_price = None
def handle_data(context, data):
@@ -20,26 +23,104 @@ def handle_data(context, data):
prices = data.history(
context.asset,
fields='price',
bar_count=15,
frequency='1D'
bar_count=14,
frequency='15T'
)
rsi = talib.RSI(prices.values, timeperiod=14)[-1]
print('got rsi: {}'.format(rsi))
except Exception as e:
print(e)
# If base_price is not set, we use the current value. This is the
# price at the first bar which we reference to calculate price_change.
if context.base_price is None:
context.base_price = price
price_change = (price - context.base_price) / context.base_price
cash = context.portfolio.cash
# Now that we've collected all current data for this frame, we use
# the record() method to save it. This data will be available as
# a parameter of the analyze() function for further analysis.
record(
price=price,
price_change=price_change,
cash=cash
)
def analyze(context, perf):
import matplotlib.pyplot as plt
print('the stats: {}'.format(get_pretty_stats(perf)))
# The base currency of the algo exchange
base_currency = context.exchanges.values()[0].base_currency.upper()
# Plot the portfolio value over time.
ax1 = plt.subplot(611)
perf.loc[:, 'portfolio_value'].plot(ax=ax1)
ax1.set_ylabel('Portfolio Value ({})'.format(base_currency))
# Plot the price increase or decrease over time.
ax2 = plt.subplot(612, sharex=ax1)
perf.loc[:, 'price'].plot(ax=ax2, label='Price')
ax2.set_ylabel('{asset} ({base})'.format(
asset=context.asset.symbol, base=base_currency
))
transaction_df = extract_transactions(perf)
if not transaction_df.empty:
buy_df = transaction_df[transaction_df['amount'] > 0]
sell_df = transaction_df[transaction_df['amount'] < 0]
ax2.scatter(
buy_df.index.to_pydatetime(),
perf.loc[buy_df.index, 'price'],
marker='^',
s=100,
c='green',
label=''
)
ax2.scatter(
sell_df.index.to_pydatetime(),
perf.loc[sell_df.index, 'price'],
marker='v',
s=100,
c='red',
label=''
)
ax4 = plt.subplot(613, sharex=ax1)
perf.loc[:, 'cash'].plot(
ax=ax4, label='Base Currency ({})'.format(base_currency)
)
ax4.set_ylabel('Cash ({})'.format(base_currency))
perf['algorithm'] = perf.loc[:, 'algorithm_period_return']
ax5 = plt.subplot(614, sharex=ax1)
perf.loc[:, ['algorithm', 'price_change']].plot(ax=ax5)
ax5.set_ylabel('Percent Change')
plt.legend(loc=3)
# Show the plot.
plt.gcf().set_size_inches(18, 8)
plt.show()
pass
run_algorithm(
capital_base=250,
start=pd.to_datetime('2015-4-1', utc=True),
end=pd.to_datetime('2017-11-1', utc=True),
start=pd.to_datetime('2017-11-1 0:00', utc=True),
end=pd.to_datetime('2017-11-10 23:59', utc=True),
data_frequency='daily',
initialize=initialize,
handle_data=handle_data,
analyze=None,
exchange_name='bittrex',
analyze=analyze,
exchange_name='bitfinex',
algo_namespace='simple_loop',
base_currency='btc'
base_currency='usd'
)
# run_algorithm(
# initialize=initialize,
+27 -4
View File
@@ -245,19 +245,42 @@ class ExchangeTradingAlgorithmBacktest(ExchangeTradingAlgorithmBase):
else:
return MarketOrder()
def is_last_frame_of_day(self, data):
# TODO: adjust here to support more intervals
next_frame_dt = data.current_dt + timedelta(minutes=1)
if next_frame_dt.date() > data.current_dt.date():
return True
else:
return False
def handle_data(self, data):
super(ExchangeTradingAlgorithmBacktest, self).handle_data(data)
minute_stats = self.prepare_period_stats(
data.current_dt, data.current_dt + timedelta(minutes=1))
self.frame_stats.append(minute_stats)
if self.data_frequency == 'minute':
frame_stats = self.prepare_period_stats(
data.current_dt, data.current_dt + timedelta(minutes=1)
)
self.frame_stats.append(frame_stats)
def analyze(self, perf):
def _create_stats_df(self):
stats = pd.DataFrame(self.frame_stats)
stats.set_index('period_close', inplace=True, drop=False)
return stats
def analyze(self, perf):
stats = self._create_stats_df() if self.data_frequency == 'minute' \
else perf
super(ExchangeTradingAlgorithmBacktest, self).analyze(stats)
def run(self, data=None, overwrite_sim_params=True):
perf = super(ExchangeTradingAlgorithmBacktest, self).run(
data, overwrite_sim_params
)
# Rebuilding the stats to support minute data
stats = self._create_stats_df() if self.data_frequency == 'minute' \
else perf
return stats
class ExchangeTradingAlgorithmLive(ExchangeTradingAlgorithmBase):
def __init__(self, *args, **kwargs):
+9 -1
View File
@@ -26,7 +26,7 @@ from catalyst.exchange.exchange_bcolz import BcolzExchangeBarReader, \
from catalyst.exchange.exchange_errors import EmptyValuesInBundleError, \
TempBundleNotFoundError, \
NoDataAvailableOnExchange, \
PricingDataNotLoadedError
PricingDataNotLoadedError, DataCorruptionError
from catalyst.exchange.exchange_utils import get_exchange_folder
from catalyst.utils.cli import maybe_show_progress
from catalyst.utils.paths import ensure_directory
@@ -881,6 +881,14 @@ class ExchangeBundle:
start_dt=start_dt,
end_dt=end_dt
)
if len(arrays) == 0:
raise DataCorruptionError(
exchange=self.exchange.name,
symbols=asset.symbol,
start_dt=asset_start_dt,
end_dt=asset_end_dt
)
field_values = arrays[0][:, 0]
value_series = pd.Series(field_values, index=periods)
+9
View File
@@ -218,6 +218,15 @@ class PricingDataNotLoadedError(ZiplineError):
'for details.').strip()
class DataCorruptionError(ZiplineError):
msg = ('Unable to validate data for {exchange} {symbols} in date range '
'[{start_dt} - {end_dt}]. The data is either corrupted or '
'unavailable. Please try deleting this bundle:'
'\n`catalyst clean-exchange -x {exchange}\n'
'Then, ingest the data again. Please contact the Catalyst team if '
'the issue persists.').strip()
class ApiCandlesError(ZiplineError):
msg = ('Unable to fetch candles from the remote API: {error}.').strip()