mirror of
https://github.com/wassname/rl-portfolio-management.git
synced 2026-06-27 16:46:41 +08:00
wrappers
This commit is contained in:
@@ -1,2 +1,3 @@
|
||||
from .concat_states import ConcatStates
|
||||
from .softmax_actions import SoftmaxActions
|
||||
from .transpose_history import TransposeHistory
|
||||
|
||||
@@ -16,7 +16,7 @@ class SoftmaxActions(gym.Wrapper):
|
||||
"""
|
||||
|
||||
def step(self, action):
|
||||
# also it puts it in a list
|
||||
# also it puts it in a list
|
||||
if isinstance(action, list):
|
||||
action = action[0]
|
||||
|
||||
@@ -25,6 +25,4 @@ class SoftmaxActions(gym.Wrapper):
|
||||
|
||||
action = softmax(action, t=1)
|
||||
|
||||
observation, reward, done, info = self.env.step(action)
|
||||
|
||||
return observation, reward, done, info
|
||||
return self.env.step(action)
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
import gym.wrappers
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TransposeHistory(gym.Wrapper):
|
||||
"""Transpose history."""
|
||||
|
||||
def step(self, action):
|
||||
state, reward, done, info = self.env.step(action)
|
||||
state["history"] = np.transpose(state["history"], (2, 1, 0))
|
||||
return state, reward, done, info
|
||||
|
||||
def reset(self):
|
||||
state = self.env.reset()
|
||||
state["history"] = np.transpose(state["history"], (2, 1, 0))
|
||||
return state
|
||||
+19
-2
@@ -1,7 +1,8 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
from rl_portfolio_management.environments import PortfolioEnv, env_specs
|
||||
from rl_portfolio_management.wrappers import ConcatStates, SoftmaxActions
|
||||
from rl_portfolio_management.environments import PortfolioEnv
|
||||
from rl_portfolio_management.wrappers import ConcatStates, SoftmaxActions, TransposeHistory
|
||||
|
||||
|
||||
def test_concat():
|
||||
@@ -22,3 +23,19 @@ def test_softmax():
|
||||
# should be no problem with actions that don't sum to one
|
||||
action = env.action_space.sample() * 100
|
||||
obs, rew, done, info = env.step(action)
|
||||
|
||||
|
||||
def test_transpose():
|
||||
env0 = gym.make("CryptoPortfolioEIIE-v0")
|
||||
obs0 = env0.reset()
|
||||
transposed_shape = np.transpose(obs0["history"], (2, 1, 0)).shape
|
||||
|
||||
env = gym.make("CryptoPortfolioEIIE-v0")
|
||||
env = TransposeHistory(env)
|
||||
obs = env.reset()
|
||||
assert obs["history"].shape == transposed_shape
|
||||
# should be no problem with actions that don't sum to one
|
||||
action = env.action_space.sample()
|
||||
action /= action.sum()
|
||||
obs, rew, done, info = env.step(action)
|
||||
assert obs["history"].shape == transposed_shape
|
||||
|
||||
Reference in New Issue
Block a user