mirror of
https://github.com/wassname/seq2seq-time.git
synced 2026-06-27 18:06:49 +08:00
1629 lines
58 KiB
Plaintext
1629 lines
58 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2020-10-10T01:25:12.788851Z",
|
|
"start_time": "2020-10-10T01:25:12.783398Z"
|
|
}
|
|
},
|
|
"source": [
|
|
"# Sequence to Sequence Models for Timeseries Regression\n",
|
|
"\n",
|
|
"\n",
|
|
"In this notebook we are going to find the optimal hidden_size for a model vs a dataset. We will use pytorch lightning and optuna."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2020-11-13T07:55:31.364813Z",
|
|
"start_time": "2020-11-13T07:55:31.025418Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# OPTIONAL: Load the \"autoreload\" extension so that code can change. But blacklist large modules\n",
|
|
"%load_ext autoreload\n",
|
|
"%autoreload 2\n",
|
|
"%aimport -pandas\n",
|
|
"%aimport -torch\n",
|
|
"%aimport -numpy\n",
|
|
"%aimport -matplotlib\n",
|
|
"%aimport -dask\n",
|
|
"%aimport -tqdm\n",
|
|
"%matplotlib inline"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2020-11-13T07:55:32.324191Z",
|
|
"start_time": "2020-11-13T07:55:31.367170Z"
|
|
},
|
|
"lines_to_next_cell": 0
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Imports\n",
|
|
"import torch\n",
|
|
"from torch import nn, optim\n",
|
|
"from torch.nn import functional as F\n",
|
|
"from torch.autograd import Variable\n",
|
|
"import torch\n",
|
|
"import torch.utils.data\n",
|
|
"\n",
|
|
"\n",
|
|
"from pathlib import Path\n",
|
|
"from tqdm.auto import tqdm\n",
|
|
"\n",
|
|
"import pytorch_lightning as pl"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2020-11-13T07:55:32.863984Z",
|
|
"start_time": "2020-11-13T07:55:32.326782Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from seq2seq_time.data.dataset import Seq2SeqDataSet, Seq2SeqDataSets\n",
|
|
"from seq2seq_time.predict import predict, predict_multi\n",
|
|
"from seq2seq_time.util import dset_to_nc"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2020-11-13T07:55:32.900137Z",
|
|
"start_time": "2020-11-13T07:55:32.866551Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import logging\n",
|
|
"import warnings\n",
|
|
"import seq2seq_time.silence \n",
|
|
"warnings.simplefilter('once')\n",
|
|
"warnings.simplefilter(action='ignore', category=FutureWarning)\n",
|
|
"warnings.simplefilter(action='ignore', category=DeprecationWarning)\n",
|
|
"warnings.filterwarnings('ignore', 'Consider increasing the value of the `num_workers` argument', UserWarning)\n",
|
|
"warnings.filterwarnings('ignore', 'Your val_dataloader has `shuffle=True`', UserWarning)\n",
|
|
"\n",
|
|
"from pytorch_lightning import _logger as log\n",
|
|
"log.setLevel(logging.WARN)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2020-10-10T01:28:32.492160Z",
|
|
"start_time": "2020-10-10T01:28:32.488140Z"
|
|
}
|
|
},
|
|
"source": [
|
|
"## Parameters"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2020-11-13T07:55:32.992818Z",
|
|
"start_time": "2020-11-13T07:55:32.902453Z"
|
|
},
|
|
"lines_to_next_cell": 0
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"using cuda\n",
|
|
"20201108-300000\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"336"
|
|
]
|
|
},
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
|
"print(f'using {device}')\n",
|
|
"\n",
|
|
"timestamp = '20201108-300000'\n",
|
|
"print(timestamp)\n",
|
|
"window_past = 48*7\n",
|
|
"window_future = 48\n",
|
|
"batch_size = 64\n",
|
|
"num_workers = 4\n",
|
|
"datasets_root = Path('../data/processed/')\n",
|
|
"window_past"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Datasets\n",
|
|
"\n",
|
|
"From easy to hard, these dataset show different challenges, all of them with more than 20k datapoints and with a regression output. See the 00.01 notebook for more details, and the code for more information.\n",
|
|
"\n",
|
|
"Some such as MetroInterstateTraffic are easier, some are periodic such as BejingPM25, some are conditional on inputs such as GasSensor, and some are noisy and periodic like IMOSCurrentsVel"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2020-11-13T07:55:33.715323Z",
|
|
"start_time": "2020-11-13T07:55:33.238397Z"
|
|
},
|
|
"lines_to_next_cell": 0
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[seq2seq_time.data.data.MetroInterstateTraffic,\n",
|
|
" seq2seq_time.data.data.IMOSCurrentsVel,\n",
|
|
" seq2seq_time.data.data.GasSensor,\n",
|
|
" seq2seq_time.data.data.AppliancesEnergyPrediction,\n",
|
|
" seq2seq_time.data.data.BejingPM25]"
|
|
]
|
|
},
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"from seq2seq_time.data.data import IMOSCurrentsVel, AppliancesEnergyPrediction, BejingPM25, GasSensor, MetroInterstateTraffic\n",
|
|
"datasets = [MetroInterstateTraffic, IMOSCurrentsVel, GasSensor, AppliancesEnergyPrediction, BejingPM25]\n",
|
|
"datasets\n",
|
|
"# ## Lightning\n",
|
|
"#\n",
|
|
"# We will use pytorch lightning to handle all the training scaffolding. We have a common pytorch lightning class that takes in the model and defines training steps and logging."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2020-11-13T07:55:34.390203Z",
|
|
"start_time": "2020-11-13T07:55:34.338792Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import pytorch_lightning as pl\n",
|
|
"\n",
|
|
"class PL_MODEL(pl.LightningModule):\n",
|
|
" def __init__(self, model, lr=3e-4, patience=None, weight_decay=0):\n",
|
|
" super().__init__()\n",
|
|
" self._model = model\n",
|
|
" self.lr = lr\n",
|
|
" self.patience = patience\n",
|
|
" self.weight_decay = weight_decay\n",
|
|
"\n",
|
|
" def forward(self, x_past, y_past, x_future, y_future=None):\n",
|
|
" \"\"\"Eval/Predict\"\"\"\n",
|
|
" y_dist, extra = self._model(x_past, y_past, x_future, y_future)\n",
|
|
" return y_dist, extra\n",
|
|
"\n",
|
|
" def training_step(self, batch, batch_idx, phase='train'):\n",
|
|
" x_past, y_past, x_future, y_future = batch\n",
|
|
" y_dist, extra = self.forward(*batch)\n",
|
|
" loss = -y_dist.log_prob(y_future).mean()\n",
|
|
" self.log_dict({f'loss/{phase}':loss})\n",
|
|
" if ('loss' in extra) and (phase=='train'):\n",
|
|
" # some models have a special loss\n",
|
|
" loss = extra['loss']\n",
|
|
" self.log_dict({f'model_loss/{phase}':loss})\n",
|
|
" return loss\n",
|
|
"\n",
|
|
" def validation_step(self, batch, batch_idx):\n",
|
|
" return self.training_step(batch, batch_idx, phase='val')\n",
|
|
" \n",
|
|
" def test_step(self, batch, batch_idx):\n",
|
|
" return self.training_step(batch, batch_idx, phase='test')\n",
|
|
" \n",
|
|
" def configure_optimizers(self):\n",
|
|
" optim = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)\n",
|
|
" scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n",
|
|
" optim,\n",
|
|
" patience=self.patience,\n",
|
|
" verbose=False,\n",
|
|
" min_lr=1e-7,\n",
|
|
" ) if self.patience else None\n",
|
|
" return {'optimizer': optim, 'lr_scheduler': scheduler, 'monitor': 'loss/val'}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2020-11-13T07:55:34.653519Z",
|
|
"start_time": "2020-11-13T07:55:34.605285Z"
|
|
},
|
|
"lines_to_next_cell": 2
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from torch.utils.data import DataLoader\n",
|
|
"from pytorch_lightning.loggers import CSVLogger, WandbLogger, TensorBoardLogger, TestTubeLogger\n",
|
|
"from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n",
|
|
"from pytorch_lightning.callbacks import LearningRateMonitor"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Models"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2020-11-13T07:55:35.693992Z",
|
|
"start_time": "2020-11-13T07:55:35.629427Z"
|
|
},
|
|
"lines_to_end_of_cell_marker": 2,
|
|
"lines_to_next_cell": 0
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from seq2seq_time.models.baseline import BaselineLast, BaselineMean\n",
|
|
"from seq2seq_time.models.lstm_seq2seq import LSTMSeq2Seq\n",
|
|
"from seq2seq_time.models.lstm import LSTM\n",
|
|
"from seq2seq_time.models.transformer import Transformer\n",
|
|
"from seq2seq_time.models.transformer_seq2seq import TransformerSeq2Seq\n",
|
|
"from seq2seq_time.models.neural_process import RANP\n",
|
|
"from seq2seq_time.models.transformer_process import TransformerProcess\n",
|
|
"from seq2seq_time.models.tcn import TCNSeq\n",
|
|
"from seq2seq_time.models.inceptiontime import InceptionTimeSeq\n",
|
|
"from seq2seq_time.models.xattention import CrossAttention"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2020-11-13T07:55:36.024789Z",
|
|
"start_time": "2020-11-13T07:55:35.985433Z"
|
|
},
|
|
"lines_to_next_cell": 0
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import gc\n",
|
|
"\n",
|
|
"def free_mem():\n",
|
|
" gc.collect()\n",
|
|
" torch.cuda.empty_cache()\n",
|
|
" gc.collect()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2020-11-02T06:10:41.904480Z",
|
|
"start_time": "2020-11-02T06:10:41.848613Z"
|
|
},
|
|
"lines_to_next_cell": 2
|
|
},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2020-11-13T07:55:37.813853Z",
|
|
"start_time": "2020-11-13T07:55:37.768348Z"
|
|
},
|
|
"lines_to_next_cell": 0
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# PARAMS: model\n",
|
|
"dropout=0.0\n",
|
|
"layers=6\n",
|
|
"nhead=4\n",
|
|
"\n",
|
|
"models = [\n",
|
|
"# lambda xs, ys: BaselineLast(),\n",
|
|
"# lambda xs, ys, hidden_size: BaselineMean(),\n",
|
|
" lambda xs, ys, hidden_size, layers:TransformerProcess(xs,\n",
|
|
" ys, hidden_size=hidden_size, nhead=nhead,\n",
|
|
" latent_dim=hidden_size//2, dropout=dropout,\n",
|
|
" nlayers=layers),\n",
|
|
" lambda xs, ys, hidden_size, layers: RANP(xs,\n",
|
|
" ys, hidden_dim=hidden_size, dropout=dropout, \n",
|
|
" latent_dim=hidden_size//2, n_decoder_layers=layers, n_latent_encoder_layers=layers, n_det_encoder_layers=layers),\n",
|
|
" lambda xs, ys, hidden_size, layers:TCNSeq(xs, ys, hidden_size=hidden_size, nlayers=layers, dropout=dropout, kernel_size=2),\n",
|
|
" lambda xs, ys, hidden_size, layers: Transformer(xs,\n",
|
|
" ys,\n",
|
|
" attention_dropout=dropout,\n",
|
|
" nhead=nhead,\n",
|
|
" nlayers=layers,\n",
|
|
" hidden_size=hidden_size),\n",
|
|
" lambda xs, ys, hidden_size, layers: LSTM(xs,\n",
|
|
" ys,\n",
|
|
" hidden_size=hidden_size,\n",
|
|
" lstm_layers=layers//2,\n",
|
|
" lstm_dropout=dropout),\n",
|
|
"\n",
|
|
" lambda xs, ys, hidden_size, layers: TransformerSeq2Seq(xs,\n",
|
|
" ys,\n",
|
|
" hidden_size=hidden_size,\n",
|
|
" nhead=nhead,\n",
|
|
" nlayers=layers,\n",
|
|
" attention_dropout=dropout\n",
|
|
" ),\n",
|
|
"\n",
|
|
" lambda xs, ys, hidden_size, layers: LSTMSeq2Seq(xs,\n",
|
|
" ys,\n",
|
|
" hidden_size=hidden_size,\n",
|
|
" lstm_layers=layers//2,\n",
|
|
" lstm_dropout=dropout),\n",
|
|
" lambda xs, ys, hidden_size, layers: CrossAttention(xs,\n",
|
|
" ys,\n",
|
|
" nlayers=layers,\n",
|
|
" hidden_size=hidden_size,),\n",
|
|
" lambda xs, ys, hidden_size, layers: InceptionTimeSeq(xs,\n",
|
|
" ys,\n",
|
|
" kernel_size=96,\n",
|
|
" layers=layers//2,\n",
|
|
" hidden_size=hidden_size,\n",
|
|
" bottleneck=hidden_size//4)\n",
|
|
"\n",
|
|
"]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2020-11-12T23:16:29.248291Z",
|
|
"start_time": "2020-11-12T23:16:16.139941Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# DEBUG: sanity check\n",
|
|
"\n",
|
|
"for Dataset in datasets:\n",
|
|
" dataset_name = Dataset.__name__\n",
|
|
" dataset = Dataset(datasets_root)\n",
|
|
" ds_train, ds_val, ds_test = dataset.to_datasets(window_past=window_past,\n",
|
|
" window_future=window_future)\n",
|
|
"\n",
|
|
" # Init data\n",
|
|
" x_past, y_past, x_future, y_future = ds_train.get_rows(10)\n",
|
|
" xs = x_past.shape[-1]\n",
|
|
" ys = y_future.shape[-1]\n",
|
|
"\n",
|
|
" # Loaders\n",
|
|
" dl_train = DataLoader(ds_train,\n",
|
|
" batch_size=batch_size,\n",
|
|
" shuffle=True,\n",
|
|
" pin_memory=num_workers == 0,\n",
|
|
" num_workers=num_workers)\n",
|
|
" dl_val = DataLoader(ds_val,\n",
|
|
" shuffle=True,\n",
|
|
" batch_size=batch_size,\n",
|
|
" num_workers=num_workers)\n",
|
|
"\n",
|
|
" for m_fn in models:\n",
|
|
" free_mem()\n",
|
|
" pt_model = m_fn(xs, ys, 8, 4)\n",
|
|
" model_name = type(pt_model).__name__\n",
|
|
" print(timestamp, dataset_name, model_name)\n",
|
|
"\n",
|
|
" # Wrap in lightning\n",
|
|
" model = PL_MODEL(pt_model,\n",
|
|
" lr=3e-4\n",
|
|
" ).to(device)\n",
|
|
" trainer = pl.Trainer(\n",
|
|
" fast_dev_run=True,\n",
|
|
" # GPU\n",
|
|
" gpus=1,\n",
|
|
" amp_level='O1',\n",
|
|
" precision=16,\n",
|
|
" )"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2020-10-23T23:36:11.052891Z",
|
|
"start_time": "2020-10-23T23:36:11.048874Z"
|
|
}
|
|
},
|
|
"source": [
|
|
"## Train"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2020-11-12T23:16:29.286762Z",
|
|
"start_time": "2020-11-12T23:16:29.250367Z"
|
|
},
|
|
"lines_to_next_cell": 2
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"max_iters=20000"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2020-11-12T23:16:29.326846Z",
|
|
"start_time": "2020-11-12T23:16:29.289004Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"tensorboard_dir = Path(f\"../outputs/{timestamp}\").resolve()\n",
|
|
"print(f'For tensorboard run:\\ntensorboard --logdir=\"{tensorboard_dir}\"')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2020-11-13T07:55:48.561522Z",
|
|
"start_time": "2020-11-13T07:55:48.516182Z"
|
|
},
|
|
"lines_to_next_cell": 0
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"def objective(trial):\n",
|
|
" \"\"\"\n",
|
|
" Optuna function to optimize\n",
|
|
" \n",
|
|
" See https://github.com/optuna/optuna/blob/master/examples/pytorch_lightning_simple.py\n",
|
|
" \"\"\"\n",
|
|
" # sample\n",
|
|
" hidden_size_exp = trial.suggest_int(\"hidden_size_exp\", 1, 8)\n",
|
|
" hidden_size = 2**hidden_size_exp\n",
|
|
" \n",
|
|
" layers = trial.suggest_int(\"layers\", 1, 12)\n",
|
|
" \n",
|
|
" # Load model\n",
|
|
" pt_model = m_fn(xs, ys, hidden_size, layers)\n",
|
|
" model_name = type(pt_model).__name__\n",
|
|
" \n",
|
|
" # Wrap in lightning\n",
|
|
" patience = 2\n",
|
|
" model = PL_MODEL(pt_model,\n",
|
|
" lr=3e-4, patience=patience,\n",
|
|
" ).to(device)\n",
|
|
"\n",
|
|
" \n",
|
|
" save_dir = f\"../outputs/{timestamp}/{dataset_name}_{model_name}/{trial.number}\"\n",
|
|
" Path(save_dir).mkdir(exist_ok=True, parents=True)\n",
|
|
" trainer = pl.Trainer(\n",
|
|
" # Training length\n",
|
|
" min_epochs=2,\n",
|
|
" max_epochs=40,\n",
|
|
" limit_train_batches=max_iters//batch_size,\n",
|
|
" limit_val_batches=max_iters//batch_size//5,\n",
|
|
" # Misc\n",
|
|
" gradient_clip_val=20,\n",
|
|
" terminate_on_nan=True,\n",
|
|
" # GPU\n",
|
|
" gpus=1,\n",
|
|
" amp_level='O1',\n",
|
|
" precision=16,\n",
|
|
" # Callbacks\n",
|
|
" default_root_dir=save_dir,\n",
|
|
" logger=False,\n",
|
|
" callbacks=[\n",
|
|
" EarlyStopping(monitor='loss/val', patience=patience * 2),\n",
|
|
" PyTorchLightningPruningCallback(trial, monitor=\"loss/val\")],\n",
|
|
" )\n",
|
|
" trainer.fit(model, dl_train, dl_val)\n",
|
|
" \n",
|
|
" # Run on all val data, using test mode\n",
|
|
" r = trainer.test(model, test_dataloaders=dl_val, verbose=False)\n",
|
|
" return r[0]['loss/test']"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2020-11-08T02:45:44.106583Z",
|
|
"start_time": "2020-11-08T02:45:44.050637Z"
|
|
},
|
|
"lines_to_next_cell": 2
|
|
},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2020-11-13T07:55:49.497334Z",
|
|
"start_time": "2020-11-13T07:55:48.981040Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import optuna\n",
|
|
"from optuna.integration import PyTorchLightningPruningCallback"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2020-11-13T07:55:49.543035Z",
|
|
"start_time": "2020-11-13T07:55:49.499555Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import subprocess\n",
|
|
"def get_git_commit():\n",
|
|
" try:\n",
|
|
" return subprocess.check_output([\"git\", \"rev-parse\", \"HEAD\"], cwd='..').decode().strip()\n",
|
|
" except Exception:\n",
|
|
" logging.exception(\"failed to get git hash\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2020-11-13T07:55:50.265327Z",
|
|
"start_time": "2020-11-13T07:55:50.204384Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"Path(f\"../outputs/{timestamp}\").mkdir(exist_ok=True)\n",
|
|
"storage = f\"sqlite:///../outputs/{timestamp}/optuna.db\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2020-11-13T01:44:13.554355Z",
|
|
"start_time": "2020-11-12T23:16:30.128631Z"
|
|
},
|
|
"lines_to_next_cell": 0,
|
|
"scrolled": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"for Dataset in tqdm(datasets, desc='datasets'):\n",
|
|
" dataset_name = Dataset.__name__\n",
|
|
" dataset = Dataset(datasets_root)\n",
|
|
" ds_train, ds_val, ds_test = dataset.to_datasets(window_past=window_past,\n",
|
|
" window_future=window_future)\n",
|
|
"\n",
|
|
" # Init data\n",
|
|
" x_past, y_past, x_future, y_future = ds_train.get_rows(10)\n",
|
|
" xs = x_past.shape[-1]\n",
|
|
" ys = y_future.shape[-1]\n",
|
|
"\n",
|
|
" # Loaders\n",
|
|
" dl_train = DataLoader(ds_train,\n",
|
|
" batch_size=batch_size,\n",
|
|
" shuffle=True,\n",
|
|
" drop_last=True,\n",
|
|
" pin_memory=num_workers == 0,\n",
|
|
" num_workers=num_workers)\n",
|
|
" dl_val = DataLoader(ds_val,\n",
|
|
" shuffle=False,\n",
|
|
" batch_size=batch_size,\n",
|
|
" drop_last=True,\n",
|
|
" num_workers=num_workers)\n",
|
|
"\n",
|
|
" for i, m_fn in enumerate(tqdm(models, desc=f'models ({dataset_name})')):\n",
|
|
" try:\n",
|
|
" model_name = type(m_fn(8, 8, 8, 2)).__name__\n",
|
|
" free_mem()\n",
|
|
" study_name = f'{timestamp}_{dataset_name}-{model_name}'\n",
|
|
" \n",
|
|
" # Create study \n",
|
|
" pruner = optuna.pruners.MedianPruner()\n",
|
|
" study = optuna.create_study(storage=storage, \n",
|
|
" study_name=study_name, \n",
|
|
" pruner=pruner,\n",
|
|
" load_if_exists=True)\n",
|
|
" study.set_user_attr('dataset', dataset_name)\n",
|
|
" study.set_user_attr('model', model_name)\n",
|
|
" study.set_user_attr('commit', get_git_commit())\n",
|
|
" \n",
|
|
" df_trials = study.trials_dataframe()\n",
|
|
" if len(df_trials):\n",
|
|
" df_trials = df_trials[df_trials.state=='COMPLETE']\n",
|
|
" nb_trials = len(df_trials)\n",
|
|
" if nb_trials==0:\n",
|
|
" # Priors\n",
|
|
" study.enqueue_trial({\"layers\": 6, \"params_hidden_size_exp\": 2})\n",
|
|
" study.enqueue_trial({\"layers\": 1, \"params_hidden_size_exp\": 3})\n",
|
|
" study.enqueue_trial({\"layers\": 3, \"params_hidden_size_exp\": 5})\n",
|
|
" if nb_trials<20:\n",
|
|
" # Opt\n",
|
|
" study.optimize(objective, n_trials=20-nb_trials, \n",
|
|
" timeout=60*60*2 # Max seconds for all optimizes\n",
|
|
" )\n",
|
|
" \n",
|
|
" \n",
|
|
" print(\"Number of finished trials: {}\".format(len(study.trials)))\n",
|
|
"\n",
|
|
" print(\"Best trial:\")\n",
|
|
" trial = study.best_trial\n",
|
|
"\n",
|
|
" print(\" Value: {}\".format(trial.value))\n",
|
|
"\n",
|
|
" print(\" Params: \")\n",
|
|
" for key, value in trial.params.items():\n",
|
|
" print(\" {}: {}\".format(key, value))\n",
|
|
" \n",
|
|
" except Exception as e:\n",
|
|
" logging.exception('failed to run model')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2020-11-12T23:08:15.792627Z",
|
|
"start_time": "2020-11-12T23:08:15.700279Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2020-11-13T01:44:20.959637Z",
|
|
"start_time": "2020-11-13T01:44:13.560463Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Baseline\n",
|
|
"\n",
|
|
"models2 = [\n",
|
|
" lambda xs, ys, _, l: BaselineLast(),\n",
|
|
" lambda xs, ys, _, l: BaselineMean(),\n",
|
|
"]\n",
|
|
"\n",
|
|
"for Dataset in tqdm(datasets, desc='datasets'):\n",
|
|
" dataset_name = Dataset.__name__\n",
|
|
" dataset = Dataset(datasets_root)\n",
|
|
" ds_train, ds_val, ds_test = dataset.to_datasets(window_past=window_past,\n",
|
|
" window_future=window_future)\n",
|
|
"\n",
|
|
" # Init data\n",
|
|
" x_past, y_past, x_future, y_future = ds_train.get_rows(10)\n",
|
|
" xs = x_past.shape[-1]\n",
|
|
" ys = y_future.shape[-1]\n",
|
|
"\n",
|
|
" # Loaders\n",
|
|
" dl_train = DataLoader(ds_train,\n",
|
|
" batch_size=batch_size,\n",
|
|
" shuffle=True,\n",
|
|
" drop_last=True,\n",
|
|
" pin_memory=num_workers == 0,\n",
|
|
" num_workers=num_workers)\n",
|
|
" dl_val = DataLoader(ds_val,\n",
|
|
" shuffle=False,\n",
|
|
" batch_size=batch_size,\n",
|
|
" drop_last=True,\n",
|
|
" num_workers=num_workers)\n",
|
|
"\n",
|
|
" for i, m_fn in enumerate(tqdm(models2, desc=f'models ({dataset_name})')):\n",
|
|
" try:\n",
|
|
" model_name = type(m_fn(8, 8, 8, 2)).__name__\n",
|
|
" free_mem()\n",
|
|
" study_name = f'{timestamp}_{dataset_name}-{model_name}'\n",
|
|
" \n",
|
|
" # Create study \n",
|
|
" pruner = optuna.pruners.MedianPruner()\n",
|
|
" study = optuna.create_study(storage=storage, \n",
|
|
" study_name=study_name, \n",
|
|
" pruner=pruner,\n",
|
|
" load_if_exists=True)\n",
|
|
" study.set_user_attr('dataset', dataset_name)\n",
|
|
" study.set_user_attr('model', model_name)\n",
|
|
" study.set_user_attr('commit', get_git_commit())\n",
|
|
" \n",
|
|
" df_trials = study.trials_dataframe()\n",
|
|
" if len(df_trials):\n",
|
|
" df_trials = df_trials[df_trials.state=='COMPLETE']\n",
|
|
" nb_trials = len(df_trials)\n",
|
|
" if nb_trials<1:\n",
|
|
" # Opt\n",
|
|
" study.optimize(objective, n_trials=1, \n",
|
|
" timeout=60*30 # Max seconds for all optimizes\n",
|
|
" )\n",
|
|
" \n",
|
|
" except Exception as e:\n",
|
|
" logging.exception('failed to run model')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2020-11-13T01:44:21.021030Z",
|
|
"start_time": "2020-11-13T01:44:20.962459Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# TODO baseline, run as sep cell, opt once\n",
|
|
"# TODO summarize time and model params at best params"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2020-11-13T07:55:54.138705Z",
|
|
"start_time": "2020-11-13T07:55:53.706642Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th></th>\n",
|
|
" <th>n_trials</th>\n",
|
|
" <th>param_hidden_size_exp</th>\n",
|
|
" <th>param_layers</th>\n",
|
|
" <th>value</th>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>dataset</th>\n",
|
|
" <th>model</th>\n",
|
|
" <th></th>\n",
|
|
" <th></th>\n",
|
|
" <th></th>\n",
|
|
" <th></th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th rowspan=\"11\" valign=\"top\">AppliancesEnergyPrediction</th>\n",
|
|
" <th>TCNSeq</th>\n",
|
|
" <td>48</td>\n",
|
|
" <td>2</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>1.075230</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>InceptionTimeSeq</th>\n",
|
|
" <td>17</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>1.099332</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>TransformerProcess</th>\n",
|
|
" <td>18</td>\n",
|
|
" <td>4</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>1.163042</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>LSTM</th>\n",
|
|
" <td>19</td>\n",
|
|
" <td>6</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>1.172367</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>TransformerSeq2Seq</th>\n",
|
|
" <td>17</td>\n",
|
|
" <td>2</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>1.174764</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>Transformer</th>\n",
|
|
" <td>19</td>\n",
|
|
" <td>4</td>\n",
|
|
" <td>6</td>\n",
|
|
" <td>1.195765</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>LSTMSeq2Seq</th>\n",
|
|
" <td>29</td>\n",
|
|
" <td>4</td>\n",
|
|
" <td>6</td>\n",
|
|
" <td>1.204848</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>RANP</th>\n",
|
|
" <td>16</td>\n",
|
|
" <td>4</td>\n",
|
|
" <td>10</td>\n",
|
|
" <td>1.278206</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>BaselineMean</th>\n",
|
|
" <td>1</td>\n",
|
|
" <td>6</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>1.322796</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>BaselineLast</th>\n",
|
|
" <td>1</td>\n",
|
|
" <td>2</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>1.475730</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>CrossAttention</th>\n",
|
|
" <td>17</td>\n",
|
|
" <td>4</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>1.547343</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th rowspan=\"11\" valign=\"top\">BejingPM25</th>\n",
|
|
" <th>TCNSeq</th>\n",
|
|
" <td>46</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>1.235752</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>InceptionTimeSeq</th>\n",
|
|
" <td>15</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>1.239024</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>LSTM</th>\n",
|
|
" <td>18</td>\n",
|
|
" <td>2</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>1.271591</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>LSTMSeq2Seq</th>\n",
|
|
" <td>17</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>6</td>\n",
|
|
" <td>1.292539</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>Transformer</th>\n",
|
|
" <td>53</td>\n",
|
|
" <td>6</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>1.295221</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>TransformerProcess</th>\n",
|
|
" <td>16</td>\n",
|
|
" <td>2</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>1.404077</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>CrossAttention</th>\n",
|
|
" <td>17</td>\n",
|
|
" <td>8</td>\n",
|
|
" <td>11</td>\n",
|
|
" <td>1.411578</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>RANP</th>\n",
|
|
" <td>16</td>\n",
|
|
" <td>4</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>1.429652</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>BaselineMean</th>\n",
|
|
" <td>1</td>\n",
|
|
" <td>6</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>1.435451</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>BaselineLast</th>\n",
|
|
" <td>1</td>\n",
|
|
" <td>4</td>\n",
|
|
" <td>12</td>\n",
|
|
" <td>1.548376</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>TransformerSeq2Seq</th>\n",
|
|
" <td>16</td>\n",
|
|
" <td>4</td>\n",
|
|
" <td>8</td>\n",
|
|
" <td>2.393506</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th rowspan=\"11\" valign=\"top\">GasSensor</th>\n",
|
|
" <th>RANP</th>\n",
|
|
" <td>18</td>\n",
|
|
" <td>6</td>\n",
|
|
" <td>10</td>\n",
|
|
" <td>-2.133501</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>InceptionTimeSeq</th>\n",
|
|
" <td>33</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>10</td>\n",
|
|
" <td>-2.102463</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>Transformer</th>\n",
|
|
" <td>50</td>\n",
|
|
" <td>6</td>\n",
|
|
" <td>12</td>\n",
|
|
" <td>-1.960168</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>TCNSeq</th>\n",
|
|
" <td>41</td>\n",
|
|
" <td>6</td>\n",
|
|
" <td>8</td>\n",
|
|
" <td>-1.736079</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>LSTM</th>\n",
|
|
" <td>18</td>\n",
|
|
" <td>6</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>-1.537527</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>LSTMSeq2Seq</th>\n",
|
|
" <td>20</td>\n",
|
|
" <td>8</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>-1.489893</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>TransformerProcess</th>\n",
|
|
" <td>38</td>\n",
|
|
" <td>7</td>\n",
|
|
" <td>4</td>\n",
|
|
" <td>-0.880734</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>CrossAttention</th>\n",
|
|
" <td>17</td>\n",
|
|
" <td>7</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>-0.638713</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>TransformerSeq2Seq</th>\n",
|
|
" <td>28</td>\n",
|
|
" <td>6</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>0.336668</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>BaselineMean</th>\n",
|
|
" <td>1</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>2</td>\n",
|
|
" <td>1.584720</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>BaselineLast</th>\n",
|
|
" <td>1</td>\n",
|
|
" <td>2</td>\n",
|
|
" <td>6</td>\n",
|
|
" <td>1.974204</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th rowspan=\"11\" valign=\"top\">IMOSCurrentsVel</th>\n",
|
|
" <th>TCNSeq</th>\n",
|
|
" <td>52</td>\n",
|
|
" <td>4</td>\n",
|
|
" <td>6</td>\n",
|
|
" <td>0.817371</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>InceptionTimeSeq</th>\n",
|
|
" <td>20</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>6</td>\n",
|
|
" <td>0.845443</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>LSTM</th>\n",
|
|
" <td>22</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>6</td>\n",
|
|
" <td>0.875669</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>Transformer</th>\n",
|
|
" <td>24</td>\n",
|
|
" <td>2</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>0.879682</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>BaselineLast</th>\n",
|
|
" <td>1</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>6</td>\n",
|
|
" <td>0.885377</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>LSTMSeq2Seq</th>\n",
|
|
" <td>24</td>\n",
|
|
" <td>7</td>\n",
|
|
" <td>12</td>\n",
|
|
" <td>0.887783</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>RANP</th>\n",
|
|
" <td>19</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>1.035464</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>BaselineMean</th>\n",
|
|
" <td>1</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>7</td>\n",
|
|
" <td>1.202492</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>TransformerSeq2Seq</th>\n",
|
|
" <td>19</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>1.266849</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>TransformerProcess</th>\n",
|
|
" <td>27</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>1.394595</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>CrossAttention</th>\n",
|
|
" <td>17</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>1.656854</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th rowspan=\"11\" valign=\"top\">MetroInterstateTraffic</th>\n",
|
|
" <th>TCNSeq</th>\n",
|
|
" <td>61</td>\n",
|
|
" <td>7</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>-0.324690</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>TransformerProcess</th>\n",
|
|
" <td>17</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>7</td>\n",
|
|
" <td>-0.297472</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>RANP</th>\n",
|
|
" <td>16</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>-0.291096</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>Transformer</th>\n",
|
|
" <td>20</td>\n",
|
|
" <td>4</td>\n",
|
|
" <td>11</td>\n",
|
|
" <td>-0.253573</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>LSTMSeq2Seq</th>\n",
|
|
" <td>15</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>-0.200625</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>LSTM</th>\n",
|
|
" <td>17</td>\n",
|
|
" <td>4</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>-0.200455</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>TransformerSeq2Seq</th>\n",
|
|
" <td>15</td>\n",
|
|
" <td>2</td>\n",
|
|
" <td>6</td>\n",
|
|
" <td>-0.192895</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>InceptionTimeSeq</th>\n",
|
|
" <td>17</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>6</td>\n",
|
|
" <td>-0.162513</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>CrossAttention</th>\n",
|
|
" <td>14</td>\n",
|
|
" <td>6</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>-0.103853</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>BaselineMean</th>\n",
|
|
" <td>1</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>1.410653</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>BaselineLast</th>\n",
|
|
" <td>1</td>\n",
|
|
" <td>2</td>\n",
|
|
" <td>12</td>\n",
|
|
" <td>1.741882</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" n_trials \\\n",
|
|
"dataset model \n",
|
|
"AppliancesEnergyPrediction TCNSeq 48 \n",
|
|
" InceptionTimeSeq 17 \n",
|
|
" TransformerProcess 18 \n",
|
|
" LSTM 19 \n",
|
|
" TransformerSeq2Seq 17 \n",
|
|
" Transformer 19 \n",
|
|
" LSTMSeq2Seq 29 \n",
|
|
" RANP 16 \n",
|
|
" BaselineMean 1 \n",
|
|
" BaselineLast 1 \n",
|
|
" CrossAttention 17 \n",
|
|
"BejingPM25 TCNSeq 46 \n",
|
|
" InceptionTimeSeq 15 \n",
|
|
" LSTM 18 \n",
|
|
" LSTMSeq2Seq 17 \n",
|
|
" Transformer 53 \n",
|
|
" TransformerProcess 16 \n",
|
|
" CrossAttention 17 \n",
|
|
" RANP 16 \n",
|
|
" BaselineMean 1 \n",
|
|
" BaselineLast 1 \n",
|
|
" TransformerSeq2Seq 16 \n",
|
|
"GasSensor RANP 18 \n",
|
|
" InceptionTimeSeq 33 \n",
|
|
" Transformer 50 \n",
|
|
" TCNSeq 41 \n",
|
|
" LSTM 18 \n",
|
|
" LSTMSeq2Seq 20 \n",
|
|
" TransformerProcess 38 \n",
|
|
" CrossAttention 17 \n",
|
|
" TransformerSeq2Seq 28 \n",
|
|
" BaselineMean 1 \n",
|
|
" BaselineLast 1 \n",
|
|
"IMOSCurrentsVel TCNSeq 52 \n",
|
|
" InceptionTimeSeq 20 \n",
|
|
" LSTM 22 \n",
|
|
" Transformer 24 \n",
|
|
" BaselineLast 1 \n",
|
|
" LSTMSeq2Seq 24 \n",
|
|
" RANP 19 \n",
|
|
" BaselineMean 1 \n",
|
|
" TransformerSeq2Seq 19 \n",
|
|
" TransformerProcess 27 \n",
|
|
" CrossAttention 17 \n",
|
|
"MetroInterstateTraffic TCNSeq 61 \n",
|
|
" TransformerProcess 17 \n",
|
|
" RANP 16 \n",
|
|
" Transformer 20 \n",
|
|
" LSTMSeq2Seq 15 \n",
|
|
" LSTM 17 \n",
|
|
" TransformerSeq2Seq 15 \n",
|
|
" InceptionTimeSeq 17 \n",
|
|
" CrossAttention 14 \n",
|
|
" BaselineMean 1 \n",
|
|
" BaselineLast 1 \n",
|
|
"\n",
|
|
" param_hidden_size_exp \\\n",
|
|
"dataset model \n",
|
|
"AppliancesEnergyPrediction TCNSeq 2 \n",
|
|
" InceptionTimeSeq 1 \n",
|
|
" TransformerProcess 4 \n",
|
|
" LSTM 6 \n",
|
|
" TransformerSeq2Seq 2 \n",
|
|
" Transformer 4 \n",
|
|
" LSTMSeq2Seq 4 \n",
|
|
" RANP 4 \n",
|
|
" BaselineMean 6 \n",
|
|
" BaselineLast 2 \n",
|
|
" CrossAttention 4 \n",
|
|
"BejingPM25 TCNSeq 3 \n",
|
|
" InceptionTimeSeq 1 \n",
|
|
" LSTM 2 \n",
|
|
" LSTMSeq2Seq 5 \n",
|
|
" Transformer 6 \n",
|
|
" TransformerProcess 2 \n",
|
|
" CrossAttention 8 \n",
|
|
" RANP 4 \n",
|
|
" BaselineMean 6 \n",
|
|
" BaselineLast 4 \n",
|
|
" TransformerSeq2Seq 4 \n",
|
|
"GasSensor RANP 6 \n",
|
|
" InceptionTimeSeq 5 \n",
|
|
" Transformer 6 \n",
|
|
" TCNSeq 6 \n",
|
|
" LSTM 6 \n",
|
|
" LSTMSeq2Seq 8 \n",
|
|
" TransformerProcess 7 \n",
|
|
" CrossAttention 7 \n",
|
|
" TransformerSeq2Seq 6 \n",
|
|
" BaselineMean 3 \n",
|
|
" BaselineLast 2 \n",
|
|
"IMOSCurrentsVel TCNSeq 4 \n",
|
|
" InceptionTimeSeq 3 \n",
|
|
" LSTM 5 \n",
|
|
" Transformer 2 \n",
|
|
" BaselineLast 3 \n",
|
|
" LSTMSeq2Seq 7 \n",
|
|
" RANP 5 \n",
|
|
" BaselineMean 5 \n",
|
|
" TransformerSeq2Seq 3 \n",
|
|
" TransformerProcess 5 \n",
|
|
" CrossAttention 3 \n",
|
|
"MetroInterstateTraffic TCNSeq 7 \n",
|
|
" TransformerProcess 3 \n",
|
|
" RANP 3 \n",
|
|
" Transformer 4 \n",
|
|
" LSTMSeq2Seq 5 \n",
|
|
" LSTM 4 \n",
|
|
" TransformerSeq2Seq 2 \n",
|
|
" InceptionTimeSeq 3 \n",
|
|
" CrossAttention 6 \n",
|
|
" BaselineMean 3 \n",
|
|
" BaselineLast 2 \n",
|
|
"\n",
|
|
" param_layers value \n",
|
|
"dataset model \n",
|
|
"AppliancesEnergyPrediction TCNSeq 1 1.075230 \n",
|
|
" InceptionTimeSeq 3 1.099332 \n",
|
|
" TransformerProcess 3 1.163042 \n",
|
|
" LSTM 3 1.172367 \n",
|
|
" TransformerSeq2Seq 1 1.174764 \n",
|
|
" Transformer 6 1.195765 \n",
|
|
" LSTMSeq2Seq 6 1.204848 \n",
|
|
" RANP 10 1.278206 \n",
|
|
" BaselineMean 5 1.322796 \n",
|
|
" BaselineLast 1 1.475730 \n",
|
|
" CrossAttention 5 1.547343 \n",
|
|
"BejingPM25 TCNSeq 3 1.235752 \n",
|
|
" InceptionTimeSeq 3 1.239024 \n",
|
|
" LSTM 3 1.271591 \n",
|
|
" LSTMSeq2Seq 6 1.292539 \n",
|
|
" Transformer 1 1.295221 \n",
|
|
" TransformerProcess 1 1.404077 \n",
|
|
" CrossAttention 11 1.411578 \n",
|
|
" RANP 1 1.429652 \n",
|
|
" BaselineMean 3 1.435451 \n",
|
|
" BaselineLast 12 1.548376 \n",
|
|
" TransformerSeq2Seq 8 2.393506 \n",
|
|
"GasSensor RANP 10 -2.133501 \n",
|
|
" InceptionTimeSeq 10 -2.102463 \n",
|
|
" Transformer 12 -1.960168 \n",
|
|
" TCNSeq 8 -1.736079 \n",
|
|
" LSTM 5 -1.537527 \n",
|
|
" LSTMSeq2Seq 5 -1.489893 \n",
|
|
" TransformerProcess 4 -0.880734 \n",
|
|
" CrossAttention 3 -0.638713 \n",
|
|
" TransformerSeq2Seq 1 0.336668 \n",
|
|
" BaselineMean 2 1.584720 \n",
|
|
" BaselineLast 6 1.974204 \n",
|
|
"IMOSCurrentsVel TCNSeq 6 0.817371 \n",
|
|
" InceptionTimeSeq 6 0.845443 \n",
|
|
" LSTM 6 0.875669 \n",
|
|
" Transformer 3 0.879682 \n",
|
|
" BaselineLast 6 0.885377 \n",
|
|
" LSTMSeq2Seq 12 0.887783 \n",
|
|
" RANP 3 1.035464 \n",
|
|
" BaselineMean 7 1.202492 \n",
|
|
" TransformerSeq2Seq 1 1.266849 \n",
|
|
" TransformerProcess 1 1.394595 \n",
|
|
" CrossAttention 3 1.656854 \n",
|
|
"MetroInterstateTraffic TCNSeq 3 -0.324690 \n",
|
|
" TransformerProcess 7 -0.297472 \n",
|
|
" RANP 3 -0.291096 \n",
|
|
" Transformer 11 -0.253573 \n",
|
|
" LSTMSeq2Seq 5 -0.200625 \n",
|
|
" LSTM 5 -0.200455 \n",
|
|
" TransformerSeq2Seq 6 -0.192895 \n",
|
|
" InceptionTimeSeq 6 -0.162513 \n",
|
|
" CrossAttention 1 -0.103853 \n",
|
|
" BaselineMean 1 1.410653 \n",
|
|
" BaselineLast 12 1.741882 "
|
|
]
|
|
},
|
|
"execution_count": 18,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"import pandas as pd\n",
|
|
"# Summarize studies\n",
|
|
"rs = []\n",
|
|
"study_summaries = optuna.study.get_all_study_summaries(storage=storage)\n",
|
|
"for s in study_summaries:\n",
|
|
" row = {}\n",
|
|
" if (s.best_trial is not None) and (s.best_trial.state==optuna.trial.TrialState.COMPLETE):\n",
|
|
" params = {k:v for k,v in s.best_trial.__dict__.items() if not k.startswith('_')}\n",
|
|
" row.update(s.user_attrs)\n",
|
|
" row['n_trials'] = s.n_trials\n",
|
|
" row.update({'param_'+k:v for k,v in s.best_trial._params.items()})\n",
|
|
" row.update(params)\n",
|
|
" rs.append(row)\n",
|
|
"df_studies = pd.DataFrame(rs)\n",
|
|
"\n",
|
|
"df_studies = (df_studies.drop(columns=['state', 'intermediate_values', 'commit', 'datetime_complete'])\n",
|
|
" .sort_values(['dataset', 'value'])\n",
|
|
" .set_index(['dataset', 'model'])\n",
|
|
" )\n",
|
|
"df_studies"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2020-11-13T07:54:27.730531Z",
|
|
"start_time": "2020-11-13T07:54:27.667135Z"
|
|
},
|
|
"lines_to_next_cell": 0,
|
|
"scrolled": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# study_names = [s.study_name for s in optuna.study.get_all_study_summaries(storage=storage)]\n",
|
|
"\n",
|
|
"# for study_name in study_names:\n",
|
|
"# loaded_study = optuna.load_study(study_name=study_name, storage=storage)\n",
|
|
" \n",
|
|
"# # Make DF over trials\n",
|
|
"# print(study_name)\n",
|
|
" \n",
|
|
"# df_trials = loaded_study.trials_dataframe()\n",
|
|
"# # df_trials.index = df_trials.apply(lambda r:f'l={r.params_layers}_hs={r.params_hidden_size_exp}', 1)\n",
|
|
"# display(df_trials)\n",
|
|
"\n",
|
|
"# # Plot test curves, to see how much overfitting\n",
|
|
"# df_values = pd.DataFrame([s.intermediate_values for s in loaded_study.get_trials()]).T\n",
|
|
"# df_values.columns = [f\"l={s.params['layers']}_hs={s.params['hidden_size_exp']}\" for s in loaded_study.get_trials()]\n",
|
|
"# df_values.plot(ylabel='nll', xlabel='epochs', title=f'val loss \"{study_name}\"')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2020-11-13T01:44:25.350038Z",
|
|
"start_time": "2020-11-12T23:16:12.130Z"
|
|
},
|
|
"lines_to_next_cell": 0
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# study_names = [s.study_name for s in optuna.study.get_all_study_summaries(storage=storage)]\n",
|
|
"\n",
|
|
"# for study_name in study_names:\n",
|
|
"# loaded_study = optuna.load_study(study_name=study_name, storage=storage)\n",
|
|
"# fig=optuna.visualization.plot_contour(loaded_study, params=['hidden_size_exp', 'layers'])\n",
|
|
"# fig = fig.update_layout(dict(title=f'{study_name} nll'))\n",
|
|
"# display(fig)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2020-11-08T23:11:33.818862Z",
|
|
"start_time": "2020-11-08T23:11:33.749610Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2020-11-08T23:11:14.588720Z",
|
|
"start_time": "2020-11-08T23:11:14.516020Z"
|
|
},
|
|
"lines_to_next_cell": 2
|
|
},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2020-11-13T01:44:25.351121Z",
|
|
"start_time": "2020-11-12T23:16:12.141Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"[f\"{s.params['layers']}_{s.params['hidden_size_exp']}\" for s in loaded_study.get_trials()]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2020-11-13T01:44:25.352257Z",
|
|
"start_time": "2020-11-12T23:16:12.146Z"
|
|
},
|
|
"scrolled": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"df_values"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"lines_to_next_cell": 2
|
|
},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"@webio": {
|
|
"lastCommId": null,
|
|
"lastKernelId": null
|
|
},
|
|
"jupytext": {
|
|
"encoding": "# -*- coding: utf-8 -*-",
|
|
"formats": "ipynb,py:light"
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "seq2seq-time",
|
|
"language": "python",
|
|
"name": "seq2seq-time"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.8"
|
|
},
|
|
"toc": {
|
|
"base_numbering": 1,
|
|
"nav_menu": {},
|
|
"number_sections": true,
|
|
"sideBar": true,
|
|
"skip_h1_title": false,
|
|
"title_cell": "Table of Contents",
|
|
"title_sidebar": "Contents",
|
|
"toc_cell": false,
|
|
"toc_position": {
|
|
"height": "calc(100% - 180px)",
|
|
"left": "10px",
|
|
"top": "150px",
|
|
"width": "209.162px"
|
|
},
|
|
"toc_section_display": true,
|
|
"toc_window_display": true
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|