From 309833ed14c7e1de3e079f4076f2664caee8ac6b Mon Sep 17 00:00:00 2001 From: wassname Date: Fri, 13 Nov 2020 15:56:16 +0800 Subject: [PATCH] hyperparam opt --- notebooks/07.1-mc-optuna.ipynb | 1628 ++++++++++++++++++++ notebooks/07.1-mc-optuna.py | 539 +++++++ seq2seq_time/models/neural_process.py | 6 +- seq2seq_time/models/transformer_process.py | 7 +- 4 files changed, 2175 insertions(+), 5 deletions(-) create mode 100644 notebooks/07.1-mc-optuna.ipynb create mode 100644 notebooks/07.1-mc-optuna.py diff --git a/notebooks/07.1-mc-optuna.ipynb b/notebooks/07.1-mc-optuna.ipynb new file mode 100644 index 0000000..462068e --- /dev/null +++ b/notebooks/07.1-mc-optuna.ipynb @@ -0,0 +1,1628 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "ExecuteTime": { + "end_time": "2020-10-10T01:25:12.788851Z", + "start_time": "2020-10-10T01:25:12.783398Z" + } + }, + "source": [ + "# Sequence to Sequence Models for Timeseries Regression\n", + "\n", + "\n", + "In this notebook we are going to find the optimal hidden_size for a model vs a dataset. We will use pytorch lightning and optuna." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "ExecuteTime": { + "end_time": "2020-11-13T07:55:31.364813Z", + "start_time": "2020-11-13T07:55:31.025418Z" + } + }, + "outputs": [], + "source": [ + "# OPTIONAL: Load the \"autoreload\" extension so that code can change. But blacklist large modules\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "%aimport -pandas\n", + "%aimport -torch\n", + "%aimport -numpy\n", + "%aimport -matplotlib\n", + "%aimport -dask\n", + "%aimport -tqdm\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "ExecuteTime": { + "end_time": "2020-11-13T07:55:32.324191Z", + "start_time": "2020-11-13T07:55:31.367170Z" + }, + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "# Imports\n", + "import torch\n", + "from torch import nn, optim\n", + "from torch.nn import functional as F\n", + "from torch.autograd import Variable\n", + "import torch\n", + "import torch.utils.data\n", + "\n", + "\n", + "from pathlib import Path\n", + "from tqdm.auto import tqdm\n", + "\n", + "import pytorch_lightning as pl" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "ExecuteTime": { + "end_time": "2020-11-13T07:55:32.863984Z", + "start_time": "2020-11-13T07:55:32.326782Z" + } + }, + "outputs": [], + "source": [ + "from seq2seq_time.data.dataset import Seq2SeqDataSet, Seq2SeqDataSets\n", + "from seq2seq_time.predict import predict, predict_multi\n", + "from seq2seq_time.util import dset_to_nc" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "ExecuteTime": { + "end_time": "2020-11-13T07:55:32.900137Z", + "start_time": "2020-11-13T07:55:32.866551Z" + } + }, + "outputs": [], + "source": [ + "import logging\n", + "import warnings\n", + "import seq2seq_time.silence \n", + "warnings.simplefilter('once')\n", + "warnings.simplefilter(action='ignore', category=FutureWarning)\n", + "warnings.simplefilter(action='ignore', category=DeprecationWarning)\n", + "warnings.filterwarnings('ignore', 'Consider increasing the value of the `num_workers` argument', UserWarning)\n", + "warnings.filterwarnings('ignore', 'Your val_dataloader has `shuffle=True`', UserWarning)\n", + "\n", + "from pytorch_lightning import _logger as log\n", + "log.setLevel(logging.WARN)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "ExecuteTime": { + "end_time": "2020-10-10T01:28:32.492160Z", + "start_time": "2020-10-10T01:28:32.488140Z" + } + }, + "source": [ + "## Parameters" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "ExecuteTime": { + "end_time": "2020-11-13T07:55:32.992818Z", + "start_time": "2020-11-13T07:55:32.902453Z" + }, + "lines_to_next_cell": 0 + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "using cuda\n", + "20201108-300000\n" + ] + }, + { + "data": { + "text/plain": [ + "336" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "print(f'using {device}')\n", + "\n", + "timestamp = '20201108-300000'\n", + "print(timestamp)\n", + "window_past = 48*7\n", + "window_future = 48\n", + "batch_size = 64\n", + "num_workers = 4\n", + "datasets_root = Path('../data/processed/')\n", + "window_past" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Datasets\n", + "\n", + "From easy to hard, these dataset show different challenges, all of them with more than 20k datapoints and with a regression output. See the 00.01 notebook for more details, and the code for more information.\n", + "\n", + "Some such as MetroInterstateTraffic are easier, some are periodic such as BejingPM25, some are conditional on inputs such as GasSensor, and some are noisy and periodic like IMOSCurrentsVel" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "ExecuteTime": { + "end_time": "2020-11-13T07:55:33.715323Z", + "start_time": "2020-11-13T07:55:33.238397Z" + }, + "lines_to_next_cell": 0 + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[seq2seq_time.data.data.MetroInterstateTraffic,\n", + " seq2seq_time.data.data.IMOSCurrentsVel,\n", + " seq2seq_time.data.data.GasSensor,\n", + " seq2seq_time.data.data.AppliancesEnergyPrediction,\n", + " seq2seq_time.data.data.BejingPM25]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from seq2seq_time.data.data import IMOSCurrentsVel, AppliancesEnergyPrediction, BejingPM25, GasSensor, MetroInterstateTraffic\n", + "datasets = [MetroInterstateTraffic, IMOSCurrentsVel, GasSensor, AppliancesEnergyPrediction, BejingPM25]\n", + "datasets\n", + "# ## Lightning\n", + "#\n", + "# We will use pytorch lightning to handle all the training scaffolding. We have a common pytorch lightning class that takes in the model and defines training steps and logging." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "ExecuteTime": { + "end_time": "2020-11-13T07:55:34.390203Z", + "start_time": "2020-11-13T07:55:34.338792Z" + } + }, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "\n", + "class PL_MODEL(pl.LightningModule):\n", + " def __init__(self, model, lr=3e-4, patience=None, weight_decay=0):\n", + " super().__init__()\n", + " self._model = model\n", + " self.lr = lr\n", + " self.patience = patience\n", + " self.weight_decay = weight_decay\n", + "\n", + " def forward(self, x_past, y_past, x_future, y_future=None):\n", + " \"\"\"Eval/Predict\"\"\"\n", + " y_dist, extra = self._model(x_past, y_past, x_future, y_future)\n", + " return y_dist, extra\n", + "\n", + " def training_step(self, batch, batch_idx, phase='train'):\n", + " x_past, y_past, x_future, y_future = batch\n", + " y_dist, extra = self.forward(*batch)\n", + " loss = -y_dist.log_prob(y_future).mean()\n", + " self.log_dict({f'loss/{phase}':loss})\n", + " if ('loss' in extra) and (phase=='train'):\n", + " # some models have a special loss\n", + " loss = extra['loss']\n", + " self.log_dict({f'model_loss/{phase}':loss})\n", + " return loss\n", + "\n", + " def validation_step(self, batch, batch_idx):\n", + " return self.training_step(batch, batch_idx, phase='val')\n", + " \n", + " def test_step(self, batch, batch_idx):\n", + " return self.training_step(batch, batch_idx, phase='test')\n", + " \n", + " def configure_optimizers(self):\n", + " optim = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)\n", + " scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n", + " optim,\n", + " patience=self.patience,\n", + " verbose=False,\n", + " min_lr=1e-7,\n", + " ) if self.patience else None\n", + " return {'optimizer': optim, 'lr_scheduler': scheduler, 'monitor': 'loss/val'}" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "ExecuteTime": { + "end_time": "2020-11-13T07:55:34.653519Z", + "start_time": "2020-11-13T07:55:34.605285Z" + }, + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "from torch.utils.data import DataLoader\n", + "from pytorch_lightning.loggers import CSVLogger, WandbLogger, TensorBoardLogger, TestTubeLogger\n", + "from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n", + "from pytorch_lightning.callbacks import LearningRateMonitor" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Models" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "ExecuteTime": { + "end_time": "2020-11-13T07:55:35.693992Z", + "start_time": "2020-11-13T07:55:35.629427Z" + }, + "lines_to_end_of_cell_marker": 2, + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "from seq2seq_time.models.baseline import BaselineLast, BaselineMean\n", + "from seq2seq_time.models.lstm_seq2seq import LSTMSeq2Seq\n", + "from seq2seq_time.models.lstm import LSTM\n", + "from seq2seq_time.models.transformer import Transformer\n", + "from seq2seq_time.models.transformer_seq2seq import TransformerSeq2Seq\n", + "from seq2seq_time.models.neural_process import RANP\n", + "from seq2seq_time.models.transformer_process import TransformerProcess\n", + "from seq2seq_time.models.tcn import TCNSeq\n", + "from seq2seq_time.models.inceptiontime import InceptionTimeSeq\n", + "from seq2seq_time.models.xattention import CrossAttention" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "ExecuteTime": { + "end_time": "2020-11-13T07:55:36.024789Z", + "start_time": "2020-11-13T07:55:35.985433Z" + }, + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "import gc\n", + "\n", + "def free_mem():\n", + " gc.collect()\n", + " torch.cuda.empty_cache()\n", + " gc.collect()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-11-02T06:10:41.904480Z", + "start_time": "2020-11-02T06:10:41.848613Z" + }, + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "ExecuteTime": { + "end_time": "2020-11-13T07:55:37.813853Z", + "start_time": "2020-11-13T07:55:37.768348Z" + }, + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "# PARAMS: model\n", + "dropout=0.0\n", + "layers=6\n", + "nhead=4\n", + "\n", + "models = [\n", + "# lambda xs, ys: BaselineLast(),\n", + "# lambda xs, ys, hidden_size: BaselineMean(),\n", + " lambda xs, ys, hidden_size, layers:TransformerProcess(xs,\n", + " ys, hidden_size=hidden_size, nhead=nhead,\n", + " latent_dim=hidden_size//2, dropout=dropout,\n", + " nlayers=layers),\n", + " lambda xs, ys, hidden_size, layers: RANP(xs,\n", + " ys, hidden_dim=hidden_size, dropout=dropout, \n", + " latent_dim=hidden_size//2, n_decoder_layers=layers, n_latent_encoder_layers=layers, n_det_encoder_layers=layers),\n", + " lambda xs, ys, hidden_size, layers:TCNSeq(xs, ys, hidden_size=hidden_size, nlayers=layers, dropout=dropout, kernel_size=2),\n", + " lambda xs, ys, hidden_size, layers: Transformer(xs,\n", + " ys,\n", + " attention_dropout=dropout,\n", + " nhead=nhead,\n", + " nlayers=layers,\n", + " hidden_size=hidden_size),\n", + " lambda xs, ys, hidden_size, layers: LSTM(xs,\n", + " ys,\n", + " hidden_size=hidden_size,\n", + " lstm_layers=layers//2,\n", + " lstm_dropout=dropout),\n", + "\n", + " lambda xs, ys, hidden_size, layers: TransformerSeq2Seq(xs,\n", + " ys,\n", + " hidden_size=hidden_size,\n", + " nhead=nhead,\n", + " nlayers=layers,\n", + " attention_dropout=dropout\n", + " ),\n", + "\n", + " lambda xs, ys, hidden_size, layers: LSTMSeq2Seq(xs,\n", + " ys,\n", + " hidden_size=hidden_size,\n", + " lstm_layers=layers//2,\n", + " lstm_dropout=dropout),\n", + " lambda xs, ys, hidden_size, layers: CrossAttention(xs,\n", + " ys,\n", + " nlayers=layers,\n", + " hidden_size=hidden_size,),\n", + " lambda xs, ys, hidden_size, layers: InceptionTimeSeq(xs,\n", + " ys,\n", + " kernel_size=96,\n", + " layers=layers//2,\n", + " hidden_size=hidden_size,\n", + " bottleneck=hidden_size//4)\n", + "\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-11-12T23:16:29.248291Z", + "start_time": "2020-11-12T23:16:16.139941Z" + } + }, + "outputs": [], + "source": [ + "# DEBUG: sanity check\n", + "\n", + "for Dataset in datasets:\n", + " dataset_name = Dataset.__name__\n", + " dataset = Dataset(datasets_root)\n", + " ds_train, ds_val, ds_test = dataset.to_datasets(window_past=window_past,\n", + " window_future=window_future)\n", + "\n", + " # Init data\n", + " x_past, y_past, x_future, y_future = ds_train.get_rows(10)\n", + " xs = x_past.shape[-1]\n", + " ys = y_future.shape[-1]\n", + "\n", + " # Loaders\n", + " dl_train = DataLoader(ds_train,\n", + " batch_size=batch_size,\n", + " shuffle=True,\n", + " pin_memory=num_workers == 0,\n", + " num_workers=num_workers)\n", + " dl_val = DataLoader(ds_val,\n", + " shuffle=True,\n", + " batch_size=batch_size,\n", + " num_workers=num_workers)\n", + "\n", + " for m_fn in models:\n", + " free_mem()\n", + " pt_model = m_fn(xs, ys, 8, 4)\n", + " model_name = type(pt_model).__name__\n", + " print(timestamp, dataset_name, model_name)\n", + "\n", + " # Wrap in lightning\n", + " model = PL_MODEL(pt_model,\n", + " lr=3e-4\n", + " ).to(device)\n", + " trainer = pl.Trainer(\n", + " fast_dev_run=True,\n", + " # GPU\n", + " gpus=1,\n", + " amp_level='O1',\n", + " precision=16,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "ExecuteTime": { + "end_time": "2020-10-23T23:36:11.052891Z", + "start_time": "2020-10-23T23:36:11.048874Z" + } + }, + "source": [ + "## Train" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-11-12T23:16:29.286762Z", + "start_time": "2020-11-12T23:16:29.250367Z" + }, + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "max_iters=20000" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-11-12T23:16:29.326846Z", + "start_time": "2020-11-12T23:16:29.289004Z" + } + }, + "outputs": [], + "source": [ + "tensorboard_dir = Path(f\"../outputs/{timestamp}\").resolve()\n", + "print(f'For tensorboard run:\\ntensorboard --logdir=\"{tensorboard_dir}\"')" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "ExecuteTime": { + "end_time": "2020-11-13T07:55:48.561522Z", + "start_time": "2020-11-13T07:55:48.516182Z" + }, + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "\n", + "def objective(trial):\n", + " \"\"\"\n", + " Optuna function to optimize\n", + " \n", + " See https://github.com/optuna/optuna/blob/master/examples/pytorch_lightning_simple.py\n", + " \"\"\"\n", + " # sample\n", + " hidden_size_exp = trial.suggest_int(\"hidden_size_exp\", 1, 8)\n", + " hidden_size = 2**hidden_size_exp\n", + " \n", + " layers = trial.suggest_int(\"layers\", 1, 12)\n", + " \n", + " # Load model\n", + " pt_model = m_fn(xs, ys, hidden_size, layers)\n", + " model_name = type(pt_model).__name__\n", + " \n", + " # Wrap in lightning\n", + " patience = 2\n", + " model = PL_MODEL(pt_model,\n", + " lr=3e-4, patience=patience,\n", + " ).to(device)\n", + "\n", + " \n", + " save_dir = f\"../outputs/{timestamp}/{dataset_name}_{model_name}/{trial.number}\"\n", + " Path(save_dir).mkdir(exist_ok=True, parents=True)\n", + " trainer = pl.Trainer(\n", + " # Training length\n", + " min_epochs=2,\n", + " max_epochs=40,\n", + " limit_train_batches=max_iters//batch_size,\n", + " limit_val_batches=max_iters//batch_size//5,\n", + " # Misc\n", + " gradient_clip_val=20,\n", + " terminate_on_nan=True,\n", + " # GPU\n", + " gpus=1,\n", + " amp_level='O1',\n", + " precision=16,\n", + " # Callbacks\n", + " default_root_dir=save_dir,\n", + " logger=False,\n", + " callbacks=[\n", + " EarlyStopping(monitor='loss/val', patience=patience * 2),\n", + " PyTorchLightningPruningCallback(trial, monitor=\"loss/val\")],\n", + " )\n", + " trainer.fit(model, dl_train, dl_val)\n", + " \n", + " # Run on all val data, using test mode\n", + " r = trainer.test(model, test_dataloaders=dl_val, verbose=False)\n", + " return r[0]['loss/test']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-11-08T02:45:44.106583Z", + "start_time": "2020-11-08T02:45:44.050637Z" + }, + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "ExecuteTime": { + "end_time": "2020-11-13T07:55:49.497334Z", + "start_time": "2020-11-13T07:55:48.981040Z" + } + }, + "outputs": [], + "source": [ + "import optuna\n", + "from optuna.integration import PyTorchLightningPruningCallback" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "ExecuteTime": { + "end_time": "2020-11-13T07:55:49.543035Z", + "start_time": "2020-11-13T07:55:49.499555Z" + } + }, + "outputs": [], + "source": [ + "import subprocess\n", + "def get_git_commit():\n", + " try:\n", + " return subprocess.check_output([\"git\", \"rev-parse\", \"HEAD\"], cwd='..').decode().strip()\n", + " except Exception:\n", + " logging.exception(\"failed to get git hash\")" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "ExecuteTime": { + "end_time": "2020-11-13T07:55:50.265327Z", + "start_time": "2020-11-13T07:55:50.204384Z" + } + }, + "outputs": [], + "source": [ + "Path(f\"../outputs/{timestamp}\").mkdir(exist_ok=True)\n", + "storage = f\"sqlite:///../outputs/{timestamp}/optuna.db\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-11-13T01:44:13.554355Z", + "start_time": "2020-11-12T23:16:30.128631Z" + }, + "lines_to_next_cell": 0, + "scrolled": true + }, + "outputs": [], + "source": [ + "for Dataset in tqdm(datasets, desc='datasets'):\n", + " dataset_name = Dataset.__name__\n", + " dataset = Dataset(datasets_root)\n", + " ds_train, ds_val, ds_test = dataset.to_datasets(window_past=window_past,\n", + " window_future=window_future)\n", + "\n", + " # Init data\n", + " x_past, y_past, x_future, y_future = ds_train.get_rows(10)\n", + " xs = x_past.shape[-1]\n", + " ys = y_future.shape[-1]\n", + "\n", + " # Loaders\n", + " dl_train = DataLoader(ds_train,\n", + " batch_size=batch_size,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " pin_memory=num_workers == 0,\n", + " num_workers=num_workers)\n", + " dl_val = DataLoader(ds_val,\n", + " shuffle=False,\n", + " batch_size=batch_size,\n", + " drop_last=True,\n", + " num_workers=num_workers)\n", + "\n", + " for i, m_fn in enumerate(tqdm(models, desc=f'models ({dataset_name})')):\n", + " try:\n", + " model_name = type(m_fn(8, 8, 8, 2)).__name__\n", + " free_mem()\n", + " study_name = f'{timestamp}_{dataset_name}-{model_name}'\n", + " \n", + " # Create study \n", + " pruner = optuna.pruners.MedianPruner()\n", + " study = optuna.create_study(storage=storage, \n", + " study_name=study_name, \n", + " pruner=pruner,\n", + " load_if_exists=True)\n", + " study.set_user_attr('dataset', dataset_name)\n", + " study.set_user_attr('model', model_name)\n", + " study.set_user_attr('commit', get_git_commit())\n", + " \n", + " df_trials = study.trials_dataframe()\n", + " if len(df_trials):\n", + " df_trials = df_trials[df_trials.state=='COMPLETE']\n", + " nb_trials = len(df_trials)\n", + " if nb_trials==0:\n", + " # Priors\n", + " study.enqueue_trial({\"layers\": 6, \"params_hidden_size_exp\": 2})\n", + " study.enqueue_trial({\"layers\": 1, \"params_hidden_size_exp\": 3})\n", + " study.enqueue_trial({\"layers\": 3, \"params_hidden_size_exp\": 5})\n", + " if nb_trials<20:\n", + " # Opt\n", + " study.optimize(objective, n_trials=20-nb_trials, \n", + " timeout=60*60*2 # Max seconds for all optimizes\n", + " )\n", + " \n", + " \n", + " print(\"Number of finished trials: {}\".format(len(study.trials)))\n", + "\n", + " print(\"Best trial:\")\n", + " trial = study.best_trial\n", + "\n", + " print(\" Value: {}\".format(trial.value))\n", + "\n", + " print(\" Params: \")\n", + " for key, value in trial.params.items():\n", + " print(\" {}: {}\".format(key, value))\n", + " \n", + " except Exception as e:\n", + " logging.exception('failed to run model')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-11-12T23:08:15.792627Z", + "start_time": "2020-11-12T23:08:15.700279Z" + } + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-11-13T01:44:20.959637Z", + "start_time": "2020-11-13T01:44:13.560463Z" + } + }, + "outputs": [], + "source": [ + "# Baseline\n", + "\n", + "models2 = [\n", + " lambda xs, ys, _, l: BaselineLast(),\n", + " lambda xs, ys, _, l: BaselineMean(),\n", + "]\n", + "\n", + "for Dataset in tqdm(datasets, desc='datasets'):\n", + " dataset_name = Dataset.__name__\n", + " dataset = Dataset(datasets_root)\n", + " ds_train, ds_val, ds_test = dataset.to_datasets(window_past=window_past,\n", + " window_future=window_future)\n", + "\n", + " # Init data\n", + " x_past, y_past, x_future, y_future = ds_train.get_rows(10)\n", + " xs = x_past.shape[-1]\n", + " ys = y_future.shape[-1]\n", + "\n", + " # Loaders\n", + " dl_train = DataLoader(ds_train,\n", + " batch_size=batch_size,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " pin_memory=num_workers == 0,\n", + " num_workers=num_workers)\n", + " dl_val = DataLoader(ds_val,\n", + " shuffle=False,\n", + " batch_size=batch_size,\n", + " drop_last=True,\n", + " num_workers=num_workers)\n", + "\n", + " for i, m_fn in enumerate(tqdm(models2, desc=f'models ({dataset_name})')):\n", + " try:\n", + " model_name = type(m_fn(8, 8, 8, 2)).__name__\n", + " free_mem()\n", + " study_name = f'{timestamp}_{dataset_name}-{model_name}'\n", + " \n", + " # Create study \n", + " pruner = optuna.pruners.MedianPruner()\n", + " study = optuna.create_study(storage=storage, \n", + " study_name=study_name, \n", + " pruner=pruner,\n", + " load_if_exists=True)\n", + " study.set_user_attr('dataset', dataset_name)\n", + " study.set_user_attr('model', model_name)\n", + " study.set_user_attr('commit', get_git_commit())\n", + " \n", + " df_trials = study.trials_dataframe()\n", + " if len(df_trials):\n", + " df_trials = df_trials[df_trials.state=='COMPLETE']\n", + " nb_trials = len(df_trials)\n", + " if nb_trials<1:\n", + " # Opt\n", + " study.optimize(objective, n_trials=1, \n", + " timeout=60*30 # Max seconds for all optimizes\n", + " )\n", + " \n", + " except Exception as e:\n", + " logging.exception('failed to run model')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-11-13T01:44:21.021030Z", + "start_time": "2020-11-13T01:44:20.962459Z" + } + }, + "outputs": [], + "source": [ + "# TODO baseline, run as sep cell, opt once\n", + "# TODO summarize time and model params at best params" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "ExecuteTime": { + "end_time": "2020-11-13T07:55:54.138705Z", + "start_time": "2020-11-13T07:55:53.706642Z" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
n_trialsparam_hidden_size_expparam_layersvalue
datasetmodel
AppliancesEnergyPredictionTCNSeq48211.075230
InceptionTimeSeq17131.099332
TransformerProcess18431.163042
LSTM19631.172367
TransformerSeq2Seq17211.174764
Transformer19461.195765
LSTMSeq2Seq29461.204848
RANP164101.278206
BaselineMean1651.322796
BaselineLast1211.475730
CrossAttention17451.547343
BejingPM25TCNSeq46331.235752
InceptionTimeSeq15131.239024
LSTM18231.271591
LSTMSeq2Seq17561.292539
Transformer53611.295221
TransformerProcess16211.404077
CrossAttention178111.411578
RANP16411.429652
BaselineMean1631.435451
BaselineLast14121.548376
TransformerSeq2Seq16482.393506
GasSensorRANP18610-2.133501
InceptionTimeSeq33510-2.102463
Transformer50612-1.960168
TCNSeq4168-1.736079
LSTM1865-1.537527
LSTMSeq2Seq2085-1.489893
TransformerProcess3874-0.880734
CrossAttention1773-0.638713
TransformerSeq2Seq28610.336668
BaselineMean1321.584720
BaselineLast1261.974204
IMOSCurrentsVelTCNSeq52460.817371
InceptionTimeSeq20360.845443
LSTM22560.875669
Transformer24230.879682
BaselineLast1360.885377
LSTMSeq2Seq247120.887783
RANP19531.035464
BaselineMean1571.202492
TransformerSeq2Seq19311.266849
TransformerProcess27511.394595
CrossAttention17331.656854
MetroInterstateTrafficTCNSeq6173-0.324690
TransformerProcess1737-0.297472
RANP1633-0.291096
Transformer20411-0.253573
LSTMSeq2Seq1555-0.200625
LSTM1745-0.200455
TransformerSeq2Seq1526-0.192895
InceptionTimeSeq1736-0.162513
CrossAttention1461-0.103853
BaselineMean1311.410653
BaselineLast12121.741882
\n", + "
" + ], + "text/plain": [ + " n_trials \\\n", + "dataset model \n", + "AppliancesEnergyPrediction TCNSeq 48 \n", + " InceptionTimeSeq 17 \n", + " TransformerProcess 18 \n", + " LSTM 19 \n", + " TransformerSeq2Seq 17 \n", + " Transformer 19 \n", + " LSTMSeq2Seq 29 \n", + " RANP 16 \n", + " BaselineMean 1 \n", + " BaselineLast 1 \n", + " CrossAttention 17 \n", + "BejingPM25 TCNSeq 46 \n", + " InceptionTimeSeq 15 \n", + " LSTM 18 \n", + " LSTMSeq2Seq 17 \n", + " Transformer 53 \n", + " TransformerProcess 16 \n", + " CrossAttention 17 \n", + " RANP 16 \n", + " BaselineMean 1 \n", + " BaselineLast 1 \n", + " TransformerSeq2Seq 16 \n", + "GasSensor RANP 18 \n", + " InceptionTimeSeq 33 \n", + " Transformer 50 \n", + " TCNSeq 41 \n", + " LSTM 18 \n", + " LSTMSeq2Seq 20 \n", + " TransformerProcess 38 \n", + " CrossAttention 17 \n", + " TransformerSeq2Seq 28 \n", + " BaselineMean 1 \n", + " BaselineLast 1 \n", + "IMOSCurrentsVel TCNSeq 52 \n", + " InceptionTimeSeq 20 \n", + " LSTM 22 \n", + " Transformer 24 \n", + " BaselineLast 1 \n", + " LSTMSeq2Seq 24 \n", + " RANP 19 \n", + " BaselineMean 1 \n", + " TransformerSeq2Seq 19 \n", + " TransformerProcess 27 \n", + " CrossAttention 17 \n", + "MetroInterstateTraffic TCNSeq 61 \n", + " TransformerProcess 17 \n", + " RANP 16 \n", + " Transformer 20 \n", + " LSTMSeq2Seq 15 \n", + " LSTM 17 \n", + " TransformerSeq2Seq 15 \n", + " InceptionTimeSeq 17 \n", + " CrossAttention 14 \n", + " BaselineMean 1 \n", + " BaselineLast 1 \n", + "\n", + " param_hidden_size_exp \\\n", + "dataset model \n", + "AppliancesEnergyPrediction TCNSeq 2 \n", + " InceptionTimeSeq 1 \n", + " TransformerProcess 4 \n", + " LSTM 6 \n", + " TransformerSeq2Seq 2 \n", + " Transformer 4 \n", + " LSTMSeq2Seq 4 \n", + " RANP 4 \n", + " BaselineMean 6 \n", + " BaselineLast 2 \n", + " CrossAttention 4 \n", + "BejingPM25 TCNSeq 3 \n", + " InceptionTimeSeq 1 \n", + " LSTM 2 \n", + " LSTMSeq2Seq 5 \n", + " Transformer 6 \n", + " TransformerProcess 2 \n", + " CrossAttention 8 \n", + " RANP 4 \n", + " BaselineMean 6 \n", + " BaselineLast 4 \n", + " TransformerSeq2Seq 4 \n", + "GasSensor RANP 6 \n", + " InceptionTimeSeq 5 \n", + " Transformer 6 \n", + " TCNSeq 6 \n", + " LSTM 6 \n", + " LSTMSeq2Seq 8 \n", + " TransformerProcess 7 \n", + " CrossAttention 7 \n", + " TransformerSeq2Seq 6 \n", + " BaselineMean 3 \n", + " BaselineLast 2 \n", + "IMOSCurrentsVel TCNSeq 4 \n", + " InceptionTimeSeq 3 \n", + " LSTM 5 \n", + " Transformer 2 \n", + " BaselineLast 3 \n", + " LSTMSeq2Seq 7 \n", + " RANP 5 \n", + " BaselineMean 5 \n", + " TransformerSeq2Seq 3 \n", + " TransformerProcess 5 \n", + " CrossAttention 3 \n", + "MetroInterstateTraffic TCNSeq 7 \n", + " TransformerProcess 3 \n", + " RANP 3 \n", + " Transformer 4 \n", + " LSTMSeq2Seq 5 \n", + " LSTM 4 \n", + " TransformerSeq2Seq 2 \n", + " InceptionTimeSeq 3 \n", + " CrossAttention 6 \n", + " BaselineMean 3 \n", + " BaselineLast 2 \n", + "\n", + " param_layers value \n", + "dataset model \n", + "AppliancesEnergyPrediction TCNSeq 1 1.075230 \n", + " InceptionTimeSeq 3 1.099332 \n", + " TransformerProcess 3 1.163042 \n", + " LSTM 3 1.172367 \n", + " TransformerSeq2Seq 1 1.174764 \n", + " Transformer 6 1.195765 \n", + " LSTMSeq2Seq 6 1.204848 \n", + " RANP 10 1.278206 \n", + " BaselineMean 5 1.322796 \n", + " BaselineLast 1 1.475730 \n", + " CrossAttention 5 1.547343 \n", + "BejingPM25 TCNSeq 3 1.235752 \n", + " InceptionTimeSeq 3 1.239024 \n", + " LSTM 3 1.271591 \n", + " LSTMSeq2Seq 6 1.292539 \n", + " Transformer 1 1.295221 \n", + " TransformerProcess 1 1.404077 \n", + " CrossAttention 11 1.411578 \n", + " RANP 1 1.429652 \n", + " BaselineMean 3 1.435451 \n", + " BaselineLast 12 1.548376 \n", + " TransformerSeq2Seq 8 2.393506 \n", + "GasSensor RANP 10 -2.133501 \n", + " InceptionTimeSeq 10 -2.102463 \n", + " Transformer 12 -1.960168 \n", + " TCNSeq 8 -1.736079 \n", + " LSTM 5 -1.537527 \n", + " LSTMSeq2Seq 5 -1.489893 \n", + " TransformerProcess 4 -0.880734 \n", + " CrossAttention 3 -0.638713 \n", + " TransformerSeq2Seq 1 0.336668 \n", + " BaselineMean 2 1.584720 \n", + " BaselineLast 6 1.974204 \n", + "IMOSCurrentsVel TCNSeq 6 0.817371 \n", + " InceptionTimeSeq 6 0.845443 \n", + " LSTM 6 0.875669 \n", + " Transformer 3 0.879682 \n", + " BaselineLast 6 0.885377 \n", + " LSTMSeq2Seq 12 0.887783 \n", + " RANP 3 1.035464 \n", + " BaselineMean 7 1.202492 \n", + " TransformerSeq2Seq 1 1.266849 \n", + " TransformerProcess 1 1.394595 \n", + " CrossAttention 3 1.656854 \n", + "MetroInterstateTraffic TCNSeq 3 -0.324690 \n", + " TransformerProcess 7 -0.297472 \n", + " RANP 3 -0.291096 \n", + " Transformer 11 -0.253573 \n", + " LSTMSeq2Seq 5 -0.200625 \n", + " LSTM 5 -0.200455 \n", + " TransformerSeq2Seq 6 -0.192895 \n", + " InceptionTimeSeq 6 -0.162513 \n", + " CrossAttention 1 -0.103853 \n", + " BaselineMean 1 1.410653 \n", + " BaselineLast 12 1.741882 " + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pandas as pd\n", + "# Summarize studies\n", + "rs = []\n", + "study_summaries = optuna.study.get_all_study_summaries(storage=storage)\n", + "for s in study_summaries:\n", + " row = {}\n", + " if (s.best_trial is not None) and (s.best_trial.state==optuna.trial.TrialState.COMPLETE):\n", + " params = {k:v for k,v in s.best_trial.__dict__.items() if not k.startswith('_')}\n", + " row.update(s.user_attrs)\n", + " row['n_trials'] = s.n_trials\n", + " row.update({'param_'+k:v for k,v in s.best_trial._params.items()})\n", + " row.update(params)\n", + " rs.append(row)\n", + "df_studies = pd.DataFrame(rs)\n", + "\n", + "df_studies = (df_studies.drop(columns=['state', 'intermediate_values', 'commit', 'datetime_complete'])\n", + " .sort_values(['dataset', 'value'])\n", + " .set_index(['dataset', 'model'])\n", + " )\n", + "df_studies" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-11-13T07:54:27.730531Z", + "start_time": "2020-11-13T07:54:27.667135Z" + }, + "lines_to_next_cell": 0, + "scrolled": true + }, + "outputs": [], + "source": [ + "# study_names = [s.study_name for s in optuna.study.get_all_study_summaries(storage=storage)]\n", + "\n", + "# for study_name in study_names:\n", + "# loaded_study = optuna.load_study(study_name=study_name, storage=storage)\n", + " \n", + "# # Make DF over trials\n", + "# print(study_name)\n", + " \n", + "# df_trials = loaded_study.trials_dataframe()\n", + "# # df_trials.index = df_trials.apply(lambda r:f'l={r.params_layers}_hs={r.params_hidden_size_exp}', 1)\n", + "# display(df_trials)\n", + "\n", + "# # Plot test curves, to see how much overfitting\n", + "# df_values = pd.DataFrame([s.intermediate_values for s in loaded_study.get_trials()]).T\n", + "# df_values.columns = [f\"l={s.params['layers']}_hs={s.params['hidden_size_exp']}\" for s in loaded_study.get_trials()]\n", + "# df_values.plot(ylabel='nll', xlabel='epochs', title=f'val loss \"{study_name}\"')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-11-13T01:44:25.350038Z", + "start_time": "2020-11-12T23:16:12.130Z" + }, + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "# study_names = [s.study_name for s in optuna.study.get_all_study_summaries(storage=storage)]\n", + "\n", + "# for study_name in study_names:\n", + "# loaded_study = optuna.load_study(study_name=study_name, storage=storage)\n", + "# fig=optuna.visualization.plot_contour(loaded_study, params=['hidden_size_exp', 'layers'])\n", + "# fig = fig.update_layout(dict(title=f'{study_name} nll'))\n", + "# display(fig)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-11-08T23:11:33.818862Z", + "start_time": "2020-11-08T23:11:33.749610Z" + } + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-11-08T23:11:14.588720Z", + "start_time": "2020-11-08T23:11:14.516020Z" + }, + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-11-13T01:44:25.351121Z", + "start_time": "2020-11-12T23:16:12.141Z" + } + }, + "outputs": [], + "source": [ + "[f\"{s.params['layers']}_{s.params['hidden_size_exp']}\" for s in loaded_study.get_trials()]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-11-13T01:44:25.352257Z", + "start_time": "2020-11-12T23:16:12.146Z" + }, + "scrolled": true + }, + "outputs": [], + "source": [ + "df_values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "@webio": { + "lastCommId": null, + "lastKernelId": null + }, + "jupytext": { + "encoding": "# -*- coding: utf-8 -*-", + "formats": "ipynb,py:light" + }, + "kernelspec": { + "display_name": "seq2seq-time", + "language": "python", + "name": "seq2seq-time" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.8" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": { + "height": "calc(100% - 180px)", + "left": "10px", + "top": "150px", + "width": "209.162px" + }, + "toc_section_display": true, + "toc_window_display": true + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/07.1-mc-optuna.py b/notebooks/07.1-mc-optuna.py new file mode 100644 index 0000000..5571cee --- /dev/null +++ b/notebooks/07.1-mc-optuna.py @@ -0,0 +1,539 @@ +# -*- coding: utf-8 -*- +# --- +# jupyter: +# jupytext: +# formats: ipynb,py:light +# text_representation: +# extension: .py +# format_name: light +# format_version: '1.5' +# jupytext_version: 1.6.0 +# kernelspec: +# display_name: seq2seq-time +# language: python +# name: seq2seq-time +# --- + +# # Sequence to Sequence Models for Timeseries Regression +# +# +# In this notebook we are going to find the optimal hidden_size for a model vs a dataset. We will use pytorch lightning and optuna. + +# OPTIONAL: Load the "autoreload" extension so that code can change. But blacklist large modules +# %load_ext autoreload +# %autoreload 2 +# %aimport -pandas +# %aimport -torch +# %aimport -numpy +# %aimport -matplotlib +# %aimport -dask +# %aimport -tqdm +# %matplotlib inline + +# + +# Imports +import torch +from torch import nn, optim +from torch.nn import functional as F +from torch.autograd import Variable +import torch +import torch.utils.data + + +from pathlib import Path +from tqdm.auto import tqdm + +import pytorch_lightning as pl +# - +from seq2seq_time.data.dataset import Seq2SeqDataSet, Seq2SeqDataSets +from seq2seq_time.predict import predict, predict_multi +from seq2seq_time.util import dset_to_nc + +# + +import logging +import warnings +import seq2seq_time.silence +warnings.simplefilter('once') +warnings.simplefilter(action='ignore', category=FutureWarning) +warnings.simplefilter(action='ignore', category=DeprecationWarning) +warnings.filterwarnings('ignore', 'Consider increasing the value of the `num_workers` argument', UserWarning) +warnings.filterwarnings('ignore', 'Your val_dataloader has `shuffle=True`', UserWarning) + +from pytorch_lightning import _logger as log +log.setLevel(logging.WARN) +# - + +# ## Parameters + +# + +device = "cuda" if torch.cuda.is_available() else "cpu" +print(f'using {device}') + +timestamp = '20201108-300000' +print(timestamp) +window_past = 48*7 +window_future = 48 +batch_size = 64 +num_workers = 4 +datasets_root = Path('../data/processed/') +window_past +# - +# ## Datasets +# +# From easy to hard, these dataset show different challenges, all of them with more than 20k datapoints and with a regression output. See the 00.01 notebook for more details, and the code for more information. +# +# Some such as MetroInterstateTraffic are easier, some are periodic such as BejingPM25, some are conditional on inputs such as GasSensor, and some are noisy and periodic like IMOSCurrentsVel + +from seq2seq_time.data.data import IMOSCurrentsVel, AppliancesEnergyPrediction, BejingPM25, GasSensor, MetroInterstateTraffic +datasets = [MetroInterstateTraffic, IMOSCurrentsVel, GasSensor, AppliancesEnergyPrediction, BejingPM25] +datasets +# ## Lightning +# +# We will use pytorch lightning to handle all the training scaffolding. We have a common pytorch lightning class that takes in the model and defines training steps and logging. +# + +import pytorch_lightning as pl + +class PL_MODEL(pl.LightningModule): + def __init__(self, model, lr=3e-4, patience=None, weight_decay=0): + super().__init__() + self._model = model + self.lr = lr + self.patience = patience + self.weight_decay = weight_decay + + def forward(self, x_past, y_past, x_future, y_future=None): + """Eval/Predict""" + y_dist, extra = self._model(x_past, y_past, x_future, y_future) + return y_dist, extra + + def training_step(self, batch, batch_idx, phase='train'): + x_past, y_past, x_future, y_future = batch + y_dist, extra = self.forward(*batch) + loss = -y_dist.log_prob(y_future).mean() + self.log_dict({f'loss/{phase}':loss}) + if ('loss' in extra) and (phase=='train'): + # some models have a special loss + loss = extra['loss'] + self.log_dict({f'model_loss/{phase}':loss}) + return loss + + def validation_step(self, batch, batch_idx): + return self.training_step(batch, batch_idx, phase='val') + + def test_step(self, batch, batch_idx): + return self.training_step(batch, batch_idx, phase='test') + + def configure_optimizers(self): + optim = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optim, + patience=self.patience, + verbose=False, + min_lr=1e-7, + ) if self.patience else None + return {'optimizer': optim, 'lr_scheduler': scheduler, 'monitor': 'loss/val'} + + +# - + +from torch.utils.data import DataLoader +from pytorch_lightning.loggers import CSVLogger, WandbLogger, TensorBoardLogger, TestTubeLogger +from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from pytorch_lightning.callbacks import LearningRateMonitor + + +# ## Models + +from seq2seq_time.models.baseline import BaselineLast, BaselineMean +from seq2seq_time.models.lstm_seq2seq import LSTMSeq2Seq +from seq2seq_time.models.lstm import LSTM +from seq2seq_time.models.transformer import Transformer +from seq2seq_time.models.transformer_seq2seq import TransformerSeq2Seq +from seq2seq_time.models.neural_process import RANP +from seq2seq_time.models.transformer_process import TransformerProcess +from seq2seq_time.models.tcn import TCNSeq +from seq2seq_time.models.inceptiontime import InceptionTimeSeq +from seq2seq_time.models.xattention import CrossAttention +# + +import gc + +def free_mem(): + gc.collect() + torch.cuda.empty_cache() + gc.collect() +# - + + + +# + +# PARAMS: model +dropout=0.0 +layers=6 +nhead=4 + +models = [ +# lambda xs, ys: BaselineLast(), +# lambda xs, ys, hidden_size: BaselineMean(), + lambda xs, ys, hidden_size, layers:TransformerProcess(xs, + ys, hidden_size=hidden_size, nhead=nhead, + latent_dim=hidden_size//2, dropout=dropout, + nlayers=layers), + lambda xs, ys, hidden_size, layers: RANP(xs, + ys, hidden_dim=hidden_size, dropout=dropout, + latent_dim=hidden_size//2, n_decoder_layers=layers, n_latent_encoder_layers=layers, n_det_encoder_layers=layers), + lambda xs, ys, hidden_size, layers:TCNSeq(xs, ys, hidden_size=hidden_size, nlayers=layers, dropout=dropout, kernel_size=2), + lambda xs, ys, hidden_size, layers: Transformer(xs, + ys, + attention_dropout=dropout, + nhead=nhead, + nlayers=layers, + hidden_size=hidden_size), + lambda xs, ys, hidden_size, layers: LSTM(xs, + ys, + hidden_size=hidden_size, + lstm_layers=layers//2, + lstm_dropout=dropout), + + lambda xs, ys, hidden_size, layers: TransformerSeq2Seq(xs, + ys, + hidden_size=hidden_size, + nhead=nhead, + nlayers=layers, + attention_dropout=dropout + ), + + lambda xs, ys, hidden_size, layers: LSTMSeq2Seq(xs, + ys, + hidden_size=hidden_size, + lstm_layers=layers//2, + lstm_dropout=dropout), + lambda xs, ys, hidden_size, layers: CrossAttention(xs, + ys, + nlayers=layers, + hidden_size=hidden_size,), + lambda xs, ys, hidden_size, layers: InceptionTimeSeq(xs, + ys, + kernel_size=96, + layers=layers//2, + hidden_size=hidden_size, + bottleneck=hidden_size//4) + +] +# + +# DEBUG: sanity check + +for Dataset in datasets: + dataset_name = Dataset.__name__ + dataset = Dataset(datasets_root) + ds_train, ds_val, ds_test = dataset.to_datasets(window_past=window_past, + window_future=window_future) + + # Init data + x_past, y_past, x_future, y_future = ds_train.get_rows(10) + xs = x_past.shape[-1] + ys = y_future.shape[-1] + + # Loaders + dl_train = DataLoader(ds_train, + batch_size=batch_size, + shuffle=True, + pin_memory=num_workers == 0, + num_workers=num_workers) + dl_val = DataLoader(ds_val, + shuffle=True, + batch_size=batch_size, + num_workers=num_workers) + + for m_fn in models: + free_mem() + pt_model = m_fn(xs, ys, 8, 4) + model_name = type(pt_model).__name__ + print(timestamp, dataset_name, model_name) + + # Wrap in lightning + model = PL_MODEL(pt_model, + lr=3e-4 + ).to(device) + trainer = pl.Trainer( + fast_dev_run=True, + # GPU + gpus=1, + amp_level='O1', + precision=16, + ) +# - + +# ## Train + +max_iters=20000 + + +tensorboard_dir = Path(f"../outputs/{timestamp}").resolve() +print(f'For tensorboard run:\ntensorboard --logdir="{tensorboard_dir}"') + + +# + + +def objective(trial): + """ + Optuna function to optimize + + See https://github.com/optuna/optuna/blob/master/examples/pytorch_lightning_simple.py + """ + # sample + hidden_size_exp = trial.suggest_int("hidden_size_exp", 1, 8) + hidden_size = 2**hidden_size_exp + + layers = trial.suggest_int("layers", 1, 12) + + # Load model + pt_model = m_fn(xs, ys, hidden_size, layers) + model_name = type(pt_model).__name__ + + # Wrap in lightning + patience = 2 + model = PL_MODEL(pt_model, + lr=3e-4, patience=patience, + ).to(device) + + + save_dir = f"../outputs/{timestamp}/{dataset_name}_{model_name}/{trial.number}" + Path(save_dir).mkdir(exist_ok=True, parents=True) + trainer = pl.Trainer( + # Training length + min_epochs=2, + max_epochs=40, + limit_train_batches=max_iters//batch_size, + limit_val_batches=max_iters//batch_size//5, + # Misc + gradient_clip_val=20, + terminate_on_nan=True, + # GPU + gpus=1, + amp_level='O1', + precision=16, + # Callbacks + default_root_dir=save_dir, + logger=False, + callbacks=[ + EarlyStopping(monitor='loss/val', patience=patience * 2), + PyTorchLightningPruningCallback(trial, monitor="loss/val")], + ) + trainer.fit(model, dl_train, dl_val) + + # Run on all val data, using test mode + r = trainer.test(model, test_dataloaders=dl_val, verbose=False) + return r[0]['loss/test'] +# - + + + +import optuna +from optuna.integration import PyTorchLightningPruningCallback + +import subprocess +def get_git_commit(): + try: + return subprocess.check_output(["git", "rev-parse", "HEAD"], cwd='..').decode().strip() + except Exception: + logging.exception("failed to get git hash") + + +Path(f"../outputs/{timestamp}").mkdir(exist_ok=True) +storage = f"sqlite:///../outputs/{timestamp}/optuna.db" + +for Dataset in tqdm(datasets, desc='datasets'): + dataset_name = Dataset.__name__ + dataset = Dataset(datasets_root) + ds_train, ds_val, ds_test = dataset.to_datasets(window_past=window_past, + window_future=window_future) + + # Init data + x_past, y_past, x_future, y_future = ds_train.get_rows(10) + xs = x_past.shape[-1] + ys = y_future.shape[-1] + + # Loaders + dl_train = DataLoader(ds_train, + batch_size=batch_size, + shuffle=True, + drop_last=True, + pin_memory=num_workers == 0, + num_workers=num_workers) + dl_val = DataLoader(ds_val, + shuffle=False, + batch_size=batch_size, + drop_last=True, + num_workers=num_workers) + + for i, m_fn in enumerate(tqdm(models, desc=f'models ({dataset_name})')): + try: + model_name = type(m_fn(8, 8, 8, 2)).__name__ + free_mem() + study_name = f'{timestamp}_{dataset_name}-{model_name}' + + # Create study + pruner = optuna.pruners.MedianPruner() + study = optuna.create_study(storage=storage, + study_name=study_name, + pruner=pruner, + load_if_exists=True) + study.set_user_attr('dataset', dataset_name) + study.set_user_attr('model', model_name) + study.set_user_attr('commit', get_git_commit()) + + df_trials = study.trials_dataframe() + if len(df_trials): + df_trials = df_trials[df_trials.state=='COMPLETE'] + nb_trials = len(df_trials) + if nb_trials==0: + # Priors + study.enqueue_trial({"layers": 6, "params_hidden_size_exp": 2}) + study.enqueue_trial({"layers": 1, "params_hidden_size_exp": 3}) + study.enqueue_trial({"layers": 3, "params_hidden_size_exp": 5}) + if nb_trials<20: + # Opt + study.optimize(objective, n_trials=20-nb_trials, + timeout=60*60*2 # Max seconds for all optimizes + ) + + + print("Number of finished trials: {}".format(len(study.trials))) + + print("Best trial:") + trial = study.best_trial + + print(" Value: {}".format(trial.value)) + + print(" Params: ") + for key, value in trial.params.items(): + print(" {}: {}".format(key, value)) + + except Exception as e: + logging.exception('failed to run model') + + +# + +# Baseline + +models2 = [ + lambda xs, ys, _, l: BaselineLast(), + lambda xs, ys, _, l: BaselineMean(), +] + +for Dataset in tqdm(datasets, desc='datasets'): + dataset_name = Dataset.__name__ + dataset = Dataset(datasets_root) + ds_train, ds_val, ds_test = dataset.to_datasets(window_past=window_past, + window_future=window_future) + + # Init data + x_past, y_past, x_future, y_future = ds_train.get_rows(10) + xs = x_past.shape[-1] + ys = y_future.shape[-1] + + # Loaders + dl_train = DataLoader(ds_train, + batch_size=batch_size, + shuffle=True, + drop_last=True, + pin_memory=num_workers == 0, + num_workers=num_workers) + dl_val = DataLoader(ds_val, + shuffle=False, + batch_size=batch_size, + drop_last=True, + num_workers=num_workers) + + for i, m_fn in enumerate(tqdm(models2, desc=f'models ({dataset_name})')): + try: + model_name = type(m_fn(8, 8, 8, 2)).__name__ + free_mem() + study_name = f'{timestamp}_{dataset_name}-{model_name}' + + # Create study + pruner = optuna.pruners.MedianPruner() + study = optuna.create_study(storage=storage, + study_name=study_name, + pruner=pruner, + load_if_exists=True) + study.set_user_attr('dataset', dataset_name) + study.set_user_attr('model', model_name) + study.set_user_attr('commit', get_git_commit()) + + df_trials = study.trials_dataframe() + if len(df_trials): + df_trials = df_trials[df_trials.state=='COMPLETE'] + nb_trials = len(df_trials) + if nb_trials<1: + # Opt + study.optimize(objective, n_trials=1, + timeout=60*30 # Max seconds for all optimizes + ) + + except Exception as e: + logging.exception('failed to run model') + +# + +# TODO baseline, run as sep cell, opt once +# TODO summarize time and model params at best params + +# + +import pandas as pd +# Summarize studies +rs = [] +study_summaries = optuna.study.get_all_study_summaries(storage=storage) +for s in study_summaries: + row = {} + if (s.best_trial is not None) and (s.best_trial.state==optuna.trial.TrialState.COMPLETE): + params = {k:v for k,v in s.best_trial.__dict__.items() if not k.startswith('_')} + row.update(s.user_attrs) + row['n_trials'] = s.n_trials + row.update({'param_'+k:v for k,v in s.best_trial._params.items()}) + row.update(params) + rs.append(row) +df_studies = pd.DataFrame(rs) + +df_studies = (df_studies.drop(columns=['state', 'intermediate_values', 'commit', 'datetime_complete']) + .sort_values(['dataset', 'value']) + .set_index(['dataset', 'model']) + ) +df_studies + +# + +# study_names = [s.study_name for s in optuna.study.get_all_study_summaries(storage=storage)] + +# for study_name in study_names: +# loaded_study = optuna.load_study(study_name=study_name, storage=storage) + +# # Make DF over trials +# print(study_name) + +# df_trials = loaded_study.trials_dataframe() +# # df_trials.index = df_trials.apply(lambda r:f'l={r.params_layers}_hs={r.params_hidden_size_exp}', 1) +# display(df_trials) + +# # Plot test curves, to see how much overfitting +# df_values = pd.DataFrame([s.intermediate_values for s in loaded_study.get_trials()]).T +# df_values.columns = [f"l={s.params['layers']}_hs={s.params['hidden_size_exp']}" for s in loaded_study.get_trials()] +# df_values.plot(ylabel='nll', xlabel='epochs', title=f'val loss "{study_name}"') +# + +# study_names = [s.study_name for s in optuna.study.get_all_study_summaries(storage=storage)] + +# for study_name in study_names: +# loaded_study = optuna.load_study(study_name=study_name, storage=storage) +# fig=optuna.visualization.plot_contour(loaded_study, params=['hidden_size_exp', 'layers']) +# fig = fig.update_layout(dict(title=f'{study_name} nll')) +# display(fig) +# - + + + + + +[f"{s.params['layers']}_{s.params['hidden_size_exp']}" for s in loaded_study.get_trials()] + +df_values + + + diff --git a/seq2seq_time/models/neural_process.py b/seq2seq_time/models/neural_process.py index 0b4472c..67b3609 100644 --- a/seq2seq_time/models/neural_process.py +++ b/seq2seq_time/models/neural_process.py @@ -407,9 +407,11 @@ class RANP(nn.Module): y = torch.cat([past_y, future_y], 1) dist_post, log_var_post = self._latent_encoder(x, y) - if self.training: - z = dist_prior.rsample() + if self.training and (future_y is not None): + # USe posterior during training, is possible + z = dist_post.rsample() else: + # During eval use the prior, also take the most probable z = dist_prior.loc num_targets = future_x.size(1) z = z.unsqueeze(1).repeat(1, num_targets, 1) # [B, T_target, H] diff --git a/seq2seq_time/models/transformer_process.py b/seq2seq_time/models/transformer_process.py index 0614016..bc836e4 100644 --- a/seq2seq_time/models/transformer_process.py +++ b/seq2seq_time/models/transformer_process.py @@ -162,10 +162,11 @@ class TransformerProcess(nn.Module): y = torch.cat([past_y, future_y], 1) dist_post = self._latent_encoder(x, y) - if self.training: - # Sample from latent space during training - z = dist_prior.rsample() + if self.training and (future_y is not None): + # USe posterior during training, is possible + z = dist_post.rsample() else: + # During eval use the prior, also take the most probable z = dist_prior.loc num_targets = future_x.size(1) z = z.unsqueeze(1).repeat(1, num_targets, 1) # [B, S_target, H]