fix dims of conditioning vector

This commit is contained in:
Dr. Kashif Rasul
2020-01-29 14:46:51 +01:00
parent d05fc7a242
commit 6ffa46d0f5
@@ -444,7 +444,7 @@ class TransformerTempFlowTrainingNetwork(nn.Module):
if self.dequantize:
target += torch.rand_like(target)
distr_args = self.distr_args(decoder_output=dec_output)
distr_args = self.distr_args(decoder_output=dec_output.permute(1,0,2))
#likelihoods = -self.flow.log_prob(target, distr_args).unsqueeze(-1)
loss = -self.flow.log_prob(future_target_cdf, distr_args).unsqueeze(-1)