diff --git a/README.md b/README.md index e69de29..07e4edd 100644 --- a/README.md +++ b/README.md @@ -0,0 +1 @@ +# Neural Arithmetic Logic Units diff --git a/extrapolation.py b/extrapolation.py index c901e7b..c001e41 100644 --- a/extrapolation.py +++ b/extrapolation.py @@ -11,10 +11,13 @@ TEST_RANGE = [-20, 20] LEARNING_RATE = 1e-2 NUM_ITERS = int(1e4) NON_LINEARITIES = [ - 'hardtanh', 'sigmoid', 'relu6', 'tanh', - 'tanhshrink', 'hardshrink', 'leakyrelu', - 'softshrink', 'softsign', 'relu', - 'prelu', 'softplus', 'elu', 'selu', + 'hardtanh', 'sigmoid', + 'relu6', 'tanh', + 'tanhshrink', 'hardshrink', + 'leakyrelu', 'softshrink', + 'softsign', 'relu', + 'prelu', 'softplus', + 'elu', 'selu', ] diff --git a/imgs/extrapolation.png b/imgs/extrapolation.png index 778852a..503d07d 100644 Binary files a/imgs/extrapolation.png and b/imgs/extrapolation.png differ diff --git a/models/__init__.py b/models/__init__.py index 0f2e1bf..33f5ede 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1 +1,3 @@ from .mlp import MLP +from .nac import NAC +from .nalu import NALU diff --git a/models/nac.py b/models/nac.py new file mode 100644 index 0000000..0791186 --- /dev/null +++ b/models/nac.py @@ -0,0 +1,40 @@ +import math +import torch +import torch.nn as nn +import torch.nn.init as init +import torch.nn.functional as F + +from torch.nn.parameter import Parameter + + +class NAC(nn.Module): + """A Neural Accumulator [1]. + + NAC supports the ability to accumulate quantities + additively which is a desirable inductive bias for + linear extrapolation. + + Attributes: + in_features: size of the input sample. + out_features: size of the output sample. + + Sources: + [1]: https://arxiv.org/abs/1808.00508 + """ + def __init__(self, in_features, out_features): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.W_hat = Parameter(torch.Tensor(out_features, in_features)) + self.M_hat = Parameter(torch.Tensor(out_features, in_features)) + self.W = F.tanh(self.W_hat) * F.sigmoid(self.M_hat) + + init.kaiming_uniform_(self.W_hat, a=math.sqrt(5)) + init.kaiming_uniform_(self.M_hat, a=math.sqrt(5)) + + def forward(self, input): + return F.linear(input, self.W, None) + + def extra_repr(self): + return 'in_features={}, out_features={}'.format(self.in_features, self.out_features) diff --git a/models/nalu.py b/models/nalu.py new file mode 100644 index 0000000..5c157d7 --- /dev/null +++ b/models/nalu.py @@ -0,0 +1,48 @@ +import math +import torch +import torch.nn as nn +import torch.nn.init as init +import torch.nn.functional as F + +from .nac import NAC +from torch.nn.parameter import Parameter + + +class NALU(nn.Module): + """A Neural Arithmetic Logic Unit [1]. + + NALU uses 2 NACs with tied weights to support + multiplicative extrapolation. + + Attributes: + in_features: size of the input sample. + out_features: size of the output sample. + + Sources: + [1]: https://arxiv.org/abs/1808.00508 + """ + def __init__(self, in_features, out_features): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.eps = 1e-10 + + self.G = Parameter(torch.Tensor(out_features, in_features)) + self.W = Parameter(torch.Tensor(out_features, in_features)) + self.nac = NAC(in_features, out_features) + + init.kaiming_uniform_(self.G, a=math.sqrt(5)) + init.kaiming_uniform_(self.W, a=math.sqrt(5)) + + def forward(self, input): + a = self.nac(input) + g = F.sigmoid(F.linear(input, self.G, None)) + add_sub = a * g + log_input = torch.log(torch.abs(input) + self.eps) + m = torch.exp(F.linear(log_input, self.W, None)) + mul_div = (1 - g) * m + y = add_sub + mul_div + return y + + def extra_repr(self): + return 'in_features={}, out_features={}'.format(self.in_features, self.out_features)