[rllib] Fix crash when setting horizon in multiagent

If a horizon is set, an env terminates without done=True.
This commit is contained in:
Eric Liang
2018-08-03 16:37:56 -07:00
committed by GitHub
parent d5dda1ebf2
commit 9449d07eca
3 changed files with 17 additions and 2 deletions
+1 -1
View File
@@ -274,7 +274,7 @@ class _MultiAgentEnvToAsync(AsyncVectorEnv):
def try_reset(self, env_id):
obs = self.env_states[env_id].reset()
if obs is not None:
if obs is not None and env_id in self.dones:
self.dones.remove(env_id)
return obs
@@ -265,6 +265,21 @@ class TestMultiAgentEnv(unittest.TestCase):
self.assertEqual(batch.policy_batches["p0"]["t"].tolist(),
list(range(25)) * 6)
def testMultiAgentSampleWithHorizon(self):
act_space = gym.spaces.Discrete(2)
obs_space = gym.spaces.Discrete(2)
ev = PolicyEvaluator(
env_creator=lambda _: BasicMultiAgent(5),
policy_graph={
"p0": (MockPolicyGraph, obs_space, act_space, {}),
"p1": (MockPolicyGraph, obs_space, act_space, {}),
},
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
episode_horizon=10, # test with episode horizon set
batch_steps=50)
batch = ev.sample()
self.assertEqual(batch.count, 50)
def testMultiAgentSampleRoundRobin(self):
act_space = gym.spaces.Discrete(2)
obs_space = gym.spaces.Discrete(2)