[rllib] Add debug info back to PPO and fix optimizer compatibility (#2366)

This commit is contained in:
Eric Liang
2018-07-12 19:22:46 +02:00
committed by Richard Liaw
parent 8ea926c266
commit b316afeb43
14 changed files with 122 additions and 97 deletions
@@ -17,11 +17,12 @@ class SyncSamplesOptimizer(PolicyOptimizer):
model weights are then broadcast to all remote evaluators.
"""
def _init(self):
def _init(self, num_sgd_iter=1):
self.update_weights_timer = TimerStat()
self.sample_timer = TimerStat()
self.grad_timer = TimerStat()
self.throughput = RunningStat()
self.num_sgd_iter = num_sgd_iter
def step(self):
with self.update_weights_timer:
@@ -39,11 +40,15 @@ class SyncSamplesOptimizer(PolicyOptimizer):
samples = self.local_evaluator.sample()
with self.grad_timer:
self.local_evaluator.compute_apply(samples)
for i in range(self.num_sgd_iter):
fetches = self.local_evaluator.compute_apply(samples)
if self.num_sgd_iter > 1:
print(i, fetches)
self.grad_timer.push_units_processed(samples.count)
self.num_steps_sampled += samples.count
self.num_steps_trained += samples.count
return fetches
def stats(self):
return dict(PolicyOptimizer.stats(self), **{