make mask optional for now

This commit is contained in:
Kashif Rasul
2022-03-30 17:04:29 +02:00
parent 8c8b1aa4eb
commit 248f81e6c2
+1 -1
View File
@@ -219,7 +219,7 @@ class TemporalFusionDecoder(nn.Module):
return mask
def forward(
self, x: torch.Tensor, static: torch.Tensor, mask: torch.Tensor
self, x: torch.Tensor, static: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
static = static.repeat((1, self.context_length + self.prediction_length, 1))