diff --git a/doc/source/impala.png b/doc/source/impala.png index a7d12e4b5..0d42fe6e0 100644 Binary files a/doc/source/impala.png and b/doc/source/impala.png differ diff --git a/doc/source/rllib-algorithms.rst b/doc/source/rllib-algorithms.rst index 9a7c53539..8b64b04c2 100644 --- a/doc/source/rllib-algorithms.rst +++ b/doc/source/rllib-algorithms.rst @@ -52,7 +52,7 @@ Importance Weighted Actor-Learner Architecture (IMPALA) `[implementation] `__ In IMPALA, a central learner runs SGD in a tight loop while asynchronously pulling sample batches from many actor processes. RLlib's IMPALA implementation uses DeepMind's reference `V-trace code `__. Note that we do not provide a deep residual network out of the box, but one can be plugged in as a `custom model `__. Multiple learner GPUs and experience replay are also supported. -Tuned examples: `PongNoFrameskip-v4 `__, `vectorized configuration `__, `{BeamRider,Breakout,Qbert,SpaceInvaders}NoFrameskip-v4 `__ +Tuned examples: `PongNoFrameskip-v4 `__, `vectorized configuration `__, `multi-gpu configuration `__, `{BeamRider,Breakout,Qbert,SpaceInvaders}NoFrameskip-v4 `__ **Atari results @10M steps**: `more details `__ @@ -78,7 +78,8 @@ SpaceInvaders 843 ~300 .. figure:: impala.png - IMPALA solves Atari several times faster than A2C / A3C, with similar sample efficiency. Here IMPALA scales from 16 to 128 workers to solve PongNoFrameskip-v4 in ~8 minutes. + Multi-GPU IMPALA scales up to solve PongNoFrameskip-v4 in ~3 minutes using a pair of V100 GPUs and 128 CPU workers. + The maximum training throughput reached is ~30k transitions per second (~120k environment frames per second). **IMPALA-specific configs** (see also `common configs `__): diff --git a/doc/source/rllib-training.rst b/doc/source/rllib-training.rst index 588ba5dcc..6b1366f4e 100644 --- a/doc/source/rllib-training.rst +++ b/doc/source/rllib-training.rst @@ -50,11 +50,11 @@ An example of evaluating a previously trained DQN agent is as follows: .. code-block:: bash python ray/python/ray/rllib/rollout.py \ - ~/ray_results/default/DQN_CartPole-v0_0upjmdgr0/checkpoint-1 \ + ~/ray_results/default/DQN_CartPole-v0_0upjmdgr0/checkpoint_1/checkpoint-1 \ --run DQN --env CartPole-v0 --steps 10000 The ``rollout.py`` helper script reconstructs a DQN agent from the checkpoint -located at ``~/ray_results/default/DQN_CartPole-v0_0upjmdgr0/checkpoint-1`` +located at ``~/ray_results/default/DQN_CartPole-v0_0upjmdgr0/checkpoint_1/checkpoint-1`` and renders its behavior in the environment specified by ``--env``. Configuration diff --git a/python/ray/rllib/agents/impala/impala.py b/python/ray/rllib/agents/impala/impala.py index b9665e9bf..45af92200 100644 --- a/python/ray/rllib/agents/impala/impala.py +++ b/python/ray/rllib/agents/impala/impala.py @@ -20,6 +20,7 @@ OPTIMIZER_SHARED_CONFIGS = [ "num_parallel_data_loaders", "grad_clip", "max_sample_requests_in_flight_per_worker", + "broadcast_interval", ] # yapf: disable @@ -42,6 +43,8 @@ DEFAULT_CONFIG = with_common_config({ "num_parallel_data_loaders": 1, # level of queuing for sampling. "max_sample_requests_in_flight_per_worker": 2, + # max number of workers to broadcast one set of weights to + "broadcast_interval": 1, # set >0 to enable experience replay. Saved samples will be replayed with # a p:1 proportion to new data samples. "replay_proportion": 0.0, diff --git a/python/ray/rllib/optimizers/async_samples_optimizer.py b/python/ray/rllib/optimizers/async_samples_optimizer.py index e0ff26ed2..6b8f6014d 100644 --- a/python/ray/rllib/optimizers/async_samples_optimizer.py +++ b/python/ray/rllib/optimizers/async_samples_optimizer.py @@ -197,10 +197,12 @@ class AsyncSamplesOptimizer(PolicyOptimizer): replay_buffer_num_slots=0, replay_proportion=0.0, num_parallel_data_loaders=1, - max_sample_requests_in_flight_per_worker=2): + max_sample_requests_in_flight_per_worker=2, + broadcast_interval=1): self.learning_started = False self.train_batch_size = train_batch_size self.sample_batch_size = sample_batch_size + self.broadcast_interval = broadcast_interval if num_gpus > 1 or num_parallel_data_loaders > 1: logger.info( @@ -224,11 +226,7 @@ class AsyncSamplesOptimizer(PolicyOptimizer): assert len(self.remote_evaluators) > 0 # Stats - self.timers = { - k: TimerStat() - for k in - ["put_weights", "enqueue", "sample_processing", "train", "sample"] - } + self.timers = {k: TimerStat() for k in ["train", "sample"]} self.num_weight_syncs = 0 self.num_replayed = 0 self.learning_started = False @@ -286,43 +284,44 @@ class AsyncSamplesOptimizer(PolicyOptimizer): def _step(self): sample_timesteps, train_timesteps = 0, 0 + num_sent = 0 weights = None - with self.timers["sample_processing"]: - for ev, sample_batch in self._augment_with_replay( - self.sample_tasks.completed_prefetch()): - self.batch_buffer.append(sample_batch) - if sum(b.count - for b in self.batch_buffer) >= self.train_batch_size: - train_batch = self.batch_buffer[0].concat_samples( - self.batch_buffer) - with self.timers["enqueue"]: - self.learner.inqueue.put(train_batch) - self.batch_buffer = [] + for ev, sample_batch in self._augment_with_replay( + self.sample_tasks.completed_prefetch()): + self.batch_buffer.append(sample_batch) + if sum(b.count + for b in self.batch_buffer) >= self.train_batch_size: + train_batch = self.batch_buffer[0].concat_samples( + self.batch_buffer) + self.learner.inqueue.put(train_batch) + self.batch_buffer = [] - # If the batch was replayed, skip the update below. - if ev is None: - continue + # If the batch was replayed, skip the update below. + if ev is None: + continue - sample_timesteps += sample_batch.count + sample_timesteps += sample_batch.count - # Put in replay buffer if enabled - if self.replay_buffer_num_slots > 0: - self.replay_batches.append(sample_batch) - if len(self.replay_batches) > self.replay_buffer_num_slots: - self.replay_batches.pop(0) + # Put in replay buffer if enabled + if self.replay_buffer_num_slots > 0: + self.replay_batches.append(sample_batch) + if len(self.replay_batches) > self.replay_buffer_num_slots: + self.replay_batches.pop(0) - # Note that it's important to pull new weights once - # updated to avoid excessive correlation between actors - if weights is None or self.learner.weights_updated: - self.learner.weights_updated = False - with self.timers["put_weights"]: - weights = ray.put(self.local_evaluator.get_weights()) - ev.set_weights.remote(weights) - self.num_weight_syncs += 1 + # Note that it's important to pull new weights once + # updated to avoid excessive correlation between actors + if weights is None or (self.learner.weights_updated + and num_sent >= self.broadcast_interval): + self.learner.weights_updated = False + weights = ray.put(self.local_evaluator.get_weights()) + num_sent = 0 + ev.set_weights.remote(weights) + self.num_weight_syncs += 1 + num_sent += 1 - # Kick off another sample request - self.sample_tasks.add(ev, ev.sample.remote()) + # Kick off another sample request + self.sample_tasks.add(ev, ev.sample.remote()) while not self.learner.outqueue.empty(): count = self.learner.outqueue.get() diff --git a/python/ray/rllib/tuned_examples/pong-impala-fast.yaml b/python/ray/rllib/tuned_examples/pong-impala-fast.yaml new file mode 100644 index 000000000..3466b63ea --- /dev/null +++ b/python/ray/rllib/tuned_examples/pong-impala-fast.yaml @@ -0,0 +1,19 @@ +# This can reach 18-19 reward in ~3 minutes on p3.16xl head w/m4.16xl workers +# 128 workers -> 3 minutes (best case) +# 64 workers -> 4 minutes +# 32 workers -> 7 minutes +# See also: pong-impala.yaml, pong-impala-vectorized.yaml +pong-impala-fast: + env: PongNoFrameskip-v4 + run: IMPALA + config: + sample_batch_size: 50 + train_batch_size: 1000 + num_workers: 256 + num_envs_per_worker: 5 + broadcast_interval: 5 + max_sample_requests_in_flight_per_worker: 1 + num_parallel_data_loaders: 4 + num_gpus: 2 + model: + dim: 42 diff --git a/python/ray/rllib/tuned_examples/pong-impala.yaml b/python/ray/rllib/tuned_examples/pong-impala.yaml index b54c79849..527bc905d 100644 --- a/python/ray/rllib/tuned_examples/pong-impala.yaml +++ b/python/ray/rllib/tuned_examples/pong-impala.yaml @@ -2,7 +2,7 @@ # 128 workers -> 8 minutes # 32 workers -> 17 minutes # 16 workers -> 40 min+ -# See also: pong-impala-vectorized.yaml +# See also: pong-impala-fast.yaml, pong-impala-vectorized.yaml pong-impala: env: PongNoFrameskip-v4 run: IMPALA