diff --git a/notebooks/07.1-mc-optuna.ipynb b/notebooks/07.1-mc-optuna.ipynb
new file mode 100644
index 0000000..462068e
--- /dev/null
+++ b/notebooks/07.1-mc-optuna.ipynb
@@ -0,0 +1,1628 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-10-10T01:25:12.788851Z",
+ "start_time": "2020-10-10T01:25:12.783398Z"
+ }
+ },
+ "source": [
+ "# Sequence to Sequence Models for Timeseries Regression\n",
+ "\n",
+ "\n",
+ "In this notebook we are going to find the optimal hidden_size for a model vs a dataset. We will use pytorch lightning and optuna."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-11-13T07:55:31.364813Z",
+ "start_time": "2020-11-13T07:55:31.025418Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# OPTIONAL: Load the \"autoreload\" extension so that code can change. But blacklist large modules\n",
+ "%load_ext autoreload\n",
+ "%autoreload 2\n",
+ "%aimport -pandas\n",
+ "%aimport -torch\n",
+ "%aimport -numpy\n",
+ "%aimport -matplotlib\n",
+ "%aimport -dask\n",
+ "%aimport -tqdm\n",
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-11-13T07:55:32.324191Z",
+ "start_time": "2020-11-13T07:55:31.367170Z"
+ },
+ "lines_to_next_cell": 0
+ },
+ "outputs": [],
+ "source": [
+ "# Imports\n",
+ "import torch\n",
+ "from torch import nn, optim\n",
+ "from torch.nn import functional as F\n",
+ "from torch.autograd import Variable\n",
+ "import torch\n",
+ "import torch.utils.data\n",
+ "\n",
+ "\n",
+ "from pathlib import Path\n",
+ "from tqdm.auto import tqdm\n",
+ "\n",
+ "import pytorch_lightning as pl"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-11-13T07:55:32.863984Z",
+ "start_time": "2020-11-13T07:55:32.326782Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from seq2seq_time.data.dataset import Seq2SeqDataSet, Seq2SeqDataSets\n",
+ "from seq2seq_time.predict import predict, predict_multi\n",
+ "from seq2seq_time.util import dset_to_nc"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-11-13T07:55:32.900137Z",
+ "start_time": "2020-11-13T07:55:32.866551Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import logging\n",
+ "import warnings\n",
+ "import seq2seq_time.silence \n",
+ "warnings.simplefilter('once')\n",
+ "warnings.simplefilter(action='ignore', category=FutureWarning)\n",
+ "warnings.simplefilter(action='ignore', category=DeprecationWarning)\n",
+ "warnings.filterwarnings('ignore', 'Consider increasing the value of the `num_workers` argument', UserWarning)\n",
+ "warnings.filterwarnings('ignore', 'Your val_dataloader has `shuffle=True`', UserWarning)\n",
+ "\n",
+ "from pytorch_lightning import _logger as log\n",
+ "log.setLevel(logging.WARN)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-10-10T01:28:32.492160Z",
+ "start_time": "2020-10-10T01:28:32.488140Z"
+ }
+ },
+ "source": [
+ "## Parameters"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-11-13T07:55:32.992818Z",
+ "start_time": "2020-11-13T07:55:32.902453Z"
+ },
+ "lines_to_next_cell": 0
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "using cuda\n",
+ "20201108-300000\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "336"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
+ "print(f'using {device}')\n",
+ "\n",
+ "timestamp = '20201108-300000'\n",
+ "print(timestamp)\n",
+ "window_past = 48*7\n",
+ "window_future = 48\n",
+ "batch_size = 64\n",
+ "num_workers = 4\n",
+ "datasets_root = Path('../data/processed/')\n",
+ "window_past"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Datasets\n",
+ "\n",
+ "From easy to hard, these dataset show different challenges, all of them with more than 20k datapoints and with a regression output. See the 00.01 notebook for more details, and the code for more information.\n",
+ "\n",
+ "Some such as MetroInterstateTraffic are easier, some are periodic such as BejingPM25, some are conditional on inputs such as GasSensor, and some are noisy and periodic like IMOSCurrentsVel"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-11-13T07:55:33.715323Z",
+ "start_time": "2020-11-13T07:55:33.238397Z"
+ },
+ "lines_to_next_cell": 0
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[seq2seq_time.data.data.MetroInterstateTraffic,\n",
+ " seq2seq_time.data.data.IMOSCurrentsVel,\n",
+ " seq2seq_time.data.data.GasSensor,\n",
+ " seq2seq_time.data.data.AppliancesEnergyPrediction,\n",
+ " seq2seq_time.data.data.BejingPM25]"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from seq2seq_time.data.data import IMOSCurrentsVel, AppliancesEnergyPrediction, BejingPM25, GasSensor, MetroInterstateTraffic\n",
+ "datasets = [MetroInterstateTraffic, IMOSCurrentsVel, GasSensor, AppliancesEnergyPrediction, BejingPM25]\n",
+ "datasets\n",
+ "# ## Lightning\n",
+ "#\n",
+ "# We will use pytorch lightning to handle all the training scaffolding. We have a common pytorch lightning class that takes in the model and defines training steps and logging."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-11-13T07:55:34.390203Z",
+ "start_time": "2020-11-13T07:55:34.338792Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import pytorch_lightning as pl\n",
+ "\n",
+ "class PL_MODEL(pl.LightningModule):\n",
+ " def __init__(self, model, lr=3e-4, patience=None, weight_decay=0):\n",
+ " super().__init__()\n",
+ " self._model = model\n",
+ " self.lr = lr\n",
+ " self.patience = patience\n",
+ " self.weight_decay = weight_decay\n",
+ "\n",
+ " def forward(self, x_past, y_past, x_future, y_future=None):\n",
+ " \"\"\"Eval/Predict\"\"\"\n",
+ " y_dist, extra = self._model(x_past, y_past, x_future, y_future)\n",
+ " return y_dist, extra\n",
+ "\n",
+ " def training_step(self, batch, batch_idx, phase='train'):\n",
+ " x_past, y_past, x_future, y_future = batch\n",
+ " y_dist, extra = self.forward(*batch)\n",
+ " loss = -y_dist.log_prob(y_future).mean()\n",
+ " self.log_dict({f'loss/{phase}':loss})\n",
+ " if ('loss' in extra) and (phase=='train'):\n",
+ " # some models have a special loss\n",
+ " loss = extra['loss']\n",
+ " self.log_dict({f'model_loss/{phase}':loss})\n",
+ " return loss\n",
+ "\n",
+ " def validation_step(self, batch, batch_idx):\n",
+ " return self.training_step(batch, batch_idx, phase='val')\n",
+ " \n",
+ " def test_step(self, batch, batch_idx):\n",
+ " return self.training_step(batch, batch_idx, phase='test')\n",
+ " \n",
+ " def configure_optimizers(self):\n",
+ " optim = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)\n",
+ " scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n",
+ " optim,\n",
+ " patience=self.patience,\n",
+ " verbose=False,\n",
+ " min_lr=1e-7,\n",
+ " ) if self.patience else None\n",
+ " return {'optimizer': optim, 'lr_scheduler': scheduler, 'monitor': 'loss/val'}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-11-13T07:55:34.653519Z",
+ "start_time": "2020-11-13T07:55:34.605285Z"
+ },
+ "lines_to_next_cell": 2
+ },
+ "outputs": [],
+ "source": [
+ "from torch.utils.data import DataLoader\n",
+ "from pytorch_lightning.loggers import CSVLogger, WandbLogger, TensorBoardLogger, TestTubeLogger\n",
+ "from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n",
+ "from pytorch_lightning.callbacks import LearningRateMonitor"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Models"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-11-13T07:55:35.693992Z",
+ "start_time": "2020-11-13T07:55:35.629427Z"
+ },
+ "lines_to_end_of_cell_marker": 2,
+ "lines_to_next_cell": 0
+ },
+ "outputs": [],
+ "source": [
+ "from seq2seq_time.models.baseline import BaselineLast, BaselineMean\n",
+ "from seq2seq_time.models.lstm_seq2seq import LSTMSeq2Seq\n",
+ "from seq2seq_time.models.lstm import LSTM\n",
+ "from seq2seq_time.models.transformer import Transformer\n",
+ "from seq2seq_time.models.transformer_seq2seq import TransformerSeq2Seq\n",
+ "from seq2seq_time.models.neural_process import RANP\n",
+ "from seq2seq_time.models.transformer_process import TransformerProcess\n",
+ "from seq2seq_time.models.tcn import TCNSeq\n",
+ "from seq2seq_time.models.inceptiontime import InceptionTimeSeq\n",
+ "from seq2seq_time.models.xattention import CrossAttention"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-11-13T07:55:36.024789Z",
+ "start_time": "2020-11-13T07:55:35.985433Z"
+ },
+ "lines_to_next_cell": 0
+ },
+ "outputs": [],
+ "source": [
+ "import gc\n",
+ "\n",
+ "def free_mem():\n",
+ " gc.collect()\n",
+ " torch.cuda.empty_cache()\n",
+ " gc.collect()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-11-02T06:10:41.904480Z",
+ "start_time": "2020-11-02T06:10:41.848613Z"
+ },
+ "lines_to_next_cell": 2
+ },
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-11-13T07:55:37.813853Z",
+ "start_time": "2020-11-13T07:55:37.768348Z"
+ },
+ "lines_to_next_cell": 0
+ },
+ "outputs": [],
+ "source": [
+ "# PARAMS: model\n",
+ "dropout=0.0\n",
+ "layers=6\n",
+ "nhead=4\n",
+ "\n",
+ "models = [\n",
+ "# lambda xs, ys: BaselineLast(),\n",
+ "# lambda xs, ys, hidden_size: BaselineMean(),\n",
+ " lambda xs, ys, hidden_size, layers:TransformerProcess(xs,\n",
+ " ys, hidden_size=hidden_size, nhead=nhead,\n",
+ " latent_dim=hidden_size//2, dropout=dropout,\n",
+ " nlayers=layers),\n",
+ " lambda xs, ys, hidden_size, layers: RANP(xs,\n",
+ " ys, hidden_dim=hidden_size, dropout=dropout, \n",
+ " latent_dim=hidden_size//2, n_decoder_layers=layers, n_latent_encoder_layers=layers, n_det_encoder_layers=layers),\n",
+ " lambda xs, ys, hidden_size, layers:TCNSeq(xs, ys, hidden_size=hidden_size, nlayers=layers, dropout=dropout, kernel_size=2),\n",
+ " lambda xs, ys, hidden_size, layers: Transformer(xs,\n",
+ " ys,\n",
+ " attention_dropout=dropout,\n",
+ " nhead=nhead,\n",
+ " nlayers=layers,\n",
+ " hidden_size=hidden_size),\n",
+ " lambda xs, ys, hidden_size, layers: LSTM(xs,\n",
+ " ys,\n",
+ " hidden_size=hidden_size,\n",
+ " lstm_layers=layers//2,\n",
+ " lstm_dropout=dropout),\n",
+ "\n",
+ " lambda xs, ys, hidden_size, layers: TransformerSeq2Seq(xs,\n",
+ " ys,\n",
+ " hidden_size=hidden_size,\n",
+ " nhead=nhead,\n",
+ " nlayers=layers,\n",
+ " attention_dropout=dropout\n",
+ " ),\n",
+ "\n",
+ " lambda xs, ys, hidden_size, layers: LSTMSeq2Seq(xs,\n",
+ " ys,\n",
+ " hidden_size=hidden_size,\n",
+ " lstm_layers=layers//2,\n",
+ " lstm_dropout=dropout),\n",
+ " lambda xs, ys, hidden_size, layers: CrossAttention(xs,\n",
+ " ys,\n",
+ " nlayers=layers,\n",
+ " hidden_size=hidden_size,),\n",
+ " lambda xs, ys, hidden_size, layers: InceptionTimeSeq(xs,\n",
+ " ys,\n",
+ " kernel_size=96,\n",
+ " layers=layers//2,\n",
+ " hidden_size=hidden_size,\n",
+ " bottleneck=hidden_size//4)\n",
+ "\n",
+ "]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-11-12T23:16:29.248291Z",
+ "start_time": "2020-11-12T23:16:16.139941Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# DEBUG: sanity check\n",
+ "\n",
+ "for Dataset in datasets:\n",
+ " dataset_name = Dataset.__name__\n",
+ " dataset = Dataset(datasets_root)\n",
+ " ds_train, ds_val, ds_test = dataset.to_datasets(window_past=window_past,\n",
+ " window_future=window_future)\n",
+ "\n",
+ " # Init data\n",
+ " x_past, y_past, x_future, y_future = ds_train.get_rows(10)\n",
+ " xs = x_past.shape[-1]\n",
+ " ys = y_future.shape[-1]\n",
+ "\n",
+ " # Loaders\n",
+ " dl_train = DataLoader(ds_train,\n",
+ " batch_size=batch_size,\n",
+ " shuffle=True,\n",
+ " pin_memory=num_workers == 0,\n",
+ " num_workers=num_workers)\n",
+ " dl_val = DataLoader(ds_val,\n",
+ " shuffle=True,\n",
+ " batch_size=batch_size,\n",
+ " num_workers=num_workers)\n",
+ "\n",
+ " for m_fn in models:\n",
+ " free_mem()\n",
+ " pt_model = m_fn(xs, ys, 8, 4)\n",
+ " model_name = type(pt_model).__name__\n",
+ " print(timestamp, dataset_name, model_name)\n",
+ "\n",
+ " # Wrap in lightning\n",
+ " model = PL_MODEL(pt_model,\n",
+ " lr=3e-4\n",
+ " ).to(device)\n",
+ " trainer = pl.Trainer(\n",
+ " fast_dev_run=True,\n",
+ " # GPU\n",
+ " gpus=1,\n",
+ " amp_level='O1',\n",
+ " precision=16,\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-10-23T23:36:11.052891Z",
+ "start_time": "2020-10-23T23:36:11.048874Z"
+ }
+ },
+ "source": [
+ "## Train"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-11-12T23:16:29.286762Z",
+ "start_time": "2020-11-12T23:16:29.250367Z"
+ },
+ "lines_to_next_cell": 2
+ },
+ "outputs": [],
+ "source": [
+ "max_iters=20000"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-11-12T23:16:29.326846Z",
+ "start_time": "2020-11-12T23:16:29.289004Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "tensorboard_dir = Path(f\"../outputs/{timestamp}\").resolve()\n",
+ "print(f'For tensorboard run:\\ntensorboard --logdir=\"{tensorboard_dir}\"')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-11-13T07:55:48.561522Z",
+ "start_time": "2020-11-13T07:55:48.516182Z"
+ },
+ "lines_to_next_cell": 0
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "def objective(trial):\n",
+ " \"\"\"\n",
+ " Optuna function to optimize\n",
+ " \n",
+ " See https://github.com/optuna/optuna/blob/master/examples/pytorch_lightning_simple.py\n",
+ " \"\"\"\n",
+ " # sample\n",
+ " hidden_size_exp = trial.suggest_int(\"hidden_size_exp\", 1, 8)\n",
+ " hidden_size = 2**hidden_size_exp\n",
+ " \n",
+ " layers = trial.suggest_int(\"layers\", 1, 12)\n",
+ " \n",
+ " # Load model\n",
+ " pt_model = m_fn(xs, ys, hidden_size, layers)\n",
+ " model_name = type(pt_model).__name__\n",
+ " \n",
+ " # Wrap in lightning\n",
+ " patience = 2\n",
+ " model = PL_MODEL(pt_model,\n",
+ " lr=3e-4, patience=patience,\n",
+ " ).to(device)\n",
+ "\n",
+ " \n",
+ " save_dir = f\"../outputs/{timestamp}/{dataset_name}_{model_name}/{trial.number}\"\n",
+ " Path(save_dir).mkdir(exist_ok=True, parents=True)\n",
+ " trainer = pl.Trainer(\n",
+ " # Training length\n",
+ " min_epochs=2,\n",
+ " max_epochs=40,\n",
+ " limit_train_batches=max_iters//batch_size,\n",
+ " limit_val_batches=max_iters//batch_size//5,\n",
+ " # Misc\n",
+ " gradient_clip_val=20,\n",
+ " terminate_on_nan=True,\n",
+ " # GPU\n",
+ " gpus=1,\n",
+ " amp_level='O1',\n",
+ " precision=16,\n",
+ " # Callbacks\n",
+ " default_root_dir=save_dir,\n",
+ " logger=False,\n",
+ " callbacks=[\n",
+ " EarlyStopping(monitor='loss/val', patience=patience * 2),\n",
+ " PyTorchLightningPruningCallback(trial, monitor=\"loss/val\")],\n",
+ " )\n",
+ " trainer.fit(model, dl_train, dl_val)\n",
+ " \n",
+ " # Run on all val data, using test mode\n",
+ " r = trainer.test(model, test_dataloaders=dl_val, verbose=False)\n",
+ " return r[0]['loss/test']"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-11-08T02:45:44.106583Z",
+ "start_time": "2020-11-08T02:45:44.050637Z"
+ },
+ "lines_to_next_cell": 2
+ },
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-11-13T07:55:49.497334Z",
+ "start_time": "2020-11-13T07:55:48.981040Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import optuna\n",
+ "from optuna.integration import PyTorchLightningPruningCallback"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-11-13T07:55:49.543035Z",
+ "start_time": "2020-11-13T07:55:49.499555Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import subprocess\n",
+ "def get_git_commit():\n",
+ " try:\n",
+ " return subprocess.check_output([\"git\", \"rev-parse\", \"HEAD\"], cwd='..').decode().strip()\n",
+ " except Exception:\n",
+ " logging.exception(\"failed to get git hash\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-11-13T07:55:50.265327Z",
+ "start_time": "2020-11-13T07:55:50.204384Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "Path(f\"../outputs/{timestamp}\").mkdir(exist_ok=True)\n",
+ "storage = f\"sqlite:///../outputs/{timestamp}/optuna.db\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-11-13T01:44:13.554355Z",
+ "start_time": "2020-11-12T23:16:30.128631Z"
+ },
+ "lines_to_next_cell": 0,
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "for Dataset in tqdm(datasets, desc='datasets'):\n",
+ " dataset_name = Dataset.__name__\n",
+ " dataset = Dataset(datasets_root)\n",
+ " ds_train, ds_val, ds_test = dataset.to_datasets(window_past=window_past,\n",
+ " window_future=window_future)\n",
+ "\n",
+ " # Init data\n",
+ " x_past, y_past, x_future, y_future = ds_train.get_rows(10)\n",
+ " xs = x_past.shape[-1]\n",
+ " ys = y_future.shape[-1]\n",
+ "\n",
+ " # Loaders\n",
+ " dl_train = DataLoader(ds_train,\n",
+ " batch_size=batch_size,\n",
+ " shuffle=True,\n",
+ " drop_last=True,\n",
+ " pin_memory=num_workers == 0,\n",
+ " num_workers=num_workers)\n",
+ " dl_val = DataLoader(ds_val,\n",
+ " shuffle=False,\n",
+ " batch_size=batch_size,\n",
+ " drop_last=True,\n",
+ " num_workers=num_workers)\n",
+ "\n",
+ " for i, m_fn in enumerate(tqdm(models, desc=f'models ({dataset_name})')):\n",
+ " try:\n",
+ " model_name = type(m_fn(8, 8, 8, 2)).__name__\n",
+ " free_mem()\n",
+ " study_name = f'{timestamp}_{dataset_name}-{model_name}'\n",
+ " \n",
+ " # Create study \n",
+ " pruner = optuna.pruners.MedianPruner()\n",
+ " study = optuna.create_study(storage=storage, \n",
+ " study_name=study_name, \n",
+ " pruner=pruner,\n",
+ " load_if_exists=True)\n",
+ " study.set_user_attr('dataset', dataset_name)\n",
+ " study.set_user_attr('model', model_name)\n",
+ " study.set_user_attr('commit', get_git_commit())\n",
+ " \n",
+ " df_trials = study.trials_dataframe()\n",
+ " if len(df_trials):\n",
+ " df_trials = df_trials[df_trials.state=='COMPLETE']\n",
+ " nb_trials = len(df_trials)\n",
+ " if nb_trials==0:\n",
+ " # Priors\n",
+ " study.enqueue_trial({\"layers\": 6, \"params_hidden_size_exp\": 2})\n",
+ " study.enqueue_trial({\"layers\": 1, \"params_hidden_size_exp\": 3})\n",
+ " study.enqueue_trial({\"layers\": 3, \"params_hidden_size_exp\": 5})\n",
+ " if nb_trials<20:\n",
+ " # Opt\n",
+ " study.optimize(objective, n_trials=20-nb_trials, \n",
+ " timeout=60*60*2 # Max seconds for all optimizes\n",
+ " )\n",
+ " \n",
+ " \n",
+ " print(\"Number of finished trials: {}\".format(len(study.trials)))\n",
+ "\n",
+ " print(\"Best trial:\")\n",
+ " trial = study.best_trial\n",
+ "\n",
+ " print(\" Value: {}\".format(trial.value))\n",
+ "\n",
+ " print(\" Params: \")\n",
+ " for key, value in trial.params.items():\n",
+ " print(\" {}: {}\".format(key, value))\n",
+ " \n",
+ " except Exception as e:\n",
+ " logging.exception('failed to run model')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-11-12T23:08:15.792627Z",
+ "start_time": "2020-11-12T23:08:15.700279Z"
+ }
+ },
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-11-13T01:44:20.959637Z",
+ "start_time": "2020-11-13T01:44:13.560463Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# Baseline\n",
+ "\n",
+ "models2 = [\n",
+ " lambda xs, ys, _, l: BaselineLast(),\n",
+ " lambda xs, ys, _, l: BaselineMean(),\n",
+ "]\n",
+ "\n",
+ "for Dataset in tqdm(datasets, desc='datasets'):\n",
+ " dataset_name = Dataset.__name__\n",
+ " dataset = Dataset(datasets_root)\n",
+ " ds_train, ds_val, ds_test = dataset.to_datasets(window_past=window_past,\n",
+ " window_future=window_future)\n",
+ "\n",
+ " # Init data\n",
+ " x_past, y_past, x_future, y_future = ds_train.get_rows(10)\n",
+ " xs = x_past.shape[-1]\n",
+ " ys = y_future.shape[-1]\n",
+ "\n",
+ " # Loaders\n",
+ " dl_train = DataLoader(ds_train,\n",
+ " batch_size=batch_size,\n",
+ " shuffle=True,\n",
+ " drop_last=True,\n",
+ " pin_memory=num_workers == 0,\n",
+ " num_workers=num_workers)\n",
+ " dl_val = DataLoader(ds_val,\n",
+ " shuffle=False,\n",
+ " batch_size=batch_size,\n",
+ " drop_last=True,\n",
+ " num_workers=num_workers)\n",
+ "\n",
+ " for i, m_fn in enumerate(tqdm(models2, desc=f'models ({dataset_name})')):\n",
+ " try:\n",
+ " model_name = type(m_fn(8, 8, 8, 2)).__name__\n",
+ " free_mem()\n",
+ " study_name = f'{timestamp}_{dataset_name}-{model_name}'\n",
+ " \n",
+ " # Create study \n",
+ " pruner = optuna.pruners.MedianPruner()\n",
+ " study = optuna.create_study(storage=storage, \n",
+ " study_name=study_name, \n",
+ " pruner=pruner,\n",
+ " load_if_exists=True)\n",
+ " study.set_user_attr('dataset', dataset_name)\n",
+ " study.set_user_attr('model', model_name)\n",
+ " study.set_user_attr('commit', get_git_commit())\n",
+ " \n",
+ " df_trials = study.trials_dataframe()\n",
+ " if len(df_trials):\n",
+ " df_trials = df_trials[df_trials.state=='COMPLETE']\n",
+ " nb_trials = len(df_trials)\n",
+ " if nb_trials<1:\n",
+ " # Opt\n",
+ " study.optimize(objective, n_trials=1, \n",
+ " timeout=60*30 # Max seconds for all optimizes\n",
+ " )\n",
+ " \n",
+ " except Exception as e:\n",
+ " logging.exception('failed to run model')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-11-13T01:44:21.021030Z",
+ "start_time": "2020-11-13T01:44:20.962459Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# TODO baseline, run as sep cell, opt once\n",
+ "# TODO summarize time and model params at best params"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-11-13T07:55:54.138705Z",
+ "start_time": "2020-11-13T07:55:53.706642Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " | \n",
+ " n_trials | \n",
+ " param_hidden_size_exp | \n",
+ " param_layers | \n",
+ " value | \n",
+ "
\n",
+ " \n",
+ " | dataset | \n",
+ " model | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | AppliancesEnergyPrediction | \n",
+ " TCNSeq | \n",
+ " 48 | \n",
+ " 2 | \n",
+ " 1 | \n",
+ " 1.075230 | \n",
+ "
\n",
+ " \n",
+ " | InceptionTimeSeq | \n",
+ " 17 | \n",
+ " 1 | \n",
+ " 3 | \n",
+ " 1.099332 | \n",
+ "
\n",
+ " \n",
+ " | TransformerProcess | \n",
+ " 18 | \n",
+ " 4 | \n",
+ " 3 | \n",
+ " 1.163042 | \n",
+ "
\n",
+ " \n",
+ " | LSTM | \n",
+ " 19 | \n",
+ " 6 | \n",
+ " 3 | \n",
+ " 1.172367 | \n",
+ "
\n",
+ " \n",
+ " | TransformerSeq2Seq | \n",
+ " 17 | \n",
+ " 2 | \n",
+ " 1 | \n",
+ " 1.174764 | \n",
+ "
\n",
+ " \n",
+ " | Transformer | \n",
+ " 19 | \n",
+ " 4 | \n",
+ " 6 | \n",
+ " 1.195765 | \n",
+ "
\n",
+ " \n",
+ " | LSTMSeq2Seq | \n",
+ " 29 | \n",
+ " 4 | \n",
+ " 6 | \n",
+ " 1.204848 | \n",
+ "
\n",
+ " \n",
+ " | RANP | \n",
+ " 16 | \n",
+ " 4 | \n",
+ " 10 | \n",
+ " 1.278206 | \n",
+ "
\n",
+ " \n",
+ " | BaselineMean | \n",
+ " 1 | \n",
+ " 6 | \n",
+ " 5 | \n",
+ " 1.322796 | \n",
+ "
\n",
+ " \n",
+ " | BaselineLast | \n",
+ " 1 | \n",
+ " 2 | \n",
+ " 1 | \n",
+ " 1.475730 | \n",
+ "
\n",
+ " \n",
+ " | CrossAttention | \n",
+ " 17 | \n",
+ " 4 | \n",
+ " 5 | \n",
+ " 1.547343 | \n",
+ "
\n",
+ " \n",
+ " | BejingPM25 | \n",
+ " TCNSeq | \n",
+ " 46 | \n",
+ " 3 | \n",
+ " 3 | \n",
+ " 1.235752 | \n",
+ "
\n",
+ " \n",
+ " | InceptionTimeSeq | \n",
+ " 15 | \n",
+ " 1 | \n",
+ " 3 | \n",
+ " 1.239024 | \n",
+ "
\n",
+ " \n",
+ " | LSTM | \n",
+ " 18 | \n",
+ " 2 | \n",
+ " 3 | \n",
+ " 1.271591 | \n",
+ "
\n",
+ " \n",
+ " | LSTMSeq2Seq | \n",
+ " 17 | \n",
+ " 5 | \n",
+ " 6 | \n",
+ " 1.292539 | \n",
+ "
\n",
+ " \n",
+ " | Transformer | \n",
+ " 53 | \n",
+ " 6 | \n",
+ " 1 | \n",
+ " 1.295221 | \n",
+ "
\n",
+ " \n",
+ " | TransformerProcess | \n",
+ " 16 | \n",
+ " 2 | \n",
+ " 1 | \n",
+ " 1.404077 | \n",
+ "
\n",
+ " \n",
+ " | CrossAttention | \n",
+ " 17 | \n",
+ " 8 | \n",
+ " 11 | \n",
+ " 1.411578 | \n",
+ "
\n",
+ " \n",
+ " | RANP | \n",
+ " 16 | \n",
+ " 4 | \n",
+ " 1 | \n",
+ " 1.429652 | \n",
+ "
\n",
+ " \n",
+ " | BaselineMean | \n",
+ " 1 | \n",
+ " 6 | \n",
+ " 3 | \n",
+ " 1.435451 | \n",
+ "
\n",
+ " \n",
+ " | BaselineLast | \n",
+ " 1 | \n",
+ " 4 | \n",
+ " 12 | \n",
+ " 1.548376 | \n",
+ "
\n",
+ " \n",
+ " | TransformerSeq2Seq | \n",
+ " 16 | \n",
+ " 4 | \n",
+ " 8 | \n",
+ " 2.393506 | \n",
+ "
\n",
+ " \n",
+ " | GasSensor | \n",
+ " RANP | \n",
+ " 18 | \n",
+ " 6 | \n",
+ " 10 | \n",
+ " -2.133501 | \n",
+ "
\n",
+ " \n",
+ " | InceptionTimeSeq | \n",
+ " 33 | \n",
+ " 5 | \n",
+ " 10 | \n",
+ " -2.102463 | \n",
+ "
\n",
+ " \n",
+ " | Transformer | \n",
+ " 50 | \n",
+ " 6 | \n",
+ " 12 | \n",
+ " -1.960168 | \n",
+ "
\n",
+ " \n",
+ " | TCNSeq | \n",
+ " 41 | \n",
+ " 6 | \n",
+ " 8 | \n",
+ " -1.736079 | \n",
+ "
\n",
+ " \n",
+ " | LSTM | \n",
+ " 18 | \n",
+ " 6 | \n",
+ " 5 | \n",
+ " -1.537527 | \n",
+ "
\n",
+ " \n",
+ " | LSTMSeq2Seq | \n",
+ " 20 | \n",
+ " 8 | \n",
+ " 5 | \n",
+ " -1.489893 | \n",
+ "
\n",
+ " \n",
+ " | TransformerProcess | \n",
+ " 38 | \n",
+ " 7 | \n",
+ " 4 | \n",
+ " -0.880734 | \n",
+ "
\n",
+ " \n",
+ " | CrossAttention | \n",
+ " 17 | \n",
+ " 7 | \n",
+ " 3 | \n",
+ " -0.638713 | \n",
+ "
\n",
+ " \n",
+ " | TransformerSeq2Seq | \n",
+ " 28 | \n",
+ " 6 | \n",
+ " 1 | \n",
+ " 0.336668 | \n",
+ "
\n",
+ " \n",
+ " | BaselineMean | \n",
+ " 1 | \n",
+ " 3 | \n",
+ " 2 | \n",
+ " 1.584720 | \n",
+ "
\n",
+ " \n",
+ " | BaselineLast | \n",
+ " 1 | \n",
+ " 2 | \n",
+ " 6 | \n",
+ " 1.974204 | \n",
+ "
\n",
+ " \n",
+ " | IMOSCurrentsVel | \n",
+ " TCNSeq | \n",
+ " 52 | \n",
+ " 4 | \n",
+ " 6 | \n",
+ " 0.817371 | \n",
+ "
\n",
+ " \n",
+ " | InceptionTimeSeq | \n",
+ " 20 | \n",
+ " 3 | \n",
+ " 6 | \n",
+ " 0.845443 | \n",
+ "
\n",
+ " \n",
+ " | LSTM | \n",
+ " 22 | \n",
+ " 5 | \n",
+ " 6 | \n",
+ " 0.875669 | \n",
+ "
\n",
+ " \n",
+ " | Transformer | \n",
+ " 24 | \n",
+ " 2 | \n",
+ " 3 | \n",
+ " 0.879682 | \n",
+ "
\n",
+ " \n",
+ " | BaselineLast | \n",
+ " 1 | \n",
+ " 3 | \n",
+ " 6 | \n",
+ " 0.885377 | \n",
+ "
\n",
+ " \n",
+ " | LSTMSeq2Seq | \n",
+ " 24 | \n",
+ " 7 | \n",
+ " 12 | \n",
+ " 0.887783 | \n",
+ "
\n",
+ " \n",
+ " | RANP | \n",
+ " 19 | \n",
+ " 5 | \n",
+ " 3 | \n",
+ " 1.035464 | \n",
+ "
\n",
+ " \n",
+ " | BaselineMean | \n",
+ " 1 | \n",
+ " 5 | \n",
+ " 7 | \n",
+ " 1.202492 | \n",
+ "
\n",
+ " \n",
+ " | TransformerSeq2Seq | \n",
+ " 19 | \n",
+ " 3 | \n",
+ " 1 | \n",
+ " 1.266849 | \n",
+ "
\n",
+ " \n",
+ " | TransformerProcess | \n",
+ " 27 | \n",
+ " 5 | \n",
+ " 1 | \n",
+ " 1.394595 | \n",
+ "
\n",
+ " \n",
+ " | CrossAttention | \n",
+ " 17 | \n",
+ " 3 | \n",
+ " 3 | \n",
+ " 1.656854 | \n",
+ "
\n",
+ " \n",
+ " | MetroInterstateTraffic | \n",
+ " TCNSeq | \n",
+ " 61 | \n",
+ " 7 | \n",
+ " 3 | \n",
+ " -0.324690 | \n",
+ "
\n",
+ " \n",
+ " | TransformerProcess | \n",
+ " 17 | \n",
+ " 3 | \n",
+ " 7 | \n",
+ " -0.297472 | \n",
+ "
\n",
+ " \n",
+ " | RANP | \n",
+ " 16 | \n",
+ " 3 | \n",
+ " 3 | \n",
+ " -0.291096 | \n",
+ "
\n",
+ " \n",
+ " | Transformer | \n",
+ " 20 | \n",
+ " 4 | \n",
+ " 11 | \n",
+ " -0.253573 | \n",
+ "
\n",
+ " \n",
+ " | LSTMSeq2Seq | \n",
+ " 15 | \n",
+ " 5 | \n",
+ " 5 | \n",
+ " -0.200625 | \n",
+ "
\n",
+ " \n",
+ " | LSTM | \n",
+ " 17 | \n",
+ " 4 | \n",
+ " 5 | \n",
+ " -0.200455 | \n",
+ "
\n",
+ " \n",
+ " | TransformerSeq2Seq | \n",
+ " 15 | \n",
+ " 2 | \n",
+ " 6 | \n",
+ " -0.192895 | \n",
+ "
\n",
+ " \n",
+ " | InceptionTimeSeq | \n",
+ " 17 | \n",
+ " 3 | \n",
+ " 6 | \n",
+ " -0.162513 | \n",
+ "
\n",
+ " \n",
+ " | CrossAttention | \n",
+ " 14 | \n",
+ " 6 | \n",
+ " 1 | \n",
+ " -0.103853 | \n",
+ "
\n",
+ " \n",
+ " | BaselineMean | \n",
+ " 1 | \n",
+ " 3 | \n",
+ " 1 | \n",
+ " 1.410653 | \n",
+ "
\n",
+ " \n",
+ " | BaselineLast | \n",
+ " 1 | \n",
+ " 2 | \n",
+ " 12 | \n",
+ " 1.741882 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " n_trials \\\n",
+ "dataset model \n",
+ "AppliancesEnergyPrediction TCNSeq 48 \n",
+ " InceptionTimeSeq 17 \n",
+ " TransformerProcess 18 \n",
+ " LSTM 19 \n",
+ " TransformerSeq2Seq 17 \n",
+ " Transformer 19 \n",
+ " LSTMSeq2Seq 29 \n",
+ " RANP 16 \n",
+ " BaselineMean 1 \n",
+ " BaselineLast 1 \n",
+ " CrossAttention 17 \n",
+ "BejingPM25 TCNSeq 46 \n",
+ " InceptionTimeSeq 15 \n",
+ " LSTM 18 \n",
+ " LSTMSeq2Seq 17 \n",
+ " Transformer 53 \n",
+ " TransformerProcess 16 \n",
+ " CrossAttention 17 \n",
+ " RANP 16 \n",
+ " BaselineMean 1 \n",
+ " BaselineLast 1 \n",
+ " TransformerSeq2Seq 16 \n",
+ "GasSensor RANP 18 \n",
+ " InceptionTimeSeq 33 \n",
+ " Transformer 50 \n",
+ " TCNSeq 41 \n",
+ " LSTM 18 \n",
+ " LSTMSeq2Seq 20 \n",
+ " TransformerProcess 38 \n",
+ " CrossAttention 17 \n",
+ " TransformerSeq2Seq 28 \n",
+ " BaselineMean 1 \n",
+ " BaselineLast 1 \n",
+ "IMOSCurrentsVel TCNSeq 52 \n",
+ " InceptionTimeSeq 20 \n",
+ " LSTM 22 \n",
+ " Transformer 24 \n",
+ " BaselineLast 1 \n",
+ " LSTMSeq2Seq 24 \n",
+ " RANP 19 \n",
+ " BaselineMean 1 \n",
+ " TransformerSeq2Seq 19 \n",
+ " TransformerProcess 27 \n",
+ " CrossAttention 17 \n",
+ "MetroInterstateTraffic TCNSeq 61 \n",
+ " TransformerProcess 17 \n",
+ " RANP 16 \n",
+ " Transformer 20 \n",
+ " LSTMSeq2Seq 15 \n",
+ " LSTM 17 \n",
+ " TransformerSeq2Seq 15 \n",
+ " InceptionTimeSeq 17 \n",
+ " CrossAttention 14 \n",
+ " BaselineMean 1 \n",
+ " BaselineLast 1 \n",
+ "\n",
+ " param_hidden_size_exp \\\n",
+ "dataset model \n",
+ "AppliancesEnergyPrediction TCNSeq 2 \n",
+ " InceptionTimeSeq 1 \n",
+ " TransformerProcess 4 \n",
+ " LSTM 6 \n",
+ " TransformerSeq2Seq 2 \n",
+ " Transformer 4 \n",
+ " LSTMSeq2Seq 4 \n",
+ " RANP 4 \n",
+ " BaselineMean 6 \n",
+ " BaselineLast 2 \n",
+ " CrossAttention 4 \n",
+ "BejingPM25 TCNSeq 3 \n",
+ " InceptionTimeSeq 1 \n",
+ " LSTM 2 \n",
+ " LSTMSeq2Seq 5 \n",
+ " Transformer 6 \n",
+ " TransformerProcess 2 \n",
+ " CrossAttention 8 \n",
+ " RANP 4 \n",
+ " BaselineMean 6 \n",
+ " BaselineLast 4 \n",
+ " TransformerSeq2Seq 4 \n",
+ "GasSensor RANP 6 \n",
+ " InceptionTimeSeq 5 \n",
+ " Transformer 6 \n",
+ " TCNSeq 6 \n",
+ " LSTM 6 \n",
+ " LSTMSeq2Seq 8 \n",
+ " TransformerProcess 7 \n",
+ " CrossAttention 7 \n",
+ " TransformerSeq2Seq 6 \n",
+ " BaselineMean 3 \n",
+ " BaselineLast 2 \n",
+ "IMOSCurrentsVel TCNSeq 4 \n",
+ " InceptionTimeSeq 3 \n",
+ " LSTM 5 \n",
+ " Transformer 2 \n",
+ " BaselineLast 3 \n",
+ " LSTMSeq2Seq 7 \n",
+ " RANP 5 \n",
+ " BaselineMean 5 \n",
+ " TransformerSeq2Seq 3 \n",
+ " TransformerProcess 5 \n",
+ " CrossAttention 3 \n",
+ "MetroInterstateTraffic TCNSeq 7 \n",
+ " TransformerProcess 3 \n",
+ " RANP 3 \n",
+ " Transformer 4 \n",
+ " LSTMSeq2Seq 5 \n",
+ " LSTM 4 \n",
+ " TransformerSeq2Seq 2 \n",
+ " InceptionTimeSeq 3 \n",
+ " CrossAttention 6 \n",
+ " BaselineMean 3 \n",
+ " BaselineLast 2 \n",
+ "\n",
+ " param_layers value \n",
+ "dataset model \n",
+ "AppliancesEnergyPrediction TCNSeq 1 1.075230 \n",
+ " InceptionTimeSeq 3 1.099332 \n",
+ " TransformerProcess 3 1.163042 \n",
+ " LSTM 3 1.172367 \n",
+ " TransformerSeq2Seq 1 1.174764 \n",
+ " Transformer 6 1.195765 \n",
+ " LSTMSeq2Seq 6 1.204848 \n",
+ " RANP 10 1.278206 \n",
+ " BaselineMean 5 1.322796 \n",
+ " BaselineLast 1 1.475730 \n",
+ " CrossAttention 5 1.547343 \n",
+ "BejingPM25 TCNSeq 3 1.235752 \n",
+ " InceptionTimeSeq 3 1.239024 \n",
+ " LSTM 3 1.271591 \n",
+ " LSTMSeq2Seq 6 1.292539 \n",
+ " Transformer 1 1.295221 \n",
+ " TransformerProcess 1 1.404077 \n",
+ " CrossAttention 11 1.411578 \n",
+ " RANP 1 1.429652 \n",
+ " BaselineMean 3 1.435451 \n",
+ " BaselineLast 12 1.548376 \n",
+ " TransformerSeq2Seq 8 2.393506 \n",
+ "GasSensor RANP 10 -2.133501 \n",
+ " InceptionTimeSeq 10 -2.102463 \n",
+ " Transformer 12 -1.960168 \n",
+ " TCNSeq 8 -1.736079 \n",
+ " LSTM 5 -1.537527 \n",
+ " LSTMSeq2Seq 5 -1.489893 \n",
+ " TransformerProcess 4 -0.880734 \n",
+ " CrossAttention 3 -0.638713 \n",
+ " TransformerSeq2Seq 1 0.336668 \n",
+ " BaselineMean 2 1.584720 \n",
+ " BaselineLast 6 1.974204 \n",
+ "IMOSCurrentsVel TCNSeq 6 0.817371 \n",
+ " InceptionTimeSeq 6 0.845443 \n",
+ " LSTM 6 0.875669 \n",
+ " Transformer 3 0.879682 \n",
+ " BaselineLast 6 0.885377 \n",
+ " LSTMSeq2Seq 12 0.887783 \n",
+ " RANP 3 1.035464 \n",
+ " BaselineMean 7 1.202492 \n",
+ " TransformerSeq2Seq 1 1.266849 \n",
+ " TransformerProcess 1 1.394595 \n",
+ " CrossAttention 3 1.656854 \n",
+ "MetroInterstateTraffic TCNSeq 3 -0.324690 \n",
+ " TransformerProcess 7 -0.297472 \n",
+ " RANP 3 -0.291096 \n",
+ " Transformer 11 -0.253573 \n",
+ " LSTMSeq2Seq 5 -0.200625 \n",
+ " LSTM 5 -0.200455 \n",
+ " TransformerSeq2Seq 6 -0.192895 \n",
+ " InceptionTimeSeq 6 -0.162513 \n",
+ " CrossAttention 1 -0.103853 \n",
+ " BaselineMean 1 1.410653 \n",
+ " BaselineLast 12 1.741882 "
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import pandas as pd\n",
+ "# Summarize studies\n",
+ "rs = []\n",
+ "study_summaries = optuna.study.get_all_study_summaries(storage=storage)\n",
+ "for s in study_summaries:\n",
+ " row = {}\n",
+ " if (s.best_trial is not None) and (s.best_trial.state==optuna.trial.TrialState.COMPLETE):\n",
+ " params = {k:v for k,v in s.best_trial.__dict__.items() if not k.startswith('_')}\n",
+ " row.update(s.user_attrs)\n",
+ " row['n_trials'] = s.n_trials\n",
+ " row.update({'param_'+k:v for k,v in s.best_trial._params.items()})\n",
+ " row.update(params)\n",
+ " rs.append(row)\n",
+ "df_studies = pd.DataFrame(rs)\n",
+ "\n",
+ "df_studies = (df_studies.drop(columns=['state', 'intermediate_values', 'commit', 'datetime_complete'])\n",
+ " .sort_values(['dataset', 'value'])\n",
+ " .set_index(['dataset', 'model'])\n",
+ " )\n",
+ "df_studies"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-11-13T07:54:27.730531Z",
+ "start_time": "2020-11-13T07:54:27.667135Z"
+ },
+ "lines_to_next_cell": 0,
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "# study_names = [s.study_name for s in optuna.study.get_all_study_summaries(storage=storage)]\n",
+ "\n",
+ "# for study_name in study_names:\n",
+ "# loaded_study = optuna.load_study(study_name=study_name, storage=storage)\n",
+ " \n",
+ "# # Make DF over trials\n",
+ "# print(study_name)\n",
+ " \n",
+ "# df_trials = loaded_study.trials_dataframe()\n",
+ "# # df_trials.index = df_trials.apply(lambda r:f'l={r.params_layers}_hs={r.params_hidden_size_exp}', 1)\n",
+ "# display(df_trials)\n",
+ "\n",
+ "# # Plot test curves, to see how much overfitting\n",
+ "# df_values = pd.DataFrame([s.intermediate_values for s in loaded_study.get_trials()]).T\n",
+ "# df_values.columns = [f\"l={s.params['layers']}_hs={s.params['hidden_size_exp']}\" for s in loaded_study.get_trials()]\n",
+ "# df_values.plot(ylabel='nll', xlabel='epochs', title=f'val loss \"{study_name}\"')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-11-13T01:44:25.350038Z",
+ "start_time": "2020-11-12T23:16:12.130Z"
+ },
+ "lines_to_next_cell": 0
+ },
+ "outputs": [],
+ "source": [
+ "# study_names = [s.study_name for s in optuna.study.get_all_study_summaries(storage=storage)]\n",
+ "\n",
+ "# for study_name in study_names:\n",
+ "# loaded_study = optuna.load_study(study_name=study_name, storage=storage)\n",
+ "# fig=optuna.visualization.plot_contour(loaded_study, params=['hidden_size_exp', 'layers'])\n",
+ "# fig = fig.update_layout(dict(title=f'{study_name} nll'))\n",
+ "# display(fig)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-11-08T23:11:33.818862Z",
+ "start_time": "2020-11-08T23:11:33.749610Z"
+ }
+ },
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-11-08T23:11:14.588720Z",
+ "start_time": "2020-11-08T23:11:14.516020Z"
+ },
+ "lines_to_next_cell": 2
+ },
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-11-13T01:44:25.351121Z",
+ "start_time": "2020-11-12T23:16:12.141Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "[f\"{s.params['layers']}_{s.params['hidden_size_exp']}\" for s in loaded_study.get_trials()]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-11-13T01:44:25.352257Z",
+ "start_time": "2020-11-12T23:16:12.146Z"
+ },
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "df_values"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "lines_to_next_cell": 2
+ },
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "@webio": {
+ "lastCommId": null,
+ "lastKernelId": null
+ },
+ "jupytext": {
+ "encoding": "# -*- coding: utf-8 -*-",
+ "formats": "ipynb,py:light"
+ },
+ "kernelspec": {
+ "display_name": "seq2seq-time",
+ "language": "python",
+ "name": "seq2seq-time"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.8"
+ },
+ "toc": {
+ "base_numbering": 1,
+ "nav_menu": {},
+ "number_sections": true,
+ "sideBar": true,
+ "skip_h1_title": false,
+ "title_cell": "Table of Contents",
+ "title_sidebar": "Contents",
+ "toc_cell": false,
+ "toc_position": {
+ "height": "calc(100% - 180px)",
+ "left": "10px",
+ "top": "150px",
+ "width": "209.162px"
+ },
+ "toc_section_display": true,
+ "toc_window_display": true
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/notebooks/07.1-mc-optuna.py b/notebooks/07.1-mc-optuna.py
new file mode 100644
index 0000000..5571cee
--- /dev/null
+++ b/notebooks/07.1-mc-optuna.py
@@ -0,0 +1,539 @@
+# -*- coding: utf-8 -*-
+# ---
+# jupyter:
+# jupytext:
+# formats: ipynb,py:light
+# text_representation:
+# extension: .py
+# format_name: light
+# format_version: '1.5'
+# jupytext_version: 1.6.0
+# kernelspec:
+# display_name: seq2seq-time
+# language: python
+# name: seq2seq-time
+# ---
+
+# # Sequence to Sequence Models for Timeseries Regression
+#
+#
+# In this notebook we are going to find the optimal hidden_size for a model vs a dataset. We will use pytorch lightning and optuna.
+
+# OPTIONAL: Load the "autoreload" extension so that code can change. But blacklist large modules
+# %load_ext autoreload
+# %autoreload 2
+# %aimport -pandas
+# %aimport -torch
+# %aimport -numpy
+# %aimport -matplotlib
+# %aimport -dask
+# %aimport -tqdm
+# %matplotlib inline
+
+# +
+# Imports
+import torch
+from torch import nn, optim
+from torch.nn import functional as F
+from torch.autograd import Variable
+import torch
+import torch.utils.data
+
+
+from pathlib import Path
+from tqdm.auto import tqdm
+
+import pytorch_lightning as pl
+# -
+from seq2seq_time.data.dataset import Seq2SeqDataSet, Seq2SeqDataSets
+from seq2seq_time.predict import predict, predict_multi
+from seq2seq_time.util import dset_to_nc
+
+# +
+import logging
+import warnings
+import seq2seq_time.silence
+warnings.simplefilter('once')
+warnings.simplefilter(action='ignore', category=FutureWarning)
+warnings.simplefilter(action='ignore', category=DeprecationWarning)
+warnings.filterwarnings('ignore', 'Consider increasing the value of the `num_workers` argument', UserWarning)
+warnings.filterwarnings('ignore', 'Your val_dataloader has `shuffle=True`', UserWarning)
+
+from pytorch_lightning import _logger as log
+log.setLevel(logging.WARN)
+# -
+
+# ## Parameters
+
+# +
+device = "cuda" if torch.cuda.is_available() else "cpu"
+print(f'using {device}')
+
+timestamp = '20201108-300000'
+print(timestamp)
+window_past = 48*7
+window_future = 48
+batch_size = 64
+num_workers = 4
+datasets_root = Path('../data/processed/')
+window_past
+# -
+# ## Datasets
+#
+# From easy to hard, these dataset show different challenges, all of them with more than 20k datapoints and with a regression output. See the 00.01 notebook for more details, and the code for more information.
+#
+# Some such as MetroInterstateTraffic are easier, some are periodic such as BejingPM25, some are conditional on inputs such as GasSensor, and some are noisy and periodic like IMOSCurrentsVel
+
+from seq2seq_time.data.data import IMOSCurrentsVel, AppliancesEnergyPrediction, BejingPM25, GasSensor, MetroInterstateTraffic
+datasets = [MetroInterstateTraffic, IMOSCurrentsVel, GasSensor, AppliancesEnergyPrediction, BejingPM25]
+datasets
+# ## Lightning
+#
+# We will use pytorch lightning to handle all the training scaffolding. We have a common pytorch lightning class that takes in the model and defines training steps and logging.
+# +
+import pytorch_lightning as pl
+
+class PL_MODEL(pl.LightningModule):
+ def __init__(self, model, lr=3e-4, patience=None, weight_decay=0):
+ super().__init__()
+ self._model = model
+ self.lr = lr
+ self.patience = patience
+ self.weight_decay = weight_decay
+
+ def forward(self, x_past, y_past, x_future, y_future=None):
+ """Eval/Predict"""
+ y_dist, extra = self._model(x_past, y_past, x_future, y_future)
+ return y_dist, extra
+
+ def training_step(self, batch, batch_idx, phase='train'):
+ x_past, y_past, x_future, y_future = batch
+ y_dist, extra = self.forward(*batch)
+ loss = -y_dist.log_prob(y_future).mean()
+ self.log_dict({f'loss/{phase}':loss})
+ if ('loss' in extra) and (phase=='train'):
+ # some models have a special loss
+ loss = extra['loss']
+ self.log_dict({f'model_loss/{phase}':loss})
+ return loss
+
+ def validation_step(self, batch, batch_idx):
+ return self.training_step(batch, batch_idx, phase='val')
+
+ def test_step(self, batch, batch_idx):
+ return self.training_step(batch, batch_idx, phase='test')
+
+ def configure_optimizers(self):
+ optim = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
+ optim,
+ patience=self.patience,
+ verbose=False,
+ min_lr=1e-7,
+ ) if self.patience else None
+ return {'optimizer': optim, 'lr_scheduler': scheduler, 'monitor': 'loss/val'}
+
+
+# -
+
+from torch.utils.data import DataLoader
+from pytorch_lightning.loggers import CSVLogger, WandbLogger, TensorBoardLogger, TestTubeLogger
+from pytorch_lightning.callbacks.early_stopping import EarlyStopping
+from pytorch_lightning.callbacks import LearningRateMonitor
+
+
+# ## Models
+
+from seq2seq_time.models.baseline import BaselineLast, BaselineMean
+from seq2seq_time.models.lstm_seq2seq import LSTMSeq2Seq
+from seq2seq_time.models.lstm import LSTM
+from seq2seq_time.models.transformer import Transformer
+from seq2seq_time.models.transformer_seq2seq import TransformerSeq2Seq
+from seq2seq_time.models.neural_process import RANP
+from seq2seq_time.models.transformer_process import TransformerProcess
+from seq2seq_time.models.tcn import TCNSeq
+from seq2seq_time.models.inceptiontime import InceptionTimeSeq
+from seq2seq_time.models.xattention import CrossAttention
+# +
+import gc
+
+def free_mem():
+ gc.collect()
+ torch.cuda.empty_cache()
+ gc.collect()
+# -
+
+
+
+# +
+# PARAMS: model
+dropout=0.0
+layers=6
+nhead=4
+
+models = [
+# lambda xs, ys: BaselineLast(),
+# lambda xs, ys, hidden_size: BaselineMean(),
+ lambda xs, ys, hidden_size, layers:TransformerProcess(xs,
+ ys, hidden_size=hidden_size, nhead=nhead,
+ latent_dim=hidden_size//2, dropout=dropout,
+ nlayers=layers),
+ lambda xs, ys, hidden_size, layers: RANP(xs,
+ ys, hidden_dim=hidden_size, dropout=dropout,
+ latent_dim=hidden_size//2, n_decoder_layers=layers, n_latent_encoder_layers=layers, n_det_encoder_layers=layers),
+ lambda xs, ys, hidden_size, layers:TCNSeq(xs, ys, hidden_size=hidden_size, nlayers=layers, dropout=dropout, kernel_size=2),
+ lambda xs, ys, hidden_size, layers: Transformer(xs,
+ ys,
+ attention_dropout=dropout,
+ nhead=nhead,
+ nlayers=layers,
+ hidden_size=hidden_size),
+ lambda xs, ys, hidden_size, layers: LSTM(xs,
+ ys,
+ hidden_size=hidden_size,
+ lstm_layers=layers//2,
+ lstm_dropout=dropout),
+
+ lambda xs, ys, hidden_size, layers: TransformerSeq2Seq(xs,
+ ys,
+ hidden_size=hidden_size,
+ nhead=nhead,
+ nlayers=layers,
+ attention_dropout=dropout
+ ),
+
+ lambda xs, ys, hidden_size, layers: LSTMSeq2Seq(xs,
+ ys,
+ hidden_size=hidden_size,
+ lstm_layers=layers//2,
+ lstm_dropout=dropout),
+ lambda xs, ys, hidden_size, layers: CrossAttention(xs,
+ ys,
+ nlayers=layers,
+ hidden_size=hidden_size,),
+ lambda xs, ys, hidden_size, layers: InceptionTimeSeq(xs,
+ ys,
+ kernel_size=96,
+ layers=layers//2,
+ hidden_size=hidden_size,
+ bottleneck=hidden_size//4)
+
+]
+# +
+# DEBUG: sanity check
+
+for Dataset in datasets:
+ dataset_name = Dataset.__name__
+ dataset = Dataset(datasets_root)
+ ds_train, ds_val, ds_test = dataset.to_datasets(window_past=window_past,
+ window_future=window_future)
+
+ # Init data
+ x_past, y_past, x_future, y_future = ds_train.get_rows(10)
+ xs = x_past.shape[-1]
+ ys = y_future.shape[-1]
+
+ # Loaders
+ dl_train = DataLoader(ds_train,
+ batch_size=batch_size,
+ shuffle=True,
+ pin_memory=num_workers == 0,
+ num_workers=num_workers)
+ dl_val = DataLoader(ds_val,
+ shuffle=True,
+ batch_size=batch_size,
+ num_workers=num_workers)
+
+ for m_fn in models:
+ free_mem()
+ pt_model = m_fn(xs, ys, 8, 4)
+ model_name = type(pt_model).__name__
+ print(timestamp, dataset_name, model_name)
+
+ # Wrap in lightning
+ model = PL_MODEL(pt_model,
+ lr=3e-4
+ ).to(device)
+ trainer = pl.Trainer(
+ fast_dev_run=True,
+ # GPU
+ gpus=1,
+ amp_level='O1',
+ precision=16,
+ )
+# -
+
+# ## Train
+
+max_iters=20000
+
+
+tensorboard_dir = Path(f"../outputs/{timestamp}").resolve()
+print(f'For tensorboard run:\ntensorboard --logdir="{tensorboard_dir}"')
+
+
+# +
+
+def objective(trial):
+ """
+ Optuna function to optimize
+
+ See https://github.com/optuna/optuna/blob/master/examples/pytorch_lightning_simple.py
+ """
+ # sample
+ hidden_size_exp = trial.suggest_int("hidden_size_exp", 1, 8)
+ hidden_size = 2**hidden_size_exp
+
+ layers = trial.suggest_int("layers", 1, 12)
+
+ # Load model
+ pt_model = m_fn(xs, ys, hidden_size, layers)
+ model_name = type(pt_model).__name__
+
+ # Wrap in lightning
+ patience = 2
+ model = PL_MODEL(pt_model,
+ lr=3e-4, patience=patience,
+ ).to(device)
+
+
+ save_dir = f"../outputs/{timestamp}/{dataset_name}_{model_name}/{trial.number}"
+ Path(save_dir).mkdir(exist_ok=True, parents=True)
+ trainer = pl.Trainer(
+ # Training length
+ min_epochs=2,
+ max_epochs=40,
+ limit_train_batches=max_iters//batch_size,
+ limit_val_batches=max_iters//batch_size//5,
+ # Misc
+ gradient_clip_val=20,
+ terminate_on_nan=True,
+ # GPU
+ gpus=1,
+ amp_level='O1',
+ precision=16,
+ # Callbacks
+ default_root_dir=save_dir,
+ logger=False,
+ callbacks=[
+ EarlyStopping(monitor='loss/val', patience=patience * 2),
+ PyTorchLightningPruningCallback(trial, monitor="loss/val")],
+ )
+ trainer.fit(model, dl_train, dl_val)
+
+ # Run on all val data, using test mode
+ r = trainer.test(model, test_dataloaders=dl_val, verbose=False)
+ return r[0]['loss/test']
+# -
+
+
+
+import optuna
+from optuna.integration import PyTorchLightningPruningCallback
+
+import subprocess
+def get_git_commit():
+ try:
+ return subprocess.check_output(["git", "rev-parse", "HEAD"], cwd='..').decode().strip()
+ except Exception:
+ logging.exception("failed to get git hash")
+
+
+Path(f"../outputs/{timestamp}").mkdir(exist_ok=True)
+storage = f"sqlite:///../outputs/{timestamp}/optuna.db"
+
+for Dataset in tqdm(datasets, desc='datasets'):
+ dataset_name = Dataset.__name__
+ dataset = Dataset(datasets_root)
+ ds_train, ds_val, ds_test = dataset.to_datasets(window_past=window_past,
+ window_future=window_future)
+
+ # Init data
+ x_past, y_past, x_future, y_future = ds_train.get_rows(10)
+ xs = x_past.shape[-1]
+ ys = y_future.shape[-1]
+
+ # Loaders
+ dl_train = DataLoader(ds_train,
+ batch_size=batch_size,
+ shuffle=True,
+ drop_last=True,
+ pin_memory=num_workers == 0,
+ num_workers=num_workers)
+ dl_val = DataLoader(ds_val,
+ shuffle=False,
+ batch_size=batch_size,
+ drop_last=True,
+ num_workers=num_workers)
+
+ for i, m_fn in enumerate(tqdm(models, desc=f'models ({dataset_name})')):
+ try:
+ model_name = type(m_fn(8, 8, 8, 2)).__name__
+ free_mem()
+ study_name = f'{timestamp}_{dataset_name}-{model_name}'
+
+ # Create study
+ pruner = optuna.pruners.MedianPruner()
+ study = optuna.create_study(storage=storage,
+ study_name=study_name,
+ pruner=pruner,
+ load_if_exists=True)
+ study.set_user_attr('dataset', dataset_name)
+ study.set_user_attr('model', model_name)
+ study.set_user_attr('commit', get_git_commit())
+
+ df_trials = study.trials_dataframe()
+ if len(df_trials):
+ df_trials = df_trials[df_trials.state=='COMPLETE']
+ nb_trials = len(df_trials)
+ if nb_trials==0:
+ # Priors
+ study.enqueue_trial({"layers": 6, "params_hidden_size_exp": 2})
+ study.enqueue_trial({"layers": 1, "params_hidden_size_exp": 3})
+ study.enqueue_trial({"layers": 3, "params_hidden_size_exp": 5})
+ if nb_trials<20:
+ # Opt
+ study.optimize(objective, n_trials=20-nb_trials,
+ timeout=60*60*2 # Max seconds for all optimizes
+ )
+
+
+ print("Number of finished trials: {}".format(len(study.trials)))
+
+ print("Best trial:")
+ trial = study.best_trial
+
+ print(" Value: {}".format(trial.value))
+
+ print(" Params: ")
+ for key, value in trial.params.items():
+ print(" {}: {}".format(key, value))
+
+ except Exception as e:
+ logging.exception('failed to run model')
+
+
+# +
+# Baseline
+
+models2 = [
+ lambda xs, ys, _, l: BaselineLast(),
+ lambda xs, ys, _, l: BaselineMean(),
+]
+
+for Dataset in tqdm(datasets, desc='datasets'):
+ dataset_name = Dataset.__name__
+ dataset = Dataset(datasets_root)
+ ds_train, ds_val, ds_test = dataset.to_datasets(window_past=window_past,
+ window_future=window_future)
+
+ # Init data
+ x_past, y_past, x_future, y_future = ds_train.get_rows(10)
+ xs = x_past.shape[-1]
+ ys = y_future.shape[-1]
+
+ # Loaders
+ dl_train = DataLoader(ds_train,
+ batch_size=batch_size,
+ shuffle=True,
+ drop_last=True,
+ pin_memory=num_workers == 0,
+ num_workers=num_workers)
+ dl_val = DataLoader(ds_val,
+ shuffle=False,
+ batch_size=batch_size,
+ drop_last=True,
+ num_workers=num_workers)
+
+ for i, m_fn in enumerate(tqdm(models2, desc=f'models ({dataset_name})')):
+ try:
+ model_name = type(m_fn(8, 8, 8, 2)).__name__
+ free_mem()
+ study_name = f'{timestamp}_{dataset_name}-{model_name}'
+
+ # Create study
+ pruner = optuna.pruners.MedianPruner()
+ study = optuna.create_study(storage=storage,
+ study_name=study_name,
+ pruner=pruner,
+ load_if_exists=True)
+ study.set_user_attr('dataset', dataset_name)
+ study.set_user_attr('model', model_name)
+ study.set_user_attr('commit', get_git_commit())
+
+ df_trials = study.trials_dataframe()
+ if len(df_trials):
+ df_trials = df_trials[df_trials.state=='COMPLETE']
+ nb_trials = len(df_trials)
+ if nb_trials<1:
+ # Opt
+ study.optimize(objective, n_trials=1,
+ timeout=60*30 # Max seconds for all optimizes
+ )
+
+ except Exception as e:
+ logging.exception('failed to run model')
+
+# +
+# TODO baseline, run as sep cell, opt once
+# TODO summarize time and model params at best params
+
+# +
+import pandas as pd
+# Summarize studies
+rs = []
+study_summaries = optuna.study.get_all_study_summaries(storage=storage)
+for s in study_summaries:
+ row = {}
+ if (s.best_trial is not None) and (s.best_trial.state==optuna.trial.TrialState.COMPLETE):
+ params = {k:v for k,v in s.best_trial.__dict__.items() if not k.startswith('_')}
+ row.update(s.user_attrs)
+ row['n_trials'] = s.n_trials
+ row.update({'param_'+k:v for k,v in s.best_trial._params.items()})
+ row.update(params)
+ rs.append(row)
+df_studies = pd.DataFrame(rs)
+
+df_studies = (df_studies.drop(columns=['state', 'intermediate_values', 'commit', 'datetime_complete'])
+ .sort_values(['dataset', 'value'])
+ .set_index(['dataset', 'model'])
+ )
+df_studies
+
+# +
+# study_names = [s.study_name for s in optuna.study.get_all_study_summaries(storage=storage)]
+
+# for study_name in study_names:
+# loaded_study = optuna.load_study(study_name=study_name, storage=storage)
+
+# # Make DF over trials
+# print(study_name)
+
+# df_trials = loaded_study.trials_dataframe()
+# # df_trials.index = df_trials.apply(lambda r:f'l={r.params_layers}_hs={r.params_hidden_size_exp}', 1)
+# display(df_trials)
+
+# # Plot test curves, to see how much overfitting
+# df_values = pd.DataFrame([s.intermediate_values for s in loaded_study.get_trials()]).T
+# df_values.columns = [f"l={s.params['layers']}_hs={s.params['hidden_size_exp']}" for s in loaded_study.get_trials()]
+# df_values.plot(ylabel='nll', xlabel='epochs', title=f'val loss "{study_name}"')
+# +
+# study_names = [s.study_name for s in optuna.study.get_all_study_summaries(storage=storage)]
+
+# for study_name in study_names:
+# loaded_study = optuna.load_study(study_name=study_name, storage=storage)
+# fig=optuna.visualization.plot_contour(loaded_study, params=['hidden_size_exp', 'layers'])
+# fig = fig.update_layout(dict(title=f'{study_name} nll'))
+# display(fig)
+# -
+
+
+
+
+
+[f"{s.params['layers']}_{s.params['hidden_size_exp']}" for s in loaded_study.get_trials()]
+
+df_values
+
+
+
diff --git a/seq2seq_time/models/neural_process.py b/seq2seq_time/models/neural_process.py
index 0b4472c..67b3609 100644
--- a/seq2seq_time/models/neural_process.py
+++ b/seq2seq_time/models/neural_process.py
@@ -407,9 +407,11 @@ class RANP(nn.Module):
y = torch.cat([past_y, future_y], 1)
dist_post, log_var_post = self._latent_encoder(x, y)
- if self.training:
- z = dist_prior.rsample()
+ if self.training and (future_y is not None):
+ # USe posterior during training, is possible
+ z = dist_post.rsample()
else:
+ # During eval use the prior, also take the most probable
z = dist_prior.loc
num_targets = future_x.size(1)
z = z.unsqueeze(1).repeat(1, num_targets, 1) # [B, T_target, H]
diff --git a/seq2seq_time/models/transformer_process.py b/seq2seq_time/models/transformer_process.py
index 0614016..bc836e4 100644
--- a/seq2seq_time/models/transformer_process.py
+++ b/seq2seq_time/models/transformer_process.py
@@ -162,10 +162,11 @@ class TransformerProcess(nn.Module):
y = torch.cat([past_y, future_y], 1)
dist_post = self._latent_encoder(x, y)
- if self.training:
- # Sample from latent space during training
- z = dist_prior.rsample()
+ if self.training and (future_y is not None):
+ # USe posterior during training, is possible
+ z = dist_post.rsample()
else:
+ # During eval use the prior, also take the most probable
z = dist_prior.loc
num_targets = future_x.size(1)
z = z.unsqueeze(1).repeat(1, num_targets, 1) # [B, S_target, H]