From e93a1a82abcec58078865d90ea683a09902bc2ec Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Fri, 26 Jun 2020 09:50:31 +0200 Subject: [PATCH] Issue 8407: RNN sequencing error in QMIX (#9139) --- rllib/agents/qmix/qmix_policy.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/rllib/agents/qmix/qmix_policy.py b/rllib/agents/qmix/qmix_policy.py index 2d821da9c..4faa2bb30 100644 --- a/rllib/agents/qmix/qmix_policy.py +++ b/rllib/agents/qmix/qmix_policy.py @@ -320,11 +320,11 @@ class QMixTorchPolicy(Policy): output_list, _, seq_lens = \ chop_into_sequences( - samples[SampleBatch.EPS_ID], - samples[SampleBatch.UNROLL_ID], - samples[SampleBatch.AGENT_INDEX], - input_list, - [], # RNN states not used here + episode_ids=samples[SampleBatch.EPS_ID], + unroll_ids=samples[SampleBatch.UNROLL_ID], + agent_indices=samples[SampleBatch.AGENT_INDEX], + feature_columns=input_list, + state_columns=[], # RNN states not used here max_seq_len=self.config["model"]["max_seq_len"], dynamic_max=True) # These will be padded to shape [B * T, ...] @@ -473,6 +473,7 @@ class QMixTorchPolicy(Policy): tensorlib=np) if isinstance(unpacked[0], dict): + assert "obs" in unpacked[0] unpacked_obs = [ np.concatenate(tree.flatten(u["obs"]), 1) for u in unpacked ] @@ -493,7 +494,7 @@ class QMixTorchPolicy(Policy): dtype=np.float32) if self.has_env_global_state: - state = unpacked[0][ENV_STATE] + state = np.concatenate(tree.flatten(unpacked[0][ENV_STATE]), 1) else: state = None return obs, action_mask, state