commit 55522f44ef013062287322d34c623b87f358535a Author: Kevin Date: Fri Aug 3 17:11:57 2018 -0700 done with extrapolation experiment. diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b3ebba1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,110 @@ +*.ipynb +ckpt/ +logs/ +plots/ +.DS_Store + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..e69de29 diff --git a/config.py b/config.py new file mode 100644 index 0000000..e69de29 diff --git a/data.py b/data.py new file mode 100644 index 0000000..e69de29 diff --git a/extrapolation.py b/extrapolation.py new file mode 100644 index 0000000..c901e7b --- /dev/null +++ b/extrapolation.py @@ -0,0 +1,77 @@ +import numpy as np +import matplotlib.pyplot as plt + +import torch +import torch.nn.functional as F + +from models import MLP + +TRAIN_RANGE = [-5, 5] +TEST_RANGE = [-20, 20] +LEARNING_RATE = 1e-2 +NUM_ITERS = int(1e4) +NON_LINEARITIES = [ + 'hardtanh', 'sigmoid', 'relu6', 'tanh', + 'tanhshrink', 'hardshrink', 'leakyrelu', + 'softshrink', 'softsign', 'relu', + 'prelu', 'softplus', 'elu', 'selu', +] + + +def train(model, optimizer, data, num_iters): + for i in range(num_iters): + out = model(data) + loss = F.mse_loss(out, data) + mea = torch.mean(torch.abs(data - out)) + optimizer.zero_grad() + loss.backward() + optimizer.step() + if i % 1000 == 0: + print("\t{}/{}: loss: {:.3f} - mea: {:.3f}".format( + i+1, num_iters, loss.item(), mea.item()) + ) + + +def test(model, data): + with torch.no_grad(): + out = model(data) + return torch.abs(data - out) + + +def main(): + save_dir = './imgs/' + + TRAIN_RANGE[-1] += 1 + TEST_RANGE[-1] += 1 + + # datasets + train_data = torch.arange(*TRAIN_RANGE).unsqueeze_(1).float() + test_data = torch.arange(*TEST_RANGE).unsqueeze_(1).float() + + # train + all_mses = [] + for non_lin in NON_LINEARITIES: + print("Working with {}...".format(non_lin)) + mses = [] + for i in range(100): + net = MLP(non_lin) + optim = torch.optim.Adam(net.parameters(), lr=LEARNING_RATE) + train(net, optim, train_data, NUM_ITERS) + mses.append(test(net, test_data)) + all_mses.append(torch.cat(mses, dim=1).mean(dim=1)) + all_mses = [x.numpy().flatten() for x in all_mses] + + # plot + fig, ax = plt.subplots(figsize=(8, 7)) + x_axis = np.arange(-20, 21) + for i, non_lin in enumerate(NON_LINEARITIES): + ax.plot(x_axis, all_mses[i], label=non_lin) + plt.grid() + plt.legend(loc='best') + plt.ylabel('Mean Absolute Error') + plt.savefig(save_dir + 'extrapolation.png', format='png', dpi=300) + plt.show() + + +if __name__ == '__main__': + main() diff --git a/imgs/extrapolation.png b/imgs/extrapolation.png new file mode 100644 index 0000000..778852a Binary files /dev/null and b/imgs/extrapolation.png differ diff --git a/main.py b/main.py new file mode 100644 index 0000000..e69de29 diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..0f2e1bf --- /dev/null +++ b/models/__init__.py @@ -0,0 +1 @@ +from .mlp import MLP diff --git a/models/mlp.py b/models/mlp.py new file mode 100644 index 0000000..032eb34 --- /dev/null +++ b/models/mlp.py @@ -0,0 +1,61 @@ +import math +import torch.nn as nn +import torch.nn.functional as F + + +class MLP(nn.Module): + def __init__(self, activation, input_dim=1, encoding_dim=8): + super().__init__() + + if activation is 'hardtanh': + self.activation = nn.Hardtanh() + elif activation is 'sigmoid': + self.activation = nn.Sigmoid() + elif activation is 'relu6': + self.activation = nn.ReLU6() + elif activation is 'tanh': + self.activation = nn.Tanh() + elif activation is 'tanhshrink': + self.activation = nn.Tanhshrink() + elif activation is 'hardshrink': + self.activation = nn.Hardshrink() + elif activation is 'leakyrelu': + self.activation = nn.LeakyReLU() + elif activation is 'softshrink': + self.activation = nn.Softshrink() + elif activation is 'softsign': + self.activation = nn.Softsign() + elif activation is 'relu': + self.activation = nn.ReLU() + elif activation is 'prelu': + self.activation = nn.PReLU() + elif activation is 'softplus': + self.activation = nn.Softplus() + elif activation is 'elu': + self.activation = nn.ELU() + elif activation is 'selu': + self.activation = nn.SELU() + else: + raise ValueError("[!] Invalid activation function.") + + self.i2h = nn.Linear(input_dim, encoding_dim) + self.h2h1 = nn.Linear(encoding_dim, encoding_dim) + self.h2h2 = nn.Linear(encoding_dim, encoding_dim) + self.h2h3 = nn.Linear(encoding_dim, encoding_dim) + self.h2o = nn.Linear(encoding_dim, input_dim) + + # init + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5)) + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(m.bias, -bound, bound) + + def forward(self, x): + out = self.activation(self.i2h(x)) + out = self.activation(self.h2h1(out)) + out = self.activation(self.h2h2(out)) + out = self.activation(self.h2h3(out)) + out = self.h2o(out) + return out diff --git a/trainer.py b/trainer.py new file mode 100644 index 0000000..e69de29 diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..e69de29