[rllib] Fix multiagent_two_trainer test (#3509)

* update

* fix

* dict ordre

* fix

* fix
This commit is contained in:
Eric Liang
2018-12-11 00:16:39 -08:00
committed by GitHub
parent 1f4a01cff6
commit 52df4dfc6f
@@ -65,7 +65,9 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
logger.info("LocalMultiGPUOptimizer devices {}".format(self.devices))
self.policies = self.local_evaluator.policy_map
self.policies = dict(
self.local_evaluator.foreach_trainable_policy(lambda p, i: (i, p)))
logger.debug("Policies to train: {}".format(self.policies))
for policy_id, policy in self.policies.items():
if not isinstance(policy, TFPolicyGraph):
raise ValueError(
@@ -118,21 +120,26 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
DEFAULT_POLICY_ID: samples
}, samples.count)
for _, batch in samples.policy_batches.items():
for policy_id, policy in self.policies.items():
if policy_id not in samples.policy_batches:
continue
batch = samples.policy_batches[policy_id]
for field in self.standardize_fields:
value = batch[field]
standardized = (value - value.mean()) / max(1e-4, value.std())
batch[field] = standardized
for policy_id, policy in self.policies.items():
# Important: don't shuffle RNN sequence elements
if (policy_id in samples.policy_batches
and not policy._state_inputs):
samples.policy_batches[policy_id].shuffle()
if not policy._state_inputs:
batch.shuffle()
num_loaded_tuples = {}
with self.load_timer:
for policy_id, batch in samples.policy_batches.items():
if policy_id not in self.policies:
continue
policy = self.policies[policy_id]
tuples = policy._get_loss_inputs_dict(batch)
data_keys = [ph for _, ph in policy._loss_inputs]