mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 07:50:30 +08:00
[rllib] Fix crash when setting horizon in multiagent
If a horizon is set, an env terminates without done=True.
This commit is contained in:
+1
-1
@@ -274,7 +274,7 @@ class _MultiAgentEnvToAsync(AsyncVectorEnv):
|
||||
|
||||
def try_reset(self, env_id):
|
||||
obs = self.env_states[env_id].reset()
|
||||
if obs is not None:
|
||||
if obs is not None and env_id in self.dones:
|
||||
self.dones.remove(env_id)
|
||||
return obs
|
||||
|
||||
|
||||
@@ -265,6 +265,21 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
self.assertEqual(batch.policy_batches["p0"]["t"].tolist(),
|
||||
list(range(25)) * 6)
|
||||
|
||||
def testMultiAgentSampleWithHorizon(self):
|
||||
act_space = gym.spaces.Discrete(2)
|
||||
obs_space = gym.spaces.Discrete(2)
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: BasicMultiAgent(5),
|
||||
policy_graph={
|
||||
"p0": (MockPolicyGraph, obs_space, act_space, {}),
|
||||
"p1": (MockPolicyGraph, obs_space, act_space, {}),
|
||||
},
|
||||
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
|
||||
episode_horizon=10, # test with episode horizon set
|
||||
batch_steps=50)
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 50)
|
||||
|
||||
def testMultiAgentSampleRoundRobin(self):
|
||||
act_space = gym.spaces.Discrete(2)
|
||||
obs_space = gym.spaces.Discrete(2)
|
||||
|
||||
Reference in New Issue
Block a user