diff --git a/rl_portfolio_management/callbacks/notebook_plot.py b/rl_portfolio_management/callbacks/notebook_plot.py index 94db233..b39c12b 100644 --- a/rl_portfolio_management/callbacks/notebook_plot.py +++ b/rl_portfolio_management/callbacks/notebook_plot.py @@ -1,6 +1,8 @@ import os import numpy as np from matplotlib import pyplot as plt +import matplotlib +import logging class LivePlotNotebook(object): @@ -15,8 +17,10 @@ class LivePlotNotebook(object): liveplot.update(x, [ya,yb]) """ - def __init__(self, log_dir=None, episode=0, labels=[], title=''): - # TODO check warn if note using right matplotlib backend + def __init__(self, log_dir=None, episode=0, labels=[], title='', ylabel='returns'): + if not matplotlib.rcParams['backend'] == 'nbAgg': + logging.warn("The liveplot callback only work when matplotlib is using the nbAgg backend. Execute 'matplotlib.use('nbAgg', force=True)'' or '%matplotlib notebook'") + self.log_dir = log_dir self.i = episode diff --git a/rl_portfolio_management/environments/portfolio.py b/rl_portfolio_management/environments/portfolio.py index 284b0c2..870947f 100644 --- a/rl_portfolio_management/environments/portfolio.py +++ b/rl_portfolio_management/environments/portfolio.py @@ -325,7 +325,7 @@ class PortfolioEnv(gym.Env): # plot prices and performance if not self._plot: self._plot = LivePlotNotebook( - '/tmp', title='prices & performance', labels=self.sim.asset_names + ["Portfolio"]) + '/tmp', title='prices & performance', labels=self.sim.asset_names + ["Portfolio"], ylabel='value') x = df_info.index y_portfolio = df_info["portfolio_value"] y_assets = [df_info['price_' + name].cumprod() @@ -335,7 +335,7 @@ class PortfolioEnv(gym.Env): # plot portfolio weights if not self._plot2: self._plot2 = LivePlotNotebook( - '/tmp', labels=self.sim.asset_names, title='weights') + '/tmp', labels=self.sim.asset_names, title='weights', ylabel='weight') ys = [df_info['weight_' + name] for name in self.sim.asset_names] self._plot2.update(x, ys)