mirror of
https://github.com/wassname/world-models-pytorch.git
synced 2026-06-27 17:33:07 +08:00
remove sqrt from l2_dist
This commit is contained in:
@@ -26,7 +26,7 @@ def loss_function(recon_x, x, mu, logvar):
|
||||
recon_x = recon_x.view(n, -1)
|
||||
x = x.view(n, -1)
|
||||
# L2 distance
|
||||
l2_dist = torch.mean(torch.sqrt(torch.sum(torch.pow(recon_x - x, 2), 1)))
|
||||
l2_dist = torch.mean(torch.sum(torch.pow(recon_x - x, 2), 1))
|
||||
# see Appendix B from VAE paper:
|
||||
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
|
||||
# https://arxiv.org/abs/1312.6114
|
||||
@@ -107,4 +107,4 @@ class VAE(nn.Module):
|
||||
#
|
||||
# cv2.imshow('original', img)
|
||||
# cv2.imshow('reconstructed', x)
|
||||
# cv2.waitKey()
|
||||
# cv2.waitKey()
|
||||
|
||||
Reference in New Issue
Block a user