mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 23:08:32 +08:00
[rllib] Fix KeyError: 'kl' in multiagent ppo training
This commit is contained in:
@@ -131,7 +131,11 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
||||
"This may be because you have many workers or "
|
||||
"long episodes in 'complete_episodes' batch mode.")
|
||||
else:
|
||||
samples = self.local_evaluator.sample()
|
||||
samples = []
|
||||
while sum(s.count for s in samples) < self.train_batch_size:
|
||||
samples.append(self.local_evaluator.sample())
|
||||
samples = SampleBatch.concat_samples(samples)
|
||||
|
||||
# Handle everything as if multiagent
|
||||
if isinstance(samples, SampleBatch):
|
||||
samples = MultiAgentBatch({
|
||||
@@ -174,7 +178,8 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
||||
with self.grad_timer:
|
||||
for policy_id, tuples_per_device in num_loaded_tuples.items():
|
||||
optimizer = self.optimizers[policy_id]
|
||||
num_batches = (
|
||||
num_batches = max(
|
||||
1,
|
||||
int(tuples_per_device) // int(self.per_device_batch_size))
|
||||
logger.debug("== sgd epochs for {} ==".format(policy_id))
|
||||
for i in range(self.num_sgd_iter):
|
||||
|
||||
Reference in New Issue
Block a user