From 5bf6e24b94808bad99659201ce09d238a5e7c5f7 Mon Sep 17 00:00:00 2001 From: Kevin Date: Sat, 4 Aug 2018 01:17:21 -0700 Subject: [PATCH] done with function learning tasks. --- models/nac.py | 5 +++-- models/nalu.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/models/nac.py b/models/nac.py index 0791186..01fdb78 100644 --- a/models/nac.py +++ b/models/nac.py @@ -28,13 +28,14 @@ class NAC(nn.Module): self.W_hat = Parameter(torch.Tensor(out_features, in_features)) self.M_hat = Parameter(torch.Tensor(out_features, in_features)) - self.W = F.tanh(self.W_hat) * F.sigmoid(self.M_hat) + self.W = Parameter(F.tanh(self.W_hat) * F.sigmoid(self.M_hat)) + self.register_parameter('bias', None) init.kaiming_uniform_(self.W_hat, a=math.sqrt(5)) init.kaiming_uniform_(self.M_hat, a=math.sqrt(5)) def forward(self, input): - return F.linear(input, self.W, None) + return F.linear(input, self.W, self.bias) def extra_repr(self): return 'in_features={}, out_features={}'.format(self.in_features, self.out_features) diff --git a/models/nalu.py b/models/nalu.py index 5c157d7..b2b0474 100644 --- a/models/nalu.py +++ b/models/nalu.py @@ -37,7 +37,7 @@ class NALU(nn.Module): def forward(self, input): a = self.nac(input) g = F.sigmoid(F.linear(input, self.G, None)) - add_sub = a * g + add_sub = g * a log_input = torch.log(torch.abs(input) + self.eps) m = torch.exp(F.linear(log_input, self.W, None)) mul_div = (1 - g) * m