From eea4a7e35aa2118a0480eb8006f65844cccc3396 Mon Sep 17 00:00:00 2001 From: Mike Clark Date: Thu, 17 Mar 2022 21:51:55 +0000 Subject: [PATCH] fix #5 --- neural_processes/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/neural_processes/utils.py b/neural_processes/utils.py index fc7093a..6c43aa9 100644 --- a/neural_processes/utils.py +++ b/neural_processes/utils.py @@ -172,8 +172,10 @@ def kl_loss_var(prior_mu, log_var_prior, post_mu, log_var_post): """ var_ratio_log = log_var_post - log_var_prior + t1 = (post_mu - prior_mu) ** 2 / log_var_prior.exp() kl_div = ( - (var_ratio_log.exp() + (post_mu - prior_mu) ** 2) / log_var_prior.exp() + var_ratio_log.exp() + + t1 - 1.0 - var_ratio_log )