mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 01:46:10 +08:00
[rllib] Allow None to be specified in multi-agent envs (#4464)
* wip * check * doc update * Update hierarchical_training.py
This commit is contained in:
@@ -17,7 +17,8 @@ from ray.exceptions import RayError
|
||||
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \
|
||||
ShuffledInput
|
||||
from ray.rllib.models import MODEL_DEFAULTS
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator, \
|
||||
_validate_multiagent_config
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
|
||||
@@ -682,9 +683,18 @@ class Agent(Trainable):
|
||||
else:
|
||||
input_evaluation = config["input_evaluation"]
|
||||
|
||||
# Fill in the default policy graph if 'None' is specified in multiagent
|
||||
if self.config["multiagent"]["policy_graphs"]:
|
||||
tmp = self.config["multiagent"]["policy_graphs"]
|
||||
_validate_multiagent_config(tmp, allow_none_graph=True)
|
||||
for k, v in tmp.items():
|
||||
if v[0] is None:
|
||||
tmp[k] = (policy_graph, v[1], v[2], v[3])
|
||||
policy_graph = tmp
|
||||
|
||||
return cls(
|
||||
env_creator,
|
||||
self.config["multiagent"]["policy_graphs"] or policy_graph,
|
||||
policy_graph,
|
||||
policy_mapping_fn=self.config["multiagent"]["policy_mapping_fn"],
|
||||
policies_to_train=self.config["multiagent"]["policies_to_train"],
|
||||
tf_session_creator=(session_creator
|
||||
|
||||
@@ -668,30 +668,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
|
||||
def _validate_and_canonicalize(policy_graph, env):
|
||||
if isinstance(policy_graph, dict):
|
||||
for k, v in policy_graph.items():
|
||||
if not isinstance(k, str):
|
||||
raise ValueError(
|
||||
"policy_graph keys must be strs, got {}".format(type(k)))
|
||||
if not isinstance(v, tuple) or len(v) != 4:
|
||||
raise ValueError(
|
||||
"policy_graph values must be tuples of "
|
||||
"(cls, obs_space, action_space, config), got {}".format(v))
|
||||
if not issubclass(v[0], PolicyGraph):
|
||||
raise ValueError(
|
||||
"policy_graph tuple value 0 must be a rllib.PolicyGraph "
|
||||
"class, got {}".format(v[0]))
|
||||
if not isinstance(v[1], gym.Space):
|
||||
raise ValueError(
|
||||
"policy_graph tuple value 1 (observation_space) must be a "
|
||||
"gym.Space, got {}".format(type(v[1])))
|
||||
if not isinstance(v[2], gym.Space):
|
||||
raise ValueError(
|
||||
"policy_graph tuple value 2 (action_space) must be a "
|
||||
"gym.Space, got {}".format(type(v[2])))
|
||||
if not isinstance(v[3], dict):
|
||||
raise ValueError(
|
||||
"policy_graph tuple value 3 (config) must be a dict, "
|
||||
"got {}".format(type(v[3])))
|
||||
_validate_multiagent_config(policy_graph)
|
||||
return policy_graph
|
||||
elif not issubclass(policy_graph, PolicyGraph):
|
||||
raise ValueError("policy_graph must be a rllib.PolicyGraph class")
|
||||
@@ -707,6 +684,35 @@ def _validate_and_canonicalize(policy_graph, env):
|
||||
}
|
||||
|
||||
|
||||
def _validate_multiagent_config(policy_graph, allow_none_graph=False):
|
||||
for k, v in policy_graph.items():
|
||||
if not isinstance(k, str):
|
||||
raise ValueError("policy_graph keys must be strs, got {}".format(
|
||||
type(k)))
|
||||
if not isinstance(v, tuple) or len(v) != 4:
|
||||
raise ValueError(
|
||||
"policy_graph values must be tuples of "
|
||||
"(cls, obs_space, action_space, config), got {}".format(v))
|
||||
if allow_none_graph and v[0] is None:
|
||||
pass
|
||||
elif not issubclass(v[0], PolicyGraph):
|
||||
raise ValueError(
|
||||
"policy_graph tuple value 0 must be a rllib.PolicyGraph "
|
||||
"class or None, got {}".format(v[0]))
|
||||
if not isinstance(v[1], gym.Space):
|
||||
raise ValueError(
|
||||
"policy_graph tuple value 1 (observation_space) must be a "
|
||||
"gym.Space, got {}".format(type(v[1])))
|
||||
if not isinstance(v[2], gym.Space):
|
||||
raise ValueError(
|
||||
"policy_graph tuple value 2 (action_space) must be a "
|
||||
"gym.Space, got {}".format(type(v[2])))
|
||||
if not isinstance(v[3], dict):
|
||||
raise ValueError(
|
||||
"policy_graph tuple value 3 (config) must be a dict, "
|
||||
"got {}".format(type(v[3])))
|
||||
|
||||
|
||||
def _validate_env(env):
|
||||
# allow this as a special case (assumed gym.Env)
|
||||
if hasattr(env, "observation_space") and hasattr(env, "action_space"):
|
||||
|
||||
@@ -35,7 +35,6 @@ import logging
|
||||
import ray
|
||||
from ray.tune import run_experiments, function
|
||||
from ray.rllib.env import MultiAgentEnv
|
||||
from ray.rllib.agents.ppo import PPOAgent
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--flat", action="store_true")
|
||||
@@ -213,12 +212,11 @@ if __name__ == "__main__":
|
||||
"entropy_coeff": 0.01,
|
||||
"multiagent": {
|
||||
"policy_graphs": {
|
||||
"high_level_policy": (PPOAgent._policy_graph,
|
||||
maze.observation_space,
|
||||
"high_level_policy": (None, maze.observation_space,
|
||||
Discrete(4), {
|
||||
"gamma": 0.9
|
||||
}),
|
||||
"low_level_policy": (PPOAgent._policy_graph,
|
||||
"low_level_policy": (None,
|
||||
Tuple([
|
||||
maze.observation_space,
|
||||
Discrete(4)
|
||||
|
||||
@@ -21,7 +21,6 @@ import tensorflow.contrib.slim as slim
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
|
||||
from ray.rllib.models import Model, ModelCatalog
|
||||
from ray.rllib.tests.test_multi_agent_env import MultiCartpole
|
||||
from ray.tune import run_experiments
|
||||
@@ -90,7 +89,7 @@ if __name__ == "__main__":
|
||||
},
|
||||
"gamma": random.choice([0.95, 0.99]),
|
||||
}
|
||||
return (PPOPolicyGraph, obs_space, act_space, config)
|
||||
return (None, obs_space, act_space, config)
|
||||
|
||||
# Setup PPO with an ensemble of `num_policies` different policy graphs
|
||||
policy_graphs = {
|
||||
|
||||
@@ -531,7 +531,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
}
|
||||
obs_space = single_env.observation_space
|
||||
act_space = single_env.action_space
|
||||
return (PGPolicyGraph, obs_space, act_space, config)
|
||||
return (None, obs_space, act_space, config)
|
||||
|
||||
pg = PGAgent(
|
||||
env="multi_cartpole",
|
||||
|
||||
Reference in New Issue
Block a user