diff --git a/src/environments/portfolio.py b/src/environments/portfolio.py index 5bb2635..dfef598 100644 --- a/src/environments/portfolio.py +++ b/src/environments/portfolio.py @@ -247,15 +247,13 @@ class PortfolioEnv(gym.Env): """ logger.debug('action: %s', action) - # normalise just in case - # weights = softmax(action) weights = np.clip(action, 0.0, 1.0) - weights[0] = 1.0 - sum(weights[1:]) # Sanity checks np.testing.assert_almost_equal( action.shape, - (len(self.sim.asset_names),) + (len(self.sim.asset_names),), + err_msg='Action should contain %s floats, not %s'%(len(self.sim.asset_names), action.shape) ) assert ((action >= 0) * (action <= 1) ).all(), 'all action values should be between 0 and 1. Not %s' % action @@ -345,3 +343,6 @@ class PortfolioEnv(gym.Env): '/tmp', labels=self.sim.asset_names, title='price changes') ys = [df_info['price_' + name] for name in self.sim.asset_names] self._plot3.update(x, ys) + + if close: + self._plot = self._plot2 = self._plot3 = None