From 6ffa46d0f5e6c9939fecf0fadcd13bb4944ec7b9 Mon Sep 17 00:00:00 2001 From: "Dr. Kashif Rasul" Date: Wed, 29 Jan 2020 14:46:51 +0100 Subject: [PATCH] fix dims of conditioning vector --- pts/model/transformer_tempflow/transformer_tempflow_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pts/model/transformer_tempflow/transformer_tempflow_network.py b/pts/model/transformer_tempflow/transformer_tempflow_network.py index 55f7076..f01cdba 100644 --- a/pts/model/transformer_tempflow/transformer_tempflow_network.py +++ b/pts/model/transformer_tempflow/transformer_tempflow_network.py @@ -444,7 +444,7 @@ class TransformerTempFlowTrainingNetwork(nn.Module): if self.dequantize: target += torch.rand_like(target) - distr_args = self.distr_args(decoder_output=dec_output) + distr_args = self.distr_args(decoder_output=dec_output.permute(1,0,2)) #likelihoods = -self.flow.log_prob(target, distr_args).unsqueeze(-1) loss = -self.flow.log_prob(future_target_cdf, distr_args).unsqueeze(-1)