[rllib] Fix crash when setting horizon in multiagent

If a horizon is set, an env terminates without done=True.
This commit is contained in:
Eric Liang
2018-08-03 16:37:56 -07:00
committed by GitHub
parent d5dda1ebf2
commit 9449d07eca
3 changed files with 17 additions and 2 deletions
+1 -1
View File
@@ -36,7 +36,7 @@ The ``train.py`` script has a number of options you can show by running
The most important options are for choosing the environment
with ``--env`` (any OpenAI gym environment including ones registered by the user
can be used) and for choosing the algorithm with ``--run``
(available options are ``PPO``, ``PG``, ``A3C``, ``ES``, ``DDPG``, ``DDPG2``, ``DQN``, ``APEX``, and ``APEX_DDPG``).
(available options are ``PPO``, ``PG``, ``A3C``, ``IMPALA``, ``ES``, ``DDPG``, ``DQN``, ``APEX``, and ``APEX_DDPG``).
Specifying Parameters
~~~~~~~~~~~~~~~~~~~~~
+1 -1
View File
@@ -274,7 +274,7 @@ class _MultiAgentEnvToAsync(AsyncVectorEnv):
def try_reset(self, env_id):
obs = self.env_states[env_id].reset()
if obs is not None:
if obs is not None and env_id in self.dones:
self.dones.remove(env_id)
return obs
@@ -265,6 +265,21 @@ class TestMultiAgentEnv(unittest.TestCase):
self.assertEqual(batch.policy_batches["p0"]["t"].tolist(),
list(range(25)) * 6)
def testMultiAgentSampleWithHorizon(self):
act_space = gym.spaces.Discrete(2)
obs_space = gym.spaces.Discrete(2)
ev = PolicyEvaluator(
env_creator=lambda _: BasicMultiAgent(5),
policy_graph={
"p0": (MockPolicyGraph, obs_space, act_space, {}),
"p1": (MockPolicyGraph, obs_space, act_space, {}),
},
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
episode_horizon=10, # test with episode horizon set
batch_steps=50)
batch = ev.sample()
self.assertEqual(batch.count, 50)
def testMultiAgentSampleRoundRobin(self):
act_space = gym.spaces.Discrete(2)
obs_space = gym.spaces.Discrete(2)