mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 18:06:14 +08:00
make mask optional for now
This commit is contained in:
+1
-1
@@ -219,7 +219,7 @@ class TemporalFusionDecoder(nn.Module):
|
||||
return mask
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, static: torch.Tensor, mask: torch.Tensor
|
||||
self, x: torch.Tensor, static: torch.Tensor, mask: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
static = static.repeat((1, self.context_length + self.prediction_length, 1))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user