mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-06-27 19:32:05 +08:00
803 lines
26 KiB
Plaintext
803 lines
26 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-12-23T12:53:57.712458Z",
|
|
"start_time": "2022-12-23T12:53:57.704099Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# autoreload import your package\n",
|
|
"%load_ext autoreload\n",
|
|
"%autoreload 2"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-12-23T12:53:59.259208Z",
|
|
"start_time": "2022-12-23T12:53:57.713340Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"%matplotlib inline\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"\n",
|
|
"import numpy as np\n",
|
|
"import pandas as pd\n",
|
|
"\n",
|
|
"import torch"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-12-23T12:53:59.433098Z",
|
|
"start_time": "2022-12-23T12:53:59.261001Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/home/wassname/miniforge3/envs/gluonts10.0/lib/python3.9/site-packages/gluonts/json.py:101: UserWarning: Using `json`-module for json-handling. Consider installing one of `orjson`, `ujson` to speed up serialization and deserialization.\n",
|
|
" warnings.warn(\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from gluonts.dataset.multivariate_grouper import MultivariateGrouper\n",
|
|
"from gluonts.dataset.repository.datasets import dataset_recipes, get_dataset\n",
|
|
"from gluonts.evaluation.backtest import make_evaluation_predictions\n",
|
|
"from gluonts.evaluation import MultivariateEvaluator"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-12-23T12:53:59.545657Z",
|
|
"start_time": "2022-12-23T12:53:59.434450Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from pts.model.tempflow import TempFlowEstimator\n",
|
|
"from pts.model.time_grad import TimeGradEstimator\n",
|
|
"from pts.model.transformer_tempflow import TransformerTempFlowEstimator\n",
|
|
"from pts import Trainer"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-12-23T12:53:59.589097Z",
|
|
"start_time": "2022-12-23T12:53:59.546862Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-12-23T12:53:59.620352Z",
|
|
"start_time": "2022-12-23T12:53:59.590581Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def plot(target, forecast, prediction_length, prediction_intervals=(50.0, 90.0), color='g', fname=None):\n",
|
|
" label_prefix = \"\"\n",
|
|
" rows = 4\n",
|
|
" cols = 4\n",
|
|
" fig, axs = plt.subplots(rows, cols, figsize=(24, 24))\n",
|
|
" axx = axs.ravel()\n",
|
|
" seq_len, target_dim = target.shape\n",
|
|
" \n",
|
|
" ps = [50.0] + [\n",
|
|
" 50.0 + f * c / 2.0 for c in prediction_intervals for f in [-1.0, +1.0]\n",
|
|
" ]\n",
|
|
" \n",
|
|
" percentiles_sorted = sorted(set(ps))\n",
|
|
" \n",
|
|
" def alpha_for_percentile(p):\n",
|
|
" return (p / 100.0) ** 0.3\n",
|
|
" \n",
|
|
" for dim in range(0, min(rows * cols, target_dim)):\n",
|
|
" ax = axx[dim]\n",
|
|
"\n",
|
|
" target[-2 * prediction_length :][dim].plot(ax=ax)\n",
|
|
" \n",
|
|
" ps_data = [forecast.quantile(p / 100.0)[:,dim] for p in percentiles_sorted]\n",
|
|
" i_p50 = len(percentiles_sorted) // 2\n",
|
|
" \n",
|
|
" p50_data = ps_data[i_p50]\n",
|
|
" p50_series = pd.Series(data=p50_data, index=forecast.index)\n",
|
|
" p50_series.plot(color=color, ls=\"-\", label=f\"{label_prefix}median\", ax=ax)\n",
|
|
" \n",
|
|
" for i in range(len(percentiles_sorted) // 2):\n",
|
|
" ptile = percentiles_sorted[i]\n",
|
|
" alpha = alpha_for_percentile(ptile)\n",
|
|
" ax.fill_between(\n",
|
|
" forecast.index,\n",
|
|
" ps_data[i],\n",
|
|
" ps_data[-i - 1],\n",
|
|
" facecolor=color,\n",
|
|
" alpha=alpha,\n",
|
|
" interpolate=True,\n",
|
|
" )\n",
|
|
" # Hack to create labels for the error intervals.\n",
|
|
" # Doesn't actually plot anything, because we only pass a single data point\n",
|
|
" pd.Series(data=p50_data[:1], index=forecast.index[:1]).plot(\n",
|
|
" color=color,\n",
|
|
" alpha=alpha,\n",
|
|
" linewidth=10,\n",
|
|
" label=f\"{label_prefix}{100 - ptile * 2}%\",\n",
|
|
" ax=ax,\n",
|
|
" )\n",
|
|
"\n",
|
|
" legend = [\"observations\", \"median prediction\"] + [f\"{k}% prediction interval\" for k in prediction_intervals][::-1] \n",
|
|
" axx[0].legend(legend, loc=\"upper left\")\n",
|
|
" \n",
|
|
" if fname is not None:\n",
|
|
" plt.savefig(fname, bbox_inches='tight', pad_inches=0.05)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-12-23T12:53:59.648105Z",
|
|
"start_time": "2022-12-23T12:53:59.621542Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Available datasets: ['constant', 'exchange_rate', 'solar-energy', 'electricity', 'traffic', 'exchange_rate_nips', 'electricity_nips', 'traffic_nips', 'solar_nips', 'wiki-rolling_nips', 'taxi_30min', 'kaggle_web_traffic_with_missing', 'kaggle_web_traffic_without_missing', 'kaggle_web_traffic_weekly', 'm1_yearly', 'm1_quarterly', 'm1_monthly', 'nn5_daily_with_missing', 'nn5_daily_without_missing', 'nn5_weekly', 'tourism_monthly', 'tourism_quarterly', 'tourism_yearly', 'cif_2016', 'london_smart_meters_without_missing', 'wind_farms_without_missing', 'car_parts_without_missing', 'dominick', 'fred_md', 'pedestrian_counts', 'hospital', 'covid_deaths', 'kdd_cup_2018_without_missing', 'weather', 'm3_monthly', 'm3_quarterly', 'm3_yearly', 'm3_other', 'm4_hourly', 'm4_daily', 'm4_weekly', 'm4_monthly', 'm4_quarterly', 'm4_yearly', 'm5', 'uber_tlc_daily', 'uber_tlc_hourly', 'airpassengers']\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(f\"Available datasets: {list(dataset_recipes.keys())}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-12-23T12:53:59.674244Z",
|
|
"start_time": "2022-12-23T12:53:59.649143Z"
|
|
},
|
|
"scrolled": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# exchange_rate_nips, electricity_nips, traffic_nips, solar_nips, wiki-rolling_nips, ## taxi_30min is buggy still\n",
|
|
"dataset = get_dataset(\"electricity_nips\", regenerate=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-12-23T12:53:59.691198Z",
|
|
"start_time": "2022-12-23T12:53:59.676310Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"MetaData(freq='H', target=None, feat_static_cat=[CategoricalFeatureInfo(name='feat_static_cat_0', cardinality='370')], feat_static_real=[], feat_dynamic_real=[], feat_dynamic_cat=[], prediction_length=24)"
|
|
]
|
|
},
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"dataset.metadata"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-12-23T07:12:03.693420Z",
|
|
"start_time": "2022-12-23T07:12:02.845649Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-12-23T12:53:59.715182Z",
|
|
"start_time": "2022-12-23T12:53:59.692190Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_grouper = MultivariateGrouper(max_target_dim=min(2000, int(dataset.metadata.feat_static_cat[0].cardinality)))\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-12-23T12:54:01.039272Z",
|
|
"start_time": "2022-12-23T12:53:59.716235Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"dataset_train = train_grouper(dataset.train)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-12-23T12:54:01.052821Z",
|
|
"start_time": "2022-12-23T12:54:01.040398Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"estimator = TimeGradEstimator(\n",
|
|
" target_dim=int(dataset.metadata.feat_static_cat[0].cardinality),\n",
|
|
" prediction_length=dataset.metadata.prediction_length,\n",
|
|
" context_length=dataset.metadata.prediction_length,\n",
|
|
" cell_type='GRU',\n",
|
|
" input_size=1484,\n",
|
|
" freq=dataset.metadata.freq,\n",
|
|
" loss_type='l2',\n",
|
|
" scaling=True,\n",
|
|
" diff_steps=100,\n",
|
|
" beta_end=0.1,\n",
|
|
" beta_schedule=\"linear\",\n",
|
|
" trainer=Trainer(device=device,\n",
|
|
" epochs=20,\n",
|
|
" learning_rate=1e-3,\n",
|
|
" num_batches_per_epoch=100,\n",
|
|
" batch_size=64,)\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-12-23T13:17:23.518463Z",
|
|
"start_time": "2022-12-23T12:54:01.053838Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "194804949f884a888bfc91f4d908d90b",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "ecd692a5a7414039b956e859e0b57495",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "af6f1e845efa447391b8c28d1ae9353a",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "12c5b95d86a543519cf057b3febf6344",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "1724ddc42b3849d0bfb873ed6f94cb53",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "65ad802537d246609ec3448d1fab3831",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "b137dcfd271441df960ca5bdf37d434e",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "1bf81c1871af44d18f4aa424f7187b9d",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "a747d5b3f9d0452380008b6752a3d8a9",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "28c0a56808d54d31b14e880d78d367e8",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "18109c16fc1848deb34cf8dd3ed488f3",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "a0f9eedaf8c44feb9a948f1102ca098a",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "000acd563f924589b8e3f4af5086e732",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "57292cdbd51048309555b387c44865b1",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "42f535c8a97c4e5597a6b063119a1296",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "922895d5b6a04854aa1488617f8de0bb",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "45e7a1d08cc648b1883ac48da52833c5",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "d65c3e15dddd4be285b98d0206df24f5",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "7410904994794355a7e9a6a60d98bbe2",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "5b88a677ebb44eccb32f3d8088266a11",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"predictor = estimator.train(dataset_train, num_workers=0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-12-23T13:17:24.533719Z",
|
|
"start_time": "2022-12-23T13:17:23.519766Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"test_grouper = MultivariateGrouper(\n",
|
|
" num_test_dates=int(len(dataset.test)/len(dataset.train)*2),\n",
|
|
"# num_test_dates=7,\n",
|
|
" max_target_dim=min(2000, int(dataset.metadata.feat_static_cat[0].cardinality)))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-12-23T13:17:36.536658Z",
|
|
"start_time": "2022-12-23T13:17:24.534843Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"ename": "ValueError",
|
|
"evalue": "setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (2590,) + inhomogeneous part.",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
|
|
"Input \u001b[0;32mIn [15]\u001b[0m, in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m dataset_test \u001b[38;5;241m=\u001b[39m \u001b[43mtest_grouper\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdataset\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtest\u001b[49m\u001b[43m)\u001b[49m\n",
|
|
"File \u001b[0;32m~/miniforge3/envs/gluonts10.0/lib/python3.9/site-packages/gluonts/dataset/multivariate_grouper.py:87\u001b[0m, in \u001b[0;36mMultivariateGrouper.__call__\u001b[0;34m(self, dataset)\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, dataset: Dataset) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Dataset:\n\u001b[1;32m 86\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_preprocess(dataset)\n\u001b[0;32m---> 87\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_group_all\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdataset\u001b[49m\u001b[43m)\u001b[49m\n",
|
|
"File \u001b[0;32m~/miniforge3/envs/gluonts10.0/lib/python3.9/site-packages/gluonts/dataset/multivariate_grouper.py:125\u001b[0m, in \u001b[0;36mMultivariateGrouper._group_all\u001b[0;34m(self, dataset)\u001b[0m\n\u001b[1;32m 123\u001b[0m grouped_dataset \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_prepare_train_data(dataset)\n\u001b[1;32m 124\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 125\u001b[0m grouped_dataset \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_prepare_test_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdataset\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 126\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m grouped_dataset\n",
|
|
"File \u001b[0;32m~/miniforge3/envs/gluonts10.0/lib/python3.9/site-packages/gluonts/dataset/multivariate_grouper.py:145\u001b[0m, in \u001b[0;36mMultivariateGrouper._prepare_test_data\u001b[0;34m(self, dataset)\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_test_dates \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 143\u001b[0m logging\u001b[38;5;241m.\u001b[39minfo(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgroup test time-series to datasets\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 145\u001b[0m grouped_data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_transform_target\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_left_pad_data\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataset\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 146\u001b[0m \u001b[38;5;66;03m# splits test dataset with rolling date into N R^d time series where\u001b[39;00m\n\u001b[1;32m 147\u001b[0m \u001b[38;5;66;03m# N is the number of rolling evaluation dates\u001b[39;00m\n\u001b[1;32m 148\u001b[0m split_dataset \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39msplit(\n\u001b[1;32m 149\u001b[0m grouped_data[FieldName\u001b[38;5;241m.\u001b[39mTARGET], \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_test_dates\n\u001b[1;32m 150\u001b[0m )\n",
|
|
"File \u001b[0;32m~/miniforge3/envs/gluonts10.0/lib/python3.9/site-packages/gluonts/dataset/multivariate_grouper.py:191\u001b[0m, in \u001b[0;36mMultivariateGrouper._transform_target\u001b[0;34m(funcs, dataset)\u001b[0m\n\u001b[1;32m 189\u001b[0m \u001b[38;5;129m@staticmethod\u001b[39m\n\u001b[1;32m 190\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_transform_target\u001b[39m(funcs, dataset: Dataset) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m DataEntry:\n\u001b[0;32m--> 191\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {FieldName\u001b[38;5;241m.\u001b[39mTARGET: \u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43marray\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43mfuncs\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mdataset\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m}\n",
|
|
"\u001b[0;31mValueError\u001b[0m: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (2590,) + inhomogeneous part."
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dataset_test = test_grouper(dataset.test)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-12-23T13:17:36.537956Z",
|
|
"start_time": "2022-12-23T13:17:36.537946Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"forecast_it, ts_it = make_evaluation_predictions(dataset=dataset_test,\n",
|
|
" predictor=predictor,\n",
|
|
" num_samples=100)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-12-23T13:17:36.538906Z",
|
|
"start_time": "2022-12-23T13:17:36.538898Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"forecasts = list(forecast_it)\n",
|
|
"targets = list(ts_it)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-12-23T13:17:36.539968Z",
|
|
"start_time": "2022-12-23T13:17:36.539960Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"%debug"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-12-23T13:17:36.541030Z",
|
|
"start_time": "2022-12-23T13:17:36.541021Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"plot(\n",
|
|
" target=targets[0],\n",
|
|
" forecast=forecasts[0],\n",
|
|
" prediction_length=dataset.metadata.prediction_length,\n",
|
|
")\n",
|
|
"plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-12-23T13:17:36.542095Z",
|
|
"start_time": "2022-12-23T13:17:36.542086Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"evaluator = MultivariateEvaluator(quantiles=(np.arange(20)/20.0)[1:], \n",
|
|
" target_agg_funcs={'sum': np.sum})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-12-23T13:17:36.543150Z",
|
|
"start_time": "2022-12-23T13:17:36.543141Z"
|
|
},
|
|
"scrolled": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"agg_metric, item_metrics = evaluator(targets, forecasts, num_series=len(dataset_test))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-12-23T13:17:36.544197Z",
|
|
"start_time": "2022-12-23T13:17:36.544189Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(\"CRPS:\", agg_metric[\"mean_wQuantileLoss\"])\n",
|
|
"print(\"ND:\", agg_metric[\"ND\"])\n",
|
|
"print(\"NRMSE:\", agg_metric[\"NRMSE\"])\n",
|
|
"print(\"\")\n",
|
|
"print(\"CRPS-Sum:\", agg_metric[\"m_sum_mean_wQuantileLoss\"])\n",
|
|
"print(\"ND-Sum:\", agg_metric[\"m_sum_ND\"])\n",
|
|
"print(\"NRMSE-Sum:\", agg_metric[\"m_sum_NRMSE\"])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "gluonts10.0",
|
|
"language": "python",
|
|
"name": "gluonts10.0"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.15"
|
|
},
|
|
"toc": {
|
|
"base_numbering": 1,
|
|
"nav_menu": {},
|
|
"number_sections": true,
|
|
"sideBar": true,
|
|
"skip_h1_title": false,
|
|
"title_cell": "Table of Contents",
|
|
"title_sidebar": "Contents",
|
|
"toc_cell": false,
|
|
"toc_position": {},
|
|
"toc_section_display": true,
|
|
"toc_window_display": false
|
|
},
|
|
"vscode": {
|
|
"interpreter": {
|
|
"hash": "7f25a1f13147a60511cf6766827402baf95cbe50d53a241197155306ee38fe70"
|
|
}
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|