some poor results

This commit is contained in:
wassname
2022-12-24 12:36:51 +08:00
parent 204aed2371
commit 01663eb05f
2 changed files with 229 additions and 114 deletions
+225 -111
View File
@@ -5,8 +5,8 @@
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T09:57:39.738715Z",
"start_time": "2022-12-23T09:57:39.730077Z"
"end_time": "2022-12-23T12:53:57.712458Z",
"start_time": "2022-12-23T12:53:57.704099Z"
}
},
"outputs": [],
@@ -21,8 +21,8 @@
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T09:57:40.738250Z",
"start_time": "2022-12-23T09:57:39.739677Z"
"end_time": "2022-12-23T12:53:59.259208Z",
"start_time": "2022-12-23T12:53:57.713340Z"
}
},
"outputs": [],
@@ -41,11 +41,20 @@
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T09:57:40.909476Z",
"start_time": "2022-12-23T09:57:40.739741Z"
"end_time": "2022-12-23T12:53:59.433098Z",
"start_time": "2022-12-23T12:53:59.261001Z"
}
},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/wassname/miniforge3/envs/gluonts10.0/lib/python3.9/site-packages/gluonts/json.py:101: UserWarning: Using `json`-module for json-handling. Consider installing one of `orjson`, `ujson` to speed up serialization and deserialization.\n",
" warnings.warn(\n"
]
}
],
"source": [
"from gluonts.dataset.multivariate_grouper import MultivariateGrouper\n",
"from gluonts.dataset.repository.datasets import dataset_recipes, get_dataset\n",
@@ -58,8 +67,8 @@
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T09:57:40.966206Z",
"start_time": "2022-12-23T09:57:40.910477Z"
"end_time": "2022-12-23T12:53:59.545657Z",
"start_time": "2022-12-23T12:53:59.434450Z"
}
},
"outputs": [],
@@ -75,8 +84,8 @@
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T09:57:40.980883Z",
"start_time": "2022-12-23T09:57:40.967157Z"
"end_time": "2022-12-23T12:53:59.589097Z",
"start_time": "2022-12-23T12:53:59.546862Z"
}
},
"outputs": [],
@@ -89,8 +98,8 @@
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T09:57:41.003468Z",
"start_time": "2022-12-23T09:57:40.981830Z"
"end_time": "2022-12-23T12:53:59.620352Z",
"start_time": "2022-12-23T12:53:59.590581Z"
}
},
"outputs": [],
@@ -157,8 +166,8 @@
"execution_count": 7,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T09:57:41.025224Z",
"start_time": "2022-12-23T09:57:41.004330Z"
"end_time": "2022-12-23T12:53:59.648105Z",
"start_time": "2022-12-23T12:53:59.621542Z"
}
},
"outputs": [
@@ -179,8 +188,8 @@
"execution_count": 8,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T09:57:41.041326Z",
"start_time": "2022-12-23T09:57:41.026069Z"
"end_time": "2022-12-23T12:53:59.674244Z",
"start_time": "2022-12-23T12:53:59.649143Z"
},
"scrolled": true
},
@@ -195,8 +204,8 @@
"execution_count": 9,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T09:57:41.057694Z",
"start_time": "2022-12-23T09:57:41.042892Z"
"end_time": "2022-12-23T12:53:59.691198Z",
"start_time": "2022-12-23T12:53:59.676310Z"
}
},
"outputs": [
@@ -232,8 +241,8 @@
"execution_count": 10,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T09:57:41.073075Z",
"start_time": "2022-12-23T09:57:41.058526Z"
"end_time": "2022-12-23T12:53:59.715182Z",
"start_time": "2022-12-23T12:53:59.692190Z"
}
},
"outputs": [],
@@ -246,8 +255,8 @@
"execution_count": 11,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T09:57:44.134871Z",
"start_time": "2022-12-23T09:57:41.073880Z"
"end_time": "2022-12-23T12:54:01.039272Z",
"start_time": "2022-12-23T12:53:59.716235Z"
}
},
"outputs": [],
@@ -260,8 +269,8 @@
"execution_count": 12,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T09:57:44.153700Z",
"start_time": "2022-12-23T09:57:44.136099Z"
"end_time": "2022-12-23T12:54:01.052821Z",
"start_time": "2022-12-23T12:54:01.040398Z"
}
},
"outputs": [],
@@ -288,17 +297,18 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 13,
"metadata": {
"ExecuteTime": {
"start_time": "2022-12-23T09:57:39.746Z"
"end_time": "2022-12-23T13:17:23.518463Z",
"start_time": "2022-12-23T12:54:01.053838Z"
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "744b148440674ef3902ca56622cb001b",
"model_id": "194804949f884a888bfc91f4d908d90b",
"version_major": 2,
"version_minor": 0
},
@@ -312,7 +322,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "88abd67eb98d4569b276f98dadd73205",
"model_id": "ecd692a5a7414039b956e859e0b57495",
"version_major": 2,
"version_minor": 0
},
@@ -326,7 +336,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f5e08c9306c647cd8589f29d8652177b",
"model_id": "af6f1e845efa447391b8c28d1ae9353a",
"version_major": 2,
"version_minor": 0
},
@@ -340,7 +350,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "10a4ac2006db4f6090dc3f08cc3e3e9e",
"model_id": "12c5b95d86a543519cf057b3febf6344",
"version_major": 2,
"version_minor": 0
},
@@ -354,7 +364,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6458a98361854786bc4fa81da37471d2",
"model_id": "1724ddc42b3849d0bfb873ed6f94cb53",
"version_major": 2,
"version_minor": 0
},
@@ -368,7 +378,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "676d661e21ec42acbb544a8848264674",
"model_id": "65ad802537d246609ec3448d1fab3831",
"version_major": 2,
"version_minor": 0
},
@@ -382,7 +392,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "36010cebcd884d248de158890d2ba324",
"model_id": "b137dcfd271441df960ca5bdf37d434e",
"version_major": 2,
"version_minor": 0
},
@@ -396,7 +406,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f7487fae1ae54dae8c71d23ea5721e7f",
"model_id": "1bf81c1871af44d18f4aa424f7187b9d",
"version_major": 2,
"version_minor": 0
},
@@ -410,7 +420,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6a995bf552bb487cbca2a0f3e45098b3",
"model_id": "a747d5b3f9d0452380008b6752a3d8a9",
"version_major": 2,
"version_minor": 0
},
@@ -424,7 +434,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3ee8b7cdd0b94465a78c43d0917910ae",
"model_id": "28c0a56808d54d31b14e880d78d367e8",
"version_major": 2,
"version_minor": 0
},
@@ -438,7 +448,133 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "958283558c6a4e958c5c67188e735ed9",
"model_id": "18109c16fc1848deb34cf8dd3ed488f3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a0f9eedaf8c44feb9a948f1102ca098a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "000acd563f924589b8e3f4af5086e732",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "57292cdbd51048309555b387c44865b1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "42f535c8a97c4e5597a6b063119a1296",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "922895d5b6a04854aa1488617f8de0bb",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "45e7a1d08cc648b1883ac48da52833c5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d65c3e15dddd4be285b98d0206df24f5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7410904994794355a7e9a6a60d98bbe2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5b88a677ebb44eccb32f3d8088266a11",
"version_major": 2,
"version_minor": 0
},
@@ -456,17 +592,11 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 14,
"metadata": {
"ExecuteTime": {
"start_time": "2022-12-23T09:57:39.748Z"
"end_time": "2022-12-23T13:17:24.533719Z",
"start_time": "2022-12-23T13:17:23.519766Z"
}
},
"outputs": [],
@@ -480,13 +610,30 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 15,
"metadata": {
"ExecuteTime": {
"start_time": "2022-12-23T09:57:39.749Z"
"end_time": "2022-12-23T13:17:36.536658Z",
"start_time": "2022-12-23T13:17:24.534843Z"
}
},
"outputs": [],
"outputs": [
{
"ename": "ValueError",
"evalue": "setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (2590,) + inhomogeneous part.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"Input \u001b[0;32mIn [15]\u001b[0m, in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m dataset_test \u001b[38;5;241m=\u001b[39m \u001b[43mtest_grouper\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdataset\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtest\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniforge3/envs/gluonts10.0/lib/python3.9/site-packages/gluonts/dataset/multivariate_grouper.py:87\u001b[0m, in \u001b[0;36mMultivariateGrouper.__call__\u001b[0;34m(self, dataset)\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, dataset: Dataset) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Dataset:\n\u001b[1;32m 86\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_preprocess(dataset)\n\u001b[0;32m---> 87\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_group_all\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdataset\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniforge3/envs/gluonts10.0/lib/python3.9/site-packages/gluonts/dataset/multivariate_grouper.py:125\u001b[0m, in \u001b[0;36mMultivariateGrouper._group_all\u001b[0;34m(self, dataset)\u001b[0m\n\u001b[1;32m 123\u001b[0m grouped_dataset \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_prepare_train_data(dataset)\n\u001b[1;32m 124\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 125\u001b[0m grouped_dataset \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_prepare_test_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdataset\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 126\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m grouped_dataset\n",
"File \u001b[0;32m~/miniforge3/envs/gluonts10.0/lib/python3.9/site-packages/gluonts/dataset/multivariate_grouper.py:145\u001b[0m, in \u001b[0;36mMultivariateGrouper._prepare_test_data\u001b[0;34m(self, dataset)\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_test_dates \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 143\u001b[0m logging\u001b[38;5;241m.\u001b[39minfo(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgroup test time-series to datasets\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 145\u001b[0m grouped_data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_transform_target\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_left_pad_data\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataset\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 146\u001b[0m \u001b[38;5;66;03m# splits test dataset with rolling date into N R^d time series where\u001b[39;00m\n\u001b[1;32m 147\u001b[0m \u001b[38;5;66;03m# N is the number of rolling evaluation dates\u001b[39;00m\n\u001b[1;32m 148\u001b[0m split_dataset \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39msplit(\n\u001b[1;32m 149\u001b[0m grouped_data[FieldName\u001b[38;5;241m.\u001b[39mTARGET], \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_test_dates\n\u001b[1;32m 150\u001b[0m )\n",
"File \u001b[0;32m~/miniforge3/envs/gluonts10.0/lib/python3.9/site-packages/gluonts/dataset/multivariate_grouper.py:191\u001b[0m, in \u001b[0;36mMultivariateGrouper._transform_target\u001b[0;34m(funcs, dataset)\u001b[0m\n\u001b[1;32m 189\u001b[0m \u001b[38;5;129m@staticmethod\u001b[39m\n\u001b[1;32m 190\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_transform_target\u001b[39m(funcs, dataset: Dataset) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m DataEntry:\n\u001b[0;32m--> 191\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {FieldName\u001b[38;5;241m.\u001b[39mTARGET: \u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43marray\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43mfuncs\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mdataset\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m}\n",
"\u001b[0;31mValueError\u001b[0m: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (2590,) + inhomogeneous part."
]
}
],
"source": [
"dataset_test = test_grouper(dataset.test)"
]
@@ -496,60 +643,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
"start_time": "2022-12-23T09:57:39.750Z"
}
},
"outputs": [],
"source": [
"next(iter(dataset.test.iter_sequential()))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"start_time": "2022-12-23T09:57:39.751Z"
}
},
"outputs": [],
"source": [
"len(dataset.test)\n",
"x = [x['target'].shape for x in dataset.test.iter_sequential()]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"start_time": "2022-12-23T09:57:39.752Z"
}
},
"outputs": [],
"source": [
"pd.Series(x).value_counts()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"start_time": "2022-12-23T09:57:39.753Z"
}
},
"outputs": [],
"source": [
"%debug"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"start_time": "2022-12-23T09:57:39.754Z"
"end_time": "2022-12-23T13:17:36.537956Z",
"start_time": "2022-12-23T13:17:36.537946Z"
}
},
"outputs": [],
@@ -564,7 +659,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
"start_time": "2022-12-23T09:57:39.756Z"
"end_time": "2022-12-23T13:17:36.538906Z",
"start_time": "2022-12-23T13:17:36.538898Z"
}
},
"outputs": [],
@@ -578,7 +674,22 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
"start_time": "2022-12-23T09:57:39.756Z"
"end_time": "2022-12-23T13:17:36.539968Z",
"start_time": "2022-12-23T13:17:36.539960Z"
}
},
"outputs": [],
"source": [
"%debug"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T13:17:36.541030Z",
"start_time": "2022-12-23T13:17:36.541021Z"
}
},
"outputs": [],
@@ -596,7 +707,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
"start_time": "2022-12-23T09:57:39.757Z"
"end_time": "2022-12-23T13:17:36.542095Z",
"start_time": "2022-12-23T13:17:36.542086Z"
}
},
"outputs": [],
@@ -610,7 +722,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
"start_time": "2022-12-23T09:57:39.758Z"
"end_time": "2022-12-23T13:17:36.543150Z",
"start_time": "2022-12-23T13:17:36.543141Z"
},
"scrolled": true
},
@@ -624,7 +737,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
"start_time": "2022-12-23T09:57:39.759Z"
"end_time": "2022-12-23T13:17:36.544197Z",
"start_time": "2022-12-23T13:17:36.544189Z"
}
},
"outputs": [],
@@ -648,9 +762,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "glounts",
"display_name": "gluonts10.0",
"language": "python",
"name": "glounts"
"name": "gluonts10.0"
},
"language_info": {
"codemirror_mode": {
+4 -3
View File
@@ -84,7 +84,8 @@
},
"outputs": [],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"device"
]
},
{
@@ -1162,8 +1163,8 @@
],
"source": [
"print(\"CRPS:\", agg_metric[\"mean_wQuantileLoss\"])\n",
"print(\"ND:\", agg_metric[\"ND\"])\n",
"print(\"NRMSE:\", agg_metric[\"NRMSE\"])\n",
"print(\"ND:\", agg_metric[\"ND\"]) # totals[\"abs_error\"] / totals[\"abs_target_sum\"]\n",
"print(\"NRMSE:\", agg_metric[\"NRMSE\"]) # totals[\"RMSE\"] / totals[\"abs_target_mean\"]\n",
"print(\"\")\n",
"print(\"CRPS-Sum:\", agg_metric[\"m_sum_mean_wQuantileLoss\"])\n",
"print(\"ND-Sum:\", agg_metric[\"m_sum_ND\"])\n",