diff --git a/pts/distributions/zero_inflated.py b/pts/distributions/zero_inflated.py index 2fa91dd..266fc82 100644 --- a/pts/distributions/zero_inflated.py +++ b/pts/distributions/zero_inflated.py @@ -2,15 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 import torch -import torch.nn.functional as F from torch.distributions import constraints, NegativeBinomial, Poisson, Distribution -from torch.distributions.utils import ( - broadcast_all, - lazy_property, - lazy_property, - logits_to_probs, - probs_to_logits, -) +from torch.distributions.utils import broadcast_all, lazy_property from .utils import broadcast_shape @@ -22,33 +15,20 @@ class ZeroInflatedDistribution(Distribution): This can be used directly or can be used as a base class as e.g. for :class:`ZeroInflatedPoisson` and :class:`ZeroInflatedNegativeBinomial`. - :param TorchDistribution base_dist: the base distribution. :param torch.Tensor gate: probability of extra zeros given via a Bernoulli distribution. - :param torch.Tensor gate_logits: logits of extra zeros given via a Bernoulli distribution. + :param TorchDistribution base_dist: the base distribution. """ - arg_constraints = { - "gate": constraints.unit_interval, - "gate_logits": constraints.real, - } + arg_constraints = {"gate": constraints.unit_interval} - def __init__(self, base_dist, *, gate=None, gate_logits=None, validate_args=None): - if (gate is None) == (gate_logits is None): - raise ValueError( - "Either `gate` or `gate_logits` must be specified, but not both." - ) - if gate is not None: - batch_shape = broadcast_shape(gate.shape, base_dist.batch_shape) - self.gate = gate.expand(batch_shape) - else: - batch_shape = broadcast_shape(gate_logits.shape, base_dist.batch_shape) - self.gate_logits = gate_logits.expand(batch_shape) + def __init__(self, gate, base_dist, validate_args=None): if base_dist.event_shape: raise ValueError( "ZeroInflatedDistribution expected empty " "base_dist.event_shape but got {}".format(base_dist.event_shape) ) - + batch_shape = broadcast_shape(gate.shape, base_dist.batch_shape) + self.gate = gate.expand(batch_shape) self.base_dist = base_dist.expand(batch_shape) event_shape = torch.Size() @@ -58,29 +38,13 @@ class ZeroInflatedDistribution(Distribution): def support(self): return self.base_dist.support - @lazy_property - def gate(self): - return logits_to_probs(self.gate_logits) - - @lazy_property - def gate_logits(self): - return probs_to_logits(self.gate) - def log_prob(self, value): if self._validate_args: self._validate_sample(value) - if "gate" in self.__dict__: - gate, value = broadcast_all(self.gate, value) - log_prob = (-gate).log1p() + self.base_dist.log_prob(value) - log_prob = torch.where(value == 0, (gate + log_prob.exp()).log(), log_prob) - else: - gate_logits, value = broadcast_all(self.gate_logits, value) - log_prob_minus_log_gate = -gate_logits + self.base_dist.log_prob(value) - log_gate = -F.softplus(-gate_logits) - log_prob = log_prob_minus_log_gate + log_gate - zero_log_prob = F.softplus(log_prob_minus_log_gate) + log_gate - log_prob = torch.where(value == 0, zero_log_prob, log_prob) + gate, value = broadcast_all(self.gate, value) + log_prob = (-gate).log1p() + self.base_dist.log_prob(value) + log_prob = torch.where(value == 0, (gate + log_prob.exp()).log(), log_prob) return log_prob def sample(self, sample_shape=torch.Size()): @@ -104,16 +68,9 @@ class ZeroInflatedDistribution(Distribution): def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(type(self), _instance) batch_shape = torch.Size(batch_shape) - gate = self.gate.expand(batch_shape) if "gate" in self.__dict__ else None - gate_logits = ( - self.gate_logits.expand(batch_shape) - if "gate_logits" in self.__dict__ - else None - ) + gate = self.gate.expand(batch_shape) base_dist = self.base_dist.expand(batch_shape) - ZeroInflatedDistribution.__init__( - new, base_dist, gate=gate, gate_logits=gate_logits, validate_args=False - ) + ZeroInflatedDistribution.__init__(new, gate, base_dist, validate_args=False) new._validate_args = self._validate_args return new @@ -122,25 +79,18 @@ class ZeroInflatedPoisson(ZeroInflatedDistribution): """ A Zero Inflated Poisson distribution. - :param torch.Tensor rate: rate of poisson distribution. :param torch.Tensor gate: probability of extra zeros. - :param torch.Tensor gate_logits: logits of extra zeros. + :param torch.Tensor rate: rate of poisson distribution. """ - arg_constraints = { - "rate": constraints.positive, - "gate": constraints.unit_interval, - "gate_logits": constraints.real, - } + arg_constraints = {"gate": constraints.unit_interval, "rate": constraints.positive} support = constraints.nonnegative_integer - def __init__(self, rate, *, gate=None, gate_logits=None, validate_args=None): + def __init__(self, gate, rate, validate_args=None): base_dist = Poisson(rate=rate, validate_args=False) base_dist._validate_args = validate_args - super().__init__( - base_dist, gate=gate, gate_logits=gate_logits, validate_args=validate_args - ) + super().__init__(gate, base_dist, validate_args=validate_args) @property def rate(self): @@ -151,44 +101,28 @@ class ZeroInflatedNegativeBinomial(ZeroInflatedDistribution): """ A Zero Inflated Negative Binomial distribution. + :param torch.Tensor gate: probability of extra zeros. :param total_count: non-negative number of negative Bernoulli trials. :type total_count: float or torch.Tensor - :param torch.Tensor probs: Event probabilities of failure in the half open interval [0, 1). - :param torch.Tensor logits: Event log-odds for probabilities of failure. - :param torch.Tensor gate: probability of extra zeros. - :param torch.Tensor gate_logits: logits of extra zeros. + :param torch.Tensor probs: Event probabilities of success in the half open interval [0, 1). + :param torch.Tensor logits: Event log-odds for probabilities of success. """ arg_constraints = { + "gate": constraints.unit_interval, "total_count": constraints.greater_than_eq(0), "probs": constraints.half_open_interval(0.0, 1.0), "logits": constraints.real, - "gate": constraints.unit_interval, - "gate_logits": constraints.real, } support = constraints.nonnegative_integer - def __init__( - self, - total_count, - *, - probs=None, - logits=None, - gate=None, - gate_logits=None, - validate_args=None - ): + def __init__(self, gate, total_count, probs=None, logits=None, validate_args=None): base_dist = NegativeBinomial( - total_count=total_count, - probs=probs, - logits=logits, - validate_args=False, + total_count=total_count, probs=probs, logits=logits, validate_args=False, ) base_dist._validate_args = validate_args - super().__init__( - base_dist, gate=gate, gate_logits=gate_logits, validate_args=validate_args - ) + super().__init__(gate, base_dist, validate_args=validate_args) @property def total_count(self): diff --git a/pts/modules/distribution_output.py b/pts/modules/distribution_output.py index 10f04a6..6aff1e1 100644 --- a/pts/modules/distribution_output.py +++ b/pts/modules/distribution_output.py @@ -201,7 +201,7 @@ class PoissonOutput(IndependentDistributionOutput): class ZeroInflatedPoissonOutput(IndependentDistributionOutput): - args_dim: Dict[str, int] = {"gate_logits": 1, "rate": 1} + args_dim: Dict[str, int] = {"gate": 1, "rate": 1} distr_cls: type = ZeroInflatedPoisson def __init__(self, dim: Optional[int] = None) -> None: @@ -210,20 +210,21 @@ class ZeroInflatedPoissonOutput(IndependentDistributionOutput): self.args_dim = {k: dim for k in self.args_dim} @classmethod - def domain_map(cls, gate_logits, rate): + def domain_map(cls, gate, rate): + gate_unit = torch.sigmoid(gate).clone() rate_pos = F.softplus(rate).clone() - return gate_logits.squeeze(-1), rate_pos.squeeze(-1) + return gate_unit.squeeze(-1), rate_pos.squeeze(-1) def distribution( self, distr_args, scale: Optional[torch.Tensor] = None ) -> Distribution: - gate_logits, rate = distr_args + gate, rate = distr_args if scale is not None: rate *= scale - return self.independent(ZeroInflatedPoisson(rate=rate, gate_logits=gate_logits)) + return self.independent(ZeroInflatedPoisson(gate=gate, rate=rate)) class NegativeBinomialOutput(IndependentDistributionOutput): @@ -254,7 +255,7 @@ class NegativeBinomialOutput(IndependentDistributionOutput): class ZeroInflatedNegativeBinomialOutput(IndependentDistributionOutput): - args_dim: Dict[str, int] = {"gate_logits": 1, "total_count": 1, "logits": 1} + args_dim: Dict[str, int] = {"gate": 1, "total_count": 1, "logits": 1} distr_cls: type = ZeroInflatedNegativeBinomial def __init__(self, dim: Optional[int] = None) -> None: @@ -263,22 +264,22 @@ class ZeroInflatedNegativeBinomialOutput(IndependentDistributionOutput): self.args_dim = {k: dim for k in self.args_dim} @classmethod - def domain_map(cls, gate_logits, total_count, logits): + def domain_map(cls, gate, total_count, logits): + gate = torch.sigmoid(gate) total_count = F.softplus(total_count) - return gate_logits.squeeze(-1), total_count.squeeze(-1), logits.squeeze(-1) + return gate.squeeze(-1), total_count.squeeze(-1), logits.squeeze(-1) def distribution( self, distr_args, scale: Optional[torch.Tensor] = None ) -> Distribution: - gate_logits, total_count, logits = distr_args + gate, total_count, logits = distr_args if scale is not None: logits += scale.log() return self.independent( ZeroInflatedNegativeBinomial( - total_count=total_count, - gate_logits=gate_logits, logits=logits + gate=gate, total_count=total_count, logits=logits ) ) diff --git a/test/distributions/test_zero_inflated.py b/test/distributions/test_zero_inflated.py index bb97382..3609f9f 100644 --- a/test/distributions/test_zero_inflated.py +++ b/test/distributions/test_zero_inflated.py @@ -25,7 +25,7 @@ def test_zid_shape(gate_shape, base_shape): gate = torch.rand(gate_shape) base_dist = Normal(torch.randn(base_shape), torch.randn(base_shape).exp()) - d = ZeroInflatedDistribution(base_dist, gate=gate) + d = ZeroInflatedDistribution(gate, base_dist) assert d.batch_shape == broadcast_shape(gate_shape, base_shape) assert d.support == base_dist.support @@ -36,22 +36,19 @@ def test_zid_shape(gate_shape, base_shape): @pytest.mark.parametrize("rate", [0.1, 0.5, 0.9, 1.0, 1.1, 2.0, 10.0]) def test_zip_0_gate(rate): # if gate is 0 ZIP is Poisson - zip1 = ZeroInflatedPoisson(torch.tensor(rate), gate=torch.zeros(1)) - zip2 = ZeroInflatedPoisson(torch.tensor(rate), gate_logits=torch.tensor(-99.9)) + zip_ = ZeroInflatedPoisson(torch.zeros(1), torch.tensor(rate)) pois = Poisson(torch.tensor(rate)) s = pois.sample((20,)) - zip1_prob = zip1.log_prob(s) - zip2_prob = zip2.log_prob(s) + zip_prob = zip_.log_prob(s) pois_prob = pois.log_prob(s) - assert_close(zip1_prob, pois_prob, atol=1e-05) - assert_close(zip2_prob, pois_prob, atol=1e-05) + assert_close(zip_prob, pois_prob, atol=1e-06) @pytest.mark.parametrize("gate", [0.0, 0.25, 0.5, 0.75, 1.0]) @pytest.mark.parametrize("rate", [0.1, 0.5, 0.9, 1.0, 1.1, 2.0, 10.0]) def test_zip_mean_variance(gate, rate): num_samples = 1000000 - zip_ = ZeroInflatedPoisson(torch.tensor(rate), gate=torch.tensor(gate)) + zip_ = ZeroInflatedPoisson(torch.tensor(gate), torch.tensor(rate)) s = zip_.sample((num_samples,)) expected_mean = zip_.mean estimated_mean = s.mean() @@ -65,23 +62,14 @@ def test_zip_mean_variance(gate, rate): @pytest.mark.parametrize("probs", [0.1, 0.5, 0.9]) def test_zinb_0_gate(total_count, probs): # if gate is 0 ZINB is NegativeBinomial - zinb1 = ZeroInflatedNegativeBinomial( - total_count=torch.tensor(total_count), - gate=torch.zeros(1), - probs=torch.tensor(probs), - ) - zinb2 = ZeroInflatedNegativeBinomial( - total_count=torch.tensor(total_count), - gate_logits=torch.tensor(-99.9), - probs=torch.tensor(probs), + zinb_ = ZeroInflatedNegativeBinomial( + torch.zeros(1), total_count=torch.tensor(total_count), probs=torch.tensor(probs) ) neg_bin = NegativeBinomial(torch.tensor(total_count), probs=torch.tensor(probs)) s = neg_bin.sample((20,)) - zinb1_prob = zinb1.log_prob(s) - zinb2_prob = zinb2.log_prob(s) + zinb_prob = zinb_.log_prob(s) neg_bin_prob = neg_bin.log_prob(s) - assert_close(zinb1_prob, neg_bin_prob, atol=1e-05) - assert_close(zinb2_prob, neg_bin_prob, atol=1e-05) + assert_close(zinb_prob, neg_bin_prob, atol=1e-06) @pytest.mark.parametrize("gate", [0.0, 0.25, 0.5, 0.75, 1.0]) @@ -90,8 +78,8 @@ def test_zinb_0_gate(total_count, probs): def test_zinb_mean_variance(gate, total_count, logits): num_samples = 1000000 zinb_ = ZeroInflatedNegativeBinomial( + torch.tensor(gate), total_count=torch.tensor(total_count), - gate=torch.tensor(gate), logits=torch.tensor(logits), ) s = zinb_.sample((num_samples,)) @@ -100,4 +88,4 @@ def test_zinb_mean_variance(gate, total_count, logits): expected_std = zinb_.stddev estimated_std = s.std() assert_close(expected_mean, estimated_mean, atol=1e-01) - assert_close(expected_std, estimated_std, atol=1e-01) + assert_close(expected_std, estimated_std, atol=1e-1)