{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "\n", "import torch" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from gluonts.dataset.multivariate_grouper import MultivariateGrouper\n", "from gluonts.dataset.repository.datasets import dataset_recipes, get_dataset\n", "from pts.model.tempflow import TempFlowEstimator\n", "from pts.model.transformer_tempflow import TransformerTempFlowEstimator\n", "from pts import Trainer\n", "from gluonts.evaluation.backtest import make_evaluation_predictions\n", "from gluonts.evaluation import MultivariateEvaluator" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Prepeare data set" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "dataset = get_dataset(\"solar_nips\", regenerate=False)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "MetaData(freq='H', target=None, feat_static_cat=[CategoricalFeatureInfo(name='feat_static_cat_0', cardinality='137')], feat_static_real=[], feat_dynamic_real=[], feat_dynamic_cat=[], prediction_length=24)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset.metadata" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "train_grouper = MultivariateGrouper(max_target_dim=int(dataset.metadata.feat_static_cat[0].cardinality))\n", "\n", "test_grouper = MultivariateGrouper(num_test_dates=int(len(dataset.test)/len(dataset.train)), \n", " max_target_dim=int(dataset.metadata.feat_static_cat[0].cardinality))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/wassname/miniforge3/envs/glounts/lib/python3.9/site-packages/gluonts/dataset/multivariate_grouper.py:191: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n", " return {FieldName.TARGET: np.array([funcs(data) for data in dataset])}\n" ] } ], "source": [ "dataset_train = train_grouper(dataset.train)\n", "dataset_test = test_grouper(dataset.test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluator" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "evaluator = MultivariateEvaluator(quantiles=(np.arange(20)/20.0)[1:],\n", " target_agg_funcs={'sum': np.sum})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## `GRU-Real-NVP`" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "estimator = TempFlowEstimator(\n", " target_dim=int(dataset.metadata.feat_static_cat[0].cardinality),\n", " prediction_length=dataset.metadata.prediction_length,\n", " cell_type='GRU',\n", " input_size=552,\n", " freq=dataset.metadata.freq,\n", " scaling=True,\n", " dequantize=True,\n", " n_blocks=4,\n", " trainer=Trainer(device=device,\n", " epochs=45,\n", " learning_rate=1e-3,\n", " num_batches_per_epoch=100,\n", " batch_size=64)\n", ")" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e6b529fa683f4e51b5f8ac63db471849", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/99 [00:00\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m predictor \u001b[39m=\u001b[39m estimator\u001b[39m.\u001b[39;49mtrain(dataset_train)\n\u001b[1;32m 2\u001b[0m forecast_it, ts_it \u001b[39m=\u001b[39m make_evaluation_predictions(dataset\u001b[39m=\u001b[39mdataset_test,\n\u001b[1;32m 3\u001b[0m predictor\u001b[39m=\u001b[39mpredictor,\n\u001b[1;32m 4\u001b[0m num_samples\u001b[39m=\u001b[39m\u001b[39m100\u001b[39m)\n\u001b[1;32m 5\u001b[0m forecasts \u001b[39m=\u001b[39m \u001b[39mlist\u001b[39m(forecast_it)\n", "File \u001b[0;32m/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/pts/model/estimator.py:179\u001b[0m, in \u001b[0;36mPyTorchEstimator.train\u001b[0;34m(self, training_data, validation_data, num_workers, prefetch_factor, shuffle_buffer_length, cache_data, **kwargs)\u001b[0m\n\u001b[1;32m 169\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mtrain\u001b[39m(\n\u001b[1;32m 170\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 171\u001b[0m training_data: Dataset,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 177\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs,\n\u001b[1;32m 178\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m PyTorchPredictor:\n\u001b[0;32m--> 179\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtrain_model(\n\u001b[1;32m 180\u001b[0m training_data,\n\u001b[1;32m 181\u001b[0m validation_data,\n\u001b[1;32m 182\u001b[0m num_workers\u001b[39m=\u001b[39;49mnum_workers,\n\u001b[1;32m 183\u001b[0m prefetch_factor\u001b[39m=\u001b[39;49mprefetch_factor,\n\u001b[1;32m 184\u001b[0m shuffle_buffer_length\u001b[39m=\u001b[39;49mshuffle_buffer_length,\n\u001b[1;32m 185\u001b[0m cache_data\u001b[39m=\u001b[39;49mcache_data,\n\u001b[1;32m 186\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs,\n\u001b[1;32m 187\u001b[0m )\u001b[39m.\u001b[39mpredictor\n", "File \u001b[0;32m/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/pts/model/estimator.py:151\u001b[0m, in \u001b[0;36mPyTorchEstimator.train_model\u001b[0;34m(self, training_data, validation_data, num_workers, prefetch_factor, shuffle_buffer_length, cache_data, **kwargs)\u001b[0m\n\u001b[1;32m 133\u001b[0m validation_iter_dataset \u001b[39m=\u001b[39m TransformedIterableDataset(\n\u001b[1;32m 134\u001b[0m dataset\u001b[39m=\u001b[39mvalidation_data,\n\u001b[1;32m 135\u001b[0m transform\u001b[39m=\u001b[39mtransformation\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 139\u001b[0m cache_data\u001b[39m=\u001b[39mcache_data,\n\u001b[1;32m 140\u001b[0m )\n\u001b[1;32m 141\u001b[0m validation_data_loader \u001b[39m=\u001b[39m DataLoader(\n\u001b[1;32m 142\u001b[0m validation_iter_dataset,\n\u001b[1;32m 143\u001b[0m batch_size\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtrainer\u001b[39m.\u001b[39mbatch_size,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 148\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs,\n\u001b[1;32m 149\u001b[0m )\n\u001b[0;32m--> 151\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtrainer(\n\u001b[1;32m 152\u001b[0m net\u001b[39m=\u001b[39;49mtrained_net,\n\u001b[1;32m 153\u001b[0m train_iter\u001b[39m=\u001b[39;49mtraining_data_loader,\n\u001b[1;32m 154\u001b[0m validation_iter\u001b[39m=\u001b[39;49mvalidation_data_loader,\n\u001b[1;32m 155\u001b[0m )\n\u001b[1;32m 157\u001b[0m \u001b[39mreturn\u001b[39;00m TrainOutput(\n\u001b[1;32m 158\u001b[0m transformation\u001b[39m=\u001b[39mtransformation,\n\u001b[1;32m 159\u001b[0m trained_net\u001b[39m=\u001b[39mtrained_net,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 162\u001b[0m ),\n\u001b[1;32m 163\u001b[0m )\n", "File \u001b[0;32m/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/pts/trainer.py:63\u001b[0m, in \u001b[0;36mTrainer.__call__\u001b[0;34m(self, net, train_iter, validation_iter)\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[39m# training loop\u001b[39;00m\n\u001b[1;32m 62\u001b[0m \u001b[39mwith\u001b[39;00m tqdm(train_iter, total\u001b[39m=\u001b[39mtotal) \u001b[39mas\u001b[39;00m it:\n\u001b[0;32m---> 63\u001b[0m \u001b[39mfor\u001b[39;00m batch_no, data_entry \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(it, start\u001b[39m=\u001b[39m\u001b[39m1\u001b[39m):\n\u001b[1;32m 64\u001b[0m optimizer\u001b[39m.\u001b[39mzero_grad()\n\u001b[1;32m 66\u001b[0m inputs \u001b[39m=\u001b[39m [v\u001b[39m.\u001b[39mto(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdevice) \u001b[39mfor\u001b[39;00m v \u001b[39min\u001b[39;00m data_entry\u001b[39m.\u001b[39mvalues()]\n", "File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/tqdm/notebook.py:259\u001b[0m, in \u001b[0;36mtqdm_notebook.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 257\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 258\u001b[0m it \u001b[39m=\u001b[39m \u001b[39msuper\u001b[39m(tqdm_notebook, \u001b[39mself\u001b[39m)\u001b[39m.\u001b[39m\u001b[39m__iter__\u001b[39m()\n\u001b[0;32m--> 259\u001b[0m \u001b[39mfor\u001b[39;00m obj \u001b[39min\u001b[39;00m it:\n\u001b[1;32m 260\u001b[0m \u001b[39m# return super(tqdm...) will not catch exception\u001b[39;00m\n\u001b[1;32m 261\u001b[0m \u001b[39myield\u001b[39;00m obj\n\u001b[1;32m 262\u001b[0m \u001b[39m# NB: except ... [ as ...] breaks IPython async KeyboardInterrupt\u001b[39;00m\n", "File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/tqdm/std.py:1195\u001b[0m, in \u001b[0;36mtqdm.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1192\u001b[0m time \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_time\n\u001b[1;32m 1194\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m-> 1195\u001b[0m \u001b[39mfor\u001b[39;00m obj \u001b[39min\u001b[39;00m iterable:\n\u001b[1;32m 1196\u001b[0m \u001b[39myield\u001b[39;00m obj\n\u001b[1;32m 1197\u001b[0m \u001b[39m# Update and possibly print the progressbar.\u001b[39;00m\n\u001b[1;32m 1198\u001b[0m \u001b[39m# Note: does not call self.update(1) for speed optimisation.\u001b[39;00m\n", "File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/torch/utils/data/dataloader.py:628\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 625\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_sampler_iter \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 626\u001b[0m \u001b[39m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[1;32m 627\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_reset() \u001b[39m# type: ignore[call-arg]\u001b[39;00m\n\u001b[0;32m--> 628\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_next_data()\n\u001b[1;32m 629\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\n\u001b[1;32m 630\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_dataset_kind \u001b[39m==\u001b[39m _DatasetKind\u001b[39m.\u001b[39mIterable \u001b[39mand\u001b[39;00m \\\n\u001b[1;32m 631\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m \\\n\u001b[1;32m 632\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m>\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called:\n", "File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/torch/utils/data/dataloader.py:671\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 669\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_next_data\u001b[39m(\u001b[39mself\u001b[39m):\n\u001b[1;32m 670\u001b[0m index \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_next_index() \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m--> 671\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_dataset_fetcher\u001b[39m.\u001b[39;49mfetch(index) \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m 672\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_pin_memory:\n\u001b[1;32m 673\u001b[0m data \u001b[39m=\u001b[39m _utils\u001b[39m.\u001b[39mpin_memory\u001b[39m.\u001b[39mpin_memory(data, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_pin_memory_device)\n", "File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py:34\u001b[0m, in \u001b[0;36m_IterableDatasetFetcher.fetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[39mfor\u001b[39;00m _ \u001b[39min\u001b[39;00m possibly_batched_index:\n\u001b[1;32m 33\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m---> 34\u001b[0m data\u001b[39m.\u001b[39mappend(\u001b[39mnext\u001b[39;49m(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdataset_iter))\n\u001b[1;32m 35\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mStopIteration\u001b[39;00m:\n\u001b[1;32m 36\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mended \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m\n", "File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/gluonts/transform/_base.py:103\u001b[0m, in \u001b[0;36mTransformedDataset.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 102\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__iter__\u001b[39m(\u001b[39mself\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Iterator[DataEntry]:\n\u001b[0;32m--> 103\u001b[0m \u001b[39myield from\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtransformation(\n\u001b[1;32m 104\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbase_dataset, is_train\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mis_train\n\u001b[1;32m 105\u001b[0m )\n", "File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/gluonts/transform/_base.py:124\u001b[0m, in \u001b[0;36mMapTransformation.__call__\u001b[0;34m(self, data_it, is_train)\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__call__\u001b[39m(\n\u001b[1;32m 122\u001b[0m \u001b[39mself\u001b[39m, data_it: Iterable[DataEntry], is_train: \u001b[39mbool\u001b[39m\n\u001b[1;32m 123\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Iterator:\n\u001b[0;32m--> 124\u001b[0m \u001b[39mfor\u001b[39;00m data_entry \u001b[39min\u001b[39;00m data_it:\n\u001b[1;32m 125\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 126\u001b[0m \u001b[39myield\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmap_transform(data_entry\u001b[39m.\u001b[39mcopy(), is_train)\n", "File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/gluonts/transform/_base.py:124\u001b[0m, in \u001b[0;36mMapTransformation.__call__\u001b[0;34m(self, data_it, is_train)\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__call__\u001b[39m(\n\u001b[1;32m 122\u001b[0m \u001b[39mself\u001b[39m, data_it: Iterable[DataEntry], is_train: \u001b[39mbool\u001b[39m\n\u001b[1;32m 123\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Iterator:\n\u001b[0;32m--> 124\u001b[0m \u001b[39mfor\u001b[39;00m data_entry \u001b[39min\u001b[39;00m data_it:\n\u001b[1;32m 125\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 126\u001b[0m \u001b[39myield\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmap_transform(data_entry\u001b[39m.\u001b[39mcopy(), is_train)\n", "File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/gluonts/transform/_base.py:189\u001b[0m, in \u001b[0;36mFlatMapTransformation.__call__\u001b[0;34m(self, data_it, is_train)\u001b[0m\n\u001b[1;32m 182\u001b[0m \u001b[39myield\u001b[39;00m result\n\u001b[1;32m 184\u001b[0m \u001b[39mif\u001b[39;00m (\n\u001b[1;32m 185\u001b[0m \u001b[39m# negative values disable the check\u001b[39;00m\n\u001b[1;32m 186\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmax_idle_transforms \u001b[39m>\u001b[39m \u001b[39m0\u001b[39m\n\u001b[1;32m 187\u001b[0m \u001b[39mand\u001b[39;00m num_idle_transforms \u001b[39m>\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmax_idle_transforms\n\u001b[1;32m 188\u001b[0m ):\n\u001b[0;32m--> 189\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mException\u001b[39;00m(\n\u001b[1;32m 190\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mReached maximum number of idle transformation\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 191\u001b[0m \u001b[39m\"\u001b[39m\u001b[39m calls.\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39mThis means the transformation looped over\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 192\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m \u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmax_idle_transforms\u001b[39m}\u001b[39;00m\u001b[39m inputs without returning any\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 193\u001b[0m \u001b[39m\"\u001b[39m\u001b[39m output.\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39mThis occurred in the following\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 194\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m transformation:\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m\n\u001b[1;32m 195\u001b[0m )\n", "\u001b[0;31mException\u001b[0m: Reached maximum number of idle transformation calls.\nThis means the transformation looped over 1 inputs without returning any output.\nThis occurred in the following transformation:\ngluonts.transform.split.InstanceSplitter(dummy_value=0.0, forecast_start_field=\"forecast_start\", future_length=24, instance_sampler=gluonts.transform.sampler.ExpectedNumInstanceSampler(axis=-1, min_past=192, min_future=24, num_instances=1.0, total_length=20382, n=3), is_pad_field=\"is_pad\", lead_time=0, output_NTC=True, past_length=192, start_field=\"start\", target_field=\"target\", time_series_fields=[\"time_feat\", \"observed_values\"])" ] } ], "source": [ "predictor = estimator.train(dataset_train)\n", "forecast_it, ts_it = make_evaluation_predictions(dataset=dataset_test,\n", " predictor=predictor,\n", " num_samples=100)\n", "forecasts = list(forecast_it)\n", "targets = list(ts_it)\n", "\n", "agg_metric, _ = evaluator(targets, forecasts, num_series=len(dataset_test))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Metrics" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CRPS: 0.36531966950112466\n", "ND: 0.45434020382814283\n", "NRMSE: 0.9820216603495642\n", "MSE: 914.7868680304274\n" ] } ], "source": [ "print(\"CRPS: {}\".format(agg_metric['mean_wQuantileLoss']))\n", "print(\"ND: {}\".format(agg_metric['ND']))\n", "print(\"NRMSE: {}\".format(agg_metric['NRMSE']))\n", "print(\"MSE: {}\".format(agg_metric['MSE']))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CRPS-Sum: 0.2873863376280519\n", "ND-Sum: 0.35970480888579265\n", "NRMSE-Sum: 0.7184166842326591\n", "MSE-Sum: 9189074.285714285\n" ] } ], "source": [ "print(\"CRPS-Sum: {}\".format(agg_metric['m_sum_mean_wQuantileLoss']))\n", "print(\"ND-Sum: {}\".format(agg_metric['m_sum_ND']))\n", "print(\"NRMSE-Sum: {}\".format(agg_metric['m_sum_NRMSE']))\n", "print(\"MSE-Sum: {}\".format(agg_metric['m_sum_MSE']))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## `GRU-MAF`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "estimator = TempFlowEstimator(\n", " target_dim=int(dataset.metadata.feat_static_cat[0].cardinality),\n", " prediction_length=dataset.metadata.prediction_length,\n", " cell_type='GRU',\n", " input_size=552,\n", " freq=dataset.metadata.freq,\n", " scaling=True,\n", " dequantize=True,\n", " flow_type='MAF',\n", " trainer=Trainer(device=device,\n", " epochs=25,\n", " learning_rate=1e-3,\n", " num_batches_per_epoch=100,\n", " batch_size=64)\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "98it [00:10, 9.05it/s, avg_epoch_loss=-7.36, epoch=0]\n", "99it [00:10, 9.19it/s, avg_epoch_loss=-136, epoch=1]\n", "99it [00:10, 9.12it/s, avg_epoch_loss=-164, epoch=2]\n", "98it [00:10, 8.91it/s, avg_epoch_loss=-179, epoch=3]\n", "98it [00:10, 9.09it/s, avg_epoch_loss=-188, epoch=4]\n", "99it [00:10, 9.05it/s, avg_epoch_loss=-194, epoch=5]\n", "98it [00:10, 9.04it/s, avg_epoch_loss=-198, epoch=6]\n", "98it [00:10, 8.97it/s, avg_epoch_loss=-201, epoch=7]\n", "97it [00:10, 8.90it/s, avg_epoch_loss=-204, epoch=8]\n", "99it [00:10, 9.07it/s, avg_epoch_loss=-206, epoch=9]\n", "99it [00:10, 9.09it/s, avg_epoch_loss=-207, epoch=10]\n", "98it [00:11, 8.90it/s, avg_epoch_loss=-209, epoch=11]\n", "99it [00:10, 9.02it/s, avg_epoch_loss=-210, epoch=12]\n", "98it [00:10, 8.95it/s, avg_epoch_loss=-211, epoch=13]\n", "99it [00:10, 9.21it/s, avg_epoch_loss=-212, epoch=14]\n", "98it [00:10, 9.00it/s, avg_epoch_loss=-213, epoch=15]\n", "99it [00:10, 9.21it/s, avg_epoch_loss=-214, epoch=16]\n", "98it [00:10, 8.95it/s, avg_epoch_loss=-215, epoch=17]\n", "98it [00:11, 8.88it/s, avg_epoch_loss=-216, epoch=18]\n", "99it [00:10, 9.08it/s, avg_epoch_loss=-216, epoch=19]\n", "98it [00:10, 8.96it/s, avg_epoch_loss=-217, epoch=20]\n", "98it [00:10, 8.98it/s, avg_epoch_loss=-218, epoch=21]\n", "97it [00:10, 8.88it/s, avg_epoch_loss=-218, epoch=22]\n", "97it [00:10, 8.83it/s, avg_epoch_loss=-219, epoch=23]\n", "98it [00:10, 8.97it/s, avg_epoch_loss=-219, epoch=24]\n", " 0%| | 0/137 [00:00