Files
pytorch-ts/test/modules/test_distribution_output.py
Dr. Kashif Rasul d5577a2c9e fix some tests
2020-12-30 19:17:54 +01:00

307 lines
9.3 KiB
Python

from typing import List, Tuple
import numpy as np
import pytest
import torch
import torch.nn as nn
from torch.distributions import (
StudentT,
Beta,
NegativeBinomial,
LowRankMultivariateNormal,
MultivariateNormal,
Independent,
Normal,
)
from torch.nn.utils import clip_grad_norm_
from torch.optim import SGD
from torch.utils.data import TensorDataset, DataLoader
from gluonts.torch.modules.distribution_output import DistributionOutput
from pts.modules import (
StudentTOutput,
BetaOutput,
NegativeBinomialOutput,
LowRankMultivariateNormalOutput,
MultivariateNormalOutput,
NormalOutput,
)
NUM_SAMPLES = 2000
BATCH_SIZE = 32
TOL = 0.3
START_TOL_MULTIPLE = 1
def inv_softplus(y: np.ndarray) -> np.ndarray:
return np.log(np.exp(y) - 1)
def maximum_likelihood_estimate_sgd(
distr_output: DistributionOutput,
samples: torch.Tensor,
init_biases: List[np.ndarray] = None,
num_epochs: int = 5,
learning_rate: float = 1e-2,
):
arg_proj = distr_output.get_args_proj(in_features=1)
if init_biases is not None:
for param, bias in zip(arg_proj.proj, init_biases):
nn.init.constant_(param.bias, bias)
dummy_data = torch.ones((len(samples), 1))
dataset = TensorDataset(dummy_data, samples)
train_data = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
optimizer = SGD(arg_proj.parameters(), lr=learning_rate)
for e in range(num_epochs):
cumulative_loss = 0
num_batches = 0
for i, (data, sample_label) in enumerate(train_data):
optimizer.zero_grad()
distr_args = arg_proj(data)
distr = distr_output.distribution(distr_args)
loss = -distr.log_prob(sample_label).mean()
loss.backward()
clip_grad_norm_(arg_proj.parameters(), 10.0)
optimizer.step()
num_batches += 1
cumulative_loss += loss.item()
print("Epoch %s, loss: %s" % (e, cumulative_loss / num_batches))
if len(distr_args[0].shape) == 1:
return [param.detach().numpy() for param in arg_proj(torch.ones((1, 1)))]
return [param[0].detach().numpy() for param in arg_proj(torch.ones((1, 1)))]
@pytest.mark.parametrize("concentration1, concentration0", [(3.75, 1.25)])
def test_beta_likelihood(concentration1: float, concentration0: float) -> None:
"""
Test to check that maximizing the likelihood recovers the parameters
"""
# generate samples
concentration1s = torch.zeros((NUM_SAMPLES,)) + concentration1
concentration0s = torch.zeros((NUM_SAMPLES,)) + concentration0
distr = Beta(concentration1s, concentration0s)
samples = distr.sample()
init_biases = [
inv_softplus(concentration1 - START_TOL_MULTIPLE * TOL * concentration1),
inv_softplus(concentration0 - START_TOL_MULTIPLE * TOL * concentration0),
]
concentration1_hat, concentration0_hat = maximum_likelihood_estimate_sgd(
BetaOutput(),
samples,
init_biases=init_biases,
learning_rate=0.05,
num_epochs=10,
)
print("concentration1:", concentration1_hat, "concentration0:", concentration0_hat)
assert (
np.abs(concentration1_hat - concentration1) < TOL * concentration1
), f"concentration1 did not match: concentration1 = {concentration1}, concentration1_hat = {concentration1_hat}"
assert (
np.abs(concentration0_hat - concentration0) < TOL * concentration0
), f"concentration0 did not match: concentration0 = {concentration0}, concentration0_hat = {concentration0_hat}"
@pytest.mark.parametrize("total_count_logit", [(2.5, 0.7)])
def test_neg_binomial(total_count_logit: Tuple[float, float]) -> None:
"""
Test to check that maximizing the likelihood recovers the parameters
"""
# test instance
total_count, logit = total_count_logit
# generate samples
total_counts = torch.zeros((NUM_SAMPLES,)) + total_count
logits = torch.zeros((NUM_SAMPLES,)) + logit
neg_bin_distr = NegativeBinomial(total_count=total_counts, logits=logits)
samples = neg_bin_distr.sample()
init_biases = [
inv_softplus(total_count - START_TOL_MULTIPLE * TOL * total_count),
logit - START_TOL_MULTIPLE * TOL * logit,
]
total_count_hat, logit_hat = maximum_likelihood_estimate_sgd(
NegativeBinomialOutput(),
samples,
init_biases=init_biases,
num_epochs=15,
)
assert (
np.abs(total_count_hat - total_count) < TOL * total_count
), f"total_count did not match: total_count = {total_count}, total_count_hat = {total_count_hat}"
assert (
np.abs(logit_hat - logit) < TOL * logit_hat
), f"logit did not match: logit = {logit}, logit_hat = {logit_hat}"
@pytest.mark.parametrize("df, loc, scale,", [(6.0, 2.3, 0.7)])
def test_studentT_likelihood(df: float, loc: float, scale: float):
dfs = torch.zeros((NUM_SAMPLES,)) + df
locs = torch.zeros((NUM_SAMPLES,)) + loc
scales = torch.zeros((NUM_SAMPLES,)) + scale
distr = StudentT(df=dfs, loc=locs, scale=scales)
samples = distr.sample()
init_bias = [
inv_softplus(df - 2),
loc - START_TOL_MULTIPLE * TOL * loc,
inv_softplus(scale - START_TOL_MULTIPLE * TOL * scale),
]
df_hat, loc_hat, scale_hat = maximum_likelihood_estimate_sgd(
StudentTOutput(),
samples,
init_biases=init_bias,
num_epochs=15,
learning_rate=1e-3,
)
assert (
np.abs(df_hat - df) < TOL * df
), f"df did not match: df = {df}, df_hat = {df_hat}"
assert (
np.abs(loc_hat - loc) < TOL * loc
), f"loc did not match: loc = {loc}, loc_hat = {loc_hat}"
assert (
np.abs(scale_hat - scale) < TOL * scale
), f"scale did not match: scale = {scale}, scale_hat = {scale_hat}"
def test_independent_normal() -> None:
num_samples = 2000
dim = 4
loc = np.arange(0, dim) / float(dim)
diag = np.arange(dim) / dim + 0.5
Sigma = diag ** 2
distr = Independent(Normal(loc=torch.Tensor(loc), scale=torch.Tensor(diag)), 1)
assert np.allclose(
distr.variance.numpy(), Sigma, atol=0.1, rtol=0.1
), f"did not match: sigma = {Sigma}, sigma_hat = {distr.variance.numpy()}"
samples = distr.sample((num_samples,))
loc_hat, diag_hat = maximum_likelihood_estimate_sgd(
NormalOutput(dim=dim),
samples,
learning_rate=0.01,
num_epochs=10,
)
distr = Independent(
Normal(loc=torch.Tensor(loc_hat), scale=torch.Tensor(diag_hat)), 1
)
Sigma_hat = distr.variance.numpy()
assert np.allclose(
loc_hat, loc, atol=0.2, rtol=0.1
), f"mu did not match: loc = {loc}, loc_hat = {loc_hat}"
assert np.allclose(
Sigma_hat, Sigma, atol=0.1, rtol=0.1
), f"sigma did not match: sigma = {Sigma}, sigma_hat = {Sigma_hat}"
def test_lowrank_multivariate_normal() -> None:
num_samples = 2000
dim = 4
rank = 3
loc = np.arange(0, dim) / float(dim)
cov_diag = np.eye(dim) * (np.arange(dim) / dim + 0.5)
cov_factor = np.sqrt(np.ones((dim, rank)) * 0.2)
Sigma = cov_factor @ cov_factor.T + cov_diag
distr = LowRankMultivariateNormal(
loc=torch.Tensor(loc.copy()),
cov_diag=torch.Tensor(np.diag(cov_diag).copy()),
cov_factor=torch.Tensor(cov_factor.copy()),
)
assert np.allclose(
distr.covariance_matrix.numpy(), Sigma, atol=0.1, rtol=0.1
), f"did not match: sigma = {Sigma}, sigma_hat = {distr.covariance_matrix.numpy()}"
samples = distr.sample((num_samples,))
loc_hat, cov_factor_hat, cov_diag_hat = maximum_likelihood_estimate_sgd(
LowRankMultivariateNormalOutput(
dim=dim, rank=rank, sigma_init=0.2, sigma_minimum=0.0
),
samples,
learning_rate=0.01,
num_epochs=10,
)
distr = LowRankMultivariateNormal(
loc=torch.Tensor(loc_hat),
cov_diag=torch.Tensor(cov_diag_hat),
cov_factor=torch.Tensor(cov_factor_hat),
)
Sigma_hat = distr.covariance_matrix.numpy()
assert np.allclose(
loc_hat, loc, atol=0.2, rtol=0.1
), f"mu did not match: loc = {loc}, loc_hat = {loc_hat}"
assert np.allclose(
Sigma_hat, Sigma, atol=0.1, rtol=0.1
), f"sigma did not match: sigma = {Sigma}, sigma_hat = {Sigma_hat}"
def test_multivariate_normal() -> None:
num_samples = 2000
dim = 2
mu = np.arange(0, dim) / float(dim)
L_diag = np.ones((dim,))
L_low = 0.1 * np.ones((dim, dim)) * np.tri(dim, k=-1)
L = np.diag(L_diag) + L_low
Sigma = L.dot(L.transpose())
distr = MultivariateNormal(loc=torch.Tensor(mu), scale_tril=torch.Tensor(L))
samples = distr.sample((num_samples,))
mu_hat, L_hat = maximum_likelihood_estimate_sgd(
MultivariateNormalOutput(dim=dim),
samples,
init_biases=None, # todo we would need to rework biases a bit to use it in the multivariate case
learning_rate=0.01,
num_epochs=10,
)
distr = MultivariateNormal(loc=torch.tensor(mu_hat), scale_tril=torch.tensor(L_hat))
Sigma_hat = distr.covariance_matrix.numpy()
assert np.allclose(
mu_hat, mu, atol=0.1, rtol=0.1
), f"mu did not match: mu = {mu}, mu_hat = {mu_hat}"
assert np.allclose(
Sigma_hat, Sigma, atol=0.1, rtol=0.1
), f"Sigma did not match: sigma = {Sigma}, sigma_hat = {Sigma_hat}"