[rllib] Don't call reset() unless necessary for multi-agent envs

This commit is contained in:
Eric Liang
2019-01-20 15:00:18 -08:00
committed by GitHub
parent 0dad4e6a25
commit f0e6523323
2 changed files with 16 additions and 1 deletions
+8 -1
View File
@@ -319,14 +319,21 @@ class _MultiAgentEnvToAsync(AsyncVectorEnv):
self.dones.remove(env_id)
return obs
@override(AsyncVectorEnv)
def get_unwrapped(self):
return [state.env for state in self.env_states]
class _MultiAgentEnvState(object):
def __init__(self, env):
assert isinstance(env, MultiAgentEnv)
self.env = env
self.reset()
self.initialized = False
def poll(self):
if not self.initialized:
self.reset()
self.initialized = True
obs, rew, dones, info = (self.last_obs, self.last_rewards,
self.last_dones, self.last_infos)
self.last_obs = {}
@@ -36,8 +36,10 @@ class BasicMultiAgent(MultiAgentEnv):
self.dones = set()
self.observation_space = gym.spaces.Discrete(2)
self.action_space = gym.spaces.Discrete(2)
self.resetted = False
def reset(self):
self.resetted = True
self.dones = set()
return {i: a.reset() for i, a in enumerate(self.agents)}
@@ -173,6 +175,12 @@ class TestMultiAgentEnv(unittest.TestCase):
obs, rew, done, info = env.step({0: 0})
self.assertEqual(done["__all__"], True)
def testNoResetUntilPoll(self):
env = _MultiAgentEnvToAsync(lambda v: BasicMultiAgent(2), [], 1)
self.assertFalse(env.get_unwrapped()[0].resetted)
env.poll()
self.assertTrue(env.get_unwrapped()[0].resetted)
def testVectorizeBasic(self):
env = _MultiAgentEnvToAsync(lambda v: BasicMultiAgent(2), [], 2)
obs, rew, dones, _, _ = env.poll()