diff --git a/python/ray/rllib/optimizers/multi_gpu_optimizer.py b/python/ray/rllib/optimizers/multi_gpu_optimizer.py index 73c2416b1..f9f662ef6 100644 --- a/python/ray/rllib/optimizers/multi_gpu_optimizer.py +++ b/python/ray/rllib/optimizers/multi_gpu_optimizer.py @@ -131,7 +131,11 @@ class LocalMultiGPUOptimizer(PolicyOptimizer): "This may be because you have many workers or " "long episodes in 'complete_episodes' batch mode.") else: - samples = self.local_evaluator.sample() + samples = [] + while sum(s.count for s in samples) < self.train_batch_size: + samples.append(self.local_evaluator.sample()) + samples = SampleBatch.concat_samples(samples) + # Handle everything as if multiagent if isinstance(samples, SampleBatch): samples = MultiAgentBatch({ @@ -174,7 +178,8 @@ class LocalMultiGPUOptimizer(PolicyOptimizer): with self.grad_timer: for policy_id, tuples_per_device in num_loaded_tuples.items(): optimizer = self.optimizers[policy_id] - num_batches = ( + num_batches = max( + 1, int(tuples_per_device) // int(self.per_device_batch_size)) logger.debug("== sgd epochs for {} ==".format(policy_id)) for i in range(self.num_sgd_iter):