mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 02:42:52 +08:00
[rllib] Support the timesteps_per_batch in simple optimizer PPO mode (#2558)
* support ts * doc * Update sync_samples_optimizer.py
This commit is contained in:
@@ -83,8 +83,10 @@ class PPOAgent(Agent):
|
||||
})
|
||||
if self.config["simple_optimizer"]:
|
||||
self.optimizer = SyncSamplesOptimizer(
|
||||
self.local_evaluator, self.remote_evaluators,
|
||||
{"num_sgd_iter": self.config["num_sgd_iter"]})
|
||||
self.local_evaluator, self.remote_evaluators, {
|
||||
"num_sgd_iter": self.config["num_sgd_iter"],
|
||||
"timesteps_per_batch": self.config["timesteps_per_batch"]
|
||||
})
|
||||
else:
|
||||
self.optimizer = LocalMultiGPUOptimizer(
|
||||
self.local_evaluator, self.remote_evaluators, {
|
||||
|
||||
@@ -17,12 +17,13 @@ class SyncSamplesOptimizer(PolicyOptimizer):
|
||||
model weights are then broadcast to all remote evaluators.
|
||||
"""
|
||||
|
||||
def _init(self, num_sgd_iter=1):
|
||||
def _init(self, num_sgd_iter=1, timesteps_per_batch=1):
|
||||
self.update_weights_timer = TimerStat()
|
||||
self.sample_timer = TimerStat()
|
||||
self.grad_timer = TimerStat()
|
||||
self.throughput = RunningStat()
|
||||
self.num_sgd_iter = num_sgd_iter
|
||||
self.timesteps_per_batch = timesteps_per_batch
|
||||
|
||||
def step(self):
|
||||
with self.update_weights_timer:
|
||||
@@ -32,12 +33,16 @@ class SyncSamplesOptimizer(PolicyOptimizer):
|
||||
e.set_weights.remote(weights)
|
||||
|
||||
with self.sample_timer:
|
||||
if self.remote_evaluators:
|
||||
samples = SampleBatch.concat_samples(
|
||||
ray.get(
|
||||
[e.sample.remote() for e in self.remote_evaluators]))
|
||||
else:
|
||||
samples = self.local_evaluator.sample()
|
||||
samples = []
|
||||
while sum(s.count for s in samples) < self.timesteps_per_batch:
|
||||
if self.remote_evaluators:
|
||||
samples.extend(
|
||||
ray.get([
|
||||
e.sample.remote() for e in self.remote_evaluators
|
||||
]))
|
||||
else:
|
||||
samples.append(self.local_evaluator.sample())
|
||||
samples = SampleBatch.concat_samples(samples)
|
||||
|
||||
with self.grad_timer:
|
||||
for i in range(self.num_sgd_iter):
|
||||
|
||||
Reference in New Issue
Block a user