diff --git a/pts/distributions/zero_inflated.py b/pts/distributions/zero_inflated.py index 266fc82..2fa91dd 100644 --- a/pts/distributions/zero_inflated.py +++ b/pts/distributions/zero_inflated.py @@ -2,8 +2,15 @@ # SPDX-License-Identifier: Apache-2.0 import torch +import torch.nn.functional as F from torch.distributions import constraints, NegativeBinomial, Poisson, Distribution -from torch.distributions.utils import broadcast_all, lazy_property +from torch.distributions.utils import ( + broadcast_all, + lazy_property, + lazy_property, + logits_to_probs, + probs_to_logits, +) from .utils import broadcast_shape @@ -15,20 +22,33 @@ class ZeroInflatedDistribution(Distribution): This can be used directly or can be used as a base class as e.g. for :class:`ZeroInflatedPoisson` and :class:`ZeroInflatedNegativeBinomial`. - :param torch.Tensor gate: probability of extra zeros given via a Bernoulli distribution. :param TorchDistribution base_dist: the base distribution. + :param torch.Tensor gate: probability of extra zeros given via a Bernoulli distribution. + :param torch.Tensor gate_logits: logits of extra zeros given via a Bernoulli distribution. """ - arg_constraints = {"gate": constraints.unit_interval} + arg_constraints = { + "gate": constraints.unit_interval, + "gate_logits": constraints.real, + } - def __init__(self, gate, base_dist, validate_args=None): + def __init__(self, base_dist, *, gate=None, gate_logits=None, validate_args=None): + if (gate is None) == (gate_logits is None): + raise ValueError( + "Either `gate` or `gate_logits` must be specified, but not both." + ) + if gate is not None: + batch_shape = broadcast_shape(gate.shape, base_dist.batch_shape) + self.gate = gate.expand(batch_shape) + else: + batch_shape = broadcast_shape(gate_logits.shape, base_dist.batch_shape) + self.gate_logits = gate_logits.expand(batch_shape) if base_dist.event_shape: raise ValueError( "ZeroInflatedDistribution expected empty " "base_dist.event_shape but got {}".format(base_dist.event_shape) ) - batch_shape = broadcast_shape(gate.shape, base_dist.batch_shape) - self.gate = gate.expand(batch_shape) + self.base_dist = base_dist.expand(batch_shape) event_shape = torch.Size() @@ -38,13 +58,29 @@ class ZeroInflatedDistribution(Distribution): def support(self): return self.base_dist.support + @lazy_property + def gate(self): + return logits_to_probs(self.gate_logits) + + @lazy_property + def gate_logits(self): + return probs_to_logits(self.gate) + def log_prob(self, value): if self._validate_args: self._validate_sample(value) - gate, value = broadcast_all(self.gate, value) - log_prob = (-gate).log1p() + self.base_dist.log_prob(value) - log_prob = torch.where(value == 0, (gate + log_prob.exp()).log(), log_prob) + if "gate" in self.__dict__: + gate, value = broadcast_all(self.gate, value) + log_prob = (-gate).log1p() + self.base_dist.log_prob(value) + log_prob = torch.where(value == 0, (gate + log_prob.exp()).log(), log_prob) + else: + gate_logits, value = broadcast_all(self.gate_logits, value) + log_prob_minus_log_gate = -gate_logits + self.base_dist.log_prob(value) + log_gate = -F.softplus(-gate_logits) + log_prob = log_prob_minus_log_gate + log_gate + zero_log_prob = F.softplus(log_prob_minus_log_gate) + log_gate + log_prob = torch.where(value == 0, zero_log_prob, log_prob) return log_prob def sample(self, sample_shape=torch.Size()): @@ -68,9 +104,16 @@ class ZeroInflatedDistribution(Distribution): def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(type(self), _instance) batch_shape = torch.Size(batch_shape) - gate = self.gate.expand(batch_shape) + gate = self.gate.expand(batch_shape) if "gate" in self.__dict__ else None + gate_logits = ( + self.gate_logits.expand(batch_shape) + if "gate_logits" in self.__dict__ + else None + ) base_dist = self.base_dist.expand(batch_shape) - ZeroInflatedDistribution.__init__(new, gate, base_dist, validate_args=False) + ZeroInflatedDistribution.__init__( + new, base_dist, gate=gate, gate_logits=gate_logits, validate_args=False + ) new._validate_args = self._validate_args return new @@ -79,18 +122,25 @@ class ZeroInflatedPoisson(ZeroInflatedDistribution): """ A Zero Inflated Poisson distribution. - :param torch.Tensor gate: probability of extra zeros. :param torch.Tensor rate: rate of poisson distribution. + :param torch.Tensor gate: probability of extra zeros. + :param torch.Tensor gate_logits: logits of extra zeros. """ - arg_constraints = {"gate": constraints.unit_interval, "rate": constraints.positive} + arg_constraints = { + "rate": constraints.positive, + "gate": constraints.unit_interval, + "gate_logits": constraints.real, + } support = constraints.nonnegative_integer - def __init__(self, gate, rate, validate_args=None): + def __init__(self, rate, *, gate=None, gate_logits=None, validate_args=None): base_dist = Poisson(rate=rate, validate_args=False) base_dist._validate_args = validate_args - super().__init__(gate, base_dist, validate_args=validate_args) + super().__init__( + base_dist, gate=gate, gate_logits=gate_logits, validate_args=validate_args + ) @property def rate(self): @@ -101,28 +151,44 @@ class ZeroInflatedNegativeBinomial(ZeroInflatedDistribution): """ A Zero Inflated Negative Binomial distribution. - :param torch.Tensor gate: probability of extra zeros. :param total_count: non-negative number of negative Bernoulli trials. :type total_count: float or torch.Tensor - :param torch.Tensor probs: Event probabilities of success in the half open interval [0, 1). - :param torch.Tensor logits: Event log-odds for probabilities of success. + :param torch.Tensor probs: Event probabilities of failure in the half open interval [0, 1). + :param torch.Tensor logits: Event log-odds for probabilities of failure. + :param torch.Tensor gate: probability of extra zeros. + :param torch.Tensor gate_logits: logits of extra zeros. """ arg_constraints = { - "gate": constraints.unit_interval, "total_count": constraints.greater_than_eq(0), "probs": constraints.half_open_interval(0.0, 1.0), "logits": constraints.real, + "gate": constraints.unit_interval, + "gate_logits": constraints.real, } support = constraints.nonnegative_integer - def __init__(self, gate, total_count, probs=None, logits=None, validate_args=None): + def __init__( + self, + total_count, + *, + probs=None, + logits=None, + gate=None, + gate_logits=None, + validate_args=None + ): base_dist = NegativeBinomial( - total_count=total_count, probs=probs, logits=logits, validate_args=False, + total_count=total_count, + probs=probs, + logits=logits, + validate_args=False, ) base_dist._validate_args = validate_args - super().__init__(gate, base_dist, validate_args=validate_args) + super().__init__( + base_dist, gate=gate, gate_logits=gate_logits, validate_args=validate_args + ) @property def total_count(self): diff --git a/pts/modules/distribution_output.py b/pts/modules/distribution_output.py index 6aff1e1..10f04a6 100644 --- a/pts/modules/distribution_output.py +++ b/pts/modules/distribution_output.py @@ -201,7 +201,7 @@ class PoissonOutput(IndependentDistributionOutput): class ZeroInflatedPoissonOutput(IndependentDistributionOutput): - args_dim: Dict[str, int] = {"gate": 1, "rate": 1} + args_dim: Dict[str, int] = {"gate_logits": 1, "rate": 1} distr_cls: type = ZeroInflatedPoisson def __init__(self, dim: Optional[int] = None) -> None: @@ -210,21 +210,20 @@ class ZeroInflatedPoissonOutput(IndependentDistributionOutput): self.args_dim = {k: dim for k in self.args_dim} @classmethod - def domain_map(cls, gate, rate): - gate_unit = torch.sigmoid(gate).clone() + def domain_map(cls, gate_logits, rate): rate_pos = F.softplus(rate).clone() - return gate_unit.squeeze(-1), rate_pos.squeeze(-1) + return gate_logits.squeeze(-1), rate_pos.squeeze(-1) def distribution( self, distr_args, scale: Optional[torch.Tensor] = None ) -> Distribution: - gate, rate = distr_args + gate_logits, rate = distr_args if scale is not None: rate *= scale - return self.independent(ZeroInflatedPoisson(gate=gate, rate=rate)) + return self.independent(ZeroInflatedPoisson(rate=rate, gate_logits=gate_logits)) class NegativeBinomialOutput(IndependentDistributionOutput): @@ -255,7 +254,7 @@ class NegativeBinomialOutput(IndependentDistributionOutput): class ZeroInflatedNegativeBinomialOutput(IndependentDistributionOutput): - args_dim: Dict[str, int] = {"gate": 1, "total_count": 1, "logits": 1} + args_dim: Dict[str, int] = {"gate_logits": 1, "total_count": 1, "logits": 1} distr_cls: type = ZeroInflatedNegativeBinomial def __init__(self, dim: Optional[int] = None) -> None: @@ -264,22 +263,22 @@ class ZeroInflatedNegativeBinomialOutput(IndependentDistributionOutput): self.args_dim = {k: dim for k in self.args_dim} @classmethod - def domain_map(cls, gate, total_count, logits): - gate = torch.sigmoid(gate) + def domain_map(cls, gate_logits, total_count, logits): total_count = F.softplus(total_count) - return gate.squeeze(-1), total_count.squeeze(-1), logits.squeeze(-1) + return gate_logits.squeeze(-1), total_count.squeeze(-1), logits.squeeze(-1) def distribution( self, distr_args, scale: Optional[torch.Tensor] = None ) -> Distribution: - gate, total_count, logits = distr_args + gate_logits, total_count, logits = distr_args if scale is not None: logits += scale.log() return self.independent( ZeroInflatedNegativeBinomial( - gate=gate, total_count=total_count, logits=logits + total_count=total_count, + gate_logits=gate_logits, logits=logits ) ) diff --git a/test/distributions/test_zero_inflated.py b/test/distributions/test_zero_inflated.py index 3609f9f..bb97382 100644 --- a/test/distributions/test_zero_inflated.py +++ b/test/distributions/test_zero_inflated.py @@ -25,7 +25,7 @@ def test_zid_shape(gate_shape, base_shape): gate = torch.rand(gate_shape) base_dist = Normal(torch.randn(base_shape), torch.randn(base_shape).exp()) - d = ZeroInflatedDistribution(gate, base_dist) + d = ZeroInflatedDistribution(base_dist, gate=gate) assert d.batch_shape == broadcast_shape(gate_shape, base_shape) assert d.support == base_dist.support @@ -36,19 +36,22 @@ def test_zid_shape(gate_shape, base_shape): @pytest.mark.parametrize("rate", [0.1, 0.5, 0.9, 1.0, 1.1, 2.0, 10.0]) def test_zip_0_gate(rate): # if gate is 0 ZIP is Poisson - zip_ = ZeroInflatedPoisson(torch.zeros(1), torch.tensor(rate)) + zip1 = ZeroInflatedPoisson(torch.tensor(rate), gate=torch.zeros(1)) + zip2 = ZeroInflatedPoisson(torch.tensor(rate), gate_logits=torch.tensor(-99.9)) pois = Poisson(torch.tensor(rate)) s = pois.sample((20,)) - zip_prob = zip_.log_prob(s) + zip1_prob = zip1.log_prob(s) + zip2_prob = zip2.log_prob(s) pois_prob = pois.log_prob(s) - assert_close(zip_prob, pois_prob, atol=1e-06) + assert_close(zip1_prob, pois_prob, atol=1e-05) + assert_close(zip2_prob, pois_prob, atol=1e-05) @pytest.mark.parametrize("gate", [0.0, 0.25, 0.5, 0.75, 1.0]) @pytest.mark.parametrize("rate", [0.1, 0.5, 0.9, 1.0, 1.1, 2.0, 10.0]) def test_zip_mean_variance(gate, rate): num_samples = 1000000 - zip_ = ZeroInflatedPoisson(torch.tensor(gate), torch.tensor(rate)) + zip_ = ZeroInflatedPoisson(torch.tensor(rate), gate=torch.tensor(gate)) s = zip_.sample((num_samples,)) expected_mean = zip_.mean estimated_mean = s.mean() @@ -62,14 +65,23 @@ def test_zip_mean_variance(gate, rate): @pytest.mark.parametrize("probs", [0.1, 0.5, 0.9]) def test_zinb_0_gate(total_count, probs): # if gate is 0 ZINB is NegativeBinomial - zinb_ = ZeroInflatedNegativeBinomial( - torch.zeros(1), total_count=torch.tensor(total_count), probs=torch.tensor(probs) + zinb1 = ZeroInflatedNegativeBinomial( + total_count=torch.tensor(total_count), + gate=torch.zeros(1), + probs=torch.tensor(probs), + ) + zinb2 = ZeroInflatedNegativeBinomial( + total_count=torch.tensor(total_count), + gate_logits=torch.tensor(-99.9), + probs=torch.tensor(probs), ) neg_bin = NegativeBinomial(torch.tensor(total_count), probs=torch.tensor(probs)) s = neg_bin.sample((20,)) - zinb_prob = zinb_.log_prob(s) + zinb1_prob = zinb1.log_prob(s) + zinb2_prob = zinb2.log_prob(s) neg_bin_prob = neg_bin.log_prob(s) - assert_close(zinb_prob, neg_bin_prob, atol=1e-06) + assert_close(zinb1_prob, neg_bin_prob, atol=1e-05) + assert_close(zinb2_prob, neg_bin_prob, atol=1e-05) @pytest.mark.parametrize("gate", [0.0, 0.25, 0.5, 0.75, 1.0]) @@ -78,8 +90,8 @@ def test_zinb_0_gate(total_count, probs): def test_zinb_mean_variance(gate, total_count, logits): num_samples = 1000000 zinb_ = ZeroInflatedNegativeBinomial( - torch.tensor(gate), total_count=torch.tensor(total_count), + gate=torch.tensor(gate), logits=torch.tensor(logits), ) s = zinb_.sample((num_samples,)) @@ -88,4 +100,4 @@ def test_zinb_mean_variance(gate, total_count, logits): expected_std = zinb_.stddev estimated_std = s.std() assert_close(expected_mean, estimated_mean, atol=1e-01) - assert_close(expected_std, estimated_std, atol=1e-1) + assert_close(expected_std, estimated_std, atol=1e-01)