fix selection

This commit is contained in:
Kashif Rasul
2022-03-30 17:45:23 +02:00
parent 248f81e6c2
commit a9933dc4ff
+1 -2
View File
@@ -232,9 +232,8 @@ class TemporalFusionDecoder(nn.Module):
# key_padding_mask = torch.cat((mask, mask_pad), dim=1).bool()
query_key_value = x
attn_output, _ = self.attention(
query=query_key_value[-self.prediction_length :, ...],
query=query_key_value[:, -self.prediction_length :, ...],
key=query_key_value,
value=query_key_value,
# key_padding_mask=key_padding_mask,