mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 22:36:53 +08:00
[RLlib] Add ResetOnExceptionWrapper with tests for unstable 3rd party envs (#12353)
This commit is contained in:
+37
@@ -0,0 +1,37 @@
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
import gym
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TooManyResetAttemptsException(Exception):
|
||||
def __init__(self, max_attempts: int):
|
||||
super().__init__(
|
||||
f"Reached the maximum number of attempts ({max_attempts}) "
|
||||
f"to reset an environment.")
|
||||
|
||||
|
||||
class ResetOnExceptionWrapper(gym.Wrapper):
|
||||
def __init__(self, env: gym.Env, max_reset_attempts: int = 5):
|
||||
super().__init__(env)
|
||||
self.max_reset_attempts = max_reset_attempts
|
||||
|
||||
def reset(self, **kwargs):
|
||||
attempt = 0
|
||||
while attempt < self.max_reset_attempts:
|
||||
try:
|
||||
return self.env.reset(**kwargs)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
attempt += 1
|
||||
else:
|
||||
raise TooManyResetAttemptsException(self.max_reset_attempts)
|
||||
|
||||
def step(self, action):
|
||||
try:
|
||||
return self.env.step(action)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
return self.reset(), 0.0, False, {"__terminated__": True}
|
||||
@@ -0,0 +1,57 @@
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import gym
|
||||
from ray.rllib.env.wrappers.exception_wrapper import ResetOnExceptionWrapper, \
|
||||
TooManyResetAttemptsException
|
||||
|
||||
|
||||
class TestResetOnExceptionWrapper(unittest.TestCase):
|
||||
def test_unstable_env(self):
|
||||
class UnstableEnv(gym.Env):
|
||||
observation_space = gym.spaces.Discrete(2)
|
||||
action_space = gym.spaces.Discrete(2)
|
||||
|
||||
def step(self, action):
|
||||
if random.choice([True, False]):
|
||||
raise ValueError("An error from a unstable environment.")
|
||||
return self.observation_space.sample(), 0.0, False, {}
|
||||
|
||||
def reset(self):
|
||||
return self.observation_space.sample()
|
||||
|
||||
env = UnstableEnv()
|
||||
env = ResetOnExceptionWrapper(env)
|
||||
|
||||
try:
|
||||
self._run_for_100_steps(env)
|
||||
except Exception:
|
||||
self.fail()
|
||||
|
||||
def test_very_unstable_env(self):
|
||||
class VeryUnstableEnv(gym.Env):
|
||||
observation_space = gym.spaces.Discrete(2)
|
||||
action_space = gym.spaces.Discrete(2)
|
||||
|
||||
def step(self, action):
|
||||
return self.observation_space.sample(), 0.0, False, {}
|
||||
|
||||
def reset(self):
|
||||
raise ValueError("An error from a very unstable environment.")
|
||||
|
||||
env = VeryUnstableEnv()
|
||||
env = ResetOnExceptionWrapper(env)
|
||||
self.assertRaises(TooManyResetAttemptsException,
|
||||
lambda: self._run_for_100_steps(env))
|
||||
|
||||
@staticmethod
|
||||
def _run_for_100_steps(env):
|
||||
env.reset()
|
||||
for _ in range(100):
|
||||
env.step(env.action_space.sample())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
import pytest
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
Reference in New Issue
Block a user