From ccc35afb31288bf528d7157d1789c11f4c177e3c Mon Sep 17 00:00:00 2001 From: Nick Date: Fri, 29 Jul 2022 16:46:17 -0400 Subject: [PATCH] GluonTS import updates (#106) * GluonTS import updates * drop freq argument see https://github.com/awslabs/gluon-ts/pull/1997 --- pts/model/deepar/deepar_estimator.py | 4 ++-- pts/model/deepar/deepar_network.py | 2 +- pts/modules/distribution_output.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pts/model/deepar/deepar_estimator.py b/pts/model/deepar/deepar_estimator.py index eb0d673..6f8a23e 100644 --- a/pts/model/deepar/deepar_estimator.py +++ b/pts/model/deepar/deepar_estimator.py @@ -28,7 +28,7 @@ from gluonts.transform import ( ) from gluonts.torch.util import copy_parameters from gluonts.torch.model.predictor import PyTorchPredictor -from gluonts.torch.modules.distribution_output import DistributionOutput +from gluonts.torch.distributions.distribution_output import DistributionOutput from gluonts.model.predictor import Predictor from pts.model.utils import get_module_forward_input_names @@ -258,7 +258,7 @@ class DeepAREstimator(PyTorchEstimator): input_names=input_names, prediction_net=prediction_network, batch_size=self.trainer.batch_size, - freq=self.freq, + #freq=self.freq, prediction_length=self.prediction_length, device=device, ) diff --git a/pts/model/deepar/deepar_network.py b/pts/model/deepar/deepar_network.py index 12f1709..300f580 100644 --- a/pts/model/deepar/deepar_network.py +++ b/pts/model/deepar/deepar_network.py @@ -6,7 +6,7 @@ import torch.nn as nn from torch.distributions import Distribution from gluonts.core.component import validated -from gluonts.torch.modules.distribution_output import DistributionOutput +from gluonts.torch.distributions.distribution_output import DistributionOutput from pts.model import weighted_average from pts.modules import MeanScaler, NOPScaler, FeatureEmbedder diff --git a/pts/modules/distribution_output.py b/pts/modules/distribution_output.py index 694358d..9ca7f19 100644 --- a/pts/modules/distribution_output.py +++ b/pts/modules/distribution_output.py @@ -31,7 +31,7 @@ from pts.distributions import ( TransformedImplicitQuantile, ) from gluonts.core.component import validated -from gluonts.torch.modules.distribution_output import ( +from gluonts.torch.distributions.distribution_output import ( DistributionOutput, LambdaLayer, PtArgProj,