Files
pytorch-ts/pts/model/deepvar/deepvar_estimator.py
T
Kashif Rasul ea9b2b7df5 Gluon master (#29)
* Estimator needs an create_instance_splitter now

* updated estimators and tests

* fix test

* validated
2021-02-07 17:43:07 +01:00

304 lines
11 KiB
Python

from typing import List, Optional, Callable
import numpy as np
import torch
from gluonts.core.component import validated
from gluonts.dataset.field_names import FieldName
from gluonts.time_feature import TimeFeature
from gluonts.torch.modules.distribution_output import DistributionOutput
from gluonts.torch.support.util import copy_parameters
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.model.predictor import Predictor
from gluonts.transform import (
AddObservedValuesIndicator,
AddTimeFeatures,
AsNumpyArray,
CDFtoGaussianTransform,
Chain,
ExpandDimArray,
ExpectedNumInstanceSampler,
InstanceSplitter,
ValidationSplitSampler,
TestSplitSampler,
RenameFields,
SetField,
TargetDimIndicator,
Transformation,
VstackFeatures,
RemoveFields,
AddAgeFeature,
cdf_to_gaussian_forward_transform,
)
from pts import Trainer
from pts.model.utils import get_module_forward_input_names
from pts.feature import (
fourier_time_features_from_frequency,
lags_for_fourier_time_features_from_frequency,
)
from pts.model import PyTorchEstimator
from pts.modules import LowRankMultivariateNormalOutput
from .deepvar_network import DeepVARTrainingNetwork, DeepVARPredictionNetwork
class DeepVAREstimator(PyTorchEstimator):
@validated()
def __init__(
self,
input_size: int,
freq: str,
prediction_length: int,
target_dim: int,
trainer: Trainer = Trainer(),
context_length: Optional[int] = None,
num_layers: int = 2,
num_cells: int = 40,
cell_type: str = "LSTM",
num_parallel_samples: int = 100,
dropout_rate: float = 0.1,
use_feat_dynamic_real: bool = False,
use_feat_static_cat: bool = False,
use_feat_static_real: bool = False,
cardinality: Optional[List[int]] = None,
embedding_dimension: Optional[List[int]] = None,
distr_output: Optional[DistributionOutput] = None,
rank: Optional[int] = 5,
scaling: bool = True,
pick_incomplete: bool = False,
lags_seq: Optional[List[int]] = None,
time_features: Optional[List[TimeFeature]] = None,
conditioning_length: int = 200,
use_marginal_transformation=False,
**kwargs,
) -> None:
super().__init__(trainer=trainer, **kwargs)
self.freq = freq
self.context_length = (
context_length if context_length is not None else prediction_length
)
if distr_output is not None:
self.distr_output = distr_output
else:
self.distr_output = LowRankMultivariateNormalOutput(
dim=target_dim, rank=rank
)
self.input_size = input_size
self.prediction_length = prediction_length
self.target_dim = target_dim
self.num_layers = num_layers
self.num_cells = num_cells
self.cell_type = cell_type
self.num_parallel_samples = num_parallel_samples
self.dropout_rate = dropout_rate
self.use_feat_dynamic_real = use_feat_dynamic_real
self.use_feat_static_cat = use_feat_static_cat
self.use_feat_static_real = use_feat_static_real
self.cardinality = cardinality if cardinality and use_feat_static_cat else [1]
self.embedding_dimension = (
embedding_dimension
if embedding_dimension is not None
else [min(50, (cat + 1) // 2) for cat in self.cardinality]
)
self.conditioning_length = conditioning_length
self.use_marginal_transformation = use_marginal_transformation
self.lags_seq = (
lags_seq
if lags_seq is not None
else lags_for_fourier_time_features_from_frequency(freq_str=freq)
)
self.time_features = (
time_features
if time_features is not None
else fourier_time_features_from_frequency(self.freq)
)
self.history_length = self.context_length + max(self.lags_seq)
self.pick_incomplete = pick_incomplete
self.scaling = scaling
if self.use_marginal_transformation:
self.output_transform: Optional[
Callable
] = cdf_to_gaussian_forward_transform
else:
self.output_transform = None
self.train_sampler = ExpectedNumInstanceSampler(
num_instances=1.0,
min_past=0 if pick_incomplete else self.history_length,
min_future=prediction_length,
)
self.validation_sampler = ValidationSplitSampler(
min_past=0 if pick_incomplete else self.history_length,
min_future=prediction_length,
)
def create_transformation(self) -> Transformation:
remove_field_names = [FieldName.FEAT_DYNAMIC_CAT]
if not self.use_feat_dynamic_real:
remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL)
if not self.use_feat_static_real:
remove_field_names.append(FieldName.FEAT_STATIC_REAL)
return Chain(
[RemoveFields(field_names=remove_field_names)]
+ (
[SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0])]
if not self.use_feat_static_cat
else []
)
+ (
[SetField(output_field=FieldName.FEAT_STATIC_REAL, value=[0.0])]
if not self.use_feat_static_real
else []
)
+ [
AsNumpyArray(
field=FieldName.TARGET,
expected_ndim=1 + len(self.distr_output.event_shape),
),
# maps the target to (1, T)
# if the target data is uni dimensional
ExpandDimArray(
field=FieldName.TARGET,
axis=0 if self.distr_output.event_shape[0] == 1 else None,
),
AddObservedValuesIndicator(
target_field=FieldName.TARGET,
output_field=FieldName.OBSERVED_VALUES,
),
AddTimeFeatures(
start_field=FieldName.START,
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_TIME,
time_features=self.time_features,
pred_length=self.prediction_length,
),
AddAgeFeature(
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_AGE,
pred_length=self.prediction_length,
log_scale=True,
),
VstackFeatures(
output_field=FieldName.FEAT_TIME,
input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE]
+ (
[FieldName.FEAT_DYNAMIC_REAL]
if self.use_feat_dynamic_real
else []
),
),
TargetDimIndicator(
field_name="target_dimension_indicator",
target_field=FieldName.TARGET,
),
AsNumpyArray(
field=FieldName.FEAT_STATIC_CAT, expected_ndim=1, dtype=np.long
),
AsNumpyArray(field=FieldName.FEAT_STATIC_REAL, expected_ndim=1),
]
)
def create_instance_splitter(self, mode: str):
assert mode in ["training", "validation", "test"]
instance_sampler = {
"training": self.train_sampler,
"validation": self.validation_sampler,
"test": TestSplitSampler(),
}[mode]
return InstanceSplitter(
target_field=FieldName.TARGET,
is_pad_field=FieldName.IS_PAD,
start_field=FieldName.START,
forecast_start_field=FieldName.FORECAST_START,
instance_sampler=instance_sampler,
past_length=self.history_length,
future_length=self.prediction_length,
time_series_fields=[
FieldName.FEAT_TIME,
FieldName.OBSERVED_VALUES,
],
) + (
CDFtoGaussianTransform(
target_field=FieldName.TARGET,
observed_values_field=FieldName.OBSERVED_VALUES,
max_context_length=self.conditioning_length,
target_dim=self.target_dim,
)
if self.use_marginal_transformation
else RenameFields(
{
f"past_{FieldName.TARGET}": f"past_{FieldName.TARGET}_cdf",
f"future_{FieldName.TARGET}": f"future_{FieldName.TARGET}_cdf",
}
)
)
def create_training_network(self, device: torch.device) -> DeepVARTrainingNetwork:
return DeepVARTrainingNetwork(
input_size=self.input_size,
target_dim=self.target_dim,
num_layers=self.num_layers,
num_cells=self.num_cells,
cell_type=self.cell_type,
history_length=self.history_length,
context_length=self.context_length,
prediction_length=self.prediction_length,
distr_output=self.distr_output,
dropout_rate=self.dropout_rate,
cardinality=self.cardinality,
embedding_dimension=self.embedding_dimension,
lags_seq=self.lags_seq,
scaling=self.scaling,
).to(device)
def create_predictor(
self,
transformation: Transformation,
trained_network: DeepVARTrainingNetwork,
device: torch.device,
) -> Predictor:
prediction_network = DeepVARPredictionNetwork(
input_size=self.input_size,
target_dim=self.target_dim,
num_parallel_samples=self.num_parallel_samples,
num_layers=self.num_layers,
num_cells=self.num_cells,
cell_type=self.cell_type,
history_length=self.history_length,
context_length=self.context_length,
prediction_length=self.prediction_length,
distr_output=self.distr_output,
dropout_rate=self.dropout_rate,
cardinality=self.cardinality,
embedding_dimension=self.embedding_dimension,
lags_seq=self.lags_seq,
scaling=self.scaling,
).to(device)
copy_parameters(trained_network, prediction_network)
input_names = get_module_forward_input_names(prediction_network)
prediction_splitter = self.create_instance_splitter("test")
return PyTorchPredictor(
input_transform=transformation + prediction_splitter,
input_names=input_names,
prediction_net=prediction_network,
batch_size=self.trainer.batch_size,
freq=self.freq,
prediction_length=self.prediction_length,
device=device,
output_transform=self.output_transform,
)