[rllib] Update multi-gpu impala numbers (#3327)

This commit is contained in:
Eric Liang
2018-11-19 20:55:27 -08:00
committed by GitHub
parent 5972c29d28
commit abdc3b592e
7 changed files with 63 additions and 41 deletions
+3
View File
@@ -20,6 +20,7 @@ OPTIMIZER_SHARED_CONFIGS = [
"num_parallel_data_loaders",
"grad_clip",
"max_sample_requests_in_flight_per_worker",
"broadcast_interval",
]
# yapf: disable
@@ -42,6 +43,8 @@ DEFAULT_CONFIG = with_common_config({
"num_parallel_data_loaders": 1,
# level of queuing for sampling.
"max_sample_requests_in_flight_per_worker": 2,
# max number of workers to broadcast one set of weights to
"broadcast_interval": 1,
# set >0 to enable experience replay. Saved samples will be replayed with
# a p:1 proportion to new data samples.
"replay_proportion": 0.0,
@@ -197,10 +197,12 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
replay_buffer_num_slots=0,
replay_proportion=0.0,
num_parallel_data_loaders=1,
max_sample_requests_in_flight_per_worker=2):
max_sample_requests_in_flight_per_worker=2,
broadcast_interval=1):
self.learning_started = False
self.train_batch_size = train_batch_size
self.sample_batch_size = sample_batch_size
self.broadcast_interval = broadcast_interval
if num_gpus > 1 or num_parallel_data_loaders > 1:
logger.info(
@@ -224,11 +226,7 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
assert len(self.remote_evaluators) > 0
# Stats
self.timers = {
k: TimerStat()
for k in
["put_weights", "enqueue", "sample_processing", "train", "sample"]
}
self.timers = {k: TimerStat() for k in ["train", "sample"]}
self.num_weight_syncs = 0
self.num_replayed = 0
self.learning_started = False
@@ -286,43 +284,44 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
def _step(self):
sample_timesteps, train_timesteps = 0, 0
num_sent = 0
weights = None
with self.timers["sample_processing"]:
for ev, sample_batch in self._augment_with_replay(
self.sample_tasks.completed_prefetch()):
self.batch_buffer.append(sample_batch)
if sum(b.count
for b in self.batch_buffer) >= self.train_batch_size:
train_batch = self.batch_buffer[0].concat_samples(
self.batch_buffer)
with self.timers["enqueue"]:
self.learner.inqueue.put(train_batch)
self.batch_buffer = []
for ev, sample_batch in self._augment_with_replay(
self.sample_tasks.completed_prefetch()):
self.batch_buffer.append(sample_batch)
if sum(b.count
for b in self.batch_buffer) >= self.train_batch_size:
train_batch = self.batch_buffer[0].concat_samples(
self.batch_buffer)
self.learner.inqueue.put(train_batch)
self.batch_buffer = []
# If the batch was replayed, skip the update below.
if ev is None:
continue
# If the batch was replayed, skip the update below.
if ev is None:
continue
sample_timesteps += sample_batch.count
sample_timesteps += sample_batch.count
# Put in replay buffer if enabled
if self.replay_buffer_num_slots > 0:
self.replay_batches.append(sample_batch)
if len(self.replay_batches) > self.replay_buffer_num_slots:
self.replay_batches.pop(0)
# Put in replay buffer if enabled
if self.replay_buffer_num_slots > 0:
self.replay_batches.append(sample_batch)
if len(self.replay_batches) > self.replay_buffer_num_slots:
self.replay_batches.pop(0)
# Note that it's important to pull new weights once
# updated to avoid excessive correlation between actors
if weights is None or self.learner.weights_updated:
self.learner.weights_updated = False
with self.timers["put_weights"]:
weights = ray.put(self.local_evaluator.get_weights())
ev.set_weights.remote(weights)
self.num_weight_syncs += 1
# Note that it's important to pull new weights once
# updated to avoid excessive correlation between actors
if weights is None or (self.learner.weights_updated
and num_sent >= self.broadcast_interval):
self.learner.weights_updated = False
weights = ray.put(self.local_evaluator.get_weights())
num_sent = 0
ev.set_weights.remote(weights)
self.num_weight_syncs += 1
num_sent += 1
# Kick off another sample request
self.sample_tasks.add(ev, ev.sample.remote())
# Kick off another sample request
self.sample_tasks.add(ev, ev.sample.remote())
while not self.learner.outqueue.empty():
count = self.learner.outqueue.get()
@@ -0,0 +1,19 @@
# This can reach 18-19 reward in ~3 minutes on p3.16xl head w/m4.16xl workers
# 128 workers -> 3 minutes (best case)
# 64 workers -> 4 minutes
# 32 workers -> 7 minutes
# See also: pong-impala.yaml, pong-impala-vectorized.yaml
pong-impala-fast:
env: PongNoFrameskip-v4
run: IMPALA
config:
sample_batch_size: 50
train_batch_size: 1000
num_workers: 256
num_envs_per_worker: 5
broadcast_interval: 5
max_sample_requests_in_flight_per_worker: 1
num_parallel_data_loaders: 4
num_gpus: 2
model:
dim: 42
@@ -2,7 +2,7 @@
# 128 workers -> 8 minutes
# 32 workers -> 17 minutes
# 16 workers -> 40 min+
# See also: pong-impala-vectorized.yaml
# See also: pong-impala-fast.yaml, pong-impala-vectorized.yaml
pong-impala:
env: PongNoFrameskip-v4
run: IMPALA