working on static interpolation.

This commit is contained in:
Kevin
2018-08-04 02:53:30 -07:00
parent 5bf6e24b94
commit 42165bb02b
15 changed files with 319 additions and 3 deletions
+43
View File
@@ -1 +1,44 @@
# Neural Arithmetic Logic Units
[WIP]
This is a PyTorch implementation of [Neural Arithmetic Logic Units](https://arxiv.org/abs/1808.00508) by *Andrew Trask, Felix Hill, Scott Reed, Jack Rae, Chris Dyer and Phil Blunsom*.
<p align="center">
<img src="./imgs/arch.png" alt="Drawing", width=60%>
</p>
## API
```python
```
## Experiments
To reproduce "Numerical Extrapolation Failures in Neural Networks" (Section 1.1), run:
```python
python failures.py
```
This should generate the following plot:
<p align="center">
<img src="./imgs/extrapolation.png" alt="Drawing", width=60%>
</p>
To reproduce "Simple Function Learning Tasks" (Section 4.1), run:
```python
python function_learning.py
```
This should generate a text file called `interpolation.txt` with the following results. (Currently only supports interpolation, I'm working on the rest. Also getting `nans` which I'm investigating.)
| | Relu6 | None | NAC | NALU |
|-------|----------|----------|----------|--------|
| a + b | 0.002 | 0.000 | 0.000 | 1.399 |
| a - b | 0.046 | 0.000 | 0.000 | 0.224 |
| a * b | 83.012 | 99.590 | 98.822 | 12.237 |
| a / b | 2245.560 | 2888.195 | 2765.908 | nan |
| a ^ 2 | 76.126 | 99.106 | 99.559 | nan |
View File
View File
View File
+140
View File
@@ -0,0 +1,140 @@
import math
import random
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from models import MultiLayerNet, MultiLayerNAC, MultiLayerNALU
NORMALIZE = True
NUM_LAYERS = 2
HIDDEN_DIM = 2
LEARNING_RATE = 1e-3
NUM_ITERS = int(8e4)
RANGE = [-5, 5]
ARITHMETIC_FUNCTIONS = {
'add': lambda x, y: x + y,
'sub': lambda x, y: x - y,
'mul': lambda x, y: x * y,
'div': lambda x, y: x / y,
'squared': lambda x, y: torch.pow(x, 2),
}
def generate_data(num_train, num_test, dim, num_sum, fn, support):
data = torch.FloatTensor(dim).uniform_(*support).unsqueeze_(1)
X, y = [], []
for i in range(num_train + num_test):
idx_a = random.sample(range(dim), num_sum)
idx_b = random.sample([x for x in range(dim) if x not in idx_a], num_sum)
a, b = data[idx_a].sum(), data[idx_b].sum()
X.append([a, b])
y.append(fn(a, b))
X = torch.FloatTensor(X)
y = torch.FloatTensor(y).unsqueeze_(1)
indices = list(range(num_train + num_test))
np.random.shuffle(indices)
X_train, y_train = X[indices[num_test:]], y[indices[num_test:]]
X_test, y_test = X[indices[:num_test]], y[indices[:num_test]]
return X_train, y_train, X_test, y_test
def train(model, optimizer, data, target, num_iters):
for i in range(num_iters):
out = model(data)
loss = F.mse_loss(out, target)
mea = torch.mean(torch.abs(target - out))
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % 1000 == 0:
print("\t{}/{}: loss: {:.7f} - mea: {:.7f}".format(
i+1, num_iters, loss.item(), mea.item())
)
def test(model, data, target):
with torch.no_grad():
out = model(data)
return torch.abs(target - out)
def main():
save_dir = './results/'
models = [
MultiLayerNet(
'relu6',
num_layers=NUM_LAYERS,
in_dim=2,
hidden_dim=HIDDEN_DIM,
out_dim=1
),
MultiLayerNet(
'none',
num_layers=NUM_LAYERS,
in_dim=2,
hidden_dim=HIDDEN_DIM,
out_dim=1
),
MultiLayerNAC(
num_layers=NUM_LAYERS,
in_dim=2,
hidden_dim=HIDDEN_DIM,
out_dim=1
),
MultiLayerNALU(
num_layers=NUM_LAYERS,
in_dim=2,
hidden_dim=HIDDEN_DIM,
out_dim=1
),
]
results = {}
for fn_str, fn in ARITHMETIC_FUNCTIONS.items():
results[fn_str] = []
# dataset
X_train, y_train, X_test, y_test = generate_data(
num_train=500, num_test=50,
dim=100, num_sum=5, fn=fn,
support=RANGE,
)
# random model
random_mse = []
for i in range(100):
net = MultiLayerNet(
'relu6', num_layers=NUM_LAYERS,
in_dim=2, hidden_dim=HIDDEN_DIM, out_dim=1
)
mse = test(net, X_test, y_test)
random_mse.append(mse.mean().item())
results[fn_str].append(np.mean(random_mse))
# others
for net in models:
optim = torch.optim.Adam(net.parameters(), lr=LEARNING_RATE)
train(net, optim, X_train, y_train, NUM_ITERS)
mse = test(net, X_test, y_test).mean().item()
results[fn_str].append(mse)
with open(save_dir + "interpolation.txt", "w") as f:
f.write("Relu6\tNone\tNAC\tNALU\n")
for k, v in results.items():
rand = results[k][0]
mses = [100.0*x/rand for x in results[k][1:]]
if NORMALIZE:
f.write("{:.3f}\t{:.3f}\t{:.3f}\t{:.3f}\n".format(*mses))
else:
f.write("{:.3f}\t{:.3f}\t{:.3f}\t{:.3f}\n".format(*results[k][1:]))
if __name__ == '__main__':
main()
BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 263 KiB

View File
+1
View File
@@ -1,3 +1,4 @@
from .mlp import MLP
from .nac import NAC
from .nalu import NALU
from .models import MultiLayerNet, MultiLayerNAC, MultiLayerNALU
-1
View File
@@ -1,6 +1,5 @@
import math
import torch.nn as nn
import torch.nn.functional as F
class MLP(nn.Module):
+123
View File
@@ -0,0 +1,123 @@
import math
import torch
import torch.nn as nn
from .nac import NAC
from .nalu import NALU
class MultiLayerNet(nn.Module):
def __init__(self, activation, num_layers, in_dim, hidden_dim, out_dim):
super().__init__()
self.num_layers = num_layers
self.in_dim = in_dim
self.hidden_dim = hidden_dim
self.out_dim = out_dim
if activation is 'none':
self.activation = None
elif activation is 'hardtanh':
self.activation = nn.Hardtanh()
elif activation is 'sigmoid':
self.activation = nn.Sigmoid()
elif activation is 'relu6':
self.activation = nn.ReLU6()
elif activation is 'tanh':
self.activation = nn.Tanh()
elif activation is 'tanhshrink':
self.activation = nn.Tanhshrink()
elif activation is 'hardshrink':
self.activation = nn.Hardshrink()
elif activation is 'leakyrelu':
self.activation = nn.LeakyReLU()
elif activation is 'softshrink':
self.activation = nn.Softshrink()
elif activation is 'softsign':
self.activation = nn.Softsign()
elif activation is 'relu':
self.activation = nn.ReLU()
elif activation is 'prelu':
self.activation = nn.PReLU()
elif activation is 'softplus':
self.activation = nn.Softplus()
elif activation is 'elu':
self.activation = nn.ELU()
elif activation is 'selu':
self.activation = nn.SELU()
else:
raise ValueError("[!] Invalid activation function.")
layers = []
if self.activation is not None:
layers.extend([
nn.Linear(in_dim, hidden_dim),
self.activation,
])
else:
layers.append(nn.Linear(in_dim, hidden_dim))
for i in range(num_layers - 2):
if self.activation is not None:
layers.extend([
nn.Linear(hidden_dim, hidden_dim),
self.activation,
])
else:
layers.append(nn.Linear(hidden_dim, hidden_dim))
layers.append(nn.Linear(hidden_dim, out_dim))
self.model = nn.Sequential(*layers)
# init
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(m.bias, -bound, bound)
def forward(self, x):
out = self.model(x)
return out
class MultiLayerNAC(nn.Module):
def __init__(self, num_layers, in_dim, hidden_dim, out_dim):
super().__init__()
self.num_layers = num_layers
self.in_dim = in_dim
self.hidden_dim = hidden_dim
self.out_dim = out_dim
layers = []
layers.append(NAC(in_dim, hidden_dim))
for i in range(num_layers - 2):
layers.append(NAC(hidden_dim, hidden_dim))
layers.append(NAC(hidden_dim, out_dim))
self.model = nn.Sequential(*layers)
def forward(self, x):
out = self.model(x)
return out
class MultiLayerNALU(nn.Module):
def __init__(self, num_layers, in_dim, hidden_dim, out_dim):
super().__init__()
self.num_layers = num_layers
self.in_dim = in_dim
self.hidden_dim = hidden_dim
self.out_dim = out_dim
layers = []
layers.append(NALU(in_dim, hidden_dim))
for i in range(num_layers - 2):
layers.append(NALU(hidden_dim, hidden_dim))
layers.append(NALU(hidden_dim, out_dim))
self.model = nn.Sequential(*layers)
def forward(self, x):
out = self.model(x)
return out
+3 -1
View File
@@ -38,4 +38,6 @@ class NAC(nn.Module):
return F.linear(input, self.W, self.bias)
def extra_repr(self):
return 'in_features={}, out_features={}'.format(self.in_features, self.out_features)
return 'in_features={}, out_features={}'.format(
self.in_features, self.out_features
)
+3 -1
View File
@@ -45,4 +45,6 @@ class NALU(nn.Module):
return y
def extra_repr(self):
return 'in_features={}, out_features={}'.format(self.in_features, self.out_features)
return 'in_features={}, out_features={}'.format(
self.in_features, self.out_features
)
+6
View File
@@ -0,0 +1,6 @@
Relu6 None NAC NALU
0.002 0.000 0.000 1.399
0.046 0.000 0.000 0.224
83.012 99.590 98.822 12.237
2245.560 2888.195 2765.908 nan
76.126 99.106 99.559 nan
View File
View File