fix deepvar test

This commit is contained in:
Dr. Kashif Rasul
2021-01-01 13:25:46 +01:00
parent d5577a2c9e
commit c7e603be3e
3 changed files with 44 additions and 15 deletions
+1
View File
@@ -277,4 +277,5 @@ class DeepVAREstimator(PyTorchEstimator):
freq=self.freq,
prediction_length=self.prediction_length,
device=device,
output_transform=self.output_transform,
)
-1
View File
@@ -12,7 +12,6 @@ from gluonts.core.component import validated
from gluonts.dataset.common import Dataset
from gluonts.model.estimator import Estimator
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.torch.batchify import batchify
from gluonts.transform import SelectFields, Transformation
from pts import Trainer
+43 -14
View File
@@ -14,17 +14,18 @@
# First-party imports
import pytest
from pts import Trainer
from pts.dataset import TrainDatasets, MultivariateGrouper
from pts.dataset.artificial import constant_dataset
from pts.evaluation import MultivariateEvaluator
from pts.evaluation import backtest_metrics
from gluonts.dataset.common import TrainDatasets
from gluonts.dataset.artificial import constant_dataset
from gluonts.dataset.multivariate_grouper import MultivariateGrouper
from gluonts.evaluation import MultivariateEvaluator
from gluonts.evaluation.backtest import backtest_metrics
from pts.model.deepvar import DeepVAREstimator
from pts.modules import (
NormalOutput,
LowRankMultivariateNormalOutput,
MultivariateNormalOutput,
)
from pts import Trainer
def load_multivariate_constant_dataset():
@@ -46,13 +47,27 @@ metadata = dataset.metadata
estimator = DeepVAREstimator
# @pytest.mark.timeout(10)
@pytest.mark.parametrize(
"distr_output, num_batches_per_epoch, Estimator, use_marginal_transformation",
[
(NormalOutput(dim=target_dim), 10, estimator, True,),
(NormalOutput(dim=target_dim), 10, estimator, False,),
(LowRankMultivariateNormalOutput(dim=target_dim, rank=2), 10, estimator, True,),
(
NormalOutput(dim=target_dim),
10,
estimator,
True,
),
(
NormalOutput(dim=target_dim),
10,
estimator,
False,
),
(
LowRankMultivariateNormalOutput(dim=target_dim, rank=2),
10,
estimator,
True,
),
(
LowRankMultivariateNormalOutput(dim=target_dim, rank=2),
10,
@@ -60,12 +75,25 @@ estimator = DeepVAREstimator
False,
),
(None, 10, estimator, True),
(MultivariateNormalOutput(dim=target_dim), 10, estimator, True,),
(MultivariateNormalOutput(dim=target_dim), 10, estimator, False,),
(
MultivariateNormalOutput(dim=target_dim),
10,
estimator,
True,
),
(
MultivariateNormalOutput(dim=target_dim),
10,
estimator,
False,
),
],
)
def test_deepvar(
distr_output, num_batches_per_epoch, Estimator, use_marginal_transformation,
distr_output,
num_batches_per_epoch,
Estimator,
use_marginal_transformation,
):
estimator = Estimator(
@@ -88,10 +116,11 @@ def test_deepvar(
),
)
predictor = estimator.train(training_data=dataset.train)
agg_metrics, _ = backtest_metrics(
train_dataset=dataset.train,
test_dataset=dataset.test,
forecaster=estimator,
predictor=predictor,
evaluator=MultivariateEvaluator(
quantiles=(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)
),