initial piecewise linear distribution (#22)

* initial piecewise linear distribution

test is failing though

* typo

* added more tests

* added TransformedPiecewiseLinear and output

* added test_robustness and fixed typos

* more typos

* fix issue with torch.where

* sample without grad

* added license
This commit is contained in:
Kashif Rasul
2020-09-17 09:29:19 +02:00
committed by GitHub Enterprise
parent 79ac528398
commit 4aa176186a
7 changed files with 510 additions and 70 deletions
+20 -69
View File
File diff suppressed because one or more lines are too long
+1
View File
@@ -4,3 +4,4 @@ from .zero_inflated import (
ZeroInflatedPoisson,
ZeroInflatedNegativeBinomial,
)
from .piecewise_linear import PiecewiseLinear, TransformedPiecewiseLinear
+143
View File
@@ -0,0 +1,143 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import torch
import torch.nn.functional as F
from torch.distributions import (
constraints,
NegativeBinomial,
Poisson,
Distribution,
TransformedDistribution,
AffineTransform,
)
from torch.distributions.utils import broadcast_all, lazy_property
from .utils import broadcast_shape
class PiecewiseLinear(Distribution):
def __init__(self, gamma, slopes, knot_spacings, validate_args=None):
self.gamma = gamma
self.slopes = slopes
self.knot_spacings = knot_spacings
self.b, self.knot_positions = PiecewiseLinear._to_orig_params(
slopes=slopes, knot_spacings=knot_spacings
)
super(PiecewiseLinear, self).__init__(
batch_shape=self.gamma.shape, validate_args=validate_args
)
@staticmethod
def _to_orig_params(slopes, knot_spacings):
# b: the difference between slopes of consecutive pieces
b = slopes[..., 1:] - slopes[..., 0:-1]
# Add slope of first piece to b: b_0 = m_0
m_0 = slopes[..., 0:1]
b = torch.cat((m_0, b), dim=-1)
# The actual position of the knots is obtained by cumulative sum of
# the knot spacings. The first knot position is always 0 for quantile
# functions.
knot_positions = torch.cumsum(knot_spacings, dim=-1) - knot_spacings
return b, knot_positions
@torch.no_grad()
def sample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
u = torch.rand_like(self.gamma.expand(shape))
sample = self.quantile(u)
if len(sample_shape) == 0:
sample = sample.squeeze(0)
return sample
def quantile(self, level):
return self.quantile_internal(level, dim=0)
def quantile_internal(self, x, dim=None):
if dim is not None:
gamma = self.gamma.unsqueeze(dim=dim if dim == 0 else -1)
knot_positions = self.knot_positions.unsqueeze(dim)
b = self.b.unsqueeze(dim)
else:
gamma, knot_positions, b = self.gamma, self.knot_positions, self.b
x_minus_knots = x.unsqueeze(-1) - knot_positions
quantile = gamma + (b * F.relu(x_minus_knots)).sum(-1)
return quantile
def cdf(self, x):
gamma, b, knot_positions = self.gamma, self.b, self.knot_positions
quantiles_at_knots = self.quantile_internal(knot_positions, dim=-2)
# Mask to nullify the terms corresponding to knots larger than l_0,
# which is the largest knot (quantile level) such that the quantile
# at l_0, s(l_0) < x.
mask = torch.le(quantiles_at_knots, x.unsqueeze(-1))
slope_l0 = (b * mask).sum(-1)
# slope_l0 can be zero in which case a_tilde = 0.
# The following is to circumvent an issue where the
# backward() returns nans when slope_l0 is zero in the where
slope_l0_nz = torch.where(slope_l0 == 0.0, torch.ones_like(x), slope_l0)
a_tilde = torch.where(
slope_l0 == 0.0,
torch.zeros_like(x),
(x - gamma + (b * knot_positions * mask).sum(-1)) / slope_l0_nz,
)
return torch.clamp(a_tilde, min=0.0, max=1.0)
def crps(self, x):
gamma, b, knot_positions = self.gamma, self.b, self.knot_positions
a_tilde = self.cdf(x)
max_a_tilde_knots = torch.max(a_tilde.unsqueeze(-1), knot_positions)
knots_cubed = torch.pow(knot_positions, 3.0)
coeff = (
(1.0 - knots_cubed) / 3.0
- knot_positions
- torch.square(max_a_tilde_knots)
+ 2 * max_a_tilde_knots * knot_positions
)
return (2 * a_tilde - 1) * x + (1 - 2 * a_tilde) * gamma + (b * coeff).sum(-1)
class TransformedPiecewiseLinear(TransformedDistribution):
def __init__(self, base_distribution, transforms):
super().__init__(base_distribution, transforms)
def crps(self, x):
scale = 1.0
for transform in reversed(self.transforms):
assert isinstance(transform, AffineTransform), "Not an AffineTransform"
x = transform.inv(x)
scale *= transform.scale
p = self.base_dist.crps(x)
return p * scale
+1
View File
@@ -7,6 +7,7 @@ from .distribution_output import (
BetaOutput,
PoissonOutput,
ZeroInflatedPoissonOutput,
PiecewiseLinearOutput,
NegativeBinomialOutput,
ZeroInflatedNegativeBinomialOutput,
NormalMixtureOutput,
+45 -1
View File
@@ -22,7 +22,12 @@ from torch.distributions import (
Poisson,
)
from pts.distributions import ZeroInflatedPoisson, ZeroInflatedNegativeBinomial
from pts.distributions import (
ZeroInflatedPoisson,
ZeroInflatedNegativeBinomial,
PiecewiseLinear,
TransformedPiecewiseLinear,
)
from pts.core.component import validated
from .lambda_layer import LambdaLayer
@@ -333,6 +338,45 @@ class StudentTMixtureOutput(DistributionOutput):
return ()
class PiecewiseLinearOutput(DistributionOutput):
distr_cls: type = PiecewiseLinear
@validated()
def __init__(self, num_pieces: int) -> None:
super().__init__(self)
assert (
isinstance(num_pieces, int) and num_pieces > 1
), "num_pieces should be an integer larger than 1"
self.num_pieces = num_pieces
self.args_dim = {"gamma": 1, "slopes": num_pieces, "knot_spacings": num_pieces}
@classmethod
def domain_map(cls, gamma, slopes, knot_spacings):
# slopes of the pieces are non-negative
slopes_proj = F.softplus(slopes) + 1e-4
# the spacing between the knots should be in [0, 1] and sum to 1
knot_spacings_proj = torch.softmax(knot_spacings, dim=-1)
return gamma.squeeze(axis=-1), slopes_proj, knot_spacings_proj
def distribution(
self, distr_args, scale: Optional[torch.Tensor] = None,
) -> PiecewiseLinear:
if scale is None:
return self.distr_cls(*distr_args)
else:
distr = self.distr_cls(*distr_args)
return TransformedPiecewiseLinear(
distr, [AffineTransform(loc=0, scale=scale)]
)
@property
def event_shape(self) -> Tuple:
return ()
class NormalMixtureOutput(DistributionOutput):
@validated()
def __init__(self, components: int = 1) -> None:
+209
View File
@@ -0,0 +1,209 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import Tuple, List
import pytest
import torch
import numpy as np
from pts.distributions import PiecewiseLinear
from pts.modules import PiecewiseLinearOutput
def empirical_cdf(
samples: np.ndarray, num_bins: int = 100
) -> Tuple[np.ndarray, np.ndarray]:
"""
Calculate the empirical cdf from the given samples.
Parameters
----------
samples
Tensor of samples of shape (num_samples, batch_shape)
Returns
-------
Tensor
Empirically calculated cdf values. shape (num_bins, batch_shape)
Tensor
Bin edges corresponding to the cdf values. shape (num_bins + 1, batch_shape)
"""
# calculate histogram separately for each dimension in the batch size
cdfs = []
edges = []
batch_shape = samples.shape[1:]
agg_batch_dim = np.prod(batch_shape, dtype=np.int)
samples = samples.reshape((samples.shape[0], -1))
for i in range(agg_batch_dim):
s = samples[:, i]
bins = np.linspace(s.min(), s.max(), num_bins + 1)
hist, edge = np.histogram(s, bins=bins)
cdfs.append(np.cumsum(hist / len(s)))
edges.append(edge)
empirical_cdf = np.stack(cdfs, axis=-1).reshape(num_bins, *batch_shape)
edges = np.stack(edges, axis=-1).reshape(num_bins + 1, *batch_shape)
return empirical_cdf, edges
@pytest.mark.parametrize(
"distr, target, expected_target_cdf, expected_target_crps",
[
(
PiecewiseLinear(
gamma=torch.ones(size=(1,)),
slopes=torch.Tensor([2, 3, 1]).reshape(shape=(1, 3)),
knot_spacings=torch.Tensor([0.3, 0.4, 0.3]).reshape(shape=(1, 3)),
),
[2.2],
[0.5],
[0.223000],
),
(
PiecewiseLinear(
gamma=torch.ones(size=(2,)),
slopes=torch.Tensor([[1, 1], [1, 2]]).reshape(shape=(2, 2)),
knot_spacings=torch.Tensor([[0.4, 0.6], [0.4, 0.6]]).reshape(
shape=(2, 2)
),
),
[1.5, 1.6],
[0.5, 0.5],
[0.083333, 0.145333],
),
],
)
def test_values(
distr: PiecewiseLinear,
target: List[float],
expected_target_cdf: List[float],
expected_target_crps: List[float],
):
target = torch.Tensor(target).reshape(shape=(len(target),))
expected_target_cdf = np.array(expected_target_cdf).reshape(
(len(expected_target_cdf),)
)
expected_target_crps = np.array(expected_target_crps).reshape(
(len(expected_target_crps),)
)
assert all(np.isclose(distr.cdf(target).numpy(), expected_target_cdf))
assert all(np.isclose(distr.crps(target).numpy(), expected_target_crps))
# compare with empirical cdf from samples
num_samples = 100_000
samples = distr.sample((num_samples,)).numpy()
assert np.isfinite(samples).all()
emp_cdf, edges = empirical_cdf(samples)
calc_cdf = distr.cdf(torch.Tensor(edges)).numpy()
assert np.allclose(calc_cdf[1:, :], emp_cdf, atol=1e-2)
@pytest.mark.parametrize(
"batch_shape, num_pieces, num_samples",
[((3, 4, 5), 10, 100), ((1,), 2, 1), ((10,), 10, 10), ((10, 5), 2, 1)],
)
def test_shapes(batch_shape: Tuple, num_pieces: int, num_samples: int):
gamma = torch.ones(size=(*batch_shape,))
slopes = torch.ones(size=(*batch_shape, num_pieces)) # all positive
knot_spacings = (
torch.ones(size=(*batch_shape, num_pieces)) / num_pieces
) # positive and sum to 1
target = torch.ones(size=batch_shape) # shape of gamma
distr = PiecewiseLinear(gamma=gamma, slopes=slopes, knot_spacings=knot_spacings)
# assert that the parameters and target have proper shapes
assert gamma.shape == target.shape
assert knot_spacings.shape == slopes.shape
assert len(gamma.shape) + 1 == len(knot_spacings.shape)
# assert that batch_shape is computed properly
assert distr.batch_shape == batch_shape
# assert that shapes of original parameters are correct
assert distr.b.shape == slopes.shape
assert distr.knot_positions.shape == knot_spacings.shape
# assert that the shape of crps is correct
assert distr.crps(target).shape == batch_shape
# assert that the quantile shape is correct when computing the
# quantile values at knot positions - used for a_tilde
assert distr.quantile_internal(knot_spacings, dim=-2).shape == (
*batch_shape,
num_pieces,
)
# assert that the samples and the quantile values shape when num_samples
# is None is correct
samples = distr.sample()
assert samples.shape == batch_shape
assert distr.quantile_internal(samples).shape == batch_shape
# assert that the samples and the quantile values shape when num_samples
# is not None is correct
samples = distr.sample((num_samples,))
assert samples.shape == (num_samples, *batch_shape)
assert distr.quantile_internal(samples, dim=0).shape == (num_samples, *batch_shape,)
def test_simple_symmetric():
gamma = torch.Tensor([-1.0])
slopes = torch.Tensor([[2.0, 2.0]])
knot_spacings = torch.Tensor([[0.5, 0.5]])
distr = PiecewiseLinear(gamma=gamma, slopes=slopes, knot_spacings=knot_spacings)
assert distr.cdf(torch.Tensor([-2.0])).numpy().item() == 0.0
assert distr.cdf(torch.Tensor([+2.0])).numpy().item() == 1.0
expected_crps = np.array([1.0 + 2.0 / 3.0])
assert np.allclose(distr.crps(torch.Tensor([-2.0])).numpy(), expected_crps)
assert np.allclose(distr.crps(torch.Tensor([2.0])).numpy(), expected_crps)
def test_robustness():
distr_out = PiecewiseLinearOutput(num_pieces=10)
args_proj = distr_out.get_args_proj(in_features=30)
net_out = torch.normal(mean=0.0, size=(1000, 30), std=1e2)
gamma, slopes, knot_spacings = args_proj(net_out)
distr = distr_out.distribution((gamma, slopes, knot_spacings))
# compute the 1-quantile (the 0-quantile is gamma)
sup_support = gamma + (slopes * knot_spacings).sum(-1)
assert torch.le(gamma, sup_support).numpy().all()
width = sup_support - gamma
x = torch.from_numpy(
np.random.uniform(
low=(gamma - width).detach().numpy(),
high=(sup_support + width).detach().numpy(),
).astype(np.float32),
)
# check that 0 < cdf < 1
cdf_x = distr.cdf(x)
assert torch.min(cdf_x).item() >= 0.0 and torch.max(cdf_x).item() <= 1.0
# check that 0 <= crps
crps_x = distr.crps(x)
assert torch.min(crps_x).item() >= 0.0
+91
View File
@@ -0,0 +1,91 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
from torch.distributions import (
NegativeBinomial,
Normal,
Poisson,
)
from pts.distributions import (
ZeroInflatedDistribution,
ZeroInflatedNegativeBinomial,
ZeroInflatedPoisson,
broadcast_shape,
)
from numpy.testing import assert_allclose as assert_close
@pytest.mark.parametrize("gate_shape", [(), (2,), (3, 1), (3, 2)])
@pytest.mark.parametrize("base_shape", [(), (2,), (3, 1), (3, 2)])
def test_zid_shape(gate_shape, base_shape):
gate = torch.rand(gate_shape)
base_dist = Normal(torch.randn(base_shape), torch.randn(base_shape).exp())
d = ZeroInflatedDistribution(gate, base_dist)
assert d.batch_shape == broadcast_shape(gate_shape, base_shape)
assert d.support == base_dist.support
d2 = d.expand([4, 3, 2])
assert d2.batch_shape == (4, 3, 2)
@pytest.mark.parametrize("rate", [0.1, 0.5, 0.9, 1.0, 1.1, 2.0, 10.0])
def test_zip_0_gate(rate):
# if gate is 0 ZIP is Poisson
zip_ = ZeroInflatedPoisson(torch.zeros(1), torch.tensor(rate))
pois = Poisson(torch.tensor(rate))
s = pois.sample((20,))
zip_prob = zip_.log_prob(s)
pois_prob = pois.log_prob(s)
assert_close(zip_prob, pois_prob, atol=1e-06)
@pytest.mark.parametrize("gate", [0.0, 0.25, 0.5, 0.75, 1.0])
@pytest.mark.parametrize("rate", [0.1, 0.5, 0.9, 1.0, 1.1, 2.0, 10.0])
def test_zip_mean_variance(gate, rate):
num_samples = 1000000
zip_ = ZeroInflatedPoisson(torch.tensor(gate), torch.tensor(rate))
s = zip_.sample((num_samples,))
expected_mean = zip_.mean
estimated_mean = s.mean()
expected_std = zip_.stddev
estimated_std = s.std()
assert_close(expected_mean, estimated_mean, atol=1e-02)
assert_close(expected_std, estimated_std, atol=1e-02)
@pytest.mark.parametrize("total_count", [0.1, 0.5, 0.9, 1.0, 1.1, 2.0, 10.0])
@pytest.mark.parametrize("probs", [0.1, 0.5, 0.9])
def test_zinb_0_gate(total_count, probs):
# if gate is 0 ZINB is NegativeBinomial
zinb_ = ZeroInflatedNegativeBinomial(
torch.zeros(1), total_count=torch.tensor(total_count), probs=torch.tensor(probs)
)
neg_bin = NegativeBinomial(torch.tensor(total_count), probs=torch.tensor(probs))
s = neg_bin.sample((20,))
zinb_prob = zinb_.log_prob(s)
neg_bin_prob = neg_bin.log_prob(s)
assert_close(zinb_prob, neg_bin_prob, atol=1e-06)
@pytest.mark.parametrize("gate", [0.0, 0.25, 0.5, 0.75, 1.0])
@pytest.mark.parametrize("total_count", [0.1, 0.5, 0.9, 1.0, 1.1, 2.0, 10.0])
@pytest.mark.parametrize("logits", [-0.5, 0.5, -0.9, 1.9])
def test_zinb_mean_variance(gate, total_count, logits):
num_samples = 1000000
zinb_ = ZeroInflatedNegativeBinomial(
torch.tensor(gate),
total_count=torch.tensor(total_count),
logits=torch.tensor(logits),
)
s = zinb_.sample((num_samples,))
expected_mean = zinb_.mean
estimated_mean = s.mean()
expected_std = zinb_.stddev
estimated_std = s.std()
assert_close(expected_mean, estimated_mean, atol=1e-01)
assert_close(expected_std, estimated_std, atol=1e-1)