mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-06-27 18:06:19 +08:00
add gate_logits to zero inflated
This commit is contained in:
@@ -2,8 +2,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.distributions import constraints, NegativeBinomial, Poisson, Distribution
|
||||
from torch.distributions.utils import broadcast_all, lazy_property
|
||||
from torch.distributions.utils import (
|
||||
broadcast_all,
|
||||
lazy_property,
|
||||
lazy_property,
|
||||
logits_to_probs,
|
||||
probs_to_logits,
|
||||
)
|
||||
|
||||
from .utils import broadcast_shape
|
||||
|
||||
@@ -15,20 +22,33 @@ class ZeroInflatedDistribution(Distribution):
|
||||
This can be used directly or can be used as a base class as e.g. for
|
||||
:class:`ZeroInflatedPoisson` and :class:`ZeroInflatedNegativeBinomial`.
|
||||
|
||||
:param torch.Tensor gate: probability of extra zeros given via a Bernoulli distribution.
|
||||
:param TorchDistribution base_dist: the base distribution.
|
||||
:param torch.Tensor gate: probability of extra zeros given via a Bernoulli distribution.
|
||||
:param torch.Tensor gate_logits: logits of extra zeros given via a Bernoulli distribution.
|
||||
"""
|
||||
|
||||
arg_constraints = {"gate": constraints.unit_interval}
|
||||
arg_constraints = {
|
||||
"gate": constraints.unit_interval,
|
||||
"gate_logits": constraints.real,
|
||||
}
|
||||
|
||||
def __init__(self, gate, base_dist, validate_args=None):
|
||||
def __init__(self, base_dist, *, gate=None, gate_logits=None, validate_args=None):
|
||||
if (gate is None) == (gate_logits is None):
|
||||
raise ValueError(
|
||||
"Either `gate` or `gate_logits` must be specified, but not both."
|
||||
)
|
||||
if gate is not None:
|
||||
batch_shape = broadcast_shape(gate.shape, base_dist.batch_shape)
|
||||
self.gate = gate.expand(batch_shape)
|
||||
else:
|
||||
batch_shape = broadcast_shape(gate_logits.shape, base_dist.batch_shape)
|
||||
self.gate_logits = gate_logits.expand(batch_shape)
|
||||
if base_dist.event_shape:
|
||||
raise ValueError(
|
||||
"ZeroInflatedDistribution expected empty "
|
||||
"base_dist.event_shape but got {}".format(base_dist.event_shape)
|
||||
)
|
||||
batch_shape = broadcast_shape(gate.shape, base_dist.batch_shape)
|
||||
self.gate = gate.expand(batch_shape)
|
||||
|
||||
self.base_dist = base_dist.expand(batch_shape)
|
||||
event_shape = torch.Size()
|
||||
|
||||
@@ -38,13 +58,29 @@ class ZeroInflatedDistribution(Distribution):
|
||||
def support(self):
|
||||
return self.base_dist.support
|
||||
|
||||
@lazy_property
|
||||
def gate(self):
|
||||
return logits_to_probs(self.gate_logits)
|
||||
|
||||
@lazy_property
|
||||
def gate_logits(self):
|
||||
return probs_to_logits(self.gate)
|
||||
|
||||
def log_prob(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
|
||||
gate, value = broadcast_all(self.gate, value)
|
||||
log_prob = (-gate).log1p() + self.base_dist.log_prob(value)
|
||||
log_prob = torch.where(value == 0, (gate + log_prob.exp()).log(), log_prob)
|
||||
if "gate" in self.__dict__:
|
||||
gate, value = broadcast_all(self.gate, value)
|
||||
log_prob = (-gate).log1p() + self.base_dist.log_prob(value)
|
||||
log_prob = torch.where(value == 0, (gate + log_prob.exp()).log(), log_prob)
|
||||
else:
|
||||
gate_logits, value = broadcast_all(self.gate_logits, value)
|
||||
log_prob_minus_log_gate = -gate_logits + self.base_dist.log_prob(value)
|
||||
log_gate = -F.softplus(-gate_logits)
|
||||
log_prob = log_prob_minus_log_gate + log_gate
|
||||
zero_log_prob = F.softplus(log_prob_minus_log_gate) + log_gate
|
||||
log_prob = torch.where(value == 0, zero_log_prob, log_prob)
|
||||
return log_prob
|
||||
|
||||
def sample(self, sample_shape=torch.Size()):
|
||||
@@ -68,9 +104,16 @@ class ZeroInflatedDistribution(Distribution):
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(type(self), _instance)
|
||||
batch_shape = torch.Size(batch_shape)
|
||||
gate = self.gate.expand(batch_shape)
|
||||
gate = self.gate.expand(batch_shape) if "gate" in self.__dict__ else None
|
||||
gate_logits = (
|
||||
self.gate_logits.expand(batch_shape)
|
||||
if "gate_logits" in self.__dict__
|
||||
else None
|
||||
)
|
||||
base_dist = self.base_dist.expand(batch_shape)
|
||||
ZeroInflatedDistribution.__init__(new, gate, base_dist, validate_args=False)
|
||||
ZeroInflatedDistribution.__init__(
|
||||
new, base_dist, gate=gate, gate_logits=gate_logits, validate_args=False
|
||||
)
|
||||
new._validate_args = self._validate_args
|
||||
return new
|
||||
|
||||
@@ -79,18 +122,25 @@ class ZeroInflatedPoisson(ZeroInflatedDistribution):
|
||||
"""
|
||||
A Zero Inflated Poisson distribution.
|
||||
|
||||
:param torch.Tensor gate: probability of extra zeros.
|
||||
:param torch.Tensor rate: rate of poisson distribution.
|
||||
:param torch.Tensor gate: probability of extra zeros.
|
||||
:param torch.Tensor gate_logits: logits of extra zeros.
|
||||
"""
|
||||
|
||||
arg_constraints = {"gate": constraints.unit_interval, "rate": constraints.positive}
|
||||
arg_constraints = {
|
||||
"rate": constraints.positive,
|
||||
"gate": constraints.unit_interval,
|
||||
"gate_logits": constraints.real,
|
||||
}
|
||||
support = constraints.nonnegative_integer
|
||||
|
||||
def __init__(self, gate, rate, validate_args=None):
|
||||
def __init__(self, rate, *, gate=None, gate_logits=None, validate_args=None):
|
||||
base_dist = Poisson(rate=rate, validate_args=False)
|
||||
base_dist._validate_args = validate_args
|
||||
|
||||
super().__init__(gate, base_dist, validate_args=validate_args)
|
||||
super().__init__(
|
||||
base_dist, gate=gate, gate_logits=gate_logits, validate_args=validate_args
|
||||
)
|
||||
|
||||
@property
|
||||
def rate(self):
|
||||
@@ -101,28 +151,44 @@ class ZeroInflatedNegativeBinomial(ZeroInflatedDistribution):
|
||||
"""
|
||||
A Zero Inflated Negative Binomial distribution.
|
||||
|
||||
:param torch.Tensor gate: probability of extra zeros.
|
||||
:param total_count: non-negative number of negative Bernoulli trials.
|
||||
:type total_count: float or torch.Tensor
|
||||
:param torch.Tensor probs: Event probabilities of success in the half open interval [0, 1).
|
||||
:param torch.Tensor logits: Event log-odds for probabilities of success.
|
||||
:param torch.Tensor probs: Event probabilities of failure in the half open interval [0, 1).
|
||||
:param torch.Tensor logits: Event log-odds for probabilities of failure.
|
||||
:param torch.Tensor gate: probability of extra zeros.
|
||||
:param torch.Tensor gate_logits: logits of extra zeros.
|
||||
"""
|
||||
|
||||
arg_constraints = {
|
||||
"gate": constraints.unit_interval,
|
||||
"total_count": constraints.greater_than_eq(0),
|
||||
"probs": constraints.half_open_interval(0.0, 1.0),
|
||||
"logits": constraints.real,
|
||||
"gate": constraints.unit_interval,
|
||||
"gate_logits": constraints.real,
|
||||
}
|
||||
support = constraints.nonnegative_integer
|
||||
|
||||
def __init__(self, gate, total_count, probs=None, logits=None, validate_args=None):
|
||||
def __init__(
|
||||
self,
|
||||
total_count,
|
||||
*,
|
||||
probs=None,
|
||||
logits=None,
|
||||
gate=None,
|
||||
gate_logits=None,
|
||||
validate_args=None
|
||||
):
|
||||
base_dist = NegativeBinomial(
|
||||
total_count=total_count, probs=probs, logits=logits, validate_args=False,
|
||||
total_count=total_count,
|
||||
probs=probs,
|
||||
logits=logits,
|
||||
validate_args=False,
|
||||
)
|
||||
base_dist._validate_args = validate_args
|
||||
|
||||
super().__init__(gate, base_dist, validate_args=validate_args)
|
||||
super().__init__(
|
||||
base_dist, gate=gate, gate_logits=gate_logits, validate_args=validate_args
|
||||
)
|
||||
|
||||
@property
|
||||
def total_count(self):
|
||||
|
||||
@@ -201,7 +201,7 @@ class PoissonOutput(IndependentDistributionOutput):
|
||||
|
||||
|
||||
class ZeroInflatedPoissonOutput(IndependentDistributionOutput):
|
||||
args_dim: Dict[str, int] = {"gate": 1, "rate": 1}
|
||||
args_dim: Dict[str, int] = {"gate_logits": 1, "rate": 1}
|
||||
distr_cls: type = ZeroInflatedPoisson
|
||||
|
||||
def __init__(self, dim: Optional[int] = None) -> None:
|
||||
@@ -210,21 +210,20 @@ class ZeroInflatedPoissonOutput(IndependentDistributionOutput):
|
||||
self.args_dim = {k: dim for k in self.args_dim}
|
||||
|
||||
@classmethod
|
||||
def domain_map(cls, gate, rate):
|
||||
gate_unit = torch.sigmoid(gate).clone()
|
||||
def domain_map(cls, gate_logits, rate):
|
||||
rate_pos = F.softplus(rate).clone()
|
||||
|
||||
return gate_unit.squeeze(-1), rate_pos.squeeze(-1)
|
||||
return gate_logits.squeeze(-1), rate_pos.squeeze(-1)
|
||||
|
||||
def distribution(
|
||||
self, distr_args, scale: Optional[torch.Tensor] = None
|
||||
) -> Distribution:
|
||||
gate, rate = distr_args
|
||||
gate_logits, rate = distr_args
|
||||
|
||||
if scale is not None:
|
||||
rate *= scale
|
||||
|
||||
return self.independent(ZeroInflatedPoisson(gate=gate, rate=rate))
|
||||
return self.independent(ZeroInflatedPoisson(rate=rate, gate_logits=gate_logits))
|
||||
|
||||
|
||||
class NegativeBinomialOutput(IndependentDistributionOutput):
|
||||
@@ -255,7 +254,7 @@ class NegativeBinomialOutput(IndependentDistributionOutput):
|
||||
|
||||
|
||||
class ZeroInflatedNegativeBinomialOutput(IndependentDistributionOutput):
|
||||
args_dim: Dict[str, int] = {"gate": 1, "total_count": 1, "logits": 1}
|
||||
args_dim: Dict[str, int] = {"gate_logits": 1, "total_count": 1, "logits": 1}
|
||||
distr_cls: type = ZeroInflatedNegativeBinomial
|
||||
|
||||
def __init__(self, dim: Optional[int] = None) -> None:
|
||||
@@ -264,22 +263,22 @@ class ZeroInflatedNegativeBinomialOutput(IndependentDistributionOutput):
|
||||
self.args_dim = {k: dim for k in self.args_dim}
|
||||
|
||||
@classmethod
|
||||
def domain_map(cls, gate, total_count, logits):
|
||||
gate = torch.sigmoid(gate)
|
||||
def domain_map(cls, gate_logits, total_count, logits):
|
||||
total_count = F.softplus(total_count)
|
||||
return gate.squeeze(-1), total_count.squeeze(-1), logits.squeeze(-1)
|
||||
return gate_logits.squeeze(-1), total_count.squeeze(-1), logits.squeeze(-1)
|
||||
|
||||
def distribution(
|
||||
self, distr_args, scale: Optional[torch.Tensor] = None
|
||||
) -> Distribution:
|
||||
gate, total_count, logits = distr_args
|
||||
gate_logits, total_count, logits = distr_args
|
||||
|
||||
if scale is not None:
|
||||
logits += scale.log()
|
||||
|
||||
return self.independent(
|
||||
ZeroInflatedNegativeBinomial(
|
||||
gate=gate, total_count=total_count, logits=logits
|
||||
total_count=total_count,
|
||||
gate_logits=gate_logits, logits=logits
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ def test_zid_shape(gate_shape, base_shape):
|
||||
gate = torch.rand(gate_shape)
|
||||
base_dist = Normal(torch.randn(base_shape), torch.randn(base_shape).exp())
|
||||
|
||||
d = ZeroInflatedDistribution(gate, base_dist)
|
||||
d = ZeroInflatedDistribution(base_dist, gate=gate)
|
||||
assert d.batch_shape == broadcast_shape(gate_shape, base_shape)
|
||||
assert d.support == base_dist.support
|
||||
|
||||
@@ -36,19 +36,22 @@ def test_zid_shape(gate_shape, base_shape):
|
||||
@pytest.mark.parametrize("rate", [0.1, 0.5, 0.9, 1.0, 1.1, 2.0, 10.0])
|
||||
def test_zip_0_gate(rate):
|
||||
# if gate is 0 ZIP is Poisson
|
||||
zip_ = ZeroInflatedPoisson(torch.zeros(1), torch.tensor(rate))
|
||||
zip1 = ZeroInflatedPoisson(torch.tensor(rate), gate=torch.zeros(1))
|
||||
zip2 = ZeroInflatedPoisson(torch.tensor(rate), gate_logits=torch.tensor(-99.9))
|
||||
pois = Poisson(torch.tensor(rate))
|
||||
s = pois.sample((20,))
|
||||
zip_prob = zip_.log_prob(s)
|
||||
zip1_prob = zip1.log_prob(s)
|
||||
zip2_prob = zip2.log_prob(s)
|
||||
pois_prob = pois.log_prob(s)
|
||||
assert_close(zip_prob, pois_prob, atol=1e-06)
|
||||
assert_close(zip1_prob, pois_prob, atol=1e-05)
|
||||
assert_close(zip2_prob, pois_prob, atol=1e-05)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("gate", [0.0, 0.25, 0.5, 0.75, 1.0])
|
||||
@pytest.mark.parametrize("rate", [0.1, 0.5, 0.9, 1.0, 1.1, 2.0, 10.0])
|
||||
def test_zip_mean_variance(gate, rate):
|
||||
num_samples = 1000000
|
||||
zip_ = ZeroInflatedPoisson(torch.tensor(gate), torch.tensor(rate))
|
||||
zip_ = ZeroInflatedPoisson(torch.tensor(rate), gate=torch.tensor(gate))
|
||||
s = zip_.sample((num_samples,))
|
||||
expected_mean = zip_.mean
|
||||
estimated_mean = s.mean()
|
||||
@@ -62,14 +65,23 @@ def test_zip_mean_variance(gate, rate):
|
||||
@pytest.mark.parametrize("probs", [0.1, 0.5, 0.9])
|
||||
def test_zinb_0_gate(total_count, probs):
|
||||
# if gate is 0 ZINB is NegativeBinomial
|
||||
zinb_ = ZeroInflatedNegativeBinomial(
|
||||
torch.zeros(1), total_count=torch.tensor(total_count), probs=torch.tensor(probs)
|
||||
zinb1 = ZeroInflatedNegativeBinomial(
|
||||
total_count=torch.tensor(total_count),
|
||||
gate=torch.zeros(1),
|
||||
probs=torch.tensor(probs),
|
||||
)
|
||||
zinb2 = ZeroInflatedNegativeBinomial(
|
||||
total_count=torch.tensor(total_count),
|
||||
gate_logits=torch.tensor(-99.9),
|
||||
probs=torch.tensor(probs),
|
||||
)
|
||||
neg_bin = NegativeBinomial(torch.tensor(total_count), probs=torch.tensor(probs))
|
||||
s = neg_bin.sample((20,))
|
||||
zinb_prob = zinb_.log_prob(s)
|
||||
zinb1_prob = zinb1.log_prob(s)
|
||||
zinb2_prob = zinb2.log_prob(s)
|
||||
neg_bin_prob = neg_bin.log_prob(s)
|
||||
assert_close(zinb_prob, neg_bin_prob, atol=1e-06)
|
||||
assert_close(zinb1_prob, neg_bin_prob, atol=1e-05)
|
||||
assert_close(zinb2_prob, neg_bin_prob, atol=1e-05)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("gate", [0.0, 0.25, 0.5, 0.75, 1.0])
|
||||
@@ -78,8 +90,8 @@ def test_zinb_0_gate(total_count, probs):
|
||||
def test_zinb_mean_variance(gate, total_count, logits):
|
||||
num_samples = 1000000
|
||||
zinb_ = ZeroInflatedNegativeBinomial(
|
||||
torch.tensor(gate),
|
||||
total_count=torch.tensor(total_count),
|
||||
gate=torch.tensor(gate),
|
||||
logits=torch.tensor(logits),
|
||||
)
|
||||
s = zinb_.sample((num_samples,))
|
||||
@@ -88,4 +100,4 @@ def test_zinb_mean_variance(gate, total_count, logits):
|
||||
expected_std = zinb_.stddev
|
||||
estimated_std = s.std()
|
||||
assert_close(expected_mean, estimated_mean, atol=1e-01)
|
||||
assert_close(expected_std, estimated_std, atol=1e-1)
|
||||
assert_close(expected_std, estimated_std, atol=1e-01)
|
||||
|
||||
Reference in New Issue
Block a user