mirror of
https://github.com/wassname/NALU-pytorch.git
synced 2026-06-27 16:00:06 +08:00
done with extrapolation experiment.
This commit is contained in:
+110
@@ -0,0 +1,110 @@
|
||||
*.ipynb
|
||||
ckpt/
|
||||
logs/
|
||||
plots/
|
||||
.DS_Store
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# celery beat schedule file
|
||||
celerybeat-schedule
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
@@ -0,0 +1,77 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from models import MLP
|
||||
|
||||
TRAIN_RANGE = [-5, 5]
|
||||
TEST_RANGE = [-20, 20]
|
||||
LEARNING_RATE = 1e-2
|
||||
NUM_ITERS = int(1e4)
|
||||
NON_LINEARITIES = [
|
||||
'hardtanh', 'sigmoid', 'relu6', 'tanh',
|
||||
'tanhshrink', 'hardshrink', 'leakyrelu',
|
||||
'softshrink', 'softsign', 'relu',
|
||||
'prelu', 'softplus', 'elu', 'selu',
|
||||
]
|
||||
|
||||
|
||||
def train(model, optimizer, data, num_iters):
|
||||
for i in range(num_iters):
|
||||
out = model(data)
|
||||
loss = F.mse_loss(out, data)
|
||||
mea = torch.mean(torch.abs(data - out))
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
if i % 1000 == 0:
|
||||
print("\t{}/{}: loss: {:.3f} - mea: {:.3f}".format(
|
||||
i+1, num_iters, loss.item(), mea.item())
|
||||
)
|
||||
|
||||
|
||||
def test(model, data):
|
||||
with torch.no_grad():
|
||||
out = model(data)
|
||||
return torch.abs(data - out)
|
||||
|
||||
|
||||
def main():
|
||||
save_dir = './imgs/'
|
||||
|
||||
TRAIN_RANGE[-1] += 1
|
||||
TEST_RANGE[-1] += 1
|
||||
|
||||
# datasets
|
||||
train_data = torch.arange(*TRAIN_RANGE).unsqueeze_(1).float()
|
||||
test_data = torch.arange(*TEST_RANGE).unsqueeze_(1).float()
|
||||
|
||||
# train
|
||||
all_mses = []
|
||||
for non_lin in NON_LINEARITIES:
|
||||
print("Working with {}...".format(non_lin))
|
||||
mses = []
|
||||
for i in range(100):
|
||||
net = MLP(non_lin)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=LEARNING_RATE)
|
||||
train(net, optim, train_data, NUM_ITERS)
|
||||
mses.append(test(net, test_data))
|
||||
all_mses.append(torch.cat(mses, dim=1).mean(dim=1))
|
||||
all_mses = [x.numpy().flatten() for x in all_mses]
|
||||
|
||||
# plot
|
||||
fig, ax = plt.subplots(figsize=(8, 7))
|
||||
x_axis = np.arange(-20, 21)
|
||||
for i, non_lin in enumerate(NON_LINEARITIES):
|
||||
ax.plot(x_axis, all_mses[i], label=non_lin)
|
||||
plt.grid()
|
||||
plt.legend(loc='best')
|
||||
plt.ylabel('Mean Absolute Error')
|
||||
plt.savefig(save_dir + 'extrapolation.png', format='png', dpi=300)
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 392 KiB |
@@ -0,0 +1 @@
|
||||
from .mlp import MLP
|
||||
@@ -0,0 +1,61 @@
|
||||
import math
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, activation, input_dim=1, encoding_dim=8):
|
||||
super().__init__()
|
||||
|
||||
if activation is 'hardtanh':
|
||||
self.activation = nn.Hardtanh()
|
||||
elif activation is 'sigmoid':
|
||||
self.activation = nn.Sigmoid()
|
||||
elif activation is 'relu6':
|
||||
self.activation = nn.ReLU6()
|
||||
elif activation is 'tanh':
|
||||
self.activation = nn.Tanh()
|
||||
elif activation is 'tanhshrink':
|
||||
self.activation = nn.Tanhshrink()
|
||||
elif activation is 'hardshrink':
|
||||
self.activation = nn.Hardshrink()
|
||||
elif activation is 'leakyrelu':
|
||||
self.activation = nn.LeakyReLU()
|
||||
elif activation is 'softshrink':
|
||||
self.activation = nn.Softshrink()
|
||||
elif activation is 'softsign':
|
||||
self.activation = nn.Softsign()
|
||||
elif activation is 'relu':
|
||||
self.activation = nn.ReLU()
|
||||
elif activation is 'prelu':
|
||||
self.activation = nn.PReLU()
|
||||
elif activation is 'softplus':
|
||||
self.activation = nn.Softplus()
|
||||
elif activation is 'elu':
|
||||
self.activation = nn.ELU()
|
||||
elif activation is 'selu':
|
||||
self.activation = nn.SELU()
|
||||
else:
|
||||
raise ValueError("[!] Invalid activation function.")
|
||||
|
||||
self.i2h = nn.Linear(input_dim, encoding_dim)
|
||||
self.h2h1 = nn.Linear(encoding_dim, encoding_dim)
|
||||
self.h2h2 = nn.Linear(encoding_dim, encoding_dim)
|
||||
self.h2h3 = nn.Linear(encoding_dim, encoding_dim)
|
||||
self.h2o = nn.Linear(encoding_dim, input_dim)
|
||||
|
||||
# init
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))
|
||||
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight)
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
nn.init.uniform_(m.bias, -bound, bound)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.activation(self.i2h(x))
|
||||
out = self.activation(self.h2h1(out))
|
||||
out = self.activation(self.h2h2(out))
|
||||
out = self.activation(self.h2h3(out))
|
||||
out = self.h2o(out)
|
||||
return out
|
||||
Reference in New Issue
Block a user