Issue 8407: RNN sequencing error in QMIX (#9139)

This commit is contained in:
Sven Mika
2020-06-26 09:50:31 +02:00
committed by GitHub
parent f940ccd6fb
commit e93a1a82ab
+7 -6
View File
@@ -320,11 +320,11 @@ class QMixTorchPolicy(Policy):
output_list, _, seq_lens = \
chop_into_sequences(
samples[SampleBatch.EPS_ID],
samples[SampleBatch.UNROLL_ID],
samples[SampleBatch.AGENT_INDEX],
input_list,
[], # RNN states not used here
episode_ids=samples[SampleBatch.EPS_ID],
unroll_ids=samples[SampleBatch.UNROLL_ID],
agent_indices=samples[SampleBatch.AGENT_INDEX],
feature_columns=input_list,
state_columns=[], # RNN states not used here
max_seq_len=self.config["model"]["max_seq_len"],
dynamic_max=True)
# These will be padded to shape [B * T, ...]
@@ -473,6 +473,7 @@ class QMixTorchPolicy(Policy):
tensorlib=np)
if isinstance(unpacked[0], dict):
assert "obs" in unpacked[0]
unpacked_obs = [
np.concatenate(tree.flatten(u["obs"]), 1) for u in unpacked
]
@@ -493,7 +494,7 @@ class QMixTorchPolicy(Policy):
dtype=np.float32)
if self.has_env_global_state:
state = unpacked[0][ENV_STATE]
state = np.concatenate(tree.flatten(unpacked[0][ENV_STATE]), 1)
else:
state = None
return obs, action_mask, state