Revert "add gate_logits to zero inflated"

This reverts commit 7415a15256.
This commit is contained in:
Dr. Kashif Rasul
2020-10-21 16:44:44 +02:00
parent 7415a15256
commit 149a35a7f8
3 changed files with 45 additions and 122 deletions
+22 -88
View File
@@ -2,15 +2,8 @@
# SPDX-License-Identifier: Apache-2.0
import torch
import torch.nn.functional as F
from torch.distributions import constraints, NegativeBinomial, Poisson, Distribution
from torch.distributions.utils import (
broadcast_all,
lazy_property,
lazy_property,
logits_to_probs,
probs_to_logits,
)
from torch.distributions.utils import broadcast_all, lazy_property
from .utils import broadcast_shape
@@ -22,33 +15,20 @@ class ZeroInflatedDistribution(Distribution):
This can be used directly or can be used as a base class as e.g. for
:class:`ZeroInflatedPoisson` and :class:`ZeroInflatedNegativeBinomial`.
:param TorchDistribution base_dist: the base distribution.
:param torch.Tensor gate: probability of extra zeros given via a Bernoulli distribution.
:param torch.Tensor gate_logits: logits of extra zeros given via a Bernoulli distribution.
:param TorchDistribution base_dist: the base distribution.
"""
arg_constraints = {
"gate": constraints.unit_interval,
"gate_logits": constraints.real,
}
arg_constraints = {"gate": constraints.unit_interval}
def __init__(self, base_dist, *, gate=None, gate_logits=None, validate_args=None):
if (gate is None) == (gate_logits is None):
raise ValueError(
"Either `gate` or `gate_logits` must be specified, but not both."
)
if gate is not None:
batch_shape = broadcast_shape(gate.shape, base_dist.batch_shape)
self.gate = gate.expand(batch_shape)
else:
batch_shape = broadcast_shape(gate_logits.shape, base_dist.batch_shape)
self.gate_logits = gate_logits.expand(batch_shape)
def __init__(self, gate, base_dist, validate_args=None):
if base_dist.event_shape:
raise ValueError(
"ZeroInflatedDistribution expected empty "
"base_dist.event_shape but got {}".format(base_dist.event_shape)
)
batch_shape = broadcast_shape(gate.shape, base_dist.batch_shape)
self.gate = gate.expand(batch_shape)
self.base_dist = base_dist.expand(batch_shape)
event_shape = torch.Size()
@@ -58,29 +38,13 @@ class ZeroInflatedDistribution(Distribution):
def support(self):
return self.base_dist.support
@lazy_property
def gate(self):
return logits_to_probs(self.gate_logits)
@lazy_property
def gate_logits(self):
return probs_to_logits(self.gate)
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
if "gate" in self.__dict__:
gate, value = broadcast_all(self.gate, value)
log_prob = (-gate).log1p() + self.base_dist.log_prob(value)
log_prob = torch.where(value == 0, (gate + log_prob.exp()).log(), log_prob)
else:
gate_logits, value = broadcast_all(self.gate_logits, value)
log_prob_minus_log_gate = -gate_logits + self.base_dist.log_prob(value)
log_gate = -F.softplus(-gate_logits)
log_prob = log_prob_minus_log_gate + log_gate
zero_log_prob = F.softplus(log_prob_minus_log_gate) + log_gate
log_prob = torch.where(value == 0, zero_log_prob, log_prob)
gate, value = broadcast_all(self.gate, value)
log_prob = (-gate).log1p() + self.base_dist.log_prob(value)
log_prob = torch.where(value == 0, (gate + log_prob.exp()).log(), log_prob)
return log_prob
def sample(self, sample_shape=torch.Size()):
@@ -104,16 +68,9 @@ class ZeroInflatedDistribution(Distribution):
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(type(self), _instance)
batch_shape = torch.Size(batch_shape)
gate = self.gate.expand(batch_shape) if "gate" in self.__dict__ else None
gate_logits = (
self.gate_logits.expand(batch_shape)
if "gate_logits" in self.__dict__
else None
)
gate = self.gate.expand(batch_shape)
base_dist = self.base_dist.expand(batch_shape)
ZeroInflatedDistribution.__init__(
new, base_dist, gate=gate, gate_logits=gate_logits, validate_args=False
)
ZeroInflatedDistribution.__init__(new, gate, base_dist, validate_args=False)
new._validate_args = self._validate_args
return new
@@ -122,25 +79,18 @@ class ZeroInflatedPoisson(ZeroInflatedDistribution):
"""
A Zero Inflated Poisson distribution.
:param torch.Tensor rate: rate of poisson distribution.
:param torch.Tensor gate: probability of extra zeros.
:param torch.Tensor gate_logits: logits of extra zeros.
:param torch.Tensor rate: rate of poisson distribution.
"""
arg_constraints = {
"rate": constraints.positive,
"gate": constraints.unit_interval,
"gate_logits": constraints.real,
}
arg_constraints = {"gate": constraints.unit_interval, "rate": constraints.positive}
support = constraints.nonnegative_integer
def __init__(self, rate, *, gate=None, gate_logits=None, validate_args=None):
def __init__(self, gate, rate, validate_args=None):
base_dist = Poisson(rate=rate, validate_args=False)
base_dist._validate_args = validate_args
super().__init__(
base_dist, gate=gate, gate_logits=gate_logits, validate_args=validate_args
)
super().__init__(gate, base_dist, validate_args=validate_args)
@property
def rate(self):
@@ -151,44 +101,28 @@ class ZeroInflatedNegativeBinomial(ZeroInflatedDistribution):
"""
A Zero Inflated Negative Binomial distribution.
:param torch.Tensor gate: probability of extra zeros.
:param total_count: non-negative number of negative Bernoulli trials.
:type total_count: float or torch.Tensor
:param torch.Tensor probs: Event probabilities of failure in the half open interval [0, 1).
:param torch.Tensor logits: Event log-odds for probabilities of failure.
:param torch.Tensor gate: probability of extra zeros.
:param torch.Tensor gate_logits: logits of extra zeros.
:param torch.Tensor probs: Event probabilities of success in the half open interval [0, 1).
:param torch.Tensor logits: Event log-odds for probabilities of success.
"""
arg_constraints = {
"gate": constraints.unit_interval,
"total_count": constraints.greater_than_eq(0),
"probs": constraints.half_open_interval(0.0, 1.0),
"logits": constraints.real,
"gate": constraints.unit_interval,
"gate_logits": constraints.real,
}
support = constraints.nonnegative_integer
def __init__(
self,
total_count,
*,
probs=None,
logits=None,
gate=None,
gate_logits=None,
validate_args=None
):
def __init__(self, gate, total_count, probs=None, logits=None, validate_args=None):
base_dist = NegativeBinomial(
total_count=total_count,
probs=probs,
logits=logits,
validate_args=False,
total_count=total_count, probs=probs, logits=logits, validate_args=False,
)
base_dist._validate_args = validate_args
super().__init__(
base_dist, gate=gate, gate_logits=gate_logits, validate_args=validate_args
)
super().__init__(gate, base_dist, validate_args=validate_args)
@property
def total_count(self):
+12 -11
View File
@@ -201,7 +201,7 @@ class PoissonOutput(IndependentDistributionOutput):
class ZeroInflatedPoissonOutput(IndependentDistributionOutput):
args_dim: Dict[str, int] = {"gate_logits": 1, "rate": 1}
args_dim: Dict[str, int] = {"gate": 1, "rate": 1}
distr_cls: type = ZeroInflatedPoisson
def __init__(self, dim: Optional[int] = None) -> None:
@@ -210,20 +210,21 @@ class ZeroInflatedPoissonOutput(IndependentDistributionOutput):
self.args_dim = {k: dim for k in self.args_dim}
@classmethod
def domain_map(cls, gate_logits, rate):
def domain_map(cls, gate, rate):
gate_unit = torch.sigmoid(gate).clone()
rate_pos = F.softplus(rate).clone()
return gate_logits.squeeze(-1), rate_pos.squeeze(-1)
return gate_unit.squeeze(-1), rate_pos.squeeze(-1)
def distribution(
self, distr_args, scale: Optional[torch.Tensor] = None
) -> Distribution:
gate_logits, rate = distr_args
gate, rate = distr_args
if scale is not None:
rate *= scale
return self.independent(ZeroInflatedPoisson(rate=rate, gate_logits=gate_logits))
return self.independent(ZeroInflatedPoisson(gate=gate, rate=rate))
class NegativeBinomialOutput(IndependentDistributionOutput):
@@ -254,7 +255,7 @@ class NegativeBinomialOutput(IndependentDistributionOutput):
class ZeroInflatedNegativeBinomialOutput(IndependentDistributionOutput):
args_dim: Dict[str, int] = {"gate_logits": 1, "total_count": 1, "logits": 1}
args_dim: Dict[str, int] = {"gate": 1, "total_count": 1, "logits": 1}
distr_cls: type = ZeroInflatedNegativeBinomial
def __init__(self, dim: Optional[int] = None) -> None:
@@ -263,22 +264,22 @@ class ZeroInflatedNegativeBinomialOutput(IndependentDistributionOutput):
self.args_dim = {k: dim for k in self.args_dim}
@classmethod
def domain_map(cls, gate_logits, total_count, logits):
def domain_map(cls, gate, total_count, logits):
gate = torch.sigmoid(gate)
total_count = F.softplus(total_count)
return gate_logits.squeeze(-1), total_count.squeeze(-1), logits.squeeze(-1)
return gate.squeeze(-1), total_count.squeeze(-1), logits.squeeze(-1)
def distribution(
self, distr_args, scale: Optional[torch.Tensor] = None
) -> Distribution:
gate_logits, total_count, logits = distr_args
gate, total_count, logits = distr_args
if scale is not None:
logits += scale.log()
return self.independent(
ZeroInflatedNegativeBinomial(
total_count=total_count,
gate_logits=gate_logits, logits=logits
gate=gate, total_count=total_count, logits=logits
)
)
+11 -23
View File
@@ -25,7 +25,7 @@ def test_zid_shape(gate_shape, base_shape):
gate = torch.rand(gate_shape)
base_dist = Normal(torch.randn(base_shape), torch.randn(base_shape).exp())
d = ZeroInflatedDistribution(base_dist, gate=gate)
d = ZeroInflatedDistribution(gate, base_dist)
assert d.batch_shape == broadcast_shape(gate_shape, base_shape)
assert d.support == base_dist.support
@@ -36,22 +36,19 @@ def test_zid_shape(gate_shape, base_shape):
@pytest.mark.parametrize("rate", [0.1, 0.5, 0.9, 1.0, 1.1, 2.0, 10.0])
def test_zip_0_gate(rate):
# if gate is 0 ZIP is Poisson
zip1 = ZeroInflatedPoisson(torch.tensor(rate), gate=torch.zeros(1))
zip2 = ZeroInflatedPoisson(torch.tensor(rate), gate_logits=torch.tensor(-99.9))
zip_ = ZeroInflatedPoisson(torch.zeros(1), torch.tensor(rate))
pois = Poisson(torch.tensor(rate))
s = pois.sample((20,))
zip1_prob = zip1.log_prob(s)
zip2_prob = zip2.log_prob(s)
zip_prob = zip_.log_prob(s)
pois_prob = pois.log_prob(s)
assert_close(zip1_prob, pois_prob, atol=1e-05)
assert_close(zip2_prob, pois_prob, atol=1e-05)
assert_close(zip_prob, pois_prob, atol=1e-06)
@pytest.mark.parametrize("gate", [0.0, 0.25, 0.5, 0.75, 1.0])
@pytest.mark.parametrize("rate", [0.1, 0.5, 0.9, 1.0, 1.1, 2.0, 10.0])
def test_zip_mean_variance(gate, rate):
num_samples = 1000000
zip_ = ZeroInflatedPoisson(torch.tensor(rate), gate=torch.tensor(gate))
zip_ = ZeroInflatedPoisson(torch.tensor(gate), torch.tensor(rate))
s = zip_.sample((num_samples,))
expected_mean = zip_.mean
estimated_mean = s.mean()
@@ -65,23 +62,14 @@ def test_zip_mean_variance(gate, rate):
@pytest.mark.parametrize("probs", [0.1, 0.5, 0.9])
def test_zinb_0_gate(total_count, probs):
# if gate is 0 ZINB is NegativeBinomial
zinb1 = ZeroInflatedNegativeBinomial(
total_count=torch.tensor(total_count),
gate=torch.zeros(1),
probs=torch.tensor(probs),
)
zinb2 = ZeroInflatedNegativeBinomial(
total_count=torch.tensor(total_count),
gate_logits=torch.tensor(-99.9),
probs=torch.tensor(probs),
zinb_ = ZeroInflatedNegativeBinomial(
torch.zeros(1), total_count=torch.tensor(total_count), probs=torch.tensor(probs)
)
neg_bin = NegativeBinomial(torch.tensor(total_count), probs=torch.tensor(probs))
s = neg_bin.sample((20,))
zinb1_prob = zinb1.log_prob(s)
zinb2_prob = zinb2.log_prob(s)
zinb_prob = zinb_.log_prob(s)
neg_bin_prob = neg_bin.log_prob(s)
assert_close(zinb1_prob, neg_bin_prob, atol=1e-05)
assert_close(zinb2_prob, neg_bin_prob, atol=1e-05)
assert_close(zinb_prob, neg_bin_prob, atol=1e-06)
@pytest.mark.parametrize("gate", [0.0, 0.25, 0.5, 0.75, 1.0])
@@ -90,8 +78,8 @@ def test_zinb_0_gate(total_count, probs):
def test_zinb_mean_variance(gate, total_count, logits):
num_samples = 1000000
zinb_ = ZeroInflatedNegativeBinomial(
torch.tensor(gate),
total_count=torch.tensor(total_count),
gate=torch.tensor(gate),
logits=torch.tensor(logits),
)
s = zinb_.sample((num_samples,))
@@ -100,4 +88,4 @@ def test_zinb_mean_variance(gate, total_count, logits):
expected_std = zinb_.stddev
estimated_std = s.std()
assert_close(expected_mean, estimated_mean, atol=1e-01)
assert_close(expected_std, estimated_std, atol=1e-01)
assert_close(expected_std, estimated_std, atol=1e-1)