added Masked Linear Layer

This commit is contained in:
Dr. Kashif Rasul
2020-01-17 11:47:35 +01:00
parent 04bdd7bbf2
commit 3b46a293c5
+21
View File
@@ -183,6 +183,27 @@ class LinearMaskedCoupling(nn.Module):
return x, log_abs_det_jacobian
class MaskedLinear(nn.Linear):
""" MADE building block layer """
def __init__(self, input_size, n_outputs, mask, cond_label_size=None):
super().__init__(input_size, n_outputs)
self.register_buffer("mask", mask)
self.cond_label_size = cond_label_size
if cond_label_size is not None:
self.cond_weight = nn.Parameter(
torch.rand(n_outputs, cond_label_size) / math.sqrt(cond_label_size)
)
def forward(self, x, y=None):
out = F.linear(x, self.weight * self.mask, self.bias)
if y is not None:
out = out + F.linear(y, self.cond_weight)
return out
class MADE(nn.Module):
def __init__(
self,