mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-06-27 19:32:05 +08:00
use helper function to generate mask
This commit is contained in:
@@ -83,7 +83,7 @@ class TransformerNetwork(nn.Module):
|
||||
|
||||
# mask
|
||||
self.register_buffer(
|
||||
"tgt_mask", torch.triu(torch.ones((prediction_length, prediction_length)), diagonal=1)
|
||||
"tgt_mask", self.transformer.generate_square_subsequent_mask(prediction_length)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user