Files
ray/python/ray/rllib/optimizers/async_samples_optimizer.py
T
Eric Liang e78562b2e8 [rllib] Misc fixes: set lr for PG, better error message for LSTM/PPO, fix multi-agent/APEX (#3697)
* fix

* update test

* better error

* compute

* eps fix

* add get_policy() api

* Update agent.py

* better err msg

* fix

* pass in rew
2019-01-06 19:37:35 -08:00

429 lines
16 KiB
Python

"""Implements the IMPALA architecture.
https://arxiv.org/abs/1802.01561"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import numpy as np
import random
import time
import threading
from six.moves import queue
import ray
from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.utils.actors import TaskPool
from ray.rllib.utils.annotations import override
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.window_stat import WindowStat
logger = logging.getLogger(__name__)
LEARNER_QUEUE_MAX_SIZE = 16
NUM_DATA_LOAD_THREADS = 16
class AsyncSamplesOptimizer(PolicyOptimizer):
"""Main event loop of the IMPALA architecture.
This class coordinates the data transfers between the learner thread
and remote evaluators (IMPALA actors).
"""
@override(PolicyOptimizer)
def _init(self,
train_batch_size=500,
sample_batch_size=50,
num_envs_per_worker=1,
num_gpus=0,
lr=0.0005,
replay_buffer_num_slots=0,
replay_proportion=0.0,
num_data_loader_buffers=1,
max_sample_requests_in_flight_per_worker=2,
broadcast_interval=1,
num_sgd_iter=1,
minibatch_buffer_size=1,
_fake_gpus=False):
self.learning_started = False
self.train_batch_size = train_batch_size
self.sample_batch_size = sample_batch_size
self.broadcast_interval = broadcast_interval
if num_gpus > 1 or num_data_loader_buffers > 1:
logger.info(
"Enabling multi-GPU mode, {} GPUs, {} parallel loaders".format(
num_gpus, num_data_loader_buffers))
if num_data_loader_buffers < minibatch_buffer_size:
raise ValueError(
"In multi-gpu mode you must have at least as many "
"parallel data loader buffers as minibatch buffers: "
"{} vs {}".format(num_data_loader_buffers,
minibatch_buffer_size))
self.learner = TFMultiGPULearner(
self.local_evaluator,
lr=lr,
num_gpus=num_gpus,
train_batch_size=train_batch_size,
num_data_loader_buffers=num_data_loader_buffers,
minibatch_buffer_size=minibatch_buffer_size,
num_sgd_iter=num_sgd_iter,
_fake_gpus=_fake_gpus)
else:
self.learner = LearnerThread(self.local_evaluator,
minibatch_buffer_size, num_sgd_iter)
self.learner.start()
assert len(self.remote_evaluators) > 0
# Stats
self.timers = {k: TimerStat() for k in ["train", "sample"]}
self.num_weight_syncs = 0
self.num_replayed = 0
self.learning_started = False
# Kick off async background sampling
self.sample_tasks = TaskPool()
weights = self.local_evaluator.get_weights()
for ev in self.remote_evaluators:
ev.set_weights.remote(weights)
for _ in range(max_sample_requests_in_flight_per_worker):
self.sample_tasks.add(ev, ev.sample.remote())
self.batch_buffer = []
if replay_proportion:
if replay_buffer_num_slots * sample_batch_size <= train_batch_size:
raise ValueError(
"Replay buffer size is too small to produce train, "
"please increase replay_buffer_num_slots.",
replay_buffer_num_slots, sample_batch_size,
train_batch_size)
self.replay_proportion = replay_proportion
self.replay_buffer_num_slots = replay_buffer_num_slots
self.replay_batches = []
@override(PolicyOptimizer)
def step(self):
assert self.learner.is_alive()
start = time.time()
sample_timesteps, train_timesteps = self._step()
time_delta = time.time() - start
self.timers["sample"].push(time_delta)
self.timers["sample"].push_units_processed(sample_timesteps)
if train_timesteps > 0:
self.learning_started = True
if self.learning_started:
self.timers["train"].push(time_delta)
self.timers["train"].push_units_processed(train_timesteps)
self.num_steps_sampled += sample_timesteps
self.num_steps_trained += train_timesteps
@override(PolicyOptimizer)
def stop(self):
self.learner.stopped = True
@override(PolicyOptimizer)
def stats(self):
timing = {
"{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3)
for k in self.timers
}
timing["learner_grad_time_ms"] = round(
1000 * self.learner.grad_timer.mean, 3)
timing["learner_load_time_ms"] = round(
1000 * self.learner.load_timer.mean, 3)
timing["learner_load_wait_time_ms"] = round(
1000 * self.learner.load_wait_timer.mean, 3)
timing["learner_dequeue_time_ms"] = round(
1000 * self.learner.queue_timer.mean, 3)
stats = {
"sample_throughput": round(self.timers["sample"].mean_throughput,
3),
"train_throughput": round(self.timers["train"].mean_throughput, 3),
"num_weight_syncs": self.num_weight_syncs,
"num_steps_replayed": self.num_replayed,
"timing_breakdown": timing,
"learner_queue": self.learner.learner_queue_size.stats(),
}
if self.learner.stats:
stats["learner"] = self.learner.stats
return dict(PolicyOptimizer.stats(self), **stats)
def _step(self):
sample_timesteps, train_timesteps = 0, 0
num_sent = 0
weights = None
for ev, sample_batch in self._augment_with_replay(
self.sample_tasks.completed_prefetch()):
self.batch_buffer.append(sample_batch)
if sum(b.count
for b in self.batch_buffer) >= self.train_batch_size:
train_batch = self.batch_buffer[0].concat_samples(
self.batch_buffer)
self.learner.inqueue.put(train_batch)
self.batch_buffer = []
# If the batch was replayed, skip the update below.
if ev is None:
continue
sample_timesteps += sample_batch.count
# Put in replay buffer if enabled
if self.replay_buffer_num_slots > 0:
self.replay_batches.append(sample_batch)
if len(self.replay_batches) > self.replay_buffer_num_slots:
self.replay_batches.pop(0)
# Note that it's important to pull new weights once
# updated to avoid excessive correlation between actors
if weights is None or (self.learner.weights_updated
and num_sent >= self.broadcast_interval):
self.learner.weights_updated = False
weights = ray.put(self.local_evaluator.get_weights())
num_sent = 0
ev.set_weights.remote(weights)
self.num_weight_syncs += 1
num_sent += 1
# Kick off another sample request
self.sample_tasks.add(ev, ev.sample.remote())
while not self.learner.outqueue.empty():
count = self.learner.outqueue.get()
train_timesteps += count
return sample_timesteps, train_timesteps
def _augment_with_replay(self, sample_futures):
def can_replay():
num_needed = int(
np.ceil(self.train_batch_size / self.sample_batch_size))
return len(self.replay_batches) > num_needed
for ev, sample_batch in sample_futures:
sample_batch = ray.get(sample_batch)
yield ev, sample_batch
if can_replay():
f = self.replay_proportion
while random.random() < f:
f -= 1
replay_batch = random.choice(self.replay_batches)
self.num_replayed += replay_batch.count
yield None, replay_batch
class LearnerThread(threading.Thread):
"""Background thread that updates the local model from sample trajectories.
The learner thread communicates with the main thread through Queues. This
is needed since Ray operations can only be run on the main thread. In
addition, moving heavyweight gradient ops session runs off the main thread
improves overall throughput.
"""
def __init__(self, local_evaluator, minibatch_buffer_size, num_sgd_iter):
threading.Thread.__init__(self)
self.learner_queue_size = WindowStat("size", 50)
self.local_evaluator = local_evaluator
self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE)
self.outqueue = queue.Queue()
self.minibatch_buffer = MinibatchBuffer(
self.inqueue, minibatch_buffer_size, num_sgd_iter)
self.queue_timer = TimerStat()
self.grad_timer = TimerStat()
self.load_timer = TimerStat()
self.load_wait_timer = TimerStat()
self.daemon = True
self.weights_updated = False
self.stats = {}
self.stopped = False
def run(self):
while not self.stopped:
self.step()
def step(self):
with self.queue_timer:
batch, _ = self.minibatch_buffer.get()
with self.grad_timer:
fetches = self.local_evaluator.compute_apply(batch)
self.weights_updated = True
self.stats = fetches.get("stats", {})
self.outqueue.put(batch.count)
self.learner_queue_size.push(self.inqueue.qsize())
class TFMultiGPULearner(LearnerThread):
"""Learner that can use multiple GPUs and parallel loading."""
def __init__(self,
local_evaluator,
num_gpus=1,
lr=0.0005,
train_batch_size=500,
num_data_loader_buffers=1,
minibatch_buffer_size=1,
num_sgd_iter=1,
_fake_gpus=False):
# Multi-GPU requires TensorFlow to function.
import tensorflow as tf
LearnerThread.__init__(self, local_evaluator, minibatch_buffer_size,
num_sgd_iter)
self.lr = lr
self.train_batch_size = train_batch_size
if not num_gpus:
self.devices = ["/cpu:0"]
elif _fake_gpus:
self.devices = ["/cpu:{}".format(i) for i in range(num_gpus)]
else:
self.devices = ["/gpu:{}".format(i) for i in range(num_gpus)]
logger.info("TFMultiGPULearner devices {}".format(self.devices))
assert self.train_batch_size % len(self.devices) == 0
assert self.train_batch_size >= len(self.devices), "batch too small"
if set(self.local_evaluator.policy_map.keys()) != {"default"}:
raise NotImplementedError("Multi-gpu mode for multi-agent")
self.policy = self.local_evaluator.policy_map["default"]
# per-GPU graph copies created below must share vars with the policy
# reuse is set to AUTO_REUSE because Adam nodes are created after
# all of the device copies are created.
self.par_opt = []
with self.local_evaluator.tf_sess.graph.as_default():
with self.local_evaluator.tf_sess.as_default():
with tf.variable_scope("default", reuse=tf.AUTO_REUSE):
if self.policy._state_inputs:
rnn_inputs = self.policy._state_inputs + [
self.policy._seq_lens
]
else:
rnn_inputs = []
adam = tf.train.AdamOptimizer(self.lr)
for _ in range(num_data_loader_buffers):
self.par_opt.append(
LocalSyncParallelOptimizer(
adam,
self.devices,
[v for _, v in self.policy._loss_inputs],
rnn_inputs,
999999, # it will get rounded down
self.policy.copy))
self.sess = self.local_evaluator.tf_sess
self.sess.run(tf.global_variables_initializer())
self.idle_optimizers = queue.Queue()
self.ready_optimizers = queue.Queue()
for opt in self.par_opt:
self.idle_optimizers.put(opt)
for i in range(NUM_DATA_LOAD_THREADS):
self.loader_thread = _LoaderThread(self, share_stats=(i == 0))
self.loader_thread.start()
self.minibatch_buffer = MinibatchBuffer(
self.ready_optimizers, minibatch_buffer_size, num_sgd_iter)
@override(LearnerThread)
def step(self):
assert self.loader_thread.is_alive()
with self.load_wait_timer:
opt, released = self.minibatch_buffer.get()
if released:
self.idle_optimizers.put(opt)
with self.grad_timer:
fetches = opt.optimize(self.sess, 0)
self.weights_updated = True
self.stats = fetches.get("stats", {})
self.outqueue.put(self.train_batch_size)
self.learner_queue_size.push(self.inqueue.qsize())
class _LoaderThread(threading.Thread):
def __init__(self, learner, share_stats):
threading.Thread.__init__(self)
self.learner = learner
self.daemon = True
if share_stats:
self.queue_timer = learner.queue_timer
self.load_timer = learner.load_timer
else:
self.queue_timer = TimerStat()
self.load_timer = TimerStat()
def run(self):
while True:
self._step()
def _step(self):
s = self.learner
with self.queue_timer:
batch = s.inqueue.get()
opt = s.idle_optimizers.get()
with self.load_timer:
tuples = s.policy._get_loss_inputs_dict(batch)
data_keys = [ph for _, ph in s.policy._loss_inputs]
if s.policy._state_inputs:
state_keys = s.policy._state_inputs + [s.policy._seq_lens]
else:
state_keys = []
opt.load_data(s.sess, [tuples[k] for k in data_keys],
[tuples[k] for k in state_keys])
s.ready_optimizers.put(opt)
class MinibatchBuffer(object):
"""Ring buffer of recent data batches for minibatch SGD."""
def __init__(self, inqueue, size, num_passes):
"""Initialize a minibatch buffer.
Arguments:
inqueue: Queue to populate the internal ring buffer from.
size: Max number of data items to buffer.
num_passes: Max num times each data item should be emitted.
"""
self.inqueue = inqueue
self.size = size
self.max_ttl = num_passes
self.cur_max_ttl = 1 # ramp up slowly to better mix the input data
self.buffers = [None] * size
self.ttl = [0] * size
self.idx = 0
def get(self):
"""Get a new batch from the internal ring buffer.
Returns:
buf: Data item saved from inqueue.
released: True if the item is now removed from the ring buffer.
"""
if self.ttl[self.idx] <= 0:
self.buffers[self.idx] = self.inqueue.get()
self.ttl[self.idx] = self.cur_max_ttl
if self.cur_max_ttl < self.max_ttl:
self.cur_max_ttl += 1
buf = self.buffers[self.idx]
self.ttl[self.idx] -= 1
released = self.ttl[self.idx] <= 0
if released:
self.buffers[self.idx] = None
self.idx = (self.idx + 1) % len(self.buffers)
return buf, released