diff --git a/README.md b/README.md index 07e4edd..bf3b40a 100644 --- a/README.md +++ b/README.md @@ -1 +1,44 @@ # Neural Arithmetic Logic Units + +[WIP] + +This is a PyTorch implementation of [Neural Arithmetic Logic Units](https://arxiv.org/abs/1808.00508) by *Andrew Trask, Felix Hill, Scott Reed, Jack Rae, Chris Dyer and Phil Blunsom*. + +

+ Drawing +

+ +## API + +```python + +``` + +## Experiments + +To reproduce "Numerical Extrapolation Failures in Neural Networks" (Section 1.1), run: + +```python +python failures.py +``` + +This should generate the following plot: + +

+ Drawing +

+ +To reproduce "Simple Function Learning Tasks" (Section 4.1), run: + +```python +python function_learning.py +``` +This should generate a text file called `interpolation.txt` with the following results. (Currently only supports interpolation, I'm working on the rest. Also getting `nans` which I'm investigating.) + +| | Relu6 | None | NAC | NALU | +|-------|----------|----------|----------|--------| +| a + b | 0.002 | 0.000 | 0.000 | 1.399 | +| a - b | 0.046 | 0.000 | 0.000 | 0.224 | +| a * b | 83.012 | 99.590 | 98.822 | 12.237 | +| a / b | 2245.560 | 2888.195 | 2765.908 | nan | +| a ^ 2 | 76.126 | 99.106 | 99.559 | nan | diff --git a/config.py b/config.py deleted file mode 100644 index e69de29..0000000 diff --git a/data.py b/data.py deleted file mode 100644 index e69de29..0000000 diff --git a/extrapolation.py b/failures.py similarity index 100% rename from extrapolation.py rename to failures.py diff --git a/function_learning.py b/function_learning.py new file mode 100644 index 0000000..ee2bce5 --- /dev/null +++ b/function_learning.py @@ -0,0 +1,140 @@ +import math +import random +import numpy as np +import matplotlib.pyplot as plt + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from models import MultiLayerNet, MultiLayerNAC, MultiLayerNALU + +NORMALIZE = True +NUM_LAYERS = 2 +HIDDEN_DIM = 2 +LEARNING_RATE = 1e-3 +NUM_ITERS = int(8e4) +RANGE = [-5, 5] +ARITHMETIC_FUNCTIONS = { + 'add': lambda x, y: x + y, + 'sub': lambda x, y: x - y, + 'mul': lambda x, y: x * y, + 'div': lambda x, y: x / y, + 'squared': lambda x, y: torch.pow(x, 2), +} + + +def generate_data(num_train, num_test, dim, num_sum, fn, support): + data = torch.FloatTensor(dim).uniform_(*support).unsqueeze_(1) + X, y = [], [] + for i in range(num_train + num_test): + idx_a = random.sample(range(dim), num_sum) + idx_b = random.sample([x for x in range(dim) if x not in idx_a], num_sum) + a, b = data[idx_a].sum(), data[idx_b].sum() + X.append([a, b]) + y.append(fn(a, b)) + X = torch.FloatTensor(X) + y = torch.FloatTensor(y).unsqueeze_(1) + indices = list(range(num_train + num_test)) + np.random.shuffle(indices) + X_train, y_train = X[indices[num_test:]], y[indices[num_test:]] + X_test, y_test = X[indices[:num_test]], y[indices[:num_test]] + return X_train, y_train, X_test, y_test + + +def train(model, optimizer, data, target, num_iters): + for i in range(num_iters): + out = model(data) + loss = F.mse_loss(out, target) + mea = torch.mean(torch.abs(target - out)) + optimizer.zero_grad() + loss.backward() + optimizer.step() + if i % 1000 == 0: + print("\t{}/{}: loss: {:.7f} - mea: {:.7f}".format( + i+1, num_iters, loss.item(), mea.item()) + ) + + +def test(model, data, target): + with torch.no_grad(): + out = model(data) + return torch.abs(target - out) + + + +def main(): + save_dir = './results/' + + models = [ + MultiLayerNet( + 'relu6', + num_layers=NUM_LAYERS, + in_dim=2, + hidden_dim=HIDDEN_DIM, + out_dim=1 + ), + MultiLayerNet( + 'none', + num_layers=NUM_LAYERS, + in_dim=2, + hidden_dim=HIDDEN_DIM, + out_dim=1 + ), + MultiLayerNAC( + num_layers=NUM_LAYERS, + in_dim=2, + hidden_dim=HIDDEN_DIM, + out_dim=1 + ), + MultiLayerNALU( + num_layers=NUM_LAYERS, + in_dim=2, + hidden_dim=HIDDEN_DIM, + out_dim=1 + ), + ] + + results = {} + for fn_str, fn in ARITHMETIC_FUNCTIONS.items(): + results[fn_str] = [] + + # dataset + X_train, y_train, X_test, y_test = generate_data( + num_train=500, num_test=50, + dim=100, num_sum=5, fn=fn, + support=RANGE, + ) + + # random model + random_mse = [] + for i in range(100): + net = MultiLayerNet( + 'relu6', num_layers=NUM_LAYERS, + in_dim=2, hidden_dim=HIDDEN_DIM, out_dim=1 + ) + mse = test(net, X_test, y_test) + random_mse.append(mse.mean().item()) + results[fn_str].append(np.mean(random_mse)) + + # others + for net in models: + optim = torch.optim.Adam(net.parameters(), lr=LEARNING_RATE) + train(net, optim, X_train, y_train, NUM_ITERS) + mse = test(net, X_test, y_test).mean().item() + results[fn_str].append(mse) + + with open(save_dir + "interpolation.txt", "w") as f: + f.write("Relu6\tNone\tNAC\tNALU\n") + for k, v in results.items(): + rand = results[k][0] + mses = [100.0*x/rand for x in results[k][1:]] + if NORMALIZE: + f.write("{:.3f}\t{:.3f}\t{:.3f}\t{:.3f}\n".format(*mses)) + else: + f.write("{:.3f}\t{:.3f}\t{:.3f}\t{:.3f}\n".format(*results[k][1:])) + + +if __name__ == '__main__': + main() + diff --git a/imgs/arch.png b/imgs/arch.png new file mode 100644 index 0000000..9d9fd35 Binary files /dev/null and b/imgs/arch.png differ diff --git a/main.py b/main.py deleted file mode 100644 index e69de29..0000000 diff --git a/models/__init__.py b/models/__init__.py index 33f5ede..c9fbfa9 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,3 +1,4 @@ from .mlp import MLP from .nac import NAC from .nalu import NALU +from .models import MultiLayerNet, MultiLayerNAC, MultiLayerNALU diff --git a/models/mlp.py b/models/mlp.py index 032eb34..137d843 100644 --- a/models/mlp.py +++ b/models/mlp.py @@ -1,6 +1,5 @@ import math import torch.nn as nn -import torch.nn.functional as F class MLP(nn.Module): diff --git a/models/models.py b/models/models.py new file mode 100644 index 0000000..25f4349 --- /dev/null +++ b/models/models.py @@ -0,0 +1,123 @@ +import math +import torch +import torch.nn as nn + +from .nac import NAC +from .nalu import NALU + + +class MultiLayerNet(nn.Module): + def __init__(self, activation, num_layers, in_dim, hidden_dim, out_dim): + super().__init__() + self.num_layers = num_layers + self.in_dim = in_dim + self.hidden_dim = hidden_dim + self.out_dim = out_dim + + if activation is 'none': + self.activation = None + elif activation is 'hardtanh': + self.activation = nn.Hardtanh() + elif activation is 'sigmoid': + self.activation = nn.Sigmoid() + elif activation is 'relu6': + self.activation = nn.ReLU6() + elif activation is 'tanh': + self.activation = nn.Tanh() + elif activation is 'tanhshrink': + self.activation = nn.Tanhshrink() + elif activation is 'hardshrink': + self.activation = nn.Hardshrink() + elif activation is 'leakyrelu': + self.activation = nn.LeakyReLU() + elif activation is 'softshrink': + self.activation = nn.Softshrink() + elif activation is 'softsign': + self.activation = nn.Softsign() + elif activation is 'relu': + self.activation = nn.ReLU() + elif activation is 'prelu': + self.activation = nn.PReLU() + elif activation is 'softplus': + self.activation = nn.Softplus() + elif activation is 'elu': + self.activation = nn.ELU() + elif activation is 'selu': + self.activation = nn.SELU() + else: + raise ValueError("[!] Invalid activation function.") + + + layers = [] + if self.activation is not None: + layers.extend([ + nn.Linear(in_dim, hidden_dim), + self.activation, + ]) + else: + layers.append(nn.Linear(in_dim, hidden_dim)) + for i in range(num_layers - 2): + if self.activation is not None: + layers.extend([ + nn.Linear(hidden_dim, hidden_dim), + self.activation, + ]) + else: + layers.append(nn.Linear(hidden_dim, hidden_dim)) + layers.append(nn.Linear(hidden_dim, out_dim)) + + self.model = nn.Sequential(*layers) + + # init + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5)) + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(m.bias, -bound, bound) + + def forward(self, x): + out = self.model(x) + return out + + +class MultiLayerNAC(nn.Module): + def __init__(self, num_layers, in_dim, hidden_dim, out_dim): + super().__init__() + self.num_layers = num_layers + self.in_dim = in_dim + self.hidden_dim = hidden_dim + self.out_dim = out_dim + + layers = [] + layers.append(NAC(in_dim, hidden_dim)) + for i in range(num_layers - 2): + layers.append(NAC(hidden_dim, hidden_dim)) + layers.append(NAC(hidden_dim, out_dim)) + + self.model = nn.Sequential(*layers) + + def forward(self, x): + out = self.model(x) + return out + + +class MultiLayerNALU(nn.Module): + def __init__(self, num_layers, in_dim, hidden_dim, out_dim): + super().__init__() + self.num_layers = num_layers + self.in_dim = in_dim + self.hidden_dim = hidden_dim + self.out_dim = out_dim + + layers = [] + layers.append(NALU(in_dim, hidden_dim)) + for i in range(num_layers - 2): + layers.append(NALU(hidden_dim, hidden_dim)) + layers.append(NALU(hidden_dim, out_dim)) + + self.model = nn.Sequential(*layers) + + def forward(self, x): + out = self.model(x) + return out diff --git a/models/nac.py b/models/nac.py index 01fdb78..c15e74a 100644 --- a/models/nac.py +++ b/models/nac.py @@ -38,4 +38,6 @@ class NAC(nn.Module): return F.linear(input, self.W, self.bias) def extra_repr(self): - return 'in_features={}, out_features={}'.format(self.in_features, self.out_features) + return 'in_features={}, out_features={}'.format( + self.in_features, self.out_features + ) diff --git a/models/nalu.py b/models/nalu.py index b2b0474..3bcdc7c 100644 --- a/models/nalu.py +++ b/models/nalu.py @@ -45,4 +45,6 @@ class NALU(nn.Module): return y def extra_repr(self): - return 'in_features={}, out_features={}'.format(self.in_features, self.out_features) + return 'in_features={}, out_features={}'.format( + self.in_features, self.out_features + ) diff --git a/results/interpolation.txt b/results/interpolation.txt new file mode 100644 index 0000000..8e9ec43 --- /dev/null +++ b/results/interpolation.txt @@ -0,0 +1,6 @@ +Relu6 None NAC NALU +0.002 0.000 0.000 1.399 +0.046 0.000 0.000 0.224 +83.012 99.590 98.822 12.237 +2245.560 2888.195 2765.908 nan +76.126 99.106 99.559 nan diff --git a/trainer.py b/trainer.py deleted file mode 100644 index e69de29..0000000 diff --git a/utils.py b/utils.py deleted file mode 100644 index e69de29..0000000