mirror of
https://github.com/wassname/rl-portfolio-management.git
synced 2026-06-27 16:46:41 +08:00
better plotting with lots of assets
This commit is contained in:
@@ -17,7 +17,7 @@ class LivePlotNotebook(object):
|
||||
liveplot.update(x, [ya,yb])
|
||||
"""
|
||||
|
||||
def __init__(self, log_dir=None, episode=0, labels=[], title='', ylabel='returns'):
|
||||
def __init__(self, log_dir=None, episode=0, labels=[], title='', ylabel='returns', colors=None, linestyles=None):
|
||||
if not matplotlib.rcParams['backend'] == 'nbAgg':
|
||||
logging.warn("The liveplot callback only work when matplotlib is using the nbAgg backend. Execute 'matplotlib.use('nbAgg', force=True)'' or '%matplotlib notebook'")
|
||||
|
||||
@@ -32,7 +32,14 @@ class LivePlotNotebook(object):
|
||||
fig, ax = plt.subplots(1, 1)
|
||||
|
||||
for i in range(len(labels)):
|
||||
ax.plot([0] * 20, label=labels[i], alpha=0.75, lw=2)
|
||||
ax.plot(
|
||||
[0] * 20,
|
||||
label=labels[i],
|
||||
alpha=0.75,
|
||||
lw=2,
|
||||
color=colors[i] if colors else None,
|
||||
linestyle=linestyles[i] if linestyles else None,
|
||||
)
|
||||
|
||||
ax.set_xlim(0, 1)
|
||||
ax.set_ylim(0, 1)
|
||||
|
||||
@@ -362,16 +362,18 @@ class PortfolioEnv(gym.Env):
|
||||
# plot prices and performance
|
||||
all_assets = ['BTCBTC'] + self.sim.asset_names
|
||||
if not self._plot:
|
||||
colors = [None] * len(all_assets) + ['black']
|
||||
self._plot_dir = os.path.join(
|
||||
self.log_dir, 'notebook_plot_prices_' + str(time.time())) if self.log_dir else None
|
||||
self._plot = LivePlotNotebook(
|
||||
log_dir=self._plot_dir, title='prices & performance', labels=all_assets + ["Portfolio"], ylabel='value')
|
||||
log_dir=self._plot_dir, title='prices & performance', labels=all_assets + ["Portfolio"], ylabel='value', colors=colors)
|
||||
x = df_info.index
|
||||
y_portfolio = df_info["portfolio_value"]
|
||||
y_assets = [df_info['price_' + name].cumprod()
|
||||
for name in all_assets]
|
||||
self._plot.update(x, y_assets + [y_portfolio])
|
||||
|
||||
|
||||
# plot portfolio weights
|
||||
if not self._plot2:
|
||||
self._plot_dir2 = os.path.join(
|
||||
|
||||
Reference in New Issue
Block a user