From 9252601f49cfcc8a5b50561829d1e516500fc4e7 Mon Sep 17 00:00:00 2001 From: wassname Date: Tue, 28 Apr 2020 08:35:32 +0800 Subject: [PATCH] tests --- neural_processes/utils.py | 1 + test/test_utils.py | 82 ++++++++++++++++++++++++++------------- 2 files changed, 55 insertions(+), 28 deletions(-) diff --git a/neural_processes/utils.py b/neural_processes/utils.py index 825ae0d..fc7093a 100644 --- a/neural_processes/utils.py +++ b/neural_processes/utils.py @@ -178,5 +178,6 @@ def kl_loss_var(prior_mu, log_var_prior, post_mu, log_var_post): - var_ratio_log ) kl_div = 0.5 * kl_div + logger.warning('seems to be an error in kl_loss_var, dont use it') return kl_div diff --git a/test/test_utils.py b/test/test_utils.py index 4ac3aed..6c0851c 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -3,10 +3,11 @@ import pickle, json import torch import tempfile + def test_obectdict(tmpdir): o = neural_processes.utils.ObjectDict(z=1, b=4, test="g", w=0) - pickle.dump(o, open(tmpdir+'/test.pkl', 'wb')) - o2 = pickle.load(open(tmpdir+'/test.pkl', 'rb')) + pickle.dump(o, open(tmpdir + "/test.pkl", "wb")) + o2 = pickle.load(open(tmpdir + "/test.pkl", "rb")) o3 = json.loads(json.dumps(o)) print(o, o2, o3) @@ -14,34 +15,57 @@ def test_obectdict(tmpdir): def test_agg_logs(): outputs = [ - {'val_loss': torch.tensor(0.7206), - 'log': {'val_loss': torch.tensor(0.7206), 'val_loss_p': torch.tensor(0.7206), 'val_loss_kl': torch.tensor(2.3812e-06), 'val_loss_mse': torch.tensor(0.1838)}}, - {'val_loss': torch.tensor(0.7047), - 'log': {'val_loss': torch.tensor(0.7047), 'val_loss_p': torch.tensor(0.7047), 'val_loss_kl': torch.tensor(2.8391e-06), 'val_loss_mse': torch.tensor(0.1696)}}, - ] + { + "val_loss": torch.tensor(0.7206), + "log": { + "val_loss": torch.tensor(0.7206), + "val_loss_p": torch.tensor(0.7206), + "val_loss_kl": torch.tensor(2.3812e-06), + "val_loss_mse": torch.tensor(0.1838), + }, + }, + { + "val_loss": torch.tensor(0.7047), + "log": { + "val_loss": torch.tensor(0.7047), + "val_loss_p": torch.tensor(0.7047), + "val_loss_kl": torch.tensor(2.8391e-06), + "val_loss_mse": torch.tensor(0.1696), + }, + }, + ] r = neural_processes.utils.agg_logs(outputs) assert isinstance(r, dict) - assert 'agg_val_loss' in r.keys() - assert 'agg_val_loss_kl' in r['log'].keys() - assert isinstance(r['agg_val_loss'], float) + assert "agg_val_loss" in r.keys() + assert "agg_val_loss_kl" in r["log"].keys() + assert isinstance(r["agg_val_loss"], float) - outputs = {'val_loss': torch.tensor(0.7206), - 'log': {'val_loss': torch.tensor(0.7206), 'val_loss_p': torch.tensor(0.7206), 'val_loss_kl': torch.tensor(2.3812e-06), 'val_loss_mse': torch.tensor(0.1838)}} + outputs = { + "val_loss": torch.tensor(0.7206), + "log": { + "val_loss": torch.tensor(0.7206), + "val_loss_p": torch.tensor(0.7206), + "val_loss_kl": torch.tensor(2.3812e-06), + "val_loss_mse": torch.tensor(0.1838), + }, + } r = neural_processes.utils.agg_logs(outputs) assert isinstance(r, dict) - assert 'agg_val_loss' in r.keys() - assert 'agg_val_loss_kl' in r['log'].keys() - assert isinstance(r['agg_val_loss'], float) + assert "agg_val_loss" in r.keys() + assert "agg_val_loss_kl" in r["log"].keys() + assert isinstance(r["agg_val_loss"], float) def test_round_values(): - r = neural_processes.utils.round_values({'a': 0.00004, 'd': {'b': 124455.45, 'c': 0.004}, 'l': 500}) + r = neural_processes.utils.round_values( + {"a": 0.00004, "d": {"b": 124455.45, "c": 0.004}, "l": 500} + ) def test_hparams_power(): - r = neural_processes.utils.hparams_power({'test_power': 2, 'test2': 2}) - assert r['test'] == 2 ** 2 - assert r['test2'] == 2 + r = neural_processes.utils.hparams_power({"test_power": 2, "test2": 2}) + assert r["test"] == 2 ** 2 + assert r["test2"] == 2 def test_log_prob_sigma(): @@ -49,18 +73,20 @@ def test_log_prob_sigma(): log_scale = torch.ones(4, 5) value = torch.zeros(4, 5) y_dist = torch.distributions.Normal(mean, log_scale.exp()) - r1 = y_dist.log_prob(values) - r2 = neural_processes.utils.log_prob_sigma(value, loc, log_scale) - assert (r1==r2).all() + r1 = y_dist.log_prob(value) + r2 = neural_processes.utils.log_prob_sigma(value, mean, log_scale) + assert (r1 == r2).all() + def test_kl_loss_var(): prior_mu = torch.zeros(4, 5) post_mu = torch.zeros(4, 5) + 1 log_var_prior = torch.ones(4, 5) log_var_post = torch.ones(4, 5) + 1 - dist_prior = torch.distributions.Normal(prior_mu, log_var_prior.exp()) - dist_post = torch.distributions.Normal(post_mu, log_var_post.exp()) - r1 = torch.distributions.kl_divergence( - dist_post, dist_prior).mean(-1) - r2 = neural_processes.utils.kl_loss_var(prior_mu, log_var_prior, post_mu, log_var_post) - assert (r1==r2).all() + dist_prior = torch.distributions.Normal(prior_mu, torch.exp(0.5 * log_var_prior)) + dist_post = torch.distributions.Normal(post_mu, torch.exp(0.5 * log_var_post)) + r1 = torch.distributions.kl_divergence(dist_post, dist_prior) + r2 = neural_processes.utils.kl_loss_var( + prior_mu, log_var_prior, post_mu, log_var_post + ) + assert (r1 == r2).all()