mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-06-27 19:32:05 +08:00
674929344b
for issue #12
1403 lines
67 KiB
Plaintext
1403 lines
67 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"import pandas as pd\n",
|
|
"\n",
|
|
"import torch"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from pts.dataset import to_pandas, MultivariateGrouper, TrainDatasets\n",
|
|
"from pts.dataset.repository import get_dataset, dataset_recipes\n",
|
|
"from pts.model.tempflow import TempFlowEstimator\n",
|
|
"from pts.model.transformer_tempflow import TransformerTempFlowEstimator\n",
|
|
"from pts import Trainer\n",
|
|
"from pts.evaluation import make_evaluation_predictions\n",
|
|
"from pts.evaluation import MultivariateEvaluator"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"device = torch.device(\"cuda:3\" if torch.cuda.is_available() else \"cpu\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Prepeare data set"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"dataset = get_dataset(\"solar_nips\", regenerate=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"MetaData(freq='H', target=None, feat_static_cat=[CategoricalFeatureInfo(name='feat_static_cat', cardinality='137')], feat_static_real=[], feat_dynamic_real=[], feat_dynamic_cat=[], prediction_length=24)"
|
|
]
|
|
},
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"dataset.metadata"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_grouper = MultivariateGrouper(max_target_dim=int(dataset.metadata.feat_static_cat[0].cardinality))\n",
|
|
"\n",
|
|
"test_grouper = MultivariateGrouper(num_test_dates=int(len(dataset.test)/len(dataset.train)), \n",
|
|
" max_target_dim=int(dataset.metadata.feat_static_cat[0].cardinality))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"dataset_train = train_grouper(dataset.train)\n",
|
|
"dataset_test = test_grouper(dataset.test)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Evaluator"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"evaluator = MultivariateEvaluator(quantiles=(np.arange(20)/20.0)[1:],\n",
|
|
" target_agg_funcs={'sum': np.sum})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## `GRU-Real-NVP`"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 45,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"estimator = TempFlowEstimator(\n",
|
|
" target_dim=int(dataset.metadata.feat_static_cat[0].cardinality),\n",
|
|
" prediction_length=dataset.metadata.prediction_length,\n",
|
|
" cell_type='GRU',\n",
|
|
" input_size=552,\n",
|
|
" freq=dataset.metadata.freq,\n",
|
|
" scaling=True,\n",
|
|
" dequantize=True,\n",
|
|
" n_blocks=4,\n",
|
|
" trainer=Trainer(device=device,\n",
|
|
" epochs=45,\n",
|
|
" learning_rate=1e-3,\n",
|
|
" num_batches_per_epoch=100,\n",
|
|
" batch_size=64)\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 46,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"99it [00:10, 9.03it/s, avg_epoch_loss=-43.1, epoch=0]\n",
|
|
"99it [00:10, 9.03it/s, avg_epoch_loss=-126, epoch=1]\n",
|
|
"99it [00:11, 9.00it/s, avg_epoch_loss=-142, epoch=2]\n",
|
|
"99it [00:10, 9.37it/s, avg_epoch_loss=-143, epoch=3]\n",
|
|
"99it [00:10, 9.09it/s, avg_epoch_loss=-153, epoch=4]\n",
|
|
"99it [00:11, 8.76it/s, avg_epoch_loss=-157, epoch=5]\n",
|
|
"99it [00:10, 9.03it/s, avg_epoch_loss=-157, epoch=6]\n",
|
|
"99it [00:11, 8.94it/s, avg_epoch_loss=-166, epoch=7]\n",
|
|
"99it [00:11, 8.56it/s, avg_epoch_loss=-169, epoch=8]\n",
|
|
"99it [00:11, 8.84it/s, avg_epoch_loss=-168, epoch=9]\n",
|
|
"98it [00:11, 8.89it/s, avg_epoch_loss=-170, epoch=10]\n",
|
|
"99it [00:11, 8.89it/s, avg_epoch_loss=-172, epoch=11]\n",
|
|
"98it [00:10, 9.00it/s, avg_epoch_loss=-172, epoch=12]\n",
|
|
"99it [00:10, 9.02it/s, avg_epoch_loss=-177, epoch=13]\n",
|
|
"99it [00:10, 9.48it/s, avg_epoch_loss=-180, epoch=14]\n",
|
|
"98it [00:10, 9.65it/s, avg_epoch_loss=-180, epoch=15]\n",
|
|
"99it [00:10, 9.01it/s, avg_epoch_loss=-182, epoch=16]\n",
|
|
"99it [00:10, 9.11it/s, avg_epoch_loss=-182, epoch=17]\n",
|
|
"99it [00:10, 9.02it/s, avg_epoch_loss=-182, epoch=18]\n",
|
|
"98it [00:11, 8.89it/s, avg_epoch_loss=-182, epoch=19]\n",
|
|
"99it [00:10, 9.01it/s, avg_epoch_loss=-179, epoch=20]\n",
|
|
"99it [00:10, 9.15it/s, avg_epoch_loss=-183, epoch=21]\n",
|
|
"99it [00:11, 8.96it/s, avg_epoch_loss=-188, epoch=22]\n",
|
|
"99it [00:11, 8.96it/s, avg_epoch_loss=-188, epoch=23]\n",
|
|
"99it [00:10, 9.04it/s, avg_epoch_loss=-190, epoch=24]\n",
|
|
"98it [00:11, 8.85it/s, avg_epoch_loss=-193, epoch=25]\n",
|
|
"98it [00:10, 8.95it/s, avg_epoch_loss=-193, epoch=26]\n",
|
|
"99it [00:11, 8.93it/s, avg_epoch_loss=-192, epoch=27]\n",
|
|
"99it [00:10, 9.06it/s, avg_epoch_loss=-193, epoch=28]\n",
|
|
"98it [00:10, 8.97it/s, avg_epoch_loss=-193, epoch=29]\n",
|
|
"99it [00:11, 8.95it/s, avg_epoch_loss=-196, epoch=30]\n",
|
|
"98it [00:10, 8.95it/s, avg_epoch_loss=-193, epoch=31]\n",
|
|
"99it [00:10, 9.05it/s, avg_epoch_loss=-192, epoch=32]\n",
|
|
"99it [00:11, 8.90it/s, avg_epoch_loss=-197, epoch=33]\n",
|
|
"99it [00:11, 8.93it/s, avg_epoch_loss=-198, epoch=34]\n",
|
|
"98it [00:11, 8.85it/s, avg_epoch_loss=-197, epoch=35]\n",
|
|
"99it [00:10, 9.01it/s, avg_epoch_loss=-198, epoch=36]\n",
|
|
"98it [00:10, 8.97it/s, avg_epoch_loss=-200, epoch=37]\n",
|
|
"98it [00:11, 8.85it/s, avg_epoch_loss=-199, epoch=38]\n",
|
|
"99it [00:10, 9.04it/s, avg_epoch_loss=-197, epoch=39]\n",
|
|
"99it [00:11, 8.97it/s, avg_epoch_loss=-199, epoch=40]\n",
|
|
"99it [00:11, 8.88it/s, avg_epoch_loss=-201, epoch=41]\n",
|
|
"98it [00:11, 8.90it/s, avg_epoch_loss=-201, epoch=42]\n",
|
|
"99it [00:10, 9.09it/s, avg_epoch_loss=-202, epoch=43]\n",
|
|
"98it [00:10, 8.93it/s, avg_epoch_loss=-199, epoch=44]\n",
|
|
" 0%| | 0/137 [00:00<?, ?it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.28it/s]\n",
|
|
" 1%| | 1/137 [00:00<00:14, 9.68it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.80it/s]\n",
|
|
" 1%|▏ | 2/137 [00:00<00:13, 9.72it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.36it/s]\n",
|
|
" 2%|▏ | 3/137 [00:00<00:13, 9.77it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.19it/s]\n",
|
|
" 3%|▎ | 4/137 [00:00<00:13, 9.80it/s]\n",
|
|
"Running evaluation: 7it [00:00, 82.08it/s]\n",
|
|
" 4%|▎ | 5/137 [00:00<00:13, 9.85it/s]\n",
|
|
"Running evaluation: 7it [00:00, 77.88it/s]\n",
|
|
" 4%|▍ | 6/137 [00:00<00:13, 9.72it/s]\n",
|
|
"Running evaluation: 7it [00:00, 78.87it/s]\n",
|
|
" 5%|▌ | 7/137 [00:00<00:13, 9.65it/s]\n",
|
|
"Running evaluation: 7it [00:00, 77.08it/s]\n",
|
|
" 6%|▌ | 8/137 [00:00<00:13, 9.55it/s]\n",
|
|
"Running evaluation: 7it [00:00, 76.53it/s]\n",
|
|
" 7%|▋ | 9/137 [00:00<00:13, 9.43it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.35it/s]\n",
|
|
" 7%|▋ | 10/137 [00:01<00:13, 9.52it/s]\n",
|
|
"Running evaluation: 7it [00:00, 78.39it/s]\n",
|
|
" 8%|▊ | 11/137 [00:01<00:13, 9.54it/s]\n",
|
|
"Running evaluation: 7it [00:00, 77.60it/s]\n",
|
|
" 9%|▉ | 12/137 [00:01<00:13, 9.53it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.24it/s]\n",
|
|
" 9%|▉ | 13/137 [00:01<00:13, 9.49it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.02it/s]\n",
|
|
" 10%|█ | 14/137 [00:01<00:12, 9.57it/s]\n",
|
|
"Running evaluation: 7it [00:00, 78.62it/s]\n",
|
|
" 11%|█ | 15/137 [00:01<00:12, 9.55it/s]\n",
|
|
"Running evaluation: 7it [00:00, 78.47it/s]\n",
|
|
" 12%|█▏ | 16/137 [00:01<00:12, 9.52it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.37it/s]\n",
|
|
" 12%|█▏ | 17/137 [00:01<00:12, 9.59it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.06it/s]\n",
|
|
" 13%|█▎ | 18/137 [00:01<00:12, 9.67it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.19it/s]\n",
|
|
" 14%|█▍ | 19/137 [00:01<00:12, 9.73it/s]\n",
|
|
"Running evaluation: 7it [00:00, 82.05it/s]\n",
|
|
" 15%|█▍ | 20/137 [00:02<00:11, 9.80it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.38it/s]\n",
|
|
" 15%|█▌ | 21/137 [00:02<00:11, 9.83it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.80it/s]\n",
|
|
" 16%|█▌ | 22/137 [00:02<00:11, 9.80it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.03it/s]\n",
|
|
" 17%|█▋ | 23/137 [00:02<00:11, 9.82it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.49it/s]\n",
|
|
" 18%|█▊ | 24/137 [00:02<00:11, 9.74it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.46it/s]\n",
|
|
" 18%|█▊ | 25/137 [00:02<00:11, 9.76it/s]\n",
|
|
"Running evaluation: 7it [00:00, 83.16it/s]\n",
|
|
" 19%|█▉ | 26/137 [00:02<00:11, 9.81it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.61it/s]\n",
|
|
" 20%|█▉ | 27/137 [00:02<00:11, 9.84it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.16it/s]\n",
|
|
" 20%|██ | 28/137 [00:02<00:11, 9.85it/s]\n",
|
|
"Running evaluation: 7it [00:00, 73.60it/s]\n",
|
|
" 21%|██ | 29/137 [00:02<00:11, 9.59it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.64it/s]\n",
|
|
" 22%|██▏ | 30/137 [00:03<00:11, 9.68it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.16it/s]\n",
|
|
" 23%|██▎ | 31/137 [00:03<00:10, 9.72it/s]\n",
|
|
"Running evaluation: 7it [00:00, 82.76it/s]\n",
|
|
"\n",
|
|
"Running evaluation: 7it [00:00, 80.90it/s]\n",
|
|
" 24%|██▍ | 33/137 [00:03<00:10, 9.78it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.28it/s]\n",
|
|
" 25%|██▍ | 34/137 [00:03<00:10, 9.78it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.43it/s]\n",
|
|
" 26%|██▌ | 35/137 [00:03<00:10, 9.78it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.23it/s]\n",
|
|
" 26%|██▋ | 36/137 [00:03<00:10, 9.81it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.30it/s]\n",
|
|
" 27%|██▋ | 37/137 [00:03<00:10, 9.84it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.99it/s]\n",
|
|
" 28%|██▊ | 38/137 [00:03<00:10, 9.87it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.61it/s]\n",
|
|
" 28%|██▊ | 39/137 [00:04<00:09, 9.82it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.20it/s]\n",
|
|
" 29%|██▉ | 40/137 [00:04<00:09, 9.83it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.11it/s]\n",
|
|
" 30%|██▉ | 41/137 [00:04<00:09, 9.84it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.50it/s]\n",
|
|
" 31%|███ | 42/137 [00:04<00:09, 9.87it/s]\n",
|
|
"Running evaluation: 7it [00:00, 82.13it/s]\n",
|
|
" 31%|███▏ | 43/137 [00:04<00:09, 9.89it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.91it/s]\n",
|
|
" 32%|███▏ | 44/137 [00:04<00:09, 9.88it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.62it/s]\n",
|
|
" 33%|███▎ | 45/137 [00:04<00:09, 9.80it/s]\n",
|
|
"Running evaluation: 7it [00:00, 78.32it/s]\n",
|
|
" 34%|███▎ | 46/137 [00:04<00:09, 9.70it/s]\n",
|
|
"Running evaluation: 7it [00:00, 83.13it/s]\n",
|
|
" 34%|███▍ | 47/137 [00:04<00:09, 9.77it/s]\n",
|
|
"Running evaluation: 7it [00:00, 73.63it/s]\n",
|
|
" 35%|███▌ | 48/137 [00:04<00:09, 9.50it/s]\n",
|
|
"Running evaluation: 7it [00:00, 76.64it/s]\n",
|
|
" 36%|███▌ | 49/137 [00:05<00:09, 9.42it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.82it/s]\n",
|
|
" 36%|███▋ | 50/137 [00:05<00:09, 9.52it/s]\n",
|
|
"Running evaluation: 7it [00:00, 75.39it/s]\n",
|
|
" 37%|███▋ | 51/137 [00:05<00:09, 9.40it/s]\n",
|
|
"Running evaluation: 7it [00:00, 77.91it/s]\n",
|
|
" 38%|███▊ | 52/137 [00:05<00:09, 9.41it/s]\n",
|
|
"Running evaluation: 7it [00:00, 78.21it/s]\n",
|
|
" 39%|███▊ | 53/137 [00:05<00:08, 9.45it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.29it/s]\n",
|
|
" 39%|███▉ | 54/137 [00:05<00:08, 9.53it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.16it/s]\n",
|
|
" 40%|████ | 55/137 [00:05<00:08, 9.63it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.02it/s]\n",
|
|
" 41%|████ | 56/137 [00:05<00:08, 9.69it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.13it/s]\n",
|
|
" 42%|████▏ | 57/137 [00:05<00:08, 9.74it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.67it/s]\n",
|
|
" 42%|████▏ | 58/137 [00:05<00:08, 9.60it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.86it/s]\n",
|
|
" 43%|████▎ | 59/137 [00:06<00:08, 9.63it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.31it/s]\n",
|
|
" 44%|████▍ | 60/137 [00:06<00:07, 9.70it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.42it/s]\n",
|
|
" 45%|████▍ | 61/137 [00:06<00:07, 9.72it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.96it/s]\n",
|
|
" 45%|████▌ | 62/137 [00:06<00:07, 9.76it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Running evaluation: 7it [00:00, 80.15it/s]\n",
|
|
" 46%|████▌ | 63/137 [00:06<00:07, 9.76it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.72it/s]\n",
|
|
" 47%|████▋ | 64/137 [00:06<00:07, 9.77it/s]\n",
|
|
"Running evaluation: 7it [00:00, 78.85it/s]\n",
|
|
" 47%|████▋ | 65/137 [00:06<00:07, 9.69it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.44it/s]\n",
|
|
" 48%|████▊ | 66/137 [00:06<00:07, 9.73it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.88it/s]\n",
|
|
" 49%|████▉ | 67/137 [00:06<00:07, 9.74it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.47it/s]\n",
|
|
" 50%|████▉ | 68/137 [00:07<00:07, 9.79it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.62it/s]\n",
|
|
" 50%|█████ | 69/137 [00:07<00:06, 9.80it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.87it/s]\n",
|
|
" 51%|█████ | 70/137 [00:07<00:06, 9.79it/s]\n",
|
|
"Running evaluation: 7it [00:00, 77.90it/s]\n",
|
|
" 52%|█████▏ | 71/137 [00:07<00:06, 9.69it/s]\n",
|
|
"Running evaluation: 7it [00:00, 77.50it/s]\n",
|
|
" 53%|█████▎ | 72/137 [00:07<00:06, 9.59it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.34it/s]\n",
|
|
" 53%|█████▎ | 73/137 [00:07<00:06, 9.62it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.45it/s]\n",
|
|
" 54%|█████▍ | 74/137 [00:07<00:06, 9.64it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.21it/s]\n",
|
|
" 55%|█████▍ | 75/137 [00:07<00:06, 9.71it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.50it/s]\n",
|
|
" 55%|█████▌ | 76/137 [00:07<00:06, 9.73it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.01it/s]\n",
|
|
" 56%|█████▌ | 77/137 [00:07<00:06, 9.76it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.26it/s]\n",
|
|
" 57%|█████▋ | 78/137 [00:08<00:06, 9.80it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.99it/s]\n",
|
|
" 58%|█████▊ | 79/137 [00:08<00:05, 9.81it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.81it/s]\n",
|
|
" 58%|█████▊ | 80/137 [00:08<00:05, 9.77it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.21it/s]\n",
|
|
" 59%|█████▉ | 81/137 [00:08<00:05, 9.79it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.45it/s]\n",
|
|
" 60%|█████▉ | 82/137 [00:08<00:05, 9.82it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.25it/s]\n",
|
|
" 61%|██████ | 83/137 [00:08<00:05, 9.83it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.13it/s]\n",
|
|
" 61%|██████▏ | 84/137 [00:08<00:05, 9.81it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.15it/s]\n",
|
|
" 62%|██████▏ | 85/137 [00:08<00:05, 9.75it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.62it/s]\n",
|
|
" 63%|██████▎ | 86/137 [00:08<00:05, 9.77it/s]\n",
|
|
"Running evaluation: 7it [00:00, 82.72it/s]\n",
|
|
"\n",
|
|
"Running evaluation: 7it [00:00, 82.62it/s]\n",
|
|
" 64%|██████▍ | 88/137 [00:09<00:04, 9.85it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.59it/s]\n",
|
|
" 65%|██████▍ | 89/137 [00:09<00:04, 9.88it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.62it/s]\n",
|
|
" 66%|██████▌ | 90/137 [00:09<00:04, 9.86it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.38it/s]\n",
|
|
" 66%|██████▋ | 91/137 [00:09<00:04, 9.88it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.71it/s]\n",
|
|
" 67%|██████▋ | 92/137 [00:09<00:04, 9.89it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.85it/s]\n",
|
|
" 68%|██████▊ | 93/137 [00:09<00:04, 9.87it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.08it/s]\n",
|
|
" 69%|██████▊ | 94/137 [00:09<00:04, 9.86it/s]\n",
|
|
"Running evaluation: 7it [00:00, 82.68it/s]\n",
|
|
" 69%|██████▉ | 95/137 [00:09<00:04, 9.88it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.35it/s]\n",
|
|
" 70%|███████ | 96/137 [00:09<00:04, 9.85it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.36it/s]\n",
|
|
" 71%|███████ | 97/137 [00:09<00:04, 9.84it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.28it/s]\n",
|
|
" 72%|███████▏ | 98/137 [00:10<00:03, 9.85it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.06it/s]\n",
|
|
" 72%|███████▏ | 99/137 [00:10<00:03, 9.85it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.01it/s]\n",
|
|
" 73%|███████▎ | 100/137 [00:10<00:03, 9.82it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.70it/s]\n",
|
|
" 74%|███████▎ | 101/137 [00:10<00:03, 9.85it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.67it/s]\n",
|
|
" 74%|███████▍ | 102/137 [00:10<00:03, 9.84it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.49it/s]\n",
|
|
" 75%|███████▌ | 103/137 [00:10<00:03, 9.82it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.27it/s]\n",
|
|
" 76%|███████▌ | 104/137 [00:10<00:03, 9.84it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.05it/s]\n",
|
|
" 77%|███████▋ | 105/137 [00:10<00:03, 9.85it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.27it/s]\n",
|
|
" 77%|███████▋ | 106/137 [00:10<00:03, 9.83it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.00it/s]\n",
|
|
" 78%|███████▊ | 107/137 [00:10<00:03, 9.79it/s]\n",
|
|
"Running evaluation: 7it [00:00, 77.18it/s]\n",
|
|
" 79%|███████▉ | 108/137 [00:11<00:03, 9.65it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.07it/s]\n",
|
|
" 80%|███████▉ | 109/137 [00:11<00:02, 9.63it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.43it/s]\n",
|
|
" 80%|████████ | 110/137 [00:11<00:02, 9.67it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.85it/s]\n",
|
|
" 81%|████████ | 111/137 [00:11<00:02, 9.74it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.47it/s]\n",
|
|
" 82%|████████▏ | 112/137 [00:11<00:02, 9.77it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.02it/s]\n",
|
|
" 82%|████████▏ | 113/137 [00:11<00:02, 9.76it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.52it/s]\n",
|
|
" 83%|████████▎ | 114/137 [00:11<00:02, 9.80it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.42it/s]\n",
|
|
" 84%|████████▍ | 115/137 [00:11<00:02, 9.83it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.95it/s]\n",
|
|
" 85%|████████▍ | 116/137 [00:11<00:02, 9.84it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.78it/s]\n",
|
|
" 85%|████████▌ | 117/137 [00:12<00:02, 9.85it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.33it/s]\n",
|
|
" 86%|████████▌ | 118/137 [00:12<00:01, 9.85it/s]\n",
|
|
"Running evaluation: 7it [00:00, 82.25it/s]\n",
|
|
" 87%|████████▋ | 119/137 [00:12<00:01, 9.86it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.09it/s]\n",
|
|
" 88%|████████▊ | 120/137 [00:12<00:02, 7.50it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.78it/s]\n",
|
|
" 88%|████████▊ | 121/137 [00:12<00:01, 8.04it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.03it/s]\n",
|
|
" 89%|████████▉ | 122/137 [00:12<00:01, 8.52it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.17it/s]\n",
|
|
" 90%|████████▉ | 123/137 [00:12<00:01, 8.89it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.38it/s]\n",
|
|
" 91%|█████████ | 124/137 [00:12<00:01, 9.17it/s]\n",
|
|
"Running evaluation: 7it [00:00, 72.26it/s]\n",
|
|
" 91%|█████████ | 125/137 [00:12<00:01, 9.07it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.07it/s]\n",
|
|
" 92%|█████████▏| 126/137 [00:13<00:01, 9.27it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.94it/s]\n",
|
|
" 93%|█████████▎| 127/137 [00:13<00:01, 9.44it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.67it/s]\n",
|
|
" 93%|█████████▎| 128/137 [00:13<00:00, 9.57it/s]\n",
|
|
"Running evaluation: 7it [00:00, 84.00it/s]\n",
|
|
"\n",
|
|
"Running evaluation: 7it [00:00, 82.48it/s]\n",
|
|
" 95%|█████████▍| 130/137 [00:13<00:00, 9.72it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.11it/s]\n",
|
|
" 96%|█████████▌| 131/137 [00:13<00:00, 9.77it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.90it/s]\n",
|
|
" 96%|█████████▋| 132/137 [00:13<00:00, 9.76it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.81it/s]\n",
|
|
" 97%|█████████▋| 133/137 [00:13<00:00, 9.80it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.61it/s]\n",
|
|
" 98%|█████████▊| 134/137 [00:13<00:00, 9.84it/s]\n",
|
|
"Running evaluation: 7it [00:00, 82.01it/s]\n",
|
|
" 99%|█████████▊| 135/137 [00:13<00:00, 9.88it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.80it/s]\n",
|
|
" 99%|█████████▉| 136/137 [00:14<00:00, 9.90it/s]\n",
|
|
"Running evaluation: 7it [00:00, 77.76it/s]\n",
|
|
"100%|██████████| 137/137 [00:14<00:00, 9.68it/s]\n",
|
|
"Running evaluation: 7it [00:00, 60.08it/s]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"predictor = estimator.train(dataset_train)\n",
|
|
"forecast_it, ts_it = make_evaluation_predictions(dataset=dataset_test,\n",
|
|
" predictor=predictor,\n",
|
|
" num_samples=100)\n",
|
|
"forecasts = list(forecast_it)\n",
|
|
"targets = list(ts_it)\n",
|
|
"\n",
|
|
"agg_metric, _ = evaluator(targets, forecasts, num_series=len(dataset_test))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Metrics"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 47,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"CRPS: 0.36531966950112466\n",
|
|
"ND: 0.45434020382814283\n",
|
|
"NRMSE: 0.9820216603495642\n",
|
|
"MSE: 914.7868680304274\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(\"CRPS: {}\".format(agg_metric['mean_wQuantileLoss']))\n",
|
|
"print(\"ND: {}\".format(agg_metric['ND']))\n",
|
|
"print(\"NRMSE: {}\".format(agg_metric['NRMSE']))\n",
|
|
"print(\"MSE: {}\".format(agg_metric['MSE']))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 48,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"CRPS-Sum: 0.2873863376280519\n",
|
|
"ND-Sum: 0.35970480888579265\n",
|
|
"NRMSE-Sum: 0.7184166842326591\n",
|
|
"MSE-Sum: 9189074.285714285\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(\"CRPS-Sum: {}\".format(agg_metric['m_sum_mean_wQuantileLoss']))\n",
|
|
"print(\"ND-Sum: {}\".format(agg_metric['m_sum_ND']))\n",
|
|
"print(\"NRMSE-Sum: {}\".format(agg_metric['m_sum_NRMSE']))\n",
|
|
"print(\"MSE-Sum: {}\".format(agg_metric['m_sum_MSE']))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## `GRU-MAF`"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"estimator = TempFlowEstimator(\n",
|
|
" target_dim=int(dataset.metadata.feat_static_cat[0].cardinality),\n",
|
|
" prediction_length=dataset.metadata.prediction_length,\n",
|
|
" cell_type='GRU',\n",
|
|
" input_size=552,\n",
|
|
" freq=dataset.metadata.freq,\n",
|
|
" scaling=True,\n",
|
|
" dequantize=True,\n",
|
|
" flow_type='MAF',\n",
|
|
" trainer=Trainer(device=device,\n",
|
|
" epochs=25,\n",
|
|
" learning_rate=1e-3,\n",
|
|
" num_batches_per_epoch=100,\n",
|
|
" batch_size=64)\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"98it [00:10, 9.05it/s, avg_epoch_loss=-7.36, epoch=0]\n",
|
|
"99it [00:10, 9.19it/s, avg_epoch_loss=-136, epoch=1]\n",
|
|
"99it [00:10, 9.12it/s, avg_epoch_loss=-164, epoch=2]\n",
|
|
"98it [00:10, 8.91it/s, avg_epoch_loss=-179, epoch=3]\n",
|
|
"98it [00:10, 9.09it/s, avg_epoch_loss=-188, epoch=4]\n",
|
|
"99it [00:10, 9.05it/s, avg_epoch_loss=-194, epoch=5]\n",
|
|
"98it [00:10, 9.04it/s, avg_epoch_loss=-198, epoch=6]\n",
|
|
"98it [00:10, 8.97it/s, avg_epoch_loss=-201, epoch=7]\n",
|
|
"97it [00:10, 8.90it/s, avg_epoch_loss=-204, epoch=8]\n",
|
|
"99it [00:10, 9.07it/s, avg_epoch_loss=-206, epoch=9]\n",
|
|
"99it [00:10, 9.09it/s, avg_epoch_loss=-207, epoch=10]\n",
|
|
"98it [00:11, 8.90it/s, avg_epoch_loss=-209, epoch=11]\n",
|
|
"99it [00:10, 9.02it/s, avg_epoch_loss=-210, epoch=12]\n",
|
|
"98it [00:10, 8.95it/s, avg_epoch_loss=-211, epoch=13]\n",
|
|
"99it [00:10, 9.21it/s, avg_epoch_loss=-212, epoch=14]\n",
|
|
"98it [00:10, 9.00it/s, avg_epoch_loss=-213, epoch=15]\n",
|
|
"99it [00:10, 9.21it/s, avg_epoch_loss=-214, epoch=16]\n",
|
|
"98it [00:10, 8.95it/s, avg_epoch_loss=-215, epoch=17]\n",
|
|
"98it [00:11, 8.88it/s, avg_epoch_loss=-216, epoch=18]\n",
|
|
"99it [00:10, 9.08it/s, avg_epoch_loss=-216, epoch=19]\n",
|
|
"98it [00:10, 8.96it/s, avg_epoch_loss=-217, epoch=20]\n",
|
|
"98it [00:10, 8.98it/s, avg_epoch_loss=-218, epoch=21]\n",
|
|
"97it [00:10, 8.88it/s, avg_epoch_loss=-218, epoch=22]\n",
|
|
"97it [00:10, 8.83it/s, avg_epoch_loss=-219, epoch=23]\n",
|
|
"98it [00:10, 8.97it/s, avg_epoch_loss=-219, epoch=24]\n",
|
|
" 0%| | 0/137 [00:00<?, ?it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.30it/s]\n",
|
|
" 1%| | 1/137 [00:00<00:14, 9.67it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.37it/s]\n",
|
|
" 1%|▏ | 2/137 [00:00<00:13, 9.69it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.99it/s]\n",
|
|
" 2%|▏ | 3/137 [00:00<00:13, 9.72it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.16it/s]\n",
|
|
" 3%|▎ | 4/137 [00:00<00:13, 9.74it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.87it/s]\n",
|
|
" 4%|▎ | 5/137 [00:00<00:15, 8.38it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.00it/s]\n",
|
|
" 4%|▍ | 6/137 [00:00<00:15, 8.73it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.20it/s]\n",
|
|
" 5%|▌ | 7/137 [00:00<00:14, 9.02it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.10it/s]\n",
|
|
" 6%|▌ | 8/137 [00:00<00:13, 9.26it/s]\n",
|
|
"Running evaluation: 7it [00:00, 78.81it/s]\n",
|
|
" 7%|▋ | 9/137 [00:00<00:13, 9.37it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.94it/s]\n",
|
|
" 7%|▋ | 10/137 [00:01<00:13, 9.51it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.86it/s]\n",
|
|
" 8%|▊ | 11/137 [00:01<00:13, 9.61it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.54it/s]\n",
|
|
" 9%|▉ | 12/137 [00:01<00:12, 9.65it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.68it/s]\n",
|
|
" 9%|▉ | 13/137 [00:01<00:12, 9.70it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.79it/s]\n",
|
|
" 10%|█ | 14/137 [00:01<00:12, 9.67it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.82it/s]\n",
|
|
" 11%|█ | 15/137 [00:01<00:12, 9.69it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.53it/s]\n",
|
|
" 12%|█▏ | 16/137 [00:01<00:12, 9.73it/s]\n",
|
|
"Running evaluation: 7it [00:00, 76.42it/s]\n",
|
|
" 12%|█▏ | 17/137 [00:01<00:12, 9.59it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.49it/s]\n",
|
|
" 13%|█▎ | 18/137 [00:01<00:12, 9.61it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.33it/s]\n",
|
|
" 14%|█▍ | 19/137 [00:02<00:12, 9.66it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.06it/s]\n",
|
|
" 15%|█▍ | 20/137 [00:02<00:12, 9.70it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.93it/s]\n",
|
|
" 15%|█▌ | 21/137 [00:02<00:11, 9.74it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.20it/s]\n",
|
|
" 16%|█▌ | 22/137 [00:02<00:11, 9.78it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.83it/s]\n",
|
|
" 17%|█▋ | 23/137 [00:02<00:11, 9.79it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.33it/s]\n",
|
|
" 18%|█▊ | 24/137 [00:02<00:11, 9.82it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.91it/s]\n",
|
|
" 18%|█▊ | 25/137 [00:02<00:11, 9.83it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.40it/s]\n",
|
|
" 19%|█▉ | 26/137 [00:02<00:11, 9.85it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.20it/s]\n",
|
|
" 20%|█▉ | 27/137 [00:02<00:11, 9.86it/s]\n",
|
|
"Running evaluation: 7it [00:00, 82.38it/s]\n",
|
|
" 20%|██ | 28/137 [00:02<00:11, 9.90it/s]\n",
|
|
"Running evaluation: 7it [00:00, 82.08it/s]\n",
|
|
" 21%|██ | 29/137 [00:03<00:10, 9.92it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.48it/s]\n",
|
|
" 22%|██▏ | 30/137 [00:03<00:10, 9.91it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.91it/s]\n",
|
|
" 23%|██▎ | 31/137 [00:03<00:10, 9.89it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.74it/s]\n",
|
|
" 23%|██▎ | 32/137 [00:03<00:10, 9.89it/s]\n",
|
|
"Running evaluation: 7it [00:00, 82.56it/s]\n",
|
|
" 24%|██▍ | 33/137 [00:03<00:10, 9.91it/s]\n",
|
|
"Running evaluation: 7it [00:00, 82.94it/s]\n",
|
|
"\n",
|
|
"Running evaluation: 7it [00:00, 81.87it/s]\n",
|
|
" 26%|██▌ | 35/137 [00:03<00:10, 9.93it/s]\n",
|
|
"Running evaluation: 7it [00:00, 82.36it/s]\n",
|
|
"\n",
|
|
"Running evaluation: 7it [00:00, 82.48it/s]\n",
|
|
" 27%|██▋ | 37/137 [00:03<00:10, 9.96it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.42it/s]\n",
|
|
" 28%|██▊ | 38/137 [00:03<00:09, 9.93it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.16it/s]\n",
|
|
" 28%|██▊ | 39/137 [00:04<00:09, 9.91it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.18it/s]\n",
|
|
" 29%|██▉ | 40/137 [00:04<00:09, 9.86it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.44it/s]\n",
|
|
" 30%|██▉ | 41/137 [00:04<00:09, 9.77it/s]\n",
|
|
"Running evaluation: 7it [00:00, 78.49it/s]\n",
|
|
" 31%|███ | 42/137 [00:04<00:09, 9.71it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.14it/s]\n",
|
|
" 31%|███▏ | 43/137 [00:04<00:09, 9.76it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.89it/s]\n",
|
|
" 32%|███▏ | 44/137 [00:04<00:09, 9.77it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.56it/s]\n",
|
|
" 33%|███▎ | 45/137 [00:04<00:09, 9.75it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.37it/s]\n",
|
|
" 34%|███▎ | 46/137 [00:04<00:09, 9.77it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.06it/s]\n",
|
|
" 34%|███▍ | 47/137 [00:04<00:09, 9.80it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.00it/s]\n",
|
|
" 35%|███▌ | 48/137 [00:04<00:09, 9.82it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.59it/s]\n",
|
|
" 36%|███▌ | 49/137 [00:05<00:08, 9.78it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.16it/s]\n",
|
|
" 36%|███▋ | 50/137 [00:05<00:08, 9.78it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.13it/s]\n",
|
|
" 37%|███▋ | 51/137 [00:05<00:08, 9.81it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.09it/s]\n",
|
|
" 38%|███▊ | 52/137 [00:05<00:08, 9.80it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.05it/s]\n",
|
|
" 39%|███▊ | 53/137 [00:05<00:08, 9.82it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.80it/s]\n",
|
|
" 39%|███▉ | 54/137 [00:05<00:08, 9.80it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.99it/s]\n",
|
|
" 40%|████ | 55/137 [00:05<00:08, 9.79it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.74it/s]\n",
|
|
" 41%|████ | 56/137 [00:05<00:08, 9.80it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.38it/s]\n",
|
|
" 42%|████▏ | 57/137 [00:05<00:08, 9.80it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.81it/s]\n",
|
|
" 42%|████▏ | 58/137 [00:05<00:08, 9.81it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.07it/s]\n",
|
|
" 43%|████▎ | 59/137 [00:06<00:07, 9.80it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.89it/s]\n",
|
|
" 44%|████▍ | 60/137 [00:06<00:07, 9.79it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.78it/s]\n",
|
|
" 45%|████▍ | 61/137 [00:06<00:07, 9.81it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.28it/s]\n",
|
|
" 45%|████▌ | 62/137 [00:06<00:07, 9.80it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.50it/s]\n",
|
|
" 46%|████▌ | 63/137 [00:06<00:07, 9.81it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.07it/s]\n",
|
|
" 47%|████▋ | 64/137 [00:06<00:07, 9.76it/s]\n",
|
|
"Running evaluation: 7it [00:00, 77.93it/s]\n",
|
|
" 47%|████▋ | 65/137 [00:06<00:07, 9.67it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.83it/s]\n",
|
|
" 48%|████▊ | 66/137 [00:06<00:07, 9.72it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.68it/s]\n",
|
|
" 49%|████▉ | 67/137 [00:06<00:07, 9.79it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.13it/s]\n",
|
|
" 50%|████▉ | 68/137 [00:06<00:07, 9.81it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.32it/s]\n",
|
|
" 50%|█████ | 69/137 [00:07<00:06, 9.84it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.92it/s]\n",
|
|
" 51%|█████ | 70/137 [00:07<00:06, 9.83it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.30it/s]\n",
|
|
" 52%|█████▏ | 71/137 [00:07<00:06, 9.85it/s]\n",
|
|
"Running evaluation: 7it [00:00, 82.09it/s]\n",
|
|
" 53%|█████▎ | 72/137 [00:07<00:06, 9.89it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.77it/s]\n",
|
|
" 53%|█████▎ | 73/137 [00:07<00:06, 9.87it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.08it/s]\n",
|
|
" 54%|█████▍ | 74/137 [00:07<00:06, 9.87it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.81it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" 55%|█████▍ | 75/137 [00:07<00:06, 9.89it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.63it/s]\n",
|
|
" 55%|█████▌ | 76/137 [00:07<00:06, 9.87it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.91it/s]\n",
|
|
" 56%|█████▌ | 77/137 [00:07<00:06, 9.87it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.83it/s]\n",
|
|
" 57%|█████▋ | 78/137 [00:08<00:05, 9.86it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.90it/s]\n",
|
|
" 58%|█████▊ | 79/137 [00:08<00:05, 9.86it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.50it/s]\n",
|
|
" 58%|█████▊ | 80/137 [00:08<00:05, 9.84it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.69it/s]\n",
|
|
" 59%|█████▉ | 81/137 [00:08<00:05, 9.84it/s]\n",
|
|
"Running evaluation: 7it [00:00, 78.67it/s]\n",
|
|
" 60%|█████▉ | 82/137 [00:08<00:05, 9.74it/s]\n",
|
|
"Running evaluation: 0it [00:00, ?it/s]\u001b[A\n",
|
|
"Running evaluation: 7it [00:00, 57.62it/s]\u001b[A\n",
|
|
" 61%|██████ | 83/137 [00:08<00:06, 8.73it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.47it/s]\n",
|
|
" 61%|██████▏ | 84/137 [00:08<00:05, 8.97it/s]\n",
|
|
"Running evaluation: 7it [00:00, 77.83it/s]\n",
|
|
" 62%|██████▏ | 85/137 [00:08<00:05, 9.12it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.50it/s]\n",
|
|
" 63%|██████▎ | 86/137 [00:08<00:05, 9.26it/s]\n",
|
|
"Running evaluation: 7it [00:00, 78.65it/s]\n",
|
|
" 64%|██████▎ | 87/137 [00:08<00:05, 9.33it/s]\n",
|
|
"Running evaluation: 7it [00:00, 77.89it/s]\n",
|
|
" 64%|██████▍ | 88/137 [00:09<00:05, 9.37it/s]\n",
|
|
"Running evaluation: 7it [00:00, 77.88it/s]\n",
|
|
" 65%|██████▍ | 89/137 [00:09<00:05, 9.39it/s]\n",
|
|
"Running evaluation: 7it [00:00, 78.58it/s]\n",
|
|
" 66%|██████▌ | 90/137 [00:09<00:04, 9.43it/s]\n",
|
|
"Running evaluation: 7it [00:00, 78.98it/s]\n",
|
|
" 66%|██████▋ | 91/137 [00:09<00:04, 9.45it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.47it/s]\n",
|
|
" 67%|██████▋ | 92/137 [00:09<00:04, 9.47it/s]\n",
|
|
"Running evaluation: 7it [00:00, 75.97it/s]\n",
|
|
" 68%|██████▊ | 93/137 [00:09<00:04, 9.37it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.28it/s]\n",
|
|
" 69%|██████▊ | 94/137 [00:09<00:04, 9.47it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.42it/s]\n",
|
|
" 69%|██████▉ | 95/137 [00:09<00:04, 9.57it/s]\n",
|
|
"Running evaluation: 7it [00:00, 78.29it/s]\n",
|
|
" 70%|███████ | 96/137 [00:09<00:04, 9.51it/s]\n",
|
|
"Running evaluation: 7it [00:00, 78.14it/s]\n",
|
|
" 71%|███████ | 97/137 [00:10<00:04, 9.49it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.59it/s]\n",
|
|
" 72%|███████▏ | 98/137 [00:10<00:04, 9.50it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.20it/s]\n",
|
|
" 72%|███████▏ | 99/137 [00:10<00:03, 9.56it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.43it/s]\n",
|
|
" 73%|███████▎ | 100/137 [00:10<00:03, 9.66it/s]\n",
|
|
"Running evaluation: 7it [00:00, 82.14it/s]\n",
|
|
" 74%|███████▎ | 101/137 [00:10<00:03, 9.76it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.95it/s]\n",
|
|
" 74%|███████▍ | 102/137 [00:10<00:03, 9.79it/s]\n",
|
|
"Running evaluation: 7it [00:00, 75.02it/s]\n",
|
|
" 75%|███████▌ | 103/137 [00:10<00:03, 9.58it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.46it/s]\n",
|
|
" 76%|███████▌ | 104/137 [00:10<00:03, 9.63it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.08it/s]\n",
|
|
" 77%|███████▋ | 105/137 [00:10<00:03, 9.70it/s]\n",
|
|
"Running evaluation: 7it [00:00, 82.26it/s]\n",
|
|
" 77%|███████▋ | 106/137 [00:10<00:03, 9.77it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.13it/s]\n",
|
|
" 78%|███████▊ | 107/137 [00:11<00:03, 9.80it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.79it/s]\n",
|
|
" 79%|███████▉ | 108/137 [00:11<00:02, 9.81it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.81it/s]\n",
|
|
" 80%|███████▉ | 109/137 [00:11<00:02, 9.83it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.83it/s]\n",
|
|
" 80%|████████ | 110/137 [00:11<00:02, 9.78it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.44it/s]\n",
|
|
" 81%|████████ | 111/137 [00:11<00:02, 9.81it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.57it/s]\n",
|
|
" 82%|████████▏ | 112/137 [00:11<00:02, 9.85it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.45it/s]\n",
|
|
" 82%|████████▏ | 113/137 [00:11<00:02, 9.77it/s]\n",
|
|
"Running evaluation: 7it [00:00, 77.78it/s]\n",
|
|
" 83%|████████▎ | 114/137 [00:11<00:02, 9.68it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.29it/s]\n",
|
|
" 84%|████████▍ | 115/137 [00:11<00:02, 9.74it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.49it/s]\n",
|
|
" 85%|████████▍ | 116/137 [00:11<00:02, 9.80it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.46it/s]\n",
|
|
" 85%|████████▌ | 117/137 [00:12<00:02, 9.80it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.67it/s]\n",
|
|
" 86%|████████▌ | 118/137 [00:12<00:01, 9.81it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.37it/s]\n",
|
|
" 87%|████████▋ | 119/137 [00:12<00:01, 9.84it/s]\n",
|
|
"Running evaluation: 7it [00:00, 78.34it/s]\n",
|
|
" 88%|████████▊ | 120/137 [00:12<00:01, 9.74it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.76it/s]\n",
|
|
" 88%|████████▊ | 121/137 [00:12<00:01, 9.70it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.34it/s]\n",
|
|
" 89%|████████▉ | 122/137 [00:12<00:01, 9.64it/s]\n",
|
|
"Running evaluation: 7it [00:00, 78.06it/s]\n",
|
|
" 90%|████████▉ | 123/137 [00:12<00:01, 9.58it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.37it/s]\n",
|
|
" 91%|█████████ | 124/137 [00:12<00:01, 9.67it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.66it/s]\n",
|
|
" 91%|█████████ | 125/137 [00:12<00:01, 9.72it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.24it/s]\n",
|
|
" 92%|█████████▏| 126/137 [00:13<00:01, 9.77it/s]\n",
|
|
"Running evaluation: 7it [00:00, 82.04it/s]\n",
|
|
" 93%|█████████▎| 127/137 [00:13<00:01, 9.83it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.05it/s]\n",
|
|
" 93%|█████████▎| 128/137 [00:13<00:00, 9.82it/s]\n",
|
|
"Running evaluation: 7it [00:00, 83.03it/s]\n",
|
|
"\n",
|
|
"Running evaluation: 7it [00:00, 82.04it/s]\n",
|
|
" 95%|█████████▍| 130/137 [00:13<00:00, 9.88it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.86it/s]\n",
|
|
" 96%|█████████▌| 131/137 [00:13<00:00, 9.86it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.55it/s]\n",
|
|
" 96%|█████████▋| 132/137 [00:13<00:00, 9.87it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.48it/s]\n",
|
|
" 97%|█████████▋| 133/137 [00:13<00:00, 9.88it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.72it/s]\n",
|
|
" 98%|█████████▊| 134/137 [00:13<00:00, 9.90it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.09it/s]\n",
|
|
" 99%|█████████▊| 135/137 [00:13<00:00, 9.82it/s]\n",
|
|
"Running evaluation: 7it [00:00, 76.88it/s]\n",
|
|
" 99%|█████████▉| 136/137 [00:14<00:00, 9.65it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.92it/s]\n",
|
|
"100%|██████████| 137/137 [00:14<00:00, 9.70it/s]\n",
|
|
"Running evaluation: 7it [00:00, 61.47it/s]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"predictor = estimator.train(dataset_train)\n",
|
|
"forecast_it, ts_it = make_evaluation_predictions(dataset=dataset_test,\n",
|
|
" predictor=predictor,\n",
|
|
" num_samples=100)\n",
|
|
"forecasts = list(forecast_it)\n",
|
|
"targets = list(ts_it)\n",
|
|
"\n",
|
|
"agg_metric, _ = evaluator(targets, forecasts, num_series=len(dataset_test))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Metrics"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"CRPS: 0.3855313301520275\n",
|
|
"ND: 0.48820539490099113\n",
|
|
"NRMSE: 1.018839692673421\n",
|
|
"MSE: 984.6672641166102\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(\"CRPS: {}\".format(agg_metric['mean_wQuantileLoss']))\n",
|
|
"print(\"ND: {}\".format(agg_metric['ND']))\n",
|
|
"print(\"NRMSE: {}\".format(agg_metric['NRMSE']))\n",
|
|
"print(\"MSE: {}\".format(agg_metric['MSE']))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"CRPS-Sum: 0.3268739166960563\n",
|
|
"ND-Sum: 0.40321702146475014\n",
|
|
"NRMSE-Sum: 0.75586334994103\n",
|
|
"MSE-Sum: 10171980.5\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(\"CRPS-Sum: {}\".format(agg_metric['m_sum_mean_wQuantileLoss']))\n",
|
|
"print(\"ND-Sum: {}\".format(agg_metric['m_sum_ND']))\n",
|
|
"print(\"NRMSE-Sum: {}\".format(agg_metric['m_sum_NRMSE']))\n",
|
|
"print(\"MSE-Sum: {}\".format(agg_metric['m_sum_MSE']))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## `Transformer-MAF`"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"estimator = TransformerTempFlowEstimator(\n",
|
|
" d_model=16,\n",
|
|
" num_heads=4,\n",
|
|
" input_size=552,\n",
|
|
" target_dim=int(dataset.metadata.feat_static_cat[0].cardinality),\n",
|
|
" prediction_length=dataset.metadata.prediction_length,\n",
|
|
" context_length=dataset.metadata.prediction_length*4,\n",
|
|
" flow_type='MAF',\n",
|
|
" dequantize=True,\n",
|
|
" freq=dataset.metadata.freq,\n",
|
|
" trainer=Trainer(\n",
|
|
" device=device,\n",
|
|
" epochs=14,\n",
|
|
" learning_rate=1e-3,\n",
|
|
" num_batches_per_epoch=100,\n",
|
|
" batch_size=64,\n",
|
|
" )\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"98it [00:11, 8.71it/s, avg_epoch_loss=-44.8, epoch=0]\n",
|
|
"98it [00:11, 8.60it/s, avg_epoch_loss=-170, epoch=1]\n",
|
|
"99it [00:11, 8.82it/s, avg_epoch_loss=-189, epoch=2]\n",
|
|
"98it [00:11, 8.83it/s, avg_epoch_loss=-201, epoch=3]\n",
|
|
"99it [00:11, 8.80it/s, avg_epoch_loss=-208, epoch=4]\n",
|
|
"98it [00:11, 8.72it/s, avg_epoch_loss=-212, epoch=5]\n",
|
|
"99it [00:11, 8.83it/s, avg_epoch_loss=-216, epoch=6]\n",
|
|
"99it [00:11, 8.80it/s, avg_epoch_loss=-218, epoch=7]\n",
|
|
"99it [00:11, 8.84it/s, avg_epoch_loss=-220, epoch=8]\n",
|
|
"98it [00:11, 8.74it/s, avg_epoch_loss=-222, epoch=9]\n",
|
|
"99it [00:11, 8.92it/s, avg_epoch_loss=-223, epoch=10]\n",
|
|
"99it [00:11, 8.74it/s, avg_epoch_loss=-225, epoch=11]\n",
|
|
"99it [00:11, 8.84it/s, avg_epoch_loss=-226, epoch=12]\n",
|
|
"99it [00:11, 8.88it/s, avg_epoch_loss=-227, epoch=13]\n",
|
|
" 0%| | 0/137 [00:00<?, ?it/s]\n",
|
|
"Running evaluation: 7it [00:00, 77.77it/s]\n",
|
|
" 1%| | 1/137 [00:00<00:21, 6.33it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.43it/s]\n",
|
|
" 1%|▏ | 2/137 [00:00<00:19, 7.08it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.35it/s]\n",
|
|
" 2%|▏ | 3/137 [00:00<00:17, 7.72it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.23it/s]\n",
|
|
" 3%|▎ | 4/137 [00:00<00:16, 8.24it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.80it/s]\n",
|
|
" 4%|▎ | 5/137 [00:00<00:15, 8.64it/s]\n",
|
|
"Running evaluation: 7it [00:00, 78.88it/s]\n",
|
|
" 4%|▍ | 6/137 [00:00<00:14, 8.92it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.12it/s]\n",
|
|
" 5%|▌ | 7/137 [00:00<00:14, 9.17it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.56it/s]\n",
|
|
" 6%|▌ | 8/137 [00:00<00:13, 9.36it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.12it/s]\n",
|
|
" 7%|▋ | 9/137 [00:00<00:13, 9.48it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.53it/s]\n",
|
|
" 7%|▋ | 10/137 [00:01<00:13, 9.58it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.50it/s]\n",
|
|
" 8%|▊ | 11/137 [00:01<00:13, 9.65it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.77it/s]\n",
|
|
" 9%|▉ | 12/137 [00:01<00:12, 9.68it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.35it/s]\n",
|
|
" 9%|▉ | 13/137 [00:01<00:12, 9.71it/s]\n",
|
|
"Running evaluation: 7it [00:00, 77.65it/s]\n",
|
|
" 10%|█ | 14/137 [00:01<00:12, 9.63it/s]\n",
|
|
"Running evaluation: 7it [00:00, 77.54it/s]\n",
|
|
" 11%|█ | 15/137 [00:01<00:12, 9.55it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.11it/s]\n",
|
|
" 12%|█▏ | 16/137 [00:01<00:12, 9.61it/s]\n",
|
|
"Running evaluation: 7it [00:00, 84.48it/s]\n",
|
|
"\n",
|
|
"Running evaluation: 7it [00:00, 82.61it/s]\n",
|
|
" 13%|█▎ | 18/137 [00:01<00:12, 9.75it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.09it/s]\n",
|
|
" 14%|█▍ | 19/137 [00:01<00:12, 9.79it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.75it/s]\n",
|
|
" 15%|█▍ | 20/137 [00:02<00:11, 9.81it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.87it/s]\n",
|
|
" 15%|█▌ | 21/137 [00:02<00:11, 9.82it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.69it/s]\n",
|
|
" 16%|█▌ | 22/137 [00:02<00:11, 9.83it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.07it/s]\n",
|
|
" 17%|█▋ | 23/137 [00:02<00:11, 9.81it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.81it/s]\n",
|
|
" 18%|█▊ | 24/137 [00:02<00:11, 9.83it/s]\n",
|
|
"Running evaluation: 7it [00:00, 78.27it/s]\n",
|
|
" 18%|█▊ | 25/137 [00:02<00:11, 9.72it/s]\n",
|
|
"Running evaluation: 7it [00:00, 76.40it/s]\n",
|
|
" 19%|█▉ | 26/137 [00:02<00:11, 9.56it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.52it/s]\n",
|
|
" 20%|█▉ | 27/137 [00:02<00:11, 9.58it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.30it/s]\n",
|
|
" 20%|██ | 28/137 [00:02<00:11, 9.67it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.19it/s]\n",
|
|
" 21%|██ | 29/137 [00:03<00:11, 9.74it/s]\n",
|
|
"Running evaluation: 7it [00:00, 78.78it/s]\n",
|
|
" 22%|██▏ | 30/137 [00:03<00:11, 9.68it/s]\n",
|
|
"Running evaluation: 7it [00:00, 76.67it/s]\n",
|
|
" 23%|██▎ | 31/137 [00:03<00:11, 9.57it/s]\n",
|
|
"Running evaluation: 7it [00:00, 77.69it/s]\n",
|
|
" 23%|██▎ | 32/137 [00:03<00:11, 9.51it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.35it/s]\n",
|
|
" 24%|██▍ | 33/137 [00:03<00:10, 9.59it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.23it/s]\n",
|
|
" 25%|██▍ | 34/137 [00:03<00:10, 9.68it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.28it/s]\n",
|
|
" 26%|██▌ | 35/137 [00:03<00:10, 9.67it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.90it/s]\n",
|
|
" 26%|██▋ | 36/137 [00:03<00:10, 9.73it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.90it/s]\n",
|
|
" 27%|██▋ | 37/137 [00:03<00:10, 9.77it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.11it/s]\n",
|
|
" 28%|██▊ | 38/137 [00:03<00:10, 9.73it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.05it/s]\n",
|
|
" 28%|██▊ | 39/137 [00:04<00:10, 9.77it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.01it/s]\n",
|
|
" 29%|██▉ | 40/137 [00:04<00:09, 9.79it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.60it/s]\n",
|
|
" 30%|██▉ | 41/137 [00:04<00:09, 9.80it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.29it/s]\n",
|
|
" 31%|███ | 42/137 [00:04<00:09, 9.82it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.37it/s]\n",
|
|
" 31%|███▏ | 43/137 [00:04<00:09, 9.84it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.04it/s]\n",
|
|
" 32%|███▏ | 44/137 [00:04<00:09, 9.86it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.97it/s]\n",
|
|
" 33%|███▎ | 45/137 [00:04<00:09, 9.83it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.23it/s]\n",
|
|
" 34%|███▎ | 46/137 [00:04<00:09, 9.82it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.90it/s]\n",
|
|
" 34%|███▍ | 47/137 [00:04<00:09, 9.84it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.12it/s]\n",
|
|
" 35%|███▌ | 48/137 [00:04<00:09, 9.84it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.41it/s]\n",
|
|
" 36%|███▌ | 49/137 [00:05<00:08, 9.87it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.62it/s]\n",
|
|
" 36%|███▋ | 50/137 [00:05<00:08, 9.89it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.46it/s]\n",
|
|
" 37%|███▋ | 51/137 [00:05<00:08, 9.83it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.46it/s]\n",
|
|
" 38%|███▊ | 52/137 [00:05<00:08, 9.86it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.03it/s]\n",
|
|
" 39%|███▊ | 53/137 [00:05<00:08, 9.85it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.10it/s]\n",
|
|
" 39%|███▉ | 54/137 [00:05<00:08, 9.80it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.98it/s]\n",
|
|
" 40%|████ | 55/137 [00:05<00:08, 9.78it/s]\n",
|
|
"Running evaluation: 7it [00:00, 77.92it/s]\n",
|
|
" 41%|████ | 56/137 [00:05<00:08, 9.70it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.25it/s]\n",
|
|
" 42%|████▏ | 57/137 [00:05<00:08, 9.74it/s]\n",
|
|
"Running evaluation: 7it [00:00, 82.88it/s]\n",
|
|
" 42%|████▏ | 58/137 [00:05<00:08, 9.81it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.94it/s]\n",
|
|
" 43%|████▎ | 59/137 [00:06<00:07, 9.83it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.83it/s]\n",
|
|
" 44%|████▍ | 60/137 [00:06<00:07, 9.84it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.61it/s]\n",
|
|
" 45%|████▍ | 61/137 [00:06<00:07, 9.84it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.59it/s]\n",
|
|
" 45%|████▌ | 62/137 [00:06<00:07, 9.84it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.34it/s]\n",
|
|
" 46%|████▌ | 63/137 [00:06<00:07, 9.85it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.29it/s]\n",
|
|
" 47%|████▋ | 64/137 [00:06<00:07, 9.81it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.27it/s]\n",
|
|
" 47%|████▋ | 65/137 [00:06<00:07, 9.79it/s]\n",
|
|
"Running evaluation: 7it [00:00, 82.31it/s]\n",
|
|
" 48%|████▊ | 66/137 [00:06<00:07, 9.84it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.17it/s]\n",
|
|
" 49%|████▉ | 67/137 [00:06<00:07, 9.85it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.44it/s]\n",
|
|
" 50%|████▉ | 68/137 [00:07<00:07, 9.82it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.87it/s]\n",
|
|
" 50%|█████ | 69/137 [00:07<00:06, 9.79it/s]\n",
|
|
"Running evaluation: 7it [00:00, 78.47it/s]\n",
|
|
" 51%|█████ | 70/137 [00:07<00:06, 9.71it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.03it/s]\n",
|
|
" 52%|█████▏ | 71/137 [00:07<00:06, 9.76it/s]\n",
|
|
"Running evaluation: 7it [00:00, 78.82it/s]\n",
|
|
" 53%|█████▎ | 72/137 [00:07<00:06, 9.69it/s]\n",
|
|
"Running evaluation: 7it [00:00, 78.15it/s]\n",
|
|
" 53%|█████▎ | 73/137 [00:07<00:06, 9.61it/s]\n",
|
|
"Running evaluation: 7it [00:00, 77.70it/s]\n",
|
|
" 54%|█████▍ | 74/137 [00:07<00:06, 9.58it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.61it/s]\n",
|
|
" 55%|█████▍ | 75/137 [00:07<00:06, 9.63it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.33it/s]\n",
|
|
" 55%|█████▌ | 76/137 [00:07<00:06, 9.65it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.99it/s]\n",
|
|
" 56%|█████▌ | 77/137 [00:07<00:06, 9.70it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.21it/s]\n",
|
|
" 57%|█████▋ | 78/137 [00:08<00:06, 9.74it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.20it/s]\n",
|
|
" 58%|█████▊ | 79/137 [00:08<00:05, 9.77it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.63it/s]\n",
|
|
" 58%|█████▊ | 80/137 [00:08<00:05, 9.77it/s]\n",
|
|
"Running evaluation: 7it [00:00, 83.25it/s]\n",
|
|
"\n",
|
|
"Running evaluation: 7it [00:00, 82.19it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" 60%|█████▉ | 82/137 [00:08<00:05, 9.86it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.32it/s]\n",
|
|
" 61%|██████ | 83/137 [00:08<00:05, 9.86it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.08it/s]\n",
|
|
" 61%|██████▏ | 84/137 [00:08<00:05, 9.81it/s]\n",
|
|
"Running evaluation: 0it [00:00, ?it/s]\u001b[A\n",
|
|
"Running evaluation: 7it [00:00, 59.80it/s]\u001b[A\n",
|
|
" 62%|██████▏ | 85/137 [00:08<00:05, 8.89it/s]\n",
|
|
"Running evaluation: 0it [00:00, ?it/s]\u001b[A\n",
|
|
"Running evaluation: 7it [00:00, 61.93it/s]\u001b[A\n",
|
|
" 63%|██████▎ | 86/137 [00:08<00:05, 8.53it/s]\n",
|
|
"Running evaluation: 0it [00:00, ?it/s]\u001b[A\n",
|
|
"Running evaluation: 7it [00:00, 62.18it/s]\u001b[A\n",
|
|
" 64%|██████▎ | 87/137 [00:09<00:06, 8.27it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.80it/s]\n",
|
|
" 64%|██████▍ | 88/137 [00:09<00:05, 8.67it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.51it/s]\n",
|
|
" 65%|██████▍ | 89/137 [00:09<00:05, 8.98it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.50it/s]\n",
|
|
" 66%|██████▌ | 90/137 [00:09<00:05, 9.24it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.78it/s]\n",
|
|
" 66%|██████▋ | 91/137 [00:09<00:04, 9.43it/s]\n",
|
|
"Running evaluation: 7it [00:00, 83.24it/s]\n",
|
|
" 67%|██████▋ | 92/137 [00:09<00:04, 9.58it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.40it/s]\n",
|
|
" 68%|██████▊ | 93/137 [00:09<00:04, 9.66it/s]\n",
|
|
"Running evaluation: 7it [00:00, 82.05it/s]\n",
|
|
" 69%|██████▊ | 94/137 [00:09<00:04, 9.76it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.61it/s]\n",
|
|
" 69%|██████▉ | 95/137 [00:09<00:04, 9.82it/s]\n",
|
|
"Running evaluation: 7it [00:00, 78.35it/s]\n",
|
|
" 70%|███████ | 96/137 [00:09<00:04, 9.73it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.77it/s]\n",
|
|
" 71%|███████ | 97/137 [00:10<00:04, 9.76it/s]\n",
|
|
"Running evaluation: 7it [00:00, 82.27it/s]\n",
|
|
" 72%|███████▏ | 98/137 [00:10<00:03, 9.81it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.60it/s]\n",
|
|
" 72%|███████▏ | 99/137 [00:10<00:03, 9.85it/s]\n",
|
|
"Running evaluation: 7it [00:00, 82.25it/s]\n",
|
|
" 73%|███████▎ | 100/137 [00:10<00:03, 9.89it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.72it/s]\n",
|
|
" 74%|███████▎ | 101/137 [00:10<00:03, 9.88it/s]\n",
|
|
"Running evaluation: 7it [00:00, 83.37it/s]\n",
|
|
"\n",
|
|
"Running evaluation: 7it [00:00, 81.87it/s]\n",
|
|
" 75%|███████▌ | 103/137 [00:10<00:03, 9.93it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.33it/s]\n",
|
|
" 76%|███████▌ | 104/137 [00:10<00:03, 9.91it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.90it/s]\n",
|
|
" 77%|███████▋ | 105/137 [00:10<00:03, 9.90it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.45it/s]\n",
|
|
" 77%|███████▋ | 106/137 [00:10<00:03, 9.91it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.97it/s]\n",
|
|
" 78%|███████▊ | 107/137 [00:11<00:03, 9.90it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.81it/s]\n",
|
|
" 79%|███████▉ | 108/137 [00:11<00:02, 9.90it/s]\n",
|
|
"Running evaluation: 7it [00:00, 82.78it/s]\n",
|
|
"\n",
|
|
"Running evaluation: 7it [00:00, 80.04it/s]\n",
|
|
" 80%|████████ | 110/137 [00:11<00:02, 9.90it/s]\n",
|
|
"Running evaluation: 7it [00:00, 82.18it/s]\n",
|
|
" 81%|████████ | 111/137 [00:11<00:02, 9.91it/s]\n",
|
|
"Running evaluation: 7it [00:00, 82.26it/s]\n",
|
|
"\n",
|
|
"Running evaluation: 7it [00:00, 80.87it/s]\n",
|
|
" 82%|████████▏ | 113/137 [00:11<00:02, 9.92it/s]\n",
|
|
"Running evaluation: 0it [00:00, ?it/s]\u001b[A\n",
|
|
"Running evaluation: 7it [00:00, 54.86it/s]\u001b[A\n",
|
|
" 83%|████████▎ | 114/137 [00:11<00:02, 8.71it/s]\n",
|
|
"Running evaluation: 0it [00:00, ?it/s]\u001b[A\n",
|
|
"Running evaluation: 7it [00:00, 58.60it/s]\u001b[A\n",
|
|
" 84%|████████▍ | 115/137 [00:11<00:02, 8.25it/s]\n",
|
|
"Running evaluation: 7it [00:00, 75.16it/s]\n",
|
|
" 85%|████████▍ | 116/137 [00:12<00:02, 8.49it/s]\n",
|
|
"Running evaluation: 0it [00:00, ?it/s]\u001b[A\n",
|
|
"Running evaluation: 7it [00:00, 57.53it/s]\u001b[A\n",
|
|
" 85%|████████▌ | 117/137 [00:12<00:02, 8.00it/s]\n",
|
|
"Running evaluation: 0it [00:00, ?it/s]\u001b[A\n",
|
|
"Running evaluation: 7it [00:00, 57.26it/s]\u001b[A\n",
|
|
" 86%|████████▌ | 118/137 [00:12<00:02, 7.70it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.55it/s]\n",
|
|
" 87%|████████▋ | 119/137 [00:12<00:02, 8.21it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.74it/s]\n",
|
|
" 88%|████████▊ | 120/137 [00:12<00:01, 8.67it/s]\n",
|
|
"Running evaluation: 7it [00:00, 79.56it/s]\n",
|
|
" 88%|████████▊ | 121/137 [00:12<00:01, 8.95it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.46it/s]\n",
|
|
" 89%|████████▉ | 122/137 [00:12<00:01, 9.22it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.89it/s]\n",
|
|
" 90%|████████▉ | 123/137 [00:12<00:01, 9.40it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.92it/s]\n",
|
|
" 91%|█████████ | 124/137 [00:12<00:01, 9.52it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.09it/s]\n",
|
|
" 91%|█████████ | 125/137 [00:13<00:01, 9.61it/s]\n",
|
|
"Running evaluation: 7it [00:00, 82.04it/s]\n",
|
|
" 92%|█████████▏| 126/137 [00:13<00:01, 9.70it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.77it/s]\n",
|
|
" 93%|█████████▎| 127/137 [00:13<00:01, 9.78it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.17it/s]\n",
|
|
" 93%|█████████▎| 128/137 [00:13<00:00, 9.81it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.65it/s]\n",
|
|
" 94%|█████████▍| 129/137 [00:13<00:00, 9.82it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.73it/s]\n",
|
|
" 95%|█████████▍| 130/137 [00:13<00:00, 9.82it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.02it/s]\n",
|
|
" 96%|█████████▌| 131/137 [00:13<00:00, 9.80it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.68it/s]\n",
|
|
" 96%|█████████▋| 132/137 [00:13<00:00, 9.78it/s]\n",
|
|
"Running evaluation: 7it [00:00, 81.46it/s]\n",
|
|
" 97%|█████████▋| 133/137 [00:13<00:00, 9.79it/s]\n",
|
|
"Running evaluation: 7it [00:00, 82.36it/s]\n",
|
|
"\n",
|
|
"Running evaluation: 7it [00:00, 79.63it/s]\n",
|
|
" 99%|█████████▊| 135/137 [00:14<00:00, 9.81it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.53it/s]\n",
|
|
" 99%|█████████▉| 136/137 [00:14<00:00, 9.81it/s]\n",
|
|
"Running evaluation: 7it [00:00, 80.42it/s]\n",
|
|
"100%|██████████| 137/137 [00:14<00:00, 9.59it/s]\n",
|
|
"Running evaluation: 7it [00:00, 62.66it/s]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"predictor = estimator.train(dataset_train)\n",
|
|
"forecast_it, ts_it = make_evaluation_predictions(dataset=dataset_test,\n",
|
|
" predictor=predictor,\n",
|
|
" num_samples=100)\n",
|
|
"forecasts = list(forecast_it)\n",
|
|
"targets = list(ts_it)\n",
|
|
"\n",
|
|
"agg_metric, _ = evaluator(targets, forecasts, num_series=len(dataset_test))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Metrics"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"CRPS: 0.37264046134993567\n",
|
|
"ND: 0.5043621354947913\n",
|
|
"NRMSE: 0.9928759300158241\n",
|
|
"MSE: 935.1208752979203\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(\"CRPS: {}\".format(agg_metric['mean_wQuantileLoss']))\n",
|
|
"print(\"ND: {}\".format(agg_metric['ND']))\n",
|
|
"print(\"NRMSE: {}\".format(agg_metric['NRMSE']))\n",
|
|
"print(\"MSE: {}\".format(agg_metric['MSE']))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"CRPS-Sum: 0.30787625107438427\n",
|
|
"ND-Sum: 0.4188356756894787\n",
|
|
"NRMSE-Sum: 0.7504274205713227\n",
|
|
"MSE-Sum: 10026199.285714285\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(\"CRPS-Sum: {}\".format(agg_metric['m_sum_mean_wQuantileLoss']))\n",
|
|
"print(\"ND-Sum: {}\".format(agg_metric['m_sum_ND']))\n",
|
|
"print(\"NRMSE-Sum: {}\".format(agg_metric['m_sum_NRMSE']))\n",
|
|
"print(\"MSE-Sum: {}\".format(agg_metric['m_sum_MSE']))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.4"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|