mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-06-27 20:22:32 +08:00
fix deepvar test
This commit is contained in:
@@ -277,4 +277,5 @@ class DeepVAREstimator(PyTorchEstimator):
|
||||
freq=self.freq,
|
||||
prediction_length=self.prediction_length,
|
||||
device=device,
|
||||
output_transform=self.output_transform,
|
||||
)
|
||||
|
||||
@@ -12,7 +12,6 @@ from gluonts.core.component import validated
|
||||
from gluonts.dataset.common import Dataset
|
||||
from gluonts.model.estimator import Estimator
|
||||
from gluonts.torch.model.predictor import PyTorchPredictor
|
||||
from gluonts.torch.batchify import batchify
|
||||
from gluonts.transform import SelectFields, Transformation
|
||||
|
||||
from pts import Trainer
|
||||
|
||||
+43
-14
@@ -14,17 +14,18 @@
|
||||
# First-party imports
|
||||
import pytest
|
||||
|
||||
from pts import Trainer
|
||||
from pts.dataset import TrainDatasets, MultivariateGrouper
|
||||
from pts.dataset.artificial import constant_dataset
|
||||
from pts.evaluation import MultivariateEvaluator
|
||||
from pts.evaluation import backtest_metrics
|
||||
from gluonts.dataset.common import TrainDatasets
|
||||
from gluonts.dataset.artificial import constant_dataset
|
||||
from gluonts.dataset.multivariate_grouper import MultivariateGrouper
|
||||
from gluonts.evaluation import MultivariateEvaluator
|
||||
from gluonts.evaluation.backtest import backtest_metrics
|
||||
from pts.model.deepvar import DeepVAREstimator
|
||||
from pts.modules import (
|
||||
NormalOutput,
|
||||
LowRankMultivariateNormalOutput,
|
||||
MultivariateNormalOutput,
|
||||
)
|
||||
from pts import Trainer
|
||||
|
||||
|
||||
def load_multivariate_constant_dataset():
|
||||
@@ -46,13 +47,27 @@ metadata = dataset.metadata
|
||||
estimator = DeepVAREstimator
|
||||
|
||||
|
||||
# @pytest.mark.timeout(10)
|
||||
@pytest.mark.parametrize(
|
||||
"distr_output, num_batches_per_epoch, Estimator, use_marginal_transformation",
|
||||
[
|
||||
(NormalOutput(dim=target_dim), 10, estimator, True,),
|
||||
(NormalOutput(dim=target_dim), 10, estimator, False,),
|
||||
(LowRankMultivariateNormalOutput(dim=target_dim, rank=2), 10, estimator, True,),
|
||||
(
|
||||
NormalOutput(dim=target_dim),
|
||||
10,
|
||||
estimator,
|
||||
True,
|
||||
),
|
||||
(
|
||||
NormalOutput(dim=target_dim),
|
||||
10,
|
||||
estimator,
|
||||
False,
|
||||
),
|
||||
(
|
||||
LowRankMultivariateNormalOutput(dim=target_dim, rank=2),
|
||||
10,
|
||||
estimator,
|
||||
True,
|
||||
),
|
||||
(
|
||||
LowRankMultivariateNormalOutput(dim=target_dim, rank=2),
|
||||
10,
|
||||
@@ -60,12 +75,25 @@ estimator = DeepVAREstimator
|
||||
False,
|
||||
),
|
||||
(None, 10, estimator, True),
|
||||
(MultivariateNormalOutput(dim=target_dim), 10, estimator, True,),
|
||||
(MultivariateNormalOutput(dim=target_dim), 10, estimator, False,),
|
||||
(
|
||||
MultivariateNormalOutput(dim=target_dim),
|
||||
10,
|
||||
estimator,
|
||||
True,
|
||||
),
|
||||
(
|
||||
MultivariateNormalOutput(dim=target_dim),
|
||||
10,
|
||||
estimator,
|
||||
False,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_deepvar(
|
||||
distr_output, num_batches_per_epoch, Estimator, use_marginal_transformation,
|
||||
distr_output,
|
||||
num_batches_per_epoch,
|
||||
Estimator,
|
||||
use_marginal_transformation,
|
||||
):
|
||||
|
||||
estimator = Estimator(
|
||||
@@ -88,10 +116,11 @@ def test_deepvar(
|
||||
),
|
||||
)
|
||||
|
||||
predictor = estimator.train(training_data=dataset.train)
|
||||
|
||||
agg_metrics, _ = backtest_metrics(
|
||||
train_dataset=dataset.train,
|
||||
test_dataset=dataset.test,
|
||||
forecaster=estimator,
|
||||
predictor=predictor,
|
||||
evaluator=MultivariateEvaluator(
|
||||
quantiles=(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user