mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-06-27 16:46:32 +08:00
304 lines
11 KiB
Python
304 lines
11 KiB
Python
from typing import List, Optional, Callable
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from gluonts.core.component import validated
|
|
from gluonts.dataset.field_names import FieldName
|
|
from gluonts.time_feature import TimeFeature
|
|
from gluonts.torch.modules.distribution_output import DistributionOutput
|
|
from gluonts.torch.util import copy_parameters
|
|
from gluonts.torch.model.predictor import PyTorchPredictor
|
|
from gluonts.model.predictor import Predictor
|
|
from gluonts.transform import (
|
|
AddObservedValuesIndicator,
|
|
AddTimeFeatures,
|
|
AsNumpyArray,
|
|
CDFtoGaussianTransform,
|
|
Chain,
|
|
ExpandDimArray,
|
|
ExpectedNumInstanceSampler,
|
|
InstanceSplitter,
|
|
ValidationSplitSampler,
|
|
TestSplitSampler,
|
|
RenameFields,
|
|
SetField,
|
|
TargetDimIndicator,
|
|
Transformation,
|
|
VstackFeatures,
|
|
RemoveFields,
|
|
AddAgeFeature,
|
|
cdf_to_gaussian_forward_transform,
|
|
)
|
|
|
|
from pts import Trainer
|
|
from pts.model.utils import get_module_forward_input_names
|
|
from pts.feature import (
|
|
fourier_time_features_from_frequency,
|
|
lags_for_fourier_time_features_from_frequency,
|
|
)
|
|
from pts.model import PyTorchEstimator
|
|
from pts.modules import LowRankMultivariateNormalOutput
|
|
|
|
from .deepvar_network import DeepVARTrainingNetwork, DeepVARPredictionNetwork
|
|
|
|
|
|
class DeepVAREstimator(PyTorchEstimator):
|
|
@validated()
|
|
def __init__(
|
|
self,
|
|
input_size: int,
|
|
freq: str,
|
|
prediction_length: int,
|
|
target_dim: int,
|
|
trainer: Trainer = Trainer(),
|
|
context_length: Optional[int] = None,
|
|
num_layers: int = 2,
|
|
num_cells: int = 40,
|
|
cell_type: str = "LSTM",
|
|
num_parallel_samples: int = 100,
|
|
dropout_rate: float = 0.1,
|
|
use_feat_dynamic_real: bool = False,
|
|
use_feat_static_cat: bool = False,
|
|
use_feat_static_real: bool = False,
|
|
cardinality: Optional[List[int]] = None,
|
|
embedding_dimension: Optional[List[int]] = None,
|
|
distr_output: Optional[DistributionOutput] = None,
|
|
rank: Optional[int] = 5,
|
|
scaling: bool = True,
|
|
pick_incomplete: bool = False,
|
|
lags_seq: Optional[List[int]] = None,
|
|
time_features: Optional[List[TimeFeature]] = None,
|
|
conditioning_length: int = 200,
|
|
use_marginal_transformation=False,
|
|
**kwargs,
|
|
) -> None:
|
|
super().__init__(trainer=trainer, **kwargs)
|
|
|
|
self.freq = freq
|
|
self.context_length = (
|
|
context_length if context_length is not None else prediction_length
|
|
)
|
|
|
|
if distr_output is not None:
|
|
self.distr_output = distr_output
|
|
else:
|
|
self.distr_output = LowRankMultivariateNormalOutput(
|
|
dim=target_dim, rank=rank
|
|
)
|
|
|
|
self.input_size = input_size
|
|
self.prediction_length = prediction_length
|
|
self.target_dim = target_dim
|
|
self.num_layers = num_layers
|
|
self.num_cells = num_cells
|
|
self.cell_type = cell_type
|
|
self.num_parallel_samples = num_parallel_samples
|
|
self.dropout_rate = dropout_rate
|
|
self.use_feat_dynamic_real = use_feat_dynamic_real
|
|
self.use_feat_static_cat = use_feat_static_cat
|
|
self.use_feat_static_real = use_feat_static_real
|
|
self.cardinality = cardinality if cardinality and use_feat_static_cat else [1]
|
|
self.embedding_dimension = (
|
|
embedding_dimension
|
|
if embedding_dimension is not None
|
|
else [min(50, (cat + 1) // 2) for cat in self.cardinality]
|
|
)
|
|
self.conditioning_length = conditioning_length
|
|
self.use_marginal_transformation = use_marginal_transformation
|
|
|
|
self.lags_seq = (
|
|
lags_seq
|
|
if lags_seq is not None
|
|
else lags_for_fourier_time_features_from_frequency(freq_str=freq)
|
|
)
|
|
|
|
self.time_features = (
|
|
time_features
|
|
if time_features is not None
|
|
else fourier_time_features_from_frequency(self.freq)
|
|
)
|
|
|
|
self.history_length = self.context_length + max(self.lags_seq)
|
|
self.pick_incomplete = pick_incomplete
|
|
self.scaling = scaling
|
|
|
|
if self.use_marginal_transformation:
|
|
self.output_transform: Optional[
|
|
Callable
|
|
] = cdf_to_gaussian_forward_transform
|
|
else:
|
|
self.output_transform = None
|
|
|
|
self.train_sampler = ExpectedNumInstanceSampler(
|
|
num_instances=1.0,
|
|
min_past=0 if pick_incomplete else self.history_length,
|
|
min_future=prediction_length,
|
|
)
|
|
|
|
self.validation_sampler = ValidationSplitSampler(
|
|
min_past=0 if pick_incomplete else self.history_length,
|
|
min_future=prediction_length,
|
|
)
|
|
|
|
def create_transformation(self) -> Transformation:
|
|
remove_field_names = [FieldName.FEAT_DYNAMIC_CAT]
|
|
if not self.use_feat_dynamic_real:
|
|
remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL)
|
|
if not self.use_feat_static_real:
|
|
remove_field_names.append(FieldName.FEAT_STATIC_REAL)
|
|
|
|
return Chain(
|
|
[RemoveFields(field_names=remove_field_names)]
|
|
+ (
|
|
[SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0])]
|
|
if not self.use_feat_static_cat
|
|
else []
|
|
)
|
|
+ (
|
|
[SetField(output_field=FieldName.FEAT_STATIC_REAL, value=[0.0])]
|
|
if not self.use_feat_static_real
|
|
else []
|
|
)
|
|
+ [
|
|
AsNumpyArray(
|
|
field=FieldName.TARGET,
|
|
expected_ndim=1 + len(self.distr_output.event_shape),
|
|
),
|
|
# maps the target to (1, T)
|
|
# if the target data is uni dimensional
|
|
ExpandDimArray(
|
|
field=FieldName.TARGET,
|
|
axis=0 if self.distr_output.event_shape[0] == 1 else None,
|
|
),
|
|
AddObservedValuesIndicator(
|
|
target_field=FieldName.TARGET,
|
|
output_field=FieldName.OBSERVED_VALUES,
|
|
),
|
|
AddTimeFeatures(
|
|
start_field=FieldName.START,
|
|
target_field=FieldName.TARGET,
|
|
output_field=FieldName.FEAT_TIME,
|
|
time_features=self.time_features,
|
|
pred_length=self.prediction_length,
|
|
),
|
|
AddAgeFeature(
|
|
target_field=FieldName.TARGET,
|
|
output_field=FieldName.FEAT_AGE,
|
|
pred_length=self.prediction_length,
|
|
log_scale=True,
|
|
),
|
|
VstackFeatures(
|
|
output_field=FieldName.FEAT_TIME,
|
|
input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE]
|
|
+ (
|
|
[FieldName.FEAT_DYNAMIC_REAL]
|
|
if self.use_feat_dynamic_real
|
|
else []
|
|
),
|
|
),
|
|
TargetDimIndicator(
|
|
field_name="target_dimension_indicator",
|
|
target_field=FieldName.TARGET,
|
|
),
|
|
AsNumpyArray(
|
|
field=FieldName.FEAT_STATIC_CAT, expected_ndim=1, dtype=np.long
|
|
),
|
|
AsNumpyArray(field=FieldName.FEAT_STATIC_REAL, expected_ndim=1),
|
|
]
|
|
)
|
|
|
|
def create_instance_splitter(self, mode: str):
|
|
assert mode in ["training", "validation", "test"]
|
|
|
|
instance_sampler = {
|
|
"training": self.train_sampler,
|
|
"validation": self.validation_sampler,
|
|
"test": TestSplitSampler(),
|
|
}[mode]
|
|
|
|
return InstanceSplitter(
|
|
target_field=FieldName.TARGET,
|
|
is_pad_field=FieldName.IS_PAD,
|
|
start_field=FieldName.START,
|
|
forecast_start_field=FieldName.FORECAST_START,
|
|
instance_sampler=instance_sampler,
|
|
past_length=self.history_length,
|
|
future_length=self.prediction_length,
|
|
time_series_fields=[
|
|
FieldName.FEAT_TIME,
|
|
FieldName.OBSERVED_VALUES,
|
|
],
|
|
) + (
|
|
CDFtoGaussianTransform(
|
|
target_field=FieldName.TARGET,
|
|
observed_values_field=FieldName.OBSERVED_VALUES,
|
|
max_context_length=self.conditioning_length,
|
|
target_dim=self.target_dim,
|
|
)
|
|
if self.use_marginal_transformation
|
|
else RenameFields(
|
|
{
|
|
f"past_{FieldName.TARGET}": f"past_{FieldName.TARGET}_cdf",
|
|
f"future_{FieldName.TARGET}": f"future_{FieldName.TARGET}_cdf",
|
|
}
|
|
)
|
|
)
|
|
|
|
def create_training_network(self, device: torch.device) -> DeepVARTrainingNetwork:
|
|
return DeepVARTrainingNetwork(
|
|
input_size=self.input_size,
|
|
target_dim=self.target_dim,
|
|
num_layers=self.num_layers,
|
|
num_cells=self.num_cells,
|
|
cell_type=self.cell_type,
|
|
history_length=self.history_length,
|
|
context_length=self.context_length,
|
|
prediction_length=self.prediction_length,
|
|
distr_output=self.distr_output,
|
|
dropout_rate=self.dropout_rate,
|
|
cardinality=self.cardinality,
|
|
embedding_dimension=self.embedding_dimension,
|
|
lags_seq=self.lags_seq,
|
|
scaling=self.scaling,
|
|
).to(device)
|
|
|
|
def create_predictor(
|
|
self,
|
|
transformation: Transformation,
|
|
trained_network: DeepVARTrainingNetwork,
|
|
device: torch.device,
|
|
) -> Predictor:
|
|
prediction_network = DeepVARPredictionNetwork(
|
|
input_size=self.input_size,
|
|
target_dim=self.target_dim,
|
|
num_parallel_samples=self.num_parallel_samples,
|
|
num_layers=self.num_layers,
|
|
num_cells=self.num_cells,
|
|
cell_type=self.cell_type,
|
|
history_length=self.history_length,
|
|
context_length=self.context_length,
|
|
prediction_length=self.prediction_length,
|
|
distr_output=self.distr_output,
|
|
dropout_rate=self.dropout_rate,
|
|
cardinality=self.cardinality,
|
|
embedding_dimension=self.embedding_dimension,
|
|
lags_seq=self.lags_seq,
|
|
scaling=self.scaling,
|
|
).to(device)
|
|
|
|
copy_parameters(trained_network, prediction_network)
|
|
input_names = get_module_forward_input_names(prediction_network)
|
|
prediction_splitter = self.create_instance_splitter("test")
|
|
|
|
return PyTorchPredictor(
|
|
input_transform=transformation + prediction_splitter,
|
|
input_names=input_names,
|
|
prediction_net=prediction_network,
|
|
batch_size=self.trainer.batch_size,
|
|
freq=self.freq,
|
|
prediction_length=self.prediction_length,
|
|
device=device,
|
|
output_transform=self.output_transform,
|
|
)
|