mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 12:45:44 +08:00
[rllib] Update multi-gpu impala numbers (#3327)
This commit is contained in:
@@ -20,6 +20,7 @@ OPTIMIZER_SHARED_CONFIGS = [
|
||||
"num_parallel_data_loaders",
|
||||
"grad_clip",
|
||||
"max_sample_requests_in_flight_per_worker",
|
||||
"broadcast_interval",
|
||||
]
|
||||
|
||||
# yapf: disable
|
||||
@@ -42,6 +43,8 @@ DEFAULT_CONFIG = with_common_config({
|
||||
"num_parallel_data_loaders": 1,
|
||||
# level of queuing for sampling.
|
||||
"max_sample_requests_in_flight_per_worker": 2,
|
||||
# max number of workers to broadcast one set of weights to
|
||||
"broadcast_interval": 1,
|
||||
# set >0 to enable experience replay. Saved samples will be replayed with
|
||||
# a p:1 proportion to new data samples.
|
||||
"replay_proportion": 0.0,
|
||||
|
||||
@@ -197,10 +197,12 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
|
||||
replay_buffer_num_slots=0,
|
||||
replay_proportion=0.0,
|
||||
num_parallel_data_loaders=1,
|
||||
max_sample_requests_in_flight_per_worker=2):
|
||||
max_sample_requests_in_flight_per_worker=2,
|
||||
broadcast_interval=1):
|
||||
self.learning_started = False
|
||||
self.train_batch_size = train_batch_size
|
||||
self.sample_batch_size = sample_batch_size
|
||||
self.broadcast_interval = broadcast_interval
|
||||
|
||||
if num_gpus > 1 or num_parallel_data_loaders > 1:
|
||||
logger.info(
|
||||
@@ -224,11 +226,7 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
|
||||
assert len(self.remote_evaluators) > 0
|
||||
|
||||
# Stats
|
||||
self.timers = {
|
||||
k: TimerStat()
|
||||
for k in
|
||||
["put_weights", "enqueue", "sample_processing", "train", "sample"]
|
||||
}
|
||||
self.timers = {k: TimerStat() for k in ["train", "sample"]}
|
||||
self.num_weight_syncs = 0
|
||||
self.num_replayed = 0
|
||||
self.learning_started = False
|
||||
@@ -286,43 +284,44 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
|
||||
|
||||
def _step(self):
|
||||
sample_timesteps, train_timesteps = 0, 0
|
||||
num_sent = 0
|
||||
weights = None
|
||||
|
||||
with self.timers["sample_processing"]:
|
||||
for ev, sample_batch in self._augment_with_replay(
|
||||
self.sample_tasks.completed_prefetch()):
|
||||
self.batch_buffer.append(sample_batch)
|
||||
if sum(b.count
|
||||
for b in self.batch_buffer) >= self.train_batch_size:
|
||||
train_batch = self.batch_buffer[0].concat_samples(
|
||||
self.batch_buffer)
|
||||
with self.timers["enqueue"]:
|
||||
self.learner.inqueue.put(train_batch)
|
||||
self.batch_buffer = []
|
||||
for ev, sample_batch in self._augment_with_replay(
|
||||
self.sample_tasks.completed_prefetch()):
|
||||
self.batch_buffer.append(sample_batch)
|
||||
if sum(b.count
|
||||
for b in self.batch_buffer) >= self.train_batch_size:
|
||||
train_batch = self.batch_buffer[0].concat_samples(
|
||||
self.batch_buffer)
|
||||
self.learner.inqueue.put(train_batch)
|
||||
self.batch_buffer = []
|
||||
|
||||
# If the batch was replayed, skip the update below.
|
||||
if ev is None:
|
||||
continue
|
||||
# If the batch was replayed, skip the update below.
|
||||
if ev is None:
|
||||
continue
|
||||
|
||||
sample_timesteps += sample_batch.count
|
||||
sample_timesteps += sample_batch.count
|
||||
|
||||
# Put in replay buffer if enabled
|
||||
if self.replay_buffer_num_slots > 0:
|
||||
self.replay_batches.append(sample_batch)
|
||||
if len(self.replay_batches) > self.replay_buffer_num_slots:
|
||||
self.replay_batches.pop(0)
|
||||
# Put in replay buffer if enabled
|
||||
if self.replay_buffer_num_slots > 0:
|
||||
self.replay_batches.append(sample_batch)
|
||||
if len(self.replay_batches) > self.replay_buffer_num_slots:
|
||||
self.replay_batches.pop(0)
|
||||
|
||||
# Note that it's important to pull new weights once
|
||||
# updated to avoid excessive correlation between actors
|
||||
if weights is None or self.learner.weights_updated:
|
||||
self.learner.weights_updated = False
|
||||
with self.timers["put_weights"]:
|
||||
weights = ray.put(self.local_evaluator.get_weights())
|
||||
ev.set_weights.remote(weights)
|
||||
self.num_weight_syncs += 1
|
||||
# Note that it's important to pull new weights once
|
||||
# updated to avoid excessive correlation between actors
|
||||
if weights is None or (self.learner.weights_updated
|
||||
and num_sent >= self.broadcast_interval):
|
||||
self.learner.weights_updated = False
|
||||
weights = ray.put(self.local_evaluator.get_weights())
|
||||
num_sent = 0
|
||||
ev.set_weights.remote(weights)
|
||||
self.num_weight_syncs += 1
|
||||
num_sent += 1
|
||||
|
||||
# Kick off another sample request
|
||||
self.sample_tasks.add(ev, ev.sample.remote())
|
||||
# Kick off another sample request
|
||||
self.sample_tasks.add(ev, ev.sample.remote())
|
||||
|
||||
while not self.learner.outqueue.empty():
|
||||
count = self.learner.outqueue.get()
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
# This can reach 18-19 reward in ~3 minutes on p3.16xl head w/m4.16xl workers
|
||||
# 128 workers -> 3 minutes (best case)
|
||||
# 64 workers -> 4 minutes
|
||||
# 32 workers -> 7 minutes
|
||||
# See also: pong-impala.yaml, pong-impala-vectorized.yaml
|
||||
pong-impala-fast:
|
||||
env: PongNoFrameskip-v4
|
||||
run: IMPALA
|
||||
config:
|
||||
sample_batch_size: 50
|
||||
train_batch_size: 1000
|
||||
num_workers: 256
|
||||
num_envs_per_worker: 5
|
||||
broadcast_interval: 5
|
||||
max_sample_requests_in_flight_per_worker: 1
|
||||
num_parallel_data_loaders: 4
|
||||
num_gpus: 2
|
||||
model:
|
||||
dim: 42
|
||||
@@ -2,7 +2,7 @@
|
||||
# 128 workers -> 8 minutes
|
||||
# 32 workers -> 17 minutes
|
||||
# 16 workers -> 40 min+
|
||||
# See also: pong-impala-vectorized.yaml
|
||||
# See also: pong-impala-fast.yaml, pong-impala-vectorized.yaml
|
||||
pong-impala:
|
||||
env: PongNoFrameskip-v4
|
||||
run: IMPALA
|
||||
|
||||
Reference in New Issue
Block a user