shared weights fix

This commit is contained in:
Kevin
2018-08-06 06:25:12 -07:00
parent 002c97f21e
commit f4941c525f
+2 -4
View File
@@ -25,19 +25,17 @@ class NeuralArithmeticLogicUnitCell(nn.Module):
self.eps = 1e-10
self.G = Parameter(torch.Tensor(out_dim, in_dim))
self.W = Parameter(torch.Tensor(out_dim, in_dim))
self.register_parameter('bias', None)
self.nac = NeuralAccumulatorCell(in_dim, out_dim)
self.register_parameter('bias', None)
init.kaiming_uniform_(self.G, a=math.sqrt(5))
init.kaiming_uniform_(self.W, a=math.sqrt(5))
def forward(self, input):
a = self.nac(input)
g = F.sigmoid(F.linear(input, self.G, self.bias))
add_sub = g * a
log_input = torch.log(torch.abs(input) + self.eps)
m = torch.exp(F.linear(log_input, self.W, self.bias))
m = torch.exp(self.nac(log_input))
mul_div = (1 - g) * m
y = add_sub + mul_div
return y