mirror of
https://github.com/wassname/NALU-pytorch.git
synced 2026-06-27 16:00:06 +08:00
can NALU learn to do gaussian elimination?
This commit is contained in:
@@ -1,4 +1,3 @@
|
|||||||
*.ipynb
|
|
||||||
ckpt/
|
ckpt/
|
||||||
logs/
|
logs/
|
||||||
plots/
|
plots/
|
||||||
|
|||||||
@@ -0,0 +1,391 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"%load_ext autoreload\n",
|
||||||
|
"%autoreload 2"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import math\n",
|
||||||
|
"import random\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"\n",
|
||||||
|
"import torch\n",
|
||||||
|
"import torch.nn as nn\n",
|
||||||
|
"import torch.nn.functional as F\n",
|
||||||
|
"\n",
|
||||||
|
"from models import *"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def train(model, optimizer, data, target, num_iters):\n",
|
||||||
|
" for i in range(num_iters):\n",
|
||||||
|
" out = model(data)\n",
|
||||||
|
" loss = F.mse_loss(out, target)\n",
|
||||||
|
" mea = torch.mean(torch.abs(target - out))\n",
|
||||||
|
" optimizer.zero_grad()\n",
|
||||||
|
" loss.backward()\n",
|
||||||
|
" optimizer.step()\n",
|
||||||
|
" if i % 5000 == 0:\n",
|
||||||
|
" print(\"\\t{}/{}: loss: {:.3f} - mea: {:.3f}\".format(\n",
|
||||||
|
" i+1, num_iters, loss.item(), mea.item())\n",
|
||||||
|
" )"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Permutation"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# permute the first column with the third\n",
|
||||||
|
"\n",
|
||||||
|
"A = torch.from_numpy(np.array([\n",
|
||||||
|
" [0, 1, -1],\n",
|
||||||
|
" [3, -1, 1],\n",
|
||||||
|
" [1, 1, -2],\n",
|
||||||
|
"])).float()\n",
|
||||||
|
"\n",
|
||||||
|
"B = torch.from_numpy(np.array([\n",
|
||||||
|
" [-1, 1, -0],\n",
|
||||||
|
" [1, -1, 3],\n",
|
||||||
|
" [-2, 1, 1],\n",
|
||||||
|
"])).float()\n",
|
||||||
|
"\n",
|
||||||
|
"P = torch.from_numpy(np.array([\n",
|
||||||
|
" [0, 0, 1],\n",
|
||||||
|
" [0, 1, 0],\n",
|
||||||
|
" [1, 0, 0],\n",
|
||||||
|
"])).float()\n",
|
||||||
|
"\n",
|
||||||
|
"assert torch.allclose(torch.matmul(A, P), B)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\t1/40000: loss: 1.444 - mea: 0.927\n",
|
||||||
|
"\t5001/40000: loss: 0.077 - mea: 0.177\n",
|
||||||
|
"\t10001/40000: loss: 0.000 - mea: 0.007\n",
|
||||||
|
"\t15001/40000: loss: 0.000 - mea: 0.000\n",
|
||||||
|
"\t20001/40000: loss: 0.000 - mea: 0.000\n",
|
||||||
|
"\t25001/40000: loss: 0.000 - mea: 0.000\n",
|
||||||
|
"\t30001/40000: loss: 0.000 - mea: 0.000\n",
|
||||||
|
"\t35001/40000: loss: 0.000 - mea: 0.000\n",
|
||||||
|
"tensor([[ 0.0000, -0.0000, 1.0000],\n",
|
||||||
|
" [-0.0001, 1.0000, -0.0001],\n",
|
||||||
|
" [ 1.0000, 0.0000, -0.0000]])\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"net = NeuralAccumulatorCell(3, 3)\n",
|
||||||
|
"optim = torch.optim.RMSprop(net.parameters(), lr=1e-4)\n",
|
||||||
|
"\n",
|
||||||
|
"train(net, optim, A, B, int(4e4))\n",
|
||||||
|
"\n",
|
||||||
|
"print(net.W.data)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"actual: \n",
|
||||||
|
"tensor([[ 0.0000, -0.0001, 1.0000],\n",
|
||||||
|
" [-0.0000, 1.0000, 0.0000],\n",
|
||||||
|
" [ 1.0000, -0.0001, -0.0000]])\n",
|
||||||
|
"\n",
|
||||||
|
"expected: \n",
|
||||||
|
"tensor([[ 0., 0., 1.],\n",
|
||||||
|
" [ 0., 1., 0.],\n",
|
||||||
|
" [ 1., 0., 0.]])\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(\"actual: \\n{}\\n\".format(net.W.transpose(0, 1)))\n",
|
||||||
|
"print(\"expected: \\n{}\\n\".format(P))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Column Scaling"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# scale the first column by 5\n",
|
||||||
|
"\n",
|
||||||
|
"A = torch.from_numpy(np.array([\n",
|
||||||
|
" [0, 1, -1],\n",
|
||||||
|
" [3, -1, 1],\n",
|
||||||
|
" [1, 1, -2],\n",
|
||||||
|
"])).float()\n",
|
||||||
|
"\n",
|
||||||
|
"B = torch.from_numpy(np.array([\n",
|
||||||
|
" [0, 1, -1],\n",
|
||||||
|
" [15, -1, 1],\n",
|
||||||
|
" [5, 1, -2],\n",
|
||||||
|
"])).float()\n",
|
||||||
|
"\n",
|
||||||
|
"P = torch.from_numpy(np.array([\n",
|
||||||
|
" [5, 0, 0],\n",
|
||||||
|
" [0, 1, 0],\n",
|
||||||
|
" [0, 0, 1],\n",
|
||||||
|
"])).float()\n",
|
||||||
|
"\n",
|
||||||
|
"assert torch.allclose(torch.matmul(A, P), B)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 11,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\t1/90000: loss: 28.833 - mea: 3.000\n",
|
||||||
|
"\t5001/90000: loss: 21.383 - mea: 2.246\n",
|
||||||
|
"\t10001/90000: loss: 16.247 - mea: 1.900\n",
|
||||||
|
"\t15001/90000: loss: 11.851 - mea: 1.626\n",
|
||||||
|
"\t20001/90000: loss: 8.166 - mea: 1.367\n",
|
||||||
|
"\t25001/90000: loss: 5.191 - mea: 1.109\n",
|
||||||
|
"\t30001/90000: loss: 2.924 - mea: 0.852\n",
|
||||||
|
"\t35001/90000: loss: 1.366 - mea: 0.595\n",
|
||||||
|
"\t40001/90000: loss: 0.511 - mea: 0.340\n",
|
||||||
|
"\t45001/90000: loss: 0.207 - mea: 0.227\n",
|
||||||
|
"\t50001/90000: loss: 0.148 - mea: 0.184\n",
|
||||||
|
"\t55001/90000: loss: 0.103 - mea: 0.153\n",
|
||||||
|
"\t60001/90000: loss: 0.065 - mea: 0.122\n",
|
||||||
|
"\t65001/90000: loss: 0.036 - mea: 0.091\n",
|
||||||
|
"\t70001/90000: loss: 0.016 - mea: 0.060\n",
|
||||||
|
"\t75001/90000: loss: 0.004 - mea: 0.030\n",
|
||||||
|
"\t80001/90000: loss: 0.000 - mea: 0.001\n",
|
||||||
|
"\t85001/90000: loss: 0.000 - mea: 0.000\n",
|
||||||
|
"tensor([[ 5.0001e+00, -5.1951e-05, 5.1859e-05],\n",
|
||||||
|
" [-5.0003e-05, 1.0000e+00, -5.0024e-05],\n",
|
||||||
|
" [-5.0004e-05, 4.9928e-05, 9.9995e-01]])\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"net = NeuralAccumulatorCell(3, 3)\n",
|
||||||
|
"optim = torch.optim.RMSprop(net.parameters(), lr=1e-4)\n",
|
||||||
|
"\n",
|
||||||
|
"train(net, optim, A, B, int(9e4))\n",
|
||||||
|
"\n",
|
||||||
|
"print(net.W.data)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 12,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"actual: \n",
|
||||||
|
"tensor([[ 5.0001e+00, -5.0003e-05, -5.0004e-05],\n",
|
||||||
|
" [-5.1951e-05, 1.0000e+00, 4.9928e-05],\n",
|
||||||
|
" [ 5.1859e-05, -5.0024e-05, 9.9995e-01]])\n",
|
||||||
|
"\n",
|
||||||
|
"expected: \n",
|
||||||
|
"tensor([[ 5., 0., 0.],\n",
|
||||||
|
" [ 0., 1., 0.],\n",
|
||||||
|
" [ 0., 0., 1.]])\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(\"actual: \\n{}\\n\".format(net.W.transpose(0, 1)))\n",
|
||||||
|
"print(\"expected: \\n{}\\n\".format(P))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Column Elimination"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 18,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def basis_vec(k, n):\n",
|
||||||
|
" \"\"\"Creates the k'th standard basis vector in R^n.\"\"\"\n",
|
||||||
|
" error_msg = \"[!] k cannot exceed {}.\".format(n)\n",
|
||||||
|
" assert (k < n), error_msg\n",
|
||||||
|
" b = np.zeros([n, 1])\n",
|
||||||
|
" b[k] = 1\n",
|
||||||
|
" return b"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 19,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# add -3x the second column to the first => P = (I - (c)(e_k)(e_l.T))\n",
|
||||||
|
"\n",
|
||||||
|
"A = torch.from_numpy(np.array([\n",
|
||||||
|
" [3, 1, -1],\n",
|
||||||
|
" [3, -1, 1],\n",
|
||||||
|
" [1, 1, -2],\n",
|
||||||
|
"])).float()\n",
|
||||||
|
"\n",
|
||||||
|
"B = torch.from_numpy(np.array([\n",
|
||||||
|
" [0, 1, -1],\n",
|
||||||
|
" [6, -1, 1],\n",
|
||||||
|
" [-2, 1, -2],\n",
|
||||||
|
"])).float()\n",
|
||||||
|
"\n",
|
||||||
|
"P = torch.from_numpy(\n",
|
||||||
|
" np.eye(3) + (-3)*basis_vec(1, 3).dot(basis_vec(0, 3).T)\n",
|
||||||
|
").float()\n",
|
||||||
|
"\n",
|
||||||
|
"assert torch.allclose(torch.matmul(A, P), B)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 20,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\t1/90000: loss: 5.444 - mea: 1.667\n",
|
||||||
|
"\t5001/90000: loss: 1.509 - mea: 0.679\n",
|
||||||
|
"\t10001/90000: loss: 0.226 - mea: 0.247\n",
|
||||||
|
"\t15001/90000: loss: 0.084 - mea: 0.165\n",
|
||||||
|
"\t20001/90000: loss: 0.038 - mea: 0.111\n",
|
||||||
|
"\t25001/90000: loss: 0.010 - mea: 0.057\n",
|
||||||
|
"\t30001/90000: loss: 0.000 - mea: 0.005\n",
|
||||||
|
"\t35001/90000: loss: 0.000 - mea: 0.000\n",
|
||||||
|
"\t40001/90000: loss: 0.000 - mea: 0.000\n",
|
||||||
|
"\t45001/90000: loss: 0.000 - mea: 0.000\n",
|
||||||
|
"\t50001/90000: loss: 0.000 - mea: 0.000\n",
|
||||||
|
"\t55001/90000: loss: 0.000 - mea: 0.000\n",
|
||||||
|
"\t60001/90000: loss: 0.000 - mea: 0.000\n",
|
||||||
|
"\t65001/90000: loss: 0.000 - mea: 0.000\n",
|
||||||
|
"\t70001/90000: loss: 0.000 - mea: 0.000\n",
|
||||||
|
"\t75001/90000: loss: 0.000 - mea: 0.000\n",
|
||||||
|
"\t80001/90000: loss: 0.000 - mea: 0.000\n",
|
||||||
|
"\t85001/90000: loss: 0.000 - mea: 0.000\n",
|
||||||
|
"tensor([[ 9.9995e-01, -3.0001e+00, 5.0084e-05],\n",
|
||||||
|
" [ 4.9894e-05, 1.0000e+00, -4.9870e-05],\n",
|
||||||
|
" [ 4.9999e-05, 4.9925e-05, 9.9995e-01]])\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"net = NeuralAccumulatorCell(3, 3)\n",
|
||||||
|
"optim = torch.optim.RMSprop(net.parameters(), lr=1e-4)\n",
|
||||||
|
"\n",
|
||||||
|
"train(net, optim, A, B, int(9e4))\n",
|
||||||
|
"\n",
|
||||||
|
"print(net.W.data)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 21,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"actual: \n",
|
||||||
|
"tensor([[ 9.9995e-01, 4.9894e-05, 4.9999e-05],\n",
|
||||||
|
" [-3.0001e+00, 1.0000e+00, 4.9925e-05],\n",
|
||||||
|
" [ 5.0084e-05, -4.9870e-05, 9.9995e-01]])\n",
|
||||||
|
"\n",
|
||||||
|
"expected: \n",
|
||||||
|
"tensor([[ 1., 0., 0.],\n",
|
||||||
|
" [-3., 1., 0.],\n",
|
||||||
|
" [ 0., 0., 1.]])\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(\"actual: \\n{}\\n\".format(net.W.transpose(0, 1)))\n",
|
||||||
|
"print(\"expected: \\n{}\\n\".format(P))"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.6.5"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
||||||
+1
-1
@@ -58,7 +58,7 @@ def main():
|
|||||||
mses = []
|
mses = []
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
net = MLP(4, 1, 8, 1, non_lin)
|
net = MLP(4, 1, 8, 1, non_lin)
|
||||||
optim = torch.optim.Adam(net.parameters(), lr=LEARNING_RATE)
|
optim = torch.optim.RMSprop(net.parameters(), lr=LEARNING_RATE)
|
||||||
train(net, optim, train_data, NUM_ITERS)
|
train(net, optim, train_data, NUM_ITERS)
|
||||||
mses.append(test(net, test_data))
|
mses.append(test(net, test_data))
|
||||||
all_mses.append(torch.cat(mses, dim=1).mean(dim=1))
|
all_mses.append(torch.cat(mses, dim=1).mean(dim=1))
|
||||||
|
|||||||
Reference in New Issue
Block a user