mirror of
https://github.com/wassname/NALU-pytorch.git
synced 2026-06-27 16:00:06 +08:00
shared weights fix
This commit is contained in:
+2
-4
@@ -25,19 +25,17 @@ class NeuralArithmeticLogicUnitCell(nn.Module):
|
||||
self.eps = 1e-10
|
||||
|
||||
self.G = Parameter(torch.Tensor(out_dim, in_dim))
|
||||
self.W = Parameter(torch.Tensor(out_dim, in_dim))
|
||||
self.register_parameter('bias', None)
|
||||
self.nac = NeuralAccumulatorCell(in_dim, out_dim)
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
init.kaiming_uniform_(self.G, a=math.sqrt(5))
|
||||
init.kaiming_uniform_(self.W, a=math.sqrt(5))
|
||||
|
||||
def forward(self, input):
|
||||
a = self.nac(input)
|
||||
g = F.sigmoid(F.linear(input, self.G, self.bias))
|
||||
add_sub = g * a
|
||||
log_input = torch.log(torch.abs(input) + self.eps)
|
||||
m = torch.exp(F.linear(log_input, self.W, self.bias))
|
||||
m = torch.exp(self.nac(log_input))
|
||||
mul_div = (1 - g) * m
|
||||
y = add_sub + mul_div
|
||||
return y
|
||||
|
||||
Reference in New Issue
Block a user