[rllib] Allow None to be specified in multi-agent envs (#4464)

* wip

* check

* doc update

* Update hierarchical_training.py
This commit is contained in:
Eric Liang
2019-03-25 11:38:17 -07:00
committed by GitHub
parent 11580fb7dc
commit 5b8eb475ce
6 changed files with 65 additions and 50 deletions
+19 -17
View File
@@ -166,9 +166,10 @@ If all the agents will be using the same algorithm class to train, then you can
trainer = pg.PGAgent(env="my_multiagent_env", config={
"multiagent": {
"policy_graphs": {
"car1": (PGPolicyGraph, car_obs_space, car_act_space, {"gamma": 0.85}),
"car2": (PGPolicyGraph, car_obs_space, car_act_space, {"gamma": 0.99}),
"traffic_light": (PGPolicyGraph, tl_obs_space, tl_act_space, {}),
# the first tuple value is None -> uses default policy graph
"car1": (None, car_obs_space, car_act_space, {"gamma": 0.85}),
"car2": (None, car_obs_space, car_act_space, {"gamma": 0.99}),
"traffic_light": (None, tl_obs_space, tl_act_space, {}),
},
"policy_mapping_fn":
lambda agent_id:
@@ -232,9 +233,9 @@ This can be implemented as a multi-agent environment with three types of agents.
"multiagent": {
"policy_graphs": {
"top_level": (some_policy_graph, ...),
"mid_level": (some_policy_graph, ...),
"low_level": (some_policy_graph, ...),
"top_level": (custom_policy_graph or None, ...),
"mid_level": (custom_policy_graph or None, ...),
"low_level": (custom_policy_graph or None, ...),
},
"policy_mapping_fn":
lambda agent_id:
@@ -248,17 +249,6 @@ In this setup, the appropriate rewards for training lower-level agents must be p
See this file for a runnable example: `hierarchical_training.py <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/hierarchical_training.py>`__.
Grouping Agents
~~~~~~~~~~~~~~~
It is common to have groups of agents in multi-agent RL. RLlib treats agent groups like a single agent with a Tuple action and observation space. The group agent can then be assigned to a single policy for centralized execution, or to specialized multi-agent policies such as `Q-Mix <rllib-algorithms.html#qmix-monotonic-value-factorisation-qmix-vdn-iqn>`__ that implement centralized training but decentralized execution. You can use the ``MultiAgentEnv.with_agent_groups()`` method to define these groups:
.. literalinclude:: ../../python/ray/rllib/env/multi_agent_env.py
:language: python
:start-after: __grouping_doc_begin__
:end-before: __grouping_doc_end__
Variable-Sharing Between Policies
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -296,6 +286,18 @@ Implementing a centralized critic that takes as input the observations and actio
2. Updating the critic: the centralized critic loss can be added to the loss of the custom policy graph, the same as with any other value function. For an example of defining loss inputs, see the `PGPolicyGraph example <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/pg/pg_policy_graph.py>`__.
Grouping Agents
~~~~~~~~~~~~~~~
It is common to have groups of agents in multi-agent RL. RLlib treats agent groups like a single agent with a Tuple action and observation space. The group agent can then be assigned to a single policy for centralized execution, or to specialized multi-agent policies such as `Q-Mix <rllib-algorithms.html#qmix-monotonic-value-factorisation-qmix-vdn-iqn>`__ that implement centralized training but decentralized execution. You can use the ``MultiAgentEnv.with_agent_groups()`` method to define these groups:
.. literalinclude:: ../../python/ray/rllib/env/multi_agent_env.py
:language: python
:start-after: __grouping_doc_begin__
:end-before: __grouping_doc_end__
For environments with multiple groups, or mixtures of agent groups and individual agents, you can use grouping in conjunction with the policy mapping API described in prior sections.
Interfacing with External Agents
--------------------------------
+12 -2
View File
@@ -17,7 +17,8 @@ from ray.exceptions import RayError
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \
ShuffledInput
from ray.rllib.models import MODEL_DEFAULTS
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator, \
_validate_multiagent_config
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
@@ -682,9 +683,18 @@ class Agent(Trainable):
else:
input_evaluation = config["input_evaluation"]
# Fill in the default policy graph if 'None' is specified in multiagent
if self.config["multiagent"]["policy_graphs"]:
tmp = self.config["multiagent"]["policy_graphs"]
_validate_multiagent_config(tmp, allow_none_graph=True)
for k, v in tmp.items():
if v[0] is None:
tmp[k] = (policy_graph, v[1], v[2], v[3])
policy_graph = tmp
return cls(
env_creator,
self.config["multiagent"]["policy_graphs"] or policy_graph,
policy_graph,
policy_mapping_fn=self.config["multiagent"]["policy_mapping_fn"],
policies_to_train=self.config["multiagent"]["policies_to_train"],
tf_session_creator=(session_creator
+30 -24
View File
@@ -668,30 +668,7 @@ class PolicyEvaluator(EvaluatorInterface):
def _validate_and_canonicalize(policy_graph, env):
if isinstance(policy_graph, dict):
for k, v in policy_graph.items():
if not isinstance(k, str):
raise ValueError(
"policy_graph keys must be strs, got {}".format(type(k)))
if not isinstance(v, tuple) or len(v) != 4:
raise ValueError(
"policy_graph values must be tuples of "
"(cls, obs_space, action_space, config), got {}".format(v))
if not issubclass(v[0], PolicyGraph):
raise ValueError(
"policy_graph tuple value 0 must be a rllib.PolicyGraph "
"class, got {}".format(v[0]))
if not isinstance(v[1], gym.Space):
raise ValueError(
"policy_graph tuple value 1 (observation_space) must be a "
"gym.Space, got {}".format(type(v[1])))
if not isinstance(v[2], gym.Space):
raise ValueError(
"policy_graph tuple value 2 (action_space) must be a "
"gym.Space, got {}".format(type(v[2])))
if not isinstance(v[3], dict):
raise ValueError(
"policy_graph tuple value 3 (config) must be a dict, "
"got {}".format(type(v[3])))
_validate_multiagent_config(policy_graph)
return policy_graph
elif not issubclass(policy_graph, PolicyGraph):
raise ValueError("policy_graph must be a rllib.PolicyGraph class")
@@ -707,6 +684,35 @@ def _validate_and_canonicalize(policy_graph, env):
}
def _validate_multiagent_config(policy_graph, allow_none_graph=False):
for k, v in policy_graph.items():
if not isinstance(k, str):
raise ValueError("policy_graph keys must be strs, got {}".format(
type(k)))
if not isinstance(v, tuple) or len(v) != 4:
raise ValueError(
"policy_graph values must be tuples of "
"(cls, obs_space, action_space, config), got {}".format(v))
if allow_none_graph and v[0] is None:
pass
elif not issubclass(v[0], PolicyGraph):
raise ValueError(
"policy_graph tuple value 0 must be a rllib.PolicyGraph "
"class or None, got {}".format(v[0]))
if not isinstance(v[1], gym.Space):
raise ValueError(
"policy_graph tuple value 1 (observation_space) must be a "
"gym.Space, got {}".format(type(v[1])))
if not isinstance(v[2], gym.Space):
raise ValueError(
"policy_graph tuple value 2 (action_space) must be a "
"gym.Space, got {}".format(type(v[2])))
if not isinstance(v[3], dict):
raise ValueError(
"policy_graph tuple value 3 (config) must be a dict, "
"got {}".format(type(v[3])))
def _validate_env(env):
# allow this as a special case (assumed gym.Env)
if hasattr(env, "observation_space") and hasattr(env, "action_space"):
@@ -35,7 +35,6 @@ import logging
import ray
from ray.tune import run_experiments, function
from ray.rllib.env import MultiAgentEnv
from ray.rllib.agents.ppo import PPOAgent
parser = argparse.ArgumentParser()
parser.add_argument("--flat", action="store_true")
@@ -213,12 +212,11 @@ if __name__ == "__main__":
"entropy_coeff": 0.01,
"multiagent": {
"policy_graphs": {
"high_level_policy": (PPOAgent._policy_graph,
maze.observation_space,
"high_level_policy": (None, maze.observation_space,
Discrete(4), {
"gamma": 0.9
}),
"low_level_policy": (PPOAgent._policy_graph,
"low_level_policy": (None,
Tuple([
maze.observation_space,
Discrete(4)
@@ -21,7 +21,6 @@ import tensorflow.contrib.slim as slim
import ray
from ray import tune
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
from ray.rllib.models import Model, ModelCatalog
from ray.rllib.tests.test_multi_agent_env import MultiCartpole
from ray.tune import run_experiments
@@ -90,7 +89,7 @@ if __name__ == "__main__":
},
"gamma": random.choice([0.95, 0.99]),
}
return (PPOPolicyGraph, obs_space, act_space, config)
return (None, obs_space, act_space, config)
# Setup PPO with an ensemble of `num_policies` different policy graphs
policy_graphs = {
@@ -531,7 +531,7 @@ class TestMultiAgentEnv(unittest.TestCase):
}
obs_space = single_env.observation_space
act_space = single_env.action_space
return (PGPolicyGraph, obs_space, act_space, config)
return (None, obs_space, act_space, config)
pg = PGAgent(
env="multi_cartpole",