mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 05:39:30 +08:00
[rllib] Part 2 of multiagent support (#2286)
* wip * cls * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * fix obs filter * pass thru worker index * fix * fix log action * debug name * fix sphinx
This commit is contained in:
@@ -63,13 +63,18 @@ DEFAULT_CONFIG = {
|
||||
},
|
||||
# Arguments to pass to the env creator
|
||||
"env_config": {},
|
||||
|
||||
# === Multiagent ===
|
||||
"multiagent": {
|
||||
"policy_graphs": {},
|
||||
"policy_mapping_fn": None,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class A3CAgent(Agent):
|
||||
_agent_name = "A3C"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_allow_unknown_subkeys = ["model", "optimizer", "env_config"]
|
||||
|
||||
@classmethod
|
||||
def default_resource_request(cls, config):
|
||||
@@ -98,7 +103,9 @@ class A3CAgent(Agent):
|
||||
remote_cls = CommonPolicyEvaluator.as_remote(
|
||||
num_gpus=1 if self.config["use_gpu_for_workers"] else 0)
|
||||
self.local_evaluator = CommonPolicyEvaluator(
|
||||
self.env_creator, self.policy_cls,
|
||||
self.env_creator,
|
||||
self.config["multiagent"]["policy_graphs"] or self.policy_cls,
|
||||
policy_mapping_fn=self.config["multiagent"]["policy_mapping_fn"],
|
||||
batch_steps=self.config["batch_size"],
|
||||
batch_mode="truncate_episodes",
|
||||
tf_session_creator=session_creator,
|
||||
@@ -107,13 +114,17 @@ class A3CAgent(Agent):
|
||||
num_envs=self.config["num_envs"])
|
||||
self.remote_evaluators = [
|
||||
remote_cls.remote(
|
||||
self.env_creator, self.policy_cls,
|
||||
self.env_creator,
|
||||
self.config["multiagent"]["policy_graphs"] or self.policy_cls,
|
||||
policy_mapping_fn=(
|
||||
self.config["multiagent"]["policy_mapping_fn"]),
|
||||
batch_steps=self.config["batch_size"],
|
||||
batch_mode="truncate_episodes", sample_async=True,
|
||||
tf_session_creator=session_creator,
|
||||
env_config=self.config["env_config"],
|
||||
model_config=self.config["model"], policy_config=self.config,
|
||||
num_envs=self.config["num_envs"])
|
||||
num_envs=self.config["num_envs"],
|
||||
worker_index=i+1)
|
||||
for i in range(self.config["num_workers"])]
|
||||
|
||||
self.optimizer = AsyncOptimizer(
|
||||
|
||||
@@ -5,6 +5,7 @@ from __future__ import print_function
|
||||
import tensorflow as tf
|
||||
import gym
|
||||
|
||||
import ray
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.utils.process_rollout import compute_advantages
|
||||
from ray.rllib.utils.tf_policy_graph import TFPolicyGraph
|
||||
@@ -14,6 +15,7 @@ class A3CTFPolicyGraph(TFPolicyGraph):
|
||||
"""The TF policy base class."""
|
||||
|
||||
def __init__(self, ob_space, action_space, config):
|
||||
config = dict(ray.rllib.a3c.a3c.DEFAULT_CONFIG, **config)
|
||||
self.local_steps = 0
|
||||
self.config = config
|
||||
self.summarize = config.get("summarize")
|
||||
@@ -27,7 +29,7 @@ class A3CTFPolicyGraph(TFPolicyGraph):
|
||||
self.sess = tf.get_default_session()
|
||||
|
||||
TFPolicyGraph.__init__(
|
||||
self, self.sess, obs_input=self.x,
|
||||
self, ob_space, action_space, self.sess, obs_input=self.x,
|
||||
action_sampler=self.action_dist.sample(), loss=self.loss,
|
||||
loss_inputs=self.loss_in, is_training=self.is_training,
|
||||
state_inputs=self.state_in, state_outputs=self.state_out)
|
||||
|
||||
@@ -8,6 +8,7 @@ from threading import Lock
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import ray
|
||||
from ray.rllib.models.pytorch.misc import var_to_np, convert_batch
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.utils.process_rollout import compute_advantages
|
||||
@@ -18,6 +19,7 @@ class SharedTorchPolicy(PolicyGraph):
|
||||
"""A simple, non-recurrent PyTorch policy example."""
|
||||
|
||||
def __init__(self, obs_space, action_space, config):
|
||||
config = dict(ray.rllib.a3c.a3c.DEFAULT_CONFIG, **config)
|
||||
PolicyGraph.__init__(self, obs_space, action_space, config)
|
||||
self.local_steps = 0
|
||||
self.config = config
|
||||
|
||||
@@ -59,7 +59,8 @@ class Agent(Trainable):
|
||||
"""
|
||||
|
||||
_allow_unknown_configs = False
|
||||
_allow_unknown_subkeys = ["env_config", "model", "optimizer"]
|
||||
_allow_unknown_subkeys = [
|
||||
"tf_session_args", "env_config", "model", "optimizer", "multiagent"]
|
||||
|
||||
@classmethod
|
||||
def resource_help(cls, config):
|
||||
|
||||
@@ -108,14 +108,18 @@ DEFAULT_CONFIG = {
|
||||
# Whether to use a distribution of epsilons across workers for exploration.
|
||||
"per_worker_exploration": False,
|
||||
# Whether to compute priorities on workers.
|
||||
"worker_side_prioritization": False
|
||||
"worker_side_prioritization": False,
|
||||
|
||||
# === Multiagent ===
|
||||
"multiagent": {
|
||||
"policy_graphs": {},
|
||||
"policy_mapping_fn": None,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class DDPGAgent(DQNAgent):
|
||||
_agent_name = "DDPG"
|
||||
_allow_unknown_subkeys = [
|
||||
"model", "optimizer", "tf_session_args", "env_config"]
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_policy_graph = DDPGPolicyGraph
|
||||
|
||||
|
||||
@@ -82,6 +82,7 @@ def _build_q_network(inputs, action_inputs, config):
|
||||
|
||||
class DDPGPolicyGraph(TFPolicyGraph):
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
config = dict(ray.rllib.ddpg.ddpg.DEFAULT_CONFIG, **config)
|
||||
if not isinstance(action_space, Box):
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space {} is not supported for DDPG.".format(
|
||||
@@ -232,7 +233,8 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
]
|
||||
self.is_training = tf.placeholder_with_default(True, ())
|
||||
TFPolicyGraph.__init__(
|
||||
self, self.sess, obs_input=self.cur_observations,
|
||||
self, observation_space, action_space, self.sess,
|
||||
obs_input=self.cur_observations,
|
||||
action_sampler=self.output_actions, loss=self.loss,
|
||||
loss_inputs=self.loss_inputs, is_training=self.is_training)
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
+17
-10
@@ -102,14 +102,18 @@ DEFAULT_CONFIG = {
|
||||
# Whether to use a distribution of epsilons across workers for exploration.
|
||||
"per_worker_exploration": False,
|
||||
# Whether to compute priorities on workers.
|
||||
"worker_side_prioritization": False
|
||||
"worker_side_prioritization": False,
|
||||
|
||||
# === Multiagent ===
|
||||
"multiagent": {
|
||||
"policy_graphs": {},
|
||||
"policy_mapping_fn": None,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class DQNAgent(Agent):
|
||||
_agent_name = "DQN"
|
||||
_allow_unknown_subkeys = [
|
||||
"model", "optimizer", "tf_session_args", "env_config"]
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_policy_graph = DQNPolicyGraph
|
||||
|
||||
@@ -125,7 +129,9 @@ class DQNAgent(Agent):
|
||||
adjusted_batch_size = (
|
||||
self.config["sample_batch_size"] + self.config["n_step"] - 1)
|
||||
self.local_evaluator = CommonPolicyEvaluator(
|
||||
self.env_creator, self._policy_graph,
|
||||
self.env_creator,
|
||||
self.config["multiagent"]["policy_graphs"] or self._policy_graph,
|
||||
policy_mapping_fn=self.config["multiagent"]["policy_mapping_fn"],
|
||||
batch_steps=adjusted_batch_size,
|
||||
batch_mode="truncate_episodes", preprocessor_pref="deepmind",
|
||||
compress_observations=True,
|
||||
@@ -143,8 +149,9 @@ class DQNAgent(Agent):
|
||||
compress_observations=True,
|
||||
env_config=self.config["env_config"],
|
||||
model_config=self.config["model"], policy_config=self.config,
|
||||
num_envs=self.config["num_envs"])
|
||||
for _ in range(self.config["num_workers"])]
|
||||
num_envs=self.config["num_envs"],
|
||||
worker_index=i+1)
|
||||
for i in range(self.config["num_workers"])]
|
||||
|
||||
self.exploration0 = self._make_exploration_schedule(0)
|
||||
self.explorations = [
|
||||
@@ -185,7 +192,7 @@ class DQNAgent(Agent):
|
||||
def update_target_if_needed(self):
|
||||
if self.global_timestep - self.last_target_update_ts > \
|
||||
self.config["target_network_update_freq"]:
|
||||
self.local_evaluator.for_policy(lambda p: p.update_target())
|
||||
self.local_evaluator.foreach_policy(lambda p, _: p.update_target())
|
||||
self.last_target_update_ts = self.global_timestep
|
||||
self.num_target_updates += 1
|
||||
|
||||
@@ -198,11 +205,11 @@ class DQNAgent(Agent):
|
||||
self.update_target_if_needed()
|
||||
|
||||
exp_vals = [self.exploration0.value(self.global_timestep)]
|
||||
self.local_evaluator.for_policy(
|
||||
lambda p: p.set_epsilon(exp_vals[0]))
|
||||
self.local_evaluator.foreach_policy(
|
||||
lambda p, _: p.set_epsilon(exp_vals[0]))
|
||||
for i, e in enumerate(self.remote_evaluators):
|
||||
exp_val = self.explorations[i].value(self.global_timestep)
|
||||
e.for_policy.remote(lambda p: p.set_epsilon(exp_val))
|
||||
e.foreach_policy.remote(lambda p, _: p.set_epsilon(exp_val))
|
||||
exp_vals.append(exp_val)
|
||||
|
||||
result = collect_metrics(
|
||||
|
||||
@@ -7,6 +7,7 @@ import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow.contrib.layers as layers
|
||||
|
||||
import ray
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.optimizers.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
@@ -47,6 +48,7 @@ def adjust_nstep(n_step, gamma, obs, actions, rewards, new_obs, dones):
|
||||
|
||||
class DQNPolicyGraph(TFPolicyGraph):
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
config = dict(ray.rllib.dqn.dqn.DEFAULT_CONFIG, **config)
|
||||
if not isinstance(action_space, Discrete):
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space {} is not supported for DQN.".format(
|
||||
@@ -144,7 +146,8 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
]
|
||||
self.is_training = tf.placeholder_with_default(True, ())
|
||||
TFPolicyGraph.__init__(
|
||||
self, self.sess, obs_input=self.cur_observations,
|
||||
self, observation_space, action_space, self.sess,
|
||||
obs_input=self.cur_observations,
|
||||
action_sampler=self.output_actions, loss=self.loss,
|
||||
loss_inputs=self.loss_inputs, is_training=self.is_training)
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
@@ -137,7 +137,6 @@ class Worker(object):
|
||||
class ESAgent(agent.Agent):
|
||||
_agent_name = "ES"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_allow_unknown_subkeys = ["env_config"]
|
||||
|
||||
@classmethod
|
||||
def default_resource_request(cls, config):
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
"""Simple example of setting up a multi-agent policy mapping.
|
||||
|
||||
Control the number of agents and policies via --num-agents and --num-policies.
|
||||
|
||||
This works with hundreds of agents and policies, but note that initializing
|
||||
many TF policy graphs will take some time.
|
||||
|
||||
Also, TF evals might slow down with large numbers of policies. To debug TF
|
||||
execution, set the TF_TIMELINE_DIR environment variable.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import gym
|
||||
import random
|
||||
|
||||
import ray
|
||||
from ray.rllib.pg.pg import PGAgent
|
||||
from ray.rllib.pg.pg_policy_graph import PGPolicyGraph
|
||||
from ray.rllib.test.test_multi_agent_env import MultiCartpole
|
||||
from ray.tune.logger import pretty_print
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--num-agents", type=int, default=4)
|
||||
parser.add_argument("--num-policies", type=int, default=2)
|
||||
parser.add_argument("--num-iters", type=int, default=20)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
ray.init()
|
||||
|
||||
# Simple environment with `num_agents` independent cartpole entities
|
||||
register_env("multi_cartpole", lambda _: MultiCartpole(args.num_agents))
|
||||
single_env = gym.make("CartPole-v0")
|
||||
obs_space = single_env.observation_space
|
||||
act_space = single_env.action_space
|
||||
|
||||
def gen_policy():
|
||||
config = {
|
||||
"gamma": random.choice([0.5, 0.8, 0.9, 0.95, 0.99]),
|
||||
"n_step": random.choice([1, 2, 3, 4, 5]),
|
||||
}
|
||||
return (PGPolicyGraph, obs_space, act_space, config)
|
||||
|
||||
# Setup PG with an ensemble of `num_policies` different policy graphs
|
||||
policy_graphs = {
|
||||
"policy_{}".format(i): gen_policy() for i in range(args.num_policies)
|
||||
}
|
||||
policy_ids = list(policy_graphs.keys())
|
||||
|
||||
agent = PGAgent(
|
||||
env="multi_cartpole",
|
||||
config={
|
||||
"multiagent": {
|
||||
"policy_graphs": policy_graphs,
|
||||
"policy_mapping_fn": (
|
||||
lambda agent_id: random.choice(policy_ids)),
|
||||
},
|
||||
})
|
||||
|
||||
for i in range(args.num_iters):
|
||||
print("== Iteration", i, "==")
|
||||
print(pretty_print(agent.train()))
|
||||
@@ -217,6 +217,7 @@ class ApexOptimizer(PolicyOptimizer):
|
||||
|
||||
with self.timers["sample_processing"]:
|
||||
for ev, sample_batch in self.sample_tasks.completed():
|
||||
self._check_not_multiagent(sample_batch)
|
||||
sample_timesteps += self.sample_batch_size
|
||||
|
||||
# Send the data to the replay buffer
|
||||
|
||||
@@ -20,6 +20,9 @@ class AsyncOptimizer(PolicyOptimizer):
|
||||
self.dispatch_timer = TimerStat()
|
||||
self.grads_per_step = grads_per_step
|
||||
self.batch_size = batch_size
|
||||
if not self.remote_evaluators:
|
||||
raise ValueError(
|
||||
"Async optimizer requires at least 1 remote evaluator")
|
||||
|
||||
def step(self):
|
||||
weights = ray.put(self.local_evaluator.get_weights())
|
||||
|
||||
@@ -2,13 +2,15 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.rllib.optimizers.replay_buffer import ReplayBuffer, \
|
||||
PrioritizedReplayBuffer
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.optimizers.sample_batch import SampleBatch
|
||||
from ray.rllib.optimizers.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
||||
MultiAgentBatch
|
||||
from ray.rllib.utils.compression import pack_if_needed
|
||||
from ray.rllib.utils.filter import RunningStat
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
@@ -41,11 +43,15 @@ class LocalSyncReplayOptimizer(PolicyOptimizer):
|
||||
|
||||
# Set up replay buffer
|
||||
if prioritized_replay:
|
||||
self.replay_buffer = PrioritizedReplayBuffer(
|
||||
buffer_size, alpha=prioritized_replay_alpha,
|
||||
clip_rewards=clip_rewards)
|
||||
def new_buffer():
|
||||
return PrioritizedReplayBuffer(
|
||||
buffer_size, alpha=prioritized_replay_alpha,
|
||||
clip_rewards=clip_rewards)
|
||||
else:
|
||||
self.replay_buffer = ReplayBuffer(buffer_size, clip_rewards)
|
||||
def new_buffer():
|
||||
return ReplayBuffer(buffer_size, clip_rewards)
|
||||
|
||||
self.replay_buffers = collections.defaultdict(new_buffer)
|
||||
|
||||
assert buffer_size >= self.replay_starts
|
||||
|
||||
@@ -63,47 +69,64 @@ class LocalSyncReplayOptimizer(PolicyOptimizer):
|
||||
[e.sample.remote() for e in self.remote_evaluators]))
|
||||
else:
|
||||
batch = self.local_evaluator.sample()
|
||||
for row in batch.rows():
|
||||
self.replay_buffer.add(
|
||||
pack_if_needed(row["obs"]), row["actions"], row["rewards"],
|
||||
pack_if_needed(row["new_obs"]),
|
||||
row["dones"], row["weights"])
|
||||
|
||||
if len(self.replay_buffer) >= self.replay_starts:
|
||||
# Handle everything as if multiagent
|
||||
if isinstance(batch, SampleBatch):
|
||||
batch = MultiAgentBatch(
|
||||
{DEFAULT_POLICY_ID: batch}, batch.count)
|
||||
|
||||
for policy_id, s in batch.policy_batches.items():
|
||||
for row in s.rows():
|
||||
if "weights" not in row:
|
||||
row["weights"] = np.ones_like(row["rewards"])
|
||||
self.replay_buffers[policy_id].add(
|
||||
pack_if_needed(row["obs"]), row["actions"],
|
||||
row["rewards"], pack_if_needed(row["new_obs"]),
|
||||
row["dones"], row["weights"])
|
||||
|
||||
if self.num_steps_sampled >= self.replay_starts:
|
||||
self._optimize()
|
||||
|
||||
self.num_steps_sampled += batch.count
|
||||
|
||||
def _optimize(self):
|
||||
with self.replay_timer:
|
||||
if isinstance(self.replay_buffer, PrioritizedReplayBuffer):
|
||||
(obses_t, actions, rewards, obses_tp1,
|
||||
dones, weights, batch_indexes) = self.replay_buffer.sample(
|
||||
self.train_batch_size,
|
||||
beta=self.prioritized_replay_beta)
|
||||
else:
|
||||
(obses_t, actions, rewards, obses_tp1,
|
||||
dones) = self.replay_buffer.sample(
|
||||
self.train_batch_size)
|
||||
weights = np.ones_like(rewards)
|
||||
batch_indexes = - np.ones_like(rewards)
|
||||
samples = SampleBatch({
|
||||
"obs": obses_t, "actions": actions, "rewards": rewards,
|
||||
"new_obs": obses_tp1, "dones": dones, "weights": weights,
|
||||
"batch_indexes": batch_indexes})
|
||||
samples = self._replay()
|
||||
|
||||
with self.grad_timer:
|
||||
info = self.local_evaluator.compute_apply(samples)
|
||||
if isinstance(self.replay_buffer, PrioritizedReplayBuffer):
|
||||
td_error = info["td_error"]
|
||||
new_priorities = (
|
||||
np.abs(td_error) + self.prioritized_replay_eps)
|
||||
self.replay_buffer.update_priorities(
|
||||
samples["batch_indexes"], new_priorities)
|
||||
info_dict = self.local_evaluator.compute_apply(samples)
|
||||
for policy_id, info in info_dict.items():
|
||||
replay_buffer = self.replay_buffers[policy_id]
|
||||
if isinstance(replay_buffer, PrioritizedReplayBuffer):
|
||||
td_error = info["td_error"]
|
||||
new_priorities = (
|
||||
np.abs(td_error) + self.prioritized_replay_eps)
|
||||
replay_buffer.update_priorities(
|
||||
samples.policy_batches[policy_id]["batch_indexes"],
|
||||
new_priorities)
|
||||
self.grad_timer.push_units_processed(samples.count)
|
||||
|
||||
self.num_steps_trained += samples.count
|
||||
|
||||
def _replay(self):
|
||||
samples = {}
|
||||
with self.replay_timer:
|
||||
for policy_id, replay_buffer in self.replay_buffers.items():
|
||||
if isinstance(replay_buffer, PrioritizedReplayBuffer):
|
||||
(obses_t, actions, rewards, obses_tp1,
|
||||
dones, weights, batch_indexes) = replay_buffer.sample(
|
||||
self.train_batch_size,
|
||||
beta=self.prioritized_replay_beta)
|
||||
else:
|
||||
(obses_t, actions, rewards, obses_tp1,
|
||||
dones) = replay_buffer.sample(self.train_batch_size)
|
||||
weights = np.ones_like(rewards)
|
||||
batch_indexes = - np.ones_like(rewards)
|
||||
samples[policy_id] = SampleBatch({
|
||||
"obs": obses_t, "actions": actions, "rewards": rewards,
|
||||
"new_obs": obses_tp1, "dones": dones, "weights": weights,
|
||||
"batch_indexes": batch_indexes})
|
||||
return MultiAgentBatch(samples, self.train_batch_size)
|
||||
|
||||
def stats(self):
|
||||
return dict(PolicyOptimizer.stats(self), **{
|
||||
"sample_time_ms": round(1000 * self.sample_timer.mean, 3),
|
||||
|
||||
@@ -9,7 +9,6 @@ import tensorflow as tf
|
||||
import ray
|
||||
from ray.rllib.optimizers.policy_evaluator import TFMultiGPUSupport
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.optimizers.sample_batch import SampleBatch
|
||||
from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
|
||||
@@ -90,7 +89,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
||||
self.timesteps_per_batch)
|
||||
else:
|
||||
samples = self.local_evaluator.sample()
|
||||
assert isinstance(samples, SampleBatch)
|
||||
self._check_not_multiagent(samples)
|
||||
|
||||
if postprocess_fn:
|
||||
postprocess_fn(samples)
|
||||
|
||||
@@ -20,7 +20,8 @@ class PolicyEvaluator(object):
|
||||
This method must be implemented by subclasses.
|
||||
|
||||
Returns:
|
||||
SampleBatch: A columnar batch of experiences (e.g., tensors).
|
||||
SampleBatch|MultiAgentBatch: A columnar batch of experiences
|
||||
(e.g., tensors), or a multi-agent batch.
|
||||
|
||||
Examples:
|
||||
>>> print(ev.sample())
|
||||
@@ -35,8 +36,10 @@ class PolicyEvaluator(object):
|
||||
This method must be implemented by subclasses.
|
||||
|
||||
Returns:
|
||||
object: A gradient that can be applied on a compatible evaluator.
|
||||
info: dictionary of extra metadata.
|
||||
(grads, info): A list of gradients that can be applied on a
|
||||
compatible evaluator. In the multi-agent case, returns a dict
|
||||
of gradients keyed by policy graph ids. An info dictionary of
|
||||
extra metadata is also returned.
|
||||
|
||||
Examples:
|
||||
>>> batch = ev.sample()
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ray
|
||||
from ray.rllib.optimizers.sample_batch import MultiAgentBatch
|
||||
|
||||
|
||||
class PolicyOptimizer(object):
|
||||
@@ -54,8 +55,8 @@ class PolicyOptimizer(object):
|
||||
else:
|
||||
local_evaluator = evaluator_cls(**evaluator_args)
|
||||
remote_evaluators = [
|
||||
remote_cls.remote(**evaluator_args)
|
||||
for _ in range(num_workers)]
|
||||
remote_cls.remote(worker_index=i+1, **evaluator_args)
|
||||
for i in range(num_workers)]
|
||||
return cls(optimizer_config, local_evaluator, remote_evaluators)
|
||||
|
||||
def __init__(self, config, local_evaluator, remote_evaluators):
|
||||
@@ -130,3 +131,8 @@ class PolicyOptimizer(object):
|
||||
[ev.apply.remote(func, i + 1)
|
||||
for i, ev in enumerate(self.remote_evaluators)])
|
||||
return local_result + remote_results
|
||||
|
||||
def _check_not_multiagent(self, sample_batch):
|
||||
if isinstance(sample_batch, MultiAgentBatch):
|
||||
raise NotImplementedError(
|
||||
"This optimizer does not support multi-agent yet.")
|
||||
|
||||
@@ -6,6 +6,10 @@ import collections
|
||||
import numpy as np
|
||||
|
||||
|
||||
# Defaults policy id for single agent environments
|
||||
DEFAULT_POLICY_ID = "default"
|
||||
|
||||
|
||||
class SampleBatchBuilder(object):
|
||||
"""Util to build a SampleBatch incrementally.
|
||||
|
||||
@@ -107,7 +111,7 @@ class MultiAgentSampleBatchBuilder(object):
|
||||
pre_batch, other_batches)
|
||||
|
||||
# Append into policy batches and reset
|
||||
for agent_id, post_batch in post_batches.items():
|
||||
for agent_id, post_batch in sorted(post_batches.items()):
|
||||
self.policy_builders[self.agent_to_policy[agent_id]].add_batch(
|
||||
post_batch)
|
||||
self.agent_builders.clear()
|
||||
@@ -122,33 +126,62 @@ class MultiAgentSampleBatchBuilder(object):
|
||||
|
||||
self.postprocess_batch_so_far()
|
||||
policy_batches = {}
|
||||
for policy_id, policy_batch_builder in self.policy_builders.items():
|
||||
policy_batches[policy_id] = policy_batch_builder.build_and_reset()
|
||||
for policy_id, builder in self.policy_builders.items():
|
||||
if builder.count > 0:
|
||||
policy_batches[policy_id] = builder.build_and_reset()
|
||||
old_count = self.count
|
||||
self.count = 0
|
||||
return MultiAgentBatch.wrap_as_needed(policy_batches)
|
||||
return MultiAgentBatch.wrap_as_needed(policy_batches, old_count)
|
||||
|
||||
|
||||
class MultiAgentBatch(object):
|
||||
def __init__(self, policy_batches):
|
||||
"""A batch of experiences from multiple policies in the environment.
|
||||
|
||||
Attributes:
|
||||
policy_batches (dict): Mapping from policy id to a normal SampleBatch
|
||||
of experiences. Note that these batches may be of different length.
|
||||
count (int): The number of timesteps in the environment this batch
|
||||
contains. This will be less than the number of transitions this
|
||||
batch contains across all policies in total.
|
||||
"""
|
||||
|
||||
def __init__(self, policy_batches, count):
|
||||
self.policy_batches = policy_batches
|
||||
self.count = count
|
||||
|
||||
@staticmethod
|
||||
def wrap_as_needed(batches):
|
||||
if len(batches) == 1 and "default" in batches:
|
||||
return batches["default"]
|
||||
return MultiAgentBatch(batches)
|
||||
def wrap_as_needed(batches, count):
|
||||
if len(batches) == 1 and DEFAULT_POLICY_ID in batches:
|
||||
return batches[DEFAULT_POLICY_ID]
|
||||
return MultiAgentBatch(batches, count)
|
||||
|
||||
@staticmethod
|
||||
def concat_samples(samples):
|
||||
policy_batches = collections.defaultdict(list)
|
||||
total_count = 0
|
||||
for s in samples:
|
||||
assert isinstance(s, MultiAgentBatch)
|
||||
for policy_id, batch in s.policy_batches.items():
|
||||
policy_batches[policy_id].append(batch)
|
||||
total_count += s.count
|
||||
out = {}
|
||||
for policy_id, batches in policy_batches.items():
|
||||
out[policy_id] = SampleBatch.concat_samples(batches)
|
||||
return MultiAgentBatch(out)
|
||||
return MultiAgentBatch(out, total_count)
|
||||
|
||||
def total(self):
|
||||
ct = 0
|
||||
for batch in self.policy_batches.values():
|
||||
ct += batch.count
|
||||
return ct
|
||||
|
||||
def __str__(self):
|
||||
return "MultiAgentBatch({}, count={})".format(
|
||||
str(self.policy_batches), self.count)
|
||||
|
||||
def __repr__(self):
|
||||
return "MultiAgentBatch({}, count={})".format(
|
||||
str(self.policy_batches), self.count)
|
||||
|
||||
|
||||
class SampleBatch(object):
|
||||
@@ -166,11 +199,15 @@ class SampleBatch(object):
|
||||
for k, v in self.data.copy().items():
|
||||
assert type(k) == str, self
|
||||
lengths.append(len(v))
|
||||
if not lengths:
|
||||
raise ValueError("Empty sample batch")
|
||||
assert len(set(lengths)) == 1, "data columns must be same length"
|
||||
self.count = lengths[0]
|
||||
|
||||
@staticmethod
|
||||
def concat_samples(samples):
|
||||
if isinstance(samples[0], MultiAgentBatch):
|
||||
return MultiAgentBatch.concat_samples(samples)
|
||||
out = {}
|
||||
samples = [s for s in samples if s.count > 0]
|
||||
for k in samples[0].keys():
|
||||
|
||||
@@ -29,6 +29,12 @@ DEFAULT_CONFIG = {
|
||||
"model": {"fcnet_hiddens": [128, 128]},
|
||||
# Arguments to pass to the env creator
|
||||
"env_config": {},
|
||||
|
||||
# === Multiagent ===
|
||||
"multiagent": {
|
||||
"policy_graphs": {},
|
||||
"policy_mapping_fn": None,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -52,7 +58,11 @@ class PGAgent(Agent):
|
||||
evaluator_cls=CommonPolicyEvaluator,
|
||||
evaluator_args={
|
||||
"env_creator": self.env_creator,
|
||||
"policy_graph": PGPolicyGraph,
|
||||
"policy_graph": (
|
||||
self.config["multiagent"]["policy_graphs"] or
|
||||
PGPolicyGraph),
|
||||
"policy_mapping_fn":
|
||||
self.config["multiagent"]["policy_mapping_fn"],
|
||||
"batch_steps": self.config["batch_size"],
|
||||
"batch_mode": "truncate_episodes",
|
||||
"model_config": self.config["model"],
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.utils.process_rollout import compute_advantages
|
||||
from ray.rllib.utils.tf_policy_graph import TFPolicyGraph
|
||||
@@ -12,6 +13,7 @@ from ray.rllib.utils.tf_policy_graph import TFPolicyGraph
|
||||
class PGPolicyGraph(TFPolicyGraph):
|
||||
|
||||
def __init__(self, obs_space, action_space, config):
|
||||
config = dict(ray.rllib.pg.pg.DEFAULT_CONFIG, **config)
|
||||
self.config = config
|
||||
|
||||
# setup policy
|
||||
@@ -36,7 +38,7 @@ class PGPolicyGraph(TFPolicyGraph):
|
||||
]
|
||||
self.is_training = tf.placeholder_with_default(True, ())
|
||||
TFPolicyGraph.__init__(
|
||||
self, self.sess, obs_input=self.x,
|
||||
self, obs_space, action_space, self.sess, obs_input=self.x,
|
||||
action_sampler=self.dist.sample(), loss=self.loss,
|
||||
loss_inputs=self.loss_in, is_training=self.is_training)
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
@@ -88,7 +88,6 @@ DEFAULT_CONFIG = {
|
||||
|
||||
class PPOAgent(Agent):
|
||||
_agent_name = "PPO"
|
||||
_allow_unknown_subkeys = ["model", "tf_session_args", "env_config"]
|
||||
_default_config = DEFAULT_CONFIG
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -48,6 +48,22 @@ class MockEnv(gym.Env):
|
||||
return 0, 1, self.i >= self.episode_length, {}
|
||||
|
||||
|
||||
class MockEnv2(gym.Env):
|
||||
def __init__(self, episode_length):
|
||||
self.episode_length = episode_length
|
||||
self.i = 0
|
||||
self.observation_space = gym.spaces.Discrete(100)
|
||||
self.action_space = gym.spaces.Discrete(2)
|
||||
|
||||
def reset(self):
|
||||
self.i = 0
|
||||
return self.i
|
||||
|
||||
def step(self, action):
|
||||
self.i += 1
|
||||
return self.i, 100, self.i >= self.episode_length, {}
|
||||
|
||||
|
||||
class MockVectorEnv(VectorEnv):
|
||||
def __init__(self, episode_length, num_envs):
|
||||
self.envs = [
|
||||
|
||||
@@ -2,12 +2,23 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gym
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.test.test_common_policy_evaluator import MockEnv
|
||||
from ray.rllib.pg import PGAgent
|
||||
from ray.rllib.pg.pg_policy_graph import PGPolicyGraph
|
||||
from ray.rllib.dqn.dqn_policy_graph import DQNPolicyGraph
|
||||
from ray.rllib.optimizers import LocalSyncOptimizer, \
|
||||
LocalSyncReplayOptimizer, AsyncOptimizer
|
||||
from ray.rllib.test.test_common_policy_evaluator import MockEnv, MockEnv2, \
|
||||
MockPolicyGraph
|
||||
from ray.rllib.utils.common_policy_evaluator import CommonPolicyEvaluator, \
|
||||
collect_metrics
|
||||
from ray.rllib.utils.async_vector_env import _MultiAgentEnvToAsync
|
||||
from ray.rllib.utils.multi_agent_env import MultiAgentEnv
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
|
||||
class BasicMultiAgent(MultiAgentEnv):
|
||||
@@ -16,6 +27,8 @@ class BasicMultiAgent(MultiAgentEnv):
|
||||
def __init__(self, num):
|
||||
self.agents = [MockEnv(25) for _ in range(num)]
|
||||
self.dones = set()
|
||||
self.observation_space = gym.spaces.Discrete(2)
|
||||
self.action_space = gym.spaces.Discrete(2)
|
||||
|
||||
def reset(self):
|
||||
self.dones = set()
|
||||
@@ -36,8 +49,13 @@ class RoundRobinMultiAgent(MultiAgentEnv):
|
||||
|
||||
On each step() of the env, only one agent takes an action."""
|
||||
|
||||
def __init__(self, num):
|
||||
self.agents = [MockEnv(5) for _ in range(num)]
|
||||
def __init__(self, num, increment_obs=False):
|
||||
if increment_obs:
|
||||
# Observations are 0, 1, 2, 3... etc. as time advances
|
||||
self.agents = [MockEnv2(5) for _ in range(num)]
|
||||
else:
|
||||
# Observations are all zeros
|
||||
self.agents = [MockEnv(5) for _ in range(num)]
|
||||
self.dones = set()
|
||||
self.last_obs = {}
|
||||
self.last_rew = {}
|
||||
@@ -45,24 +63,59 @@ class RoundRobinMultiAgent(MultiAgentEnv):
|
||||
self.last_info = {}
|
||||
self.i = 0
|
||||
self.num = num
|
||||
self.observation_space = gym.spaces.Discrete(2)
|
||||
self.action_space = gym.spaces.Discrete(2)
|
||||
|
||||
def reset(self):
|
||||
self.dones = set()
|
||||
return {i: a.reset() for i, a in enumerate(self.agents)}
|
||||
self.last_obs = {}
|
||||
self.last_rew = {}
|
||||
self.last_done = {}
|
||||
self.last_info = {}
|
||||
self.i = 0
|
||||
for i, a in enumerate(self.agents):
|
||||
self.last_obs[i] = a.reset()
|
||||
self.last_rew[i] = None
|
||||
self.last_done[i] = False
|
||||
self.last_info[i] = {}
|
||||
obs_dict = {self.i: self.last_obs[self.i]}
|
||||
self.i = (self.i + 1) % self.num
|
||||
return obs_dict
|
||||
|
||||
def step(self, action_dict):
|
||||
assert len(self.dones) != len(self.agents)
|
||||
for i, action in action_dict.items():
|
||||
(self.last_obs[i], self.last_rew[i], self.last_done[i],
|
||||
self.last_info[i]) = self.agents[i].step(action)
|
||||
if self.last_done[i]:
|
||||
obs = {self.i: self.last_obs[self.i]}
|
||||
rew = {self.i: self.last_rew[self.i]}
|
||||
done = {self.i: self.last_done[self.i]}
|
||||
info = {self.i: self.last_info[self.i]}
|
||||
if done[self.i]:
|
||||
rew[self.i] = 0
|
||||
self.dones.add(self.i)
|
||||
self.i = (self.i + 1) % self.num
|
||||
done["__all__"] = len(self.dones) == len(self.agents)
|
||||
return obs, rew, done, info
|
||||
|
||||
|
||||
class MultiCartpole(MultiAgentEnv):
|
||||
def __init__(self, num):
|
||||
self.agents = [gym.make("CartPole-v0") for _ in range(num)]
|
||||
self.dones = set()
|
||||
self.observation_space = self.agents[0].observation_space
|
||||
self.action_space = self.agents[0].action_space
|
||||
|
||||
def reset(self):
|
||||
self.dones = set()
|
||||
return {i: a.reset() for i, a in enumerate(self.agents)}
|
||||
|
||||
def step(self, action_dict):
|
||||
obs, rew, done, info = {}, {}, {}, {}
|
||||
for i, action in action_dict.items():
|
||||
obs[i], rew[i], done[i], info[i] = self.agents[i].step(action)
|
||||
if done[i]:
|
||||
self.dones.add(i)
|
||||
obs = {self.i: self.last_obs[i]}
|
||||
rew = {self.i: self.last_rew[i]}
|
||||
done = {self.i: self.last_done[i]}
|
||||
info = {self.i: self.last_info[i]}
|
||||
self.i += 1
|
||||
self.i %= self.num
|
||||
done["__all__"] = len(self.dones) == len(self.agents)
|
||||
return obs, rew, done, info
|
||||
|
||||
@@ -86,15 +139,15 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
def testRoundRobinMock(self):
|
||||
env = RoundRobinMultiAgent(2)
|
||||
obs = env.reset()
|
||||
self.assertEqual(obs, {0: 0, 1: 0})
|
||||
obs, rew, done, info = env.step({0: 0, 1: 0})
|
||||
self.assertEqual(obs, {0: 0})
|
||||
for _ in range(4):
|
||||
for _ in range(5):
|
||||
obs, rew, done, info = env.step({0: 0})
|
||||
self.assertEqual(obs, {1: 0})
|
||||
self.assertEqual(done["__all__"], False)
|
||||
obs, rew, done, info = env.step({1: 0})
|
||||
self.assertEqual(obs, {0: 0})
|
||||
self.assertEqual(done["__all__"], False)
|
||||
obs, rew, done, info = env.step({0: 0})
|
||||
self.assertEqual(done["__all__"], True)
|
||||
|
||||
def testVectorizeBasic(self):
|
||||
@@ -140,14 +193,160 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
def testVectorizeRoundRobin(self):
|
||||
env = _MultiAgentEnvToAsync(lambda: RoundRobinMultiAgent(2), [], 2)
|
||||
obs, rew, dones, _, _ = env.poll()
|
||||
self.assertEqual(obs, {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
|
||||
self.assertEqual(rew, {0: {0: None, 1: None}, 1: {0: None, 1: None}})
|
||||
env.send_actions({0: {0: 0}, 1: {0: 0}})
|
||||
obs, rew, dones, _, _ = env.poll()
|
||||
self.assertEqual(obs, {0: {0: 0}, 1: {0: 0}})
|
||||
self.assertEqual(rew, {0: {0: None}, 1: {0: None}})
|
||||
env.send_actions({0: {0: 0}, 1: {0: 0}})
|
||||
obs, rew, dones, _, _ = env.poll()
|
||||
self.assertEqual(obs, {0: {1: 0}, 1: {1: 0}})
|
||||
env.send_actions({0: {1: 0}, 1: {1: 0}})
|
||||
obs, rew, dones, _, _ = env.poll()
|
||||
self.assertEqual(obs, {0: {0: 0}, 1: {0: 0}})
|
||||
|
||||
def testMultiAgentSample(self):
|
||||
act_space = gym.spaces.Discrete(2)
|
||||
obs_space = gym.spaces.Discrete(2)
|
||||
ev = CommonPolicyEvaluator(
|
||||
env_creator=lambda _: BasicMultiAgent(5),
|
||||
policy_graph={
|
||||
"p0": (MockPolicyGraph, obs_space, act_space, {}),
|
||||
"p1": (MockPolicyGraph, obs_space, act_space, {}),
|
||||
},
|
||||
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
|
||||
batch_steps=50)
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 50)
|
||||
self.assertEqual(batch.policy_batches["p0"].count, 150)
|
||||
self.assertEqual(batch.policy_batches["p1"].count, 100)
|
||||
self.assertEqual(
|
||||
batch.policy_batches["p0"]["t"].tolist(),
|
||||
list(range(25)) * 6)
|
||||
|
||||
def testMultiAgentSampleRoundRobin(self):
|
||||
act_space = gym.spaces.Discrete(2)
|
||||
obs_space = gym.spaces.Discrete(2)
|
||||
ev = CommonPolicyEvaluator(
|
||||
env_creator=lambda _: RoundRobinMultiAgent(5, increment_obs=True),
|
||||
policy_graph={
|
||||
"p0": (MockPolicyGraph, obs_space, act_space, {}),
|
||||
},
|
||||
policy_mapping_fn=lambda agent_id: "p0",
|
||||
batch_steps=50)
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 50)
|
||||
# since we round robin introduce agents into the env, some of the env
|
||||
# steps don't count as proper transitions
|
||||
self.assertEqual(batch.policy_batches["p0"].count, 42)
|
||||
self.assertEqual(
|
||||
batch.policy_batches["p0"]["obs"].tolist()[:10],
|
||||
[0, 1, 2, 3, 4] * 2)
|
||||
self.assertEqual(
|
||||
batch.policy_batches["p0"]["new_obs"].tolist()[:10],
|
||||
[1, 2, 3, 4, 5] * 2)
|
||||
self.assertEqual(
|
||||
batch.policy_batches["p0"]["rewards"].tolist()[:10],
|
||||
[100, 100, 100, 100, 0] * 2)
|
||||
self.assertEqual(
|
||||
batch.policy_batches["p0"]["dones"].tolist()[:10],
|
||||
[False, False, False, False, True] * 2)
|
||||
self.assertEqual(
|
||||
batch.policy_batches["p0"]["t"].tolist()[:10],
|
||||
[4, 9, 14, 19, 24, 5, 10, 15, 20, 25])
|
||||
|
||||
def testTrainMultiCartpoleSinglePolicy(self):
|
||||
n = 10
|
||||
register_env("multi_cartpole", lambda _: MultiCartpole(n))
|
||||
pg = PGAgent(env="multi_cartpole", config={"num_workers": 0})
|
||||
for i in range(100):
|
||||
result = pg.train()
|
||||
print("Iteration {}, reward {}, timesteps {}".format(
|
||||
i, result.episode_reward_mean, result.timesteps_total))
|
||||
if result.episode_reward_mean >= 50 * n:
|
||||
return
|
||||
raise Exception("failed to improve reward")
|
||||
|
||||
def _testWithOptimizer(self, optimizer_cls):
|
||||
n = 3
|
||||
env = gym.make("CartPole-v0")
|
||||
act_space = env.action_space
|
||||
obs_space = env.observation_space
|
||||
dqn_config = {"gamma": 0.95, "n_step": 3}
|
||||
if optimizer_cls == LocalSyncReplayOptimizer:
|
||||
# TODO: support replay with non-DQN graphs. Currently this can't
|
||||
# happen since the replay buffer doesn't encode extra fields like
|
||||
# "advantages" that PG uses.
|
||||
policies = {
|
||||
"p1": (DQNPolicyGraph, obs_space, act_space, {}),
|
||||
"p2": (DQNPolicyGraph, obs_space, act_space, dqn_config),
|
||||
}
|
||||
else:
|
||||
policies = {
|
||||
"p1": (PGPolicyGraph, obs_space, act_space, dqn_config),
|
||||
"p2": (DQNPolicyGraph, obs_space, act_space, dqn_config),
|
||||
}
|
||||
ev = CommonPolicyEvaluator(
|
||||
env_creator=lambda _: MultiCartpole(n),
|
||||
policy_graph=policies,
|
||||
policy_mapping_fn=lambda agent_id: ["p1", "p2"][agent_id % 2],
|
||||
batch_steps=50)
|
||||
if optimizer_cls == AsyncOptimizer:
|
||||
remote_evs = [CommonPolicyEvaluator.as_remote().remote(
|
||||
env_creator=lambda _: MultiCartpole(n),
|
||||
policy_graph=policies,
|
||||
policy_mapping_fn=lambda agent_id: ["p1", "p2"][agent_id % 2],
|
||||
batch_steps=50)]
|
||||
else:
|
||||
remote_evs = []
|
||||
optimizer = optimizer_cls({}, ev, remote_evs)
|
||||
ev.foreach_policy(
|
||||
lambda p, _: p.set_epsilon(0.02)
|
||||
if isinstance(p, DQNPolicyGraph) else None)
|
||||
for i in range(200):
|
||||
optimizer.step()
|
||||
result = collect_metrics(ev, remote_evs)
|
||||
if i % 20 == 0:
|
||||
ev.foreach_policy(
|
||||
lambda p, _: p.update_target()
|
||||
if isinstance(p, DQNPolicyGraph) else None)
|
||||
print("Iter {}, rew {}".format(i, result.policy_reward_mean))
|
||||
print("Total reward", result.episode_reward_mean)
|
||||
if result.episode_reward_mean >= 25 * n:
|
||||
return
|
||||
print(result)
|
||||
raise Exception("failed to improve reward")
|
||||
|
||||
def testMultiAgentSyncOptimizer(self):
|
||||
self._testWithOptimizer(LocalSyncOptimizer)
|
||||
|
||||
def testMultiAgentAsyncOptimizer(self):
|
||||
self._testWithOptimizer(AsyncOptimizer)
|
||||
|
||||
def testMultiAgentReplayOptimizer(self):
|
||||
self._testWithOptimizer(LocalSyncReplayOptimizer)
|
||||
|
||||
def testTrainMultiCartpoleManyPolicies(self):
|
||||
n = 20
|
||||
env = gym.make("CartPole-v0")
|
||||
act_space = env.action_space
|
||||
obs_space = env.observation_space
|
||||
policies = {}
|
||||
for i in range(20):
|
||||
policies["pg_{}".format(i)] = (
|
||||
PGPolicyGraph, obs_space, act_space, {})
|
||||
policy_ids = list(policies.keys())
|
||||
ev = CommonPolicyEvaluator(
|
||||
env_creator=lambda _: MultiCartpole(n),
|
||||
policy_graph=policies,
|
||||
policy_mapping_fn=lambda agent_id: random.choice(policy_ids),
|
||||
batch_steps=100)
|
||||
optimizer = LocalSyncOptimizer({}, ev, [])
|
||||
for i in range(100):
|
||||
optimizer.step()
|
||||
result = collect_metrics(ev)
|
||||
print("Iteration {}, rew {}".format(i, result.policy_reward_mean))
|
||||
print("Total reward", result.episode_reward_mean)
|
||||
if result.episode_reward_mean >= 25 * n:
|
||||
return
|
||||
raise Exception("failed to improve reward")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -284,14 +284,12 @@ class _MultiAgentEnvState(object):
|
||||
self.reset()
|
||||
|
||||
def poll(self):
|
||||
if self.last_obs is None:
|
||||
raise ValueError("Need to send action after polling")
|
||||
obs, rew, dones, info = (
|
||||
self.last_obs, self.last_rewards, self.last_dones, self.last_infos)
|
||||
self.last_obs = None
|
||||
self.last_rewards = None
|
||||
self.last_dones = None
|
||||
self.last_infos = None
|
||||
self.last_obs = {}
|
||||
self.last_rewards = {}
|
||||
self.last_dones = {"__all__": False}
|
||||
self.last_infos = {}
|
||||
return obs, rew, dones, info
|
||||
|
||||
def observe(self, obs, rewards, dones, infos):
|
||||
|
||||
@@ -2,31 +2,38 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import pickle
|
||||
import collections
|
||||
import gym
|
||||
import numpy as np
|
||||
import pickle
|
||||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.optimizers import MultiAgentBatch
|
||||
from ray.rllib.optimizers.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.optimizers.sample_batch import MultiAgentBatch, \
|
||||
DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.async_vector_env import AsyncVectorEnv
|
||||
from ray.rllib.utils.atari_wrappers import wrap_deepmind, is_atari
|
||||
from ray.rllib.utils.compression import pack
|
||||
from ray.rllib.utils.env_context import EnvContext
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
from ray.rllib.utils.multi_agent_env import MultiAgentEnv
|
||||
from ray.rllib.utils.policy_graph import PolicyGraph
|
||||
from ray.rllib.utils.sampler import AsyncSampler, SyncSampler
|
||||
from ray.rllib.utils.serving_env import ServingEnv
|
||||
from ray.rllib.utils.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
from ray.rllib.utils.vector_env import VectorEnv
|
||||
from ray.tune.result import TrainingResult
|
||||
|
||||
|
||||
def collect_metrics(local_evaluator, remote_evaluators):
|
||||
def collect_metrics(local_evaluator, remote_evaluators=[]):
|
||||
"""Gathers episode metrics from CommonPolicyEvaluator instances."""
|
||||
|
||||
episode_rewards = []
|
||||
episode_lengths = []
|
||||
policy_rewards = collections.defaultdict(list)
|
||||
metric_lists = ray.get(
|
||||
[a.apply.remote(lambda ev: ev.sampler.get_metrics())
|
||||
for a in remote_evaluators])
|
||||
@@ -35,6 +42,8 @@ def collect_metrics(local_evaluator, remote_evaluators):
|
||||
for episode in metrics:
|
||||
episode_lengths.append(episode.episode_length)
|
||||
episode_rewards.append(episode.episode_reward)
|
||||
for (_, policy_id), reward in episode.agent_rewards.items():
|
||||
policy_rewards[policy_id].append(reward)
|
||||
if episode_rewards:
|
||||
min_reward = min(episode_rewards)
|
||||
max_reward = max(episode_rewards)
|
||||
@@ -45,19 +54,22 @@ def collect_metrics(local_evaluator, remote_evaluators):
|
||||
avg_length = np.mean(episode_lengths)
|
||||
timesteps = np.sum(episode_lengths)
|
||||
|
||||
for policy_id, rewards in policy_rewards.copy().items():
|
||||
policy_rewards[policy_id] = np.mean(rewards)
|
||||
|
||||
return TrainingResult(
|
||||
episode_reward_max=max_reward,
|
||||
episode_reward_min=min_reward,
|
||||
episode_reward_mean=avg_reward,
|
||||
episode_len_mean=avg_length,
|
||||
episodes_total=len(episode_lengths),
|
||||
timesteps_this_iter=timesteps)
|
||||
timesteps_this_iter=timesteps,
|
||||
policy_reward_mean=dict(policy_rewards))
|
||||
|
||||
|
||||
class CommonPolicyEvaluator(PolicyEvaluator):
|
||||
"""Policy evaluator implementation that operates on a rllib.PolicyGraph.
|
||||
|
||||
TODO: multi-agent
|
||||
TODO: multi-gpu
|
||||
|
||||
Examples:
|
||||
@@ -65,9 +77,10 @@ class CommonPolicyEvaluator(PolicyEvaluator):
|
||||
>>> evaluator = CommonPolicyEvaluator(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy_graph=PGPolicyGraph)
|
||||
>>> print(evaluator.sample().keys())
|
||||
{"obs": [[...]], "actions": [[...]], "rewards": [[...]],
|
||||
"dones": [[...]], "new_obs": [[...]]}
|
||||
>>> print(evaluator.sample())
|
||||
SampleBatch({
|
||||
"obs": [[...]], "actions": [[...]], "rewards": [[...]],
|
||||
"dones": [[...]], "new_obs": [[...]]})
|
||||
|
||||
# Creating policy evaluators using optimizer_cls.make().
|
||||
>>> optimizer = LocalSyncOptimizer.make(
|
||||
@@ -78,6 +91,28 @@ class CommonPolicyEvaluator(PolicyEvaluator):
|
||||
},
|
||||
num_workers=10)
|
||||
>>> for _ in range(10): optimizer.step()
|
||||
|
||||
# Creating a multi-agent policy evaluator
|
||||
>>> evaluator = CommonPolicyEvaluator(
|
||||
env_creator=lambda _: MultiAgentTrafficGrid(num_cars=25),
|
||||
policy_graph={
|
||||
# Use an ensemble of two policies for car agents
|
||||
"car_policy1":
|
||||
(PGPolicyGraph, Box(...), Discrete(...), {"gamma": 0.99}),
|
||||
"car_policy2":
|
||||
(PGPolicyGraph, Box(...), Discrete(...), {"gamma": 0.95}),
|
||||
# Use a single shared policy for all traffic lights
|
||||
"traffic_light_policy":
|
||||
(PGPolicyGraph, Box(...), Discrete(...), {}),
|
||||
},
|
||||
policy_mapping_fn=lambda agent_id:
|
||||
random.choice(["car_policy1", "car_policy2"])
|
||||
if agent_id.startswith("car_") else "traffic_light_policy")
|
||||
>>> print(evaluator.sample().keys())
|
||||
MultiAgentBatch({
|
||||
"car_policy1": SampleBatch(...),
|
||||
"car_policy2": SampleBatch(...),
|
||||
"traffic_light_policy": SampleBatch(...)})
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@@ -88,6 +123,7 @@ class CommonPolicyEvaluator(PolicyEvaluator):
|
||||
self,
|
||||
env_creator,
|
||||
policy_graph,
|
||||
policy_mapping_fn=None,
|
||||
tf_session_creator=None,
|
||||
batch_steps=100,
|
||||
batch_mode="truncate_episodes",
|
||||
@@ -99,14 +135,22 @@ class CommonPolicyEvaluator(PolicyEvaluator):
|
||||
observation_filter="NoFilter",
|
||||
env_config=None,
|
||||
model_config=None,
|
||||
policy_config=None):
|
||||
policy_config=None,
|
||||
worker_index=0):
|
||||
"""Initialize a policy evaluator.
|
||||
|
||||
Arguments:
|
||||
env_creator (func): Function that returns a gym.Env given an
|
||||
env config dict.
|
||||
policy_graph (class): A class implementing rllib.PolicyGraph or
|
||||
rllib.TFPolicyGraph.
|
||||
EnvContext wrapped configuration.
|
||||
policy_graph (class|dict): Either a class implementing
|
||||
PolicyGraph, or a dictionary of policy id strings to
|
||||
(PolicyGraph, obs_space, action_space, config) tuples. If a
|
||||
dict is specified, then we are in multi-agent mode and a
|
||||
policy_mapping_fn should also be set.
|
||||
policy_mapping_fn (func): A function that maps agent ids to
|
||||
policy ids in multi-agent mode. This function will be called
|
||||
each time a new agent appears in an episode, to bind that agent
|
||||
to a policy for the duration of the episode.
|
||||
tf_session_creator (func): A function that returns a TF session.
|
||||
This is optional and only useful with TFPolicyGraph.
|
||||
batch_steps (int): The target number of env transitions to include
|
||||
@@ -138,19 +182,26 @@ class CommonPolicyEvaluator(PolicyEvaluator):
|
||||
observation_filter (str): Name of observation filter to use.
|
||||
env_config (dict): Config to pass to the env creator.
|
||||
model_config (dict): Config to use when creating the policy model.
|
||||
policy_config (dict): Config to pass to the policy.
|
||||
policy_config (dict): Config to pass to the policy. In the
|
||||
multi-agent case, this config will be merged with the
|
||||
per-policy configs specified by `policy_graph`.
|
||||
worker_index (int): For remote evaluators, this should be set to a
|
||||
non-zero and unique value. This index is passed to created envs
|
||||
through EnvContext so that envs can be configured per worker.
|
||||
"""
|
||||
|
||||
env_config = env_config or {}
|
||||
env_context = EnvContext(env_config or {}, worker_index)
|
||||
policy_config = policy_config or {}
|
||||
model_config = model_config or {}
|
||||
policy_mapping_fn = (
|
||||
policy_mapping_fn or (lambda agent_id: DEFAULT_POLICY_ID))
|
||||
self.env_creator = env_creator
|
||||
self.policy_graph = policy_graph
|
||||
self.batch_steps = batch_steps
|
||||
self.batch_mode = batch_mode
|
||||
self.compress_observations = compress_observations
|
||||
|
||||
self.env = env_creator(env_config)
|
||||
self.env = env_creator(env_context)
|
||||
if isinstance(self.env, VectorEnv) or \
|
||||
isinstance(self.env, ServingEnv) or \
|
||||
isinstance(self.env, MultiAgentEnv) or \
|
||||
@@ -169,32 +220,29 @@ class CommonPolicyEvaluator(PolicyEvaluator):
|
||||
self.env = wrap(self.env)
|
||||
|
||||
def make_env():
|
||||
return wrap(env_creator(env_config))
|
||||
return wrap(env_creator(env_context))
|
||||
|
||||
if issubclass(policy_graph, TFPolicyGraph):
|
||||
self.tf_sess = None
|
||||
policy_dict = _validate_and_canonicalize(policy_graph, self.env)
|
||||
if _has_tensorflow_graph(policy_dict):
|
||||
with tf.Graph().as_default():
|
||||
if tf_session_creator:
|
||||
self.sess = tf_session_creator()
|
||||
self.tf_sess = tf_session_creator()
|
||||
else:
|
||||
self.sess = tf.Session(config=tf.ConfigProto(
|
||||
self.tf_sess = tf.Session(config=tf.ConfigProto(
|
||||
gpu_options=tf.GPUOptions(allow_growth=True)))
|
||||
with self.sess.as_default():
|
||||
policy = policy_graph(
|
||||
self.env.observation_space, self.env.action_space,
|
||||
policy_config)
|
||||
with self.tf_sess.as_default():
|
||||
self.policy_map = self._build_policy_map(
|
||||
policy_dict, policy_config)
|
||||
else:
|
||||
policy = policy_graph(
|
||||
self.env.observation_space, self.env.action_space,
|
||||
policy_config)
|
||||
self.policy_map = self._build_policy_map(
|
||||
policy_dict, policy_config)
|
||||
|
||||
self.policy_map = {
|
||||
"default": policy
|
||||
}
|
||||
self.multiagent = self.policy_map.keys() != set(DEFAULT_POLICY_ID)
|
||||
|
||||
self.filters = {
|
||||
# TODO(ekl) make the obs space dependent on policy
|
||||
policy_id: get_filter(
|
||||
observation_filter, self.env.observation_space.shape)
|
||||
observation_filter, policy.observation_space.shape)
|
||||
for (policy_id, policy) in self.policy_map.items()
|
||||
}
|
||||
|
||||
@@ -218,15 +266,25 @@ class CommonPolicyEvaluator(PolicyEvaluator):
|
||||
"Unsupported batch mode: {}".format(self.batch_mode))
|
||||
if sample_async:
|
||||
self.sampler = AsyncSampler(
|
||||
self.async_env, self.policy_map, lambda agent_id: "default",
|
||||
self.async_env, self.policy_map, policy_mapping_fn,
|
||||
self.filters, batch_steps, horizon=episode_horizon,
|
||||
pack=pack_episodes)
|
||||
pack=pack_episodes, tf_sess=self.tf_sess)
|
||||
self.sampler.start()
|
||||
else:
|
||||
self.sampler = SyncSampler(
|
||||
self.async_env, self.policy_map, lambda agent_id: "default",
|
||||
self.async_env, self.policy_map, policy_mapping_fn,
|
||||
self.filters, batch_steps, horizon=episode_horizon,
|
||||
pack=pack_episodes)
|
||||
pack=pack_episodes, tf_sess=self.tf_sess)
|
||||
|
||||
def _build_policy_map(self, policy_dict, policy_config):
|
||||
policy_map = {}
|
||||
for name, (cls, obs_space, act_space, conf) in sorted(
|
||||
policy_dict.items()):
|
||||
merged_conf = policy_config.copy()
|
||||
merged_conf.update(conf)
|
||||
with tf.variable_scope(name):
|
||||
policy_map[name] = cls(obs_space, act_space, merged_conf)
|
||||
return policy_map
|
||||
|
||||
def sample(self):
|
||||
"""Evaluate the current policies and return a batch of experiences.
|
||||
@@ -254,10 +312,15 @@ class CommonPolicyEvaluator(PolicyEvaluator):
|
||||
|
||||
return batch
|
||||
|
||||
def for_policy(self, func):
|
||||
"""Apply the given function to this evaluator's default policy."""
|
||||
def for_policy(self, func, policy_id=DEFAULT_POLICY_ID):
|
||||
"""Apply the given function to the specified policy graph."""
|
||||
|
||||
return func(self.policy_map["default"])
|
||||
return func(self.policy_map[policy_id])
|
||||
|
||||
def foreach_policy(self, func):
|
||||
"""Apply the given function to each (policy, policy_id) tuple."""
|
||||
|
||||
return [func(policy, pid) for pid, policy in self.policy_map.items()]
|
||||
|
||||
def sync_filters(self, new_filters):
|
||||
"""Changes self's filter to given and rebases any accumulated delta.
|
||||
@@ -286,28 +349,126 @@ class CommonPolicyEvaluator(PolicyEvaluator):
|
||||
return return_filters
|
||||
|
||||
def get_weights(self):
|
||||
return self.policy_map["default"].get_weights()
|
||||
return {
|
||||
pid: policy.get_weights()
|
||||
for pid, policy in self.policy_map.items()}
|
||||
|
||||
def set_weights(self, weights):
|
||||
return self.policy_map["default"].set_weights(weights)
|
||||
for pid, w in weights.items():
|
||||
self.policy_map[pid].set_weights(w)
|
||||
|
||||
def compute_gradients(self, samples):
|
||||
return self.policy_map["default"].compute_gradients(samples)
|
||||
if isinstance(samples, MultiAgentBatch):
|
||||
grad_out, info_out = {}, {}
|
||||
if self.tf_sess is not None:
|
||||
builder = TFRunBuilder(self.tf_sess, "compute_gradients")
|
||||
for pid, batch in samples.policy_batches.items():
|
||||
grad_out[pid], info_out[pid] = (
|
||||
self.policy_map[pid].build_compute_gradients(
|
||||
builder, batch))
|
||||
grad_out = {k: builder.get(v) for k, v in grad_out.items()}
|
||||
info_out = {k: builder.get(v) for k, v in info_out.items()}
|
||||
else:
|
||||
for pid, batch in samples.policy_batches.items():
|
||||
grad_out[pid], info_out[pid] = (
|
||||
self.policy_map[pid].compute_gradients(batch))
|
||||
return grad_out, info_out
|
||||
else:
|
||||
return self.policy_map[DEFAULT_POLICY_ID].compute_gradients(
|
||||
samples)
|
||||
|
||||
def apply_gradients(self, grads):
|
||||
return self.policy_map["default"].apply_gradients(grads)
|
||||
if isinstance(grads, dict):
|
||||
if self.tf_sess is not None:
|
||||
builder = TFRunBuilder(self.tf_sess, "apply_gradients")
|
||||
outputs = {
|
||||
pid: self.policy_map[pid].build_apply_gradients(
|
||||
builder, grad)
|
||||
for pid, grad in grads.items()
|
||||
}
|
||||
return {
|
||||
k: builder.get(v) for k, v in outputs.items()
|
||||
}
|
||||
else:
|
||||
return {
|
||||
pid: self.policy_map[pid].apply_gradients(g)
|
||||
for pid, g in grads.items()
|
||||
}
|
||||
else:
|
||||
return self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads)
|
||||
|
||||
def compute_apply(self, samples):
|
||||
grad_fetch, apply_fetch = self.policy_map["default"].compute_apply(
|
||||
samples)
|
||||
return grad_fetch
|
||||
if isinstance(samples, MultiAgentBatch):
|
||||
info_out = {}
|
||||
if self.tf_sess is not None:
|
||||
builder = TFRunBuilder(self.tf_sess, "compute_apply")
|
||||
for pid, batch in samples.policy_batches.items():
|
||||
info_out[pid], _ = (
|
||||
self.policy_map[pid].build_compute_apply(
|
||||
builder, batch))
|
||||
info_out = {k: builder.get(v) for k, v in info_out.items()}
|
||||
else:
|
||||
for pid, batch in samples.policy_batches.items():
|
||||
info_out[pid], _ = (
|
||||
self.policy_map[pid].compute_apply(batch))
|
||||
return info_out
|
||||
else:
|
||||
grad_fetch, apply_fetch = (
|
||||
self.policy_map[DEFAULT_POLICY_ID].compute_apply(samples))
|
||||
return grad_fetch
|
||||
|
||||
def save(self):
|
||||
filters = self.get_filters(flush_after=True)
|
||||
state = self.policy_map["default"].get_state()
|
||||
state = {
|
||||
pid: self.policy_map[pid].get_state()
|
||||
for pid in self.policy_map
|
||||
}
|
||||
return pickle.dumps({"filters": filters, "state": state})
|
||||
|
||||
def restore(self, objs):
|
||||
objs = pickle.loads(objs)
|
||||
self.sync_filters(objs["filters"])
|
||||
self.policy_map["default"].set_state(objs["state"])
|
||||
for pid, state in objs["state"].items():
|
||||
self.policy_map[pid].set_state(state)
|
||||
|
||||
|
||||
def _validate_and_canonicalize(policy_graph, env):
|
||||
if isinstance(policy_graph, dict):
|
||||
for k, v in policy_graph.items():
|
||||
if not isinstance(k, str):
|
||||
raise ValueError(
|
||||
"policy_graph keys must be strs, got {}".format(type(k)))
|
||||
if not isinstance(v, tuple) or len(v) != 4:
|
||||
raise ValueError(
|
||||
"policy_graph values must be tuples of "
|
||||
"(cls, obs_space, action_space, config), got {}".format(v))
|
||||
if not issubclass(v[0], PolicyGraph):
|
||||
raise ValueError(
|
||||
"policy_graph tuple value 0 must be a rllib.PolicyGraph "
|
||||
"class, got {}".format(v[0]))
|
||||
if not isinstance(v[1], gym.Space):
|
||||
raise ValueError(
|
||||
"policy_graph tuple value 1 (observation_space) must be a "
|
||||
"gym.Space, got {}".format(type(v[1])))
|
||||
if not isinstance(v[2], gym.Space):
|
||||
raise ValueError(
|
||||
"policy_graph tuple value 2 (action_space) must be a "
|
||||
"gym.Space, got {}".format(type(v[2])))
|
||||
if not isinstance(v[3], dict):
|
||||
raise ValueError(
|
||||
"policy_graph tuple value 3 (config) must be a dict, "
|
||||
"got {}".format(type(v[3])))
|
||||
return policy_graph
|
||||
elif not issubclass(policy_graph, PolicyGraph):
|
||||
raise ValueError("policy_graph must be a rllib.PolicyGraph class")
|
||||
else:
|
||||
return {
|
||||
DEFAULT_POLICY_ID: (
|
||||
policy_graph, env.observation_space, env.action_space, {})}
|
||||
|
||||
|
||||
def _has_tensorflow_graph(policy_dict):
|
||||
for policy, _, _, _ in policy_dict.values():
|
||||
if issubclass(policy, TFPolicyGraph):
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
class EnvContext(dict):
|
||||
"""Wraps env configurations to include extra rllib metadata.
|
||||
|
||||
These attributes can be used to parameterize environments per process.
|
||||
For example, one might use `worker_index` to control which data file an
|
||||
environment reads in on initialization.
|
||||
|
||||
RLlib auto-sets these attributes when constructing registered envs.
|
||||
|
||||
Attributes:
|
||||
worker_index (int): When there are multiple workers created, this
|
||||
uniquely identifies the worker the env is created in.
|
||||
"""
|
||||
|
||||
def __init__(self, env_config, worker_index):
|
||||
dict.__init__(self, env_config)
|
||||
self.worker_index = worker_index
|
||||
@@ -15,6 +15,10 @@ class PolicyGraph(object):
|
||||
find TFPolicyGraph simpler to implement. TFPolicyGraph also enables RLlib
|
||||
to apply TensorFlow-specific optimizations such as fusing multiple policy
|
||||
graphs and multi-GPU support.
|
||||
|
||||
Attributes:
|
||||
observation_space (gym.Space): Observation space of the policy.
|
||||
action_space (gym.Space): Action space of the policy.
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
|
||||
@@ -10,6 +10,7 @@ import threading
|
||||
from ray.rllib.optimizers.sample_batch import MultiAgentSampleBatchBuilder, \
|
||||
MultiAgentBatch
|
||||
from ray.rllib.utils.async_vector_env import AsyncVectorEnv
|
||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
|
||||
|
||||
RolloutMetrics = namedtuple(
|
||||
@@ -30,7 +31,7 @@ class SyncSampler(object):
|
||||
|
||||
def __init__(
|
||||
self, env, policies, policy_mapping_fn, obs_filters,
|
||||
num_local_steps, horizon=None, pack=False):
|
||||
num_local_steps, horizon=None, pack=False, tf_sess=None):
|
||||
self.async_vector_env = AsyncVectorEnv.wrap_async(env)
|
||||
self.num_local_steps = num_local_steps
|
||||
self.horizon = horizon
|
||||
@@ -39,7 +40,8 @@ class SyncSampler(object):
|
||||
self._obs_filters = obs_filters
|
||||
self.rollout_provider = _env_runner(
|
||||
self.async_vector_env, self.policies, self.policy_mapping_fn,
|
||||
self.num_local_steps, self.horizon, self._obs_filters, pack)
|
||||
self.num_local_steps, self.horizon, self._obs_filters, pack,
|
||||
tf_sess)
|
||||
self.metrics_queue = queue.Queue()
|
||||
|
||||
def get_data(self):
|
||||
@@ -68,7 +70,7 @@ class AsyncSampler(threading.Thread):
|
||||
|
||||
def __init__(
|
||||
self, env, policies, policy_mapping_fn, obs_filters,
|
||||
num_local_steps, horizon=None, pack=False):
|
||||
num_local_steps, horizon=None, pack=False, tf_sess=None):
|
||||
for _, f in obs_filters.items():
|
||||
assert getattr(f, "is_concurrent", False), \
|
||||
"Observation Filter must support concurrent updates."
|
||||
@@ -83,6 +85,7 @@ class AsyncSampler(threading.Thread):
|
||||
self._obs_filters = obs_filters
|
||||
self.daemon = True
|
||||
self.pack = pack
|
||||
self.tf_sess = tf_sess
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
@@ -94,7 +97,8 @@ class AsyncSampler(threading.Thread):
|
||||
def _run(self):
|
||||
rollout_provider = _env_runner(
|
||||
self.async_vector_env, self.policies, self.policy_mapping_fn,
|
||||
self.num_local_steps, self.horizon, self._obs_filters, self.pack)
|
||||
self.num_local_steps, self.horizon, self._obs_filters, self.pack,
|
||||
self.tf_sess)
|
||||
while True:
|
||||
# The timeout variable exists because apparently, if one worker
|
||||
# dies, the other workers won't die with it, unless the timeout is
|
||||
@@ -140,7 +144,7 @@ class AsyncSampler(threading.Thread):
|
||||
|
||||
def _env_runner(
|
||||
async_vector_env, policies, policy_mapping_fn, num_local_steps,
|
||||
horizon, obs_filters, pack):
|
||||
horizon, obs_filters, pack, tf_sess=None):
|
||||
"""This implements the common experience collection logic.
|
||||
|
||||
Args:
|
||||
@@ -156,6 +160,8 @@ def _env_runner(
|
||||
observations for the policy.
|
||||
pack (bool): Whether to pack multiple episodes into each batch. This
|
||||
guarantees batches will be exactly `num_local_steps` in size.
|
||||
tf_sess (Session|None): Optional tensorflow session to use for batching
|
||||
TF policy evaluations.
|
||||
|
||||
Yields:
|
||||
rollout (SampleBatch): Object containing state, action, reward,
|
||||
@@ -192,6 +198,9 @@ def _env_runner(
|
||||
# Map of policy_id to list of PolicyEvalData
|
||||
to_eval = defaultdict(list)
|
||||
|
||||
# Map of env_id -> agent_id -> action replies
|
||||
actions_to_send = defaultdict(dict)
|
||||
|
||||
# For each environment
|
||||
for env_id, agent_obs in unfiltered_obs.items():
|
||||
new_episode = env_id not in active_episodes
|
||||
@@ -209,11 +218,13 @@ def _env_runner(
|
||||
dict(episode.agent_rewards))
|
||||
else:
|
||||
all_done = False
|
||||
# At least send an empty dict if not done
|
||||
actions_to_send[env_id]
|
||||
|
||||
# For each agent in the environment
|
||||
for agent_id, raw_obs in agent_obs.items():
|
||||
policy_id = episode.policy_for(agent_id)
|
||||
filtered_obs = obs_filters[policy_id](raw_obs)
|
||||
filtered_obs = _get_or_raise(obs_filters, policy_id)(raw_obs)
|
||||
agent_done = bool(all_done or dones[env_id].get(agent_id))
|
||||
if not agent_done:
|
||||
to_eval[policy_id].append(
|
||||
@@ -263,24 +274,40 @@ def _env_runner(
|
||||
episode = active_episodes[env_id]
|
||||
for agent_id, raw_obs in resetted_obs.items():
|
||||
policy_id = episode.policy_for(agent_id)
|
||||
filtered_obs = obs_filters[policy_id](raw_obs)
|
||||
filtered_obs = _get_or_raise(
|
||||
obs_filters, policy_id)(raw_obs)
|
||||
episode.set_last_observation(agent_id, filtered_obs)
|
||||
to_eval[policy_id].append(
|
||||
PolicyEvalData(
|
||||
env_id, agent_id, filtered_obs,
|
||||
episode.rnn_state_for(agent_id)))
|
||||
|
||||
# Map of env_id -> agent_id -> action
|
||||
action_dict = defaultdict(dict)
|
||||
|
||||
# TODO(ekl) fuse all policy evaluation into one TF run
|
||||
# Batch eval policy actions if possible
|
||||
if tf_sess:
|
||||
builder = TFRunBuilder(tf_sess, "policy_eval")
|
||||
else:
|
||||
builder = None
|
||||
eval_results = {}
|
||||
rnn_in_cols = {}
|
||||
for policy_id, eval_data in to_eval.items():
|
||||
rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data])
|
||||
actions, rnn_out_cols, pi_info_cols = \
|
||||
policies[policy_id].compute_actions(
|
||||
[t.obs for t in eval_data], rnn_in_cols, is_training=True)
|
||||
rnn_in = _to_column_format([t.rnn_state for t in eval_data])
|
||||
rnn_in_cols[policy_id] = rnn_in
|
||||
policy = _get_or_raise(policies, policy_id)
|
||||
if builder:
|
||||
eval_results[policy_id] = policy.build_compute_actions(
|
||||
builder, [t.obs for t in eval_data], rnn_in,
|
||||
is_training=True)
|
||||
else:
|
||||
eval_results[policy_id] = policy.compute_actions(
|
||||
[t.obs for t in eval_data], rnn_in, is_training=True)
|
||||
if builder:
|
||||
eval_results = {k: builder.get(v) for k, v in eval_results.items()}
|
||||
|
||||
# Record the policy eval results
|
||||
for policy_id, eval_data in to_eval.items():
|
||||
actions, rnn_out_cols, pi_info_cols = eval_results[policy_id]
|
||||
# Add RNN state info
|
||||
for f_i, column in enumerate(rnn_in_cols):
|
||||
for f_i, column in enumerate(rnn_in_cols[policy_id]):
|
||||
pi_info_cols["state_in_{}".format(f_i)] = column
|
||||
for f_i, column in enumerate(rnn_out_cols):
|
||||
pi_info_cols["state_out_{}".format(f_i)] = column
|
||||
@@ -288,7 +315,7 @@ def _env_runner(
|
||||
for i, action in enumerate(actions):
|
||||
env_id = eval_data[i].env_id
|
||||
agent_id = eval_data[i].agent_id
|
||||
action_dict[env_id][agent_id] = action
|
||||
actions_to_send[env_id][agent_id] = action
|
||||
episode = active_episodes[env_id]
|
||||
episode.set_rnn_state(agent_id, [c[i] for c in rnn_out_cols])
|
||||
episode.set_last_pi_info(
|
||||
@@ -302,7 +329,7 @@ def _env_runner(
|
||||
|
||||
# Return computed actions to ready envs. We also send to envs that have
|
||||
# taken off-policy actions; those envs are free to ignore the action.
|
||||
async_vector_env.send_actions(dict(action_dict))
|
||||
async_vector_env.send_actions(dict(actions_to_send))
|
||||
|
||||
|
||||
def _to_column_format(rnn_state_rows):
|
||||
@@ -311,6 +338,14 @@ def _to_column_format(rnn_state_rows):
|
||||
[row[i] for row in rnn_state_rows] for i in range(num_cols)]
|
||||
|
||||
|
||||
def _get_or_raise(mapping, policy_id):
|
||||
if policy_id not in mapping:
|
||||
raise ValueError(
|
||||
"Could not find policy for agent: agent policy id `{}` not "
|
||||
"in policy map keys {}.".format(policy_id, mapping.keys()))
|
||||
return mapping[policy_id]
|
||||
|
||||
|
||||
class _MultiAgentEpisode(object):
|
||||
def __init__(self, policies, policy_mapping_fn, batch_builder_factory):
|
||||
self.batch_builder = batch_builder_factory()
|
||||
@@ -327,8 +362,10 @@ class _MultiAgentEpisode(object):
|
||||
|
||||
def add_agent_rewards(self, reward_dict):
|
||||
for agent_id, reward in reward_dict.items():
|
||||
self.agent_rewards[agent_id] += reward
|
||||
self.total_reward += reward
|
||||
if reward is not None:
|
||||
self.agent_rewards[
|
||||
agent_id, self.policy_for(agent_id)] += reward
|
||||
self.total_reward += reward
|
||||
|
||||
def policy_for(self, agent_id):
|
||||
if agent_id not in self._agent_to_policy:
|
||||
|
||||
@@ -35,6 +35,8 @@ class ServingEnv(threading.Thread):
|
||||
def __init__(self, action_space, observation_space, max_concurrent=100):
|
||||
"""Initialize a serving env.
|
||||
|
||||
ServingEnv subclasses must call this during their __init__.
|
||||
|
||||
Arguments:
|
||||
action_space (gym.Space): Action space of the env.
|
||||
observation_space (gym.Space): Observation space of the env.
|
||||
|
||||
@@ -6,6 +6,7 @@ import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.rllib.utils.policy_graph import PolicyGraph
|
||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
|
||||
|
||||
class TFPolicyGraph(PolicyGraph):
|
||||
@@ -29,11 +30,15 @@ class TFPolicyGraph(PolicyGraph):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, sess, obs_input, action_sampler, loss, loss_inputs,
|
||||
self, observation_space, action_space, sess, obs_input,
|
||||
action_sampler, loss, loss_inputs,
|
||||
is_training, state_inputs=None, state_outputs=None):
|
||||
"""Initialize the policy.
|
||||
|
||||
Arguments:
|
||||
observation_space (gym.Space): Observation space of the env.
|
||||
action_space (gym.Space): Action space of the env.
|
||||
sess (Session): TensorFlow session to use.
|
||||
obs_input (Tensor): input placeholder for observations.
|
||||
action_sampler (Tensor): Tensor for sampling an action.
|
||||
loss (Tensor): scalar policy loss output tensor.
|
||||
@@ -46,6 +51,8 @@ class TFPolicyGraph(PolicyGraph):
|
||||
state_outputs (list): list of initial state values.
|
||||
"""
|
||||
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
self._sess = sess
|
||||
self._obs_input = obs_input
|
||||
self._sampler = action_sampler
|
||||
@@ -55,7 +62,9 @@ class TFPolicyGraph(PolicyGraph):
|
||||
self._state_inputs = state_inputs or []
|
||||
self._state_outputs = state_outputs or []
|
||||
self._optimizer = self.optimizer()
|
||||
self._grads_and_vars = self.gradients(self._optimizer)
|
||||
self._grads_and_vars = [
|
||||
(g, v) for (g, v) in self.gradients(self._optimizer)
|
||||
if g is not None]
|
||||
self._grads = [g for (g, v) in self._grads_and_vars]
|
||||
self._apply_op = self._optimizer.apply_gradients(self._grads_and_vars)
|
||||
self._variables = ray.experimental.TensorFlowVariables(
|
||||
@@ -64,21 +73,27 @@ class TFPolicyGraph(PolicyGraph):
|
||||
assert len(self._state_inputs) == len(self._state_outputs) == \
|
||||
len(self.get_initial_state())
|
||||
|
||||
def compute_actions(
|
||||
self, obs_batch, state_batches=None, is_training=False):
|
||||
def build_compute_actions(
|
||||
self, builder, obs_batch, state_batches=None, is_training=False):
|
||||
state_batches = state_batches or []
|
||||
assert len(self._state_inputs) == len(state_batches), \
|
||||
(self._state_inputs, state_batches)
|
||||
feed_dict = self.extra_compute_action_feed_dict()
|
||||
feed_dict[self._obs_input] = obs_batch
|
||||
feed_dict[self._is_training] = is_training
|
||||
for ph, value in zip(self._state_inputs, state_batches):
|
||||
feed_dict[ph] = value
|
||||
fetches = self._sess.run(
|
||||
([self._sampler] + self._state_outputs +
|
||||
[self.extra_compute_action_fetches()]), feed_dict=feed_dict)
|
||||
builder.add_feed_dict(self.extra_compute_action_feed_dict())
|
||||
builder.add_feed_dict({self._obs_input: obs_batch})
|
||||
builder.add_feed_dict({self._is_training: is_training})
|
||||
builder.add_feed_dict(dict(zip(self._state_inputs, state_batches)))
|
||||
fetches = builder.add_fetches(
|
||||
[self._sampler] + self._state_outputs +
|
||||
[self.extra_compute_action_fetches()])
|
||||
return fetches[0], fetches[1:-1], fetches[-1]
|
||||
|
||||
def compute_actions(
|
||||
self, obs_batch, state_batches=None, is_training=False):
|
||||
builder = TFRunBuilder(self._sess, "compute_actions")
|
||||
fetches = self.build_compute_actions(
|
||||
builder, obs_batch, state_batches, is_training)
|
||||
return builder.get(fetches)
|
||||
|
||||
def _get_loss_inputs_dict(self, postprocessed_batch):
|
||||
feed_dict = {}
|
||||
for key, ph in self._loss_inputs:
|
||||
@@ -90,37 +105,48 @@ class TFPolicyGraph(PolicyGraph):
|
||||
feed_dict[ph] = postprocessed_batch[key]
|
||||
return feed_dict
|
||||
|
||||
def compute_gradients(self, postprocessed_batch):
|
||||
feed_dict = self.extra_compute_grad_feed_dict()
|
||||
feed_dict[self._is_training] = True
|
||||
feed_dict.update(self._get_loss_inputs_dict(postprocessed_batch))
|
||||
fetches = self._sess.run(
|
||||
[self._grads, self.extra_compute_grad_fetches()],
|
||||
feed_dict=feed_dict)
|
||||
def build_compute_gradients(self, builder, postprocessed_batch):
|
||||
builder.add_feed_dict(self.extra_compute_grad_feed_dict())
|
||||
builder.add_feed_dict({self._is_training: True})
|
||||
builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch))
|
||||
fetches = builder.add_fetches(
|
||||
[self._grads, self.extra_compute_grad_fetches()])
|
||||
return fetches[0], fetches[1]
|
||||
|
||||
def apply_gradients(self, gradients):
|
||||
def compute_gradients(self, postprocessed_batch):
|
||||
builder = TFRunBuilder(self._sess, "compute_gradients")
|
||||
fetches = self.build_compute_gradients(builder, postprocessed_batch)
|
||||
return builder.get(fetches)
|
||||
|
||||
def build_apply_gradients(self, builder, gradients):
|
||||
assert len(gradients) == len(self._grads), (gradients, self._grads)
|
||||
feed_dict = self.extra_apply_grad_feed_dict()
|
||||
feed_dict[self._is_training] = True
|
||||
for ph, value in zip(self._grads, gradients):
|
||||
feed_dict[ph] = value
|
||||
fetches = self._sess.run(
|
||||
[self._apply_op, self.extra_apply_grad_fetches()],
|
||||
feed_dict=feed_dict)
|
||||
builder.add_feed_dict(self.extra_apply_grad_feed_dict())
|
||||
builder.add_feed_dict({self._is_training: True})
|
||||
builder.add_feed_dict(dict(zip(self._grads, gradients)))
|
||||
fetches = builder.add_fetches(
|
||||
[self._apply_op, self.extra_apply_grad_fetches()])
|
||||
return fetches[1]
|
||||
|
||||
def compute_apply(self, postprocessed_batch):
|
||||
feed_dict = self.extra_compute_grad_feed_dict()
|
||||
feed_dict.update(self.extra_apply_grad_feed_dict())
|
||||
feed_dict.update(self._get_loss_inputs_dict(postprocessed_batch))
|
||||
feed_dict[self._is_training] = True
|
||||
fetches = self._sess.run(
|
||||
def apply_gradients(self, gradients):
|
||||
builder = TFRunBuilder(self._sess, "apply_gradients")
|
||||
fetches = self.build_apply_gradients(builder, gradients)
|
||||
return builder.get(fetches)
|
||||
|
||||
def build_compute_apply(self, builder, postprocessed_batch):
|
||||
builder.add_feed_dict(self.extra_compute_grad_feed_dict())
|
||||
builder.add_feed_dict(self.extra_apply_grad_feed_dict())
|
||||
builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch))
|
||||
builder.add_feed_dict({self._is_training: True})
|
||||
fetches = builder.add_fetches(
|
||||
[self._apply_op, self.extra_compute_grad_fetches(),
|
||||
self.extra_apply_grad_fetches()],
|
||||
feed_dict=feed_dict)
|
||||
self.extra_apply_grad_fetches()])
|
||||
return fetches[1], fetches[2]
|
||||
|
||||
def compute_apply(self, postprocessed_batch):
|
||||
builder = TFRunBuilder(self._sess, "compute_apply")
|
||||
fetches = self.build_compute_apply(builder, postprocessed_batch)
|
||||
return builder.get(fetches)
|
||||
|
||||
def get_weights(self):
|
||||
return self._variables.get_flat()
|
||||
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.client import timeline
|
||||
|
||||
|
||||
class TFRunBuilder(object):
|
||||
"""Used to incrementally build up a TensorFlow run.
|
||||
|
||||
This is particularly useful for batching ops from multiple different
|
||||
policies in the multi-agent setting.
|
||||
"""
|
||||
|
||||
def __init__(self, session, debug_name):
|
||||
self.session = session
|
||||
self.debug_name = debug_name
|
||||
self.feed_dict = {}
|
||||
self.fetches = []
|
||||
self._executed = None
|
||||
|
||||
def add_feed_dict(self, feed_dict):
|
||||
assert not self._executed
|
||||
for k in feed_dict:
|
||||
assert k not in self.feed_dict
|
||||
self.feed_dict.update(feed_dict)
|
||||
|
||||
def add_fetches(self, fetches):
|
||||
assert not self._executed
|
||||
base_index = len(self.fetches)
|
||||
self.fetches.extend(fetches)
|
||||
return list(range(base_index, len(self.fetches)))
|
||||
|
||||
def get(self, to_fetch):
|
||||
if self._executed is None:
|
||||
try:
|
||||
self._executed = run_timeline(
|
||||
self.session, self.fetches, self.debug_name,
|
||||
self.feed_dict, os.environ.get("TF_TIMELINE_DIR"))
|
||||
except Exception as e:
|
||||
print("Error fetching: {}, feed_dict={}".format(
|
||||
self.fetches, self.feed_dict))
|
||||
raise e
|
||||
if isinstance(to_fetch, int):
|
||||
return self._executed[to_fetch]
|
||||
elif isinstance(to_fetch, list):
|
||||
return [self.get(x) for x in to_fetch]
|
||||
elif isinstance(to_fetch, tuple):
|
||||
return tuple(self.get(x) for x in to_fetch)
|
||||
else:
|
||||
raise ValueError("Unsupported fetch type: {}".format(to_fetch))
|
||||
|
||||
|
||||
_count = 0
|
||||
|
||||
|
||||
def run_timeline(sess, ops, debug_name, feed_dict={}, timeline_dir=None):
|
||||
if timeline_dir:
|
||||
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
|
||||
run_metadata = tf.RunMetadata()
|
||||
start = time.time()
|
||||
fetches = sess.run(
|
||||
ops, options=run_options, run_metadata=run_metadata,
|
||||
feed_dict=feed_dict)
|
||||
trace = timeline.Timeline(step_stats=run_metadata.step_stats)
|
||||
global _count
|
||||
outf = os.path.join(
|
||||
timeline_dir,
|
||||
"timeline-{}-{}-{}.json".format(debug_name, os.getpid(), _count))
|
||||
_count += 1
|
||||
trace_file = open(outf, "w")
|
||||
print(
|
||||
"Wrote tf timeline ({} s) to {}".format(
|
||||
time.time() - start, os.path.abspath(outf)))
|
||||
trace_file.write(trace.generate_chrome_trace_format())
|
||||
else:
|
||||
fetches = sess.run(ops, feed_dict=feed_dict)
|
||||
return fetches
|
||||
@@ -43,6 +43,9 @@ TrainingResult = namedtuple(
|
||||
# (Optional) The number of episodes total.
|
||||
"episodes_total",
|
||||
|
||||
# (Optional) Per-policy reward information in multi-agent RL.
|
||||
"policy_reward_mean",
|
||||
|
||||
# (Optional) The current training accuracy if applicable.
|
||||
"mean_accuracy",
|
||||
|
||||
|
||||
Reference in New Issue
Block a user