mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 22:20:52 +08:00
[rllib] Ensure stats are consistently reported across all algos (#4445)
This commit is contained in:
@@ -4,6 +4,7 @@ from __future__ import print_function
|
||||
|
||||
import ray
|
||||
import logging
|
||||
from ray.rllib.evaluation.metrics import get_learner_stats
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import override
|
||||
@@ -55,8 +56,7 @@ class SyncSamplesOptimizer(PolicyOptimizer):
|
||||
with self.grad_timer:
|
||||
for i in range(self.num_sgd_iter):
|
||||
fetches = self.local_evaluator.learn_on_batch(samples)
|
||||
if "stats" in fetches:
|
||||
self.learner_stats = fetches["stats"]
|
||||
self.learner_stats = get_learner_stats(fetches)
|
||||
if self.num_sgd_iter > 1:
|
||||
logger.debug("{} {}".format(i, fetches))
|
||||
self.grad_timer.push_units_processed(samples.count)
|
||||
|
||||
Reference in New Issue
Block a user