mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-07-03 15:44:31 +08:00
added Masked Linear Layer
This commit is contained in:
@@ -183,6 +183,27 @@ class LinearMaskedCoupling(nn.Module):
|
||||
return x, log_abs_det_jacobian
|
||||
|
||||
|
||||
class MaskedLinear(nn.Linear):
|
||||
""" MADE building block layer """
|
||||
|
||||
def __init__(self, input_size, n_outputs, mask, cond_label_size=None):
|
||||
super().__init__(input_size, n_outputs)
|
||||
|
||||
self.register_buffer("mask", mask)
|
||||
|
||||
self.cond_label_size = cond_label_size
|
||||
if cond_label_size is not None:
|
||||
self.cond_weight = nn.Parameter(
|
||||
torch.rand(n_outputs, cond_label_size) / math.sqrt(cond_label_size)
|
||||
)
|
||||
|
||||
def forward(self, x, y=None):
|
||||
out = F.linear(x, self.weight * self.mask, self.bias)
|
||||
if y is not None:
|
||||
out = out + F.linear(y, self.cond_weight)
|
||||
return out
|
||||
|
||||
|
||||
class MADE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user