From 05d96ce81b93c41f43d3c90aeff63e0248b9712c Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sat, 23 Feb 2019 21:23:40 -0800 Subject: [PATCH] [rllib] Raise an error if multi-agent envs terminate without a last observation for agents (#4139) * fix it * lint * Update rllib-training.rst --- doc/source/rllib-training.rst | 14 +++++ python/ray/rllib/env/base_env.py | 4 ++ .../rllib/evaluation/sample_batch_builder.py | 11 ++++ python/ray/rllib/evaluation/sampler.py | 2 + python/ray/rllib/examples/twostep_game.py | 7 ++- python/ray/rllib/test/test_multi_agent_env.py | 63 +++++++++++++++++++ 6 files changed, 99 insertions(+), 2 deletions(-) diff --git a/doc/source/rllib-training.rst b/doc/source/rllib-training.rst index 91cfb3f56..c5fd0f5c6 100644 --- a/doc/source/rllib-training.rst +++ b/doc/source/rllib-training.rst @@ -373,6 +373,20 @@ The ``"monitor": true`` config can be used to save Gym episode videos to the res openaigym.video.0.31403.video000000.meta.json openaigym.video.0.31403.video000000.mp4 +Episode Traces +~~~~~~~~~~~~~~ + +You can use the `data output API `__ to save episode traces for debugging. For example, the following command will run PPO while saving episode traces to ``/tmp/debug``. + +.. code-block:: bash + + rllib train --run=PPO --env=CartPole-v0 \ + --config='{"output": "/tmp/debug", "output_compress_columns": []}' + + # episode traces will be saved in /tmp/debug, for example + output-2019-02-23_12-02-03_worker-2_0.json + output-2019-02-23_12-02-04_worker-1_0.json + Log Verbosity ~~~~~~~~~~~~~ diff --git a/python/ray/rllib/env/base_env.py b/python/ray/rllib/env/base_env.py index 0483d7434..85993f05c 100644 --- a/python/ray/rllib/env/base_env.py +++ b/python/ray/rllib/env/base_env.py @@ -325,6 +325,10 @@ class _MultiAgentEnvToBaseEnv(BaseEnv): if set(infos).difference(set(obs)): raise ValueError("Key set for infos must be a subset of obs: " "{} vs {}".format(infos.keys(), obs.keys())) + if "__all__" not in dones: + raise ValueError( + "In multi-agent environments, '__all__': True|False must " + "be included in the 'done' dict: got {}.".format(dones)) if dones["__all__"]: self.dones.add(env_id) self.env_states[env_id].observe(obs, rewards, dones, infos) diff --git a/python/ray/rllib/evaluation/sample_batch_builder.py b/python/ray/rllib/evaluation/sample_batch_builder.py index c68a89363..211e7075b 100644 --- a/python/ray/rllib/evaluation/sample_batch_builder.py +++ b/python/ray/rllib/evaluation/sample_batch_builder.py @@ -152,6 +152,17 @@ class MultiAgentSampleBatchBuilder(object): self.agent_builders.clear() self.agent_to_policy.clear() + def check_missing_dones(self): + for agent_id, builder in self.agent_builders.items(): + if builder.buffers["dones"][-1] is not True: + raise ValueError( + "The environment terminated for all agents, but we still " + "don't have a last observation for " + "agent {} (policy {}). ".format( + agent_id, self.agent_to_policy[agent_id]) + + "Please ensure that you include the last observations " + "of all live agents when setting '__all__' done to True.") + @DeveloperAPI def build_and_reset(self, episode): """Returns the accumulated sample batches for each policy. diff --git a/python/ray/rllib/evaluation/sampler.py b/python/ray/rllib/evaluation/sampler.py index 5ead9c7f8..76a73274d 100644 --- a/python/ray/rllib/evaluation/sampler.py +++ b/python/ray/rllib/evaluation/sampler.py @@ -400,6 +400,8 @@ def _process_observations(base_env, policies, batch_builder_pool, # Cut the batch if we're not packing multiple episodes into one, # or if we've exceeded the requested batch size. if episode.batch_builder.has_pending_data(): + if dones[env_id]["__all__"]: + episode.batch_builder.check_missing_dones() if (all_done and not pack) or \ episode.batch_builder.count >= unroll_length: outputs.append(episode.batch_builder.build_and_reset(episode)) diff --git a/python/ray/rllib/examples/twostep_game.py b/python/ray/rllib/examples/twostep_game.py index 63c860979..172151c9c 100644 --- a/python/ray/rllib/examples/twostep_game.py +++ b/python/ray/rllib/examples/twostep_game.py @@ -13,7 +13,7 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv parser = argparse.ArgumentParser() parser.add_argument("--stop", type=int, default=50000) -parser.add_argument("--run", type=str, default="QMIX") +parser.add_argument("--run", type=str, default="PG") class TwoStepGame(MultiAgentEnv): @@ -86,6 +86,7 @@ if __name__ == "__main__": "num_workers": 0, "mixer": grid_search([None, "qmix", "vdn"]), } + group = True elif args.run == "APEX_QMIX": config = { "num_gpus": 0, @@ -101,14 +102,16 @@ if __name__ == "__main__": "target_network_update_freq": 500, "timesteps_per_iteration": 1000, } + group = True else: config = {} + group = False ray.init() run_experiments({ "two_step": { "run": args.run, - "env": "grouped_twostep", + "env": "grouped_twostep" if group else TwoStepGame, "stop": { "timesteps_total": args.stop, }, diff --git a/python/ray/rllib/test/test_multi_agent_env.py b/python/ray/rllib/test/test_multi_agent_env.py index f6cf34b29..3a2d31453 100644 --- a/python/ray/rllib/test/test_multi_agent_env.py +++ b/python/ray/rllib/test/test_multi_agent_env.py @@ -53,6 +53,53 @@ class BasicMultiAgent(MultiAgentEnv): return obs, rew, done, info +class EarlyDoneMultiAgent(MultiAgentEnv): + """Env for testing when the env terminates (after agent 0 does).""" + + def __init__(self): + self.agents = [MockEnv(3), MockEnv(5)] + self.dones = set() + self.last_obs = {} + self.last_rew = {} + self.last_done = {} + self.last_info = {} + self.i = 0 + self.observation_space = gym.spaces.Discrete(10) + self.action_space = gym.spaces.Discrete(2) + + def reset(self): + self.dones = set() + self.last_obs = {} + self.last_rew = {} + self.last_done = {} + self.last_info = {} + self.i = 0 + for i, a in enumerate(self.agents): + self.last_obs[i] = a.reset() + self.last_rew[i] = None + self.last_done[i] = False + self.last_info[i] = {} + obs_dict = {self.i: self.last_obs[self.i]} + self.i = (self.i + 1) % len(self.agents) + return obs_dict + + def step(self, action_dict): + assert len(self.dones) != len(self.agents) + for i, action in action_dict.items(): + (self.last_obs[i], self.last_rew[i], self.last_done[i], + self.last_info[i]) = self.agents[i].step(action) + obs = {self.i: self.last_obs[self.i]} + rew = {self.i: self.last_rew[self.i]} + done = {self.i: self.last_done[self.i]} + info = {self.i: self.last_info[self.i]} + if done[self.i]: + rew[self.i] = 0 + self.dones.add(self.i) + self.i = (self.i + 1) % len(self.agents) + done["__all__"] = len(self.dones) == len(self.agents) - 1 + return obs, rew, done, info + + class RoundRobinMultiAgent(MultiAgentEnv): """Env of N independent agents, each of which exits after 5 steps. @@ -302,6 +349,22 @@ class TestMultiAgentEnv(unittest.TestCase): batch = ev.sample() self.assertEqual(batch.count, 50) + def testSampleFromEarlyDoneEnv(self): + act_space = gym.spaces.Discrete(2) + obs_space = gym.spaces.Discrete(2) + ev = PolicyEvaluator( + env_creator=lambda _: EarlyDoneMultiAgent(), + policy_graph={ + "p0": (MockPolicyGraph, obs_space, act_space, {}), + "p1": (MockPolicyGraph, obs_space, act_space, {}), + }, + policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2), + batch_mode="complete_episodes", + batch_steps=1) + self.assertRaisesRegexp(ValueError, + ".*don't have a last observation.*", + lambda: ev.sample()) + def testMultiAgentSampleRoundRobin(self): act_space = gym.spaces.Discrete(2) obs_space = gym.spaces.Discrete(10)