mirror of
https://github.com/wassname/rl-portfolio-management.git
synced 2026-06-27 18:06:31 +08:00
add env wrappers
This commit is contained in:
@@ -0,0 +1,2 @@
|
||||
from .concat_states import ConcatStates
|
||||
from .softmax_actions import SoftmaxActions
|
||||
@@ -0,0 +1,40 @@
|
||||
import gym.spaces
|
||||
import gym.wrappers
|
||||
import numpy as np
|
||||
|
||||
|
||||
def concat_states(state):
|
||||
history = state["history"]
|
||||
weights = state["weights"]
|
||||
weight_insert_shape = (history.shape[0], 1, history.shape[2])
|
||||
weight_insert = np.ones(
|
||||
weight_insert_shape) * weights[1:, np.newaxis, np.newaxis]
|
||||
state = np.concatenate([history, weight_insert], axis=1)
|
||||
return state
|
||||
|
||||
|
||||
class ConcatStates(gym.Wrapper):
|
||||
"""
|
||||
Concat both state arrays for models that take a single inputs.
|
||||
|
||||
Usage:
|
||||
env = ConcatStates(env)
|
||||
|
||||
Ref: https://github.com/openai/gym/blob/master/gym/wrappers/README.md
|
||||
"""
|
||||
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
hist_space = self.observation_space.spaces["history"]
|
||||
hist_shape = hist_space.shape
|
||||
self.observation_space = gym.spaces.Box(-10, 10, shape=(
|
||||
hist_shape[0], hist_shape[1] + 1, hist_shape[2]))
|
||||
|
||||
def step(self, action):
|
||||
|
||||
state, reward, done, info = self.env.step(action)
|
||||
|
||||
# concat the two state arrays, since some models only take a single output
|
||||
state = concat_states(state)
|
||||
|
||||
return state, reward, done, info
|
||||
@@ -0,0 +1,30 @@
|
||||
import gym.wrappers
|
||||
|
||||
from ..util import softmax
|
||||
|
||||
|
||||
class SoftmaxActions(gym.Wrapper):
|
||||
"""
|
||||
Environment wrapper to softmax actions.
|
||||
|
||||
Usage:
|
||||
env = gym.make('Pong-v0')
|
||||
env = SoftmaxActions(env)
|
||||
|
||||
Ref: https://github.com/openai/gym/blob/master/gym/wrappers/README.md
|
||||
|
||||
"""
|
||||
|
||||
def step(self, action):
|
||||
# also it puts it in a list
|
||||
if isinstance(action, list):
|
||||
action = action[0]
|
||||
|
||||
if isinstance(action, dict):
|
||||
action = list(action[k] for k in sorted(action.keys()))
|
||||
|
||||
action = softmax(action, t=1)
|
||||
|
||||
observation, reward, done, info = self.env.step(action)
|
||||
|
||||
return observation, reward, done, info
|
||||
Reference in New Issue
Block a user