From 71243203a408a2dced2644b06ccd45089a6bc12b Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 9 Jan 2019 19:33:07 -0800 Subject: [PATCH] [rllib] Fix KeyError: 'kl' in multiagent ppo training --- python/ray/rllib/optimizers/multi_gpu_optimizer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/ray/rllib/optimizers/multi_gpu_optimizer.py b/python/ray/rllib/optimizers/multi_gpu_optimizer.py index 73c2416b1..f9f662ef6 100644 --- a/python/ray/rllib/optimizers/multi_gpu_optimizer.py +++ b/python/ray/rllib/optimizers/multi_gpu_optimizer.py @@ -131,7 +131,11 @@ class LocalMultiGPUOptimizer(PolicyOptimizer): "This may be because you have many workers or " "long episodes in 'complete_episodes' batch mode.") else: - samples = self.local_evaluator.sample() + samples = [] + while sum(s.count for s in samples) < self.train_batch_size: + samples.append(self.local_evaluator.sample()) + samples = SampleBatch.concat_samples(samples) + # Handle everything as if multiagent if isinstance(samples, SampleBatch): samples = MultiAgentBatch({ @@ -174,7 +178,8 @@ class LocalMultiGPUOptimizer(PolicyOptimizer): with self.grad_timer: for policy_id, tuples_per_device in num_loaded_tuples.items(): optimizer = self.optimizers[policy_id] - num_batches = ( + num_batches = max( + 1, int(tuples_per_device) // int(self.per_device_batch_size)) logger.debug("== sgd epochs for {} ==".format(policy_id)) for i in range(self.num_sgd_iter):