Files
2020-11-13 15:56:16 +08:00

1629 lines
58 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"ExecuteTime": {
"end_time": "2020-10-10T01:25:12.788851Z",
"start_time": "2020-10-10T01:25:12.783398Z"
}
},
"source": [
"# Sequence to Sequence Models for Timeseries Regression\n",
"\n",
"\n",
"In this notebook we are going to find the optimal hidden_size for a model vs a dataset. We will use pytorch lightning and optuna."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2020-11-13T07:55:31.364813Z",
"start_time": "2020-11-13T07:55:31.025418Z"
}
},
"outputs": [],
"source": [
"# OPTIONAL: Load the \"autoreload\" extension so that code can change. But blacklist large modules\n",
"%load_ext autoreload\n",
"%autoreload 2\n",
"%aimport -pandas\n",
"%aimport -torch\n",
"%aimport -numpy\n",
"%aimport -matplotlib\n",
"%aimport -dask\n",
"%aimport -tqdm\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2020-11-13T07:55:32.324191Z",
"start_time": "2020-11-13T07:55:31.367170Z"
},
"lines_to_next_cell": 0
},
"outputs": [],
"source": [
"# Imports\n",
"import torch\n",
"from torch import nn, optim\n",
"from torch.nn import functional as F\n",
"from torch.autograd import Variable\n",
"import torch\n",
"import torch.utils.data\n",
"\n",
"\n",
"from pathlib import Path\n",
"from tqdm.auto import tqdm\n",
"\n",
"import pytorch_lightning as pl"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2020-11-13T07:55:32.863984Z",
"start_time": "2020-11-13T07:55:32.326782Z"
}
},
"outputs": [],
"source": [
"from seq2seq_time.data.dataset import Seq2SeqDataSet, Seq2SeqDataSets\n",
"from seq2seq_time.predict import predict, predict_multi\n",
"from seq2seq_time.util import dset_to_nc"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2020-11-13T07:55:32.900137Z",
"start_time": "2020-11-13T07:55:32.866551Z"
}
},
"outputs": [],
"source": [
"import logging\n",
"import warnings\n",
"import seq2seq_time.silence \n",
"warnings.simplefilter('once')\n",
"warnings.simplefilter(action='ignore', category=FutureWarning)\n",
"warnings.simplefilter(action='ignore', category=DeprecationWarning)\n",
"warnings.filterwarnings('ignore', 'Consider increasing the value of the `num_workers` argument', UserWarning)\n",
"warnings.filterwarnings('ignore', 'Your val_dataloader has `shuffle=True`', UserWarning)\n",
"\n",
"from pytorch_lightning import _logger as log\n",
"log.setLevel(logging.WARN)"
]
},
{
"cell_type": "markdown",
"metadata": {
"ExecuteTime": {
"end_time": "2020-10-10T01:28:32.492160Z",
"start_time": "2020-10-10T01:28:32.488140Z"
}
},
"source": [
"## Parameters"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2020-11-13T07:55:32.992818Z",
"start_time": "2020-11-13T07:55:32.902453Z"
},
"lines_to_next_cell": 0
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"using cuda\n",
"20201108-300000\n"
]
},
{
"data": {
"text/plain": [
"336"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"print(f'using {device}')\n",
"\n",
"timestamp = '20201108-300000'\n",
"print(timestamp)\n",
"window_past = 48*7\n",
"window_future = 48\n",
"batch_size = 64\n",
"num_workers = 4\n",
"datasets_root = Path('../data/processed/')\n",
"window_past"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Datasets\n",
"\n",
"From easy to hard, these dataset show different challenges, all of them with more than 20k datapoints and with a regression output. See the 00.01 notebook for more details, and the code for more information.\n",
"\n",
"Some such as MetroInterstateTraffic are easier, some are periodic such as BejingPM25, some are conditional on inputs such as GasSensor, and some are noisy and periodic like IMOSCurrentsVel"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"ExecuteTime": {
"end_time": "2020-11-13T07:55:33.715323Z",
"start_time": "2020-11-13T07:55:33.238397Z"
},
"lines_to_next_cell": 0
},
"outputs": [
{
"data": {
"text/plain": [
"[seq2seq_time.data.data.MetroInterstateTraffic,\n",
" seq2seq_time.data.data.IMOSCurrentsVel,\n",
" seq2seq_time.data.data.GasSensor,\n",
" seq2seq_time.data.data.AppliancesEnergyPrediction,\n",
" seq2seq_time.data.data.BejingPM25]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from seq2seq_time.data.data import IMOSCurrentsVel, AppliancesEnergyPrediction, BejingPM25, GasSensor, MetroInterstateTraffic\n",
"datasets = [MetroInterstateTraffic, IMOSCurrentsVel, GasSensor, AppliancesEnergyPrediction, BejingPM25]\n",
"datasets\n",
"# ## Lightning\n",
"#\n",
"# We will use pytorch lightning to handle all the training scaffolding. We have a common pytorch lightning class that takes in the model and defines training steps and logging."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"ExecuteTime": {
"end_time": "2020-11-13T07:55:34.390203Z",
"start_time": "2020-11-13T07:55:34.338792Z"
}
},
"outputs": [],
"source": [
"import pytorch_lightning as pl\n",
"\n",
"class PL_MODEL(pl.LightningModule):\n",
" def __init__(self, model, lr=3e-4, patience=None, weight_decay=0):\n",
" super().__init__()\n",
" self._model = model\n",
" self.lr = lr\n",
" self.patience = patience\n",
" self.weight_decay = weight_decay\n",
"\n",
" def forward(self, x_past, y_past, x_future, y_future=None):\n",
" \"\"\"Eval/Predict\"\"\"\n",
" y_dist, extra = self._model(x_past, y_past, x_future, y_future)\n",
" return y_dist, extra\n",
"\n",
" def training_step(self, batch, batch_idx, phase='train'):\n",
" x_past, y_past, x_future, y_future = batch\n",
" y_dist, extra = self.forward(*batch)\n",
" loss = -y_dist.log_prob(y_future).mean()\n",
" self.log_dict({f'loss/{phase}':loss})\n",
" if ('loss' in extra) and (phase=='train'):\n",
" # some models have a special loss\n",
" loss = extra['loss']\n",
" self.log_dict({f'model_loss/{phase}':loss})\n",
" return loss\n",
"\n",
" def validation_step(self, batch, batch_idx):\n",
" return self.training_step(batch, batch_idx, phase='val')\n",
" \n",
" def test_step(self, batch, batch_idx):\n",
" return self.training_step(batch, batch_idx, phase='test')\n",
" \n",
" def configure_optimizers(self):\n",
" optim = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)\n",
" scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n",
" optim,\n",
" patience=self.patience,\n",
" verbose=False,\n",
" min_lr=1e-7,\n",
" ) if self.patience else None\n",
" return {'optimizer': optim, 'lr_scheduler': scheduler, 'monitor': 'loss/val'}"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"ExecuteTime": {
"end_time": "2020-11-13T07:55:34.653519Z",
"start_time": "2020-11-13T07:55:34.605285Z"
},
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\n",
"from pytorch_lightning.loggers import CSVLogger, WandbLogger, TensorBoardLogger, TestTubeLogger\n",
"from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n",
"from pytorch_lightning.callbacks import LearningRateMonitor"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Models"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"ExecuteTime": {
"end_time": "2020-11-13T07:55:35.693992Z",
"start_time": "2020-11-13T07:55:35.629427Z"
},
"lines_to_end_of_cell_marker": 2,
"lines_to_next_cell": 0
},
"outputs": [],
"source": [
"from seq2seq_time.models.baseline import BaselineLast, BaselineMean\n",
"from seq2seq_time.models.lstm_seq2seq import LSTMSeq2Seq\n",
"from seq2seq_time.models.lstm import LSTM\n",
"from seq2seq_time.models.transformer import Transformer\n",
"from seq2seq_time.models.transformer_seq2seq import TransformerSeq2Seq\n",
"from seq2seq_time.models.neural_process import RANP\n",
"from seq2seq_time.models.transformer_process import TransformerProcess\n",
"from seq2seq_time.models.tcn import TCNSeq\n",
"from seq2seq_time.models.inceptiontime import InceptionTimeSeq\n",
"from seq2seq_time.models.xattention import CrossAttention"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"ExecuteTime": {
"end_time": "2020-11-13T07:55:36.024789Z",
"start_time": "2020-11-13T07:55:35.985433Z"
},
"lines_to_next_cell": 0
},
"outputs": [],
"source": [
"import gc\n",
"\n",
"def free_mem():\n",
" gc.collect()\n",
" torch.cuda.empty_cache()\n",
" gc.collect()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2020-11-02T06:10:41.904480Z",
"start_time": "2020-11-02T06:10:41.848613Z"
},
"lines_to_next_cell": 2
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"ExecuteTime": {
"end_time": "2020-11-13T07:55:37.813853Z",
"start_time": "2020-11-13T07:55:37.768348Z"
},
"lines_to_next_cell": 0
},
"outputs": [],
"source": [
"# PARAMS: model\n",
"dropout=0.0\n",
"layers=6\n",
"nhead=4\n",
"\n",
"models = [\n",
"# lambda xs, ys: BaselineLast(),\n",
"# lambda xs, ys, hidden_size: BaselineMean(),\n",
" lambda xs, ys, hidden_size, layers:TransformerProcess(xs,\n",
" ys, hidden_size=hidden_size, nhead=nhead,\n",
" latent_dim=hidden_size//2, dropout=dropout,\n",
" nlayers=layers),\n",
" lambda xs, ys, hidden_size, layers: RANP(xs,\n",
" ys, hidden_dim=hidden_size, dropout=dropout, \n",
" latent_dim=hidden_size//2, n_decoder_layers=layers, n_latent_encoder_layers=layers, n_det_encoder_layers=layers),\n",
" lambda xs, ys, hidden_size, layers:TCNSeq(xs, ys, hidden_size=hidden_size, nlayers=layers, dropout=dropout, kernel_size=2),\n",
" lambda xs, ys, hidden_size, layers: Transformer(xs,\n",
" ys,\n",
" attention_dropout=dropout,\n",
" nhead=nhead,\n",
" nlayers=layers,\n",
" hidden_size=hidden_size),\n",
" lambda xs, ys, hidden_size, layers: LSTM(xs,\n",
" ys,\n",
" hidden_size=hidden_size,\n",
" lstm_layers=layers//2,\n",
" lstm_dropout=dropout),\n",
"\n",
" lambda xs, ys, hidden_size, layers: TransformerSeq2Seq(xs,\n",
" ys,\n",
" hidden_size=hidden_size,\n",
" nhead=nhead,\n",
" nlayers=layers,\n",
" attention_dropout=dropout\n",
" ),\n",
"\n",
" lambda xs, ys, hidden_size, layers: LSTMSeq2Seq(xs,\n",
" ys,\n",
" hidden_size=hidden_size,\n",
" lstm_layers=layers//2,\n",
" lstm_dropout=dropout),\n",
" lambda xs, ys, hidden_size, layers: CrossAttention(xs,\n",
" ys,\n",
" nlayers=layers,\n",
" hidden_size=hidden_size,),\n",
" lambda xs, ys, hidden_size, layers: InceptionTimeSeq(xs,\n",
" ys,\n",
" kernel_size=96,\n",
" layers=layers//2,\n",
" hidden_size=hidden_size,\n",
" bottleneck=hidden_size//4)\n",
"\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2020-11-12T23:16:29.248291Z",
"start_time": "2020-11-12T23:16:16.139941Z"
}
},
"outputs": [],
"source": [
"# DEBUG: sanity check\n",
"\n",
"for Dataset in datasets:\n",
" dataset_name = Dataset.__name__\n",
" dataset = Dataset(datasets_root)\n",
" ds_train, ds_val, ds_test = dataset.to_datasets(window_past=window_past,\n",
" window_future=window_future)\n",
"\n",
" # Init data\n",
" x_past, y_past, x_future, y_future = ds_train.get_rows(10)\n",
" xs = x_past.shape[-1]\n",
" ys = y_future.shape[-1]\n",
"\n",
" # Loaders\n",
" dl_train = DataLoader(ds_train,\n",
" batch_size=batch_size,\n",
" shuffle=True,\n",
" pin_memory=num_workers == 0,\n",
" num_workers=num_workers)\n",
" dl_val = DataLoader(ds_val,\n",
" shuffle=True,\n",
" batch_size=batch_size,\n",
" num_workers=num_workers)\n",
"\n",
" for m_fn in models:\n",
" free_mem()\n",
" pt_model = m_fn(xs, ys, 8, 4)\n",
" model_name = type(pt_model).__name__\n",
" print(timestamp, dataset_name, model_name)\n",
"\n",
" # Wrap in lightning\n",
" model = PL_MODEL(pt_model,\n",
" lr=3e-4\n",
" ).to(device)\n",
" trainer = pl.Trainer(\n",
" fast_dev_run=True,\n",
" # GPU\n",
" gpus=1,\n",
" amp_level='O1',\n",
" precision=16,\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {
"ExecuteTime": {
"end_time": "2020-10-23T23:36:11.052891Z",
"start_time": "2020-10-23T23:36:11.048874Z"
}
},
"source": [
"## Train"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2020-11-12T23:16:29.286762Z",
"start_time": "2020-11-12T23:16:29.250367Z"
},
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"max_iters=20000"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2020-11-12T23:16:29.326846Z",
"start_time": "2020-11-12T23:16:29.289004Z"
}
},
"outputs": [],
"source": [
"tensorboard_dir = Path(f\"../outputs/{timestamp}\").resolve()\n",
"print(f'For tensorboard run:\\ntensorboard --logdir=\"{tensorboard_dir}\"')"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"ExecuteTime": {
"end_time": "2020-11-13T07:55:48.561522Z",
"start_time": "2020-11-13T07:55:48.516182Z"
},
"lines_to_next_cell": 0
},
"outputs": [],
"source": [
"\n",
"def objective(trial):\n",
" \"\"\"\n",
" Optuna function to optimize\n",
" \n",
" See https://github.com/optuna/optuna/blob/master/examples/pytorch_lightning_simple.py\n",
" \"\"\"\n",
" # sample\n",
" hidden_size_exp = trial.suggest_int(\"hidden_size_exp\", 1, 8)\n",
" hidden_size = 2**hidden_size_exp\n",
" \n",
" layers = trial.suggest_int(\"layers\", 1, 12)\n",
" \n",
" # Load model\n",
" pt_model = m_fn(xs, ys, hidden_size, layers)\n",
" model_name = type(pt_model).__name__\n",
" \n",
" # Wrap in lightning\n",
" patience = 2\n",
" model = PL_MODEL(pt_model,\n",
" lr=3e-4, patience=patience,\n",
" ).to(device)\n",
"\n",
" \n",
" save_dir = f\"../outputs/{timestamp}/{dataset_name}_{model_name}/{trial.number}\"\n",
" Path(save_dir).mkdir(exist_ok=True, parents=True)\n",
" trainer = pl.Trainer(\n",
" # Training length\n",
" min_epochs=2,\n",
" max_epochs=40,\n",
" limit_train_batches=max_iters//batch_size,\n",
" limit_val_batches=max_iters//batch_size//5,\n",
" # Misc\n",
" gradient_clip_val=20,\n",
" terminate_on_nan=True,\n",
" # GPU\n",
" gpus=1,\n",
" amp_level='O1',\n",
" precision=16,\n",
" # Callbacks\n",
" default_root_dir=save_dir,\n",
" logger=False,\n",
" callbacks=[\n",
" EarlyStopping(monitor='loss/val', patience=patience * 2),\n",
" PyTorchLightningPruningCallback(trial, monitor=\"loss/val\")],\n",
" )\n",
" trainer.fit(model, dl_train, dl_val)\n",
" \n",
" # Run on all val data, using test mode\n",
" r = trainer.test(model, test_dataloaders=dl_val, verbose=False)\n",
" return r[0]['loss/test']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2020-11-08T02:45:44.106583Z",
"start_time": "2020-11-08T02:45:44.050637Z"
},
"lines_to_next_cell": 2
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"ExecuteTime": {
"end_time": "2020-11-13T07:55:49.497334Z",
"start_time": "2020-11-13T07:55:48.981040Z"
}
},
"outputs": [],
"source": [
"import optuna\n",
"from optuna.integration import PyTorchLightningPruningCallback"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"ExecuteTime": {
"end_time": "2020-11-13T07:55:49.543035Z",
"start_time": "2020-11-13T07:55:49.499555Z"
}
},
"outputs": [],
"source": [
"import subprocess\n",
"def get_git_commit():\n",
" try:\n",
" return subprocess.check_output([\"git\", \"rev-parse\", \"HEAD\"], cwd='..').decode().strip()\n",
" except Exception:\n",
" logging.exception(\"failed to get git hash\")"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"ExecuteTime": {
"end_time": "2020-11-13T07:55:50.265327Z",
"start_time": "2020-11-13T07:55:50.204384Z"
}
},
"outputs": [],
"source": [
"Path(f\"../outputs/{timestamp}\").mkdir(exist_ok=True)\n",
"storage = f\"sqlite:///../outputs/{timestamp}/optuna.db\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2020-11-13T01:44:13.554355Z",
"start_time": "2020-11-12T23:16:30.128631Z"
},
"lines_to_next_cell": 0,
"scrolled": true
},
"outputs": [],
"source": [
"for Dataset in tqdm(datasets, desc='datasets'):\n",
" dataset_name = Dataset.__name__\n",
" dataset = Dataset(datasets_root)\n",
" ds_train, ds_val, ds_test = dataset.to_datasets(window_past=window_past,\n",
" window_future=window_future)\n",
"\n",
" # Init data\n",
" x_past, y_past, x_future, y_future = ds_train.get_rows(10)\n",
" xs = x_past.shape[-1]\n",
" ys = y_future.shape[-1]\n",
"\n",
" # Loaders\n",
" dl_train = DataLoader(ds_train,\n",
" batch_size=batch_size,\n",
" shuffle=True,\n",
" drop_last=True,\n",
" pin_memory=num_workers == 0,\n",
" num_workers=num_workers)\n",
" dl_val = DataLoader(ds_val,\n",
" shuffle=False,\n",
" batch_size=batch_size,\n",
" drop_last=True,\n",
" num_workers=num_workers)\n",
"\n",
" for i, m_fn in enumerate(tqdm(models, desc=f'models ({dataset_name})')):\n",
" try:\n",
" model_name = type(m_fn(8, 8, 8, 2)).__name__\n",
" free_mem()\n",
" study_name = f'{timestamp}_{dataset_name}-{model_name}'\n",
" \n",
" # Create study \n",
" pruner = optuna.pruners.MedianPruner()\n",
" study = optuna.create_study(storage=storage, \n",
" study_name=study_name, \n",
" pruner=pruner,\n",
" load_if_exists=True)\n",
" study.set_user_attr('dataset', dataset_name)\n",
" study.set_user_attr('model', model_name)\n",
" study.set_user_attr('commit', get_git_commit())\n",
" \n",
" df_trials = study.trials_dataframe()\n",
" if len(df_trials):\n",
" df_trials = df_trials[df_trials.state=='COMPLETE']\n",
" nb_trials = len(df_trials)\n",
" if nb_trials==0:\n",
" # Priors\n",
" study.enqueue_trial({\"layers\": 6, \"params_hidden_size_exp\": 2})\n",
" study.enqueue_trial({\"layers\": 1, \"params_hidden_size_exp\": 3})\n",
" study.enqueue_trial({\"layers\": 3, \"params_hidden_size_exp\": 5})\n",
" if nb_trials<20:\n",
" # Opt\n",
" study.optimize(objective, n_trials=20-nb_trials, \n",
" timeout=60*60*2 # Max seconds for all optimizes\n",
" )\n",
" \n",
" \n",
" print(\"Number of finished trials: {}\".format(len(study.trials)))\n",
"\n",
" print(\"Best trial:\")\n",
" trial = study.best_trial\n",
"\n",
" print(\" Value: {}\".format(trial.value))\n",
"\n",
" print(\" Params: \")\n",
" for key, value in trial.params.items():\n",
" print(\" {}: {}\".format(key, value))\n",
" \n",
" except Exception as e:\n",
" logging.exception('failed to run model')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2020-11-12T23:08:15.792627Z",
"start_time": "2020-11-12T23:08:15.700279Z"
}
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2020-11-13T01:44:20.959637Z",
"start_time": "2020-11-13T01:44:13.560463Z"
}
},
"outputs": [],
"source": [
"# Baseline\n",
"\n",
"models2 = [\n",
" lambda xs, ys, _, l: BaselineLast(),\n",
" lambda xs, ys, _, l: BaselineMean(),\n",
"]\n",
"\n",
"for Dataset in tqdm(datasets, desc='datasets'):\n",
" dataset_name = Dataset.__name__\n",
" dataset = Dataset(datasets_root)\n",
" ds_train, ds_val, ds_test = dataset.to_datasets(window_past=window_past,\n",
" window_future=window_future)\n",
"\n",
" # Init data\n",
" x_past, y_past, x_future, y_future = ds_train.get_rows(10)\n",
" xs = x_past.shape[-1]\n",
" ys = y_future.shape[-1]\n",
"\n",
" # Loaders\n",
" dl_train = DataLoader(ds_train,\n",
" batch_size=batch_size,\n",
" shuffle=True,\n",
" drop_last=True,\n",
" pin_memory=num_workers == 0,\n",
" num_workers=num_workers)\n",
" dl_val = DataLoader(ds_val,\n",
" shuffle=False,\n",
" batch_size=batch_size,\n",
" drop_last=True,\n",
" num_workers=num_workers)\n",
"\n",
" for i, m_fn in enumerate(tqdm(models2, desc=f'models ({dataset_name})')):\n",
" try:\n",
" model_name = type(m_fn(8, 8, 8, 2)).__name__\n",
" free_mem()\n",
" study_name = f'{timestamp}_{dataset_name}-{model_name}'\n",
" \n",
" # Create study \n",
" pruner = optuna.pruners.MedianPruner()\n",
" study = optuna.create_study(storage=storage, \n",
" study_name=study_name, \n",
" pruner=pruner,\n",
" load_if_exists=True)\n",
" study.set_user_attr('dataset', dataset_name)\n",
" study.set_user_attr('model', model_name)\n",
" study.set_user_attr('commit', get_git_commit())\n",
" \n",
" df_trials = study.trials_dataframe()\n",
" if len(df_trials):\n",
" df_trials = df_trials[df_trials.state=='COMPLETE']\n",
" nb_trials = len(df_trials)\n",
" if nb_trials<1:\n",
" # Opt\n",
" study.optimize(objective, n_trials=1, \n",
" timeout=60*30 # Max seconds for all optimizes\n",
" )\n",
" \n",
" except Exception as e:\n",
" logging.exception('failed to run model')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2020-11-13T01:44:21.021030Z",
"start_time": "2020-11-13T01:44:20.962459Z"
}
},
"outputs": [],
"source": [
"# TODO baseline, run as sep cell, opt once\n",
"# TODO summarize time and model params at best params"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"ExecuteTime": {
"end_time": "2020-11-13T07:55:54.138705Z",
"start_time": "2020-11-13T07:55:53.706642Z"
}
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th></th>\n",
" <th>n_trials</th>\n",
" <th>param_hidden_size_exp</th>\n",
" <th>param_layers</th>\n",
" <th>value</th>\n",
" </tr>\n",
" <tr>\n",
" <th>dataset</th>\n",
" <th>model</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th rowspan=\"11\" valign=\"top\">AppliancesEnergyPrediction</th>\n",
" <th>TCNSeq</th>\n",
" <td>48</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>1.075230</td>\n",
" </tr>\n",
" <tr>\n",
" <th>InceptionTimeSeq</th>\n",
" <td>17</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>1.099332</td>\n",
" </tr>\n",
" <tr>\n",
" <th>TransformerProcess</th>\n",
" <td>18</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>1.163042</td>\n",
" </tr>\n",
" <tr>\n",
" <th>LSTM</th>\n",
" <td>19</td>\n",
" <td>6</td>\n",
" <td>3</td>\n",
" <td>1.172367</td>\n",
" </tr>\n",
" <tr>\n",
" <th>TransformerSeq2Seq</th>\n",
" <td>17</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>1.174764</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Transformer</th>\n",
" <td>19</td>\n",
" <td>4</td>\n",
" <td>6</td>\n",
" <td>1.195765</td>\n",
" </tr>\n",
" <tr>\n",
" <th>LSTMSeq2Seq</th>\n",
" <td>29</td>\n",
" <td>4</td>\n",
" <td>6</td>\n",
" <td>1.204848</td>\n",
" </tr>\n",
" <tr>\n",
" <th>RANP</th>\n",
" <td>16</td>\n",
" <td>4</td>\n",
" <td>10</td>\n",
" <td>1.278206</td>\n",
" </tr>\n",
" <tr>\n",
" <th>BaselineMean</th>\n",
" <td>1</td>\n",
" <td>6</td>\n",
" <td>5</td>\n",
" <td>1.322796</td>\n",
" </tr>\n",
" <tr>\n",
" <th>BaselineLast</th>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>1.475730</td>\n",
" </tr>\n",
" <tr>\n",
" <th>CrossAttention</th>\n",
" <td>17</td>\n",
" <td>4</td>\n",
" <td>5</td>\n",
" <td>1.547343</td>\n",
" </tr>\n",
" <tr>\n",
" <th rowspan=\"11\" valign=\"top\">BejingPM25</th>\n",
" <th>TCNSeq</th>\n",
" <td>46</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>1.235752</td>\n",
" </tr>\n",
" <tr>\n",
" <th>InceptionTimeSeq</th>\n",
" <td>15</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>1.239024</td>\n",
" </tr>\n",
" <tr>\n",
" <th>LSTM</th>\n",
" <td>18</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>1.271591</td>\n",
" </tr>\n",
" <tr>\n",
" <th>LSTMSeq2Seq</th>\n",
" <td>17</td>\n",
" <td>5</td>\n",
" <td>6</td>\n",
" <td>1.292539</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Transformer</th>\n",
" <td>53</td>\n",
" <td>6</td>\n",
" <td>1</td>\n",
" <td>1.295221</td>\n",
" </tr>\n",
" <tr>\n",
" <th>TransformerProcess</th>\n",
" <td>16</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>1.404077</td>\n",
" </tr>\n",
" <tr>\n",
" <th>CrossAttention</th>\n",
" <td>17</td>\n",
" <td>8</td>\n",
" <td>11</td>\n",
" <td>1.411578</td>\n",
" </tr>\n",
" <tr>\n",
" <th>RANP</th>\n",
" <td>16</td>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>1.429652</td>\n",
" </tr>\n",
" <tr>\n",
" <th>BaselineMean</th>\n",
" <td>1</td>\n",
" <td>6</td>\n",
" <td>3</td>\n",
" <td>1.435451</td>\n",
" </tr>\n",
" <tr>\n",
" <th>BaselineLast</th>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>12</td>\n",
" <td>1.548376</td>\n",
" </tr>\n",
" <tr>\n",
" <th>TransformerSeq2Seq</th>\n",
" <td>16</td>\n",
" <td>4</td>\n",
" <td>8</td>\n",
" <td>2.393506</td>\n",
" </tr>\n",
" <tr>\n",
" <th rowspan=\"11\" valign=\"top\">GasSensor</th>\n",
" <th>RANP</th>\n",
" <td>18</td>\n",
" <td>6</td>\n",
" <td>10</td>\n",
" <td>-2.133501</td>\n",
" </tr>\n",
" <tr>\n",
" <th>InceptionTimeSeq</th>\n",
" <td>33</td>\n",
" <td>5</td>\n",
" <td>10</td>\n",
" <td>-2.102463</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Transformer</th>\n",
" <td>50</td>\n",
" <td>6</td>\n",
" <td>12</td>\n",
" <td>-1.960168</td>\n",
" </tr>\n",
" <tr>\n",
" <th>TCNSeq</th>\n",
" <td>41</td>\n",
" <td>6</td>\n",
" <td>8</td>\n",
" <td>-1.736079</td>\n",
" </tr>\n",
" <tr>\n",
" <th>LSTM</th>\n",
" <td>18</td>\n",
" <td>6</td>\n",
" <td>5</td>\n",
" <td>-1.537527</td>\n",
" </tr>\n",
" <tr>\n",
" <th>LSTMSeq2Seq</th>\n",
" <td>20</td>\n",
" <td>8</td>\n",
" <td>5</td>\n",
" <td>-1.489893</td>\n",
" </tr>\n",
" <tr>\n",
" <th>TransformerProcess</th>\n",
" <td>38</td>\n",
" <td>7</td>\n",
" <td>4</td>\n",
" <td>-0.880734</td>\n",
" </tr>\n",
" <tr>\n",
" <th>CrossAttention</th>\n",
" <td>17</td>\n",
" <td>7</td>\n",
" <td>3</td>\n",
" <td>-0.638713</td>\n",
" </tr>\n",
" <tr>\n",
" <th>TransformerSeq2Seq</th>\n",
" <td>28</td>\n",
" <td>6</td>\n",
" <td>1</td>\n",
" <td>0.336668</td>\n",
" </tr>\n",
" <tr>\n",
" <th>BaselineMean</th>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>2</td>\n",
" <td>1.584720</td>\n",
" </tr>\n",
" <tr>\n",
" <th>BaselineLast</th>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>6</td>\n",
" <td>1.974204</td>\n",
" </tr>\n",
" <tr>\n",
" <th rowspan=\"11\" valign=\"top\">IMOSCurrentsVel</th>\n",
" <th>TCNSeq</th>\n",
" <td>52</td>\n",
" <td>4</td>\n",
" <td>6</td>\n",
" <td>0.817371</td>\n",
" </tr>\n",
" <tr>\n",
" <th>InceptionTimeSeq</th>\n",
" <td>20</td>\n",
" <td>3</td>\n",
" <td>6</td>\n",
" <td>0.845443</td>\n",
" </tr>\n",
" <tr>\n",
" <th>LSTM</th>\n",
" <td>22</td>\n",
" <td>5</td>\n",
" <td>6</td>\n",
" <td>0.875669</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Transformer</th>\n",
" <td>24</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>0.879682</td>\n",
" </tr>\n",
" <tr>\n",
" <th>BaselineLast</th>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>6</td>\n",
" <td>0.885377</td>\n",
" </tr>\n",
" <tr>\n",
" <th>LSTMSeq2Seq</th>\n",
" <td>24</td>\n",
" <td>7</td>\n",
" <td>12</td>\n",
" <td>0.887783</td>\n",
" </tr>\n",
" <tr>\n",
" <th>RANP</th>\n",
" <td>19</td>\n",
" <td>5</td>\n",
" <td>3</td>\n",
" <td>1.035464</td>\n",
" </tr>\n",
" <tr>\n",
" <th>BaselineMean</th>\n",
" <td>1</td>\n",
" <td>5</td>\n",
" <td>7</td>\n",
" <td>1.202492</td>\n",
" </tr>\n",
" <tr>\n",
" <th>TransformerSeq2Seq</th>\n",
" <td>19</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>1.266849</td>\n",
" </tr>\n",
" <tr>\n",
" <th>TransformerProcess</th>\n",
" <td>27</td>\n",
" <td>5</td>\n",
" <td>1</td>\n",
" <td>1.394595</td>\n",
" </tr>\n",
" <tr>\n",
" <th>CrossAttention</th>\n",
" <td>17</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>1.656854</td>\n",
" </tr>\n",
" <tr>\n",
" <th rowspan=\"11\" valign=\"top\">MetroInterstateTraffic</th>\n",
" <th>TCNSeq</th>\n",
" <td>61</td>\n",
" <td>7</td>\n",
" <td>3</td>\n",
" <td>-0.324690</td>\n",
" </tr>\n",
" <tr>\n",
" <th>TransformerProcess</th>\n",
" <td>17</td>\n",
" <td>3</td>\n",
" <td>7</td>\n",
" <td>-0.297472</td>\n",
" </tr>\n",
" <tr>\n",
" <th>RANP</th>\n",
" <td>16</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>-0.291096</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Transformer</th>\n",
" <td>20</td>\n",
" <td>4</td>\n",
" <td>11</td>\n",
" <td>-0.253573</td>\n",
" </tr>\n",
" <tr>\n",
" <th>LSTMSeq2Seq</th>\n",
" <td>15</td>\n",
" <td>5</td>\n",
" <td>5</td>\n",
" <td>-0.200625</td>\n",
" </tr>\n",
" <tr>\n",
" <th>LSTM</th>\n",
" <td>17</td>\n",
" <td>4</td>\n",
" <td>5</td>\n",
" <td>-0.200455</td>\n",
" </tr>\n",
" <tr>\n",
" <th>TransformerSeq2Seq</th>\n",
" <td>15</td>\n",
" <td>2</td>\n",
" <td>6</td>\n",
" <td>-0.192895</td>\n",
" </tr>\n",
" <tr>\n",
" <th>InceptionTimeSeq</th>\n",
" <td>17</td>\n",
" <td>3</td>\n",
" <td>6</td>\n",
" <td>-0.162513</td>\n",
" </tr>\n",
" <tr>\n",
" <th>CrossAttention</th>\n",
" <td>14</td>\n",
" <td>6</td>\n",
" <td>1</td>\n",
" <td>-0.103853</td>\n",
" </tr>\n",
" <tr>\n",
" <th>BaselineMean</th>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>1.410653</td>\n",
" </tr>\n",
" <tr>\n",
" <th>BaselineLast</th>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>12</td>\n",
" <td>1.741882</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" n_trials \\\n",
"dataset model \n",
"AppliancesEnergyPrediction TCNSeq 48 \n",
" InceptionTimeSeq 17 \n",
" TransformerProcess 18 \n",
" LSTM 19 \n",
" TransformerSeq2Seq 17 \n",
" Transformer 19 \n",
" LSTMSeq2Seq 29 \n",
" RANP 16 \n",
" BaselineMean 1 \n",
" BaselineLast 1 \n",
" CrossAttention 17 \n",
"BejingPM25 TCNSeq 46 \n",
" InceptionTimeSeq 15 \n",
" LSTM 18 \n",
" LSTMSeq2Seq 17 \n",
" Transformer 53 \n",
" TransformerProcess 16 \n",
" CrossAttention 17 \n",
" RANP 16 \n",
" BaselineMean 1 \n",
" BaselineLast 1 \n",
" TransformerSeq2Seq 16 \n",
"GasSensor RANP 18 \n",
" InceptionTimeSeq 33 \n",
" Transformer 50 \n",
" TCNSeq 41 \n",
" LSTM 18 \n",
" LSTMSeq2Seq 20 \n",
" TransformerProcess 38 \n",
" CrossAttention 17 \n",
" TransformerSeq2Seq 28 \n",
" BaselineMean 1 \n",
" BaselineLast 1 \n",
"IMOSCurrentsVel TCNSeq 52 \n",
" InceptionTimeSeq 20 \n",
" LSTM 22 \n",
" Transformer 24 \n",
" BaselineLast 1 \n",
" LSTMSeq2Seq 24 \n",
" RANP 19 \n",
" BaselineMean 1 \n",
" TransformerSeq2Seq 19 \n",
" TransformerProcess 27 \n",
" CrossAttention 17 \n",
"MetroInterstateTraffic TCNSeq 61 \n",
" TransformerProcess 17 \n",
" RANP 16 \n",
" Transformer 20 \n",
" LSTMSeq2Seq 15 \n",
" LSTM 17 \n",
" TransformerSeq2Seq 15 \n",
" InceptionTimeSeq 17 \n",
" CrossAttention 14 \n",
" BaselineMean 1 \n",
" BaselineLast 1 \n",
"\n",
" param_hidden_size_exp \\\n",
"dataset model \n",
"AppliancesEnergyPrediction TCNSeq 2 \n",
" InceptionTimeSeq 1 \n",
" TransformerProcess 4 \n",
" LSTM 6 \n",
" TransformerSeq2Seq 2 \n",
" Transformer 4 \n",
" LSTMSeq2Seq 4 \n",
" RANP 4 \n",
" BaselineMean 6 \n",
" BaselineLast 2 \n",
" CrossAttention 4 \n",
"BejingPM25 TCNSeq 3 \n",
" InceptionTimeSeq 1 \n",
" LSTM 2 \n",
" LSTMSeq2Seq 5 \n",
" Transformer 6 \n",
" TransformerProcess 2 \n",
" CrossAttention 8 \n",
" RANP 4 \n",
" BaselineMean 6 \n",
" BaselineLast 4 \n",
" TransformerSeq2Seq 4 \n",
"GasSensor RANP 6 \n",
" InceptionTimeSeq 5 \n",
" Transformer 6 \n",
" TCNSeq 6 \n",
" LSTM 6 \n",
" LSTMSeq2Seq 8 \n",
" TransformerProcess 7 \n",
" CrossAttention 7 \n",
" TransformerSeq2Seq 6 \n",
" BaselineMean 3 \n",
" BaselineLast 2 \n",
"IMOSCurrentsVel TCNSeq 4 \n",
" InceptionTimeSeq 3 \n",
" LSTM 5 \n",
" Transformer 2 \n",
" BaselineLast 3 \n",
" LSTMSeq2Seq 7 \n",
" RANP 5 \n",
" BaselineMean 5 \n",
" TransformerSeq2Seq 3 \n",
" TransformerProcess 5 \n",
" CrossAttention 3 \n",
"MetroInterstateTraffic TCNSeq 7 \n",
" TransformerProcess 3 \n",
" RANP 3 \n",
" Transformer 4 \n",
" LSTMSeq2Seq 5 \n",
" LSTM 4 \n",
" TransformerSeq2Seq 2 \n",
" InceptionTimeSeq 3 \n",
" CrossAttention 6 \n",
" BaselineMean 3 \n",
" BaselineLast 2 \n",
"\n",
" param_layers value \n",
"dataset model \n",
"AppliancesEnergyPrediction TCNSeq 1 1.075230 \n",
" InceptionTimeSeq 3 1.099332 \n",
" TransformerProcess 3 1.163042 \n",
" LSTM 3 1.172367 \n",
" TransformerSeq2Seq 1 1.174764 \n",
" Transformer 6 1.195765 \n",
" LSTMSeq2Seq 6 1.204848 \n",
" RANP 10 1.278206 \n",
" BaselineMean 5 1.322796 \n",
" BaselineLast 1 1.475730 \n",
" CrossAttention 5 1.547343 \n",
"BejingPM25 TCNSeq 3 1.235752 \n",
" InceptionTimeSeq 3 1.239024 \n",
" LSTM 3 1.271591 \n",
" LSTMSeq2Seq 6 1.292539 \n",
" Transformer 1 1.295221 \n",
" TransformerProcess 1 1.404077 \n",
" CrossAttention 11 1.411578 \n",
" RANP 1 1.429652 \n",
" BaselineMean 3 1.435451 \n",
" BaselineLast 12 1.548376 \n",
" TransformerSeq2Seq 8 2.393506 \n",
"GasSensor RANP 10 -2.133501 \n",
" InceptionTimeSeq 10 -2.102463 \n",
" Transformer 12 -1.960168 \n",
" TCNSeq 8 -1.736079 \n",
" LSTM 5 -1.537527 \n",
" LSTMSeq2Seq 5 -1.489893 \n",
" TransformerProcess 4 -0.880734 \n",
" CrossAttention 3 -0.638713 \n",
" TransformerSeq2Seq 1 0.336668 \n",
" BaselineMean 2 1.584720 \n",
" BaselineLast 6 1.974204 \n",
"IMOSCurrentsVel TCNSeq 6 0.817371 \n",
" InceptionTimeSeq 6 0.845443 \n",
" LSTM 6 0.875669 \n",
" Transformer 3 0.879682 \n",
" BaselineLast 6 0.885377 \n",
" LSTMSeq2Seq 12 0.887783 \n",
" RANP 3 1.035464 \n",
" BaselineMean 7 1.202492 \n",
" TransformerSeq2Seq 1 1.266849 \n",
" TransformerProcess 1 1.394595 \n",
" CrossAttention 3 1.656854 \n",
"MetroInterstateTraffic TCNSeq 3 -0.324690 \n",
" TransformerProcess 7 -0.297472 \n",
" RANP 3 -0.291096 \n",
" Transformer 11 -0.253573 \n",
" LSTMSeq2Seq 5 -0.200625 \n",
" LSTM 5 -0.200455 \n",
" TransformerSeq2Seq 6 -0.192895 \n",
" InceptionTimeSeq 6 -0.162513 \n",
" CrossAttention 1 -0.103853 \n",
" BaselineMean 1 1.410653 \n",
" BaselineLast 12 1.741882 "
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"# Summarize studies\n",
"rs = []\n",
"study_summaries = optuna.study.get_all_study_summaries(storage=storage)\n",
"for s in study_summaries:\n",
" row = {}\n",
" if (s.best_trial is not None) and (s.best_trial.state==optuna.trial.TrialState.COMPLETE):\n",
" params = {k:v for k,v in s.best_trial.__dict__.items() if not k.startswith('_')}\n",
" row.update(s.user_attrs)\n",
" row['n_trials'] = s.n_trials\n",
" row.update({'param_'+k:v for k,v in s.best_trial._params.items()})\n",
" row.update(params)\n",
" rs.append(row)\n",
"df_studies = pd.DataFrame(rs)\n",
"\n",
"df_studies = (df_studies.drop(columns=['state', 'intermediate_values', 'commit', 'datetime_complete'])\n",
" .sort_values(['dataset', 'value'])\n",
" .set_index(['dataset', 'model'])\n",
" )\n",
"df_studies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2020-11-13T07:54:27.730531Z",
"start_time": "2020-11-13T07:54:27.667135Z"
},
"lines_to_next_cell": 0,
"scrolled": true
},
"outputs": [],
"source": [
"# study_names = [s.study_name for s in optuna.study.get_all_study_summaries(storage=storage)]\n",
"\n",
"# for study_name in study_names:\n",
"# loaded_study = optuna.load_study(study_name=study_name, storage=storage)\n",
" \n",
"# # Make DF over trials\n",
"# print(study_name)\n",
" \n",
"# df_trials = loaded_study.trials_dataframe()\n",
"# # df_trials.index = df_trials.apply(lambda r:f'l={r.params_layers}_hs={r.params_hidden_size_exp}', 1)\n",
"# display(df_trials)\n",
"\n",
"# # Plot test curves, to see how much overfitting\n",
"# df_values = pd.DataFrame([s.intermediate_values for s in loaded_study.get_trials()]).T\n",
"# df_values.columns = [f\"l={s.params['layers']}_hs={s.params['hidden_size_exp']}\" for s in loaded_study.get_trials()]\n",
"# df_values.plot(ylabel='nll', xlabel='epochs', title=f'val loss \"{study_name}\"')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2020-11-13T01:44:25.350038Z",
"start_time": "2020-11-12T23:16:12.130Z"
},
"lines_to_next_cell": 0
},
"outputs": [],
"source": [
"# study_names = [s.study_name for s in optuna.study.get_all_study_summaries(storage=storage)]\n",
"\n",
"# for study_name in study_names:\n",
"# loaded_study = optuna.load_study(study_name=study_name, storage=storage)\n",
"# fig=optuna.visualization.plot_contour(loaded_study, params=['hidden_size_exp', 'layers'])\n",
"# fig = fig.update_layout(dict(title=f'{study_name} nll'))\n",
"# display(fig)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2020-11-08T23:11:33.818862Z",
"start_time": "2020-11-08T23:11:33.749610Z"
}
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2020-11-08T23:11:14.588720Z",
"start_time": "2020-11-08T23:11:14.516020Z"
},
"lines_to_next_cell": 2
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2020-11-13T01:44:25.351121Z",
"start_time": "2020-11-12T23:16:12.141Z"
}
},
"outputs": [],
"source": [
"[f\"{s.params['layers']}_{s.params['hidden_size_exp']}\" for s in loaded_study.get_trials()]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2020-11-13T01:44:25.352257Z",
"start_time": "2020-11-12T23:16:12.146Z"
},
"scrolled": true
},
"outputs": [],
"source": [
"df_values"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": []
}
],
"metadata": {
"@webio": {
"lastCommId": null,
"lastKernelId": null
},
"jupytext": {
"encoding": "# -*- coding: utf-8 -*-",
"formats": "ipynb,py:light"
},
"kernelspec": {
"display_name": "seq2seq-time",
"language": "python",
"name": "seq2seq-time"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.8"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {
"height": "calc(100% - 180px)",
"left": "10px",
"top": "150px",
"width": "209.162px"
},
"toc_section_display": true,
"toc_window_display": true
}
},
"nbformat": 4,
"nbformat_minor": 2
}