From c4f6445d1172cfce16821fea41666bdea79b5131 Mon Sep 17 00:00:00 2001 From: Samuel Norling Date: Sat, 6 Mar 2021 17:44:34 +0100 Subject: [PATCH] Updated Solar example (#43) --- examples/Multivariate-Flow-Solar.ipynb | 503 +++++++++---------------- examples/Time-Grad-Electricity.ipynb | 18 +- 2 files changed, 171 insertions(+), 350 deletions(-) diff --git a/examples/Multivariate-Flow-Solar.ipynb b/examples/Multivariate-Flow-Solar.ipynb index 88a04be..316302f 100644 --- a/examples/Multivariate-Flow-Solar.ipynb +++ b/examples/Multivariate-Flow-Solar.ipynb @@ -14,22 +14,22 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ - "from pts.dataset import to_pandas, MultivariateGrouper, TrainDatasets\n", - "from pts.dataset.repository import get_dataset, dataset_recipes\n", + "from gluonts.dataset.multivariate_grouper import MultivariateGrouper\n", + "from gluonts.dataset.repository.datasets import dataset_recipes, get_dataset\n", "from pts.model.tempflow import TempFlowEstimator\n", "from pts.model.transformer_tempflow import TransformerTempFlowEstimator\n", "from pts import Trainer\n", - "from pts.evaluation import make_evaluation_predictions\n", - "from pts.evaluation import MultivariateEvaluator" + "from gluonts.evaluation.backtest import make_evaluation_predictions\n", + "from gluonts.evaluation import MultivariateEvaluator" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -45,16 +45,16 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ - "dataset = get_dataset(\"solar_nips\", regenerate=False, shuffle=False)" + "dataset = get_dataset(\"solar_nips\", regenerate=False)" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -63,7 +63,7 @@ "MetaData(freq='H', target=None, feat_static_cat=[CategoricalFeatureInfo(name='feat_static_cat', cardinality='137')], feat_static_real=[], feat_dynamic_real=[], feat_dynamic_cat=[], prediction_length=24)" ] }, - "execution_count": 5, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -74,7 +74,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -86,9 +86,18 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 13, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/samuelnorling/Programming/KTH/exjobb/zalandorosettaenv/lib/python3.8/site-packages/gluonts/dataset/multivariate_grouper.py:182: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n", + " return {FieldName.TARGET: np.array([funcs(data) for data in dataset])}\n" + ] + } + ], "source": [ "dataset_train = train_grouper(dataset.train)\n", "dataset_test = test_grouper(dataset.test)" @@ -103,7 +112,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -319,13 +328,7 @@ "Running evaluation: 7it [00:00, 80.42it/s]\n", " 45%|████▍ | 61/137 [00:06<00:07, 9.72it/s]\n", "Running evaluation: 7it [00:00, 80.96it/s]\n", - " 45%|████▌ | 62/137 [00:06<00:07, 9.76it/s]\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ + " 45%|████▌ | 62/137 [00:06<00:07, 9.76it/s]\n", "Running evaluation: 7it [00:00, 80.15it/s]\n", " 46%|████▌ | 63/137 [00:06<00:07, 9.76it/s]\n", "Running evaluation: 7it [00:00, 80.72it/s]\n", @@ -757,13 +760,7 @@ " 53%|█████▎ | 73/137 [00:07<00:06, 9.87it/s]\n", "Running evaluation: 7it [00:00, 81.08it/s]\n", " 54%|█████▍ | 74/137 [00:07<00:06, 9.87it/s]\n", - "Running evaluation: 7it [00:00, 81.81it/s]\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ + "Running evaluation: 7it [00:00, 81.81it/s]\n", " 55%|█████▍ | 75/137 [00:07<00:06, 9.89it/s]\n", "Running evaluation: 7it [00:00, 80.63it/s]\n", " 55%|█████▌ | 76/137 [00:07<00:06, 9.87it/s]\n", @@ -967,7 +964,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ @@ -993,316 +990,152 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "98it [00:11, 8.71it/s, avg_epoch_loss=-44.8, epoch=0]\n", - "98it [00:11, 8.60it/s, avg_epoch_loss=-170, epoch=1]\n", - "99it [00:11, 8.82it/s, avg_epoch_loss=-189, epoch=2]\n", - "98it [00:11, 8.83it/s, avg_epoch_loss=-201, epoch=3]\n", - "99it [00:11, 8.80it/s, avg_epoch_loss=-208, epoch=4]\n", - "98it [00:11, 8.72it/s, avg_epoch_loss=-212, epoch=5]\n", - "99it [00:11, 8.83it/s, avg_epoch_loss=-216, epoch=6]\n", - "99it [00:11, 8.80it/s, avg_epoch_loss=-218, epoch=7]\n", - "99it [00:11, 8.84it/s, avg_epoch_loss=-220, epoch=8]\n", - "98it [00:11, 8.74it/s, avg_epoch_loss=-222, epoch=9]\n", - "99it [00:11, 8.92it/s, avg_epoch_loss=-223, epoch=10]\n", - "99it [00:11, 8.74it/s, avg_epoch_loss=-225, epoch=11]\n", - "99it [00:11, 8.84it/s, avg_epoch_loss=-226, epoch=12]\n", - "99it [00:11, 8.88it/s, avg_epoch_loss=-227, epoch=13]\n", - " 0%| | 0/137 [00:00