Files
ray/python/ray/rllib/optimizers/multi_gpu_optimizer.py
T
Eric Liang d01dc9e22d [rllib] format with yapf (#2427)
* initial yapf

* manual fix yapf bugs
2018-07-19 15:30:36 -07:00

162 lines
6.7 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from collections import defaultdict
import os
import tensorflow as tf
import ray
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer
from ray.rllib.utils.timer import TimerStat
class LocalMultiGPUOptimizer(PolicyOptimizer):
"""A synchronous optimizer that uses multiple local GPUs.
Samples are pulled synchronously from multiple remote evaluators,
concatenated, and then split across the memory of multiple local GPUs.
A number of SGD passes are then taken over the in-memory data. For more
details, see `multi_gpu_impl.LocalSyncParallelOptimizer`.
This optimizer is Tensorflow-specific and require the underlying
PolicyGraph to be a TFPolicyGraph instance that support `.copy()`.
Note that all replicas of the TFPolicyGraph will merge their
extra_compute_grad and apply_grad feed_dicts and fetches. This
may result in unexpected behavior.
"""
def _init(self,
sgd_batch_size=128,
sgd_stepsize=5e-5,
num_sgd_iter=10,
timesteps_per_batch=1024,
standardize_fields=[]):
self.batch_size = sgd_batch_size
self.sgd_stepsize = sgd_stepsize
self.num_sgd_iter = num_sgd_iter
self.timesteps_per_batch = timesteps_per_batch
gpu_ids = ray.get_gpu_ids()
if not gpu_ids:
self.devices = ["/cpu:0"]
else:
self.devices = ["/gpu:{}".format(i) for i in range(len(gpu_ids))]
self.batch_size = int(sgd_batch_size / len(self.devices)) * len(
self.devices)
assert self.batch_size % len(self.devices) == 0
assert self.batch_size >= len(self.devices), "batch size too small"
self.per_device_batch_size = int(self.batch_size / len(self.devices))
self.sample_timer = TimerStat()
self.load_timer = TimerStat()
self.grad_timer = TimerStat()
self.update_weights_timer = TimerStat()
self.standardize_fields = standardize_fields
print("LocalMultiGPUOptimizer devices", self.devices)
assert set(self.local_evaluator.policy_map.keys()) == {"default"}, \
("Multi-agent is not supported with multi-GPU. Try using the "
"simple optimizer instead.")
self.policy = self.local_evaluator.policy_map["default"]
assert isinstance(self.policy, TFPolicyGraph), \
("Only TF policies are supported with multi-GPU. Try using the "
"simple optimizer instead.")
# per-GPU graph copies created below must share vars with the policy
# reuse is set to AUTO_REUSE because Adam nodes are created after
# all of the device copies are created.
with self.local_evaluator.tf_sess.graph.as_default():
with self.local_evaluator.tf_sess.as_default():
with tf.variable_scope("default", reuse=tf.AUTO_REUSE):
if self.policy._state_inputs:
rnn_inputs = self.policy._state_inputs + [
self.policy._seq_lens
]
else:
rnn_inputs = []
self.par_opt = LocalSyncParallelOptimizer(
tf.train.AdamOptimizer(
self.sgd_stepsize), self.devices,
[v for _, v in self.policy.loss_inputs()], rnn_inputs,
self.per_device_batch_size, self.policy.copy,
os.getcwd())
self.sess = self.local_evaluator.tf_sess
self.sess.run(tf.global_variables_initializer())
def step(self):
with self.update_weights_timer:
if self.remote_evaluators:
weights = ray.put(self.local_evaluator.get_weights())
for e in self.remote_evaluators:
e.set_weights.remote(weights)
with self.sample_timer:
if self.remote_evaluators:
# TODO(rliaw): remove when refactoring
from ray.rllib.agents.ppo.rollout import collect_samples
samples = collect_samples(self.remote_evaluators,
self.timesteps_per_batch)
else:
samples = self.local_evaluator.sample()
self._check_not_multiagent(samples)
for field in self.standardize_fields:
value = samples[field]
standardized = (value - value.mean()) / max(1e-4, value.std())
samples[field] = standardized
samples.shuffle()
with self.load_timer:
tuples = self.policy._get_loss_inputs_dict(samples)
data_keys = [ph for _, ph in self.policy.loss_inputs()]
if self.policy._state_inputs:
state_keys = (
self.policy._state_inputs + [self.policy._seq_lens])
else:
state_keys = []
tuples_per_device = self.par_opt.load_data(
self.sess, [tuples[k] for k in data_keys],
[tuples[k] for k in state_keys])
with self.grad_timer:
num_batches = (
int(tuples_per_device) // int(self.per_device_batch_size))
print("== sgd epochs ==")
for i in range(self.num_sgd_iter):
iter_extra_fetches = defaultdict(list)
permutation = np.random.permutation(num_batches)
for batch_index in range(num_batches):
batch_fetches = self.par_opt.optimize(
self.sess,
permutation[batch_index] * self.per_device_batch_size)
for k, v in batch_fetches.items():
iter_extra_fetches[k].append(v)
print(i, _averaged(iter_extra_fetches))
self.num_steps_sampled += samples.count
self.num_steps_trained += samples.count
return _averaged(iter_extra_fetches)
def stats(self):
return dict(
PolicyOptimizer.stats(self), **{
"sample_time_ms": round(1000 * self.sample_timer.mean, 3),
"load_time_ms": round(1000 * self.load_timer.mean, 3),
"grad_time_ms": round(1000 * self.grad_timer.mean, 3),
"update_time_ms": round(1000 * self.update_weights_timer.mean,
3),
})
def _averaged(kv):
out = {}
for k, v in kv.items():
if v[0] is not None:
out[k] = np.mean(v)
return out