diff --git a/pts/model/transformer_tempflow/transformer_tempflow_network.py b/pts/model/transformer_tempflow/transformer_tempflow_network.py index 55f7076..f01cdba 100644 --- a/pts/model/transformer_tempflow/transformer_tempflow_network.py +++ b/pts/model/transformer_tempflow/transformer_tempflow_network.py @@ -444,7 +444,7 @@ class TransformerTempFlowTrainingNetwork(nn.Module): if self.dequantize: target += torch.rand_like(target) - distr_args = self.distr_args(decoder_output=dec_output) + distr_args = self.distr_args(decoder_output=dec_output.permute(1,0,2)) #likelihoods = -self.flow.log_prob(target, distr_args).unsqueeze(-1) loss = -self.flow.log_prob(future_target_cdf, distr_args).unsqueeze(-1)