mirror of
https://github.com/wassname/rl-portfolio-management.git
synced 2026-06-27 18:06:31 +08:00
plot axis
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import os
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
import matplotlib
|
||||
import logging
|
||||
|
||||
|
||||
class LivePlotNotebook(object):
|
||||
@@ -15,8 +17,10 @@ class LivePlotNotebook(object):
|
||||
liveplot.update(x, [ya,yb])
|
||||
"""
|
||||
|
||||
def __init__(self, log_dir=None, episode=0, labels=[], title=''):
|
||||
# TODO check warn if note using right matplotlib backend
|
||||
def __init__(self, log_dir=None, episode=0, labels=[], title='', ylabel='returns'):
|
||||
if not matplotlib.rcParams['backend'] == 'nbAgg':
|
||||
logging.warn("The liveplot callback only work when matplotlib is using the nbAgg backend. Execute 'matplotlib.use('nbAgg', force=True)'' or '%matplotlib notebook'")
|
||||
|
||||
self.log_dir = log_dir
|
||||
self.i = episode
|
||||
|
||||
|
||||
@@ -325,7 +325,7 @@ class PortfolioEnv(gym.Env):
|
||||
# plot prices and performance
|
||||
if not self._plot:
|
||||
self._plot = LivePlotNotebook(
|
||||
'/tmp', title='prices & performance', labels=self.sim.asset_names + ["Portfolio"])
|
||||
'/tmp', title='prices & performance', labels=self.sim.asset_names + ["Portfolio"], ylabel='value')
|
||||
x = df_info.index
|
||||
y_portfolio = df_info["portfolio_value"]
|
||||
y_assets = [df_info['price_' + name].cumprod()
|
||||
@@ -335,7 +335,7 @@ class PortfolioEnv(gym.Env):
|
||||
# plot portfolio weights
|
||||
if not self._plot2:
|
||||
self._plot2 = LivePlotNotebook(
|
||||
'/tmp', labels=self.sim.asset_names, title='weights')
|
||||
'/tmp', labels=self.sim.asset_names, title='weights', ylabel='weight')
|
||||
ys = [df_info['weight_' + name] for name in self.sim.asset_names]
|
||||
self._plot2.update(x, ys)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user