mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 01:06:07 +08:00
[rllib] Add stats for A3C (#2315)
* add stats for a3c again * fix multigpu too
This commit is contained in:
@@ -139,7 +139,10 @@ class A3CAgent(Agent):
|
||||
self.optimizer.step()
|
||||
FilterManager.synchronize(
|
||||
self.local_evaluator.filters, self.remote_evaluators)
|
||||
return collect_metrics(self.local_evaluator, self.remote_evaluators)
|
||||
result = collect_metrics(self.local_evaluator, self.remote_evaluators)
|
||||
result = result._replace(
|
||||
info=self.optimizer.stats())
|
||||
return result
|
||||
|
||||
def _stop(self):
|
||||
# workaround for https://github.com/ray-project/ray/issues/1516
|
||||
|
||||
@@ -57,7 +57,7 @@ class AsyncGradientsOptimizer(PolicyOptimizer):
|
||||
self.num_steps_trained += self.grads_per_step * self.batch_size
|
||||
|
||||
def stats(self):
|
||||
return dict(PolicyOptimizer.stats(), **{
|
||||
return dict(PolicyOptimizer.stats(self), **{
|
||||
"wait_time_ms": round(1000 * self.wait_timer.mean, 3),
|
||||
"apply_time_ms": round(1000 * self.apply_timer.mean, 3),
|
||||
"dispatch_time_ms": round(1000 * self.dispatch_timer.mean, 3),
|
||||
|
||||
@@ -123,7 +123,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
||||
return all_extra_fetches
|
||||
|
||||
def stats(self):
|
||||
return dict(PolicyOptimizer.stats(), **{
|
||||
return dict(PolicyOptimizer.stats(self), **{
|
||||
"sample_time_ms": round(1000 * self.sample_timer.mean, 3),
|
||||
"load_time_ms": round(1000 * self.load_timer.mean, 3),
|
||||
"grad_time_ms": round(1000 * self.grad_timer.mean, 3),
|
||||
|
||||
Reference in New Issue
Block a user