Files
ray/python/ray/rllib/optimizers/async_samples_optimizer.py
T
Eric Liang 24649726dc [rllib] Use batch.count in async samples optimizer (#2488)
Using the actual batch size reduces the risk of mis-accounting. Here, we under-counted samples since in truncate_episodes mode we were doubling the batch size by accident in policy_evaluator.
2018-07-27 16:44:21 -07:00

296 lines
11 KiB
Python

"""Implements Distributed Prioritized Experience Replay.
https://arxiv.org/abs/1803.00933"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import random
import time
import threading
import numpy as np
from six.moves import queue
import ray
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer
from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.utils.actors import TaskPool, create_colocated
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.window_stat import WindowStat
SAMPLE_QUEUE_DEPTH = 2
REPLAY_QUEUE_DEPTH = 4
LEARNER_QUEUE_MAX_SIZE = 16
@ray.remote
class ReplayActor(object):
"""A replay buffer shard.
Ray actors are single-threaded, so for scalability multiple replay actors
may be created to increase parallelism."""
def __init__(self, num_shards, learning_starts, buffer_size,
train_batch_size, prioritized_replay_alpha,
prioritized_replay_beta, prioritized_replay_eps,
clip_rewards):
self.replay_starts = learning_starts // num_shards
self.buffer_size = buffer_size // num_shards
self.train_batch_size = train_batch_size
self.prioritized_replay_beta = prioritized_replay_beta
self.prioritized_replay_eps = prioritized_replay_eps
self.replay_buffer = PrioritizedReplayBuffer(
self.buffer_size,
alpha=prioritized_replay_alpha,
clip_rewards=clip_rewards)
# Metrics
self.add_batch_timer = TimerStat()
self.replay_timer = TimerStat()
self.update_priorities_timer = TimerStat()
def get_host(self):
return os.uname()[1]
def add_batch(self, batch):
PolicyOptimizer._check_not_multiagent(batch)
with self.add_batch_timer:
for row in batch.rows():
self.replay_buffer.add(row["obs"], row["actions"],
row["rewards"], row["new_obs"],
row["dones"], row["weights"])
def replay(self):
with self.replay_timer:
if len(self.replay_buffer) < self.replay_starts:
return None
(obses_t, actions, rewards, obses_tp1, dones, weights,
batch_indexes) = self.replay_buffer.sample(
self.train_batch_size, beta=self.prioritized_replay_beta)
batch = SampleBatch({
"obs": obses_t,
"actions": actions,
"rewards": rewards,
"new_obs": obses_tp1,
"dones": dones,
"weights": weights,
"batch_indexes": batch_indexes
})
return batch
def update_priorities(self, batch_indexes, td_errors):
with self.update_priorities_timer:
new_priorities = (np.abs(td_errors) + self.prioritized_replay_eps)
self.replay_buffer.update_priorities(batch_indexes, new_priorities)
def stats(self):
stat = {
"add_batch_time_ms": round(1000 * self.add_batch_timer.mean, 3),
"replay_time_ms": round(1000 * self.replay_timer.mean, 3),
"update_priorities_time_ms": round(
1000 * self.update_priorities_timer.mean, 3),
}
stat.update(self.replay_buffer.stats())
return stat
class LearnerThread(threading.Thread):
"""Background thread that updates the local model from replay data.
The learner thread communicates with the main thread through Queues. This
is needed since Ray operations can only be run on the main thread. In
addition, moving heavyweight gradient ops session runs off the main thread
improves overall throughput.
"""
def __init__(self, local_evaluator):
threading.Thread.__init__(self)
self.learner_queue_size = WindowStat("size", 50)
self.local_evaluator = local_evaluator
self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE)
self.outqueue = queue.Queue()
self.queue_timer = TimerStat()
self.grad_timer = TimerStat()
self.daemon = True
self.weights_updated = False
def run(self):
while True:
self.step()
def step(self):
with self.queue_timer:
ra, replay = self.inqueue.get()
if replay is not None:
with self.grad_timer:
td_error = self.local_evaluator.compute_apply(replay)[
"td_error"]
self.outqueue.put((ra, replay, td_error, replay.count))
self.learner_queue_size.push(self.inqueue.qsize())
self.weights_updated = True
class AsyncSamplesOptimizer(PolicyOptimizer):
"""Main event loop of the Ape-X optimizer (async sampling with replay).
This class coordinates the data transfers between the learner thread,
remote evaluators (Ape-X actors), and replay buffer actors.
This optimizer requires that policy evaluators return an additional
"td_error" array in the info return of compute_gradients(). This error
term will be used for sample prioritization."""
def _init(self,
learning_starts=1000,
buffer_size=10000,
prioritized_replay=True,
prioritized_replay_alpha=0.6,
prioritized_replay_beta=0.4,
prioritized_replay_eps=1e-6,
train_batch_size=512,
sample_batch_size=50,
num_replay_buffer_shards=1,
max_weight_sync_delay=400,
clip_rewards=True,
debug=False):
self.debug = debug
self.replay_starts = learning_starts
self.prioritized_replay_beta = prioritized_replay_beta
self.prioritized_replay_eps = prioritized_replay_eps
self.max_weight_sync_delay = max_weight_sync_delay
self.learner = LearnerThread(self.local_evaluator)
self.learner.start()
self.replay_actors = create_colocated(ReplayActor, [
num_replay_buffer_shards, learning_starts, buffer_size,
train_batch_size, prioritized_replay_alpha,
prioritized_replay_beta, prioritized_replay_eps, clip_rewards
], num_replay_buffer_shards)
assert len(self.remote_evaluators) > 0
# Stats
self.timers = {
k: TimerStat()
for k in [
"put_weights", "get_samples", "enqueue", "sample_processing",
"replay_processing", "update_priorities", "train", "sample"
]
}
self.num_weight_syncs = 0
self.learning_started = False
# Number of worker steps since the last weight update
self.steps_since_update = {}
# Otherwise kick of replay tasks for local gradient updates
self.replay_tasks = TaskPool()
for ra in self.replay_actors:
for _ in range(REPLAY_QUEUE_DEPTH):
self.replay_tasks.add(ra, ra.replay.remote())
# Kick off async background sampling
self.sample_tasks = TaskPool()
weights = self.local_evaluator.get_weights()
for ev in self.remote_evaluators:
ev.set_weights.remote(weights)
self.steps_since_update[ev] = 0
for _ in range(SAMPLE_QUEUE_DEPTH):
self.sample_tasks.add(ev, ev.sample_with_count.remote())
def step(self):
start = time.time()
sample_timesteps, train_timesteps = self._step()
time_delta = time.time() - start
self.timers["sample"].push(time_delta)
self.timers["sample"].push_units_processed(sample_timesteps)
if train_timesteps > 0:
self.learning_started = True
if self.learning_started:
self.timers["train"].push(time_delta)
self.timers["train"].push_units_processed(train_timesteps)
self.num_steps_sampled += sample_timesteps
self.num_steps_trained += train_timesteps
def _step(self):
sample_timesteps, train_timesteps = 0, 0
weights = None
with self.timers["sample_processing"]:
completed = list(self.sample_tasks.completed())
counts = ray.get([c[1][1] for c in completed])
for i, (ev, (sample_batch, count)) in enumerate(completed):
sample_timesteps += counts[i]
# Send the data to the replay buffer
random.choice(
self.replay_actors).add_batch.remote(sample_batch)
# Update weights if needed
self.steps_since_update[ev] += counts[i]
if self.steps_since_update[ev] >= self.max_weight_sync_delay:
# Note that it's important to pull new weights once
# updated to avoid excessive correlation between actors
if weights is None or self.learner.weights_updated:
self.learner.weights_updated = False
with self.timers["put_weights"]:
weights = ray.put(
self.local_evaluator.get_weights())
ev.set_weights.remote(weights)
self.num_weight_syncs += 1
self.steps_since_update[ev] = 0
# Kick off another sample request
self.sample_tasks.add(ev, ev.sample_with_count.remote())
with self.timers["replay_processing"]:
for ra, replay in self.replay_tasks.completed():
self.replay_tasks.add(ra, ra.replay.remote())
with self.timers["get_samples"]:
samples = ray.get(replay)
with self.timers["enqueue"]:
self.learner.inqueue.put((ra, samples))
with self.timers["update_priorities"]:
while not self.learner.outqueue.empty():
ra, replay, td_error, count = self.learner.outqueue.get()
ra.update_priorities.remote(replay["batch_indexes"], td_error)
train_timesteps += count
return sample_timesteps, train_timesteps
def stats(self):
replay_stats = ray.get(self.replay_actors[0].stats.remote())
timing = {
"{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3)
for k in self.timers
}
timing["learner_grad_time_ms"] = round(
1000 * self.learner.grad_timer.mean, 3)
timing["learner_dequeue_time_ms"] = round(
1000 * self.learner.queue_timer.mean, 3)
stats = {
"sample_throughput": round(self.timers["sample"].mean_throughput,
3),
"train_throughput": round(self.timers["train"].mean_throughput, 3),
"num_weight_syncs": self.num_weight_syncs,
}
debug_stats = {
"replay_shard_0": replay_stats,
"timing_breakdown": timing,
"pending_sample_tasks": self.sample_tasks.count,
"pending_replay_tasks": self.replay_tasks.count,
"learner_queue": self.learner.learner_queue_size.stats(),
}
if self.debug:
stats.update(debug_stats)
return dict(PolicyOptimizer.stats(self), **stats)