diff --git a/rllib/env/wrappers/exception_wrapper.py b/rllib/env/wrappers/exception_wrapper.py new file mode 100644 index 000000000..de25e13ab --- /dev/null +++ b/rllib/env/wrappers/exception_wrapper.py @@ -0,0 +1,37 @@ +import logging +import traceback + +import gym + +logger = logging.getLogger(__name__) + + +class TooManyResetAttemptsException(Exception): + def __init__(self, max_attempts: int): + super().__init__( + f"Reached the maximum number of attempts ({max_attempts}) " + f"to reset an environment.") + + +class ResetOnExceptionWrapper(gym.Wrapper): + def __init__(self, env: gym.Env, max_reset_attempts: int = 5): + super().__init__(env) + self.max_reset_attempts = max_reset_attempts + + def reset(self, **kwargs): + attempt = 0 + while attempt < self.max_reset_attempts: + try: + return self.env.reset(**kwargs) + except Exception: + logger.error(traceback.format_exc()) + attempt += 1 + else: + raise TooManyResetAttemptsException(self.max_reset_attempts) + + def step(self, action): + try: + return self.env.step(action) + except Exception: + logger.error(traceback.format_exc()) + return self.reset(), 0.0, False, {"__terminated__": True} diff --git a/rllib/env/wrappers/tests/test_exception_wrapper.py b/rllib/env/wrappers/tests/test_exception_wrapper.py new file mode 100644 index 000000000..fe59eb66a --- /dev/null +++ b/rllib/env/wrappers/tests/test_exception_wrapper.py @@ -0,0 +1,57 @@ +import random +import unittest + +import gym +from ray.rllib.env.wrappers.exception_wrapper import ResetOnExceptionWrapper, \ + TooManyResetAttemptsException + + +class TestResetOnExceptionWrapper(unittest.TestCase): + def test_unstable_env(self): + class UnstableEnv(gym.Env): + observation_space = gym.spaces.Discrete(2) + action_space = gym.spaces.Discrete(2) + + def step(self, action): + if random.choice([True, False]): + raise ValueError("An error from a unstable environment.") + return self.observation_space.sample(), 0.0, False, {} + + def reset(self): + return self.observation_space.sample() + + env = UnstableEnv() + env = ResetOnExceptionWrapper(env) + + try: + self._run_for_100_steps(env) + except Exception: + self.fail() + + def test_very_unstable_env(self): + class VeryUnstableEnv(gym.Env): + observation_space = gym.spaces.Discrete(2) + action_space = gym.spaces.Discrete(2) + + def step(self, action): + return self.observation_space.sample(), 0.0, False, {} + + def reset(self): + raise ValueError("An error from a very unstable environment.") + + env = VeryUnstableEnv() + env = ResetOnExceptionWrapper(env) + self.assertRaises(TooManyResetAttemptsException, + lambda: self._run_for_100_steps(env)) + + @staticmethod + def _run_for_100_steps(env): + env.reset() + for _ in range(100): + env.step(env.action_space.sample()) + + +if __name__ == "__main__": + import sys + import pytest + sys.exit(pytest.main(["-v", __file__]))