mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 14:12:00 +08:00
[rllib] Fix multiagent_two_trainer test (#3509)
* update * fix * dict ordre * fix * fix
This commit is contained in:
@@ -65,7 +65,9 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
||||
|
||||
logger.info("LocalMultiGPUOptimizer devices {}".format(self.devices))
|
||||
|
||||
self.policies = self.local_evaluator.policy_map
|
||||
self.policies = dict(
|
||||
self.local_evaluator.foreach_trainable_policy(lambda p, i: (i, p)))
|
||||
logger.debug("Policies to train: {}".format(self.policies))
|
||||
for policy_id, policy in self.policies.items():
|
||||
if not isinstance(policy, TFPolicyGraph):
|
||||
raise ValueError(
|
||||
@@ -118,21 +120,26 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
||||
DEFAULT_POLICY_ID: samples
|
||||
}, samples.count)
|
||||
|
||||
for _, batch in samples.policy_batches.items():
|
||||
for policy_id, policy in self.policies.items():
|
||||
if policy_id not in samples.policy_batches:
|
||||
continue
|
||||
|
||||
batch = samples.policy_batches[policy_id]
|
||||
for field in self.standardize_fields:
|
||||
value = batch[field]
|
||||
standardized = (value - value.mean()) / max(1e-4, value.std())
|
||||
batch[field] = standardized
|
||||
|
||||
for policy_id, policy in self.policies.items():
|
||||
# Important: don't shuffle RNN sequence elements
|
||||
if (policy_id in samples.policy_batches
|
||||
and not policy._state_inputs):
|
||||
samples.policy_batches[policy_id].shuffle()
|
||||
if not policy._state_inputs:
|
||||
batch.shuffle()
|
||||
|
||||
num_loaded_tuples = {}
|
||||
with self.load_timer:
|
||||
for policy_id, batch in samples.policy_batches.items():
|
||||
if policy_id not in self.policies:
|
||||
continue
|
||||
|
||||
policy = self.policies[policy_id]
|
||||
tuples = policy._get_loss_inputs_dict(batch)
|
||||
data_keys = [ph for _, ph in policy._loss_inputs]
|
||||
|
||||
Reference in New Issue
Block a user