[rllib] Allow None to be specified in multi-agent envs (#4464)

* wip

* check

* doc update

* Update hierarchical_training.py
This commit is contained in:
Eric Liang
2019-03-25 11:38:17 -07:00
committed by GitHub
parent 11580fb7dc
commit 5b8eb475ce
6 changed files with 65 additions and 50 deletions
+12 -2
View File
@@ -17,7 +17,8 @@ from ray.exceptions import RayError
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \
ShuffledInput
from ray.rllib.models import MODEL_DEFAULTS
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator, \
_validate_multiagent_config
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
@@ -682,9 +683,18 @@ class Agent(Trainable):
else:
input_evaluation = config["input_evaluation"]
# Fill in the default policy graph if 'None' is specified in multiagent
if self.config["multiagent"]["policy_graphs"]:
tmp = self.config["multiagent"]["policy_graphs"]
_validate_multiagent_config(tmp, allow_none_graph=True)
for k, v in tmp.items():
if v[0] is None:
tmp[k] = (policy_graph, v[1], v[2], v[3])
policy_graph = tmp
return cls(
env_creator,
self.config["multiagent"]["policy_graphs"] or policy_graph,
policy_graph,
policy_mapping_fn=self.config["multiagent"]["policy_mapping_fn"],
policies_to_train=self.config["multiagent"]["policies_to_train"],
tf_session_creator=(session_creator
+30 -24
View File
@@ -668,30 +668,7 @@ class PolicyEvaluator(EvaluatorInterface):
def _validate_and_canonicalize(policy_graph, env):
if isinstance(policy_graph, dict):
for k, v in policy_graph.items():
if not isinstance(k, str):
raise ValueError(
"policy_graph keys must be strs, got {}".format(type(k)))
if not isinstance(v, tuple) or len(v) != 4:
raise ValueError(
"policy_graph values must be tuples of "
"(cls, obs_space, action_space, config), got {}".format(v))
if not issubclass(v[0], PolicyGraph):
raise ValueError(
"policy_graph tuple value 0 must be a rllib.PolicyGraph "
"class, got {}".format(v[0]))
if not isinstance(v[1], gym.Space):
raise ValueError(
"policy_graph tuple value 1 (observation_space) must be a "
"gym.Space, got {}".format(type(v[1])))
if not isinstance(v[2], gym.Space):
raise ValueError(
"policy_graph tuple value 2 (action_space) must be a "
"gym.Space, got {}".format(type(v[2])))
if not isinstance(v[3], dict):
raise ValueError(
"policy_graph tuple value 3 (config) must be a dict, "
"got {}".format(type(v[3])))
_validate_multiagent_config(policy_graph)
return policy_graph
elif not issubclass(policy_graph, PolicyGraph):
raise ValueError("policy_graph must be a rllib.PolicyGraph class")
@@ -707,6 +684,35 @@ def _validate_and_canonicalize(policy_graph, env):
}
def _validate_multiagent_config(policy_graph, allow_none_graph=False):
for k, v in policy_graph.items():
if not isinstance(k, str):
raise ValueError("policy_graph keys must be strs, got {}".format(
type(k)))
if not isinstance(v, tuple) or len(v) != 4:
raise ValueError(
"policy_graph values must be tuples of "
"(cls, obs_space, action_space, config), got {}".format(v))
if allow_none_graph and v[0] is None:
pass
elif not issubclass(v[0], PolicyGraph):
raise ValueError(
"policy_graph tuple value 0 must be a rllib.PolicyGraph "
"class or None, got {}".format(v[0]))
if not isinstance(v[1], gym.Space):
raise ValueError(
"policy_graph tuple value 1 (observation_space) must be a "
"gym.Space, got {}".format(type(v[1])))
if not isinstance(v[2], gym.Space):
raise ValueError(
"policy_graph tuple value 2 (action_space) must be a "
"gym.Space, got {}".format(type(v[2])))
if not isinstance(v[3], dict):
raise ValueError(
"policy_graph tuple value 3 (config) must be a dict, "
"got {}".format(type(v[3])))
def _validate_env(env):
# allow this as a special case (assumed gym.Env)
if hasattr(env, "observation_space") and hasattr(env, "action_space"):
@@ -35,7 +35,6 @@ import logging
import ray
from ray.tune import run_experiments, function
from ray.rllib.env import MultiAgentEnv
from ray.rllib.agents.ppo import PPOAgent
parser = argparse.ArgumentParser()
parser.add_argument("--flat", action="store_true")
@@ -213,12 +212,11 @@ if __name__ == "__main__":
"entropy_coeff": 0.01,
"multiagent": {
"policy_graphs": {
"high_level_policy": (PPOAgent._policy_graph,
maze.observation_space,
"high_level_policy": (None, maze.observation_space,
Discrete(4), {
"gamma": 0.9
}),
"low_level_policy": (PPOAgent._policy_graph,
"low_level_policy": (None,
Tuple([
maze.observation_space,
Discrete(4)
@@ -21,7 +21,6 @@ import tensorflow.contrib.slim as slim
import ray
from ray import tune
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
from ray.rllib.models import Model, ModelCatalog
from ray.rllib.tests.test_multi_agent_env import MultiCartpole
from ray.tune import run_experiments
@@ -90,7 +89,7 @@ if __name__ == "__main__":
},
"gamma": random.choice([0.95, 0.99]),
}
return (PPOPolicyGraph, obs_space, act_space, config)
return (None, obs_space, act_space, config)
# Setup PPO with an ensemble of `num_policies` different policy graphs
policy_graphs = {
@@ -531,7 +531,7 @@ class TestMultiAgentEnv(unittest.TestCase):
}
obs_space = single_env.observation_space
act_space = single_env.action_space
return (PGPolicyGraph, obs_space, act_space, config)
return (None, obs_space, act_space, config)
pg = PGAgent(
env="multi_cartpole",