transpose wrapper

This commit is contained in:
wassname
2017-11-12 14:41:15 +08:00
parent e9972f9400
commit ae95c65246
2 changed files with 19 additions and 3 deletions
@@ -1,16 +1,31 @@
import gym.wrappers
import gym.spaces
import numpy as np
class TransposeHistory(gym.Wrapper):
"""Transpose history."""
def __init__(self, env, axes=(2, 1, 0)):
super().__init__(env)
self.axes = axes
hist_space = self.observation_space.spaces["history"]
hist_shape = hist_space.shape
self.observation_space = gym.spaces.Dict({
'history': gym.spaces.Box(
hist_space.low.min(),
hist_space.high.max(),
(hist_shape[axes[0]], hist_shape[axes[1]], hist_shape[axes[2]])
),
'weights': self.observation_space.spaces["weights"]
})
def step(self, action):
state, reward, done, info = self.env.step(action)
state["history"] = np.transpose(state["history"], (2, 1, 0))
state["history"] = np.transpose(state["history"], self.axes)
return state, reward, done, info
def reset(self):
state = self.env.reset()
state["history"] = np.transpose(state["history"], (2, 1, 0))
state["history"] = np.transpose(state["history"], self.axes)
return state
+1
View File
@@ -33,6 +33,7 @@ def test_transpose():
env = gym.make("CryptoPortfolioEIIE-v0")
env = TransposeHistory(env)
obs = env.reset()
env.observation_space.contains(obs)
assert obs["history"].shape == transposed_shape
# should be no problem with actions that don't sum to one
action = env.action_space.sample()