This commit is contained in:
wassname
2017-11-12 14:24:24 +08:00
parent 4b3701ceb6
commit e9972f9400
4 changed files with 38 additions and 6 deletions
@@ -1,2 +1,3 @@
from .concat_states import ConcatStates
from .softmax_actions import SoftmaxActions
from .transpose_history import TransposeHistory
@@ -16,7 +16,7 @@ class SoftmaxActions(gym.Wrapper):
"""
def step(self, action):
# also it puts it in a list
# also it puts it in a list
if isinstance(action, list):
action = action[0]
@@ -25,6 +25,4 @@ class SoftmaxActions(gym.Wrapper):
action = softmax(action, t=1)
observation, reward, done, info = self.env.step(action)
return observation, reward, done, info
return self.env.step(action)
@@ -0,0 +1,16 @@
import gym.wrappers
import numpy as np
class TransposeHistory(gym.Wrapper):
"""Transpose history."""
def step(self, action):
state, reward, done, info = self.env.step(action)
state["history"] = np.transpose(state["history"], (2, 1, 0))
return state, reward, done, info
def reset(self):
state = self.env.reset()
state["history"] = np.transpose(state["history"], (2, 1, 0))
return state
+19 -2
View File
@@ -1,7 +1,8 @@
import gym
import numpy as np
from rl_portfolio_management.environments import PortfolioEnv, env_specs
from rl_portfolio_management.wrappers import ConcatStates, SoftmaxActions
from rl_portfolio_management.environments import PortfolioEnv
from rl_portfolio_management.wrappers import ConcatStates, SoftmaxActions, TransposeHistory
def test_concat():
@@ -22,3 +23,19 @@ def test_softmax():
# should be no problem with actions that don't sum to one
action = env.action_space.sample() * 100
obs, rew, done, info = env.step(action)
def test_transpose():
env0 = gym.make("CryptoPortfolioEIIE-v0")
obs0 = env0.reset()
transposed_shape = np.transpose(obs0["history"], (2, 1, 0)).shape
env = gym.make("CryptoPortfolioEIIE-v0")
env = TransposeHistory(env)
obs = env.reset()
assert obs["history"].shape == transposed_shape
# should be no problem with actions that don't sum to one
action = env.action_space.sample()
action /= action.sum()
obs, rew, done, info = env.step(action)
assert obs["history"].shape == transposed_shape