mirror of
https://github.com/wassname/rl-portfolio-management.git
synced 2026-06-27 18:06:31 +08:00
improve plotting of prices
This commit is contained in:
@@ -48,8 +48,10 @@ class LivePlotNotebook(object):
|
||||
# update limits
|
||||
y = np.concatenate(ys)
|
||||
y_extra = y.std() * 0.1
|
||||
self.ax.set_xlim(x.min(), x.max())
|
||||
self.ax.set_ylim(y.min() - y_extra, y.max() + y_extra)
|
||||
if x.min() != x.max():
|
||||
self.ax.set_xlim(x.min(), x.max())
|
||||
if (y.min() - y_extra) != (y.max() + y_extra):
|
||||
self.ax.set_ylim(y.min() - y_extra, y.max() + y_extra)
|
||||
|
||||
if self.log_dir:
|
||||
self.fig.savefig(os.path.join(
|
||||
|
||||
@@ -318,18 +318,19 @@ class PortfolioEnv(gym.Env):
|
||||
|
||||
if close:
|
||||
self._plot = self._plot2 = self._plot3 = None
|
||||
if not self._plot:
|
||||
self._plot = LivePlotNotebook(
|
||||
'/tmp', title='performance', labels=["buy & hold", "portfolio_value"])
|
||||
|
||||
# show a plot of portfolio vs mean market performance
|
||||
df_info = pd.DataFrame(self.infos)
|
||||
df_info.index = pd.to_datetime(df_info["date"], unit='s')
|
||||
|
||||
# plot prices and performance
|
||||
if not self._plot:
|
||||
self._plot = LivePlotNotebook(
|
||||
'/tmp', title='prices & performance', labels=self.sim.asset_names + ["Portfolio"])
|
||||
x = df_info.index
|
||||
y1 = df_info["market_value"]
|
||||
y2 = df_info["portfolio_value"]
|
||||
self._plot.update(x, [y1, y2])
|
||||
y_portfolio = df_info["portfolio_value"]
|
||||
y_assets = [df_info['price_' + name].cumprod()
|
||||
for name in self.sim.asset_names]
|
||||
self._plot.update(x, y_assets + [y_portfolio])
|
||||
|
||||
# plot portfolio weights
|
||||
if not self._plot2:
|
||||
@@ -338,12 +339,5 @@ class PortfolioEnv(gym.Env):
|
||||
ys = [df_info['weight_' + name] for name in self.sim.asset_names]
|
||||
self._plot2.update(x, ys)
|
||||
|
||||
# plot portfolio prices
|
||||
if not self._plot3:
|
||||
self._plot3 = LivePlotNotebook(
|
||||
'/tmp', labels=self.sim.asset_names, title='price changes')
|
||||
ys = [df_info['price_' + name] for name in self.sim.asset_names]
|
||||
self._plot3.update(x, ys)
|
||||
|
||||
if close:
|
||||
self._plot = self._plot2 = self._plot3 = None
|
||||
|
||||
Reference in New Issue
Block a user