diff --git a/rl_portfolio_management/wrappers/__init__.py b/rl_portfolio_management/wrappers/__init__.py index 4989253..c735312 100644 --- a/rl_portfolio_management/wrappers/__init__.py +++ b/rl_portfolio_management/wrappers/__init__.py @@ -1,2 +1,3 @@ from .concat_states import ConcatStates from .softmax_actions import SoftmaxActions +from .transpose_history import TransposeHistory diff --git a/rl_portfolio_management/wrappers/softmax_actions.py b/rl_portfolio_management/wrappers/softmax_actions.py index ee8d0cc..6325f2e 100644 --- a/rl_portfolio_management/wrappers/softmax_actions.py +++ b/rl_portfolio_management/wrappers/softmax_actions.py @@ -16,7 +16,7 @@ class SoftmaxActions(gym.Wrapper): """ def step(self, action): - # also it puts it in a list + # also it puts it in a list if isinstance(action, list): action = action[0] @@ -25,6 +25,4 @@ class SoftmaxActions(gym.Wrapper): action = softmax(action, t=1) - observation, reward, done, info = self.env.step(action) - - return observation, reward, done, info + return self.env.step(action) diff --git a/rl_portfolio_management/wrappers/transpose_history.py b/rl_portfolio_management/wrappers/transpose_history.py new file mode 100644 index 0000000..a046837 --- /dev/null +++ b/rl_portfolio_management/wrappers/transpose_history.py @@ -0,0 +1,16 @@ +import gym.wrappers +import numpy as np + + +class TransposeHistory(gym.Wrapper): + """Transpose history.""" + + def step(self, action): + state, reward, done, info = self.env.step(action) + state["history"] = np.transpose(state["history"], (2, 1, 0)) + return state, reward, done, info + + def reset(self): + state = self.env.reset() + state["history"] = np.transpose(state["history"], (2, 1, 0)) + return state diff --git a/test/test_wrappers.py b/test/test_wrappers.py index 460efd5..1c8445f 100644 --- a/test/test_wrappers.py +++ b/test/test_wrappers.py @@ -1,7 +1,8 @@ import gym +import numpy as np -from rl_portfolio_management.environments import PortfolioEnv, env_specs -from rl_portfolio_management.wrappers import ConcatStates, SoftmaxActions +from rl_portfolio_management.environments import PortfolioEnv +from rl_portfolio_management.wrappers import ConcatStates, SoftmaxActions, TransposeHistory def test_concat(): @@ -22,3 +23,19 @@ def test_softmax(): # should be no problem with actions that don't sum to one action = env.action_space.sample() * 100 obs, rew, done, info = env.step(action) + + +def test_transpose(): + env0 = gym.make("CryptoPortfolioEIIE-v0") + obs0 = env0.reset() + transposed_shape = np.transpose(obs0["history"], (2, 1, 0)).shape + + env = gym.make("CryptoPortfolioEIIE-v0") + env = TransposeHistory(env) + obs = env.reset() + assert obs["history"].shape == transposed_shape + # should be no problem with actions that don't sum to one + action = env.action_space.sample() + action /= action.sum() + obs, rew, done, info = env.step(action) + assert obs["history"].shape == transposed_shape