diff --git a/rl_portfolio_management/wrappers/concat_states.py b/rl_portfolio_management/wrappers/concat_states.py index e8ecf52..adf2109 100644 --- a/rl_portfolio_management/wrappers/concat_states.py +++ b/rl_portfolio_management/wrappers/concat_states.py @@ -38,3 +38,7 @@ class ConcatStates(gym.Wrapper): state = concat_states(state) return state, reward, done, info + + def reset(self): + state = self.env.reset() + return concat_states(state)