mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-06-27 16:46:32 +08:00
wip
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -14,7 +14,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -29,7 +29,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -45,7 +45,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -54,16 +54,16 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"MetaData(freq='H', target=None, feat_static_cat=[CategoricalFeatureInfo(name='feat_static_cat', cardinality='137')], feat_static_real=[], feat_dynamic_real=[], feat_dynamic_cat=[], prediction_length=24)"
|
||||
"MetaData(freq='H', target=None, feat_static_cat=[CategoricalFeatureInfo(name='feat_static_cat_0', cardinality='137')], feat_static_real=[], feat_dynamic_real=[], feat_dynamic_cat=[], prediction_length=24)"
|
||||
]
|
||||
},
|
||||
"execution_count": 11,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -74,7 +74,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -86,14 +86,14 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/samuelnorling/Programming/KTH/exjobb/zalandorosettaenv/lib/python3.8/site-packages/gluonts/dataset/multivariate_grouper.py:182: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n",
|
||||
"/home/wassname/miniforge3/envs/glounts/lib/python3.9/site-packages/gluonts/dataset/multivariate_grouper.py:191: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n",
|
||||
" return {FieldName.TARGET: np.array([funcs(data) for data in dataset])}\n"
|
||||
]
|
||||
}
|
||||
@@ -112,7 +112,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -129,7 +129,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 45,
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -152,334 +152,44 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 46,
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"99it [00:10, 9.03it/s, avg_epoch_loss=-43.1, epoch=0]\n",
|
||||
"99it [00:10, 9.03it/s, avg_epoch_loss=-126, epoch=1]\n",
|
||||
"99it [00:11, 9.00it/s, avg_epoch_loss=-142, epoch=2]\n",
|
||||
"99it [00:10, 9.37it/s, avg_epoch_loss=-143, epoch=3]\n",
|
||||
"99it [00:10, 9.09it/s, avg_epoch_loss=-153, epoch=4]\n",
|
||||
"99it [00:11, 8.76it/s, avg_epoch_loss=-157, epoch=5]\n",
|
||||
"99it [00:10, 9.03it/s, avg_epoch_loss=-157, epoch=6]\n",
|
||||
"99it [00:11, 8.94it/s, avg_epoch_loss=-166, epoch=7]\n",
|
||||
"99it [00:11, 8.56it/s, avg_epoch_loss=-169, epoch=8]\n",
|
||||
"99it [00:11, 8.84it/s, avg_epoch_loss=-168, epoch=9]\n",
|
||||
"98it [00:11, 8.89it/s, avg_epoch_loss=-170, epoch=10]\n",
|
||||
"99it [00:11, 8.89it/s, avg_epoch_loss=-172, epoch=11]\n",
|
||||
"98it [00:10, 9.00it/s, avg_epoch_loss=-172, epoch=12]\n",
|
||||
"99it [00:10, 9.02it/s, avg_epoch_loss=-177, epoch=13]\n",
|
||||
"99it [00:10, 9.48it/s, avg_epoch_loss=-180, epoch=14]\n",
|
||||
"98it [00:10, 9.65it/s, avg_epoch_loss=-180, epoch=15]\n",
|
||||
"99it [00:10, 9.01it/s, avg_epoch_loss=-182, epoch=16]\n",
|
||||
"99it [00:10, 9.11it/s, avg_epoch_loss=-182, epoch=17]\n",
|
||||
"99it [00:10, 9.02it/s, avg_epoch_loss=-182, epoch=18]\n",
|
||||
"98it [00:11, 8.89it/s, avg_epoch_loss=-182, epoch=19]\n",
|
||||
"99it [00:10, 9.01it/s, avg_epoch_loss=-179, epoch=20]\n",
|
||||
"99it [00:10, 9.15it/s, avg_epoch_loss=-183, epoch=21]\n",
|
||||
"99it [00:11, 8.96it/s, avg_epoch_loss=-188, epoch=22]\n",
|
||||
"99it [00:11, 8.96it/s, avg_epoch_loss=-188, epoch=23]\n",
|
||||
"99it [00:10, 9.04it/s, avg_epoch_loss=-190, epoch=24]\n",
|
||||
"98it [00:11, 8.85it/s, avg_epoch_loss=-193, epoch=25]\n",
|
||||
"98it [00:10, 8.95it/s, avg_epoch_loss=-193, epoch=26]\n",
|
||||
"99it [00:11, 8.93it/s, avg_epoch_loss=-192, epoch=27]\n",
|
||||
"99it [00:10, 9.06it/s, avg_epoch_loss=-193, epoch=28]\n",
|
||||
"98it [00:10, 8.97it/s, avg_epoch_loss=-193, epoch=29]\n",
|
||||
"99it [00:11, 8.95it/s, avg_epoch_loss=-196, epoch=30]\n",
|
||||
"98it [00:10, 8.95it/s, avg_epoch_loss=-193, epoch=31]\n",
|
||||
"99it [00:10, 9.05it/s, avg_epoch_loss=-192, epoch=32]\n",
|
||||
"99it [00:11, 8.90it/s, avg_epoch_loss=-197, epoch=33]\n",
|
||||
"99it [00:11, 8.93it/s, avg_epoch_loss=-198, epoch=34]\n",
|
||||
"98it [00:11, 8.85it/s, avg_epoch_loss=-197, epoch=35]\n",
|
||||
"99it [00:10, 9.01it/s, avg_epoch_loss=-198, epoch=36]\n",
|
||||
"98it [00:10, 8.97it/s, avg_epoch_loss=-200, epoch=37]\n",
|
||||
"98it [00:11, 8.85it/s, avg_epoch_loss=-199, epoch=38]\n",
|
||||
"99it [00:10, 9.04it/s, avg_epoch_loss=-197, epoch=39]\n",
|
||||
"99it [00:11, 8.97it/s, avg_epoch_loss=-199, epoch=40]\n",
|
||||
"99it [00:11, 8.88it/s, avg_epoch_loss=-201, epoch=41]\n",
|
||||
"98it [00:11, 8.90it/s, avg_epoch_loss=-201, epoch=42]\n",
|
||||
"99it [00:10, 9.09it/s, avg_epoch_loss=-202, epoch=43]\n",
|
||||
"98it [00:10, 8.93it/s, avg_epoch_loss=-199, epoch=44]\n",
|
||||
" 0%| | 0/137 [00:00<?, ?it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 79.28it/s]\n",
|
||||
" 1%| | 1/137 [00:00<00:14, 9.68it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.80it/s]\n",
|
||||
" 1%|▏ | 2/137 [00:00<00:13, 9.72it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.36it/s]\n",
|
||||
" 2%|▏ | 3/137 [00:00<00:13, 9.77it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.19it/s]\n",
|
||||
" 3%|▎ | 4/137 [00:00<00:13, 9.80it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 82.08it/s]\n",
|
||||
" 4%|▎ | 5/137 [00:00<00:13, 9.85it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 77.88it/s]\n",
|
||||
" 4%|▍ | 6/137 [00:00<00:13, 9.72it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 78.87it/s]\n",
|
||||
" 5%|▌ | 7/137 [00:00<00:13, 9.65it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 77.08it/s]\n",
|
||||
" 6%|▌ | 8/137 [00:00<00:13, 9.55it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 76.53it/s]\n",
|
||||
" 7%|▋ | 9/137 [00:00<00:13, 9.43it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.35it/s]\n",
|
||||
" 7%|▋ | 10/137 [00:01<00:13, 9.52it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 78.39it/s]\n",
|
||||
" 8%|▊ | 11/137 [00:01<00:13, 9.54it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 77.60it/s]\n",
|
||||
" 9%|▉ | 12/137 [00:01<00:13, 9.53it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 79.24it/s]\n",
|
||||
" 9%|▉ | 13/137 [00:01<00:13, 9.49it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.02it/s]\n",
|
||||
" 10%|█ | 14/137 [00:01<00:12, 9.57it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 78.62it/s]\n",
|
||||
" 11%|█ | 15/137 [00:01<00:12, 9.55it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 78.47it/s]\n",
|
||||
" 12%|█▏ | 16/137 [00:01<00:12, 9.52it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.37it/s]\n",
|
||||
" 12%|█▏ | 17/137 [00:01<00:12, 9.59it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.06it/s]\n",
|
||||
" 13%|█▎ | 18/137 [00:01<00:12, 9.67it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.19it/s]\n",
|
||||
" 14%|█▍ | 19/137 [00:01<00:12, 9.73it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 82.05it/s]\n",
|
||||
" 15%|█▍ | 20/137 [00:02<00:11, 9.80it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.38it/s]\n",
|
||||
" 15%|█▌ | 21/137 [00:02<00:11, 9.83it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 79.80it/s]\n",
|
||||
" 16%|█▌ | 22/137 [00:02<00:11, 9.80it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.03it/s]\n",
|
||||
" 17%|█▋ | 23/137 [00:02<00:11, 9.82it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 79.49it/s]\n",
|
||||
" 18%|█▊ | 24/137 [00:02<00:11, 9.74it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.46it/s]\n",
|
||||
" 18%|█▊ | 25/137 [00:02<00:11, 9.76it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 83.16it/s]\n",
|
||||
" 19%|█▉ | 26/137 [00:02<00:11, 9.81it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.61it/s]\n",
|
||||
" 20%|█▉ | 27/137 [00:02<00:11, 9.84it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.16it/s]\n",
|
||||
" 20%|██ | 28/137 [00:02<00:11, 9.85it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 73.60it/s]\n",
|
||||
" 21%|██ | 29/137 [00:02<00:11, 9.59it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.64it/s]\n",
|
||||
" 22%|██▏ | 30/137 [00:03<00:11, 9.68it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.16it/s]\n",
|
||||
" 23%|██▎ | 31/137 [00:03<00:10, 9.72it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 82.76it/s]\n",
|
||||
"\n",
|
||||
"Running evaluation: 7it [00:00, 80.90it/s]\n",
|
||||
" 24%|██▍ | 33/137 [00:03<00:10, 9.78it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.28it/s]\n",
|
||||
" 25%|██▍ | 34/137 [00:03<00:10, 9.78it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.43it/s]\n",
|
||||
" 26%|██▌ | 35/137 [00:03<00:10, 9.78it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.23it/s]\n",
|
||||
" 26%|██▋ | 36/137 [00:03<00:10, 9.81it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.30it/s]\n",
|
||||
" 27%|██▋ | 37/137 [00:03<00:10, 9.84it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.99it/s]\n",
|
||||
" 28%|██▊ | 38/137 [00:03<00:10, 9.87it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.61it/s]\n",
|
||||
" 28%|██▊ | 39/137 [00:04<00:09, 9.82it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.20it/s]\n",
|
||||
" 29%|██▉ | 40/137 [00:04<00:09, 9.83it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.11it/s]\n",
|
||||
" 30%|██▉ | 41/137 [00:04<00:09, 9.84it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.50it/s]\n",
|
||||
" 31%|███ | 42/137 [00:04<00:09, 9.87it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 82.13it/s]\n",
|
||||
" 31%|███▏ | 43/137 [00:04<00:09, 9.89it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.91it/s]\n",
|
||||
" 32%|███▏ | 44/137 [00:04<00:09, 9.88it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 79.62it/s]\n",
|
||||
" 33%|███▎ | 45/137 [00:04<00:09, 9.80it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 78.32it/s]\n",
|
||||
" 34%|███▎ | 46/137 [00:04<00:09, 9.70it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 83.13it/s]\n",
|
||||
" 34%|███▍ | 47/137 [00:04<00:09, 9.77it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 73.63it/s]\n",
|
||||
" 35%|███▌ | 48/137 [00:04<00:09, 9.50it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 76.64it/s]\n",
|
||||
" 36%|███▌ | 49/137 [00:05<00:09, 9.42it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.82it/s]\n",
|
||||
" 36%|███▋ | 50/137 [00:05<00:09, 9.52it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 75.39it/s]\n",
|
||||
" 37%|███▋ | 51/137 [00:05<00:09, 9.40it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 77.91it/s]\n",
|
||||
" 38%|███▊ | 52/137 [00:05<00:09, 9.41it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 78.21it/s]\n",
|
||||
" 39%|███▊ | 53/137 [00:05<00:08, 9.45it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.29it/s]\n",
|
||||
" 39%|███▉ | 54/137 [00:05<00:08, 9.53it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.16it/s]\n",
|
||||
" 40%|████ | 55/137 [00:05<00:08, 9.63it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.02it/s]\n",
|
||||
" 41%|████ | 56/137 [00:05<00:08, 9.69it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.13it/s]\n",
|
||||
" 42%|████▏ | 57/137 [00:05<00:08, 9.74it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.67it/s]\n",
|
||||
" 42%|████▏ | 58/137 [00:05<00:08, 9.60it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 79.86it/s]\n",
|
||||
" 43%|████▎ | 59/137 [00:06<00:08, 9.63it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.31it/s]\n",
|
||||
" 44%|████▍ | 60/137 [00:06<00:07, 9.70it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.42it/s]\n",
|
||||
" 45%|████▍ | 61/137 [00:06<00:07, 9.72it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.96it/s]\n",
|
||||
" 45%|████▌ | 62/137 [00:06<00:07, 9.76it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.15it/s]\n",
|
||||
" 46%|████▌ | 63/137 [00:06<00:07, 9.76it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.72it/s]\n",
|
||||
" 47%|████▋ | 64/137 [00:06<00:07, 9.77it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 78.85it/s]\n",
|
||||
" 47%|████▋ | 65/137 [00:06<00:07, 9.69it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.44it/s]\n",
|
||||
" 48%|████▊ | 66/137 [00:06<00:07, 9.73it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.88it/s]\n",
|
||||
" 49%|████▉ | 67/137 [00:06<00:07, 9.74it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.47it/s]\n",
|
||||
" 50%|████▉ | 68/137 [00:07<00:07, 9.79it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.62it/s]\n",
|
||||
" 50%|█████ | 69/137 [00:07<00:06, 9.80it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.87it/s]\n",
|
||||
" 51%|█████ | 70/137 [00:07<00:06, 9.79it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 77.90it/s]\n",
|
||||
" 52%|█████▏ | 71/137 [00:07<00:06, 9.69it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 77.50it/s]\n",
|
||||
" 53%|█████▎ | 72/137 [00:07<00:06, 9.59it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.34it/s]\n",
|
||||
" 53%|█████▎ | 73/137 [00:07<00:06, 9.62it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 79.45it/s]\n",
|
||||
" 54%|█████▍ | 74/137 [00:07<00:06, 9.64it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.21it/s]\n",
|
||||
" 55%|█████▍ | 75/137 [00:07<00:06, 9.71it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.50it/s]\n",
|
||||
" 55%|█████▌ | 76/137 [00:07<00:06, 9.73it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.01it/s]\n",
|
||||
" 56%|█████▌ | 77/137 [00:07<00:06, 9.76it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.26it/s]\n",
|
||||
" 57%|█████▋ | 78/137 [00:08<00:06, 9.80it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.99it/s]\n",
|
||||
" 58%|█████▊ | 79/137 [00:08<00:05, 9.81it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.81it/s]\n",
|
||||
" 58%|█████▊ | 80/137 [00:08<00:05, 9.77it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.21it/s]\n",
|
||||
" 59%|█████▉ | 81/137 [00:08<00:05, 9.79it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.45it/s]\n",
|
||||
" 60%|█████▉ | 82/137 [00:08<00:05, 9.82it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.25it/s]\n",
|
||||
" 61%|██████ | 83/137 [00:08<00:05, 9.83it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.13it/s]\n",
|
||||
" 61%|██████▏ | 84/137 [00:08<00:05, 9.81it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 79.15it/s]\n",
|
||||
" 62%|██████▏ | 85/137 [00:08<00:05, 9.75it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.62it/s]\n",
|
||||
" 63%|██████▎ | 86/137 [00:08<00:05, 9.77it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 82.72it/s]\n",
|
||||
"\n",
|
||||
"Running evaluation: 7it [00:00, 82.62it/s]\n",
|
||||
" 64%|██████▍ | 88/137 [00:09<00:04, 9.85it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.59it/s]\n",
|
||||
" 65%|██████▍ | 89/137 [00:09<00:04, 9.88it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.62it/s]\n",
|
||||
" 66%|██████▌ | 90/137 [00:09<00:04, 9.86it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.38it/s]\n",
|
||||
" 66%|██████▋ | 91/137 [00:09<00:04, 9.88it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.71it/s]\n",
|
||||
" 67%|██████▋ | 92/137 [00:09<00:04, 9.89it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.85it/s]\n",
|
||||
" 68%|██████▊ | 93/137 [00:09<00:04, 9.87it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.08it/s]\n",
|
||||
" 69%|██████▊ | 94/137 [00:09<00:04, 9.86it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 82.68it/s]\n",
|
||||
" 69%|██████▉ | 95/137 [00:09<00:04, 9.88it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.35it/s]\n",
|
||||
" 70%|███████ | 96/137 [00:09<00:04, 9.85it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.36it/s]\n",
|
||||
" 71%|███████ | 97/137 [00:09<00:04, 9.84it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.28it/s]\n",
|
||||
" 72%|███████▏ | 98/137 [00:10<00:03, 9.85it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.06it/s]\n",
|
||||
" 72%|███████▏ | 99/137 [00:10<00:03, 9.85it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.01it/s]\n",
|
||||
" 73%|███████▎ | 100/137 [00:10<00:03, 9.82it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.70it/s]\n",
|
||||
" 74%|███████▎ | 101/137 [00:10<00:03, 9.85it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.67it/s]\n",
|
||||
" 74%|███████▍ | 102/137 [00:10<00:03, 9.84it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.49it/s]\n",
|
||||
" 75%|███████▌ | 103/137 [00:10<00:03, 9.82it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.27it/s]\n",
|
||||
" 76%|███████▌ | 104/137 [00:10<00:03, 9.84it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.05it/s]\n",
|
||||
" 77%|███████▋ | 105/137 [00:10<00:03, 9.85it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.27it/s]\n",
|
||||
" 77%|███████▋ | 106/137 [00:10<00:03, 9.83it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.00it/s]\n",
|
||||
" 78%|███████▊ | 107/137 [00:10<00:03, 9.79it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 77.18it/s]\n",
|
||||
" 79%|███████▉ | 108/137 [00:11<00:03, 9.65it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 79.07it/s]\n",
|
||||
" 80%|███████▉ | 109/137 [00:11<00:02, 9.63it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.43it/s]\n",
|
||||
" 80%|████████ | 110/137 [00:11<00:02, 9.67it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.85it/s]\n",
|
||||
" 81%|████████ | 111/137 [00:11<00:02, 9.74it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.47it/s]\n",
|
||||
" 82%|████████▏ | 112/137 [00:11<00:02, 9.77it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.02it/s]\n",
|
||||
" 82%|████████▏ | 113/137 [00:11<00:02, 9.76it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.52it/s]\n",
|
||||
" 83%|████████▎ | 114/137 [00:11<00:02, 9.80it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.42it/s]\n",
|
||||
" 84%|████████▍ | 115/137 [00:11<00:02, 9.83it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.95it/s]\n",
|
||||
" 85%|████████▍ | 116/137 [00:11<00:02, 9.84it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.78it/s]\n",
|
||||
" 85%|████████▌ | 117/137 [00:12<00:02, 9.85it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.33it/s]\n",
|
||||
" 86%|████████▌ | 118/137 [00:12<00:01, 9.85it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 82.25it/s]\n",
|
||||
" 87%|████████▋ | 119/137 [00:12<00:01, 9.86it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.09it/s]\n",
|
||||
" 88%|████████▊ | 120/137 [00:12<00:02, 7.50it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 79.78it/s]\n",
|
||||
" 88%|████████▊ | 121/137 [00:12<00:01, 8.04it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.03it/s]\n",
|
||||
" 89%|████████▉ | 122/137 [00:12<00:01, 8.52it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.17it/s]\n",
|
||||
" 90%|████████▉ | 123/137 [00:12<00:01, 8.89it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.38it/s]\n",
|
||||
" 91%|█████████ | 124/137 [00:12<00:01, 9.17it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 72.26it/s]\n",
|
||||
" 91%|█████████ | 125/137 [00:12<00:01, 9.07it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.07it/s]\n",
|
||||
" 92%|█████████▏| 126/137 [00:13<00:01, 9.27it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 80.94it/s]\n",
|
||||
" 93%|█████████▎| 127/137 [00:13<00:01, 9.44it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.67it/s]\n",
|
||||
" 93%|█████████▎| 128/137 [00:13<00:00, 9.57it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 84.00it/s]\n",
|
||||
"\n",
|
||||
"Running evaluation: 7it [00:00, 82.48it/s]\n",
|
||||
" 95%|█████████▍| 130/137 [00:13<00:00, 9.72it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.11it/s]\n",
|
||||
" 96%|█████████▌| 131/137 [00:13<00:00, 9.77it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 79.90it/s]\n",
|
||||
" 96%|█████████▋| 132/137 [00:13<00:00, 9.76it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.81it/s]\n",
|
||||
" 97%|█████████▋| 133/137 [00:13<00:00, 9.80it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.61it/s]\n",
|
||||
" 98%|█████████▊| 134/137 [00:13<00:00, 9.84it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 82.01it/s]\n",
|
||||
" 99%|█████████▊| 135/137 [00:13<00:00, 9.88it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 81.80it/s]\n",
|
||||
" 99%|█████████▉| 136/137 [00:14<00:00, 9.90it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 77.76it/s]\n",
|
||||
"100%|██████████| 137/137 [00:14<00:00, 9.68it/s]\n",
|
||||
"Running evaluation: 7it [00:00, 60.08it/s]\n"
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "e6b529fa683f4e51b5f8ac63db471849",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/99 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"ename": "Exception",
|
||||
"evalue": "Reached maximum number of idle transformation calls.\nThis means the transformation looped over 1 inputs without returning any output.\nThis occurred in the following transformation:\ngluonts.transform.split.InstanceSplitter(dummy_value=0.0, forecast_start_field=\"forecast_start\", future_length=24, instance_sampler=gluonts.transform.sampler.ExpectedNumInstanceSampler(axis=-1, min_past=192, min_future=24, num_instances=1.0, total_length=20382, n=3), is_pad_field=\"is_pad\", lead_time=0, output_NTC=True, past_length=192, start_field=\"start\", target_field=\"target\", time_series_fields=[\"time_feat\", \"observed_values\"])",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mException\u001b[0m Traceback (most recent call last)",
|
||||
"\u001b[1;32m/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/examples/Multivariate-Flow-Solar.ipynb Cell 13\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> <a href='vscode-notebook-cell:/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/examples/Multivariate-Flow-Solar.ipynb#X15sZmlsZQ%3D%3D?line=0'>1</a>\u001b[0m predictor \u001b[39m=\u001b[39m estimator\u001b[39m.\u001b[39;49mtrain(dataset_train)\n\u001b[1;32m <a href='vscode-notebook-cell:/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/examples/Multivariate-Flow-Solar.ipynb#X15sZmlsZQ%3D%3D?line=1'>2</a>\u001b[0m forecast_it, ts_it \u001b[39m=\u001b[39m make_evaluation_predictions(dataset\u001b[39m=\u001b[39mdataset_test,\n\u001b[1;32m <a href='vscode-notebook-cell:/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/examples/Multivariate-Flow-Solar.ipynb#X15sZmlsZQ%3D%3D?line=2'>3</a>\u001b[0m predictor\u001b[39m=\u001b[39mpredictor,\n\u001b[1;32m <a href='vscode-notebook-cell:/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/examples/Multivariate-Flow-Solar.ipynb#X15sZmlsZQ%3D%3D?line=3'>4</a>\u001b[0m num_samples\u001b[39m=\u001b[39m\u001b[39m100\u001b[39m)\n\u001b[1;32m <a href='vscode-notebook-cell:/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/examples/Multivariate-Flow-Solar.ipynb#X15sZmlsZQ%3D%3D?line=4'>5</a>\u001b[0m forecasts \u001b[39m=\u001b[39m \u001b[39mlist\u001b[39m(forecast_it)\n",
|
||||
"File \u001b[0;32m/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/pts/model/estimator.py:179\u001b[0m, in \u001b[0;36mPyTorchEstimator.train\u001b[0;34m(self, training_data, validation_data, num_workers, prefetch_factor, shuffle_buffer_length, cache_data, **kwargs)\u001b[0m\n\u001b[1;32m 169\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mtrain\u001b[39m(\n\u001b[1;32m 170\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 171\u001b[0m training_data: Dataset,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 177\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs,\n\u001b[1;32m 178\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m PyTorchPredictor:\n\u001b[0;32m--> 179\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtrain_model(\n\u001b[1;32m 180\u001b[0m training_data,\n\u001b[1;32m 181\u001b[0m validation_data,\n\u001b[1;32m 182\u001b[0m num_workers\u001b[39m=\u001b[39;49mnum_workers,\n\u001b[1;32m 183\u001b[0m prefetch_factor\u001b[39m=\u001b[39;49mprefetch_factor,\n\u001b[1;32m 184\u001b[0m shuffle_buffer_length\u001b[39m=\u001b[39;49mshuffle_buffer_length,\n\u001b[1;32m 185\u001b[0m cache_data\u001b[39m=\u001b[39;49mcache_data,\n\u001b[1;32m 186\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs,\n\u001b[1;32m 187\u001b[0m )\u001b[39m.\u001b[39mpredictor\n",
|
||||
"File \u001b[0;32m/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/pts/model/estimator.py:151\u001b[0m, in \u001b[0;36mPyTorchEstimator.train_model\u001b[0;34m(self, training_data, validation_data, num_workers, prefetch_factor, shuffle_buffer_length, cache_data, **kwargs)\u001b[0m\n\u001b[1;32m 133\u001b[0m validation_iter_dataset \u001b[39m=\u001b[39m TransformedIterableDataset(\n\u001b[1;32m 134\u001b[0m dataset\u001b[39m=\u001b[39mvalidation_data,\n\u001b[1;32m 135\u001b[0m transform\u001b[39m=\u001b[39mtransformation\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 139\u001b[0m cache_data\u001b[39m=\u001b[39mcache_data,\n\u001b[1;32m 140\u001b[0m )\n\u001b[1;32m 141\u001b[0m validation_data_loader \u001b[39m=\u001b[39m DataLoader(\n\u001b[1;32m 142\u001b[0m validation_iter_dataset,\n\u001b[1;32m 143\u001b[0m batch_size\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtrainer\u001b[39m.\u001b[39mbatch_size,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 148\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs,\n\u001b[1;32m 149\u001b[0m )\n\u001b[0;32m--> 151\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtrainer(\n\u001b[1;32m 152\u001b[0m net\u001b[39m=\u001b[39;49mtrained_net,\n\u001b[1;32m 153\u001b[0m train_iter\u001b[39m=\u001b[39;49mtraining_data_loader,\n\u001b[1;32m 154\u001b[0m validation_iter\u001b[39m=\u001b[39;49mvalidation_data_loader,\n\u001b[1;32m 155\u001b[0m )\n\u001b[1;32m 157\u001b[0m \u001b[39mreturn\u001b[39;00m TrainOutput(\n\u001b[1;32m 158\u001b[0m transformation\u001b[39m=\u001b[39mtransformation,\n\u001b[1;32m 159\u001b[0m trained_net\u001b[39m=\u001b[39mtrained_net,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 162\u001b[0m ),\n\u001b[1;32m 163\u001b[0m )\n",
|
||||
"File \u001b[0;32m/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/pts/trainer.py:63\u001b[0m, in \u001b[0;36mTrainer.__call__\u001b[0;34m(self, net, train_iter, validation_iter)\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[39m# training loop\u001b[39;00m\n\u001b[1;32m 62\u001b[0m \u001b[39mwith\u001b[39;00m tqdm(train_iter, total\u001b[39m=\u001b[39mtotal) \u001b[39mas\u001b[39;00m it:\n\u001b[0;32m---> 63\u001b[0m \u001b[39mfor\u001b[39;00m batch_no, data_entry \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(it, start\u001b[39m=\u001b[39m\u001b[39m1\u001b[39m):\n\u001b[1;32m 64\u001b[0m optimizer\u001b[39m.\u001b[39mzero_grad()\n\u001b[1;32m 66\u001b[0m inputs \u001b[39m=\u001b[39m [v\u001b[39m.\u001b[39mto(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdevice) \u001b[39mfor\u001b[39;00m v \u001b[39min\u001b[39;00m data_entry\u001b[39m.\u001b[39mvalues()]\n",
|
||||
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/tqdm/notebook.py:259\u001b[0m, in \u001b[0;36mtqdm_notebook.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 257\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 258\u001b[0m it \u001b[39m=\u001b[39m \u001b[39msuper\u001b[39m(tqdm_notebook, \u001b[39mself\u001b[39m)\u001b[39m.\u001b[39m\u001b[39m__iter__\u001b[39m()\n\u001b[0;32m--> 259\u001b[0m \u001b[39mfor\u001b[39;00m obj \u001b[39min\u001b[39;00m it:\n\u001b[1;32m 260\u001b[0m \u001b[39m# return super(tqdm...) will not catch exception\u001b[39;00m\n\u001b[1;32m 261\u001b[0m \u001b[39myield\u001b[39;00m obj\n\u001b[1;32m 262\u001b[0m \u001b[39m# NB: except ... [ as ...] breaks IPython async KeyboardInterrupt\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/tqdm/std.py:1195\u001b[0m, in \u001b[0;36mtqdm.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1192\u001b[0m time \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_time\n\u001b[1;32m 1194\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m-> 1195\u001b[0m \u001b[39mfor\u001b[39;00m obj \u001b[39min\u001b[39;00m iterable:\n\u001b[1;32m 1196\u001b[0m \u001b[39myield\u001b[39;00m obj\n\u001b[1;32m 1197\u001b[0m \u001b[39m# Update and possibly print the progressbar.\u001b[39;00m\n\u001b[1;32m 1198\u001b[0m \u001b[39m# Note: does not call self.update(1) for speed optimisation.\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/torch/utils/data/dataloader.py:628\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 625\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_sampler_iter \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 626\u001b[0m \u001b[39m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[1;32m 627\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_reset() \u001b[39m# type: ignore[call-arg]\u001b[39;00m\n\u001b[0;32m--> 628\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_next_data()\n\u001b[1;32m 629\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\n\u001b[1;32m 630\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_dataset_kind \u001b[39m==\u001b[39m _DatasetKind\u001b[39m.\u001b[39mIterable \u001b[39mand\u001b[39;00m \\\n\u001b[1;32m 631\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m \\\n\u001b[1;32m 632\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m>\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called:\n",
|
||||
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/torch/utils/data/dataloader.py:671\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 669\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_next_data\u001b[39m(\u001b[39mself\u001b[39m):\n\u001b[1;32m 670\u001b[0m index \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_next_index() \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m--> 671\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_dataset_fetcher\u001b[39m.\u001b[39;49mfetch(index) \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m 672\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_pin_memory:\n\u001b[1;32m 673\u001b[0m data \u001b[39m=\u001b[39m _utils\u001b[39m.\u001b[39mpin_memory\u001b[39m.\u001b[39mpin_memory(data, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_pin_memory_device)\n",
|
||||
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py:34\u001b[0m, in \u001b[0;36m_IterableDatasetFetcher.fetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[39mfor\u001b[39;00m _ \u001b[39min\u001b[39;00m possibly_batched_index:\n\u001b[1;32m 33\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m---> 34\u001b[0m data\u001b[39m.\u001b[39mappend(\u001b[39mnext\u001b[39;49m(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdataset_iter))\n\u001b[1;32m 35\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mStopIteration\u001b[39;00m:\n\u001b[1;32m 36\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mended \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/gluonts/transform/_base.py:103\u001b[0m, in \u001b[0;36mTransformedDataset.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 102\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__iter__\u001b[39m(\u001b[39mself\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Iterator[DataEntry]:\n\u001b[0;32m--> 103\u001b[0m \u001b[39myield from\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtransformation(\n\u001b[1;32m 104\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbase_dataset, is_train\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mis_train\n\u001b[1;32m 105\u001b[0m )\n",
|
||||
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/gluonts/transform/_base.py:124\u001b[0m, in \u001b[0;36mMapTransformation.__call__\u001b[0;34m(self, data_it, is_train)\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__call__\u001b[39m(\n\u001b[1;32m 122\u001b[0m \u001b[39mself\u001b[39m, data_it: Iterable[DataEntry], is_train: \u001b[39mbool\u001b[39m\n\u001b[1;32m 123\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Iterator:\n\u001b[0;32m--> 124\u001b[0m \u001b[39mfor\u001b[39;00m data_entry \u001b[39min\u001b[39;00m data_it:\n\u001b[1;32m 125\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 126\u001b[0m \u001b[39myield\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmap_transform(data_entry\u001b[39m.\u001b[39mcopy(), is_train)\n",
|
||||
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/gluonts/transform/_base.py:124\u001b[0m, in \u001b[0;36mMapTransformation.__call__\u001b[0;34m(self, data_it, is_train)\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__call__\u001b[39m(\n\u001b[1;32m 122\u001b[0m \u001b[39mself\u001b[39m, data_it: Iterable[DataEntry], is_train: \u001b[39mbool\u001b[39m\n\u001b[1;32m 123\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Iterator:\n\u001b[0;32m--> 124\u001b[0m \u001b[39mfor\u001b[39;00m data_entry \u001b[39min\u001b[39;00m data_it:\n\u001b[1;32m 125\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 126\u001b[0m \u001b[39myield\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmap_transform(data_entry\u001b[39m.\u001b[39mcopy(), is_train)\n",
|
||||
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/gluonts/transform/_base.py:189\u001b[0m, in \u001b[0;36mFlatMapTransformation.__call__\u001b[0;34m(self, data_it, is_train)\u001b[0m\n\u001b[1;32m 182\u001b[0m \u001b[39myield\u001b[39;00m result\n\u001b[1;32m 184\u001b[0m \u001b[39mif\u001b[39;00m (\n\u001b[1;32m 185\u001b[0m \u001b[39m# negative values disable the check\u001b[39;00m\n\u001b[1;32m 186\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmax_idle_transforms \u001b[39m>\u001b[39m \u001b[39m0\u001b[39m\n\u001b[1;32m 187\u001b[0m \u001b[39mand\u001b[39;00m num_idle_transforms \u001b[39m>\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmax_idle_transforms\n\u001b[1;32m 188\u001b[0m ):\n\u001b[0;32m--> 189\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mException\u001b[39;00m(\n\u001b[1;32m 190\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mReached maximum number of idle transformation\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 191\u001b[0m \u001b[39m\"\u001b[39m\u001b[39m calls.\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39mThis means the transformation looped over\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 192\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m \u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmax_idle_transforms\u001b[39m}\u001b[39;00m\u001b[39m inputs without returning any\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 193\u001b[0m \u001b[39m\"\u001b[39m\u001b[39m output.\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39mThis occurred in the following\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 194\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m transformation:\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m\n\u001b[1;32m 195\u001b[0m )\n",
|
||||
"\u001b[0;31mException\u001b[0m: Reached maximum number of idle transformation calls.\nThis means the transformation looped over 1 inputs without returning any output.\nThis occurred in the following transformation:\ngluonts.transform.split.InstanceSplitter(dummy_value=0.0, forecast_start_field=\"forecast_start\", future_length=24, instance_sampler=gluonts.transform.sampler.ExpectedNumInstanceSampler(axis=-1, min_past=192, min_future=24, num_instances=1.0, total_length=20382, n=3), is_pad_field=\"is_pad\", lead_time=0, output_NTC=True, past_length=192, start_field=\"start\", target_field=\"target\", time_series_fields=[\"time_feat\", \"observed_values\"])"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -503,7 +213,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 47,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -526,7 +236,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 48,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -556,7 +266,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -579,7 +289,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -911,7 +621,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -934,7 +644,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -964,7 +674,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -990,7 +700,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -1159,7 +869,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -1182,7 +892,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -1213,7 +923,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "Python 3.9.15 ('glounts')",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -1227,7 +937,12 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.2"
|
||||
"version": "3.9.15"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "7f25a1f13147a60511cf6766827402baf95cbe50d53a241197155306ee38fe70"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -2,8 +2,32 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:21:04.900941Z",
|
||||
"start_time": "2022-12-23T06:21:04.889808Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# import warnings\n",
|
||||
"# warnings.simplefilter(\"ignore\")\n",
|
||||
"\n",
|
||||
"# autoreload import your package\n",
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:21:06.416555Z",
|
||||
"start_time": "2022-12-23T06:21:04.902165Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%matplotlib inline\n",
|
||||
@@ -17,8 +41,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:21:06.625752Z",
|
||||
"start_time": "2022-12-23T06:21:06.418508Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from gluonts.dataset.multivariate_grouper import MultivariateGrouper\n",
|
||||
@@ -29,8 +58,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:21:06.738760Z",
|
||||
"start_time": "2022-12-23T06:21:06.627406Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from pts.model.tempflow import TempFlowEstimator\n",
|
||||
@@ -41,8 +75,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:21:06.790762Z",
|
||||
"start_time": "2022-12-23T06:21:06.740481Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
|
||||
@@ -50,8 +89,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:21:06.817593Z",
|
||||
"start_time": "2022-12-23T06:21:06.791945Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def plot(target, forecast, prediction_length, prediction_intervals=(50.0, 90.0), color='g', fname=None):\n",
|
||||
@@ -113,17 +157,34 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"execution_count": 7,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:21:06.839629Z",
|
||||
"start_time": "2022-12-23T06:21:06.818741Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Available datasets: ['constant', 'exchange_rate', 'solar-energy', 'electricity', 'traffic', 'exchange_rate_nips', 'electricity_nips', 'traffic_nips', 'solar_nips', 'wiki-rolling_nips', 'taxi_30min', 'kaggle_web_traffic_with_missing', 'kaggle_web_traffic_without_missing', 'kaggle_web_traffic_weekly', 'm1_yearly', 'm1_quarterly', 'm1_monthly', 'nn5_daily_with_missing', 'nn5_daily_without_missing', 'nn5_weekly', 'tourism_monthly', 'tourism_quarterly', 'tourism_yearly', 'cif_2016', 'london_smart_meters_without_missing', 'wind_farms_without_missing', 'car_parts_without_missing', 'dominick', 'fred_md', 'pedestrian_counts', 'hospital', 'covid_deaths', 'kdd_cup_2018_without_missing', 'weather', 'm3_monthly', 'm3_quarterly', 'm3_yearly', 'm3_other', 'm4_hourly', 'm4_daily', 'm4_weekly', 'm4_monthly', 'm4_quarterly', 'm4_yearly', 'm5', 'uber_tlc_daily', 'uber_tlc_hourly', 'airpassengers']\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(f\"Available datasets: {list(dataset_recipes.keys())}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:21:06.856379Z",
|
||||
"start_time": "2022-12-23T06:21:06.840657Z"
|
||||
},
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -134,51 +195,57 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"execution_count": 9,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:21:06.874723Z",
|
||||
"start_time": "2022-12-23T06:21:06.858646Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"MetaData(freq='H', target=None, feat_static_cat=[CategoricalFeatureInfo(name='feat_static_cat_0', cardinality='370')], feat_static_real=[], feat_dynamic_real=[], feat_dynamic_cat=[], prediction_length=24)"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"dataset.metadata"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 51,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"3.5"
|
||||
]
|
||||
},
|
||||
"execution_count": 51,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
"execution_count": 10,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:21:07.663837Z",
|
||||
"start_time": "2022-12-23T06:21:06.875746Z"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"len(dataset.test)/len(dataset.train)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 54,
|
||||
"metadata": {},
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_grouper = MultivariateGrouper(max_target_dim=min(2000, int(dataset.metadata.feat_static_cat[0].cardinality)))\n",
|
||||
"\n",
|
||||
"test_grouper = MultivariateGrouper(\n",
|
||||
" # num_test_dates=int(len(dataset.test)/len(dataset.train)), \n",
|
||||
" num_test_dates=1,\n",
|
||||
" num_test_dates=int(len(dataset.test)/len(dataset.train)*2),\n",
|
||||
"# num_test_dates=7,\n",
|
||||
" max_target_dim=min(2000, int(dataset.metadata.feat_static_cat[0].cardinality)))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 57,
|
||||
"metadata": {},
|
||||
"execution_count": 11,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:21:12.079230Z",
|
||||
"start_time": "2022-12-23T06:21:07.665093Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
@@ -187,20 +254,6 @@
|
||||
"/home/wassname/miniforge3/envs/glounts/lib/python3.9/site-packages/gluonts/dataset/multivariate_grouper.py:191: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n",
|
||||
" return {FieldName.TARGET: np.array([funcs(data) for data in dataset])}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "ValueError",
|
||||
"evalue": "setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (2590,) + inhomogeneous part.",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
|
||||
"\u001b[1;32m/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/examples/Time-Grad2-Electricity.ipynb Cell 13\u001b[0m in \u001b[0;36m<cell line: 2>\u001b[0;34m()\u001b[0m\n\u001b[1;32m <a href='vscode-notebook-cell:/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/examples/Time-Grad2-Electricity.ipynb#X12sZmlsZQ%3D%3D?line=0'>1</a>\u001b[0m dataset_train \u001b[39m=\u001b[39m train_grouper(dataset\u001b[39m.\u001b[39mtrain)\n\u001b[0;32m----> <a href='vscode-notebook-cell:/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/examples/Time-Grad2-Electricity.ipynb#X12sZmlsZQ%3D%3D?line=1'>2</a>\u001b[0m dataset_test \u001b[39m=\u001b[39m test_grouper(dataset\u001b[39m.\u001b[39;49mtest)\n",
|
||||
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/gluonts/dataset/multivariate_grouper.py:87\u001b[0m, in \u001b[0;36mMultivariateGrouper.__call__\u001b[0;34m(self, dataset)\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__call__\u001b[39m(\u001b[39mself\u001b[39m, dataset: Dataset) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Dataset:\n\u001b[1;32m 86\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_preprocess(dataset)\n\u001b[0;32m---> 87\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_group_all(dataset)\n",
|
||||
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/gluonts/dataset/multivariate_grouper.py:125\u001b[0m, in \u001b[0;36mMultivariateGrouper._group_all\u001b[0;34m(self, dataset)\u001b[0m\n\u001b[1;32m 123\u001b[0m grouped_dataset \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_prepare_train_data(dataset)\n\u001b[1;32m 124\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 125\u001b[0m grouped_dataset \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_prepare_test_data(dataset)\n\u001b[1;32m 126\u001b[0m \u001b[39mreturn\u001b[39;00m grouped_dataset\n",
|
||||
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/gluonts/dataset/multivariate_grouper.py:155\u001b[0m, in \u001b[0;36mMultivariateGrouper._prepare_test_data\u001b[0;34m(self, dataset)\u001b[0m\n\u001b[1;32m 153\u001b[0m \u001b[39mfor\u001b[39;00m dataset_at_test_date \u001b[39min\u001b[39;00m split_dataset:\n\u001b[1;32m 154\u001b[0m grouped_data \u001b[39m=\u001b[39m \u001b[39mdict\u001b[39m()\n\u001b[0;32m--> 155\u001b[0m grouped_data[FieldName\u001b[39m.\u001b[39mTARGET] \u001b[39m=\u001b[39m np\u001b[39m.\u001b[39;49marray(\n\u001b[1;32m 156\u001b[0m \u001b[39mlist\u001b[39;49m(dataset_at_test_date), dtype\u001b[39m=\u001b[39;49mnp\u001b[39m.\u001b[39;49mfloat32\n\u001b[1;32m 157\u001b[0m )\n\u001b[1;32m 158\u001b[0m grouped_data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_restrict_max_dimensionality(grouped_data)\n\u001b[1;32m 159\u001b[0m grouped_data[FieldName\u001b[39m.\u001b[39mSTART] \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mfirst_timestamp\n",
|
||||
"\u001b[0;31mValueError\u001b[0m: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (2590,) + inhomogeneous part."
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
@@ -210,34 +263,63 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 60,
|
||||
"metadata": {},
|
||||
"execution_count": 12,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:21:12.096888Z",
|
||||
"start_time": "2022-12-23T06:21:12.080635Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"> \u001b[0;32m/home/wassname/miniforge3/envs/glounts/lib/python3.9/site-packages/gluonts/dataset/multivariate_grouper.py\u001b[0m(155)\u001b[0;36m_prepare_test_data\u001b[0;34m()\u001b[0m\n",
|
||||
"\u001b[0;32m 153 \u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mdataset_at_test_date\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msplit_dataset\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 154 \u001b[0;31m \u001b[0mgrouped_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m--> 155 \u001b[0;31m grouped_data[FieldName.TARGET] = np.array(\n",
|
||||
"\u001b[0m\u001b[0;32m 156 \u001b[0;31m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset_at_test_date\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat32\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 157 \u001b[0;31m )\n",
|
||||
"\u001b[0m\n",
|
||||
"2590\n",
|
||||
"(2590,)\n",
|
||||
"(2590,)\n"
|
||||
]
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"370"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%debug"
|
||||
"int(dataset.metadata.feat_static_cat[0].cardinality)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"execution_count": 13,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:21:12.113235Z",
|
||||
"start_time": "2022-12-23T06:21:12.098346Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"370"
|
||||
]
|
||||
},
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"int(dataset.metadata.feat_static_cat[0].cardinality)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 57,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:33:27.308113Z",
|
||||
"start_time": "2022-12-23T06:33:23.559765Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"estimator = TimeGradEstimator2(\n",
|
||||
@@ -248,7 +330,7 @@
|
||||
" input_size=1484,\n",
|
||||
" freq=dataset.metadata.freq,\n",
|
||||
" loss_type='l2',\n",
|
||||
" scaling=True,\n",
|
||||
" scaling=False,\n",
|
||||
" diff_steps=100,\n",
|
||||
" beta_end=0.1,\n",
|
||||
" beta_schedule=\"linear\",\n",
|
||||
@@ -262,17 +344,144 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"execution_count": 58,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:33:28.565303Z",
|
||||
"start_time": "2022-12-23T06:33:27.309629Z"
|
||||
},
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"cond_length 100\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "f7efc1b8dbd24c0d92ecfe54a25359a1",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/99 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"shapes torch.Size([64, 370, 48]) torch.Size([64]) torch.Size([64, 100, 48])\n",
|
||||
"cond -> cond_up torch.Size([64, 100, 48]) torch.Size([64, 370, 48])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/pts/model/time_grad2/gaussian_diffusion_ou.py:283: UserWarning: Using a target size (torch.Size([64, 1, 48])) that is different to the input size (torch.Size([64, 370, 46])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
|
||||
" loss = F.mse_loss(x_recon, noise_rand)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "RuntimeError",
|
||||
"evalue": "The size of tensor a (46) must match the size of tensor b (48) at non-singleton dimension 2",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
||||
"Input \u001b[0;32mIn [58]\u001b[0m, in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m predictor \u001b[38;5;241m=\u001b[39m \u001b[43mestimator\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdataset_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_workers\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/pts/model/estimator.py:179\u001b[0m, in \u001b[0;36mPyTorchEstimator.train\u001b[0;34m(self, training_data, validation_data, num_workers, prefetch_factor, shuffle_buffer_length, cache_data, **kwargs)\u001b[0m\n\u001b[1;32m 169\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mtrain\u001b[39m(\n\u001b[1;32m 170\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 171\u001b[0m training_data: Dataset,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 177\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m 178\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m PyTorchPredictor:\n\u001b[0;32m--> 179\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_model\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 180\u001b[0m \u001b[43m \u001b[49m\u001b[43mtraining_data\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 181\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalidation_data\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 182\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_workers\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnum_workers\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 183\u001b[0m \u001b[43m \u001b[49m\u001b[43mprefetch_factor\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprefetch_factor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 184\u001b[0m \u001b[43m \u001b[49m\u001b[43mshuffle_buffer_length\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mshuffle_buffer_length\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 185\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_data\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_data\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 186\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 187\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mpredictor\n",
|
||||
"File \u001b[0;32m/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/pts/model/estimator.py:151\u001b[0m, in \u001b[0;36mPyTorchEstimator.train_model\u001b[0;34m(self, training_data, validation_data, num_workers, prefetch_factor, shuffle_buffer_length, cache_data, **kwargs)\u001b[0m\n\u001b[1;32m 133\u001b[0m validation_iter_dataset \u001b[38;5;241m=\u001b[39m TransformedIterableDataset(\n\u001b[1;32m 134\u001b[0m dataset\u001b[38;5;241m=\u001b[39mvalidation_data,\n\u001b[1;32m 135\u001b[0m transform\u001b[38;5;241m=\u001b[39mtransformation\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 139\u001b[0m cache_data\u001b[38;5;241m=\u001b[39mcache_data,\n\u001b[1;32m 140\u001b[0m )\n\u001b[1;32m 141\u001b[0m validation_data_loader \u001b[38;5;241m=\u001b[39m DataLoader(\n\u001b[1;32m 142\u001b[0m validation_iter_dataset,\n\u001b[1;32m 143\u001b[0m batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mbatch_size,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 148\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m 149\u001b[0m )\n\u001b[0;32m--> 151\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrainer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 152\u001b[0m \u001b[43m \u001b[49m\u001b[43mnet\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrained_net\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 153\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain_iter\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtraining_data_loader\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 154\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalidation_iter\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalidation_data_loader\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 155\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 157\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m TrainOutput(\n\u001b[1;32m 158\u001b[0m transformation\u001b[38;5;241m=\u001b[39mtransformation,\n\u001b[1;32m 159\u001b[0m trained_net\u001b[38;5;241m=\u001b[39mtrained_net,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 162\u001b[0m ),\n\u001b[1;32m 163\u001b[0m )\n",
|
||||
"File \u001b[0;32m/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/pts/trainer.py:67\u001b[0m, in \u001b[0;36mTrainer.__call__\u001b[0;34m(self, net, train_iter, validation_iter)\u001b[0m\n\u001b[1;32m 64\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[1;32m 66\u001b[0m inputs \u001b[38;5;241m=\u001b[39m [v\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice) \u001b[38;5;28;01mfor\u001b[39;00m v \u001b[38;5;129;01min\u001b[39;00m data_entry\u001b[38;5;241m.\u001b[39mvalues()]\n\u001b[0;32m---> 67\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mnet\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 69\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(output, (\u001b[38;5;28mlist\u001b[39m, \u001b[38;5;28mtuple\u001b[39m)):\n\u001b[1;32m 70\u001b[0m loss \u001b[38;5;241m=\u001b[39m output[\u001b[38;5;241m0\u001b[39m]\n",
|
||||
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
|
||||
"File \u001b[0;32m/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/pts/model/time_grad2/time_grad_network.py:407\u001b[0m, in \u001b[0;36mTimeGradTrainingNetwork2.forward\u001b[0;34m(self, target_dimension_indicator, past_time_feat, past_target_cdf, past_observed_values, past_is_pad, future_time_feat, future_target_cdf, future_observed_values)\u001b[0m\n\u001b[1;32m 405\u001b[0m target \u001b[38;5;241m=\u001b[39m target\u001b[38;5;241m.\u001b[39mpermute(\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 406\u001b[0m distr_args \u001b[38;5;241m=\u001b[39m distr_args\u001b[38;5;241m.\u001b[39mpermute(\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m--> 407\u001b[0m likelihoods \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdiffusion\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlog_prob\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdistr_args\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 409\u001b[0m \u001b[38;5;66;03m# assert_shape(likelihoods, (-1, seq_len, 1))\u001b[39;00m\n\u001b[1;32m 411\u001b[0m past_observed_values \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mmin(\n\u001b[1;32m 412\u001b[0m past_observed_values, \u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m past_is_pad\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 413\u001b[0m )\n",
|
||||
"File \u001b[0;32m/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/pts/model/time_grad2/gaussian_diffusion_ou.py:298\u001b[0m, in \u001b[0;36mGaussianDiffusionOU.log_prob\u001b[0;34m(self, x, cond, *args, **kwargs)\u001b[0m\n\u001b[1;32m 295\u001b[0m B, T, _ \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mshape\n\u001b[1;32m 297\u001b[0m time \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mrandint(\u001b[38;5;241m0\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_timesteps, (B,), device\u001b[38;5;241m=\u001b[39mx\u001b[38;5;241m.\u001b[39mdevice)\u001b[38;5;241m.\u001b[39mlong()\n\u001b[0;32m--> 298\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mp_losses\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 299\u001b[0m \u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcond\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtime\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m 300\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 302\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\n",
|
||||
"File \u001b[0;32m/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/pts/model/time_grad2/gaussian_diffusion_ou.py:283\u001b[0m, in \u001b[0;36mGaussianDiffusionOU.p_losses\u001b[0;34m(self, x_start, cond, t)\u001b[0m\n\u001b[1;32m 281\u001b[0m loss \u001b[38;5;241m=\u001b[39m F\u001b[38;5;241m.\u001b[39ml1_loss(x_recon, noise_rand)\n\u001b[1;32m 282\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mloss_type \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124ml2\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m--> 283\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmse_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx_recon\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnoise_rand\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 284\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mloss_type \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhuber\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 285\u001b[0m loss \u001b[38;5;241m=\u001b[39m F\u001b[38;5;241m.\u001b[39msmooth_l1_loss(x_recon, noise_rand)\n",
|
||||
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/torch/nn/functional.py:3291\u001b[0m, in \u001b[0;36mmse_loss\u001b[0;34m(input, target, size_average, reduce, reduction)\u001b[0m\n\u001b[1;32m 3288\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m size_average \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mor\u001b[39;00m reduce \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 3289\u001b[0m reduction \u001b[38;5;241m=\u001b[39m _Reduction\u001b[38;5;241m.\u001b[39mlegacy_get_string(size_average, reduce)\n\u001b[0;32m-> 3291\u001b[0m expanded_input, expanded_target \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbroadcast_tensors\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3292\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_C\u001b[38;5;241m.\u001b[39m_nn\u001b[38;5;241m.\u001b[39mmse_loss(expanded_input, expanded_target, _Reduction\u001b[38;5;241m.\u001b[39mget_enum(reduction))\n",
|
||||
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/torch/functional.py:74\u001b[0m, in \u001b[0;36mbroadcast_tensors\u001b[0;34m(*tensors)\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function(tensors):\n\u001b[1;32m 73\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(broadcast_tensors, tensors, \u001b[38;5;241m*\u001b[39mtensors)\n\u001b[0;32m---> 74\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_VF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbroadcast_tensors\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (46) must match the size of tensor b (48) at non-singleton dimension 2"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"predictor = estimator.train(dataset_train, num_workers=8)"
|
||||
"predictor = estimator.train(dataset_train, num_workers=0)\n",
|
||||
"# predictor = estimator.train(dataset_train, num_workers=8)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-12-23T06:33:31.564Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"> \u001b[0;32m/home/wassname/miniforge3/envs/glounts/lib/python3.9/site-packages/torch/functional.py\u001b[0m(74)\u001b[0;36mbroadcast_tensors\u001b[0;34m()\u001b[0m\n",
|
||||
"\u001b[0;32m 72 \u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mhas_torch_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensors\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 73 \u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mhandle_torch_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbroadcast_tensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mtensors\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m---> 74 \u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_VF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbroadcast_tensors\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensors\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[attr-defined]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 75 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 76 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\n",
|
||||
"ipdb> u\n",
|
||||
"> \u001b[0;32m/home/wassname/miniforge3/envs/glounts/lib/python3.9/site-packages/torch/nn/functional.py\u001b[0m(3291)\u001b[0;36mmse_loss\u001b[0;34m()\u001b[0m\n",
|
||||
"\u001b[0;32m 3289 \u001b[0;31m \u001b[0mreduction\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_Reduction\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlegacy_get_string\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msize_average\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 3290 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m-> 3291 \u001b[0;31m \u001b[0mexpanded_input\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexpanded_target\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbroadcast_tensors\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 3292 \u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_nn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmse_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexpanded_input\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexpanded_target\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_Reduction\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_enum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreduction\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 3293 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\n",
|
||||
"ipdb> u\n",
|
||||
"> \u001b[0;32m/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/pts/model/time_grad2/gaussian_diffusion_ou.py\u001b[0m(283)\u001b[0;36mp_losses\u001b[0;34m()\u001b[0m\n",
|
||||
"\u001b[0;32m 281 \u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0ml1_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_recon\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnoise_rand\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 282 \u001b[0;31m \u001b[0;32melif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloss_type\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"l2\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m--> 283 \u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmse_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_recon\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnoise_rand\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 284 \u001b[0;31m \u001b[0;32melif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloss_type\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"huber\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 285 \u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msmooth_l1_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_recon\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnoise_rand\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\n",
|
||||
"ipdb> u\n",
|
||||
"> \u001b[0;32m/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/pts/model/time_grad2/gaussian_diffusion_ou.py\u001b[0m(298)\u001b[0;36mlog_prob\u001b[0;34m()\u001b[0m\n",
|
||||
"\u001b[0;32m 296 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 297 \u001b[0;31m \u001b[0mtime\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_timesteps\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mB\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlong\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m--> 298 \u001b[0;31m loss = self.p_losses(\n",
|
||||
"\u001b[0m\u001b[0;32m 299 \u001b[0;31m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcond\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 300 \u001b[0;31m )\n",
|
||||
"\u001b[0m\n",
|
||||
"ipdb> x.shape\n",
|
||||
"torch.Size([64, 370, 48])\n",
|
||||
"ipdb> cond.shape\n",
|
||||
"torch.Size([64, 100, 48])\n",
|
||||
"ipdb> time.shape\n",
|
||||
"torch.Size([64])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%debug"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:21:12.613548Z",
|
||||
"start_time": "2022-12-23T06:21:12.613540Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"forecast_it, ts_it = make_evaluation_predictions(dataset=dataset_test,\n",
|
||||
@@ -283,7 +492,12 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:21:12.614197Z",
|
||||
"start_time": "2022-12-23T06:21:12.614190Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"forecasts = list(forecast_it)\n",
|
||||
@@ -293,7 +507,12 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:21:12.614817Z",
|
||||
"start_time": "2022-12-23T06:21:12.614810Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plot(\n",
|
||||
@@ -307,7 +526,12 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:21:12.615375Z",
|
||||
"start_time": "2022-12-23T06:21:12.615368Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"evaluator = MultivariateEvaluator(quantiles=(np.arange(20)/20.0)[1:], \n",
|
||||
@@ -318,6 +542,10 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:21:12.615882Z",
|
||||
"start_time": "2022-12-23T06:21:12.615875Z"
|
||||
},
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -328,7 +556,12 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:21:12.616531Z",
|
||||
"start_time": "2022-12-23T06:21:12.616525Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"CRPS:\", agg_metric[\"mean_wQuantileLoss\"])\n",
|
||||
@@ -340,6 +573,304 @@
|
||||
"print(\"NRMSE-Sum:\", agg_metric[\"m_sum_NRMSE\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Scratch: debug transforms\n",
|
||||
"\n",
|
||||
"See estimator"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:21:12.617247Z",
|
||||
"start_time": "2022-12-23T06:21:12.617240Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"predictor = estimator.train(dataset_train, num_workers=0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:21:12.617846Z",
|
||||
"start_time": "2022-12-23T06:21:12.617838Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%debug"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:21:12.618406Z",
|
||||
"start_time": "2022-12-23T06:21:12.618396Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from tqdm.auto import tqdm"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:21:12.619142Z",
|
||||
"start_time": "2022-12-23T06:21:12.619132Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"transformation = estimator.create_transformation()\n",
|
||||
"training_instance_splitter = estimator.create_instance_splitter(\"training\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:21:12.619649Z",
|
||||
"start_time": "2022-12-23T06:21:12.619642Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"g1 = trans.apply(dataset_train, is_train=True)\n",
|
||||
"b = next(iter(g1))\n",
|
||||
"b.keys()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:21:12.620194Z",
|
||||
"start_time": "2022-12-23T06:21:12.620187Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for _ in tqdm(g1):\n",
|
||||
" pass"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:21:12.620996Z",
|
||||
"start_time": "2022-12-23T06:21:12.620988Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"transform = transformation + training_instance_splitter\n",
|
||||
"g2 = transform.apply(dataset_train, is_train=True)\n",
|
||||
"gg = iter(g2)\n",
|
||||
"b = next(gg)\n",
|
||||
"b.keys()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:21:12.621464Z",
|
||||
"start_time": "2022-12-23T06:21:12.621458Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for _ in tqdm(g2):\n",
|
||||
" pass"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:21:12.622209Z",
|
||||
"start_time": "2022-12-23T06:21:12.622202Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"b = next(gg)\n",
|
||||
"b.keys()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:21:12.622763Z",
|
||||
"start_time": "2022-12-23T06:21:12.622756Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from pts.model import get_module_forward_input_names\n",
|
||||
"from gluonts.transform import SelectFields, Transformation\n",
|
||||
"\n",
|
||||
"trained_net = estimator.create_training_network(estimator.trainer.device)\n",
|
||||
"input_names = get_module_forward_input_names(trained_net)\n",
|
||||
"transform = transformation + training_instance_splitter + SelectFields(input_names)\n",
|
||||
"g = transform.apply(dataset_train, is_train=True)\n",
|
||||
"b = next(iter(g))\n",
|
||||
"b.keys()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T05:18:19.966830Z",
|
||||
"start_time": "2022-12-23T05:18:19.964901Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:21:12.623279Z",
|
||||
"start_time": "2022-12-23T06:21:12.623272Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from pts.dataset.loader import TransformedIterableDataset\n",
|
||||
"training_data = dataset_train\n",
|
||||
"training_iter_dataset = TransformedIterableDataset(\n",
|
||||
" dataset=training_data,\n",
|
||||
" transform=transformation\n",
|
||||
" + training_instance_splitter\n",
|
||||
" + SelectFields(input_names),\n",
|
||||
" is_train=True,\n",
|
||||
" shuffle_buffer_length=None,\n",
|
||||
"# cache_data=cache_data,\n",
|
||||
")\n",
|
||||
"training_iter_dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:21:12.623864Z",
|
||||
"start_time": "2022-12-23T06:21:12.623857Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"next(iter(training_iter_dataset))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:21:12.624379Z",
|
||||
"start_time": "2022-12-23T06:21:12.624372Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"training_data_loader = DataLoader(\n",
|
||||
" training_iter_dataset,\n",
|
||||
" batch_size=estimator.trainer.batch_size,\n",
|
||||
" num_workers=0,\n",
|
||||
"# prefetch_factor=prefetch_factor,\n",
|
||||
" pin_memory=True,\n",
|
||||
" worker_init_fn=estimator._worker_init_fn,\n",
|
||||
"# **kwargs,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:21:12.624842Z",
|
||||
"start_time": "2022-12-23T06:21:12.624836Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"next(iter(training_data_loader))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T06:21:12.625415Z",
|
||||
"start_time": "2022-12-23T06:21:12.625408Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for b in tqdm(training_data_loader):\n",
|
||||
" pass"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -350,9 +881,9 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3.9.15 ('glounts')",
|
||||
"display_name": "glounts",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
"name": "glounts"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
@@ -366,6 +897,19 @@
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
},
|
||||
"toc": {
|
||||
"base_numbering": 1,
|
||||
"nav_menu": {},
|
||||
"number_sections": true,
|
||||
"sideBar": true,
|
||||
"skip_h1_title": false,
|
||||
"title_cell": "Table of Contents",
|
||||
"title_sidebar": "Contents",
|
||||
"toc_cell": false,
|
||||
"toc_position": {},
|
||||
"toc_section_display": true,
|
||||
"toc_window_display": true
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "7f25a1f13147a60511cf6766827402baf95cbe50d53a241197155306ee38fe70"
|
||||
|
||||
+37
-25
File diff suppressed because one or more lines are too long
+36
-25
File diff suppressed because one or more lines are too long
@@ -0,0 +1,28 @@
|
||||
|
||||
install env
|
||||
|
||||
```sh
|
||||
export PROJ=glounts10.0
|
||||
conda create -n $PROJ python=3.8 -y
|
||||
conda activate $PROJ
|
||||
mamba install -y ipykernel pip ipywidgets
|
||||
# pip install torch==1.10.0+cu113 torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
|
||||
mamba install pytorch==1.9.1 torchvision==0.10.1 torchaudio==0.9.1 cudatoolkit=11.3 -c pytorch -c conda-forge
|
||||
pip install gluonts==0.10.0
|
||||
pip install -e .
|
||||
```
|
||||
```sh
|
||||
export PROJ=gluonts10.0
|
||||
conda create -n $PROJ python=3.9 -y
|
||||
conda activate $PROJ
|
||||
mamba install -y ipykernel pip ipywidgets
|
||||
# torch
|
||||
mamba install -y ipykernel pip ipywidgets pytorch==1.13.0 torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
|
||||
mamba install -y xformers -c xformers/label/dev
|
||||
mamba install -y conda-lock
|
||||
# gluonts
|
||||
pip install gluonts==0.10.0
|
||||
pip install -e .
|
||||
# pip install "gluonts[torch,pro]" pytorchts vectorbt
|
||||
python -m ipykernel install --user --name $PROJ --display-name $PROJ
|
||||
```
|
||||
@@ -104,7 +104,7 @@ class PyTorchEstimator(Estimator):
|
||||
|
||||
input_names = get_module_forward_input_names(trained_net)
|
||||
|
||||
with env._let(max_idle_transforms=maybe_len(training_data) or 0):
|
||||
with env._let(max_idle_transforms=10):
|
||||
training_instance_splitter = self.create_instance_splitter("training")
|
||||
training_iter_dataset = TransformedIterableDataset(
|
||||
dataset=training_data,
|
||||
|
||||
@@ -31,8 +31,9 @@ class DiffusionEmbedding(nn.Module):
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, hidden_size, residual_channels, dilation):
|
||||
def __init__(self, x_in, t_in, cond_in, dilation):
|
||||
super().__init__()
|
||||
residual_channels = x_in
|
||||
self.dilated_conv = nn.Conv1d(
|
||||
residual_channels,
|
||||
2 * residual_channels,
|
||||
@@ -41,9 +42,9 @@ class ResidualBlock(nn.Module):
|
||||
dilation=dilation,
|
||||
padding_mode="circular",
|
||||
)
|
||||
self.diffusion_projection = nn.Linear(hidden_size, residual_channels)
|
||||
self.diffusion_projection = nn.Linear(t_in, residual_channels)
|
||||
self.conditioner_projection = nn.Conv1d(
|
||||
1, 2 * residual_channels, 1, padding=2, padding_mode="circular"
|
||||
cond_in, 2 * residual_channels, 1 #3, padding_mode="circular",
|
||||
)
|
||||
self.output_projection = nn.Conv1d(residual_channels, 2 * residual_channels, 1)
|
||||
|
||||
@@ -69,13 +70,13 @@ class ResidualBlock(nn.Module):
|
||||
class CondUpsampler(nn.Module):
|
||||
def __init__(self, cond_length, target_dim):
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(cond_length, target_dim // 2)
|
||||
self.linear2 = nn.Linear(target_dim // 2, target_dim)
|
||||
self.conv1 = nn.Conv1d(cond_length, target_dim // 2, 1)
|
||||
self.conv2 = nn.Conv1d(target_dim // 2, target_dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.conv1(x)
|
||||
x = F.leaky_relu(x, 0.4)
|
||||
x = self.linear2(x)
|
||||
x = self.conv2(x)
|
||||
x = F.leaky_relu(x, 0.4)
|
||||
return x
|
||||
|
||||
@@ -93,26 +94,30 @@ class EpsilonTheta2(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.input_projection = nn.Conv1d(
|
||||
1, residual_channels, 1, padding=2, padding_mode="circular"
|
||||
target_dim, residual_channels, 1#, padding=2, padding_mode="circular"
|
||||
)
|
||||
self.diffusion_embedding = DiffusionEmbedding(
|
||||
time_emb_dim, proj_dim=residual_hidden
|
||||
)
|
||||
# self.cond_upsampler = nn.Conv1d(cond_length, width*3, 1)
|
||||
self.cond_upsampler = CondUpsampler(
|
||||
target_dim=target_dim, cond_length=cond_length
|
||||
)
|
||||
self.residual_layers = nn.ModuleList(
|
||||
[
|
||||
ResidualBlock(
|
||||
residual_channels=residual_channels,
|
||||
x_in=residual_channels,
|
||||
t_in=residual_hidden,
|
||||
cond_in=target_dim,
|
||||
# residual_channels=residual_channels,
|
||||
dilation=2 ** (i % dilation_cycle_length),
|
||||
hidden_size=residual_hidden,
|
||||
# hidden_size=residual_hidden,
|
||||
)
|
||||
for i in range(residual_layers)
|
||||
]
|
||||
)
|
||||
self.skip_projection = nn.Conv1d(residual_channels, residual_channels, 3)
|
||||
self.output_projection = nn.Conv1d(residual_channels, 1, 3)
|
||||
self.skip_projection = nn.Conv1d(residual_channels, residual_channels, 1)
|
||||
self.output_projection = nn.Conv1d(residual_channels, target_dim, 1)
|
||||
|
||||
nn.init.kaiming_normal_(self.input_projection.weight)
|
||||
nn.init.kaiming_normal_(self.skip_projection.weight)
|
||||
|
||||
@@ -53,12 +53,8 @@ def mk_noise(shape, device):
|
||||
noise_ou = noise_ou[:, None]
|
||||
return C, noise_rand, noise_ou
|
||||
|
||||
def noise_like(shape, device, repeat=False):
|
||||
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
|
||||
shape[0], *((1,) * (len(shape) - 1))
|
||||
)
|
||||
noise = lambda: torch.randn(shape, device=device)
|
||||
return repeat_noise() if repeat else noise()
|
||||
def noise_like(x):
|
||||
return mk_noise(x.shape, x.device)
|
||||
|
||||
|
||||
def cosine_beta_schedule(timesteps, s=0.008):
|
||||
@@ -298,7 +294,7 @@ class GaussianDiffusionOU(nn.Module):
|
||||
|
||||
B, T, _ = x.shape
|
||||
|
||||
time = torch.randint(0, self.num_timesteps, (B * T,), device=x.device).long()
|
||||
time = torch.randint(0, self.num_timesteps, (B,), device=x.device).long()
|
||||
loss = self.p_losses(
|
||||
x, cond, time, *args, **kwargs
|
||||
)
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from gluonts.core.component import validated
|
||||
|
||||
from pts.model import weighted_average
|
||||
@@ -62,6 +61,7 @@ class TimeGradTrainingNetwork2(nn.Module):
|
||||
)
|
||||
|
||||
self.denoise_fn = EpsilonTheta2(
|
||||
# input_size=input_size,
|
||||
target_dim=target_dim,
|
||||
cond_length=conditioning_length,
|
||||
residual_layers=residual_layers,
|
||||
@@ -402,7 +402,8 @@ class TimeGradTrainingNetwork2(nn.Module):
|
||||
|
||||
# we sum the last axis to have the same shape for all likelihoods
|
||||
# (batch_size, subseq_length, 1)
|
||||
|
||||
target = target.permute(0, 2, 1)
|
||||
distr_args = distr_args.permute(0, 2, 1)
|
||||
likelihoods = self.diffusion.log_prob(target, distr_args).unsqueeze(-1)
|
||||
|
||||
# assert_shape(likelihoods, (-1, seq_len, 1))
|
||||
|
||||
Reference in New Issue
Block a user