From c7e603be3eebfc2bdba8c7de6746d29bcce05b4a Mon Sep 17 00:00:00 2001 From: "Dr. Kashif Rasul" Date: Fri, 1 Jan 2021 13:25:46 +0100 Subject: [PATCH] fix deepvar test --- pts/model/deepvar/deepvar_estimator.py | 1 + pts/model/estimator.py | 1 - test/model/test_deepvar.py | 57 +++++++++++++++++++------- 3 files changed, 44 insertions(+), 15 deletions(-) diff --git a/pts/model/deepvar/deepvar_estimator.py b/pts/model/deepvar/deepvar_estimator.py index 2e9f34f..30429f4 100644 --- a/pts/model/deepvar/deepvar_estimator.py +++ b/pts/model/deepvar/deepvar_estimator.py @@ -277,4 +277,5 @@ class DeepVAREstimator(PyTorchEstimator): freq=self.freq, prediction_length=self.prediction_length, device=device, + output_transform=self.output_transform, ) diff --git a/pts/model/estimator.py b/pts/model/estimator.py index f54967b..74fe6e0 100644 --- a/pts/model/estimator.py +++ b/pts/model/estimator.py @@ -12,7 +12,6 @@ from gluonts.core.component import validated from gluonts.dataset.common import Dataset from gluonts.model.estimator import Estimator from gluonts.torch.model.predictor import PyTorchPredictor -from gluonts.torch.batchify import batchify from gluonts.transform import SelectFields, Transformation from pts import Trainer diff --git a/test/model/test_deepvar.py b/test/model/test_deepvar.py index 1ab74e2..2216ebd 100644 --- a/test/model/test_deepvar.py +++ b/test/model/test_deepvar.py @@ -14,17 +14,18 @@ # First-party imports import pytest -from pts import Trainer -from pts.dataset import TrainDatasets, MultivariateGrouper -from pts.dataset.artificial import constant_dataset -from pts.evaluation import MultivariateEvaluator -from pts.evaluation import backtest_metrics +from gluonts.dataset.common import TrainDatasets +from gluonts.dataset.artificial import constant_dataset +from gluonts.dataset.multivariate_grouper import MultivariateGrouper +from gluonts.evaluation import MultivariateEvaluator +from gluonts.evaluation.backtest import backtest_metrics from pts.model.deepvar import DeepVAREstimator from pts.modules import ( NormalOutput, LowRankMultivariateNormalOutput, MultivariateNormalOutput, ) +from pts import Trainer def load_multivariate_constant_dataset(): @@ -46,13 +47,27 @@ metadata = dataset.metadata estimator = DeepVAREstimator -# @pytest.mark.timeout(10) @pytest.mark.parametrize( "distr_output, num_batches_per_epoch, Estimator, use_marginal_transformation", [ - (NormalOutput(dim=target_dim), 10, estimator, True,), - (NormalOutput(dim=target_dim), 10, estimator, False,), - (LowRankMultivariateNormalOutput(dim=target_dim, rank=2), 10, estimator, True,), + ( + NormalOutput(dim=target_dim), + 10, + estimator, + True, + ), + ( + NormalOutput(dim=target_dim), + 10, + estimator, + False, + ), + ( + LowRankMultivariateNormalOutput(dim=target_dim, rank=2), + 10, + estimator, + True, + ), ( LowRankMultivariateNormalOutput(dim=target_dim, rank=2), 10, @@ -60,12 +75,25 @@ estimator = DeepVAREstimator False, ), (None, 10, estimator, True), - (MultivariateNormalOutput(dim=target_dim), 10, estimator, True,), - (MultivariateNormalOutput(dim=target_dim), 10, estimator, False,), + ( + MultivariateNormalOutput(dim=target_dim), + 10, + estimator, + True, + ), + ( + MultivariateNormalOutput(dim=target_dim), + 10, + estimator, + False, + ), ], ) def test_deepvar( - distr_output, num_batches_per_epoch, Estimator, use_marginal_transformation, + distr_output, + num_batches_per_epoch, + Estimator, + use_marginal_transformation, ): estimator = Estimator( @@ -88,10 +116,11 @@ def test_deepvar( ), ) + predictor = estimator.train(training_data=dataset.train) + agg_metrics, _ = backtest_metrics( - train_dataset=dataset.train, test_dataset=dataset.test, - forecaster=estimator, + predictor=predictor, evaluator=MultivariateEvaluator( quantiles=(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9) ),