Files
pytorch-ts/examples/Time-Grad-Electricity.ipynb
2022-12-24 12:36:51 +08:00

803 lines
26 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T12:53:57.712458Z",
"start_time": "2022-12-23T12:53:57.704099Z"
}
},
"outputs": [],
"source": [
"# autoreload import your package\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T12:53:59.259208Z",
"start_time": "2022-12-23T12:53:57.713340Z"
}
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T12:53:59.433098Z",
"start_time": "2022-12-23T12:53:59.261001Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/wassname/miniforge3/envs/gluonts10.0/lib/python3.9/site-packages/gluonts/json.py:101: UserWarning: Using `json`-module for json-handling. Consider installing one of `orjson`, `ujson` to speed up serialization and deserialization.\n",
" warnings.warn(\n"
]
}
],
"source": [
"from gluonts.dataset.multivariate_grouper import MultivariateGrouper\n",
"from gluonts.dataset.repository.datasets import dataset_recipes, get_dataset\n",
"from gluonts.evaluation.backtest import make_evaluation_predictions\n",
"from gluonts.evaluation import MultivariateEvaluator"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T12:53:59.545657Z",
"start_time": "2022-12-23T12:53:59.434450Z"
}
},
"outputs": [],
"source": [
"from pts.model.tempflow import TempFlowEstimator\n",
"from pts.model.time_grad import TimeGradEstimator\n",
"from pts.model.transformer_tempflow import TransformerTempFlowEstimator\n",
"from pts import Trainer"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T12:53:59.589097Z",
"start_time": "2022-12-23T12:53:59.546862Z"
}
},
"outputs": [],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T12:53:59.620352Z",
"start_time": "2022-12-23T12:53:59.590581Z"
}
},
"outputs": [],
"source": [
"def plot(target, forecast, prediction_length, prediction_intervals=(50.0, 90.0), color='g', fname=None):\n",
" label_prefix = \"\"\n",
" rows = 4\n",
" cols = 4\n",
" fig, axs = plt.subplots(rows, cols, figsize=(24, 24))\n",
" axx = axs.ravel()\n",
" seq_len, target_dim = target.shape\n",
" \n",
" ps = [50.0] + [\n",
" 50.0 + f * c / 2.0 for c in prediction_intervals for f in [-1.0, +1.0]\n",
" ]\n",
" \n",
" percentiles_sorted = sorted(set(ps))\n",
" \n",
" def alpha_for_percentile(p):\n",
" return (p / 100.0) ** 0.3\n",
" \n",
" for dim in range(0, min(rows * cols, target_dim)):\n",
" ax = axx[dim]\n",
"\n",
" target[-2 * prediction_length :][dim].plot(ax=ax)\n",
" \n",
" ps_data = [forecast.quantile(p / 100.0)[:,dim] for p in percentiles_sorted]\n",
" i_p50 = len(percentiles_sorted) // 2\n",
" \n",
" p50_data = ps_data[i_p50]\n",
" p50_series = pd.Series(data=p50_data, index=forecast.index)\n",
" p50_series.plot(color=color, ls=\"-\", label=f\"{label_prefix}median\", ax=ax)\n",
" \n",
" for i in range(len(percentiles_sorted) // 2):\n",
" ptile = percentiles_sorted[i]\n",
" alpha = alpha_for_percentile(ptile)\n",
" ax.fill_between(\n",
" forecast.index,\n",
" ps_data[i],\n",
" ps_data[-i - 1],\n",
" facecolor=color,\n",
" alpha=alpha,\n",
" interpolate=True,\n",
" )\n",
" # Hack to create labels for the error intervals.\n",
" # Doesn't actually plot anything, because we only pass a single data point\n",
" pd.Series(data=p50_data[:1], index=forecast.index[:1]).plot(\n",
" color=color,\n",
" alpha=alpha,\n",
" linewidth=10,\n",
" label=f\"{label_prefix}{100 - ptile * 2}%\",\n",
" ax=ax,\n",
" )\n",
"\n",
" legend = [\"observations\", \"median prediction\"] + [f\"{k}% prediction interval\" for k in prediction_intervals][::-1] \n",
" axx[0].legend(legend, loc=\"upper left\")\n",
" \n",
" if fname is not None:\n",
" plt.savefig(fname, bbox_inches='tight', pad_inches=0.05)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T12:53:59.648105Z",
"start_time": "2022-12-23T12:53:59.621542Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Available datasets: ['constant', 'exchange_rate', 'solar-energy', 'electricity', 'traffic', 'exchange_rate_nips', 'electricity_nips', 'traffic_nips', 'solar_nips', 'wiki-rolling_nips', 'taxi_30min', 'kaggle_web_traffic_with_missing', 'kaggle_web_traffic_without_missing', 'kaggle_web_traffic_weekly', 'm1_yearly', 'm1_quarterly', 'm1_monthly', 'nn5_daily_with_missing', 'nn5_daily_without_missing', 'nn5_weekly', 'tourism_monthly', 'tourism_quarterly', 'tourism_yearly', 'cif_2016', 'london_smart_meters_without_missing', 'wind_farms_without_missing', 'car_parts_without_missing', 'dominick', 'fred_md', 'pedestrian_counts', 'hospital', 'covid_deaths', 'kdd_cup_2018_without_missing', 'weather', 'm3_monthly', 'm3_quarterly', 'm3_yearly', 'm3_other', 'm4_hourly', 'm4_daily', 'm4_weekly', 'm4_monthly', 'm4_quarterly', 'm4_yearly', 'm5', 'uber_tlc_daily', 'uber_tlc_hourly', 'airpassengers']\n"
]
}
],
"source": [
"print(f\"Available datasets: {list(dataset_recipes.keys())}\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T12:53:59.674244Z",
"start_time": "2022-12-23T12:53:59.649143Z"
},
"scrolled": true
},
"outputs": [],
"source": [
"# exchange_rate_nips, electricity_nips, traffic_nips, solar_nips, wiki-rolling_nips, ## taxi_30min is buggy still\n",
"dataset = get_dataset(\"electricity_nips\", regenerate=False)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T12:53:59.691198Z",
"start_time": "2022-12-23T12:53:59.676310Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"MetaData(freq='H', target=None, feat_static_cat=[CategoricalFeatureInfo(name='feat_static_cat_0', cardinality='370')], feat_static_real=[], feat_dynamic_real=[], feat_dynamic_cat=[], prediction_length=24)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset.metadata"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T07:12:03.693420Z",
"start_time": "2022-12-23T07:12:02.845649Z"
}
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T12:53:59.715182Z",
"start_time": "2022-12-23T12:53:59.692190Z"
}
},
"outputs": [],
"source": [
"train_grouper = MultivariateGrouper(max_target_dim=min(2000, int(dataset.metadata.feat_static_cat[0].cardinality)))\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T12:54:01.039272Z",
"start_time": "2022-12-23T12:53:59.716235Z"
}
},
"outputs": [],
"source": [
"dataset_train = train_grouper(dataset.train)\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T12:54:01.052821Z",
"start_time": "2022-12-23T12:54:01.040398Z"
}
},
"outputs": [],
"source": [
"estimator = TimeGradEstimator(\n",
" target_dim=int(dataset.metadata.feat_static_cat[0].cardinality),\n",
" prediction_length=dataset.metadata.prediction_length,\n",
" context_length=dataset.metadata.prediction_length,\n",
" cell_type='GRU',\n",
" input_size=1484,\n",
" freq=dataset.metadata.freq,\n",
" loss_type='l2',\n",
" scaling=True,\n",
" diff_steps=100,\n",
" beta_end=0.1,\n",
" beta_schedule=\"linear\",\n",
" trainer=Trainer(device=device,\n",
" epochs=20,\n",
" learning_rate=1e-3,\n",
" num_batches_per_epoch=100,\n",
" batch_size=64,)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T13:17:23.518463Z",
"start_time": "2022-12-23T12:54:01.053838Z"
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "194804949f884a888bfc91f4d908d90b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ecd692a5a7414039b956e859e0b57495",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "af6f1e845efa447391b8c28d1ae9353a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "12c5b95d86a543519cf057b3febf6344",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1724ddc42b3849d0bfb873ed6f94cb53",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "65ad802537d246609ec3448d1fab3831",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b137dcfd271441df960ca5bdf37d434e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1bf81c1871af44d18f4aa424f7187b9d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a747d5b3f9d0452380008b6752a3d8a9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "28c0a56808d54d31b14e880d78d367e8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "18109c16fc1848deb34cf8dd3ed488f3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a0f9eedaf8c44feb9a948f1102ca098a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "000acd563f924589b8e3f4af5086e732",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "57292cdbd51048309555b387c44865b1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "42f535c8a97c4e5597a6b063119a1296",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "922895d5b6a04854aa1488617f8de0bb",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "45e7a1d08cc648b1883ac48da52833c5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d65c3e15dddd4be285b98d0206df24f5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7410904994794355a7e9a6a60d98bbe2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5b88a677ebb44eccb32f3d8088266a11",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"predictor = estimator.train(dataset_train, num_workers=0)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T13:17:24.533719Z",
"start_time": "2022-12-23T13:17:23.519766Z"
}
},
"outputs": [],
"source": [
"\n",
"test_grouper = MultivariateGrouper(\n",
" num_test_dates=int(len(dataset.test)/len(dataset.train)*2),\n",
"# num_test_dates=7,\n",
" max_target_dim=min(2000, int(dataset.metadata.feat_static_cat[0].cardinality)))"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T13:17:36.536658Z",
"start_time": "2022-12-23T13:17:24.534843Z"
}
},
"outputs": [
{
"ename": "ValueError",
"evalue": "setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (2590,) + inhomogeneous part.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"Input \u001b[0;32mIn [15]\u001b[0m, in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m dataset_test \u001b[38;5;241m=\u001b[39m \u001b[43mtest_grouper\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdataset\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtest\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniforge3/envs/gluonts10.0/lib/python3.9/site-packages/gluonts/dataset/multivariate_grouper.py:87\u001b[0m, in \u001b[0;36mMultivariateGrouper.__call__\u001b[0;34m(self, dataset)\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, dataset: Dataset) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Dataset:\n\u001b[1;32m 86\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_preprocess(dataset)\n\u001b[0;32m---> 87\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_group_all\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdataset\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniforge3/envs/gluonts10.0/lib/python3.9/site-packages/gluonts/dataset/multivariate_grouper.py:125\u001b[0m, in \u001b[0;36mMultivariateGrouper._group_all\u001b[0;34m(self, dataset)\u001b[0m\n\u001b[1;32m 123\u001b[0m grouped_dataset \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_prepare_train_data(dataset)\n\u001b[1;32m 124\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 125\u001b[0m grouped_dataset \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_prepare_test_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdataset\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 126\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m grouped_dataset\n",
"File \u001b[0;32m~/miniforge3/envs/gluonts10.0/lib/python3.9/site-packages/gluonts/dataset/multivariate_grouper.py:145\u001b[0m, in \u001b[0;36mMultivariateGrouper._prepare_test_data\u001b[0;34m(self, dataset)\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_test_dates \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 143\u001b[0m logging\u001b[38;5;241m.\u001b[39minfo(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgroup test time-series to datasets\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 145\u001b[0m grouped_data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_transform_target\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_left_pad_data\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataset\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 146\u001b[0m \u001b[38;5;66;03m# splits test dataset with rolling date into N R^d time series where\u001b[39;00m\n\u001b[1;32m 147\u001b[0m \u001b[38;5;66;03m# N is the number of rolling evaluation dates\u001b[39;00m\n\u001b[1;32m 148\u001b[0m split_dataset \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39msplit(\n\u001b[1;32m 149\u001b[0m grouped_data[FieldName\u001b[38;5;241m.\u001b[39mTARGET], \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_test_dates\n\u001b[1;32m 150\u001b[0m )\n",
"File \u001b[0;32m~/miniforge3/envs/gluonts10.0/lib/python3.9/site-packages/gluonts/dataset/multivariate_grouper.py:191\u001b[0m, in \u001b[0;36mMultivariateGrouper._transform_target\u001b[0;34m(funcs, dataset)\u001b[0m\n\u001b[1;32m 189\u001b[0m \u001b[38;5;129m@staticmethod\u001b[39m\n\u001b[1;32m 190\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_transform_target\u001b[39m(funcs, dataset: Dataset) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m DataEntry:\n\u001b[0;32m--> 191\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {FieldName\u001b[38;5;241m.\u001b[39mTARGET: \u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43marray\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43mfuncs\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mdataset\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m}\n",
"\u001b[0;31mValueError\u001b[0m: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (2590,) + inhomogeneous part."
]
}
],
"source": [
"dataset_test = test_grouper(dataset.test)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T13:17:36.537956Z",
"start_time": "2022-12-23T13:17:36.537946Z"
}
},
"outputs": [],
"source": [
"forecast_it, ts_it = make_evaluation_predictions(dataset=dataset_test,\n",
" predictor=predictor,\n",
" num_samples=100)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T13:17:36.538906Z",
"start_time": "2022-12-23T13:17:36.538898Z"
}
},
"outputs": [],
"source": [
"forecasts = list(forecast_it)\n",
"targets = list(ts_it)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T13:17:36.539968Z",
"start_time": "2022-12-23T13:17:36.539960Z"
}
},
"outputs": [],
"source": [
"%debug"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T13:17:36.541030Z",
"start_time": "2022-12-23T13:17:36.541021Z"
}
},
"outputs": [],
"source": [
"plot(\n",
" target=targets[0],\n",
" forecast=forecasts[0],\n",
" prediction_length=dataset.metadata.prediction_length,\n",
")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T13:17:36.542095Z",
"start_time": "2022-12-23T13:17:36.542086Z"
}
},
"outputs": [],
"source": [
"evaluator = MultivariateEvaluator(quantiles=(np.arange(20)/20.0)[1:], \n",
" target_agg_funcs={'sum': np.sum})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T13:17:36.543150Z",
"start_time": "2022-12-23T13:17:36.543141Z"
},
"scrolled": true
},
"outputs": [],
"source": [
"agg_metric, item_metrics = evaluator(targets, forecasts, num_series=len(dataset_test))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T13:17:36.544197Z",
"start_time": "2022-12-23T13:17:36.544189Z"
}
},
"outputs": [],
"source": [
"print(\"CRPS:\", agg_metric[\"mean_wQuantileLoss\"])\n",
"print(\"ND:\", agg_metric[\"ND\"])\n",
"print(\"NRMSE:\", agg_metric[\"NRMSE\"])\n",
"print(\"\")\n",
"print(\"CRPS-Sum:\", agg_metric[\"m_sum_mean_wQuantileLoss\"])\n",
"print(\"ND-Sum:\", agg_metric[\"m_sum_ND\"])\n",
"print(\"NRMSE-Sum:\", agg_metric[\"m_sum_NRMSE\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "gluonts10.0",
"language": "python",
"name": "gluonts10.0"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
},
"vscode": {
"interpreter": {
"hash": "7f25a1f13147a60511cf6766827402baf95cbe50d53a241197155306ee38fe70"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}