{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "\n", "import torch" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from pts.dataset import to_pandas, MultivariateGrouper, TrainDatasets\n", "from pts.dataset.repository import get_dataset, dataset_recipes\n", "from pts.model.tempflow import TempFlowEstimator\n", "from pts.model.transformer_tempflow import TransformerTempFlowEstimator\n", "from pts import Trainer\n", "from pts.evaluation import make_evaluation_predictions\n", "from pts.evaluation import MultivariateEvaluator" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Prepeare data set" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "dataset = get_dataset(\"solar_nips\", regenerate=False, shuffle=False)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "MetaData(freq='H', target=None, feat_static_cat=[CategoricalFeatureInfo(name='feat_static_cat', cardinality='137')], feat_static_real=[], feat_dynamic_real=[], feat_dynamic_cat=[], prediction_length=24)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset.metadata" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "train_grouper = MultivariateGrouper(max_target_dim=int(dataset.metadata.feat_static_cat[0].cardinality))\n", "\n", "test_grouper = MultivariateGrouper(num_test_dates=int(len(dataset.test)/len(dataset.train)), \n", " max_target_dim=int(dataset.metadata.feat_static_cat[0].cardinality))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "dataset_train = train_grouper(dataset.train)\n", "dataset_test = test_grouper(dataset.test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluator" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "evaluator = MultivariateEvaluator(quantiles=(np.arange(20)/20.0)[1:],\n", " target_agg_funcs={'sum': np.sum})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## `GRU-Real-NVP`" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [], "source": [ "estimator = TempFlowEstimator(\n", " target_dim=int(dataset.metadata.feat_static_cat[0].cardinality),\n", " prediction_length=dataset.metadata.prediction_length,\n", " cell_type='GRU',\n", " input_size=552,\n", " freq=dataset.metadata.freq,\n", " scaling=True,\n", " dequantize=True,\n", " n_blocks=4,\n", " trainer=Trainer(device=device,\n", " epochs=45,\n", " learning_rate=1e-3,\n", " num_batches_per_epoch=100,\n", " batch_size=64)\n", ")" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "99it [00:10, 9.03it/s, avg_epoch_loss=-43.1, epoch=0]\n", "99it [00:10, 9.03it/s, avg_epoch_loss=-126, epoch=1]\n", "99it [00:11, 9.00it/s, avg_epoch_loss=-142, epoch=2]\n", "99it [00:10, 9.37it/s, avg_epoch_loss=-143, epoch=3]\n", "99it [00:10, 9.09it/s, avg_epoch_loss=-153, epoch=4]\n", "99it [00:11, 8.76it/s, avg_epoch_loss=-157, epoch=5]\n", "99it [00:10, 9.03it/s, avg_epoch_loss=-157, epoch=6]\n", "99it [00:11, 8.94it/s, avg_epoch_loss=-166, epoch=7]\n", "99it [00:11, 8.56it/s, avg_epoch_loss=-169, epoch=8]\n", "99it [00:11, 8.84it/s, avg_epoch_loss=-168, epoch=9]\n", "98it [00:11, 8.89it/s, avg_epoch_loss=-170, epoch=10]\n", "99it [00:11, 8.89it/s, avg_epoch_loss=-172, epoch=11]\n", "98it [00:10, 9.00it/s, avg_epoch_loss=-172, epoch=12]\n", "99it [00:10, 9.02it/s, avg_epoch_loss=-177, epoch=13]\n", "99it [00:10, 9.48it/s, avg_epoch_loss=-180, epoch=14]\n", "98it [00:10, 9.65it/s, avg_epoch_loss=-180, epoch=15]\n", "99it [00:10, 9.01it/s, avg_epoch_loss=-182, epoch=16]\n", "99it [00:10, 9.11it/s, avg_epoch_loss=-182, epoch=17]\n", "99it [00:10, 9.02it/s, avg_epoch_loss=-182, epoch=18]\n", "98it [00:11, 8.89it/s, avg_epoch_loss=-182, epoch=19]\n", "99it [00:10, 9.01it/s, avg_epoch_loss=-179, epoch=20]\n", "99it [00:10, 9.15it/s, avg_epoch_loss=-183, epoch=21]\n", "99it [00:11, 8.96it/s, avg_epoch_loss=-188, epoch=22]\n", "99it [00:11, 8.96it/s, avg_epoch_loss=-188, epoch=23]\n", "99it [00:10, 9.04it/s, avg_epoch_loss=-190, epoch=24]\n", "98it [00:11, 8.85it/s, avg_epoch_loss=-193, epoch=25]\n", "98it [00:10, 8.95it/s, avg_epoch_loss=-193, epoch=26]\n", "99it [00:11, 8.93it/s, avg_epoch_loss=-192, epoch=27]\n", "99it [00:10, 9.06it/s, avg_epoch_loss=-193, epoch=28]\n", "98it [00:10, 8.97it/s, avg_epoch_loss=-193, epoch=29]\n", "99it [00:11, 8.95it/s, avg_epoch_loss=-196, epoch=30]\n", "98it [00:10, 8.95it/s, avg_epoch_loss=-193, epoch=31]\n", "99it [00:10, 9.05it/s, avg_epoch_loss=-192, epoch=32]\n", "99it [00:11, 8.90it/s, avg_epoch_loss=-197, epoch=33]\n", "99it [00:11, 8.93it/s, avg_epoch_loss=-198, epoch=34]\n", "98it [00:11, 8.85it/s, avg_epoch_loss=-197, epoch=35]\n", "99it [00:10, 9.01it/s, avg_epoch_loss=-198, epoch=36]\n", "98it [00:10, 8.97it/s, avg_epoch_loss=-200, epoch=37]\n", "98it [00:11, 8.85it/s, avg_epoch_loss=-199, epoch=38]\n", "99it [00:10, 9.04it/s, avg_epoch_loss=-197, epoch=39]\n", "99it [00:11, 8.97it/s, avg_epoch_loss=-199, epoch=40]\n", "99it [00:11, 8.88it/s, avg_epoch_loss=-201, epoch=41]\n", "98it [00:11, 8.90it/s, avg_epoch_loss=-201, epoch=42]\n", "99it [00:10, 9.09it/s, avg_epoch_loss=-202, epoch=43]\n", "98it [00:10, 8.93it/s, avg_epoch_loss=-199, epoch=44]\n", " 0%| | 0/137 [00:00