remove sqrt from l2_dist

This commit is contained in:
Mike Clark
2018-05-08 20:12:34 +08:00
committed by GitHub
parent 3ee35e440e
commit dacb51da28
+2 -2
View File
@@ -26,7 +26,7 @@ def loss_function(recon_x, x, mu, logvar):
recon_x = recon_x.view(n, -1)
x = x.view(n, -1)
# L2 distance
l2_dist = torch.mean(torch.sqrt(torch.sum(torch.pow(recon_x - x, 2), 1)))
l2_dist = torch.mean(torch.sum(torch.pow(recon_x - x, 2), 1))
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
@@ -107,4 +107,4 @@ class VAE(nn.Module):
#
# cv2.imshow('original', img)
# cv2.imshow('reconstructed', x)
# cv2.waitKey()
# cv2.waitKey()