mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 15:06:28 +08:00
882a649f0c
* wip * more work * fix apex * docs * apex doc * pool comment * clean up * make wrap stack pluggable * Mon Mar 12 21:45:50 PDT 2018 * clean up comment * table * Mon Mar 12 22:51:57 PDT 2018 * Mon Mar 12 22:53:05 PDT 2018 * Mon Mar 12 22:55:03 PDT 2018 * Mon Mar 12 22:56:18 PDT 2018 * Mon Mar 12 22:59:54 PDT 2018 * Update apex_optimizer.py * Update index.rst * Update README.rst * Update README.rst * comments * Wed Mar 14 19:01:02 PDT 2018
62 lines
2.3 KiB
Python
62 lines
2.3 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import ray
|
|
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
|
from ray.rllib.utils.timer import TimerStat
|
|
|
|
|
|
class AsyncOptimizer(PolicyOptimizer):
|
|
"""An asynchronous RL optimizer, e.g. for implementing A3C.
|
|
|
|
This optimizer asynchronously pulls and applies gradients from remote
|
|
evaluators, sending updated weights back as needed. This pipelines the
|
|
gradient computations on the remote workers.
|
|
"""
|
|
def _init(self, grads_per_step=100, batch_size=10):
|
|
self.apply_timer = TimerStat()
|
|
self.wait_timer = TimerStat()
|
|
self.dispatch_timer = TimerStat()
|
|
self.grads_per_step = grads_per_step
|
|
self.batch_size = batch_size
|
|
|
|
def step(self):
|
|
weights = ray.put(self.local_evaluator.get_weights())
|
|
gradient_queue = []
|
|
num_gradients = 0
|
|
|
|
# Kick off the first wave of async tasks
|
|
for e in self.remote_evaluators:
|
|
e.set_weights.remote(weights)
|
|
fut = e.compute_gradients.remote(e.sample.remote())
|
|
gradient_queue.append((fut, e))
|
|
num_gradients += 1
|
|
|
|
# Note: can't use wait: https://github.com/ray-project/ray/issues/1128
|
|
while gradient_queue:
|
|
with self.wait_timer:
|
|
fut, e = gradient_queue.pop(0)
|
|
gradient, _ = ray.get(fut)
|
|
|
|
if gradient is not None:
|
|
with self.apply_timer:
|
|
self.local_evaluator.apply_gradients(gradient)
|
|
|
|
if num_gradients < self.grads_per_step:
|
|
with self.dispatch_timer:
|
|
e.set_weights.remote(self.local_evaluator.get_weights())
|
|
fut = e.compute_gradients.remote(e.sample.remote())
|
|
gradient_queue.append((fut, e))
|
|
num_gradients += 1
|
|
|
|
self.num_steps_sampled += self.grads_per_step * self.batch_size
|
|
self.num_steps_trained += self.grads_per_step * self.batch_size
|
|
|
|
def stats(self):
|
|
return dict(PolicyOptimizer.stats(), **{
|
|
"wait_time_ms": round(1000 * self.wait_timer.mean, 3),
|
|
"apply_time_ms": round(1000 * self.apply_timer.mean, 3),
|
|
"dispatch_time_ms": round(1000 * self.dispatch_timer.mean, 3),
|
|
})
|