From f0e65233234f8f4fb65931331fd64f765fc0dd33 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sun, 20 Jan 2019 15:00:18 -0800 Subject: [PATCH] [rllib] Don't call reset() unless necessary for multi-agent envs --- python/ray/rllib/env/async_vector_env.py | 9 ++++++++- python/ray/rllib/test/test_multi_agent_env.py | 8 ++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/python/ray/rllib/env/async_vector_env.py b/python/ray/rllib/env/async_vector_env.py index 68ff1f2f7..ad33062af 100644 --- a/python/ray/rllib/env/async_vector_env.py +++ b/python/ray/rllib/env/async_vector_env.py @@ -319,14 +319,21 @@ class _MultiAgentEnvToAsync(AsyncVectorEnv): self.dones.remove(env_id) return obs + @override(AsyncVectorEnv) + def get_unwrapped(self): + return [state.env for state in self.env_states] + class _MultiAgentEnvState(object): def __init__(self, env): assert isinstance(env, MultiAgentEnv) self.env = env - self.reset() + self.initialized = False def poll(self): + if not self.initialized: + self.reset() + self.initialized = True obs, rew, dones, info = (self.last_obs, self.last_rewards, self.last_dones, self.last_infos) self.last_obs = {} diff --git a/python/ray/rllib/test/test_multi_agent_env.py b/python/ray/rllib/test/test_multi_agent_env.py index 6f5d3325d..c933f5b30 100644 --- a/python/ray/rllib/test/test_multi_agent_env.py +++ b/python/ray/rllib/test/test_multi_agent_env.py @@ -36,8 +36,10 @@ class BasicMultiAgent(MultiAgentEnv): self.dones = set() self.observation_space = gym.spaces.Discrete(2) self.action_space = gym.spaces.Discrete(2) + self.resetted = False def reset(self): + self.resetted = True self.dones = set() return {i: a.reset() for i, a in enumerate(self.agents)} @@ -173,6 +175,12 @@ class TestMultiAgentEnv(unittest.TestCase): obs, rew, done, info = env.step({0: 0}) self.assertEqual(done["__all__"], True) + def testNoResetUntilPoll(self): + env = _MultiAgentEnvToAsync(lambda v: BasicMultiAgent(2), [], 1) + self.assertFalse(env.get_unwrapped()[0].resetted) + env.poll() + self.assertTrue(env.get_unwrapped()[0].resetted) + def testVectorizeBasic(self): env = _MultiAgentEnvToAsync(lambda v: BasicMultiAgent(2), [], 2) obs, rew, dones, _, _ = env.poll()