From 9449d07ecaf15ce97a2325eab7c8f646549d445d Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Fri, 3 Aug 2018 16:37:56 -0700 Subject: [PATCH] [rllib] Fix crash when setting horizon in multiagent If a horizon is set, an env terminates without done=True. --- doc/source/rllib-training.rst | 2 +- python/ray/rllib/env/async_vector_env.py | 2 +- python/ray/rllib/test/test_multi_agent_env.py | 15 +++++++++++++++ 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/doc/source/rllib-training.rst b/doc/source/rllib-training.rst index 18ec4b9a5..4b0272690 100644 --- a/doc/source/rllib-training.rst +++ b/doc/source/rllib-training.rst @@ -36,7 +36,7 @@ The ``train.py`` script has a number of options you can show by running The most important options are for choosing the environment with ``--env`` (any OpenAI gym environment including ones registered by the user can be used) and for choosing the algorithm with ``--run`` -(available options are ``PPO``, ``PG``, ``A3C``, ``ES``, ``DDPG``, ``DDPG2``, ``DQN``, ``APEX``, and ``APEX_DDPG``). +(available options are ``PPO``, ``PG``, ``A3C``, ``IMPALA``, ``ES``, ``DDPG``, ``DQN``, ``APEX``, and ``APEX_DDPG``). Specifying Parameters ~~~~~~~~~~~~~~~~~~~~~ diff --git a/python/ray/rllib/env/async_vector_env.py b/python/ray/rllib/env/async_vector_env.py index 952eb1e09..3a1d1fec1 100644 --- a/python/ray/rllib/env/async_vector_env.py +++ b/python/ray/rllib/env/async_vector_env.py @@ -274,7 +274,7 @@ class _MultiAgentEnvToAsync(AsyncVectorEnv): def try_reset(self, env_id): obs = self.env_states[env_id].reset() - if obs is not None: + if obs is not None and env_id in self.dones: self.dones.remove(env_id) return obs diff --git a/python/ray/rllib/test/test_multi_agent_env.py b/python/ray/rllib/test/test_multi_agent_env.py index c9ddb3067..2f00ef3dd 100644 --- a/python/ray/rllib/test/test_multi_agent_env.py +++ b/python/ray/rllib/test/test_multi_agent_env.py @@ -265,6 +265,21 @@ class TestMultiAgentEnv(unittest.TestCase): self.assertEqual(batch.policy_batches["p0"]["t"].tolist(), list(range(25)) * 6) + def testMultiAgentSampleWithHorizon(self): + act_space = gym.spaces.Discrete(2) + obs_space = gym.spaces.Discrete(2) + ev = PolicyEvaluator( + env_creator=lambda _: BasicMultiAgent(5), + policy_graph={ + "p0": (MockPolicyGraph, obs_space, act_space, {}), + "p1": (MockPolicyGraph, obs_space, act_space, {}), + }, + policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2), + episode_horizon=10, # test with episode horizon set + batch_steps=50) + batch = ev.sample() + self.assertEqual(batch.count, 50) + def testMultiAgentSampleRoundRobin(self): act_space = gym.spaces.Discrete(2) obs_space = gym.spaces.Discrete(2)