mirror of
https://github.com/wassname/rl-portfolio-management.git
synced 2026-06-27 16:46:41 +08:00
error message, don't adjust action
This commit is contained in:
@@ -247,15 +247,13 @@ class PortfolioEnv(gym.Env):
|
||||
"""
|
||||
logger.debug('action: %s', action)
|
||||
|
||||
# normalise just in case
|
||||
# weights = softmax(action)
|
||||
weights = np.clip(action, 0.0, 1.0)
|
||||
weights[0] = 1.0 - sum(weights[1:])
|
||||
|
||||
# Sanity checks
|
||||
np.testing.assert_almost_equal(
|
||||
action.shape,
|
||||
(len(self.sim.asset_names),)
|
||||
(len(self.sim.asset_names),),
|
||||
err_msg='Action should contain %s floats, not %s'%(len(self.sim.asset_names), action.shape)
|
||||
)
|
||||
assert ((action >= 0) * (action <= 1)
|
||||
).all(), 'all action values should be between 0 and 1. Not %s' % action
|
||||
@@ -345,3 +343,6 @@ class PortfolioEnv(gym.Env):
|
||||
'/tmp', labels=self.sim.asset_names, title='price changes')
|
||||
ys = [df_info['price_' + name] for name in self.sim.asset_names]
|
||||
self._plot3.update(x, ys)
|
||||
|
||||
if close:
|
||||
self._plot = self._plot2 = self._plot3 = None
|
||||
|
||||
Reference in New Issue
Block a user