Files
pytorch-ts/examples/Multivariate-Flow-Solar.ipynb
wassname 3fb53f620e wip
2022-12-23 14:34:54 +08:00

951 lines
55 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from gluonts.dataset.multivariate_grouper import MultivariateGrouper\n",
"from gluonts.dataset.repository.datasets import dataset_recipes, get_dataset\n",
"from pts.model.tempflow import TempFlowEstimator\n",
"from pts.model.transformer_tempflow import TransformerTempFlowEstimator\n",
"from pts import Trainer\n",
"from gluonts.evaluation.backtest import make_evaluation_predictions\n",
"from gluonts.evaluation import MultivariateEvaluator"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prepeare data set"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"dataset = get_dataset(\"solar_nips\", regenerate=False)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"MetaData(freq='H', target=None, feat_static_cat=[CategoricalFeatureInfo(name='feat_static_cat_0', cardinality='137')], feat_static_real=[], feat_dynamic_real=[], feat_dynamic_cat=[], prediction_length=24)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset.metadata"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"train_grouper = MultivariateGrouper(max_target_dim=int(dataset.metadata.feat_static_cat[0].cardinality))\n",
"\n",
"test_grouper = MultivariateGrouper(num_test_dates=int(len(dataset.test)/len(dataset.train)), \n",
" max_target_dim=int(dataset.metadata.feat_static_cat[0].cardinality))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/wassname/miniforge3/envs/glounts/lib/python3.9/site-packages/gluonts/dataset/multivariate_grouper.py:191: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n",
" return {FieldName.TARGET: np.array([funcs(data) for data in dataset])}\n"
]
}
],
"source": [
"dataset_train = train_grouper(dataset.train)\n",
"dataset_test = test_grouper(dataset.test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluator"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"evaluator = MultivariateEvaluator(quantiles=(np.arange(20)/20.0)[1:],\n",
" target_agg_funcs={'sum': np.sum})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## `GRU-Real-NVP`"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"estimator = TempFlowEstimator(\n",
" target_dim=int(dataset.metadata.feat_static_cat[0].cardinality),\n",
" prediction_length=dataset.metadata.prediction_length,\n",
" cell_type='GRU',\n",
" input_size=552,\n",
" freq=dataset.metadata.freq,\n",
" scaling=True,\n",
" dequantize=True,\n",
" n_blocks=4,\n",
" trainer=Trainer(device=device,\n",
" epochs=45,\n",
" learning_rate=1e-3,\n",
" num_batches_per_epoch=100,\n",
" batch_size=64)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e6b529fa683f4e51b5f8ac63db471849",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"ename": "Exception",
"evalue": "Reached maximum number of idle transformation calls.\nThis means the transformation looped over 1 inputs without returning any output.\nThis occurred in the following transformation:\ngluonts.transform.split.InstanceSplitter(dummy_value=0.0, forecast_start_field=\"forecast_start\", future_length=24, instance_sampler=gluonts.transform.sampler.ExpectedNumInstanceSampler(axis=-1, min_past=192, min_future=24, num_instances=1.0, total_length=20382, n=3), is_pad_field=\"is_pad\", lead_time=0, output_NTC=True, past_length=192, start_field=\"start\", target_field=\"target\", time_series_fields=[\"time_feat\", \"observed_values\"])",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mException\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/examples/Multivariate-Flow-Solar.ipynb Cell 13\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> <a href='vscode-notebook-cell:/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/examples/Multivariate-Flow-Solar.ipynb#X15sZmlsZQ%3D%3D?line=0'>1</a>\u001b[0m predictor \u001b[39m=\u001b[39m estimator\u001b[39m.\u001b[39;49mtrain(dataset_train)\n\u001b[1;32m <a href='vscode-notebook-cell:/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/examples/Multivariate-Flow-Solar.ipynb#X15sZmlsZQ%3D%3D?line=1'>2</a>\u001b[0m forecast_it, ts_it \u001b[39m=\u001b[39m make_evaluation_predictions(dataset\u001b[39m=\u001b[39mdataset_test,\n\u001b[1;32m <a href='vscode-notebook-cell:/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/examples/Multivariate-Flow-Solar.ipynb#X15sZmlsZQ%3D%3D?line=2'>3</a>\u001b[0m predictor\u001b[39m=\u001b[39mpredictor,\n\u001b[1;32m <a href='vscode-notebook-cell:/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/examples/Multivariate-Flow-Solar.ipynb#X15sZmlsZQ%3D%3D?line=3'>4</a>\u001b[0m num_samples\u001b[39m=\u001b[39m\u001b[39m100\u001b[39m)\n\u001b[1;32m <a href='vscode-notebook-cell:/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/examples/Multivariate-Flow-Solar.ipynb#X15sZmlsZQ%3D%3D?line=4'>5</a>\u001b[0m forecasts \u001b[39m=\u001b[39m \u001b[39mlist\u001b[39m(forecast_it)\n",
"File \u001b[0;32m/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/pts/model/estimator.py:179\u001b[0m, in \u001b[0;36mPyTorchEstimator.train\u001b[0;34m(self, training_data, validation_data, num_workers, prefetch_factor, shuffle_buffer_length, cache_data, **kwargs)\u001b[0m\n\u001b[1;32m 169\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mtrain\u001b[39m(\n\u001b[1;32m 170\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 171\u001b[0m training_data: Dataset,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 177\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs,\n\u001b[1;32m 178\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m PyTorchPredictor:\n\u001b[0;32m--> 179\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtrain_model(\n\u001b[1;32m 180\u001b[0m training_data,\n\u001b[1;32m 181\u001b[0m validation_data,\n\u001b[1;32m 182\u001b[0m num_workers\u001b[39m=\u001b[39;49mnum_workers,\n\u001b[1;32m 183\u001b[0m prefetch_factor\u001b[39m=\u001b[39;49mprefetch_factor,\n\u001b[1;32m 184\u001b[0m shuffle_buffer_length\u001b[39m=\u001b[39;49mshuffle_buffer_length,\n\u001b[1;32m 185\u001b[0m cache_data\u001b[39m=\u001b[39;49mcache_data,\n\u001b[1;32m 186\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs,\n\u001b[1;32m 187\u001b[0m )\u001b[39m.\u001b[39mpredictor\n",
"File \u001b[0;32m/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/pts/model/estimator.py:151\u001b[0m, in \u001b[0;36mPyTorchEstimator.train_model\u001b[0;34m(self, training_data, validation_data, num_workers, prefetch_factor, shuffle_buffer_length, cache_data, **kwargs)\u001b[0m\n\u001b[1;32m 133\u001b[0m validation_iter_dataset \u001b[39m=\u001b[39m TransformedIterableDataset(\n\u001b[1;32m 134\u001b[0m dataset\u001b[39m=\u001b[39mvalidation_data,\n\u001b[1;32m 135\u001b[0m transform\u001b[39m=\u001b[39mtransformation\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 139\u001b[0m cache_data\u001b[39m=\u001b[39mcache_data,\n\u001b[1;32m 140\u001b[0m )\n\u001b[1;32m 141\u001b[0m validation_data_loader \u001b[39m=\u001b[39m DataLoader(\n\u001b[1;32m 142\u001b[0m validation_iter_dataset,\n\u001b[1;32m 143\u001b[0m batch_size\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtrainer\u001b[39m.\u001b[39mbatch_size,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 148\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs,\n\u001b[1;32m 149\u001b[0m )\n\u001b[0;32m--> 151\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtrainer(\n\u001b[1;32m 152\u001b[0m net\u001b[39m=\u001b[39;49mtrained_net,\n\u001b[1;32m 153\u001b[0m train_iter\u001b[39m=\u001b[39;49mtraining_data_loader,\n\u001b[1;32m 154\u001b[0m validation_iter\u001b[39m=\u001b[39;49mvalidation_data_loader,\n\u001b[1;32m 155\u001b[0m )\n\u001b[1;32m 157\u001b[0m \u001b[39mreturn\u001b[39;00m TrainOutput(\n\u001b[1;32m 158\u001b[0m transformation\u001b[39m=\u001b[39mtransformation,\n\u001b[1;32m 159\u001b[0m trained_net\u001b[39m=\u001b[39mtrained_net,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 162\u001b[0m ),\n\u001b[1;32m 163\u001b[0m )\n",
"File \u001b[0;32m/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/pts/trainer.py:63\u001b[0m, in \u001b[0;36mTrainer.__call__\u001b[0;34m(self, net, train_iter, validation_iter)\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[39m# training loop\u001b[39;00m\n\u001b[1;32m 62\u001b[0m \u001b[39mwith\u001b[39;00m tqdm(train_iter, total\u001b[39m=\u001b[39mtotal) \u001b[39mas\u001b[39;00m it:\n\u001b[0;32m---> 63\u001b[0m \u001b[39mfor\u001b[39;00m batch_no, data_entry \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(it, start\u001b[39m=\u001b[39m\u001b[39m1\u001b[39m):\n\u001b[1;32m 64\u001b[0m optimizer\u001b[39m.\u001b[39mzero_grad()\n\u001b[1;32m 66\u001b[0m inputs \u001b[39m=\u001b[39m [v\u001b[39m.\u001b[39mto(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdevice) \u001b[39mfor\u001b[39;00m v \u001b[39min\u001b[39;00m data_entry\u001b[39m.\u001b[39mvalues()]\n",
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/tqdm/notebook.py:259\u001b[0m, in \u001b[0;36mtqdm_notebook.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 257\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 258\u001b[0m it \u001b[39m=\u001b[39m \u001b[39msuper\u001b[39m(tqdm_notebook, \u001b[39mself\u001b[39m)\u001b[39m.\u001b[39m\u001b[39m__iter__\u001b[39m()\n\u001b[0;32m--> 259\u001b[0m \u001b[39mfor\u001b[39;00m obj \u001b[39min\u001b[39;00m it:\n\u001b[1;32m 260\u001b[0m \u001b[39m# return super(tqdm...) will not catch exception\u001b[39;00m\n\u001b[1;32m 261\u001b[0m \u001b[39myield\u001b[39;00m obj\n\u001b[1;32m 262\u001b[0m \u001b[39m# NB: except ... [ as ...] breaks IPython async KeyboardInterrupt\u001b[39;00m\n",
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/tqdm/std.py:1195\u001b[0m, in \u001b[0;36mtqdm.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1192\u001b[0m time \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_time\n\u001b[1;32m 1194\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m-> 1195\u001b[0m \u001b[39mfor\u001b[39;00m obj \u001b[39min\u001b[39;00m iterable:\n\u001b[1;32m 1196\u001b[0m \u001b[39myield\u001b[39;00m obj\n\u001b[1;32m 1197\u001b[0m \u001b[39m# Update and possibly print the progressbar.\u001b[39;00m\n\u001b[1;32m 1198\u001b[0m \u001b[39m# Note: does not call self.update(1) for speed optimisation.\u001b[39;00m\n",
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/torch/utils/data/dataloader.py:628\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 625\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_sampler_iter \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 626\u001b[0m \u001b[39m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[1;32m 627\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_reset() \u001b[39m# type: ignore[call-arg]\u001b[39;00m\n\u001b[0;32m--> 628\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_next_data()\n\u001b[1;32m 629\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\n\u001b[1;32m 630\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_dataset_kind \u001b[39m==\u001b[39m _DatasetKind\u001b[39m.\u001b[39mIterable \u001b[39mand\u001b[39;00m \\\n\u001b[1;32m 631\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m \\\n\u001b[1;32m 632\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m>\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called:\n",
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/torch/utils/data/dataloader.py:671\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 669\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_next_data\u001b[39m(\u001b[39mself\u001b[39m):\n\u001b[1;32m 670\u001b[0m index \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_next_index() \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m--> 671\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_dataset_fetcher\u001b[39m.\u001b[39;49mfetch(index) \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m 672\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_pin_memory:\n\u001b[1;32m 673\u001b[0m data \u001b[39m=\u001b[39m _utils\u001b[39m.\u001b[39mpin_memory\u001b[39m.\u001b[39mpin_memory(data, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_pin_memory_device)\n",
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py:34\u001b[0m, in \u001b[0;36m_IterableDatasetFetcher.fetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[39mfor\u001b[39;00m _ \u001b[39min\u001b[39;00m possibly_batched_index:\n\u001b[1;32m 33\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m---> 34\u001b[0m data\u001b[39m.\u001b[39mappend(\u001b[39mnext\u001b[39;49m(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdataset_iter))\n\u001b[1;32m 35\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mStopIteration\u001b[39;00m:\n\u001b[1;32m 36\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mended \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m\n",
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/gluonts/transform/_base.py:103\u001b[0m, in \u001b[0;36mTransformedDataset.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 102\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__iter__\u001b[39m(\u001b[39mself\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Iterator[DataEntry]:\n\u001b[0;32m--> 103\u001b[0m \u001b[39myield from\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtransformation(\n\u001b[1;32m 104\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbase_dataset, is_train\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mis_train\n\u001b[1;32m 105\u001b[0m )\n",
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/gluonts/transform/_base.py:124\u001b[0m, in \u001b[0;36mMapTransformation.__call__\u001b[0;34m(self, data_it, is_train)\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__call__\u001b[39m(\n\u001b[1;32m 122\u001b[0m \u001b[39mself\u001b[39m, data_it: Iterable[DataEntry], is_train: \u001b[39mbool\u001b[39m\n\u001b[1;32m 123\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Iterator:\n\u001b[0;32m--> 124\u001b[0m \u001b[39mfor\u001b[39;00m data_entry \u001b[39min\u001b[39;00m data_it:\n\u001b[1;32m 125\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 126\u001b[0m \u001b[39myield\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmap_transform(data_entry\u001b[39m.\u001b[39mcopy(), is_train)\n",
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/gluonts/transform/_base.py:124\u001b[0m, in \u001b[0;36mMapTransformation.__call__\u001b[0;34m(self, data_it, is_train)\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__call__\u001b[39m(\n\u001b[1;32m 122\u001b[0m \u001b[39mself\u001b[39m, data_it: Iterable[DataEntry], is_train: \u001b[39mbool\u001b[39m\n\u001b[1;32m 123\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Iterator:\n\u001b[0;32m--> 124\u001b[0m \u001b[39mfor\u001b[39;00m data_entry \u001b[39min\u001b[39;00m data_it:\n\u001b[1;32m 125\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 126\u001b[0m \u001b[39myield\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmap_transform(data_entry\u001b[39m.\u001b[39mcopy(), is_train)\n",
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/gluonts/transform/_base.py:189\u001b[0m, in \u001b[0;36mFlatMapTransformation.__call__\u001b[0;34m(self, data_it, is_train)\u001b[0m\n\u001b[1;32m 182\u001b[0m \u001b[39myield\u001b[39;00m result\n\u001b[1;32m 184\u001b[0m \u001b[39mif\u001b[39;00m (\n\u001b[1;32m 185\u001b[0m \u001b[39m# negative values disable the check\u001b[39;00m\n\u001b[1;32m 186\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmax_idle_transforms \u001b[39m>\u001b[39m \u001b[39m0\u001b[39m\n\u001b[1;32m 187\u001b[0m \u001b[39mand\u001b[39;00m num_idle_transforms \u001b[39m>\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmax_idle_transforms\n\u001b[1;32m 188\u001b[0m ):\n\u001b[0;32m--> 189\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mException\u001b[39;00m(\n\u001b[1;32m 190\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mReached maximum number of idle transformation\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 191\u001b[0m \u001b[39m\"\u001b[39m\u001b[39m calls.\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39mThis means the transformation looped over\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 192\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m \u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmax_idle_transforms\u001b[39m}\u001b[39;00m\u001b[39m inputs without returning any\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 193\u001b[0m \u001b[39m\"\u001b[39m\u001b[39m output.\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39mThis occurred in the following\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 194\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m transformation:\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m\n\u001b[1;32m 195\u001b[0m )\n",
"\u001b[0;31mException\u001b[0m: Reached maximum number of idle transformation calls.\nThis means the transformation looped over 1 inputs without returning any output.\nThis occurred in the following transformation:\ngluonts.transform.split.InstanceSplitter(dummy_value=0.0, forecast_start_field=\"forecast_start\", future_length=24, instance_sampler=gluonts.transform.sampler.ExpectedNumInstanceSampler(axis=-1, min_past=192, min_future=24, num_instances=1.0, total_length=20382, n=3), is_pad_field=\"is_pad\", lead_time=0, output_NTC=True, past_length=192, start_field=\"start\", target_field=\"target\", time_series_fields=[\"time_feat\", \"observed_values\"])"
]
}
],
"source": [
"predictor = estimator.train(dataset_train)\n",
"forecast_it, ts_it = make_evaluation_predictions(dataset=dataset_test,\n",
" predictor=predictor,\n",
" num_samples=100)\n",
"forecasts = list(forecast_it)\n",
"targets = list(ts_it)\n",
"\n",
"agg_metric, _ = evaluator(targets, forecasts, num_series=len(dataset_test))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Metrics"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CRPS: 0.36531966950112466\n",
"ND: 0.45434020382814283\n",
"NRMSE: 0.9820216603495642\n",
"MSE: 914.7868680304274\n"
]
}
],
"source": [
"print(\"CRPS: {}\".format(agg_metric['mean_wQuantileLoss']))\n",
"print(\"ND: {}\".format(agg_metric['ND']))\n",
"print(\"NRMSE: {}\".format(agg_metric['NRMSE']))\n",
"print(\"MSE: {}\".format(agg_metric['MSE']))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CRPS-Sum: 0.2873863376280519\n",
"ND-Sum: 0.35970480888579265\n",
"NRMSE-Sum: 0.7184166842326591\n",
"MSE-Sum: 9189074.285714285\n"
]
}
],
"source": [
"print(\"CRPS-Sum: {}\".format(agg_metric['m_sum_mean_wQuantileLoss']))\n",
"print(\"ND-Sum: {}\".format(agg_metric['m_sum_ND']))\n",
"print(\"NRMSE-Sum: {}\".format(agg_metric['m_sum_NRMSE']))\n",
"print(\"MSE-Sum: {}\".format(agg_metric['m_sum_MSE']))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## `GRU-MAF`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"estimator = TempFlowEstimator(\n",
" target_dim=int(dataset.metadata.feat_static_cat[0].cardinality),\n",
" prediction_length=dataset.metadata.prediction_length,\n",
" cell_type='GRU',\n",
" input_size=552,\n",
" freq=dataset.metadata.freq,\n",
" scaling=True,\n",
" dequantize=True,\n",
" flow_type='MAF',\n",
" trainer=Trainer(device=device,\n",
" epochs=25,\n",
" learning_rate=1e-3,\n",
" num_batches_per_epoch=100,\n",
" batch_size=64)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"98it [00:10, 9.05it/s, avg_epoch_loss=-7.36, epoch=0]\n",
"99it [00:10, 9.19it/s, avg_epoch_loss=-136, epoch=1]\n",
"99it [00:10, 9.12it/s, avg_epoch_loss=-164, epoch=2]\n",
"98it [00:10, 8.91it/s, avg_epoch_loss=-179, epoch=3]\n",
"98it [00:10, 9.09it/s, avg_epoch_loss=-188, epoch=4]\n",
"99it [00:10, 9.05it/s, avg_epoch_loss=-194, epoch=5]\n",
"98it [00:10, 9.04it/s, avg_epoch_loss=-198, epoch=6]\n",
"98it [00:10, 8.97it/s, avg_epoch_loss=-201, epoch=7]\n",
"97it [00:10, 8.90it/s, avg_epoch_loss=-204, epoch=8]\n",
"99it [00:10, 9.07it/s, avg_epoch_loss=-206, epoch=9]\n",
"99it [00:10, 9.09it/s, avg_epoch_loss=-207, epoch=10]\n",
"98it [00:11, 8.90it/s, avg_epoch_loss=-209, epoch=11]\n",
"99it [00:10, 9.02it/s, avg_epoch_loss=-210, epoch=12]\n",
"98it [00:10, 8.95it/s, avg_epoch_loss=-211, epoch=13]\n",
"99it [00:10, 9.21it/s, avg_epoch_loss=-212, epoch=14]\n",
"98it [00:10, 9.00it/s, avg_epoch_loss=-213, epoch=15]\n",
"99it [00:10, 9.21it/s, avg_epoch_loss=-214, epoch=16]\n",
"98it [00:10, 8.95it/s, avg_epoch_loss=-215, epoch=17]\n",
"98it [00:11, 8.88it/s, avg_epoch_loss=-216, epoch=18]\n",
"99it [00:10, 9.08it/s, avg_epoch_loss=-216, epoch=19]\n",
"98it [00:10, 8.96it/s, avg_epoch_loss=-217, epoch=20]\n",
"98it [00:10, 8.98it/s, avg_epoch_loss=-218, epoch=21]\n",
"97it [00:10, 8.88it/s, avg_epoch_loss=-218, epoch=22]\n",
"97it [00:10, 8.83it/s, avg_epoch_loss=-219, epoch=23]\n",
"98it [00:10, 8.97it/s, avg_epoch_loss=-219, epoch=24]\n",
" 0%| | 0/137 [00:00<?, ?it/s]\n",
"Running evaluation: 7it [00:00, 79.30it/s]\n",
" 1%| | 1/137 [00:00<00:14, 9.67it/s]\n",
"Running evaluation: 7it [00:00, 80.37it/s]\n",
" 1%|▏ | 2/137 [00:00<00:13, 9.69it/s]\n",
"Running evaluation: 7it [00:00, 80.99it/s]\n",
" 2%|▏ | 3/137 [00:00<00:13, 9.72it/s]\n",
"Running evaluation: 7it [00:00, 80.16it/s]\n",
" 3%|▎ | 4/137 [00:00<00:13, 9.74it/s]\n",
"Running evaluation: 7it [00:00, 81.87it/s]\n",
" 4%|▎ | 5/137 [00:00<00:15, 8.38it/s]\n",
"Running evaluation: 7it [00:00, 80.00it/s]\n",
" 4%|▍ | 6/137 [00:00<00:15, 8.73it/s]\n",
"Running evaluation: 7it [00:00, 80.20it/s]\n",
" 5%|▌ | 7/137 [00:00<00:14, 9.02it/s]\n",
"Running evaluation: 7it [00:00, 81.10it/s]\n",
" 6%|▌ | 8/137 [00:00<00:13, 9.26it/s]\n",
"Running evaluation: 7it [00:00, 78.81it/s]\n",
" 7%|▋ | 9/137 [00:00<00:13, 9.37it/s]\n",
"Running evaluation: 7it [00:00, 80.94it/s]\n",
" 7%|▋ | 10/137 [00:01<00:13, 9.51it/s]\n",
"Running evaluation: 7it [00:00, 80.86it/s]\n",
" 8%|▊ | 11/137 [00:01<00:13, 9.61it/s]\n",
"Running evaluation: 7it [00:00, 79.54it/s]\n",
" 9%|▉ | 12/137 [00:01<00:12, 9.65it/s]\n",
"Running evaluation: 7it [00:00, 80.68it/s]\n",
" 9%|▉ | 13/137 [00:01<00:12, 9.70it/s]\n",
"Running evaluation: 7it [00:00, 80.79it/s]\n",
" 10%|█ | 14/137 [00:01<00:12, 9.67it/s]\n",
"Running evaluation: 7it [00:00, 79.82it/s]\n",
" 11%|█ | 15/137 [00:01<00:12, 9.69it/s]\n",
"Running evaluation: 7it [00:00, 80.53it/s]\n",
" 12%|█▏ | 16/137 [00:01<00:12, 9.73it/s]\n",
"Running evaluation: 7it [00:00, 76.42it/s]\n",
" 12%|█▏ | 17/137 [00:01<00:12, 9.59it/s]\n",
"Running evaluation: 7it [00:00, 79.49it/s]\n",
" 13%|█▎ | 18/137 [00:01<00:12, 9.61it/s]\n",
"Running evaluation: 7it [00:00, 80.33it/s]\n",
" 14%|█▍ | 19/137 [00:02<00:12, 9.66it/s]\n",
"Running evaluation: 7it [00:00, 81.06it/s]\n",
" 15%|█▍ | 20/137 [00:02<00:12, 9.70it/s]\n",
"Running evaluation: 7it [00:00, 80.93it/s]\n",
" 15%|█▌ | 21/137 [00:02<00:11, 9.74it/s]\n",
"Running evaluation: 7it [00:00, 81.20it/s]\n",
" 16%|█▌ | 22/137 [00:02<00:11, 9.78it/s]\n",
"Running evaluation: 7it [00:00, 80.83it/s]\n",
" 17%|█▋ | 23/137 [00:02<00:11, 9.79it/s]\n",
"Running evaluation: 7it [00:00, 81.33it/s]\n",
" 18%|█▊ | 24/137 [00:02<00:11, 9.82it/s]\n",
"Running evaluation: 7it [00:00, 80.91it/s]\n",
" 18%|█▊ | 25/137 [00:02<00:11, 9.83it/s]\n",
"Running evaluation: 7it [00:00, 81.40it/s]\n",
" 19%|█▉ | 26/137 [00:02<00:11, 9.85it/s]\n",
"Running evaluation: 7it [00:00, 81.20it/s]\n",
" 20%|█▉ | 27/137 [00:02<00:11, 9.86it/s]\n",
"Running evaluation: 7it [00:00, 82.38it/s]\n",
" 20%|██ | 28/137 [00:02<00:11, 9.90it/s]\n",
"Running evaluation: 7it [00:00, 82.08it/s]\n",
" 21%|██ | 29/137 [00:03<00:10, 9.92it/s]\n",
"Running evaluation: 7it [00:00, 81.48it/s]\n",
" 22%|██▏ | 30/137 [00:03<00:10, 9.91it/s]\n",
"Running evaluation: 7it [00:00, 80.91it/s]\n",
" 23%|██▎ | 31/137 [00:03<00:10, 9.89it/s]\n",
"Running evaluation: 7it [00:00, 81.74it/s]\n",
" 23%|██▎ | 32/137 [00:03<00:10, 9.89it/s]\n",
"Running evaluation: 7it [00:00, 82.56it/s]\n",
" 24%|██▍ | 33/137 [00:03<00:10, 9.91it/s]\n",
"Running evaluation: 7it [00:00, 82.94it/s]\n",
"\n",
"Running evaluation: 7it [00:00, 81.87it/s]\n",
" 26%|██▌ | 35/137 [00:03<00:10, 9.93it/s]\n",
"Running evaluation: 7it [00:00, 82.36it/s]\n",
"\n",
"Running evaluation: 7it [00:00, 82.48it/s]\n",
" 27%|██▋ | 37/137 [00:03<00:10, 9.96it/s]\n",
"Running evaluation: 7it [00:00, 81.42it/s]\n",
" 28%|██▊ | 38/137 [00:03<00:09, 9.93it/s]\n",
"Running evaluation: 7it [00:00, 81.16it/s]\n",
" 28%|██▊ | 39/137 [00:04<00:09, 9.91it/s]\n",
"Running evaluation: 7it [00:00, 80.18it/s]\n",
" 29%|██▉ | 40/137 [00:04<00:09, 9.86it/s]\n",
"Running evaluation: 7it [00:00, 79.44it/s]\n",
" 30%|██▉ | 41/137 [00:04<00:09, 9.77it/s]\n",
"Running evaluation: 7it [00:00, 78.49it/s]\n",
" 31%|███ | 42/137 [00:04<00:09, 9.71it/s]\n",
"Running evaluation: 7it [00:00, 81.14it/s]\n",
" 31%|███▏ | 43/137 [00:04<00:09, 9.76it/s]\n",
"Running evaluation: 7it [00:00, 80.89it/s]\n",
" 32%|███▏ | 44/137 [00:04<00:09, 9.77it/s]\n",
"Running evaluation: 7it [00:00, 79.56it/s]\n",
" 33%|███▎ | 45/137 [00:04<00:09, 9.75it/s]\n",
"Running evaluation: 7it [00:00, 80.37it/s]\n",
" 34%|███▎ | 46/137 [00:04<00:09, 9.77it/s]\n",
"Running evaluation: 7it [00:00, 81.06it/s]\n",
" 34%|███▍ | 47/137 [00:04<00:09, 9.80it/s]\n",
"Running evaluation: 7it [00:00, 81.00it/s]\n",
" 35%|███▌ | 48/137 [00:04<00:09, 9.82it/s]\n",
"Running evaluation: 7it [00:00, 79.59it/s]\n",
" 36%|███▌ | 49/137 [00:05<00:08, 9.78it/s]\n",
"Running evaluation: 7it [00:00, 80.16it/s]\n",
" 36%|███▋ | 50/137 [00:05<00:08, 9.78it/s]\n",
"Running evaluation: 7it [00:00, 81.13it/s]\n",
" 37%|███▋ | 51/137 [00:05<00:08, 9.81it/s]\n",
"Running evaluation: 7it [00:00, 80.09it/s]\n",
" 38%|███▊ | 52/137 [00:05<00:08, 9.80it/s]\n",
"Running evaluation: 7it [00:00, 81.05it/s]\n",
" 39%|███▊ | 53/137 [00:05<00:08, 9.82it/s]\n",
"Running evaluation: 7it [00:00, 79.80it/s]\n",
" 39%|███▉ | 54/137 [00:05<00:08, 9.80it/s]\n",
"Running evaluation: 7it [00:00, 79.99it/s]\n",
" 40%|████ | 55/137 [00:05<00:08, 9.79it/s]\n",
"Running evaluation: 7it [00:00, 80.74it/s]\n",
" 41%|████ | 56/137 [00:05<00:08, 9.80it/s]\n",
"Running evaluation: 7it [00:00, 80.38it/s]\n",
" 42%|████▏ | 57/137 [00:05<00:08, 9.80it/s]\n",
"Running evaluation: 7it [00:00, 80.81it/s]\n",
" 42%|████▏ | 58/137 [00:05<00:08, 9.81it/s]\n",
"Running evaluation: 7it [00:00, 80.07it/s]\n",
" 43%|████▎ | 59/137 [00:06<00:07, 9.80it/s]\n",
"Running evaluation: 7it [00:00, 79.89it/s]\n",
" 44%|████▍ | 60/137 [00:06<00:07, 9.79it/s]\n",
"Running evaluation: 7it [00:00, 80.78it/s]\n",
" 45%|████▍ | 61/137 [00:06<00:07, 9.81it/s]\n",
"Running evaluation: 7it [00:00, 80.28it/s]\n",
" 45%|████▌ | 62/137 [00:06<00:07, 9.80it/s]\n",
"Running evaluation: 7it [00:00, 80.50it/s]\n",
" 46%|████▌ | 63/137 [00:06<00:07, 9.81it/s]\n",
"Running evaluation: 7it [00:00, 80.07it/s]\n",
" 47%|████▋ | 64/137 [00:06<00:07, 9.76it/s]\n",
"Running evaluation: 7it [00:00, 77.93it/s]\n",
" 47%|████▋ | 65/137 [00:06<00:07, 9.67it/s]\n",
"Running evaluation: 7it [00:00, 80.83it/s]\n",
" 48%|████▊ | 66/137 [00:06<00:07, 9.72it/s]\n",
"Running evaluation: 7it [00:00, 81.68it/s]\n",
" 49%|████▉ | 67/137 [00:06<00:07, 9.79it/s]\n",
"Running evaluation: 7it [00:00, 81.13it/s]\n",
" 50%|████▉ | 68/137 [00:06<00:07, 9.81it/s]\n",
"Running evaluation: 7it [00:00, 81.32it/s]\n",
" 50%|█████ | 69/137 [00:07<00:06, 9.84it/s]\n",
"Running evaluation: 7it [00:00, 80.92it/s]\n",
" 51%|█████ | 70/137 [00:07<00:06, 9.83it/s]\n",
"Running evaluation: 7it [00:00, 81.30it/s]\n",
" 52%|█████▏ | 71/137 [00:07<00:06, 9.85it/s]\n",
"Running evaluation: 7it [00:00, 82.09it/s]\n",
" 53%|█████▎ | 72/137 [00:07<00:06, 9.89it/s]\n",
"Running evaluation: 7it [00:00, 81.77it/s]\n",
" 53%|█████▎ | 73/137 [00:07<00:06, 9.87it/s]\n",
"Running evaluation: 7it [00:00, 81.08it/s]\n",
" 54%|█████▍ | 74/137 [00:07<00:06, 9.87it/s]\n",
"Running evaluation: 7it [00:00, 81.81it/s]\n",
" 55%|█████▍ | 75/137 [00:07<00:06, 9.89it/s]\n",
"Running evaluation: 7it [00:00, 80.63it/s]\n",
" 55%|█████▌ | 76/137 [00:07<00:06, 9.87it/s]\n",
"Running evaluation: 7it [00:00, 80.91it/s]\n",
" 56%|█████▌ | 77/137 [00:07<00:06, 9.87it/s]\n",
"Running evaluation: 7it [00:00, 80.83it/s]\n",
" 57%|█████▋ | 78/137 [00:08<00:05, 9.86it/s]\n",
"Running evaluation: 7it [00:00, 80.90it/s]\n",
" 58%|█████▊ | 79/137 [00:08<00:05, 9.86it/s]\n",
"Running evaluation: 7it [00:00, 80.50it/s]\n",
" 58%|█████▊ | 80/137 [00:08<00:05, 9.84it/s]\n",
"Running evaluation: 7it [00:00, 80.69it/s]\n",
" 59%|█████▉ | 81/137 [00:08<00:05, 9.84it/s]\n",
"Running evaluation: 7it [00:00, 78.67it/s]\n",
" 60%|█████▉ | 82/137 [00:08<00:05, 9.74it/s]\n",
"Running evaluation: 0it [00:00, ?it/s]\u001b[A\n",
"Running evaluation: 7it [00:00, 57.62it/s]\u001b[A\n",
" 61%|██████ | 83/137 [00:08<00:06, 8.73it/s]\n",
"Running evaluation: 7it [00:00, 79.47it/s]\n",
" 61%|██████▏ | 84/137 [00:08<00:05, 8.97it/s]\n",
"Running evaluation: 7it [00:00, 77.83it/s]\n",
" 62%|██████▏ | 85/137 [00:08<00:05, 9.12it/s]\n",
"Running evaluation: 7it [00:00, 79.50it/s]\n",
" 63%|██████▎ | 86/137 [00:08<00:05, 9.26it/s]\n",
"Running evaluation: 7it [00:00, 78.65it/s]\n",
" 64%|██████▎ | 87/137 [00:08<00:05, 9.33it/s]\n",
"Running evaluation: 7it [00:00, 77.89it/s]\n",
" 64%|██████▍ | 88/137 [00:09<00:05, 9.37it/s]\n",
"Running evaluation: 7it [00:00, 77.88it/s]\n",
" 65%|██████▍ | 89/137 [00:09<00:05, 9.39it/s]\n",
"Running evaluation: 7it [00:00, 78.58it/s]\n",
" 66%|██████▌ | 90/137 [00:09<00:04, 9.43it/s]\n",
"Running evaluation: 7it [00:00, 78.98it/s]\n",
" 66%|██████▋ | 91/137 [00:09<00:04, 9.45it/s]\n",
"Running evaluation: 7it [00:00, 79.47it/s]\n",
" 67%|██████▋ | 92/137 [00:09<00:04, 9.47it/s]\n",
"Running evaluation: 7it [00:00, 75.97it/s]\n",
" 68%|██████▊ | 93/137 [00:09<00:04, 9.37it/s]\n",
"Running evaluation: 7it [00:00, 81.28it/s]\n",
" 69%|██████▊ | 94/137 [00:09<00:04, 9.47it/s]\n",
"Running evaluation: 7it [00:00, 81.42it/s]\n",
" 69%|██████▉ | 95/137 [00:09<00:04, 9.57it/s]\n",
"Running evaluation: 7it [00:00, 78.29it/s]\n",
" 70%|███████ | 96/137 [00:09<00:04, 9.51it/s]\n",
"Running evaluation: 7it [00:00, 78.14it/s]\n",
" 71%|███████ | 97/137 [00:10<00:04, 9.49it/s]\n",
"Running evaluation: 7it [00:00, 79.59it/s]\n",
" 72%|███████▏ | 98/137 [00:10<00:04, 9.50it/s]\n",
"Running evaluation: 7it [00:00, 81.20it/s]\n",
" 72%|███████▏ | 99/137 [00:10<00:03, 9.56it/s]\n",
"Running evaluation: 7it [00:00, 81.43it/s]\n",
" 73%|███████▎ | 100/137 [00:10<00:03, 9.66it/s]\n",
"Running evaluation: 7it [00:00, 82.14it/s]\n",
" 74%|███████▎ | 101/137 [00:10<00:03, 9.76it/s]\n",
"Running evaluation: 7it [00:00, 80.95it/s]\n",
" 74%|███████▍ | 102/137 [00:10<00:03, 9.79it/s]\n",
"Running evaluation: 7it [00:00, 75.02it/s]\n",
" 75%|███████▌ | 103/137 [00:10<00:03, 9.58it/s]\n",
"Running evaluation: 7it [00:00, 80.46it/s]\n",
" 76%|███████▌ | 104/137 [00:10<00:03, 9.63it/s]\n",
"Running evaluation: 7it [00:00, 81.08it/s]\n",
" 77%|███████▋ | 105/137 [00:10<00:03, 9.70it/s]\n",
"Running evaluation: 7it [00:00, 82.26it/s]\n",
" 77%|███████▋ | 106/137 [00:10<00:03, 9.77it/s]\n",
"Running evaluation: 7it [00:00, 81.13it/s]\n",
" 78%|███████▊ | 107/137 [00:11<00:03, 9.80it/s]\n",
"Running evaluation: 7it [00:00, 81.79it/s]\n",
" 79%|███████▉ | 108/137 [00:11<00:02, 9.81it/s]\n",
"Running evaluation: 7it [00:00, 81.81it/s]\n",
" 80%|███████▉ | 109/137 [00:11<00:02, 9.83it/s]\n",
"Running evaluation: 7it [00:00, 79.83it/s]\n",
" 80%|████████ | 110/137 [00:11<00:02, 9.78it/s]\n",
"Running evaluation: 7it [00:00, 81.44it/s]\n",
" 81%|████████ | 111/137 [00:11<00:02, 9.81it/s]\n",
"Running evaluation: 7it [00:00, 81.57it/s]\n",
" 82%|████████▏ | 112/137 [00:11<00:02, 9.85it/s]\n",
"Running evaluation: 7it [00:00, 79.45it/s]\n",
" 82%|████████▏ | 113/137 [00:11<00:02, 9.77it/s]\n",
"Running evaluation: 7it [00:00, 77.78it/s]\n",
" 83%|████████▎ | 114/137 [00:11<00:02, 9.68it/s]\n",
"Running evaluation: 7it [00:00, 81.29it/s]\n",
" 84%|████████▍ | 115/137 [00:11<00:02, 9.74it/s]\n",
"Running evaluation: 7it [00:00, 81.49it/s]\n",
" 85%|████████▍ | 116/137 [00:11<00:02, 9.80it/s]\n",
"Running evaluation: 7it [00:00, 80.46it/s]\n",
" 85%|████████▌ | 117/137 [00:12<00:02, 9.80it/s]\n",
"Running evaluation: 7it [00:00, 80.67it/s]\n",
" 86%|████████▌ | 118/137 [00:12<00:01, 9.81it/s]\n",
"Running evaluation: 7it [00:00, 81.37it/s]\n",
" 87%|████████▋ | 119/137 [00:12<00:01, 9.84it/s]\n",
"Running evaluation: 7it [00:00, 78.34it/s]\n",
" 88%|████████▊ | 120/137 [00:12<00:01, 9.74it/s]\n",
"Running evaluation: 7it [00:00, 79.76it/s]\n",
" 88%|████████▊ | 121/137 [00:12<00:01, 9.70it/s]\n",
"Running evaluation: 7it [00:00, 79.34it/s]\n",
" 89%|████████▉ | 122/137 [00:12<00:01, 9.64it/s]\n",
"Running evaluation: 7it [00:00, 78.06it/s]\n",
" 90%|████████▉ | 123/137 [00:12<00:01, 9.58it/s]\n",
"Running evaluation: 7it [00:00, 81.37it/s]\n",
" 91%|█████████ | 124/137 [00:12<00:01, 9.67it/s]\n",
"Running evaluation: 7it [00:00, 80.66it/s]\n",
" 91%|█████████ | 125/137 [00:12<00:01, 9.72it/s]\n",
"Running evaluation: 7it [00:00, 81.24it/s]\n",
" 92%|█████████▏| 126/137 [00:13<00:01, 9.77it/s]\n",
"Running evaluation: 7it [00:00, 82.04it/s]\n",
" 93%|█████████▎| 127/137 [00:13<00:01, 9.83it/s]\n",
"Running evaluation: 7it [00:00, 81.05it/s]\n",
" 93%|█████████▎| 128/137 [00:13<00:00, 9.82it/s]\n",
"Running evaluation: 7it [00:00, 83.03it/s]\n",
"\n",
"Running evaluation: 7it [00:00, 82.04it/s]\n",
" 95%|█████████▍| 130/137 [00:13<00:00, 9.88it/s]\n",
"Running evaluation: 7it [00:00, 80.86it/s]\n",
" 96%|█████████▌| 131/137 [00:13<00:00, 9.86it/s]\n",
"Running evaluation: 7it [00:00, 81.55it/s]\n",
" 96%|█████████▋| 132/137 [00:13<00:00, 9.87it/s]\n",
"Running evaluation: 7it [00:00, 81.48it/s]\n",
" 97%|█████████▋| 133/137 [00:13<00:00, 9.88it/s]\n",
"Running evaluation: 7it [00:00, 81.72it/s]\n",
" 98%|█████████▊| 134/137 [00:13<00:00, 9.90it/s]\n",
"Running evaluation: 7it [00:00, 81.09it/s]\n",
" 99%|█████████▊| 135/137 [00:13<00:00, 9.82it/s]\n",
"Running evaluation: 7it [00:00, 76.88it/s]\n",
" 99%|█████████▉| 136/137 [00:14<00:00, 9.65it/s]\n",
"Running evaluation: 7it [00:00, 81.92it/s]\n",
"100%|██████████| 137/137 [00:14<00:00, 9.70it/s]\n",
"Running evaluation: 7it [00:00, 61.47it/s]\n"
]
}
],
"source": [
"predictor = estimator.train(dataset_train)\n",
"forecast_it, ts_it = make_evaluation_predictions(dataset=dataset_test,\n",
" predictor=predictor,\n",
" num_samples=100)\n",
"forecasts = list(forecast_it)\n",
"targets = list(ts_it)\n",
"\n",
"agg_metric, _ = evaluator(targets, forecasts, num_series=len(dataset_test))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Metrics"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CRPS: 0.3855313301520275\n",
"ND: 0.48820539490099113\n",
"NRMSE: 1.018839692673421\n",
"MSE: 984.6672641166102\n"
]
}
],
"source": [
"print(\"CRPS: {}\".format(agg_metric['mean_wQuantileLoss']))\n",
"print(\"ND: {}\".format(agg_metric['ND']))\n",
"print(\"NRMSE: {}\".format(agg_metric['NRMSE']))\n",
"print(\"MSE: {}\".format(agg_metric['MSE']))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CRPS-Sum: 0.3268739166960563\n",
"ND-Sum: 0.40321702146475014\n",
"NRMSE-Sum: 0.75586334994103\n",
"MSE-Sum: 10171980.5\n"
]
}
],
"source": [
"print(\"CRPS-Sum: {}\".format(agg_metric['m_sum_mean_wQuantileLoss']))\n",
"print(\"ND-Sum: {}\".format(agg_metric['m_sum_ND']))\n",
"print(\"NRMSE-Sum: {}\".format(agg_metric['m_sum_NRMSE']))\n",
"print(\"MSE-Sum: {}\".format(agg_metric['m_sum_MSE']))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## `Transformer-MAF`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"estimator = TransformerTempFlowEstimator(\n",
" d_model=16,\n",
" num_heads=4,\n",
" input_size=552,\n",
" target_dim=int(dataset.metadata.feat_static_cat[0].cardinality),\n",
" prediction_length=dataset.metadata.prediction_length,\n",
" context_length=dataset.metadata.prediction_length*4,\n",
" flow_type='MAF',\n",
" dequantize=True,\n",
" freq=dataset.metadata.freq,\n",
" trainer=Trainer(\n",
" device=device,\n",
" epochs=14,\n",
" learning_rate=1e-3,\n",
" num_batches_per_epoch=100,\n",
" batch_size=64,\n",
" )\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"99it [00:26, 3.70it/s, avg_epoch_loss=-82.7, epoch=0]\n",
"Running evaluation: 7it [00:00, 121.58it/s]\n",
"Running evaluation: 7it [00:00, 129.89it/s]\n",
"Running evaluation: 7it [00:00, 133.02it/s]\n",
"Running evaluation: 7it [00:00, 133.71it/s]\n",
"Running evaluation: 7it [00:00, 129.68it/s]\n",
"Running evaluation: 7it [00:00, 130.11it/s]\n",
"Running evaluation: 7it [00:00, 135.91it/s]\n",
"Running evaluation: 7it [00:00, 134.94it/s]\n",
"Running evaluation: 7it [00:00, 127.03it/s]\n",
"Running evaluation: 7it [00:00, 131.79it/s]\n",
"Running evaluation: 7it [00:00, 131.80it/s]\n",
"Running evaluation: 7it [00:00, 129.62it/s]\n",
"Running evaluation: 7it [00:00, 130.80it/s]\n",
"Running evaluation: 7it [00:00, 134.32it/s]\n",
"Running evaluation: 7it [00:00, 135.98it/s]\n",
"Running evaluation: 7it [00:00, 132.59it/s]\n",
"Running evaluation: 7it [00:00, 132.17it/s]\n",
"Running evaluation: 7it [00:00, 131.03it/s]\n",
"Running evaluation: 7it [00:00, 130.69it/s]\n",
"Running evaluation: 7it [00:00, 130.72it/s]\n",
"Running evaluation: 7it [00:00, 132.53it/s]\n",
"Running evaluation: 7it [00:00, 129.63it/s]\n",
"Running evaluation: 7it [00:00, 65.78it/s]\n",
"Running evaluation: 7it [00:00, 133.88it/s]\n",
"Running evaluation: 7it [00:00, 129.27it/s]\n",
"Running evaluation: 7it [00:00, 134.82it/s]\n",
"Running evaluation: 7it [00:00, 133.96it/s]\n",
"Running evaluation: 7it [00:00, 130.77it/s]\n",
"Running evaluation: 7it [00:00, 130.26it/s]\n",
"Running evaluation: 7it [00:00, 130.87it/s]\n",
"Running evaluation: 7it [00:00, 129.28it/s]\n",
"Running evaluation: 7it [00:00, 129.81it/s]\n",
"Running evaluation: 7it [00:00, 132.89it/s]\n",
"Running evaluation: 7it [00:00, 132.76it/s]\n",
"Running evaluation: 7it [00:00, 131.64it/s]\n",
"Running evaluation: 7it [00:00, 133.07it/s]\n",
"Running evaluation: 7it [00:00, 128.50it/s]\n",
"Running evaluation: 7it [00:00, 135.86it/s]\n",
"Running evaluation: 7it [00:00, 130.13it/s]\n",
"Running evaluation: 7it [00:00, 129.31it/s]\n",
"Running evaluation: 7it [00:00, 128.67it/s]\n",
"Running evaluation: 7it [00:00, 134.41it/s]\n",
"Running evaluation: 7it [00:00, 128.88it/s]\n",
"Running evaluation: 7it [00:00, 134.21it/s]\n",
"Running evaluation: 7it [00:00, 134.11it/s]\n",
"Running evaluation: 7it [00:00, 133.82it/s]\n",
"Running evaluation: 7it [00:00, 131.31it/s]\n",
"Running evaluation: 7it [00:00, 128.78it/s]\n",
"Running evaluation: 7it [00:00, 128.76it/s]\n",
"Running evaluation: 7it [00:00, 127.98it/s]\n",
"Running evaluation: 7it [00:00, 130.26it/s]\n",
"Running evaluation: 7it [00:00, 120.39it/s]\n",
"Running evaluation: 7it [00:00, 134.66it/s]\n",
"Running evaluation: 7it [00:00, 134.51it/s]\n",
"Running evaluation: 7it [00:00, 125.47it/s]\n",
"Running evaluation: 7it [00:00, 133.05it/s]\n",
"Running evaluation: 7it [00:00, 129.13it/s]\n",
"Running evaluation: 7it [00:00, 131.84it/s]\n",
"Running evaluation: 7it [00:00, 130.52it/s]\n",
"Running evaluation: 7it [00:00, 136.95it/s]\n",
"Running evaluation: 7it [00:00, 135.88it/s]\n",
"Running evaluation: 7it [00:00, 137.97it/s]\n",
"Running evaluation: 7it [00:00, 136.48it/s]\n",
"Running evaluation: 7it [00:00, 137.81it/s]\n",
"Running evaluation: 7it [00:00, 138.88it/s]\n",
"Running evaluation: 7it [00:00, 140.48it/s]\n",
"Running evaluation: 7it [00:00, 139.82it/s]\n",
"Running evaluation: 7it [00:00, 137.45it/s]\n",
"Running evaluation: 7it [00:00, 139.96it/s]\n",
"Running evaluation: 7it [00:00, 139.87it/s]\n",
"Running evaluation: 7it [00:00, 137.53it/s]\n",
"Running evaluation: 7it [00:00, 136.43it/s]\n",
"Running evaluation: 7it [00:00, 129.52it/s]\n",
"Running evaluation: 7it [00:00, 134.57it/s]\n",
"Running evaluation: 7it [00:00, 136.23it/s]\n",
"Running evaluation: 7it [00:00, 141.61it/s]\n",
"Running evaluation: 7it [00:00, 137.81it/s]\n",
"Running evaluation: 7it [00:00, 137.27it/s]\n",
"Running evaluation: 7it [00:00, 138.90it/s]\n",
"Running evaluation: 7it [00:00, 138.50it/s]\n",
"Running evaluation: 7it [00:00, 136.98it/s]\n",
"Running evaluation: 7it [00:00, 121.52it/s]\n",
"Running evaluation: 7it [00:00, 129.19it/s]\n",
"Running evaluation: 7it [00:00, 136.95it/s]\n",
"Running evaluation: 7it [00:00, 138.76it/s]\n",
"Running evaluation: 7it [00:00, 135.65it/s]\n",
"Running evaluation: 7it [00:00, 137.78it/s]\n",
"Running evaluation: 7it [00:00, 130.82it/s]\n",
"Running evaluation: 7it [00:00, 129.72it/s]\n",
"Running evaluation: 7it [00:00, 133.23it/s]\n",
"Running evaluation: 7it [00:00, 136.86it/s]\n",
"Running evaluation: 7it [00:00, 140.74it/s]\n",
"Running evaluation: 7it [00:00, 134.70it/s]\n",
"Running evaluation: 7it [00:00, 138.99it/s]\n",
"Running evaluation: 7it [00:00, 133.51it/s]\n",
"Running evaluation: 7it [00:00, 129.42it/s]\n",
"Running evaluation: 7it [00:00, 133.08it/s]\n",
"Running evaluation: 7it [00:00, 132.08it/s]\n",
"Running evaluation: 7it [00:00, 135.16it/s]\n",
"Running evaluation: 7it [00:00, 135.33it/s]\n",
"Running evaluation: 7it [00:00, 132.90it/s]\n",
"Running evaluation: 7it [00:00, 129.54it/s]\n",
"Running evaluation: 7it [00:00, 128.03it/s]\n",
"Running evaluation: 7it [00:00, 129.57it/s]\n",
"Running evaluation: 7it [00:00, 129.91it/s]\n",
"Running evaluation: 7it [00:00, 131.76it/s]\n",
"Running evaluation: 7it [00:00, 130.74it/s]\n",
"Running evaluation: 7it [00:00, 131.16it/s]\n",
"Running evaluation: 7it [00:00, 126.87it/s]\n",
"Running evaluation: 7it [00:00, 133.84it/s]\n",
"Running evaluation: 7it [00:00, 130.86it/s]\n",
"Running evaluation: 7it [00:00, 126.03it/s]\n",
"Running evaluation: 7it [00:00, 125.04it/s]\n",
"Running evaluation: 7it [00:00, 128.97it/s]\n",
"Running evaluation: 7it [00:00, 130.28it/s]\n",
"Running evaluation: 7it [00:00, 127.62it/s]\n",
"Running evaluation: 7it [00:00, 134.83it/s]\n",
"Running evaluation: 7it [00:00, 135.22it/s]\n",
"Running evaluation: 7it [00:00, 118.29it/s]\n",
"Running evaluation: 7it [00:00, 124.76it/s]\n",
"Running evaluation: 7it [00:00, 128.80it/s]\n",
"Running evaluation: 7it [00:00, 129.55it/s]\n",
"Running evaluation: 7it [00:00, 129.08it/s]\n",
"Running evaluation: 7it [00:00, 130.09it/s]\n",
"Running evaluation: 7it [00:00, 126.55it/s]\n",
"Running evaluation: 7it [00:00, 128.26it/s]\n",
"Running evaluation: 7it [00:00, 132.86it/s]\n",
"Running evaluation: 7it [00:00, 128.60it/s]\n",
"Running evaluation: 7it [00:00, 128.98it/s]\n",
"Running evaluation: 7it [00:00, 127.39it/s]\n",
"Running evaluation: 7it [00:00, 133.66it/s]\n",
"Running evaluation: 7it [00:00, 128.62it/s]\n",
"Running evaluation: 7it [00:00, 130.26it/s]\n",
"Running evaluation: 7it [00:00, 70.39it/s]\n",
"Running evaluation: 7it [00:00, 126.12it/s]\n",
"Running evaluation: 7it [00:00, 128.30it/s]\n",
"Running evaluation: 7it [00:00, 129.34it/s]\n",
"Running evaluation: 7it [00:00, 54.24it/s]\n"
]
}
],
"source": [
"predictor = estimator.train(dataset_train)\n",
"forecast_it, ts_it = make_evaluation_predictions(dataset=dataset_test,\n",
" predictor=predictor,\n",
" num_samples=100)\n",
"forecasts = list(forecast_it)\n",
"targets = list(ts_it)\n",
"\n",
"agg_metric, _ = evaluator(targets, forecasts, num_series=len(dataset_test))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Metrics"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CRPS: 0.37264046134993567\n",
"ND: 0.5043621354947913\n",
"NRMSE: 0.9928759300158241\n",
"MSE: 935.1208752979203\n"
]
}
],
"source": [
"print(\"CRPS: {}\".format(agg_metric['mean_wQuantileLoss']))\n",
"print(\"ND: {}\".format(agg_metric['ND']))\n",
"print(\"NRMSE: {}\".format(agg_metric['NRMSE']))\n",
"print(\"MSE: {}\".format(agg_metric['MSE']))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CRPS-Sum: 0.30787625107438427\n",
"ND-Sum: 0.4188356756894787\n",
"NRMSE-Sum: 0.7504274205713227\n",
"MSE-Sum: 10026199.285714285\n"
]
}
],
"source": [
"print(\"CRPS-Sum: {}\".format(agg_metric['m_sum_mean_wQuantileLoss']))\n",
"print(\"ND-Sum: {}\".format(agg_metric['m_sum_ND']))\n",
"print(\"NRMSE-Sum: {}\".format(agg_metric['m_sum_NRMSE']))\n",
"print(\"MSE-Sum: {}\".format(agg_metric['m_sum_MSE']))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.15 ('glounts')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
},
"vscode": {
"interpreter": {
"hash": "7f25a1f13147a60511cf6766827402baf95cbe50d53a241197155306ee38fe70"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}