improve plotting of prices

This commit is contained in:
wassname
2017-10-30 15:08:12 +08:00
parent aa2cbe8d58
commit 5540aa3243
2 changed files with 12 additions and 16 deletions
@@ -48,8 +48,10 @@ class LivePlotNotebook(object):
# update limits
y = np.concatenate(ys)
y_extra = y.std() * 0.1
self.ax.set_xlim(x.min(), x.max())
self.ax.set_ylim(y.min() - y_extra, y.max() + y_extra)
if x.min() != x.max():
self.ax.set_xlim(x.min(), x.max())
if (y.min() - y_extra) != (y.max() + y_extra):
self.ax.set_ylim(y.min() - y_extra, y.max() + y_extra)
if self.log_dir:
self.fig.savefig(os.path.join(
@@ -318,18 +318,19 @@ class PortfolioEnv(gym.Env):
if close:
self._plot = self._plot2 = self._plot3 = None
if not self._plot:
self._plot = LivePlotNotebook(
'/tmp', title='performance', labels=["buy & hold", "portfolio_value"])
# show a plot of portfolio vs mean market performance
df_info = pd.DataFrame(self.infos)
df_info.index = pd.to_datetime(df_info["date"], unit='s')
# plot prices and performance
if not self._plot:
self._plot = LivePlotNotebook(
'/tmp', title='prices & performance', labels=self.sim.asset_names + ["Portfolio"])
x = df_info.index
y1 = df_info["market_value"]
y2 = df_info["portfolio_value"]
self._plot.update(x, [y1, y2])
y_portfolio = df_info["portfolio_value"]
y_assets = [df_info['price_' + name].cumprod()
for name in self.sim.asset_names]
self._plot.update(x, y_assets + [y_portfolio])
# plot portfolio weights
if not self._plot2:
@@ -338,12 +339,5 @@ class PortfolioEnv(gym.Env):
ys = [df_info['weight_' + name] for name in self.sim.asset_names]
self._plot2.update(x, ys)
# plot portfolio prices
if not self._plot3:
self._plot3 = LivePlotNotebook(
'/tmp', labels=self.sim.asset_names, title='price changes')
ys = [df_info['price_' + name] for name in self.sim.asset_names]
self._plot3.update(x, ys)
if close:
self._plot = self._plot2 = self._plot3 = None