[rllib] Raise an error if multi-agent envs terminate without a last observation for agents (#4139)

* fix it

* lint

* Update rllib-training.rst
This commit is contained in:
Eric Liang
2019-02-23 21:23:40 -08:00
committed by GitHub
parent 688a0d17e6
commit 05d96ce81b
6 changed files with 99 additions and 2 deletions
+4
View File
@@ -325,6 +325,10 @@ class _MultiAgentEnvToBaseEnv(BaseEnv):
if set(infos).difference(set(obs)):
raise ValueError("Key set for infos must be a subset of obs: "
"{} vs {}".format(infos.keys(), obs.keys()))
if "__all__" not in dones:
raise ValueError(
"In multi-agent environments, '__all__': True|False must "
"be included in the 'done' dict: got {}.".format(dones))
if dones["__all__"]:
self.dones.add(env_id)
self.env_states[env_id].observe(obs, rewards, dones, infos)
@@ -152,6 +152,17 @@ class MultiAgentSampleBatchBuilder(object):
self.agent_builders.clear()
self.agent_to_policy.clear()
def check_missing_dones(self):
for agent_id, builder in self.agent_builders.items():
if builder.buffers["dones"][-1] is not True:
raise ValueError(
"The environment terminated for all agents, but we still "
"don't have a last observation for "
"agent {} (policy {}). ".format(
agent_id, self.agent_to_policy[agent_id]) +
"Please ensure that you include the last observations "
"of all live agents when setting '__all__' done to True.")
@DeveloperAPI
def build_and_reset(self, episode):
"""Returns the accumulated sample batches for each policy.
+2
View File
@@ -400,6 +400,8 @@ def _process_observations(base_env, policies, batch_builder_pool,
# Cut the batch if we're not packing multiple episodes into one,
# or if we've exceeded the requested batch size.
if episode.batch_builder.has_pending_data():
if dones[env_id]["__all__"]:
episode.batch_builder.check_missing_dones()
if (all_done and not pack) or \
episode.batch_builder.count >= unroll_length:
outputs.append(episode.batch_builder.build_and_reset(episode))
+5 -2
View File
@@ -13,7 +13,7 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv
parser = argparse.ArgumentParser()
parser.add_argument("--stop", type=int, default=50000)
parser.add_argument("--run", type=str, default="QMIX")
parser.add_argument("--run", type=str, default="PG")
class TwoStepGame(MultiAgentEnv):
@@ -86,6 +86,7 @@ if __name__ == "__main__":
"num_workers": 0,
"mixer": grid_search([None, "qmix", "vdn"]),
}
group = True
elif args.run == "APEX_QMIX":
config = {
"num_gpus": 0,
@@ -101,14 +102,16 @@ if __name__ == "__main__":
"target_network_update_freq": 500,
"timesteps_per_iteration": 1000,
}
group = True
else:
config = {}
group = False
ray.init()
run_experiments({
"two_step": {
"run": args.run,
"env": "grouped_twostep",
"env": "grouped_twostep" if group else TwoStepGame,
"stop": {
"timesteps_total": args.stop,
},
@@ -53,6 +53,53 @@ class BasicMultiAgent(MultiAgentEnv):
return obs, rew, done, info
class EarlyDoneMultiAgent(MultiAgentEnv):
"""Env for testing when the env terminates (after agent 0 does)."""
def __init__(self):
self.agents = [MockEnv(3), MockEnv(5)]
self.dones = set()
self.last_obs = {}
self.last_rew = {}
self.last_done = {}
self.last_info = {}
self.i = 0
self.observation_space = gym.spaces.Discrete(10)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
self.dones = set()
self.last_obs = {}
self.last_rew = {}
self.last_done = {}
self.last_info = {}
self.i = 0
for i, a in enumerate(self.agents):
self.last_obs[i] = a.reset()
self.last_rew[i] = None
self.last_done[i] = False
self.last_info[i] = {}
obs_dict = {self.i: self.last_obs[self.i]}
self.i = (self.i + 1) % len(self.agents)
return obs_dict
def step(self, action_dict):
assert len(self.dones) != len(self.agents)
for i, action in action_dict.items():
(self.last_obs[i], self.last_rew[i], self.last_done[i],
self.last_info[i]) = self.agents[i].step(action)
obs = {self.i: self.last_obs[self.i]}
rew = {self.i: self.last_rew[self.i]}
done = {self.i: self.last_done[self.i]}
info = {self.i: self.last_info[self.i]}
if done[self.i]:
rew[self.i] = 0
self.dones.add(self.i)
self.i = (self.i + 1) % len(self.agents)
done["__all__"] = len(self.dones) == len(self.agents) - 1
return obs, rew, done, info
class RoundRobinMultiAgent(MultiAgentEnv):
"""Env of N independent agents, each of which exits after 5 steps.
@@ -302,6 +349,22 @@ class TestMultiAgentEnv(unittest.TestCase):
batch = ev.sample()
self.assertEqual(batch.count, 50)
def testSampleFromEarlyDoneEnv(self):
act_space = gym.spaces.Discrete(2)
obs_space = gym.spaces.Discrete(2)
ev = PolicyEvaluator(
env_creator=lambda _: EarlyDoneMultiAgent(),
policy_graph={
"p0": (MockPolicyGraph, obs_space, act_space, {}),
"p1": (MockPolicyGraph, obs_space, act_space, {}),
},
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
batch_mode="complete_episodes",
batch_steps=1)
self.assertRaisesRegexp(ValueError,
".*don't have a last observation.*",
lambda: ev.sample())
def testMultiAgentSampleRoundRobin(self):
act_space = gym.spaces.Discrete(2)
obs_space = gym.spaces.Discrete(10)