mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 18:06:14 +08:00
fix selection
This commit is contained in:
+1
-2
@@ -232,9 +232,8 @@ class TemporalFusionDecoder(nn.Module):
|
||||
# key_padding_mask = torch.cat((mask, mask_pad), dim=1).bool()
|
||||
|
||||
query_key_value = x
|
||||
|
||||
attn_output, _ = self.attention(
|
||||
query=query_key_value[-self.prediction_length :, ...],
|
||||
query=query_key_value[:, -self.prediction_length :, ...],
|
||||
key=query_key_value,
|
||||
value=query_key_value,
|
||||
# key_padding_mask=key_padding_mask,
|
||||
|
||||
Reference in New Issue
Block a user