mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-06-27 18:06:19 +08:00
GluonTS import updates (#106)
* GluonTS import updates * drop freq argument see https://github.com/awslabs/gluon-ts/pull/1997
This commit is contained in:
@@ -28,7 +28,7 @@ from gluonts.transform import (
|
||||
)
|
||||
from gluonts.torch.util import copy_parameters
|
||||
from gluonts.torch.model.predictor import PyTorchPredictor
|
||||
from gluonts.torch.modules.distribution_output import DistributionOutput
|
||||
from gluonts.torch.distributions.distribution_output import DistributionOutput
|
||||
from gluonts.model.predictor import Predictor
|
||||
|
||||
from pts.model.utils import get_module_forward_input_names
|
||||
@@ -258,7 +258,7 @@ class DeepAREstimator(PyTorchEstimator):
|
||||
input_names=input_names,
|
||||
prediction_net=prediction_network,
|
||||
batch_size=self.trainer.batch_size,
|
||||
freq=self.freq,
|
||||
#freq=self.freq,
|
||||
prediction_length=self.prediction_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
@@ -6,7 +6,7 @@ import torch.nn as nn
|
||||
from torch.distributions import Distribution
|
||||
|
||||
from gluonts.core.component import validated
|
||||
from gluonts.torch.modules.distribution_output import DistributionOutput
|
||||
from gluonts.torch.distributions.distribution_output import DistributionOutput
|
||||
from pts.model import weighted_average
|
||||
from pts.modules import MeanScaler, NOPScaler, FeatureEmbedder
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ from pts.distributions import (
|
||||
TransformedImplicitQuantile,
|
||||
)
|
||||
from gluonts.core.component import validated
|
||||
from gluonts.torch.modules.distribution_output import (
|
||||
from gluonts.torch.distributions.distribution_output import (
|
||||
DistributionOutput,
|
||||
LambdaLayer,
|
||||
PtArgProj,
|
||||
|
||||
Reference in New Issue
Block a user