mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 19:05:42 +08:00
[RLlib] Enhance reward clipping test; add action_clipping tests. (#9684)
This commit is contained in:
Vendored
+17
-4
@@ -25,10 +25,15 @@ class RandomEnv(gym.Env):
|
||||
gym.spaces.Box(low=-1.0, high=1.0, shape=(), dtype=np.float32))
|
||||
# Chance that an episode ends at any step.
|
||||
self.p_done = config.get("p_done", 0.1)
|
||||
# A max episode length.
|
||||
self.max_episode_len = config.get("max_episode_len", None)
|
||||
# Whether to check action bounds.
|
||||
self.check_action_bounds = config.get("check_action_bounds", False)
|
||||
# Steps taken so far (after last reset).
|
||||
self.steps = 0
|
||||
|
||||
def reset(self):
|
||||
self.steps = 0
|
||||
return self.observation_space.sample()
|
||||
|
||||
def step(self, action):
|
||||
@@ -40,11 +45,19 @@ class RandomEnv(gym.Env):
|
||||
raise ValueError("Illegal action for {}: {}".format(
|
||||
self.action_space, action))
|
||||
|
||||
return self.observation_space.sample(), \
|
||||
float(self.reward_space.sample()), \
|
||||
bool(np.random.choice(
|
||||
self.steps += 1
|
||||
# We are done as per our max-episode-len.
|
||||
if self.max_episode_len is not None and \
|
||||
self.steps >= self.max_episode_len:
|
||||
done = True
|
||||
# Max not reached yet -> Sample done via p_done.
|
||||
else:
|
||||
done = bool(np.random.choice(
|
||||
[True, False], p=[self.p_done, 1.0 - self.p_done]
|
||||
)), {}
|
||||
))
|
||||
|
||||
return self.observation_space.sample(), \
|
||||
float(self.reward_space.sample()), done, {}
|
||||
|
||||
|
||||
# Multi-agent version of the RandomEnv.
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from gym.spaces import Box
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
@@ -11,6 +12,17 @@ class RandomPolicy(Policy):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Whether for compute_actions, the bounds given in action_space
|
||||
# should be ignored (default: False). This is to test action-clipping
|
||||
# and any Env's reaction to bounds breaches.
|
||||
if self.config.get("ignore_action_bounds", False) and \
|
||||
isinstance(self.action_space, Box):
|
||||
self.action_space_for_sampling = Box(
|
||||
-float("inf"), float("inf"),
|
||||
shape=self.action_space.shape, dtype=self.action_space.dtype)
|
||||
else:
|
||||
self.action_space_for_sampling = self.action_space
|
||||
|
||||
@override(Policy)
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
@@ -20,7 +32,8 @@ class RandomPolicy(Policy):
|
||||
**kwargs):
|
||||
# Alternatively, a numpy array would work here as well.
|
||||
# e.g.: np.array([random.choice([0, 1])] * len(obs_batch))
|
||||
return [self.action_space.sample() for _ in obs_batch], [], {}
|
||||
return [self.action_space_for_sampling.sample() for _ in obs_batch], \
|
||||
[], {}
|
||||
|
||||
@override(Policy)
|
||||
def learn_on_batch(self, samples):
|
||||
|
||||
@@ -14,12 +14,15 @@ from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
from ray.rllib.examples.policy.random_policy import RandomPolicy
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.test_utils import check, framework_iterator
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
|
||||
class MockPolicy(RandomPolicy):
|
||||
@override(RandomPolicy)
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches=None,
|
||||
@@ -31,16 +34,19 @@ class MockPolicy(RandomPolicy):
|
||||
**kwargs):
|
||||
return np.array([random.choice([0, 1])] * len(obs_batch)), [], {}
|
||||
|
||||
@override(Policy)
|
||||
def postprocess_trajectory(self,
|
||||
batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
assert episode is not None
|
||||
super().postprocess_trajectory(batch, other_agent_batches, episode)
|
||||
return compute_advantages(
|
||||
batch, 100.0, 0.9, use_gae=False, use_critic=False)
|
||||
|
||||
|
||||
class BadPolicy(RandomPolicy):
|
||||
@override(RandomPolicy)
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches=None,
|
||||
@@ -106,12 +112,15 @@ class MockVectorEnv(VectorEnv):
|
||||
num_envs=num_envs)
|
||||
self.envs = [MockEnv(episode_length) for _ in range(num_envs)]
|
||||
|
||||
@override(VectorEnv)
|
||||
def vector_reset(self):
|
||||
return [e.reset() for e in self.envs]
|
||||
|
||||
@override(VectorEnv)
|
||||
def reset_at(self, index):
|
||||
return self.envs[index].reset()
|
||||
|
||||
@override(VectorEnv)
|
||||
def vector_step(self, actions):
|
||||
obs_batch, rew_batch, done_batch, info_batch = [], [], [], []
|
||||
for i in range(len(self.envs)):
|
||||
@@ -122,6 +131,7 @@ class MockVectorEnv(VectorEnv):
|
||||
info_batch.append(info)
|
||||
return obs_batch, rew_batch, done_batch, info_batch
|
||||
|
||||
@override(VectorEnv)
|
||||
def get_unwrapped(self):
|
||||
return self.envs
|
||||
|
||||
@@ -266,8 +276,73 @@ class TestRolloutWorker(unittest.TestCase):
|
||||
self.assertEqual(results3, [[1, 1], [1, 1], [1, 1]])
|
||||
pg.stop()
|
||||
|
||||
def test_action_clipping(self):
|
||||
from ray.rllib.examples.env.random_env import RandomEnv
|
||||
action_space = gym.spaces.Box(-2.0, 1.0, (3,))
|
||||
|
||||
# Clipping: True (clip between Policy's action_space.low/high),
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: RandomEnv(config=dict(
|
||||
action_space=action_space,
|
||||
max_episode_len=10,
|
||||
p_done=0.0,
|
||||
check_action_bounds=True,
|
||||
)),
|
||||
policy=RandomPolicy,
|
||||
policy_config=dict(
|
||||
action_space=action_space,
|
||||
ignore_action_bounds=True,
|
||||
),
|
||||
clip_actions=True,
|
||||
batch_mode="complete_episodes")
|
||||
sample = ev.sample()
|
||||
# Check, whether the action bounds have been breached (expected).
|
||||
# We still arrived here b/c we clipped according to the Env's action
|
||||
# space.
|
||||
self.assertGreater(np.max(sample["actions"]), action_space.high[0])
|
||||
self.assertLess(np.min(sample["actions"]), action_space.low[0])
|
||||
ev.stop()
|
||||
|
||||
# Clipping: False and RandomPolicy produces invalid actions.
|
||||
# Expect Env to complain.
|
||||
ev2 = RolloutWorker(
|
||||
env_creator=lambda _: RandomEnv(config=dict(
|
||||
action_space=action_space,
|
||||
max_episode_len=10,
|
||||
p_done=0.0,
|
||||
check_action_bounds=True,
|
||||
)),
|
||||
policy=RandomPolicy,
|
||||
policy_config=dict(
|
||||
action_space=action_space,
|
||||
ignore_action_bounds=True,
|
||||
),
|
||||
clip_actions=False, # <- should lead to Env complaining
|
||||
batch_mode="complete_episodes")
|
||||
self.assertRaisesRegex(ValueError, r"Illegal action", ev2.sample)
|
||||
ev2.stop()
|
||||
|
||||
# Clipping: False and RandomPolicy produces valid (bounded) actions.
|
||||
# Expect "actions" in SampleBatch to be unclipped.
|
||||
ev3 = RolloutWorker(
|
||||
env_creator=lambda _: RandomEnv(config=dict(
|
||||
action_space=action_space,
|
||||
max_episode_len=10,
|
||||
p_done=0.0,
|
||||
check_action_bounds=True,
|
||||
)),
|
||||
policy=RandomPolicy,
|
||||
policy_config=dict(action_space=action_space),
|
||||
# Should not be a problem as RandomPolicy abides to bounds.
|
||||
clip_actions=False,
|
||||
batch_mode="complete_episodes")
|
||||
sample = ev3.sample()
|
||||
self.assertGreater(np.min(sample["actions"]), action_space.low[0])
|
||||
self.assertLess(np.max(sample["actions"]), action_space.high[0])
|
||||
ev3.stop()
|
||||
|
||||
def test_reward_clipping(self):
|
||||
# Clipping: on.
|
||||
# Clipping: True (clip between -1.0 and 1.0).
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: MockEnv2(episode_length=10),
|
||||
policy=MockPolicy,
|
||||
@@ -278,7 +353,27 @@ class TestRolloutWorker(unittest.TestCase):
|
||||
self.assertEqual(result["episode_reward_mean"], 1000)
|
||||
ev.stop()
|
||||
|
||||
# Clipping: off.
|
||||
from ray.rllib.examples.env.random_env import RandomEnv
|
||||
|
||||
# Clipping in certain range (-2.0, 2.0).
|
||||
ev2 = RolloutWorker(
|
||||
env_creator=lambda _: RandomEnv(
|
||||
dict(
|
||||
reward_space=gym.spaces.Box(low=-10, high=10, shape=()),
|
||||
p_done=0.0,
|
||||
max_episode_len=10,
|
||||
)),
|
||||
policy=MockPolicy,
|
||||
clip_rewards=2.0,
|
||||
batch_mode="complete_episodes")
|
||||
sample = ev2.sample()
|
||||
self.assertEqual(max(sample["rewards"]), 2.0)
|
||||
self.assertEqual(min(sample["rewards"]), -2.0)
|
||||
self.assertLess(np.mean(sample["rewards"]), 0.5)
|
||||
self.assertGreater(np.mean(sample["rewards"]), -0.5)
|
||||
ev2.stop()
|
||||
|
||||
# Clipping: Off.
|
||||
ev2 = RolloutWorker(
|
||||
env_creator=lambda _: MockEnv2(episode_length=10),
|
||||
policy=MockPolicy,
|
||||
|
||||
Reference in New Issue
Block a user