From d0a19abeb904e9d192ced9f1b7c9e91d293957da Mon Sep 17 00:00:00 2001 From: "Dr. Kashif Rasul" Date: Tue, 28 Jan 2020 15:05:47 +0100 Subject: [PATCH] use helper function to generate mask --- pts/model/transformer/transformer_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pts/model/transformer/transformer_network.py b/pts/model/transformer/transformer_network.py index c847d4d..fb078e6 100644 --- a/pts/model/transformer/transformer_network.py +++ b/pts/model/transformer/transformer_network.py @@ -83,7 +83,7 @@ class TransformerNetwork(nn.Module): # mask self.register_buffer( - "tgt_mask", torch.triu(torch.ones((prediction_length, prediction_length)), diagonal=1) + "tgt_mask", self.transformer.generate_square_subsequent_mask(prediction_length) ) @staticmethod