diff --git a/neural_processes/utils.py b/neural_processes/utils.py index fc7093a..6c43aa9 100644 --- a/neural_processes/utils.py +++ b/neural_processes/utils.py @@ -172,8 +172,10 @@ def kl_loss_var(prior_mu, log_var_prior, post_mu, log_var_post): """ var_ratio_log = log_var_post - log_var_prior + t1 = (post_mu - prior_mu) ** 2 / log_var_prior.exp() kl_div = ( - (var_ratio_log.exp() + (post_mu - prior_mu) ** 2) / log_var_prior.exp() + var_ratio_log.exp() + + t1 - 1.0 - var_ratio_log )