From 0d366a350d9037dc30d94294e60bfcc496703e24 Mon Sep 17 00:00:00 2001 From: fredfortier Date: Mon, 20 Nov 2017 21:53:48 -0500 Subject: [PATCH] BUG: fixed #75, adjusted the ruturn value of `run_algorithm` to support minute stats. --- catalyst/examples/simple_loop.py | 99 ++++++++++++++++++++++--- catalyst/exchange/exchange_algorithm.py | 31 +++++++- catalyst/exchange/exchange_bundle.py | 10 ++- catalyst/exchange/exchange_errors.py | 9 +++ 4 files changed, 135 insertions(+), 14 deletions(-) diff --git a/catalyst/examples/simple_loop.py b/catalyst/examples/simple_loop.py index 8e6ce22e..bfd2d4f0 100644 --- a/catalyst/examples/simple_loop.py +++ b/catalyst/examples/simple_loop.py @@ -2,12 +2,15 @@ import talib import pandas as pd from catalyst import run_algorithm -from catalyst.api import symbol +from catalyst.api import symbol, record +from catalyst.exchange.stats_utils import get_pretty_stats, \ + extract_transactions def initialize(context): print('initializing') - context.asset = symbol('swift_btc') + context.asset = symbol('neo_usd') + context.base_price = None def handle_data(context, data): @@ -20,26 +23,104 @@ def handle_data(context, data): prices = data.history( context.asset, fields='price', - bar_count=15, - frequency='1D' + bar_count=14, + frequency='15T' ) rsi = talib.RSI(prices.values, timeperiod=14)[-1] print('got rsi: {}'.format(rsi)) except Exception as e: print(e) + # If base_price is not set, we use the current value. This is the + # price at the first bar which we reference to calculate price_change. + if context.base_price is None: + context.base_price = price + + price_change = (price - context.base_price) / context.base_price + cash = context.portfolio.cash + + # Now that we've collected all current data for this frame, we use + # the record() method to save it. This data will be available as + # a parameter of the analyze() function for further analysis. + record( + price=price, + price_change=price_change, + cash=cash + ) + + +def analyze(context, perf): + import matplotlib.pyplot as plt + print('the stats: {}'.format(get_pretty_stats(perf))) + + # The base currency of the algo exchange + base_currency = context.exchanges.values()[0].base_currency.upper() + + # Plot the portfolio value over time. + ax1 = plt.subplot(611) + perf.loc[:, 'portfolio_value'].plot(ax=ax1) + ax1.set_ylabel('Portfolio Value ({})'.format(base_currency)) + + # Plot the price increase or decrease over time. + ax2 = plt.subplot(612, sharex=ax1) + perf.loc[:, 'price'].plot(ax=ax2, label='Price') + + ax2.set_ylabel('{asset} ({base})'.format( + asset=context.asset.symbol, base=base_currency + )) + + transaction_df = extract_transactions(perf) + if not transaction_df.empty: + buy_df = transaction_df[transaction_df['amount'] > 0] + sell_df = transaction_df[transaction_df['amount'] < 0] + ax2.scatter( + buy_df.index.to_pydatetime(), + perf.loc[buy_df.index, 'price'], + marker='^', + s=100, + c='green', + label='' + ) + ax2.scatter( + sell_df.index.to_pydatetime(), + perf.loc[sell_df.index, 'price'], + marker='v', + s=100, + c='red', + label='' + ) + + ax4 = plt.subplot(613, sharex=ax1) + perf.loc[:, 'cash'].plot( + ax=ax4, label='Base Currency ({})'.format(base_currency) + ) + ax4.set_ylabel('Cash ({})'.format(base_currency)) + + perf['algorithm'] = perf.loc[:, 'algorithm_period_return'] + + ax5 = plt.subplot(614, sharex=ax1) + perf.loc[:, ['algorithm', 'price_change']].plot(ax=ax5) + ax5.set_ylabel('Percent Change') + + plt.legend(loc=3) + + # Show the plot. + plt.gcf().set_size_inches(18, 8) + plt.show() + pass + run_algorithm( capital_base=250, - start=pd.to_datetime('2015-4-1', utc=True), - end=pd.to_datetime('2017-11-1', utc=True), + start=pd.to_datetime('2017-11-1 0:00', utc=True), + end=pd.to_datetime('2017-11-10 23:59', utc=True), data_frequency='daily', initialize=initialize, handle_data=handle_data, - analyze=None, - exchange_name='bittrex', + analyze=analyze, + exchange_name='bitfinex', algo_namespace='simple_loop', - base_currency='btc' + base_currency='usd' ) # run_algorithm( # initialize=initialize, diff --git a/catalyst/exchange/exchange_algorithm.py b/catalyst/exchange/exchange_algorithm.py index f05bee7e..bb965b58 100644 --- a/catalyst/exchange/exchange_algorithm.py +++ b/catalyst/exchange/exchange_algorithm.py @@ -245,19 +245,42 @@ class ExchangeTradingAlgorithmBacktest(ExchangeTradingAlgorithmBase): else: return MarketOrder() + def is_last_frame_of_day(self, data): + # TODO: adjust here to support more intervals + next_frame_dt = data.current_dt + timedelta(minutes=1) + if next_frame_dt.date() > data.current_dt.date(): + return True + else: + return False + def handle_data(self, data): super(ExchangeTradingAlgorithmBacktest, self).handle_data(data) - minute_stats = self.prepare_period_stats( - data.current_dt, data.current_dt + timedelta(minutes=1)) - self.frame_stats.append(minute_stats) + if self.data_frequency == 'minute': + frame_stats = self.prepare_period_stats( + data.current_dt, data.current_dt + timedelta(minutes=1) + ) + self.frame_stats.append(frame_stats) - def analyze(self, perf): + def _create_stats_df(self): stats = pd.DataFrame(self.frame_stats) stats.set_index('period_close', inplace=True, drop=False) + return stats + def analyze(self, perf): + stats = self._create_stats_df() if self.data_frequency == 'minute' \ + else perf super(ExchangeTradingAlgorithmBacktest, self).analyze(stats) + def run(self, data=None, overwrite_sim_params=True): + perf = super(ExchangeTradingAlgorithmBacktest, self).run( + data, overwrite_sim_params + ) + # Rebuilding the stats to support minute data + stats = self._create_stats_df() if self.data_frequency == 'minute' \ + else perf + return stats + class ExchangeTradingAlgorithmLive(ExchangeTradingAlgorithmBase): def __init__(self, *args, **kwargs): diff --git a/catalyst/exchange/exchange_bundle.py b/catalyst/exchange/exchange_bundle.py index d6e1d558..d4ffffa0 100644 --- a/catalyst/exchange/exchange_bundle.py +++ b/catalyst/exchange/exchange_bundle.py @@ -26,7 +26,7 @@ from catalyst.exchange.exchange_bcolz import BcolzExchangeBarReader, \ from catalyst.exchange.exchange_errors import EmptyValuesInBundleError, \ TempBundleNotFoundError, \ NoDataAvailableOnExchange, \ - PricingDataNotLoadedError + PricingDataNotLoadedError, DataCorruptionError from catalyst.exchange.exchange_utils import get_exchange_folder from catalyst.utils.cli import maybe_show_progress from catalyst.utils.paths import ensure_directory @@ -881,6 +881,14 @@ class ExchangeBundle: start_dt=start_dt, end_dt=end_dt ) + if len(arrays) == 0: + raise DataCorruptionError( + exchange=self.exchange.name, + symbols=asset.symbol, + start_dt=asset_start_dt, + end_dt=asset_end_dt + ) + field_values = arrays[0][:, 0] value_series = pd.Series(field_values, index=periods) diff --git a/catalyst/exchange/exchange_errors.py b/catalyst/exchange/exchange_errors.py index 35f320d8..a36bb23a 100644 --- a/catalyst/exchange/exchange_errors.py +++ b/catalyst/exchange/exchange_errors.py @@ -218,6 +218,15 @@ class PricingDataNotLoadedError(ZiplineError): 'for details.').strip() +class DataCorruptionError(ZiplineError): + msg = ('Unable to validate data for {exchange} {symbols} in date range ' + '[{start_dt} - {end_dt}]. The data is either corrupted or ' + 'unavailable. Please try deleting this bundle:' + '\n`catalyst clean-exchange -x {exchange}\n' + 'Then, ingest the data again. Please contact the Catalyst team if ' + 'the issue persists.').strip() + + class ApiCandlesError(ZiplineError): msg = ('Unable to fetch candles from the remote API: {error}.').strip()