[rllib] Fix KeyError: 'kl' in multiagent ppo training

This commit is contained in:
Eric Liang
2019-01-09 19:33:07 -08:00
committed by GitHub
parent 6fc3fc4120
commit 71243203a4
@@ -131,7 +131,11 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
"This may be because you have many workers or "
"long episodes in 'complete_episodes' batch mode.")
else:
samples = self.local_evaluator.sample()
samples = []
while sum(s.count for s in samples) < self.train_batch_size:
samples.append(self.local_evaluator.sample())
samples = SampleBatch.concat_samples(samples)
# Handle everything as if multiagent
if isinstance(samples, SampleBatch):
samples = MultiAgentBatch({
@@ -174,7 +178,8 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
with self.grad_timer:
for policy_id, tuples_per_device in num_loaded_tuples.items():
optimizer = self.optimizers[policy_id]
num_batches = (
num_batches = max(
1,
int(tuples_per_device) // int(self.per_device_batch_size))
logger.debug("== sgd epochs for {} ==".format(policy_id))
for i in range(self.num_sgd_iter):