mirror of
https://github.com/wassname/rl-portfolio-management.git
synced 2026-06-27 16:46:41 +08:00
transpose wrapper
This commit is contained in:
@@ -1,16 +1,31 @@
|
||||
import gym.wrappers
|
||||
import gym.spaces
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TransposeHistory(gym.Wrapper):
|
||||
"""Transpose history."""
|
||||
|
||||
def __init__(self, env, axes=(2, 1, 0)):
|
||||
super().__init__(env)
|
||||
self.axes = axes
|
||||
|
||||
hist_space = self.observation_space.spaces["history"]
|
||||
hist_shape = hist_space.shape
|
||||
self.observation_space = gym.spaces.Dict({
|
||||
'history': gym.spaces.Box(
|
||||
hist_space.low.min(),
|
||||
hist_space.high.max(),
|
||||
(hist_shape[axes[0]], hist_shape[axes[1]], hist_shape[axes[2]])
|
||||
),
|
||||
'weights': self.observation_space.spaces["weights"]
|
||||
})
|
||||
|
||||
def step(self, action):
|
||||
state, reward, done, info = self.env.step(action)
|
||||
state["history"] = np.transpose(state["history"], (2, 1, 0))
|
||||
state["history"] = np.transpose(state["history"], self.axes)
|
||||
return state, reward, done, info
|
||||
|
||||
def reset(self):
|
||||
state = self.env.reset()
|
||||
state["history"] = np.transpose(state["history"], (2, 1, 0))
|
||||
state["history"] = np.transpose(state["history"], self.axes)
|
||||
return state
|
||||
|
||||
@@ -33,6 +33,7 @@ def test_transpose():
|
||||
env = gym.make("CryptoPortfolioEIIE-v0")
|
||||
env = TransposeHistory(env)
|
||||
obs = env.reset()
|
||||
env.observation_space.contains(obs)
|
||||
assert obs["history"].shape == transposed_shape
|
||||
# should be no problem with actions that don't sum to one
|
||||
action = env.action_space.sample()
|
||||
|
||||
Reference in New Issue
Block a user