diff --git a/python/ray/rllib/optimizers/multi_gpu_optimizer.py b/python/ray/rllib/optimizers/multi_gpu_optimizer.py index 5ca29f68c..e6f2a427c 100644 --- a/python/ray/rllib/optimizers/multi_gpu_optimizer.py +++ b/python/ray/rllib/optimizers/multi_gpu_optimizer.py @@ -65,7 +65,9 @@ class LocalMultiGPUOptimizer(PolicyOptimizer): logger.info("LocalMultiGPUOptimizer devices {}".format(self.devices)) - self.policies = self.local_evaluator.policy_map + self.policies = dict( + self.local_evaluator.foreach_trainable_policy(lambda p, i: (i, p))) + logger.debug("Policies to train: {}".format(self.policies)) for policy_id, policy in self.policies.items(): if not isinstance(policy, TFPolicyGraph): raise ValueError( @@ -118,21 +120,26 @@ class LocalMultiGPUOptimizer(PolicyOptimizer): DEFAULT_POLICY_ID: samples }, samples.count) - for _, batch in samples.policy_batches.items(): + for policy_id, policy in self.policies.items(): + if policy_id not in samples.policy_batches: + continue + + batch = samples.policy_batches[policy_id] for field in self.standardize_fields: value = batch[field] standardized = (value - value.mean()) / max(1e-4, value.std()) batch[field] = standardized - for policy_id, policy in self.policies.items(): # Important: don't shuffle RNN sequence elements - if (policy_id in samples.policy_batches - and not policy._state_inputs): - samples.policy_batches[policy_id].shuffle() + if not policy._state_inputs: + batch.shuffle() num_loaded_tuples = {} with self.load_timer: for policy_id, batch in samples.policy_batches.items(): + if policy_id not in self.policies: + continue + policy = self.policies[policy_id] tuples = policy._get_loss_inputs_dict(batch) data_keys = [ph for _, ph in policy._loss_inputs]