mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 20:07:41 +08:00
24649726dc
Using the actual batch size reduces the risk of mis-accounting. Here, we under-counted samples since in truncate_episodes mode we were doubling the batch size by accident in policy_evaluator.
296 lines
11 KiB
Python
296 lines
11 KiB
Python
"""Implements Distributed Prioritized Experience Replay.
|
|
|
|
https://arxiv.org/abs/1803.00933"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import os
|
|
import random
|
|
import time
|
|
import threading
|
|
|
|
import numpy as np
|
|
from six.moves import queue
|
|
|
|
import ray
|
|
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
|
from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer
|
|
from ray.rllib.evaluation.sample_batch import SampleBatch
|
|
from ray.rllib.utils.actors import TaskPool, create_colocated
|
|
from ray.rllib.utils.timer import TimerStat
|
|
from ray.rllib.utils.window_stat import WindowStat
|
|
|
|
SAMPLE_QUEUE_DEPTH = 2
|
|
REPLAY_QUEUE_DEPTH = 4
|
|
LEARNER_QUEUE_MAX_SIZE = 16
|
|
|
|
|
|
@ray.remote
|
|
class ReplayActor(object):
|
|
"""A replay buffer shard.
|
|
|
|
Ray actors are single-threaded, so for scalability multiple replay actors
|
|
may be created to increase parallelism."""
|
|
|
|
def __init__(self, num_shards, learning_starts, buffer_size,
|
|
train_batch_size, prioritized_replay_alpha,
|
|
prioritized_replay_beta, prioritized_replay_eps,
|
|
clip_rewards):
|
|
self.replay_starts = learning_starts // num_shards
|
|
self.buffer_size = buffer_size // num_shards
|
|
self.train_batch_size = train_batch_size
|
|
self.prioritized_replay_beta = prioritized_replay_beta
|
|
self.prioritized_replay_eps = prioritized_replay_eps
|
|
|
|
self.replay_buffer = PrioritizedReplayBuffer(
|
|
self.buffer_size,
|
|
alpha=prioritized_replay_alpha,
|
|
clip_rewards=clip_rewards)
|
|
|
|
# Metrics
|
|
self.add_batch_timer = TimerStat()
|
|
self.replay_timer = TimerStat()
|
|
self.update_priorities_timer = TimerStat()
|
|
|
|
def get_host(self):
|
|
return os.uname()[1]
|
|
|
|
def add_batch(self, batch):
|
|
PolicyOptimizer._check_not_multiagent(batch)
|
|
with self.add_batch_timer:
|
|
for row in batch.rows():
|
|
self.replay_buffer.add(row["obs"], row["actions"],
|
|
row["rewards"], row["new_obs"],
|
|
row["dones"], row["weights"])
|
|
|
|
def replay(self):
|
|
with self.replay_timer:
|
|
if len(self.replay_buffer) < self.replay_starts:
|
|
return None
|
|
|
|
(obses_t, actions, rewards, obses_tp1, dones, weights,
|
|
batch_indexes) = self.replay_buffer.sample(
|
|
self.train_batch_size, beta=self.prioritized_replay_beta)
|
|
|
|
batch = SampleBatch({
|
|
"obs": obses_t,
|
|
"actions": actions,
|
|
"rewards": rewards,
|
|
"new_obs": obses_tp1,
|
|
"dones": dones,
|
|
"weights": weights,
|
|
"batch_indexes": batch_indexes
|
|
})
|
|
return batch
|
|
|
|
def update_priorities(self, batch_indexes, td_errors):
|
|
with self.update_priorities_timer:
|
|
new_priorities = (np.abs(td_errors) + self.prioritized_replay_eps)
|
|
self.replay_buffer.update_priorities(batch_indexes, new_priorities)
|
|
|
|
def stats(self):
|
|
stat = {
|
|
"add_batch_time_ms": round(1000 * self.add_batch_timer.mean, 3),
|
|
"replay_time_ms": round(1000 * self.replay_timer.mean, 3),
|
|
"update_priorities_time_ms": round(
|
|
1000 * self.update_priorities_timer.mean, 3),
|
|
}
|
|
stat.update(self.replay_buffer.stats())
|
|
return stat
|
|
|
|
|
|
class LearnerThread(threading.Thread):
|
|
"""Background thread that updates the local model from replay data.
|
|
|
|
The learner thread communicates with the main thread through Queues. This
|
|
is needed since Ray operations can only be run on the main thread. In
|
|
addition, moving heavyweight gradient ops session runs off the main thread
|
|
improves overall throughput.
|
|
"""
|
|
|
|
def __init__(self, local_evaluator):
|
|
threading.Thread.__init__(self)
|
|
self.learner_queue_size = WindowStat("size", 50)
|
|
self.local_evaluator = local_evaluator
|
|
self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE)
|
|
self.outqueue = queue.Queue()
|
|
self.queue_timer = TimerStat()
|
|
self.grad_timer = TimerStat()
|
|
self.daemon = True
|
|
self.weights_updated = False
|
|
|
|
def run(self):
|
|
while True:
|
|
self.step()
|
|
|
|
def step(self):
|
|
with self.queue_timer:
|
|
ra, replay = self.inqueue.get()
|
|
if replay is not None:
|
|
with self.grad_timer:
|
|
td_error = self.local_evaluator.compute_apply(replay)[
|
|
"td_error"]
|
|
self.outqueue.put((ra, replay, td_error, replay.count))
|
|
self.learner_queue_size.push(self.inqueue.qsize())
|
|
self.weights_updated = True
|
|
|
|
|
|
class AsyncSamplesOptimizer(PolicyOptimizer):
|
|
"""Main event loop of the Ape-X optimizer (async sampling with replay).
|
|
|
|
This class coordinates the data transfers between the learner thread,
|
|
remote evaluators (Ape-X actors), and replay buffer actors.
|
|
|
|
This optimizer requires that policy evaluators return an additional
|
|
"td_error" array in the info return of compute_gradients(). This error
|
|
term will be used for sample prioritization."""
|
|
|
|
def _init(self,
|
|
learning_starts=1000,
|
|
buffer_size=10000,
|
|
prioritized_replay=True,
|
|
prioritized_replay_alpha=0.6,
|
|
prioritized_replay_beta=0.4,
|
|
prioritized_replay_eps=1e-6,
|
|
train_batch_size=512,
|
|
sample_batch_size=50,
|
|
num_replay_buffer_shards=1,
|
|
max_weight_sync_delay=400,
|
|
clip_rewards=True,
|
|
debug=False):
|
|
|
|
self.debug = debug
|
|
self.replay_starts = learning_starts
|
|
self.prioritized_replay_beta = prioritized_replay_beta
|
|
self.prioritized_replay_eps = prioritized_replay_eps
|
|
self.max_weight_sync_delay = max_weight_sync_delay
|
|
|
|
self.learner = LearnerThread(self.local_evaluator)
|
|
self.learner.start()
|
|
|
|
self.replay_actors = create_colocated(ReplayActor, [
|
|
num_replay_buffer_shards, learning_starts, buffer_size,
|
|
train_batch_size, prioritized_replay_alpha,
|
|
prioritized_replay_beta, prioritized_replay_eps, clip_rewards
|
|
], num_replay_buffer_shards)
|
|
assert len(self.remote_evaluators) > 0
|
|
|
|
# Stats
|
|
self.timers = {
|
|
k: TimerStat()
|
|
for k in [
|
|
"put_weights", "get_samples", "enqueue", "sample_processing",
|
|
"replay_processing", "update_priorities", "train", "sample"
|
|
]
|
|
}
|
|
self.num_weight_syncs = 0
|
|
self.learning_started = False
|
|
|
|
# Number of worker steps since the last weight update
|
|
self.steps_since_update = {}
|
|
|
|
# Otherwise kick of replay tasks for local gradient updates
|
|
self.replay_tasks = TaskPool()
|
|
for ra in self.replay_actors:
|
|
for _ in range(REPLAY_QUEUE_DEPTH):
|
|
self.replay_tasks.add(ra, ra.replay.remote())
|
|
|
|
# Kick off async background sampling
|
|
self.sample_tasks = TaskPool()
|
|
weights = self.local_evaluator.get_weights()
|
|
for ev in self.remote_evaluators:
|
|
ev.set_weights.remote(weights)
|
|
self.steps_since_update[ev] = 0
|
|
for _ in range(SAMPLE_QUEUE_DEPTH):
|
|
self.sample_tasks.add(ev, ev.sample_with_count.remote())
|
|
|
|
def step(self):
|
|
start = time.time()
|
|
sample_timesteps, train_timesteps = self._step()
|
|
time_delta = time.time() - start
|
|
self.timers["sample"].push(time_delta)
|
|
self.timers["sample"].push_units_processed(sample_timesteps)
|
|
if train_timesteps > 0:
|
|
self.learning_started = True
|
|
if self.learning_started:
|
|
self.timers["train"].push(time_delta)
|
|
self.timers["train"].push_units_processed(train_timesteps)
|
|
self.num_steps_sampled += sample_timesteps
|
|
self.num_steps_trained += train_timesteps
|
|
|
|
def _step(self):
|
|
sample_timesteps, train_timesteps = 0, 0
|
|
weights = None
|
|
|
|
with self.timers["sample_processing"]:
|
|
completed = list(self.sample_tasks.completed())
|
|
counts = ray.get([c[1][1] for c in completed])
|
|
for i, (ev, (sample_batch, count)) in enumerate(completed):
|
|
sample_timesteps += counts[i]
|
|
|
|
# Send the data to the replay buffer
|
|
random.choice(
|
|
self.replay_actors).add_batch.remote(sample_batch)
|
|
|
|
# Update weights if needed
|
|
self.steps_since_update[ev] += counts[i]
|
|
if self.steps_since_update[ev] >= self.max_weight_sync_delay:
|
|
# Note that it's important to pull new weights once
|
|
# updated to avoid excessive correlation between actors
|
|
if weights is None or self.learner.weights_updated:
|
|
self.learner.weights_updated = False
|
|
with self.timers["put_weights"]:
|
|
weights = ray.put(
|
|
self.local_evaluator.get_weights())
|
|
ev.set_weights.remote(weights)
|
|
self.num_weight_syncs += 1
|
|
self.steps_since_update[ev] = 0
|
|
|
|
# Kick off another sample request
|
|
self.sample_tasks.add(ev, ev.sample_with_count.remote())
|
|
|
|
with self.timers["replay_processing"]:
|
|
for ra, replay in self.replay_tasks.completed():
|
|
self.replay_tasks.add(ra, ra.replay.remote())
|
|
with self.timers["get_samples"]:
|
|
samples = ray.get(replay)
|
|
with self.timers["enqueue"]:
|
|
self.learner.inqueue.put((ra, samples))
|
|
|
|
with self.timers["update_priorities"]:
|
|
while not self.learner.outqueue.empty():
|
|
ra, replay, td_error, count = self.learner.outqueue.get()
|
|
ra.update_priorities.remote(replay["batch_indexes"], td_error)
|
|
train_timesteps += count
|
|
|
|
return sample_timesteps, train_timesteps
|
|
|
|
def stats(self):
|
|
replay_stats = ray.get(self.replay_actors[0].stats.remote())
|
|
timing = {
|
|
"{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3)
|
|
for k in self.timers
|
|
}
|
|
timing["learner_grad_time_ms"] = round(
|
|
1000 * self.learner.grad_timer.mean, 3)
|
|
timing["learner_dequeue_time_ms"] = round(
|
|
1000 * self.learner.queue_timer.mean, 3)
|
|
stats = {
|
|
"sample_throughput": round(self.timers["sample"].mean_throughput,
|
|
3),
|
|
"train_throughput": round(self.timers["train"].mean_throughput, 3),
|
|
"num_weight_syncs": self.num_weight_syncs,
|
|
}
|
|
debug_stats = {
|
|
"replay_shard_0": replay_stats,
|
|
"timing_breakdown": timing,
|
|
"pending_sample_tasks": self.sample_tasks.count,
|
|
"pending_replay_tasks": self.replay_tasks.count,
|
|
"learner_queue": self.learner.learner_queue_size.stats(),
|
|
}
|
|
if self.debug:
|
|
stats.update(debug_stats)
|
|
return dict(PolicyOptimizer.stats(self), **stats)
|