This commit is contained in:
wassname
2020-04-28 08:35:32 +08:00
parent 18d0f12dff
commit 9252601f49
2 changed files with 55 additions and 28 deletions
+1
View File
@@ -178,5 +178,6 @@ def kl_loss_var(prior_mu, log_var_prior, post_mu, log_var_post):
- var_ratio_log
)
kl_div = 0.5 * kl_div
logger.warning('seems to be an error in kl_loss_var, dont use it')
return kl_div
+54 -28
View File
@@ -3,10 +3,11 @@ import pickle, json
import torch
import tempfile
def test_obectdict(tmpdir):
o = neural_processes.utils.ObjectDict(z=1, b=4, test="g", w=0)
pickle.dump(o, open(tmpdir+'/test.pkl', 'wb'))
o2 = pickle.load(open(tmpdir+'/test.pkl', 'rb'))
pickle.dump(o, open(tmpdir + "/test.pkl", "wb"))
o2 = pickle.load(open(tmpdir + "/test.pkl", "rb"))
o3 = json.loads(json.dumps(o))
print(o, o2, o3)
@@ -14,34 +15,57 @@ def test_obectdict(tmpdir):
def test_agg_logs():
outputs = [
{'val_loss': torch.tensor(0.7206),
'log': {'val_loss': torch.tensor(0.7206), 'val_loss_p': torch.tensor(0.7206), 'val_loss_kl': torch.tensor(2.3812e-06), 'val_loss_mse': torch.tensor(0.1838)}},
{'val_loss': torch.tensor(0.7047),
'log': {'val_loss': torch.tensor(0.7047), 'val_loss_p': torch.tensor(0.7047), 'val_loss_kl': torch.tensor(2.8391e-06), 'val_loss_mse': torch.tensor(0.1696)}},
]
{
"val_loss": torch.tensor(0.7206),
"log": {
"val_loss": torch.tensor(0.7206),
"val_loss_p": torch.tensor(0.7206),
"val_loss_kl": torch.tensor(2.3812e-06),
"val_loss_mse": torch.tensor(0.1838),
},
},
{
"val_loss": torch.tensor(0.7047),
"log": {
"val_loss": torch.tensor(0.7047),
"val_loss_p": torch.tensor(0.7047),
"val_loss_kl": torch.tensor(2.8391e-06),
"val_loss_mse": torch.tensor(0.1696),
},
},
]
r = neural_processes.utils.agg_logs(outputs)
assert isinstance(r, dict)
assert 'agg_val_loss' in r.keys()
assert 'agg_val_loss_kl' in r['log'].keys()
assert isinstance(r['agg_val_loss'], float)
assert "agg_val_loss" in r.keys()
assert "agg_val_loss_kl" in r["log"].keys()
assert isinstance(r["agg_val_loss"], float)
outputs = {'val_loss': torch.tensor(0.7206),
'log': {'val_loss': torch.tensor(0.7206), 'val_loss_p': torch.tensor(0.7206), 'val_loss_kl': torch.tensor(2.3812e-06), 'val_loss_mse': torch.tensor(0.1838)}}
outputs = {
"val_loss": torch.tensor(0.7206),
"log": {
"val_loss": torch.tensor(0.7206),
"val_loss_p": torch.tensor(0.7206),
"val_loss_kl": torch.tensor(2.3812e-06),
"val_loss_mse": torch.tensor(0.1838),
},
}
r = neural_processes.utils.agg_logs(outputs)
assert isinstance(r, dict)
assert 'agg_val_loss' in r.keys()
assert 'agg_val_loss_kl' in r['log'].keys()
assert isinstance(r['agg_val_loss'], float)
assert "agg_val_loss" in r.keys()
assert "agg_val_loss_kl" in r["log"].keys()
assert isinstance(r["agg_val_loss"], float)
def test_round_values():
r = neural_processes.utils.round_values({'a': 0.00004, 'd': {'b': 124455.45, 'c': 0.004}, 'l': 500})
r = neural_processes.utils.round_values(
{"a": 0.00004, "d": {"b": 124455.45, "c": 0.004}, "l": 500}
)
def test_hparams_power():
r = neural_processes.utils.hparams_power({'test_power': 2, 'test2': 2})
assert r['test'] == 2 ** 2
assert r['test2'] == 2
r = neural_processes.utils.hparams_power({"test_power": 2, "test2": 2})
assert r["test"] == 2 ** 2
assert r["test2"] == 2
def test_log_prob_sigma():
@@ -49,18 +73,20 @@ def test_log_prob_sigma():
log_scale = torch.ones(4, 5)
value = torch.zeros(4, 5)
y_dist = torch.distributions.Normal(mean, log_scale.exp())
r1 = y_dist.log_prob(values)
r2 = neural_processes.utils.log_prob_sigma(value, loc, log_scale)
assert (r1==r2).all()
r1 = y_dist.log_prob(value)
r2 = neural_processes.utils.log_prob_sigma(value, mean, log_scale)
assert (r1 == r2).all()
def test_kl_loss_var():
prior_mu = torch.zeros(4, 5)
post_mu = torch.zeros(4, 5) + 1
log_var_prior = torch.ones(4, 5)
log_var_post = torch.ones(4, 5) + 1
dist_prior = torch.distributions.Normal(prior_mu, log_var_prior.exp())
dist_post = torch.distributions.Normal(post_mu, log_var_post.exp())
r1 = torch.distributions.kl_divergence(
dist_post, dist_prior).mean(-1)
r2 = neural_processes.utils.kl_loss_var(prior_mu, log_var_prior, post_mu, log_var_post)
assert (r1==r2).all()
dist_prior = torch.distributions.Normal(prior_mu, torch.exp(0.5 * log_var_prior))
dist_post = torch.distributions.Normal(post_mu, torch.exp(0.5 * log_var_post))
r1 = torch.distributions.kl_divergence(dist_post, dist_prior)
r2 = neural_processes.utils.kl_loss_var(
prior_mu, log_var_prior, post_mu, log_var_post
)
assert (r1 == r2).all()