From 5b8eb475ced0bf41fe9557f9c657aabbee08f96d Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 25 Mar 2019 11:38:17 -0700 Subject: [PATCH] [rllib] Allow None to be specified in multi-agent envs (#4464) * wip * check * doc update * Update hierarchical_training.py --- doc/source/rllib-env.rst | 36 +++++++------ python/ray/rllib/agents/agent.py | 14 ++++- .../ray/rllib/evaluation/policy_evaluator.py | 54 ++++++++++--------- .../rllib/examples/hierarchical_training.py | 6 +-- .../ray/rllib/examples/multiagent_cartpole.py | 3 +- .../ray/rllib/tests/test_multi_agent_env.py | 2 +- 6 files changed, 65 insertions(+), 50 deletions(-) diff --git a/doc/source/rllib-env.rst b/doc/source/rllib-env.rst index 7180099fc..f2a6c5051 100644 --- a/doc/source/rllib-env.rst +++ b/doc/source/rllib-env.rst @@ -166,9 +166,10 @@ If all the agents will be using the same algorithm class to train, then you can trainer = pg.PGAgent(env="my_multiagent_env", config={ "multiagent": { "policy_graphs": { - "car1": (PGPolicyGraph, car_obs_space, car_act_space, {"gamma": 0.85}), - "car2": (PGPolicyGraph, car_obs_space, car_act_space, {"gamma": 0.99}), - "traffic_light": (PGPolicyGraph, tl_obs_space, tl_act_space, {}), + # the first tuple value is None -> uses default policy graph + "car1": (None, car_obs_space, car_act_space, {"gamma": 0.85}), + "car2": (None, car_obs_space, car_act_space, {"gamma": 0.99}), + "traffic_light": (None, tl_obs_space, tl_act_space, {}), }, "policy_mapping_fn": lambda agent_id: @@ -232,9 +233,9 @@ This can be implemented as a multi-agent environment with three types of agents. "multiagent": { "policy_graphs": { - "top_level": (some_policy_graph, ...), - "mid_level": (some_policy_graph, ...), - "low_level": (some_policy_graph, ...), + "top_level": (custom_policy_graph or None, ...), + "mid_level": (custom_policy_graph or None, ...), + "low_level": (custom_policy_graph or None, ...), }, "policy_mapping_fn": lambda agent_id: @@ -248,17 +249,6 @@ In this setup, the appropriate rewards for training lower-level agents must be p See this file for a runnable example: `hierarchical_training.py `__. - -Grouping Agents -~~~~~~~~~~~~~~~ - -It is common to have groups of agents in multi-agent RL. RLlib treats agent groups like a single agent with a Tuple action and observation space. The group agent can then be assigned to a single policy for centralized execution, or to specialized multi-agent policies such as `Q-Mix `__ that implement centralized training but decentralized execution. You can use the ``MultiAgentEnv.with_agent_groups()`` method to define these groups: - -.. literalinclude:: ../../python/ray/rllib/env/multi_agent_env.py - :language: python - :start-after: __grouping_doc_begin__ - :end-before: __grouping_doc_end__ - Variable-Sharing Between Policies ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -296,6 +286,18 @@ Implementing a centralized critic that takes as input the observations and actio 2. Updating the critic: the centralized critic loss can be added to the loss of the custom policy graph, the same as with any other value function. For an example of defining loss inputs, see the `PGPolicyGraph example `__. +Grouping Agents +~~~~~~~~~~~~~~~ + +It is common to have groups of agents in multi-agent RL. RLlib treats agent groups like a single agent with a Tuple action and observation space. The group agent can then be assigned to a single policy for centralized execution, or to specialized multi-agent policies such as `Q-Mix `__ that implement centralized training but decentralized execution. You can use the ``MultiAgentEnv.with_agent_groups()`` method to define these groups: + +.. literalinclude:: ../../python/ray/rllib/env/multi_agent_env.py + :language: python + :start-after: __grouping_doc_begin__ + :end-before: __grouping_doc_end__ + +For environments with multiple groups, or mixtures of agent groups and individual agents, you can use grouping in conjunction with the policy mapping API described in prior sections. + Interfacing with External Agents -------------------------------- diff --git a/python/ray/rllib/agents/agent.py b/python/ray/rllib/agents/agent.py index 8c0a1bb7f..596e4df7a 100644 --- a/python/ray/rllib/agents/agent.py +++ b/python/ray/rllib/agents/agent.py @@ -17,7 +17,8 @@ from ray.exceptions import RayError from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \ ShuffledInput from ray.rllib.models import MODEL_DEFAULTS -from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator +from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator, \ + _validate_multiagent_config from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI @@ -682,9 +683,18 @@ class Agent(Trainable): else: input_evaluation = config["input_evaluation"] + # Fill in the default policy graph if 'None' is specified in multiagent + if self.config["multiagent"]["policy_graphs"]: + tmp = self.config["multiagent"]["policy_graphs"] + _validate_multiagent_config(tmp, allow_none_graph=True) + for k, v in tmp.items(): + if v[0] is None: + tmp[k] = (policy_graph, v[1], v[2], v[3]) + policy_graph = tmp + return cls( env_creator, - self.config["multiagent"]["policy_graphs"] or policy_graph, + policy_graph, policy_mapping_fn=self.config["multiagent"]["policy_mapping_fn"], policies_to_train=self.config["multiagent"]["policies_to_train"], tf_session_creator=(session_creator diff --git a/python/ray/rllib/evaluation/policy_evaluator.py b/python/ray/rllib/evaluation/policy_evaluator.py index e83b6a209..c7389b3e4 100644 --- a/python/ray/rllib/evaluation/policy_evaluator.py +++ b/python/ray/rllib/evaluation/policy_evaluator.py @@ -668,30 +668,7 @@ class PolicyEvaluator(EvaluatorInterface): def _validate_and_canonicalize(policy_graph, env): if isinstance(policy_graph, dict): - for k, v in policy_graph.items(): - if not isinstance(k, str): - raise ValueError( - "policy_graph keys must be strs, got {}".format(type(k))) - if not isinstance(v, tuple) or len(v) != 4: - raise ValueError( - "policy_graph values must be tuples of " - "(cls, obs_space, action_space, config), got {}".format(v)) - if not issubclass(v[0], PolicyGraph): - raise ValueError( - "policy_graph tuple value 0 must be a rllib.PolicyGraph " - "class, got {}".format(v[0])) - if not isinstance(v[1], gym.Space): - raise ValueError( - "policy_graph tuple value 1 (observation_space) must be a " - "gym.Space, got {}".format(type(v[1]))) - if not isinstance(v[2], gym.Space): - raise ValueError( - "policy_graph tuple value 2 (action_space) must be a " - "gym.Space, got {}".format(type(v[2]))) - if not isinstance(v[3], dict): - raise ValueError( - "policy_graph tuple value 3 (config) must be a dict, " - "got {}".format(type(v[3]))) + _validate_multiagent_config(policy_graph) return policy_graph elif not issubclass(policy_graph, PolicyGraph): raise ValueError("policy_graph must be a rllib.PolicyGraph class") @@ -707,6 +684,35 @@ def _validate_and_canonicalize(policy_graph, env): } +def _validate_multiagent_config(policy_graph, allow_none_graph=False): + for k, v in policy_graph.items(): + if not isinstance(k, str): + raise ValueError("policy_graph keys must be strs, got {}".format( + type(k))) + if not isinstance(v, tuple) or len(v) != 4: + raise ValueError( + "policy_graph values must be tuples of " + "(cls, obs_space, action_space, config), got {}".format(v)) + if allow_none_graph and v[0] is None: + pass + elif not issubclass(v[0], PolicyGraph): + raise ValueError( + "policy_graph tuple value 0 must be a rllib.PolicyGraph " + "class or None, got {}".format(v[0])) + if not isinstance(v[1], gym.Space): + raise ValueError( + "policy_graph tuple value 1 (observation_space) must be a " + "gym.Space, got {}".format(type(v[1]))) + if not isinstance(v[2], gym.Space): + raise ValueError( + "policy_graph tuple value 2 (action_space) must be a " + "gym.Space, got {}".format(type(v[2]))) + if not isinstance(v[3], dict): + raise ValueError( + "policy_graph tuple value 3 (config) must be a dict, " + "got {}".format(type(v[3]))) + + def _validate_env(env): # allow this as a special case (assumed gym.Env) if hasattr(env, "observation_space") and hasattr(env, "action_space"): diff --git a/python/ray/rllib/examples/hierarchical_training.py b/python/ray/rllib/examples/hierarchical_training.py index b55ee78df..2cb25cbbf 100644 --- a/python/ray/rllib/examples/hierarchical_training.py +++ b/python/ray/rllib/examples/hierarchical_training.py @@ -35,7 +35,6 @@ import logging import ray from ray.tune import run_experiments, function from ray.rllib.env import MultiAgentEnv -from ray.rllib.agents.ppo import PPOAgent parser = argparse.ArgumentParser() parser.add_argument("--flat", action="store_true") @@ -213,12 +212,11 @@ if __name__ == "__main__": "entropy_coeff": 0.01, "multiagent": { "policy_graphs": { - "high_level_policy": (PPOAgent._policy_graph, - maze.observation_space, + "high_level_policy": (None, maze.observation_space, Discrete(4), { "gamma": 0.9 }), - "low_level_policy": (PPOAgent._policy_graph, + "low_level_policy": (None, Tuple([ maze.observation_space, Discrete(4) diff --git a/python/ray/rllib/examples/multiagent_cartpole.py b/python/ray/rllib/examples/multiagent_cartpole.py index a00140532..bab549a41 100644 --- a/python/ray/rllib/examples/multiagent_cartpole.py +++ b/python/ray/rllib/examples/multiagent_cartpole.py @@ -21,7 +21,6 @@ import tensorflow.contrib.slim as slim import ray from ray import tune -from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph from ray.rllib.models import Model, ModelCatalog from ray.rllib.tests.test_multi_agent_env import MultiCartpole from ray.tune import run_experiments @@ -90,7 +89,7 @@ if __name__ == "__main__": }, "gamma": random.choice([0.95, 0.99]), } - return (PPOPolicyGraph, obs_space, act_space, config) + return (None, obs_space, act_space, config) # Setup PPO with an ensemble of `num_policies` different policy graphs policy_graphs = { diff --git a/python/ray/rllib/tests/test_multi_agent_env.py b/python/ray/rllib/tests/test_multi_agent_env.py index 6eeca3ef2..1c9d32e8f 100644 --- a/python/ray/rllib/tests/test_multi_agent_env.py +++ b/python/ray/rllib/tests/test_multi_agent_env.py @@ -531,7 +531,7 @@ class TestMultiAgentEnv(unittest.TestCase): } obs_space = single_env.observation_space act_space = single_env.action_space - return (PGPolicyGraph, obs_space, act_space, config) + return (None, obs_space, act_space, config) pg = PGAgent( env="multi_cartpole",