From 94191b2a2ca83e86a618ad40e87f3338e93d97bf Mon Sep 17 00:00:00 2001 From: wassname Date: Sun, 12 Nov 2017 14:05:40 +0800 Subject: [PATCH] add env wrappers --- rl_portfolio_management/wrappers/__init__.py | 2 + .../wrappers/concat_states.py | 40 +++++++++++++++++++ .../wrappers/softmax_actions.py | 30 ++++++++++++++ 3 files changed, 72 insertions(+) create mode 100644 rl_portfolio_management/wrappers/__init__.py create mode 100644 rl_portfolio_management/wrappers/concat_states.py create mode 100644 rl_portfolio_management/wrappers/softmax_actions.py diff --git a/rl_portfolio_management/wrappers/__init__.py b/rl_portfolio_management/wrappers/__init__.py new file mode 100644 index 0000000..4989253 --- /dev/null +++ b/rl_portfolio_management/wrappers/__init__.py @@ -0,0 +1,2 @@ +from .concat_states import ConcatStates +from .softmax_actions import SoftmaxActions diff --git a/rl_portfolio_management/wrappers/concat_states.py b/rl_portfolio_management/wrappers/concat_states.py new file mode 100644 index 0000000..e8ecf52 --- /dev/null +++ b/rl_portfolio_management/wrappers/concat_states.py @@ -0,0 +1,40 @@ +import gym.spaces +import gym.wrappers +import numpy as np + + +def concat_states(state): + history = state["history"] + weights = state["weights"] + weight_insert_shape = (history.shape[0], 1, history.shape[2]) + weight_insert = np.ones( + weight_insert_shape) * weights[1:, np.newaxis, np.newaxis] + state = np.concatenate([history, weight_insert], axis=1) + return state + + +class ConcatStates(gym.Wrapper): + """ + Concat both state arrays for models that take a single inputs. + + Usage: + env = ConcatStates(env) + + Ref: https://github.com/openai/gym/blob/master/gym/wrappers/README.md + """ + + def __init__(self, env): + super().__init__(env) + hist_space = self.observation_space.spaces["history"] + hist_shape = hist_space.shape + self.observation_space = gym.spaces.Box(-10, 10, shape=( + hist_shape[0], hist_shape[1] + 1, hist_shape[2])) + + def step(self, action): + + state, reward, done, info = self.env.step(action) + + # concat the two state arrays, since some models only take a single output + state = concat_states(state) + + return state, reward, done, info diff --git a/rl_portfolio_management/wrappers/softmax_actions.py b/rl_portfolio_management/wrappers/softmax_actions.py new file mode 100644 index 0000000..ee8d0cc --- /dev/null +++ b/rl_portfolio_management/wrappers/softmax_actions.py @@ -0,0 +1,30 @@ +import gym.wrappers + +from ..util import softmax + + +class SoftmaxActions(gym.Wrapper): + """ + Environment wrapper to softmax actions. + + Usage: + env = gym.make('Pong-v0') + env = SoftmaxActions(env) + + Ref: https://github.com/openai/gym/blob/master/gym/wrappers/README.md + + """ + + def step(self, action): + # also it puts it in a list + if isinstance(action, list): + action = action[0] + + if isinstance(action, dict): + action = list(action[k] for k in sorted(action.keys())) + + action = softmax(action, t=1) + + observation, reward, done, info = self.env.step(action) + + return observation, reward, done, info