mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 18:39:41 +08:00
[rllib] Raise an error if multi-agent envs terminate without a last observation for agents (#4139)
* fix it * lint * Update rllib-training.rst
This commit is contained in:
Vendored
+4
@@ -325,6 +325,10 @@ class _MultiAgentEnvToBaseEnv(BaseEnv):
|
||||
if set(infos).difference(set(obs)):
|
||||
raise ValueError("Key set for infos must be a subset of obs: "
|
||||
"{} vs {}".format(infos.keys(), obs.keys()))
|
||||
if "__all__" not in dones:
|
||||
raise ValueError(
|
||||
"In multi-agent environments, '__all__': True|False must "
|
||||
"be included in the 'done' dict: got {}.".format(dones))
|
||||
if dones["__all__"]:
|
||||
self.dones.add(env_id)
|
||||
self.env_states[env_id].observe(obs, rewards, dones, infos)
|
||||
|
||||
@@ -152,6 +152,17 @@ class MultiAgentSampleBatchBuilder(object):
|
||||
self.agent_builders.clear()
|
||||
self.agent_to_policy.clear()
|
||||
|
||||
def check_missing_dones(self):
|
||||
for agent_id, builder in self.agent_builders.items():
|
||||
if builder.buffers["dones"][-1] is not True:
|
||||
raise ValueError(
|
||||
"The environment terminated for all agents, but we still "
|
||||
"don't have a last observation for "
|
||||
"agent {} (policy {}). ".format(
|
||||
agent_id, self.agent_to_policy[agent_id]) +
|
||||
"Please ensure that you include the last observations "
|
||||
"of all live agents when setting '__all__' done to True.")
|
||||
|
||||
@DeveloperAPI
|
||||
def build_and_reset(self, episode):
|
||||
"""Returns the accumulated sample batches for each policy.
|
||||
|
||||
@@ -400,6 +400,8 @@ def _process_observations(base_env, policies, batch_builder_pool,
|
||||
# Cut the batch if we're not packing multiple episodes into one,
|
||||
# or if we've exceeded the requested batch size.
|
||||
if episode.batch_builder.has_pending_data():
|
||||
if dones[env_id]["__all__"]:
|
||||
episode.batch_builder.check_missing_dones()
|
||||
if (all_done and not pack) or \
|
||||
episode.batch_builder.count >= unroll_length:
|
||||
outputs.append(episode.batch_builder.build_and_reset(episode))
|
||||
|
||||
@@ -13,7 +13,7 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--stop", type=int, default=50000)
|
||||
parser.add_argument("--run", type=str, default="QMIX")
|
||||
parser.add_argument("--run", type=str, default="PG")
|
||||
|
||||
|
||||
class TwoStepGame(MultiAgentEnv):
|
||||
@@ -86,6 +86,7 @@ if __name__ == "__main__":
|
||||
"num_workers": 0,
|
||||
"mixer": grid_search([None, "qmix", "vdn"]),
|
||||
}
|
||||
group = True
|
||||
elif args.run == "APEX_QMIX":
|
||||
config = {
|
||||
"num_gpus": 0,
|
||||
@@ -101,14 +102,16 @@ if __name__ == "__main__":
|
||||
"target_network_update_freq": 500,
|
||||
"timesteps_per_iteration": 1000,
|
||||
}
|
||||
group = True
|
||||
else:
|
||||
config = {}
|
||||
group = False
|
||||
|
||||
ray.init()
|
||||
run_experiments({
|
||||
"two_step": {
|
||||
"run": args.run,
|
||||
"env": "grouped_twostep",
|
||||
"env": "grouped_twostep" if group else TwoStepGame,
|
||||
"stop": {
|
||||
"timesteps_total": args.stop,
|
||||
},
|
||||
|
||||
@@ -53,6 +53,53 @@ class BasicMultiAgent(MultiAgentEnv):
|
||||
return obs, rew, done, info
|
||||
|
||||
|
||||
class EarlyDoneMultiAgent(MultiAgentEnv):
|
||||
"""Env for testing when the env terminates (after agent 0 does)."""
|
||||
|
||||
def __init__(self):
|
||||
self.agents = [MockEnv(3), MockEnv(5)]
|
||||
self.dones = set()
|
||||
self.last_obs = {}
|
||||
self.last_rew = {}
|
||||
self.last_done = {}
|
||||
self.last_info = {}
|
||||
self.i = 0
|
||||
self.observation_space = gym.spaces.Discrete(10)
|
||||
self.action_space = gym.spaces.Discrete(2)
|
||||
|
||||
def reset(self):
|
||||
self.dones = set()
|
||||
self.last_obs = {}
|
||||
self.last_rew = {}
|
||||
self.last_done = {}
|
||||
self.last_info = {}
|
||||
self.i = 0
|
||||
for i, a in enumerate(self.agents):
|
||||
self.last_obs[i] = a.reset()
|
||||
self.last_rew[i] = None
|
||||
self.last_done[i] = False
|
||||
self.last_info[i] = {}
|
||||
obs_dict = {self.i: self.last_obs[self.i]}
|
||||
self.i = (self.i + 1) % len(self.agents)
|
||||
return obs_dict
|
||||
|
||||
def step(self, action_dict):
|
||||
assert len(self.dones) != len(self.agents)
|
||||
for i, action in action_dict.items():
|
||||
(self.last_obs[i], self.last_rew[i], self.last_done[i],
|
||||
self.last_info[i]) = self.agents[i].step(action)
|
||||
obs = {self.i: self.last_obs[self.i]}
|
||||
rew = {self.i: self.last_rew[self.i]}
|
||||
done = {self.i: self.last_done[self.i]}
|
||||
info = {self.i: self.last_info[self.i]}
|
||||
if done[self.i]:
|
||||
rew[self.i] = 0
|
||||
self.dones.add(self.i)
|
||||
self.i = (self.i + 1) % len(self.agents)
|
||||
done["__all__"] = len(self.dones) == len(self.agents) - 1
|
||||
return obs, rew, done, info
|
||||
|
||||
|
||||
class RoundRobinMultiAgent(MultiAgentEnv):
|
||||
"""Env of N independent agents, each of which exits after 5 steps.
|
||||
|
||||
@@ -302,6 +349,22 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 50)
|
||||
|
||||
def testSampleFromEarlyDoneEnv(self):
|
||||
act_space = gym.spaces.Discrete(2)
|
||||
obs_space = gym.spaces.Discrete(2)
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: EarlyDoneMultiAgent(),
|
||||
policy_graph={
|
||||
"p0": (MockPolicyGraph, obs_space, act_space, {}),
|
||||
"p1": (MockPolicyGraph, obs_space, act_space, {}),
|
||||
},
|
||||
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
|
||||
batch_mode="complete_episodes",
|
||||
batch_steps=1)
|
||||
self.assertRaisesRegexp(ValueError,
|
||||
".*don't have a last observation.*",
|
||||
lambda: ev.sample())
|
||||
|
||||
def testMultiAgentSampleRoundRobin(self):
|
||||
act_space = gym.spaces.Discrete(2)
|
||||
obs_space = gym.spaces.Discrete(10)
|
||||
|
||||
Reference in New Issue
Block a user