mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 14:05:08 +08:00
[rllib] Add debug info back to PPO and fix optimizer compatibility (#2366)
This commit is contained in:
@@ -17,11 +17,12 @@ class SyncSamplesOptimizer(PolicyOptimizer):
|
||||
model weights are then broadcast to all remote evaluators.
|
||||
"""
|
||||
|
||||
def _init(self):
|
||||
def _init(self, num_sgd_iter=1):
|
||||
self.update_weights_timer = TimerStat()
|
||||
self.sample_timer = TimerStat()
|
||||
self.grad_timer = TimerStat()
|
||||
self.throughput = RunningStat()
|
||||
self.num_sgd_iter = num_sgd_iter
|
||||
|
||||
def step(self):
|
||||
with self.update_weights_timer:
|
||||
@@ -39,11 +40,15 @@ class SyncSamplesOptimizer(PolicyOptimizer):
|
||||
samples = self.local_evaluator.sample()
|
||||
|
||||
with self.grad_timer:
|
||||
self.local_evaluator.compute_apply(samples)
|
||||
for i in range(self.num_sgd_iter):
|
||||
fetches = self.local_evaluator.compute_apply(samples)
|
||||
if self.num_sgd_iter > 1:
|
||||
print(i, fetches)
|
||||
self.grad_timer.push_units_processed(samples.count)
|
||||
|
||||
self.num_steps_sampled += samples.count
|
||||
self.num_steps_trained += samples.count
|
||||
return fetches
|
||||
|
||||
def stats(self):
|
||||
return dict(PolicyOptimizer.stats(self), **{
|
||||
|
||||
Reference in New Issue
Block a user