diff --git a/README.md b/README.md
index 07e4edd..bf3b40a 100644
--- a/README.md
+++ b/README.md
@@ -1 +1,44 @@
# Neural Arithmetic Logic Units
+
+[WIP]
+
+This is a PyTorch implementation of [Neural Arithmetic Logic Units](https://arxiv.org/abs/1808.00508) by *Andrew Trask, Felix Hill, Scott Reed, Jack Rae, Chris Dyer and Phil Blunsom*.
+
+
+
+
+
+## API
+
+```python
+
+```
+
+## Experiments
+
+To reproduce "Numerical Extrapolation Failures in Neural Networks" (Section 1.1), run:
+
+```python
+python failures.py
+```
+
+This should generate the following plot:
+
+
+
+
+
+To reproduce "Simple Function Learning Tasks" (Section 4.1), run:
+
+```python
+python function_learning.py
+```
+This should generate a text file called `interpolation.txt` with the following results. (Currently only supports interpolation, I'm working on the rest. Also getting `nans` which I'm investigating.)
+
+| | Relu6 | None | NAC | NALU |
+|-------|----------|----------|----------|--------|
+| a + b | 0.002 | 0.000 | 0.000 | 1.399 |
+| a - b | 0.046 | 0.000 | 0.000 | 0.224 |
+| a * b | 83.012 | 99.590 | 98.822 | 12.237 |
+| a / b | 2245.560 | 2888.195 | 2765.908 | nan |
+| a ^ 2 | 76.126 | 99.106 | 99.559 | nan |
diff --git a/config.py b/config.py
deleted file mode 100644
index e69de29..0000000
diff --git a/data.py b/data.py
deleted file mode 100644
index e69de29..0000000
diff --git a/extrapolation.py b/failures.py
similarity index 100%
rename from extrapolation.py
rename to failures.py
diff --git a/function_learning.py b/function_learning.py
new file mode 100644
index 0000000..ee2bce5
--- /dev/null
+++ b/function_learning.py
@@ -0,0 +1,140 @@
+import math
+import random
+import numpy as np
+import matplotlib.pyplot as plt
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from models import MultiLayerNet, MultiLayerNAC, MultiLayerNALU
+
+NORMALIZE = True
+NUM_LAYERS = 2
+HIDDEN_DIM = 2
+LEARNING_RATE = 1e-3
+NUM_ITERS = int(8e4)
+RANGE = [-5, 5]
+ARITHMETIC_FUNCTIONS = {
+ 'add': lambda x, y: x + y,
+ 'sub': lambda x, y: x - y,
+ 'mul': lambda x, y: x * y,
+ 'div': lambda x, y: x / y,
+ 'squared': lambda x, y: torch.pow(x, 2),
+}
+
+
+def generate_data(num_train, num_test, dim, num_sum, fn, support):
+ data = torch.FloatTensor(dim).uniform_(*support).unsqueeze_(1)
+ X, y = [], []
+ for i in range(num_train + num_test):
+ idx_a = random.sample(range(dim), num_sum)
+ idx_b = random.sample([x for x in range(dim) if x not in idx_a], num_sum)
+ a, b = data[idx_a].sum(), data[idx_b].sum()
+ X.append([a, b])
+ y.append(fn(a, b))
+ X = torch.FloatTensor(X)
+ y = torch.FloatTensor(y).unsqueeze_(1)
+ indices = list(range(num_train + num_test))
+ np.random.shuffle(indices)
+ X_train, y_train = X[indices[num_test:]], y[indices[num_test:]]
+ X_test, y_test = X[indices[:num_test]], y[indices[:num_test]]
+ return X_train, y_train, X_test, y_test
+
+
+def train(model, optimizer, data, target, num_iters):
+ for i in range(num_iters):
+ out = model(data)
+ loss = F.mse_loss(out, target)
+ mea = torch.mean(torch.abs(target - out))
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ if i % 1000 == 0:
+ print("\t{}/{}: loss: {:.7f} - mea: {:.7f}".format(
+ i+1, num_iters, loss.item(), mea.item())
+ )
+
+
+def test(model, data, target):
+ with torch.no_grad():
+ out = model(data)
+ return torch.abs(target - out)
+
+
+
+def main():
+ save_dir = './results/'
+
+ models = [
+ MultiLayerNet(
+ 'relu6',
+ num_layers=NUM_LAYERS,
+ in_dim=2,
+ hidden_dim=HIDDEN_DIM,
+ out_dim=1
+ ),
+ MultiLayerNet(
+ 'none',
+ num_layers=NUM_LAYERS,
+ in_dim=2,
+ hidden_dim=HIDDEN_DIM,
+ out_dim=1
+ ),
+ MultiLayerNAC(
+ num_layers=NUM_LAYERS,
+ in_dim=2,
+ hidden_dim=HIDDEN_DIM,
+ out_dim=1
+ ),
+ MultiLayerNALU(
+ num_layers=NUM_LAYERS,
+ in_dim=2,
+ hidden_dim=HIDDEN_DIM,
+ out_dim=1
+ ),
+ ]
+
+ results = {}
+ for fn_str, fn in ARITHMETIC_FUNCTIONS.items():
+ results[fn_str] = []
+
+ # dataset
+ X_train, y_train, X_test, y_test = generate_data(
+ num_train=500, num_test=50,
+ dim=100, num_sum=5, fn=fn,
+ support=RANGE,
+ )
+
+ # random model
+ random_mse = []
+ for i in range(100):
+ net = MultiLayerNet(
+ 'relu6', num_layers=NUM_LAYERS,
+ in_dim=2, hidden_dim=HIDDEN_DIM, out_dim=1
+ )
+ mse = test(net, X_test, y_test)
+ random_mse.append(mse.mean().item())
+ results[fn_str].append(np.mean(random_mse))
+
+ # others
+ for net in models:
+ optim = torch.optim.Adam(net.parameters(), lr=LEARNING_RATE)
+ train(net, optim, X_train, y_train, NUM_ITERS)
+ mse = test(net, X_test, y_test).mean().item()
+ results[fn_str].append(mse)
+
+ with open(save_dir + "interpolation.txt", "w") as f:
+ f.write("Relu6\tNone\tNAC\tNALU\n")
+ for k, v in results.items():
+ rand = results[k][0]
+ mses = [100.0*x/rand for x in results[k][1:]]
+ if NORMALIZE:
+ f.write("{:.3f}\t{:.3f}\t{:.3f}\t{:.3f}\n".format(*mses))
+ else:
+ f.write("{:.3f}\t{:.3f}\t{:.3f}\t{:.3f}\n".format(*results[k][1:]))
+
+
+if __name__ == '__main__':
+ main()
+
diff --git a/imgs/arch.png b/imgs/arch.png
new file mode 100644
index 0000000..9d9fd35
Binary files /dev/null and b/imgs/arch.png differ
diff --git a/main.py b/main.py
deleted file mode 100644
index e69de29..0000000
diff --git a/models/__init__.py b/models/__init__.py
index 33f5ede..c9fbfa9 100644
--- a/models/__init__.py
+++ b/models/__init__.py
@@ -1,3 +1,4 @@
from .mlp import MLP
from .nac import NAC
from .nalu import NALU
+from .models import MultiLayerNet, MultiLayerNAC, MultiLayerNALU
diff --git a/models/mlp.py b/models/mlp.py
index 032eb34..137d843 100644
--- a/models/mlp.py
+++ b/models/mlp.py
@@ -1,6 +1,5 @@
import math
import torch.nn as nn
-import torch.nn.functional as F
class MLP(nn.Module):
diff --git a/models/models.py b/models/models.py
new file mode 100644
index 0000000..25f4349
--- /dev/null
+++ b/models/models.py
@@ -0,0 +1,123 @@
+import math
+import torch
+import torch.nn as nn
+
+from .nac import NAC
+from .nalu import NALU
+
+
+class MultiLayerNet(nn.Module):
+ def __init__(self, activation, num_layers, in_dim, hidden_dim, out_dim):
+ super().__init__()
+ self.num_layers = num_layers
+ self.in_dim = in_dim
+ self.hidden_dim = hidden_dim
+ self.out_dim = out_dim
+
+ if activation is 'none':
+ self.activation = None
+ elif activation is 'hardtanh':
+ self.activation = nn.Hardtanh()
+ elif activation is 'sigmoid':
+ self.activation = nn.Sigmoid()
+ elif activation is 'relu6':
+ self.activation = nn.ReLU6()
+ elif activation is 'tanh':
+ self.activation = nn.Tanh()
+ elif activation is 'tanhshrink':
+ self.activation = nn.Tanhshrink()
+ elif activation is 'hardshrink':
+ self.activation = nn.Hardshrink()
+ elif activation is 'leakyrelu':
+ self.activation = nn.LeakyReLU()
+ elif activation is 'softshrink':
+ self.activation = nn.Softshrink()
+ elif activation is 'softsign':
+ self.activation = nn.Softsign()
+ elif activation is 'relu':
+ self.activation = nn.ReLU()
+ elif activation is 'prelu':
+ self.activation = nn.PReLU()
+ elif activation is 'softplus':
+ self.activation = nn.Softplus()
+ elif activation is 'elu':
+ self.activation = nn.ELU()
+ elif activation is 'selu':
+ self.activation = nn.SELU()
+ else:
+ raise ValueError("[!] Invalid activation function.")
+
+
+ layers = []
+ if self.activation is not None:
+ layers.extend([
+ nn.Linear(in_dim, hidden_dim),
+ self.activation,
+ ])
+ else:
+ layers.append(nn.Linear(in_dim, hidden_dim))
+ for i in range(num_layers - 2):
+ if self.activation is not None:
+ layers.extend([
+ nn.Linear(hidden_dim, hidden_dim),
+ self.activation,
+ ])
+ else:
+ layers.append(nn.Linear(hidden_dim, hidden_dim))
+ layers.append(nn.Linear(hidden_dim, out_dim))
+
+ self.model = nn.Sequential(*layers)
+
+ # init
+ for m in self.modules():
+ if isinstance(m, nn.Linear):
+ nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight)
+ bound = 1 / math.sqrt(fan_in)
+ nn.init.uniform_(m.bias, -bound, bound)
+
+ def forward(self, x):
+ out = self.model(x)
+ return out
+
+
+class MultiLayerNAC(nn.Module):
+ def __init__(self, num_layers, in_dim, hidden_dim, out_dim):
+ super().__init__()
+ self.num_layers = num_layers
+ self.in_dim = in_dim
+ self.hidden_dim = hidden_dim
+ self.out_dim = out_dim
+
+ layers = []
+ layers.append(NAC(in_dim, hidden_dim))
+ for i in range(num_layers - 2):
+ layers.append(NAC(hidden_dim, hidden_dim))
+ layers.append(NAC(hidden_dim, out_dim))
+
+ self.model = nn.Sequential(*layers)
+
+ def forward(self, x):
+ out = self.model(x)
+ return out
+
+
+class MultiLayerNALU(nn.Module):
+ def __init__(self, num_layers, in_dim, hidden_dim, out_dim):
+ super().__init__()
+ self.num_layers = num_layers
+ self.in_dim = in_dim
+ self.hidden_dim = hidden_dim
+ self.out_dim = out_dim
+
+ layers = []
+ layers.append(NALU(in_dim, hidden_dim))
+ for i in range(num_layers - 2):
+ layers.append(NALU(hidden_dim, hidden_dim))
+ layers.append(NALU(hidden_dim, out_dim))
+
+ self.model = nn.Sequential(*layers)
+
+ def forward(self, x):
+ out = self.model(x)
+ return out
diff --git a/models/nac.py b/models/nac.py
index 01fdb78..c15e74a 100644
--- a/models/nac.py
+++ b/models/nac.py
@@ -38,4 +38,6 @@ class NAC(nn.Module):
return F.linear(input, self.W, self.bias)
def extra_repr(self):
- return 'in_features={}, out_features={}'.format(self.in_features, self.out_features)
+ return 'in_features={}, out_features={}'.format(
+ self.in_features, self.out_features
+ )
diff --git a/models/nalu.py b/models/nalu.py
index b2b0474..3bcdc7c 100644
--- a/models/nalu.py
+++ b/models/nalu.py
@@ -45,4 +45,6 @@ class NALU(nn.Module):
return y
def extra_repr(self):
- return 'in_features={}, out_features={}'.format(self.in_features, self.out_features)
+ return 'in_features={}, out_features={}'.format(
+ self.in_features, self.out_features
+ )
diff --git a/results/interpolation.txt b/results/interpolation.txt
new file mode 100644
index 0000000..8e9ec43
--- /dev/null
+++ b/results/interpolation.txt
@@ -0,0 +1,6 @@
+Relu6 None NAC NALU
+0.002 0.000 0.000 1.399
+0.046 0.000 0.000 0.224
+83.012 99.590 98.822 12.237
+2245.560 2888.195 2765.908 nan
+76.126 99.106 99.559 nan
diff --git a/trainer.py b/trainer.py
deleted file mode 100644
index e69de29..0000000
diff --git a/utils.py b/utils.py
deleted file mode 100644
index e69de29..0000000