[rllib] Support the timesteps_per_batch in simple optimizer PPO mode (#2558)

* support ts

* doc

* Update sync_samples_optimizer.py
This commit is contained in:
Eric Liang
2018-08-06 12:10:59 -07:00
committed by GitHub
parent 9015e742c4
commit 981d9818c1
3 changed files with 17 additions and 9 deletions
+4 -2
View File
@@ -83,8 +83,10 @@ class PPOAgent(Agent):
})
if self.config["simple_optimizer"]:
self.optimizer = SyncSamplesOptimizer(
self.local_evaluator, self.remote_evaluators,
{"num_sgd_iter": self.config["num_sgd_iter"]})
self.local_evaluator, self.remote_evaluators, {
"num_sgd_iter": self.config["num_sgd_iter"],
"timesteps_per_batch": self.config["timesteps_per_batch"]
})
else:
self.optimizer = LocalMultiGPUOptimizer(
self.local_evaluator, self.remote_evaluators, {
@@ -17,12 +17,13 @@ class SyncSamplesOptimizer(PolicyOptimizer):
model weights are then broadcast to all remote evaluators.
"""
def _init(self, num_sgd_iter=1):
def _init(self, num_sgd_iter=1, timesteps_per_batch=1):
self.update_weights_timer = TimerStat()
self.sample_timer = TimerStat()
self.grad_timer = TimerStat()
self.throughput = RunningStat()
self.num_sgd_iter = num_sgd_iter
self.timesteps_per_batch = timesteps_per_batch
def step(self):
with self.update_weights_timer:
@@ -32,12 +33,16 @@ class SyncSamplesOptimizer(PolicyOptimizer):
e.set_weights.remote(weights)
with self.sample_timer:
if self.remote_evaluators:
samples = SampleBatch.concat_samples(
ray.get(
[e.sample.remote() for e in self.remote_evaluators]))
else:
samples = self.local_evaluator.sample()
samples = []
while sum(s.count for s in samples) < self.timesteps_per_batch:
if self.remote_evaluators:
samples.extend(
ray.get([
e.sample.remote() for e in self.remote_evaluators
]))
else:
samples.append(self.local_evaluator.sample())
samples = SampleBatch.concat_samples(samples)
with self.grad_timer:
for i in range(self.num_sgd_iter):