mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 23:09:51 +08:00
[rllib] Fix atari reward calculations, add LR annealing, explained var stat for A2C / impala (#2700)
Changes needed to reproduce Atari plots in IMPALA / A2C: https://github.com/ray-project/rl-experiments
This commit is contained in:
@@ -24,6 +24,7 @@ class SyncSamplesOptimizer(PolicyOptimizer):
|
||||
self.throughput = RunningStat()
|
||||
self.num_sgd_iter = num_sgd_iter
|
||||
self.timesteps_per_batch = timesteps_per_batch
|
||||
self.learner_stats = {}
|
||||
|
||||
def step(self):
|
||||
with self.update_weights_timer:
|
||||
@@ -48,6 +49,8 @@ class SyncSamplesOptimizer(PolicyOptimizer):
|
||||
with self.grad_timer:
|
||||
for i in range(self.num_sgd_iter):
|
||||
fetches = self.local_evaluator.compute_apply(samples)
|
||||
if "stats" in fetches:
|
||||
self.learner_stats = fetches["stats"]
|
||||
if self.num_sgd_iter > 1:
|
||||
print(i, fetches)
|
||||
self.grad_timer.push_units_processed(samples.count)
|
||||
@@ -68,4 +71,5 @@ class SyncSamplesOptimizer(PolicyOptimizer):
|
||||
"sample_peak_throughput": round(
|
||||
self.sample_timer.mean_throughput, 3),
|
||||
"opt_samples": round(self.grad_timer.mean_units_processed, 3),
|
||||
"learner": self.learner_stats,
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user