use helper function to generate mask

This commit is contained in:
Dr. Kashif Rasul
2020-01-28 15:05:47 +01:00
parent fcb69e83d2
commit d0a19abeb9
+1 -1
View File
@@ -83,7 +83,7 @@ class TransformerNetwork(nn.Module):
# mask
self.register_buffer(
"tgt_mask", torch.triu(torch.ones((prediction_length, prediction_length)), diagonal=1)
"tgt_mask", self.transformer.generate_square_subsequent_mask(prediction_length)
)
@staticmethod