mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 18:06:14 +08:00
no causal mask when autoregressively predicting
This commit is contained in:
+7
-3
@@ -223,7 +223,11 @@ class TemporalFusionDecoder(nn.Module):
|
||||
return mask
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, static: torch.Tensor, mask: Optional[torch.Tensor] = None
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
static: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
causal: bool = True,
|
||||
) -> torch.Tensor:
|
||||
expanded_static = static.expand_as(x)
|
||||
# static.repeat((1, self.context_length + self.prediction_length, 1))
|
||||
@@ -242,7 +246,7 @@ class TemporalFusionDecoder(nn.Module):
|
||||
key=query_key_value,
|
||||
value=query_key_value,
|
||||
# key_padding_mask=key_padding_mask,
|
||||
attn_mask=self.attn_mask,
|
||||
attn_mask=self.attn_mask if causal else None,
|
||||
)
|
||||
att = self.att_net(attn_output)
|
||||
|
||||
@@ -571,7 +575,7 @@ class TFTModel(nn.Module):
|
||||
states = [c_h.unsqueeze(0), c_c.unsqueeze(0)]
|
||||
|
||||
enc_out = self.temporal_encoder(past_selection, tgt_input=None, states=states)
|
||||
dec_output = self.temporal_decoder(enc_out, static_enrichment)
|
||||
dec_output = self.temporal_decoder(enc_out, static_enrichment, causal=False)
|
||||
params = self.param_proj(dec_output)
|
||||
|
||||
distr = self.output_distribution(params, scale=scale, trailing_n=1)
|
||||
|
||||
Reference in New Issue
Block a user