done with function learning tasks.

This commit is contained in:
Kevin
2018-08-04 01:17:21 -07:00
parent 6ff9ff9c38
commit 5bf6e24b94
2 changed files with 4 additions and 3 deletions
+3 -2
View File
@@ -28,13 +28,14 @@ class NAC(nn.Module):
self.W_hat = Parameter(torch.Tensor(out_features, in_features))
self.M_hat = Parameter(torch.Tensor(out_features, in_features))
self.W = F.tanh(self.W_hat) * F.sigmoid(self.M_hat)
self.W = Parameter(F.tanh(self.W_hat) * F.sigmoid(self.M_hat))
self.register_parameter('bias', None)
init.kaiming_uniform_(self.W_hat, a=math.sqrt(5))
init.kaiming_uniform_(self.M_hat, a=math.sqrt(5))
def forward(self, input):
return F.linear(input, self.W, None)
return F.linear(input, self.W, self.bias)
def extra_repr(self):
return 'in_features={}, out_features={}'.format(self.in_features, self.out_features)
+1 -1
View File
@@ -37,7 +37,7 @@ class NALU(nn.Module):
def forward(self, input):
a = self.nac(input)
g = F.sigmoid(F.linear(input, self.G, None))
add_sub = a * g
add_sub = g * a
log_input = torch.log(torch.abs(input) + self.eps)
m = torch.exp(F.linear(log_input, self.W, None))
mul_div = (1 - g) * m