mirror of
https://github.com/wassname/ray.git
synced 2026-07-04 00:50:21 +08:00
Issue 8407: RNN sequencing error in QMIX (#9139)
This commit is contained in:
@@ -320,11 +320,11 @@ class QMixTorchPolicy(Policy):
|
||||
|
||||
output_list, _, seq_lens = \
|
||||
chop_into_sequences(
|
||||
samples[SampleBatch.EPS_ID],
|
||||
samples[SampleBatch.UNROLL_ID],
|
||||
samples[SampleBatch.AGENT_INDEX],
|
||||
input_list,
|
||||
[], # RNN states not used here
|
||||
episode_ids=samples[SampleBatch.EPS_ID],
|
||||
unroll_ids=samples[SampleBatch.UNROLL_ID],
|
||||
agent_indices=samples[SampleBatch.AGENT_INDEX],
|
||||
feature_columns=input_list,
|
||||
state_columns=[], # RNN states not used here
|
||||
max_seq_len=self.config["model"]["max_seq_len"],
|
||||
dynamic_max=True)
|
||||
# These will be padded to shape [B * T, ...]
|
||||
@@ -473,6 +473,7 @@ class QMixTorchPolicy(Policy):
|
||||
tensorlib=np)
|
||||
|
||||
if isinstance(unpacked[0], dict):
|
||||
assert "obs" in unpacked[0]
|
||||
unpacked_obs = [
|
||||
np.concatenate(tree.flatten(u["obs"]), 1) for u in unpacked
|
||||
]
|
||||
@@ -493,7 +494,7 @@ class QMixTorchPolicy(Policy):
|
||||
dtype=np.float32)
|
||||
|
||||
if self.has_env_global_state:
|
||||
state = unpacked[0][ENV_STATE]
|
||||
state = np.concatenate(tree.flatten(unpacked[0][ENV_STATE]), 1)
|
||||
else:
|
||||
state = None
|
||||
return obs, action_mask, state
|
||||
|
||||
Reference in New Issue
Block a user