no causal mask when autoregressively predicting

This commit is contained in:
Kashif Rasul
2022-03-30 21:52:07 +02:00
parent fbccd3076d
commit aa2dcc2887
+7 -3
View File
@@ -223,7 +223,11 @@ class TemporalFusionDecoder(nn.Module):
return mask
def forward(
self, x: torch.Tensor, static: torch.Tensor, mask: Optional[torch.Tensor] = None
self,
x: torch.Tensor,
static: torch.Tensor,
mask: Optional[torch.Tensor] = None,
causal: bool = True,
) -> torch.Tensor:
expanded_static = static.expand_as(x)
# static.repeat((1, self.context_length + self.prediction_length, 1))
@@ -242,7 +246,7 @@ class TemporalFusionDecoder(nn.Module):
key=query_key_value,
value=query_key_value,
# key_padding_mask=key_padding_mask,
attn_mask=self.attn_mask,
attn_mask=self.attn_mask if causal else None,
)
att = self.att_net(attn_output)
@@ -571,7 +575,7 @@ class TFTModel(nn.Module):
states = [c_h.unsqueeze(0), c_c.unsqueeze(0)]
enc_out = self.temporal_encoder(past_selection, tgt_input=None, states=states)
dec_output = self.temporal_decoder(enc_out, static_enrichment)
dec_output = self.temporal_decoder(enc_out, static_enrichment, causal=False)
params = self.param_proj(dec_output)
distr = self.output_distribution(params, scale=scale, trailing_n=1)