mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 16:31:16 +08:00
[rllib] Add multi-agent examples for hand-coded policy, centralized VF (#4554)
This commit is contained in:
@@ -49,7 +49,9 @@ class EvaluatorInterface(object):
|
||||
>>> ev.learn_on_batch(samples)
|
||||
"""
|
||||
|
||||
return self.compute_apply(samples)
|
||||
grads, info = self.compute_gradients(samples)
|
||||
self.apply_gradients(grads)
|
||||
return info
|
||||
|
||||
@DeveloperAPI
|
||||
def compute_gradients(self, samples):
|
||||
@@ -113,14 +115,6 @@ class EvaluatorInterface(object):
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def compute_apply(self, samples):
|
||||
"""Deprecated: override learn_on_batch instead."""
|
||||
|
||||
grads, info = self.compute_gradients(samples)
|
||||
self.apply_gradients(grads)
|
||||
return info
|
||||
|
||||
@DeveloperAPI
|
||||
def get_host(self):
|
||||
"""Returns the hostname of the process running this evaluator."""
|
||||
|
||||
@@ -41,7 +41,8 @@ def get_learner_stats(grad_info):
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def collect_metrics(local_evaluator, remote_evaluators=[],
|
||||
def collect_metrics(local_evaluator=None,
|
||||
remote_evaluators=[],
|
||||
timeout_seconds=180):
|
||||
"""Gathers episode metrics from PolicyEvaluator instances."""
|
||||
|
||||
@@ -52,7 +53,7 @@ def collect_metrics(local_evaluator, remote_evaluators=[],
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def collect_episodes(local_evaluator,
|
||||
def collect_episodes(local_evaluator=None,
|
||||
remote_evaluators=[],
|
||||
timeout_seconds=180):
|
||||
"""Gathers new episodes metrics tuples from the given evaluators."""
|
||||
@@ -69,7 +70,8 @@ def collect_episodes(local_evaluator,
|
||||
"this timeout with `collect_metrics_timeout`.")
|
||||
|
||||
metric_lists = ray.get(collected)
|
||||
metric_lists.append(local_evaluator.get_metrics())
|
||||
if local_evaluator:
|
||||
metric_lists.append(local_evaluator.get_metrics())
|
||||
episodes = []
|
||||
for metrics in metric_lists:
|
||||
episodes.extend(metrics)
|
||||
|
||||
@@ -564,21 +564,21 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
summarize(samples)))
|
||||
if isinstance(samples, MultiAgentBatch):
|
||||
info_out = {}
|
||||
to_fetch = {}
|
||||
if self.tf_sess is not None:
|
||||
builder = TFRunBuilder(self.tf_sess, "learn_on_batch")
|
||||
for pid, batch in samples.policy_batches.items():
|
||||
if pid not in self.policies_to_train:
|
||||
continue
|
||||
info_out[pid], _ = (
|
||||
self.policy_map[pid]._build_learn_on_batch(
|
||||
builder, batch))
|
||||
info_out = {k: builder.get(v) for k, v in info_out.items()}
|
||||
else:
|
||||
for pid, batch in samples.policy_batches.items():
|
||||
if pid not in self.policies_to_train:
|
||||
continue
|
||||
info_out[pid], _ = (
|
||||
self.policy_map[pid].learn_on_batch(batch))
|
||||
builder = None
|
||||
for pid, batch in samples.policy_batches.items():
|
||||
if pid not in self.policies_to_train:
|
||||
continue
|
||||
policy = self.policy_map[pid]
|
||||
if builder and hasattr(policy, "_build_learn_on_batch"):
|
||||
to_fetch[pid], _ = policy._build_learn_on_batch(
|
||||
builder, batch)
|
||||
else:
|
||||
info_out[pid], _ = policy.learn_on_batch(batch)
|
||||
info_out.update({k: builder.get(v) for k, v in to_fetch.items()})
|
||||
else:
|
||||
info_out, _ = (
|
||||
self.policy_map[DEFAULT_POLICY_ID].learn_on_batch(samples))
|
||||
|
||||
@@ -170,7 +170,9 @@ class PolicyGraph(object):
|
||||
>>> ev.learn_on_batch(samples)
|
||||
"""
|
||||
|
||||
return self.compute_apply(samples)
|
||||
grads, grad_info = self.compute_gradients(samples)
|
||||
apply_info = self.apply_gradients(grads)
|
||||
return grad_info, apply_info
|
||||
|
||||
@DeveloperAPI
|
||||
def compute_gradients(self, postprocessed_batch):
|
||||
@@ -195,14 +197,6 @@ class PolicyGraph(object):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def compute_apply(self, samples):
|
||||
"""Deprecated: override learn_on_batch instead."""
|
||||
|
||||
grads, grad_info = self.compute_gradients(samples)
|
||||
apply_info = self.apply_gradients(grads)
|
||||
return grad_info, apply_info
|
||||
|
||||
@DeveloperAPI
|
||||
def get_weights(self):
|
||||
"""Returns model weights.
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
"""Example of running a custom hand-coded policy alongside trainable policies.
|
||||
|
||||
This example has two policies:
|
||||
(1) a simple PG policy
|
||||
(2) a hand-coded policy that acts at random in the env (doesn't learn)
|
||||
|
||||
In the console output, you can see the PG policy does much better than random:
|
||||
Result for PG_multi_cartpole_0:
|
||||
...
|
||||
policy_reward_mean:
|
||||
pg_policy: 185.23
|
||||
random: 21.255
|
||||
...
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import gym
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.evaluation import PolicyGraph
|
||||
from ray.rllib.tests.test_multi_agent_env import MultiCartpole
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num-iters", type=int, default=20)
|
||||
|
||||
|
||||
class RandomPolicy(PolicyGraph):
|
||||
"""Hand-coded policy that returns random actions."""
|
||||
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
info_batch=None,
|
||||
episodes=None,
|
||||
**kwargs):
|
||||
"""Compute actions on a batch of observations."""
|
||||
return [self.action_space.sample() for _ in obs_batch], [], {}
|
||||
|
||||
def learn_on_batch(self, samples):
|
||||
"""No learning."""
|
||||
return {}, {}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
ray.init()
|
||||
|
||||
# Simple environment with 4 independent cartpole entities
|
||||
register_env("multi_cartpole", lambda _: MultiCartpole(4))
|
||||
single_env = gym.make("CartPole-v0")
|
||||
obs_space = single_env.observation_space
|
||||
act_space = single_env.action_space
|
||||
|
||||
tune.run(
|
||||
"PG",
|
||||
stop={"training_iteration": args.num_iters},
|
||||
config={
|
||||
"env": "multi_cartpole",
|
||||
"multiagent": {
|
||||
"policy_graphs": {
|
||||
"pg_policy": (None, obs_space, act_space, {}),
|
||||
"random": (RandomPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
"policy_mapping_fn": tune.function(
|
||||
lambda agent_id: ["pg_policy", "random"][agent_id % 2]),
|
||||
},
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,117 @@
|
||||
"""Example of using policy evaluator classes directly to implement training.
|
||||
|
||||
Instead of using the built-in Trainer classes provided by RLlib, here we define
|
||||
a custom PolicyGraph class and manually coordinate distributed sample
|
||||
collection and policy optimization.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import gym
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.evaluation import PolicyGraph, PolicyEvaluator, SampleBatch
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--gpu", action="store_true")
|
||||
parser.add_argument("--num-iters", type=int, default=20)
|
||||
parser.add_argument("--num-workers", type=int, default=2)
|
||||
|
||||
|
||||
class CustomPolicy(PolicyGraph):
|
||||
"""Example of a custom policy graph written from scratch.
|
||||
|
||||
You might find it more convenient to extend TF/TorchPolicyGraph instead
|
||||
for a real policy.
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
PolicyGraph.__init__(self, observation_space, action_space, config)
|
||||
# example parameter
|
||||
self.w = 1.0
|
||||
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
info_batch=None,
|
||||
episodes=None,
|
||||
**kwargs):
|
||||
# return random actions
|
||||
return [self.action_space.sample() for _ in obs_batch], [], {}
|
||||
|
||||
def learn_on_batch(self, samples):
|
||||
# implement your learning code here
|
||||
return {}, {}
|
||||
|
||||
def update_some_value(self, w):
|
||||
# can also call other methods on policies
|
||||
self.w = w
|
||||
|
||||
def get_weights(self):
|
||||
return {"w": self.w}
|
||||
|
||||
def set_weights(self, weights):
|
||||
self.w = weights["w"]
|
||||
|
||||
|
||||
def training_workflow(config, reporter):
|
||||
# Setup policy and policy evaluation actors
|
||||
env = gym.make("CartPole-v0")
|
||||
policy = CustomPolicy(env.observation_space, env.action_space, {})
|
||||
workers = [
|
||||
PolicyEvaluator.as_remote().remote(lambda c: gym.make("CartPole-v0"),
|
||||
CustomPolicy)
|
||||
for _ in range(config["num_workers"])
|
||||
]
|
||||
|
||||
for _ in range(config["num_iters"]):
|
||||
# Broadcast weights to the policy evaluation workers
|
||||
weights = ray.put({"default_policy": policy.get_weights()})
|
||||
for w in workers:
|
||||
w.set_weights.remote(weights)
|
||||
|
||||
# Gather a batch of samples
|
||||
T1 = SampleBatch.concat_samples(
|
||||
ray.get([w.sample.remote() for w in workers]))
|
||||
|
||||
# Update the remote policy replicas and gather another batch of samples
|
||||
new_value = policy.w * 2.0
|
||||
for w in workers:
|
||||
w.for_policy.remote(lambda p: p.update_some_value(new_value))
|
||||
|
||||
# Gather another batch of samples
|
||||
T2 = SampleBatch.concat_samples(
|
||||
ray.get([w.sample.remote() for w in workers]))
|
||||
|
||||
# Improve the policy using the T1 batch
|
||||
policy.learn_on_batch(T1)
|
||||
|
||||
# Do some arbitrary updates based on the T2 batch
|
||||
policy.update_some_value(sum(T2["rewards"]))
|
||||
|
||||
reporter(**collect_metrics(remote_evaluators=workers))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
ray.init()
|
||||
|
||||
tune.run(
|
||||
training_workflow,
|
||||
resources_per_trial={
|
||||
"gpu": 1 if args.gpu else 0,
|
||||
"cpu": 1,
|
||||
"extra_cpu": args.num_workers,
|
||||
},
|
||||
config={
|
||||
"num_workers": args.num_workers,
|
||||
"num_iters": args.num_iters,
|
||||
},
|
||||
)
|
||||
Reference in New Issue
Block a user