mirror of
https://github.com/wassname/rl-portfolio-management.git
synced 2026-06-27 18:06:31 +08:00
fix plotting
This commit is contained in:
@@ -173,9 +173,9 @@ class PortfolioSim(object):
|
||||
"cost": mu1,
|
||||
}
|
||||
# record weights and prices
|
||||
for i in range(len(self.asset_names)):
|
||||
info['weight_' + self.asset_names[i]] = w1[i]
|
||||
info['price_' + self.asset_names[i]] = y1[i]
|
||||
for i, name in enumerate(['BTCBTC'] + self.asset_names):
|
||||
info['weight_' + name] = w1[i]
|
||||
info['price_' + name] = y1[i]
|
||||
|
||||
self.infos.append(info)
|
||||
return reward, info, done
|
||||
@@ -352,22 +352,23 @@ class PortfolioEnv(gym.Env):
|
||||
df_info.index = pd.to_datetime(df_info["date"], unit='s')
|
||||
|
||||
# plot prices and performance
|
||||
all_assets = ['BTCBTC'] + self.sim.asset_names
|
||||
if not self._plot:
|
||||
self._plot_dir = os.path.join(self.log_dir, 'notebook_plot_prices_' + str(time.time())) if self.log_dir else None
|
||||
self._plot = LivePlotNotebook(
|
||||
log_dir=self._plot_dir, title='prices & performance', labels=self.sim.asset_names + ["Portfolio"], ylabel='value')
|
||||
log_dir=self._plot_dir, title='prices & performance', labels=all_assets + ["Portfolio"], ylabel='value')
|
||||
x = df_info.index
|
||||
y_portfolio = df_info["portfolio_value"]
|
||||
y_assets = [df_info['price_' + name].cumprod()
|
||||
for name in self.sim.asset_names]
|
||||
for name in all_assets]
|
||||
self._plot.update(x, y_assets + [y_portfolio])
|
||||
|
||||
# plot portfolio weights
|
||||
if not self._plot2:
|
||||
self._plot_dir2 = os.path.join(self.log_dir, 'notebook_plot_weights_' + str(time.time())) if self.log_dir else None
|
||||
self._plot2 = LivePlotNotebook(
|
||||
log_dir=self._plot_dir2, labels=self.sim.asset_names, title='weights', ylabel='weight')
|
||||
ys = [df_info['weight_' + name] for name in self.sim.asset_names]
|
||||
log_dir=self._plot_dir2, labels=all_assets, title='weights', ylabel='weight')
|
||||
ys = [df_info['weight_' + name] for name in all_assets]
|
||||
self._plot2.update(x, ys)
|
||||
|
||||
# plot portfolio costs
|
||||
|
||||
Reference in New Issue
Block a user