[rllib] Refactor Multi-GPU for PPO (#1646)

This commit is contained in:
Victor Sun
2018-06-18 23:49:35 -04:00
committed by Richard Liaw
parent 7dee2c6735
commit b372b7103e
7 changed files with 152 additions and 239 deletions
+44 -22
View File
@@ -26,18 +26,23 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
the TFMultiGPUSupport API.
"""
def _init(self, sgd_batch_size=128, sgd_stepsize=5e-5, num_sgd_iter=10):
def _init(self, sgd_batch_size=128, sgd_stepsize=5e-5, num_sgd_iter=10,
timesteps_per_batch=1024):
assert isinstance(self.local_evaluator, TFMultiGPUSupport)
self.batch_size = sgd_batch_size
self.sgd_stepsize = sgd_stepsize
self.num_sgd_iter = num_sgd_iter
self.timesteps_per_batch = timesteps_per_batch
gpu_ids = ray.get_gpu_ids()
if not gpu_ids:
self.devices = ["/cpu:0"]
else:
self.devices = ["/gpu:{}".format(i) for i in range(len(gpu_ids))]
assert self.batch_size > len(self.devices), "batch size too small"
self.per_device_batch_size = self.batch_size // len(self.devices)
self.batch_size = int(
sgd_batch_size / len(self.devices)) * len(self.devices)
assert self.batch_size % len(self.devices) == 0
assert self.batch_size >= len(self.devices), "batch size too small"
self.per_device_batch_size = int(self.batch_size / len(self.devices))
self.sample_timer = TimerStat()
self.load_timer = TimerStat()
self.grad_timer = TimerStat()
@@ -50,20 +55,27 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
self.loss_inputs = self.local_evaluator.tf_loss_inputs()
# per-GPU graph copies created below must share vars with the policy
tf.get_variable_scope().reuse_variables()
main_thread_scope = tf.get_variable_scope()
# reuse is set to AUTO_REUSE because Adam nodes are created after
# all of the device copies are created.
with tf.variable_scope(main_thread_scope, reuse=tf.AUTO_REUSE):
self.par_opt = LocalSyncParallelOptimizer(
tf.train.AdamOptimizer(self.sgd_stepsize),
self.devices,
[ph for _, ph in self.loss_inputs],
self.per_device_batch_size,
lambda *ph: self.local_evaluator.build_tf_loss(ph),
os.getcwd())
self.par_opt = LocalSyncParallelOptimizer(
tf.train.AdamOptimizer(self.sgd_stepsize),
self.devices,
[ph for _, ph in self.loss_inputs],
self.per_device_batch_size,
lambda *ph: self.local_evaluator.build_tf_loss(ph),
os.getcwd())
# TODO(rliaw): Find more elegant solution for this
if hasattr(self.local_evaluator, "init_extra_ops"):
self.local_evaluator.init_extra_ops(
self.par_opt.get_device_losses())
self.sess = self.local_evaluator.sess
self.sess.run(tf.global_variables_initializer())
def step(self):
def step(self, postprocess_fn=None):
with self.update_weights_timer:
if self.remote_evaluators:
weights = ray.put(self.local_evaluator.get_weights())
@@ -72,34 +84,44 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
with self.sample_timer:
if self.remote_evaluators:
samples = SampleBatch.concat_samples(
ray.get(
[e.sample.remote() for e in self.remote_evaluators]))
# TODO(rliaw): remove when refactoring
from ray.rllib.ppo.rollout import collect_samples
samples = collect_samples(self.remote_evaluators,
self.timesteps_per_batch)
else:
samples = self.local_evaluator.sample()
assert isinstance(samples, SampleBatch)
if postprocess_fn:
postprocess_fn(samples)
with self.load_timer:
tuples_per_device = self.par_opt.load_data(
self.local_evaluator.sess,
samples.columns([key for key, _ in self.loss_inputs]))
with self.grad_timer:
all_extra_fetches = []
model = self.local_evaluator
num_batches = (
int(tuples_per_device) // int(self.per_device_batch_size))
for i in range(self.num_sgd_iter):
batch_index = 0
num_batches = (
int(tuples_per_device) // int(self.per_device_batch_size))
iter_extra_fetches = []
permutation = np.random.permutation(num_batches)
while batch_index < num_batches:
for batch_index in range(num_batches):
# TODO(ekl) support ppo's debugging features, e.g.
# printing the current loss and tracing
self.par_opt.optimize(
batch_fetches = self.par_opt.optimize(
self.sess,
permutation[batch_index] * self.per_device_batch_size)
batch_index += 1
permutation[batch_index] * self.per_device_batch_size,
extra_ops=model.extra_apply_grad_fetches(),
extra_feed_dict=model.extra_apply_grad_feed_dict())
iter_extra_fetches += [batch_fetches]
all_extra_fetches += [iter_extra_fetches]
self.num_steps_sampled += samples.count
self.num_steps_trained += samples.count
return all_extra_fetches
def stats(self):
return dict(PolicyOptimizer.stats(), **{
@@ -60,7 +60,7 @@ class LocalSyncParallelOptimizer(object):
self.logdir = logdir
# First initialize the shared loss network
with tf.variable_scope(TOWER_SCOPE_NAME):
with tf.name_scope(TOWER_SCOPE_NAME):
self._shared_loss = build_loss(*input_placeholders)
# Then setup the per-device loss graphs that use the shared weights
@@ -192,7 +192,7 @@ class LocalSyncParallelOptimizer(object):
def _setup_device(self, device, device_input_placeholders):
with tf.device(device):
with tf.variable_scope(TOWER_SCOPE_NAME, reuse=True):
with tf.name_scope(TOWER_SCOPE_NAME):
device_input_batches = []
device_input_slices = []
for ph in device_input_placeholders:
+29 -120
View File
@@ -3,12 +3,9 @@ from __future__ import division
from __future__ import print_function
import os
import time
import numpy as np
import pickle
import tensorflow as tf
from tensorflow.python import debug as tf_debug
import ray
from ray.tune.result import TrainingResult
@@ -16,8 +13,7 @@ from ray.tune.trial import Resources
from ray.rllib.agent import Agent
from ray.rllib.utils import FilterManager
from ray.rllib.ppo.ppo_evaluator import PPOEvaluator
from ray.rllib.ppo.rollout import collect_samples
from ray.rllib.optimizers.multi_gpu import LocalMultiGPUOptimizer
DEFAULT_CONFIG = {
# Discount factor of the MDP
@@ -43,7 +39,7 @@ DEFAULT_CONFIG = {
"log_device_placement": False,
"allow_soft_placement": True,
"intra_op_parallelism_threads": 1,
"inter_op_parallelism_threads": 2,
"inter_op_parallelism_threads": 1,
},
# Batch size for policy evaluations for rollouts
"rollout_batchsize": 1,
@@ -106,7 +102,6 @@ class PPOAgent(Agent):
def _init(self):
self.global_step = 0
self.kl_coeff = self.config["kl_coeff"]
self.local_evaluator = PPOEvaluator(
self.registry, self.env_creator, self.config, self.logdir, False)
RemotePPOEvaluator = ray.remote(
@@ -117,125 +112,41 @@ class PPOAgent(Agent):
self.registry, self.env_creator, self.config, self.logdir,
True)
for _ in range(self.config["num_workers"])]
self.start_time = time.time()
if self.config["write_logs"]:
self.file_writer = tf.summary.FileWriter(
self.logdir, self.local_evaluator.sess.graph)
else:
self.file_writer = None
self.optimizer = LocalMultiGPUOptimizer(
{"sgd_batch_size": self.config["sgd_batchsize"],
"sgd_stepsize": self.config["sgd_stepsize"],
"num_sgd_iter": self.config["num_sgd_iter"],
"timesteps_per_batch": self.config["timesteps_per_batch"]},
self.local_evaluator, self.remote_evaluators,)
self.saver = tf.train.Saver(max_to_keep=None)
def _train(self):
agents = self.remote_evaluators
config = self.config
model = self.local_evaluator
if (config["num_workers"] * config["min_steps_per_task"] >
config["timesteps_per_batch"]):
print(
"WARNING: num_workers * min_steps_per_task > "
"timesteps_per_batch. This means that the output of some "
"tasks will be wasted. Consider decreasing "
"min_steps_per_task or increasing timesteps_per_batch.")
print("===> iteration", self.iteration)
iter_start = time.time()
weights = ray.put(model.get_weights())
[a.set_weights.remote(weights) for a in agents]
samples = collect_samples(agents, config, self.local_evaluator)
def standardized(value):
def postprocess_samples(batch):
# Divide by the maximum of value.std() and 1e-4
# to guard against the case where all values are equal
return (value - value.mean()) / max(1e-4, value.std())
value = batch["advantages"]
standardized = (value - value.mean()) / max(1e-4, value.std())
batch.data["advantages"] = standardized
batch.shuffle()
dummy = np.zeros_like(batch["advantages"])
if not self.config["use_gae"]:
batch.data["value_targets"] = dummy
batch.data["vf_preds"] = dummy
extra_fetches = self.optimizer.step(postprocess_fn=postprocess_samples)
samples.data["advantages"] = standardized(samples["advantages"])
rollouts_end = time.time()
print("Computing policy (iterations=" + str(config["num_sgd_iter"]) +
", stepsize=" + str(config["sgd_stepsize"]) + "):")
names = [
"iter", "total loss", "policy loss", "vf loss", "kl", "entropy"]
print(("{:>15}" * len(names)).format(*names))
samples.shuffle()
shuffle_end = time.time()
tuples_per_device = model.load_data(
samples, self.iteration == 0 and config["full_trace_data_load"])
load_end = time.time()
rollouts_time = rollouts_end - iter_start
shuffle_time = shuffle_end - rollouts_end
load_time = load_end - shuffle_end
sgd_time = 0
for i in range(config["num_sgd_iter"]):
sgd_start = time.time()
batch_index = 0
num_batches = (
int(tuples_per_device) // int(model.per_device_batch_size))
loss, policy_graph, vf_loss, kl, entropy = [], [], [], [], []
permutation = np.random.permutation(num_batches)
# Prepare to drop into the debugger
if self.iteration == config["tf_debug_iteration"]:
model.sess = tf_debug.LocalCLIDebugWrapperSession(model.sess)
while batch_index < num_batches:
full_trace = (
i == 0 and self.iteration == 0 and
batch_index == config["full_trace_nth_sgd_batch"])
batch_loss, batch_policy_graph, batch_vf_loss, batch_kl, \
batch_entropy = model.run_sgd_minibatch(
permutation[batch_index] * model.per_device_batch_size,
self.kl_coeff, full_trace,
self.file_writer)
loss.append(batch_loss)
policy_graph.append(batch_policy_graph)
vf_loss.append(batch_vf_loss)
kl.append(batch_kl)
entropy.append(batch_entropy)
batch_index += 1
loss = np.mean(loss)
policy_graph = np.mean(policy_graph)
vf_loss = np.mean(vf_loss)
kl = np.mean(kl)
entropy = np.mean(entropy)
sgd_end = time.time()
print(
"{:>15}{:15.5e}{:15.5e}{:15.5e}{:15.5e}{:15.5e}".format(
i, loss, policy_graph, vf_loss, kl, entropy))
values = []
if i == config["num_sgd_iter"] - 1:
metric_prefix = "ppo/sgd/final_iter/"
values.append(tf.Summary.Value(
tag=metric_prefix + "kl_coeff",
simple_value=self.kl_coeff))
values.extend([
tf.Summary.Value(
tag=metric_prefix + "mean_entropy",
simple_value=entropy),
tf.Summary.Value(
tag=metric_prefix + "mean_loss",
simple_value=loss),
tf.Summary.Value(
tag=metric_prefix + "mean_kl",
simple_value=kl)])
if self.file_writer:
sgd_stats = tf.Summary(value=values)
self.file_writer.add_summary(sgd_stats, self.global_step)
self.global_step += 1
sgd_time += sgd_end - sgd_start
if kl > 2.0 * config["kl_target"]:
self.kl_coeff *= 1.5
elif kl < 0.5 * config["kl_target"]:
self.kl_coeff *= 0.5
final_metrics = np.array(extra_fetches).mean(axis=1)[-1, :].tolist()
total_loss, policy_loss, vf_loss, kl, entropy = final_metrics
self.local_evaluator.update_kl(kl)
info = {
"total_loss": total_loss,
"policy_loss": policy_loss,
"vf_loss": vf_loss,
"kl_divergence": kl,
"kl_coefficient": self.kl_coeff,
"rollouts_time": rollouts_time,
"shuffle_time": shuffle_time,
"load_time": load_time,
"sgd_time": sgd_time,
"sample_throughput": len(samples["obs"]) / sgd_time
"entropy": entropy,
"kl_coefficient": self.local_evaluator.kl_coeff_val,
}
FilterManager.synchronize(
@@ -281,7 +192,6 @@ class PPOAgent(Agent):
extra_data = [
self.local_evaluator.save(),
self.global_step,
self.kl_coeff,
agent_state]
pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb"))
return checkpoint_path
@@ -291,10 +201,9 @@ class PPOAgent(Agent):
extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
self.local_evaluator.restore(extra_data[0])
self.global_step = extra_data[1]
self.kl_coeff = extra_data[2]
ray.get([
a.restore.remote(o)
for (a, o) in zip(self.remote_evaluators, extra_data[3])])
for (a, o) in zip(self.remote_evaluators, extra_data[2])])
def compute_action(self, observation):
observation = self.local_evaluator.obs_filter(
+61 -92
View File
@@ -4,15 +4,10 @@ from __future__ import print_function
import pickle
import tensorflow as tf
import os
from tensorflow.python import debug as tf_debug
import numpy as np
from collections import OrderedDict
import ray
from ray.rllib.optimizers import PolicyEvaluator, SampleBatch
from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer
from ray.rllib.optimizers import SampleBatch, TFMultiGPUSupport
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.sampler import SyncSampler
from ray.rllib.utils.filter import get_filter, MeanStdFilter
@@ -20,8 +15,7 @@ from ray.rllib.utils.process_rollout import compute_advantages
from ray.rllib.ppo.loss import ProximalPolicyGraph
# TODO(rliaw): Move this onto LocalMultiGPUOptimizer
class PPOEvaluator(PolicyEvaluator):
class PPOEvaluator(TFMultiGPUSupport):
"""
Runner class that holds the simulator environment and the policy.
@@ -32,13 +26,6 @@ class PPOEvaluator(PolicyEvaluator):
def __init__(self, registry, env_creator, config, logdir, is_remote):
self.registry = registry
self.is_remote = is_remote
if is_remote:
os.environ["CUDA_VISIBLE_DEVICES"] = ""
devices = ["/cpu:0"]
else:
devices = config["devices"]
self.devices = devices
self.config = config
self.logdir = logdir
self.env = ModelCatalog.get_preprocessor_as_wrapper(
@@ -48,10 +35,8 @@ class PPOEvaluator(PolicyEvaluator):
else:
config_proto = tf.ConfigProto(**config["tf_session_args"])
self.sess = tf.Session(config=config_proto)
if config["tf_debug_inf_or_nan"] and not is_remote:
self.sess = tf_debug.LocalCLIDebugWrapperSession(self.sess)
self.sess.add_tensor_filter(
"has_inf_or_nan", tf_debug.has_inf_or_nan)
self.kl_coeff_val = self.config["kl_coeff"]
self.kl_target = self.config["kl_target"]
# Defines the training inputs:
# The coefficient of the KL penalty.
@@ -76,52 +61,17 @@ class PPOEvaluator(PolicyEvaluator):
# Value function predictions before the policy update.
self.prev_vf_preds = tf.placeholder(tf.float32, shape=(None,))
if is_remote:
self.batch_size = config["rollout_batchsize"]
self.per_device_batch_size = config["rollout_batchsize"]
else:
self.batch_size = int(
config["sgd_batchsize"] / len(devices)) * len(devices)
assert self.batch_size % len(devices) == 0
self.per_device_batch_size = int(self.batch_size / len(devices))
def build_loss(obs, vtargets, advs, acts, plog, pvf_preds):
return ProximalPolicyGraph(
self.env.observation_space, self.env.action_space,
obs, vtargets, advs, acts, plog, pvf_preds, self.logit_dim,
self.kl_coeff, self.distribution_class, self.config,
self.sess, self.registry)
self.par_opt = LocalSyncParallelOptimizer(
tf.train.AdamOptimizer(self.config["sgd_stepsize"]),
self.devices,
[self.observations, self.value_targets, self.advantages,
self.actions, self.prev_logits, self.prev_vf_preds],
self.per_device_batch_size,
build_loss,
self.logdir)
# Metric ops
with tf.name_scope("test_outputs"):
policies = self.par_opt.get_device_losses()
self.mean_loss = tf.reduce_mean(
tf.stack(values=[
policy.loss for policy in policies]), 0)
self.mean_policy_loss = tf.reduce_mean(
tf.stack(values=[
policy.mean_policy_loss for policy in policies]), 0)
self.mean_vf_loss = tf.reduce_mean(
tf.stack(values=[
policy.mean_vf_loss for policy in policies]), 0)
self.mean_kl = tf.reduce_mean(
tf.stack(values=[
policy.mean_kl for policy in policies]), 0)
self.mean_entropy = tf.reduce_mean(
tf.stack(values=[
policy.mean_entropy for policy in policies]), 0)
self.inputs = [
("obs", self.observations),
("value_targets", self.value_targets),
("advantages", self.advantages),
("actions", self.actions),
("logprobs", self.prev_logits),
("vf_preds", self.prev_vf_preds)
]
self.common_policy = self.build_tf_loss([ph for _, ph in self.inputs])
# References to the model weights
self.common_policy = self.par_opt.get_common_loss()
self.variables = ray.experimental.TensorFlowVariables(
self.common_policy.loss, self.sess)
self.obs_filter = get_filter(
@@ -132,45 +82,64 @@ class PPOEvaluator(PolicyEvaluator):
self.sampler = SyncSampler(
self.env, self.common_policy, self.obs_filter,
self.config["horizon"], self.config["horizon"])
self.sess.run(tf.global_variables_initializer())
def load_data(self, trajectories, full_trace):
use_gae = self.config["use_gae"]
dummy = np.zeros_like(trajectories["advantages"])
return self.par_opt.load_data(
self.sess,
[trajectories["obs"],
trajectories["value_targets"] if use_gae else dummy,
trajectories["advantages"],
trajectories["actions"],
trajectories["logprobs"],
trajectories["vf_preds"] if use_gae else dummy],
full_trace=full_trace)
def tf_loss_inputs(self):
return self.inputs
def run_sgd_minibatch(
self, batch_index, kl_coeff, full_trace, file_writer):
return self.par_opt.optimize(
self.sess,
batch_index,
extra_ops=[
self.mean_loss, self.mean_policy_loss, self.mean_vf_loss,
self.mean_kl, self.mean_entropy],
extra_feed_dict={self.kl_coeff: kl_coeff},
file_writer=file_writer if full_trace else None)
def build_tf_loss(self, input_placeholders):
obs, vtargets, advs, acts, plog, pvf_preds = input_placeholders
return ProximalPolicyGraph(
self.env.observation_space, self.env.action_space,
obs, vtargets, advs, acts, plog, pvf_preds, self.logit_dim,
self.kl_coeff, self.distribution_class, self.config,
self.sess, self.registry)
def compute_gradients(self, samples):
raise NotImplementedError
def init_extra_ops(self, device_losses):
self.extra_ops = OrderedDict()
with tf.name_scope("test_outputs"):
policies = device_losses
self.extra_ops["loss"] = tf.reduce_mean(
tf.stack(values=[
policy.loss for policy in policies]), 0)
self.extra_ops["policy_loss"] = tf.reduce_mean(
tf.stack(values=[
policy.mean_policy_loss for policy in policies]), 0)
self.extra_ops["vf_loss"] = tf.reduce_mean(
tf.stack(values=[
policy.mean_vf_loss for policy in policies]), 0)
self.extra_ops["kl"] = tf.reduce_mean(
tf.stack(values=[
policy.mean_kl for policy in policies]), 0)
self.extra_ops["entropy"] = tf.reduce_mean(
tf.stack(values=[
policy.mean_entropy for policy in policies]), 0)
def apply_gradients(self, grads):
raise NotImplementedError
def extra_apply_grad_fetches(self):
return list(self.extra_ops.values())
def extra_apply_grad_feed_dict(self):
return {self.kl_coeff: self.kl_coeff_val}
def update_kl(self, sampled_kl):
if sampled_kl > 2.0 * self.kl_target:
self.kl_coeff_val *= 1.5
elif sampled_kl < 0.5 * self.kl_target:
self.kl_coeff_val *= 0.5
def save(self):
filters = self.get_filters(flush_after=True)
return pickle.dumps({"filters": filters})
return pickle.dumps({
"filters": filters,
"kl_coeff_val": self.kl_coeff_val,
"kl_target": self.kl_target,
})
def restore(self, objs):
objs = pickle.loads(objs)
self.sync_filters(objs["filters"])
self.kl_coeff_val = objs["kl_coeff_val"]
self.kl_target = objs["kl_target"]
def get_weights(self):
return self.variables.get_weights()
+2 -2
View File
@@ -6,7 +6,7 @@ import ray
from ray.rllib.optimizers import SampleBatch
def collect_samples(agents, config, local_evaluator):
def collect_samples(agents, timesteps_per_batch):
num_timesteps_so_far = 0
trajectories = []
# This variable maps the object IDs of trajectories that are currently
@@ -19,7 +19,7 @@ def collect_samples(agents, config, local_evaluator):
fut_sample = agent.sample.remote()
agent_dict[fut_sample] = agent
while num_timesteps_so_far < config["timesteps_per_batch"]:
while num_timesteps_so_far < timesteps_per_batch:
# TODO(pcm): Make wait support arbitrary iterators and remove the
# conversion to list here.
[fut_sample], _ = ray.wait(list(agent_dict))
@@ -1,6 +1,6 @@
# On a Tesla K80 GPU, this achieves the maximum reward in about 1-1.5 hours.
#
# $ python train.py -f tuned_examples/pong-ppo.yaml --num-gpus=1
# $ python train.py -f tuned_examples/pong-ppo.yaml --ray-num-gpus=1
#
# - PPO_PongDeterministic-v4_0: TERMINATED [pid=16387], 4984 s, 1117981 ts, 21 rew
# - PPO_PongDeterministic-v4_0: TERMINATED [pid=83606], 4592 s, 1068671 ts, 21 rew
+13
View File
@@ -0,0 +1,13 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import random
import tensorflow as tf
def seed(np_seed=0, random_seed=0, tf_seed=0):
np.random.seed(np_seed)
random.seed(random_seed)
tf.set_random_seed(tf_seed)