mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 01:31:08 +08:00
e78562b2e8
* fix * update test * better error * compute * eps fix * add get_policy() api * Update agent.py * better err msg * fix * pass in rew
429 lines
16 KiB
Python
429 lines
16 KiB
Python
"""Implements the IMPALA architecture.
|
|
|
|
https://arxiv.org/abs/1802.01561"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import logging
|
|
import numpy as np
|
|
import random
|
|
import time
|
|
import threading
|
|
|
|
from six.moves import queue
|
|
|
|
import ray
|
|
from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer
|
|
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
|
from ray.rllib.utils.actors import TaskPool
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.utils.timer import TimerStat
|
|
from ray.rllib.utils.window_stat import WindowStat
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
LEARNER_QUEUE_MAX_SIZE = 16
|
|
NUM_DATA_LOAD_THREADS = 16
|
|
|
|
|
|
class AsyncSamplesOptimizer(PolicyOptimizer):
|
|
"""Main event loop of the IMPALA architecture.
|
|
|
|
This class coordinates the data transfers between the learner thread
|
|
and remote evaluators (IMPALA actors).
|
|
"""
|
|
|
|
@override(PolicyOptimizer)
|
|
def _init(self,
|
|
train_batch_size=500,
|
|
sample_batch_size=50,
|
|
num_envs_per_worker=1,
|
|
num_gpus=0,
|
|
lr=0.0005,
|
|
replay_buffer_num_slots=0,
|
|
replay_proportion=0.0,
|
|
num_data_loader_buffers=1,
|
|
max_sample_requests_in_flight_per_worker=2,
|
|
broadcast_interval=1,
|
|
num_sgd_iter=1,
|
|
minibatch_buffer_size=1,
|
|
_fake_gpus=False):
|
|
self.learning_started = False
|
|
self.train_batch_size = train_batch_size
|
|
self.sample_batch_size = sample_batch_size
|
|
self.broadcast_interval = broadcast_interval
|
|
|
|
if num_gpus > 1 or num_data_loader_buffers > 1:
|
|
logger.info(
|
|
"Enabling multi-GPU mode, {} GPUs, {} parallel loaders".format(
|
|
num_gpus, num_data_loader_buffers))
|
|
if num_data_loader_buffers < minibatch_buffer_size:
|
|
raise ValueError(
|
|
"In multi-gpu mode you must have at least as many "
|
|
"parallel data loader buffers as minibatch buffers: "
|
|
"{} vs {}".format(num_data_loader_buffers,
|
|
minibatch_buffer_size))
|
|
self.learner = TFMultiGPULearner(
|
|
self.local_evaluator,
|
|
lr=lr,
|
|
num_gpus=num_gpus,
|
|
train_batch_size=train_batch_size,
|
|
num_data_loader_buffers=num_data_loader_buffers,
|
|
minibatch_buffer_size=minibatch_buffer_size,
|
|
num_sgd_iter=num_sgd_iter,
|
|
_fake_gpus=_fake_gpus)
|
|
else:
|
|
self.learner = LearnerThread(self.local_evaluator,
|
|
minibatch_buffer_size, num_sgd_iter)
|
|
self.learner.start()
|
|
|
|
assert len(self.remote_evaluators) > 0
|
|
|
|
# Stats
|
|
self.timers = {k: TimerStat() for k in ["train", "sample"]}
|
|
self.num_weight_syncs = 0
|
|
self.num_replayed = 0
|
|
self.learning_started = False
|
|
|
|
# Kick off async background sampling
|
|
self.sample_tasks = TaskPool()
|
|
weights = self.local_evaluator.get_weights()
|
|
for ev in self.remote_evaluators:
|
|
ev.set_weights.remote(weights)
|
|
for _ in range(max_sample_requests_in_flight_per_worker):
|
|
self.sample_tasks.add(ev, ev.sample.remote())
|
|
|
|
self.batch_buffer = []
|
|
|
|
if replay_proportion:
|
|
if replay_buffer_num_slots * sample_batch_size <= train_batch_size:
|
|
raise ValueError(
|
|
"Replay buffer size is too small to produce train, "
|
|
"please increase replay_buffer_num_slots.",
|
|
replay_buffer_num_slots, sample_batch_size,
|
|
train_batch_size)
|
|
self.replay_proportion = replay_proportion
|
|
self.replay_buffer_num_slots = replay_buffer_num_slots
|
|
self.replay_batches = []
|
|
|
|
@override(PolicyOptimizer)
|
|
def step(self):
|
|
assert self.learner.is_alive()
|
|
start = time.time()
|
|
sample_timesteps, train_timesteps = self._step()
|
|
time_delta = time.time() - start
|
|
self.timers["sample"].push(time_delta)
|
|
self.timers["sample"].push_units_processed(sample_timesteps)
|
|
if train_timesteps > 0:
|
|
self.learning_started = True
|
|
if self.learning_started:
|
|
self.timers["train"].push(time_delta)
|
|
self.timers["train"].push_units_processed(train_timesteps)
|
|
self.num_steps_sampled += sample_timesteps
|
|
self.num_steps_trained += train_timesteps
|
|
|
|
@override(PolicyOptimizer)
|
|
def stop(self):
|
|
self.learner.stopped = True
|
|
|
|
@override(PolicyOptimizer)
|
|
def stats(self):
|
|
timing = {
|
|
"{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3)
|
|
for k in self.timers
|
|
}
|
|
timing["learner_grad_time_ms"] = round(
|
|
1000 * self.learner.grad_timer.mean, 3)
|
|
timing["learner_load_time_ms"] = round(
|
|
1000 * self.learner.load_timer.mean, 3)
|
|
timing["learner_load_wait_time_ms"] = round(
|
|
1000 * self.learner.load_wait_timer.mean, 3)
|
|
timing["learner_dequeue_time_ms"] = round(
|
|
1000 * self.learner.queue_timer.mean, 3)
|
|
stats = {
|
|
"sample_throughput": round(self.timers["sample"].mean_throughput,
|
|
3),
|
|
"train_throughput": round(self.timers["train"].mean_throughput, 3),
|
|
"num_weight_syncs": self.num_weight_syncs,
|
|
"num_steps_replayed": self.num_replayed,
|
|
"timing_breakdown": timing,
|
|
"learner_queue": self.learner.learner_queue_size.stats(),
|
|
}
|
|
if self.learner.stats:
|
|
stats["learner"] = self.learner.stats
|
|
return dict(PolicyOptimizer.stats(self), **stats)
|
|
|
|
def _step(self):
|
|
sample_timesteps, train_timesteps = 0, 0
|
|
num_sent = 0
|
|
weights = None
|
|
|
|
for ev, sample_batch in self._augment_with_replay(
|
|
self.sample_tasks.completed_prefetch()):
|
|
self.batch_buffer.append(sample_batch)
|
|
if sum(b.count
|
|
for b in self.batch_buffer) >= self.train_batch_size:
|
|
train_batch = self.batch_buffer[0].concat_samples(
|
|
self.batch_buffer)
|
|
self.learner.inqueue.put(train_batch)
|
|
self.batch_buffer = []
|
|
|
|
# If the batch was replayed, skip the update below.
|
|
if ev is None:
|
|
continue
|
|
|
|
sample_timesteps += sample_batch.count
|
|
|
|
# Put in replay buffer if enabled
|
|
if self.replay_buffer_num_slots > 0:
|
|
self.replay_batches.append(sample_batch)
|
|
if len(self.replay_batches) > self.replay_buffer_num_slots:
|
|
self.replay_batches.pop(0)
|
|
|
|
# Note that it's important to pull new weights once
|
|
# updated to avoid excessive correlation between actors
|
|
if weights is None or (self.learner.weights_updated
|
|
and num_sent >= self.broadcast_interval):
|
|
self.learner.weights_updated = False
|
|
weights = ray.put(self.local_evaluator.get_weights())
|
|
num_sent = 0
|
|
ev.set_weights.remote(weights)
|
|
self.num_weight_syncs += 1
|
|
num_sent += 1
|
|
|
|
# Kick off another sample request
|
|
self.sample_tasks.add(ev, ev.sample.remote())
|
|
|
|
while not self.learner.outqueue.empty():
|
|
count = self.learner.outqueue.get()
|
|
train_timesteps += count
|
|
|
|
return sample_timesteps, train_timesteps
|
|
|
|
def _augment_with_replay(self, sample_futures):
|
|
def can_replay():
|
|
num_needed = int(
|
|
np.ceil(self.train_batch_size / self.sample_batch_size))
|
|
return len(self.replay_batches) > num_needed
|
|
|
|
for ev, sample_batch in sample_futures:
|
|
sample_batch = ray.get(sample_batch)
|
|
yield ev, sample_batch
|
|
|
|
if can_replay():
|
|
f = self.replay_proportion
|
|
while random.random() < f:
|
|
f -= 1
|
|
replay_batch = random.choice(self.replay_batches)
|
|
self.num_replayed += replay_batch.count
|
|
yield None, replay_batch
|
|
|
|
|
|
class LearnerThread(threading.Thread):
|
|
"""Background thread that updates the local model from sample trajectories.
|
|
|
|
The learner thread communicates with the main thread through Queues. This
|
|
is needed since Ray operations can only be run on the main thread. In
|
|
addition, moving heavyweight gradient ops session runs off the main thread
|
|
improves overall throughput.
|
|
"""
|
|
|
|
def __init__(self, local_evaluator, minibatch_buffer_size, num_sgd_iter):
|
|
threading.Thread.__init__(self)
|
|
self.learner_queue_size = WindowStat("size", 50)
|
|
self.local_evaluator = local_evaluator
|
|
self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE)
|
|
self.outqueue = queue.Queue()
|
|
self.minibatch_buffer = MinibatchBuffer(
|
|
self.inqueue, minibatch_buffer_size, num_sgd_iter)
|
|
self.queue_timer = TimerStat()
|
|
self.grad_timer = TimerStat()
|
|
self.load_timer = TimerStat()
|
|
self.load_wait_timer = TimerStat()
|
|
self.daemon = True
|
|
self.weights_updated = False
|
|
self.stats = {}
|
|
self.stopped = False
|
|
|
|
def run(self):
|
|
while not self.stopped:
|
|
self.step()
|
|
|
|
def step(self):
|
|
with self.queue_timer:
|
|
batch, _ = self.minibatch_buffer.get()
|
|
|
|
with self.grad_timer:
|
|
fetches = self.local_evaluator.compute_apply(batch)
|
|
self.weights_updated = True
|
|
self.stats = fetches.get("stats", {})
|
|
|
|
self.outqueue.put(batch.count)
|
|
self.learner_queue_size.push(self.inqueue.qsize())
|
|
|
|
|
|
class TFMultiGPULearner(LearnerThread):
|
|
"""Learner that can use multiple GPUs and parallel loading."""
|
|
|
|
def __init__(self,
|
|
local_evaluator,
|
|
num_gpus=1,
|
|
lr=0.0005,
|
|
train_batch_size=500,
|
|
num_data_loader_buffers=1,
|
|
minibatch_buffer_size=1,
|
|
num_sgd_iter=1,
|
|
_fake_gpus=False):
|
|
# Multi-GPU requires TensorFlow to function.
|
|
import tensorflow as tf
|
|
|
|
LearnerThread.__init__(self, local_evaluator, minibatch_buffer_size,
|
|
num_sgd_iter)
|
|
self.lr = lr
|
|
self.train_batch_size = train_batch_size
|
|
if not num_gpus:
|
|
self.devices = ["/cpu:0"]
|
|
elif _fake_gpus:
|
|
self.devices = ["/cpu:{}".format(i) for i in range(num_gpus)]
|
|
else:
|
|
self.devices = ["/gpu:{}".format(i) for i in range(num_gpus)]
|
|
logger.info("TFMultiGPULearner devices {}".format(self.devices))
|
|
assert self.train_batch_size % len(self.devices) == 0
|
|
assert self.train_batch_size >= len(self.devices), "batch too small"
|
|
|
|
if set(self.local_evaluator.policy_map.keys()) != {"default"}:
|
|
raise NotImplementedError("Multi-gpu mode for multi-agent")
|
|
self.policy = self.local_evaluator.policy_map["default"]
|
|
|
|
# per-GPU graph copies created below must share vars with the policy
|
|
# reuse is set to AUTO_REUSE because Adam nodes are created after
|
|
# all of the device copies are created.
|
|
self.par_opt = []
|
|
with self.local_evaluator.tf_sess.graph.as_default():
|
|
with self.local_evaluator.tf_sess.as_default():
|
|
with tf.variable_scope("default", reuse=tf.AUTO_REUSE):
|
|
if self.policy._state_inputs:
|
|
rnn_inputs = self.policy._state_inputs + [
|
|
self.policy._seq_lens
|
|
]
|
|
else:
|
|
rnn_inputs = []
|
|
adam = tf.train.AdamOptimizer(self.lr)
|
|
for _ in range(num_data_loader_buffers):
|
|
self.par_opt.append(
|
|
LocalSyncParallelOptimizer(
|
|
adam,
|
|
self.devices,
|
|
[v for _, v in self.policy._loss_inputs],
|
|
rnn_inputs,
|
|
999999, # it will get rounded down
|
|
self.policy.copy))
|
|
|
|
self.sess = self.local_evaluator.tf_sess
|
|
self.sess.run(tf.global_variables_initializer())
|
|
|
|
self.idle_optimizers = queue.Queue()
|
|
self.ready_optimizers = queue.Queue()
|
|
for opt in self.par_opt:
|
|
self.idle_optimizers.put(opt)
|
|
for i in range(NUM_DATA_LOAD_THREADS):
|
|
self.loader_thread = _LoaderThread(self, share_stats=(i == 0))
|
|
self.loader_thread.start()
|
|
|
|
self.minibatch_buffer = MinibatchBuffer(
|
|
self.ready_optimizers, minibatch_buffer_size, num_sgd_iter)
|
|
|
|
@override(LearnerThread)
|
|
def step(self):
|
|
assert self.loader_thread.is_alive()
|
|
with self.load_wait_timer:
|
|
opt, released = self.minibatch_buffer.get()
|
|
if released:
|
|
self.idle_optimizers.put(opt)
|
|
|
|
with self.grad_timer:
|
|
fetches = opt.optimize(self.sess, 0)
|
|
self.weights_updated = True
|
|
self.stats = fetches.get("stats", {})
|
|
|
|
self.outqueue.put(self.train_batch_size)
|
|
self.learner_queue_size.push(self.inqueue.qsize())
|
|
|
|
|
|
class _LoaderThread(threading.Thread):
|
|
def __init__(self, learner, share_stats):
|
|
threading.Thread.__init__(self)
|
|
self.learner = learner
|
|
self.daemon = True
|
|
if share_stats:
|
|
self.queue_timer = learner.queue_timer
|
|
self.load_timer = learner.load_timer
|
|
else:
|
|
self.queue_timer = TimerStat()
|
|
self.load_timer = TimerStat()
|
|
|
|
def run(self):
|
|
while True:
|
|
self._step()
|
|
|
|
def _step(self):
|
|
s = self.learner
|
|
with self.queue_timer:
|
|
batch = s.inqueue.get()
|
|
|
|
opt = s.idle_optimizers.get()
|
|
|
|
with self.load_timer:
|
|
tuples = s.policy._get_loss_inputs_dict(batch)
|
|
data_keys = [ph for _, ph in s.policy._loss_inputs]
|
|
if s.policy._state_inputs:
|
|
state_keys = s.policy._state_inputs + [s.policy._seq_lens]
|
|
else:
|
|
state_keys = []
|
|
opt.load_data(s.sess, [tuples[k] for k in data_keys],
|
|
[tuples[k] for k in state_keys])
|
|
|
|
s.ready_optimizers.put(opt)
|
|
|
|
|
|
class MinibatchBuffer(object):
|
|
"""Ring buffer of recent data batches for minibatch SGD."""
|
|
|
|
def __init__(self, inqueue, size, num_passes):
|
|
"""Initialize a minibatch buffer.
|
|
|
|
Arguments:
|
|
inqueue: Queue to populate the internal ring buffer from.
|
|
size: Max number of data items to buffer.
|
|
num_passes: Max num times each data item should be emitted.
|
|
"""
|
|
self.inqueue = inqueue
|
|
self.size = size
|
|
self.max_ttl = num_passes
|
|
self.cur_max_ttl = 1 # ramp up slowly to better mix the input data
|
|
self.buffers = [None] * size
|
|
self.ttl = [0] * size
|
|
self.idx = 0
|
|
|
|
def get(self):
|
|
"""Get a new batch from the internal ring buffer.
|
|
|
|
Returns:
|
|
buf: Data item saved from inqueue.
|
|
released: True if the item is now removed from the ring buffer.
|
|
"""
|
|
if self.ttl[self.idx] <= 0:
|
|
self.buffers[self.idx] = self.inqueue.get()
|
|
self.ttl[self.idx] = self.cur_max_ttl
|
|
if self.cur_max_ttl < self.max_ttl:
|
|
self.cur_max_ttl += 1
|
|
buf = self.buffers[self.idx]
|
|
self.ttl[self.idx] -= 1
|
|
released = self.ttl[self.idx] <= 0
|
|
if released:
|
|
self.buffers[self.idx] = None
|
|
self.idx = (self.idx + 1) % len(self.buffers)
|
|
return buf, released
|