mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-06-27 19:32:05 +08:00
fix dims of conditioning vector
This commit is contained in:
@@ -444,7 +444,7 @@ class TransformerTempFlowTrainingNetwork(nn.Module):
|
||||
if self.dequantize:
|
||||
target += torch.rand_like(target)
|
||||
|
||||
distr_args = self.distr_args(decoder_output=dec_output)
|
||||
distr_args = self.distr_args(decoder_output=dec_output.permute(1,0,2))
|
||||
#likelihoods = -self.flow.log_prob(target, distr_args).unsqueeze(-1)
|
||||
loss = -self.flow.log_prob(future_target_cdf, distr_args).unsqueeze(-1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user