diff --git a/pts/model/transformer/transformer_network.py b/pts/model/transformer/transformer_network.py index c847d4d..fb078e6 100644 --- a/pts/model/transformer/transformer_network.py +++ b/pts/model/transformer/transformer_network.py @@ -83,7 +83,7 @@ class TransformerNetwork(nn.Module): # mask self.register_buffer( - "tgt_mask", torch.triu(torch.ones((prediction_length, prediction_length)), diagonal=1) + "tgt_mask", self.transformer.generate_square_subsequent_mask(prediction_length) ) @staticmethod