[RLlib] Add ResetOnExceptionWrapper with tests for unstable 3rd party envs (#12353)

This commit is contained in:
Tomasz Wrona
2020-11-25 08:41:58 +01:00
committed by GitHub
parent c5845c3a4e
commit 82852f0ed2
2 changed files with 94 additions and 0 deletions
+37
View File
@@ -0,0 +1,37 @@
import logging
import traceback
import gym
logger = logging.getLogger(__name__)
class TooManyResetAttemptsException(Exception):
def __init__(self, max_attempts: int):
super().__init__(
f"Reached the maximum number of attempts ({max_attempts}) "
f"to reset an environment.")
class ResetOnExceptionWrapper(gym.Wrapper):
def __init__(self, env: gym.Env, max_reset_attempts: int = 5):
super().__init__(env)
self.max_reset_attempts = max_reset_attempts
def reset(self, **kwargs):
attempt = 0
while attempt < self.max_reset_attempts:
try:
return self.env.reset(**kwargs)
except Exception:
logger.error(traceback.format_exc())
attempt += 1
else:
raise TooManyResetAttemptsException(self.max_reset_attempts)
def step(self, action):
try:
return self.env.step(action)
except Exception:
logger.error(traceback.format_exc())
return self.reset(), 0.0, False, {"__terminated__": True}
+57
View File
@@ -0,0 +1,57 @@
import random
import unittest
import gym
from ray.rllib.env.wrappers.exception_wrapper import ResetOnExceptionWrapper, \
TooManyResetAttemptsException
class TestResetOnExceptionWrapper(unittest.TestCase):
def test_unstable_env(self):
class UnstableEnv(gym.Env):
observation_space = gym.spaces.Discrete(2)
action_space = gym.spaces.Discrete(2)
def step(self, action):
if random.choice([True, False]):
raise ValueError("An error from a unstable environment.")
return self.observation_space.sample(), 0.0, False, {}
def reset(self):
return self.observation_space.sample()
env = UnstableEnv()
env = ResetOnExceptionWrapper(env)
try:
self._run_for_100_steps(env)
except Exception:
self.fail()
def test_very_unstable_env(self):
class VeryUnstableEnv(gym.Env):
observation_space = gym.spaces.Discrete(2)
action_space = gym.spaces.Discrete(2)
def step(self, action):
return self.observation_space.sample(), 0.0, False, {}
def reset(self):
raise ValueError("An error from a very unstable environment.")
env = VeryUnstableEnv()
env = ResetOnExceptionWrapper(env)
self.assertRaises(TooManyResetAttemptsException,
lambda: self._run_for_100_steps(env))
@staticmethod
def _run_for_100_steps(env):
env.reset()
for _ in range(100):
env.step(env.action_space.sample())
if __name__ == "__main__":
import sys
import pytest
sys.exit(pytest.main(["-v", __file__]))