From 3b46a293c50f94ef4e4e265297fcec72f4abd27c Mon Sep 17 00:00:00 2001 From: "Dr. Kashif Rasul" Date: Fri, 17 Jan 2020 11:47:35 +0100 Subject: [PATCH] added Masked Linear Layer --- pts/modules/flows.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/pts/modules/flows.py b/pts/modules/flows.py index ef2b2de..6c5d475 100644 --- a/pts/modules/flows.py +++ b/pts/modules/flows.py @@ -183,6 +183,27 @@ class LinearMaskedCoupling(nn.Module): return x, log_abs_det_jacobian +class MaskedLinear(nn.Linear): + """ MADE building block layer """ + + def __init__(self, input_size, n_outputs, mask, cond_label_size=None): + super().__init__(input_size, n_outputs) + + self.register_buffer("mask", mask) + + self.cond_label_size = cond_label_size + if cond_label_size is not None: + self.cond_weight = nn.Parameter( + torch.rand(n_outputs, cond_label_size) / math.sqrt(cond_label_size) + ) + + def forward(self, x, y=None): + out = F.linear(x, self.weight * self.mask, self.bias) + if y is not None: + out = out + F.linear(y, self.cond_weight) + return out + + class MADE(nn.Module): def __init__( self,