mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 18:06:14 +08:00
388 lines
14 KiB
Python
388 lines
14 KiB
Python
from typing import List, Optional, Iterable, Dict, Any
|
|
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
|
|
from gluonts.core.component import validated
|
|
from gluonts.dataset.common import Dataset
|
|
from gluonts.dataset.field_names import FieldName
|
|
from gluonts.itertools import Cyclic, PseudoShuffled, IterableSlice
|
|
from gluonts.time_feature import (
|
|
TimeFeature,
|
|
time_features_from_frequency_str,
|
|
)
|
|
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
|
|
from gluonts.transform import (
|
|
Transformation,
|
|
Chain,
|
|
RemoveFields,
|
|
SetField,
|
|
AsNumpyArray,
|
|
AddObservedValuesIndicator,
|
|
AddTimeFeatures,
|
|
AddAgeFeature,
|
|
VstackFeatures,
|
|
InstanceSplitter,
|
|
ValidationSplitSampler,
|
|
TestSplitSampler,
|
|
ExpectedNumInstanceSampler,
|
|
SelectFields,
|
|
)
|
|
from gluonts.torch.util import (
|
|
IterableDataset,
|
|
)
|
|
from gluonts.torch.model.estimator import PyTorchLightningEstimator
|
|
from gluonts.torch.model.predictor import PyTorchPredictor
|
|
from gluonts.torch.distributions import (
|
|
DistributionOutput,
|
|
StudentTOutput,
|
|
)
|
|
from gluonts.transform.sampler import InstanceSampler
|
|
|
|
from module import PerceiverARModel
|
|
from lightning_module import PerceiverARLightningModule
|
|
|
|
PREDICTION_INPUT_NAMES = [
|
|
"feat_static_cat",
|
|
"feat_static_real",
|
|
"past_time_feat",
|
|
"past_target",
|
|
"past_observed_values",
|
|
"future_time_feat",
|
|
]
|
|
|
|
TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [
|
|
"future_target",
|
|
"future_observed_values",
|
|
]
|
|
|
|
|
|
class PerceiverAREstimator(PyTorchLightningEstimator):
|
|
"""
|
|
Estimator class to train a PerceiverAR model.
|
|
|
|
This class is uses the model defined in ``PerceiverARModel``, and wraps it
|
|
into a ``PerceiverARLightningModule`` for training purposes: training is
|
|
performed using PyTorch Lightning's ``pl.Trainer`` class.
|
|
|
|
Parameters
|
|
----------
|
|
freq
|
|
Frequency of the data to train on and predict.
|
|
prediction_length
|
|
Length of the prediction horizon.
|
|
context_length
|
|
Number of steps to unroll the RNN for before computing predictions
|
|
(default: None, in which case context_length = prediction_length).
|
|
perceive_depth
|
|
Number of RNN layers (default: 2).
|
|
hidden_size
|
|
Number of RNN cells for each layer (default: 40).
|
|
dropout_rate
|
|
Dropout regularization parameter (default: 0.1).
|
|
num_feat_dynamic_real
|
|
Number of dynamic real features in the data (default: 0).
|
|
num_feat_static_real
|
|
Number of static real features in the data (default: 0).
|
|
num_feat_static_cat
|
|
Number of static categorical features in the data (default: 0).
|
|
cardinality
|
|
Number of values of each categorical feature.
|
|
This must be set if ``num_feat_static_cat > 0`` (default: None).
|
|
embedding_dimension
|
|
Dimension of the embeddings for categorical features
|
|
(default: ``[min(50, (cat+1)//2) for cat in cardinality]``).
|
|
distr_output
|
|
Distribution to use to evaluate observations and sample predictions
|
|
(default: StudentTOutput()).
|
|
loss
|
|
Loss to be optimized during training
|
|
(default: ``NegativeLogLikelihood()``).
|
|
scaling
|
|
Whether to automatically scale the target values (default: true).
|
|
lags_seq
|
|
Indices of the lagged target values to use as inputs of the RNN
|
|
(default: None, in which case these are automatically determined
|
|
based on freq).
|
|
time_features
|
|
List of time features, from :py:mod:`gluonts.time_feature`, to use as
|
|
inputs of the RNN in addition to the provided data (default: None,
|
|
in which case these are automatically determined based on freq).
|
|
num_parallel_samples
|
|
Number of samples per time series to that the resulting predictor
|
|
should produce (default: 100).
|
|
batch_size
|
|
The size of the batches to be used for training (default: 32).
|
|
num_batches_per_epoch
|
|
Number of batches to be processed in each training epoch
|
|
(default: 50).
|
|
trainer_kwargs
|
|
Additional arguments to provide to ``pl.Trainer`` for construction.
|
|
train_sampler
|
|
Controls the sampling of windows during training.
|
|
validation_sampler
|
|
Controls the sampling of windows during validation.
|
|
"""
|
|
|
|
@validated()
|
|
def __init__(
|
|
self,
|
|
freq: str,
|
|
prediction_length: int,
|
|
depth: int,
|
|
context_length: Optional[int] = None,
|
|
input_size: int = 1,
|
|
perceive_depth: int = 1,
|
|
heads: int = 2,
|
|
hidden_size: int = 32,
|
|
dropout_rate: float = 0.1,
|
|
cross_attn_dropout: float = 0.1,
|
|
perceive_max_heads_process: int = 2,
|
|
ff_mult: int = 1,
|
|
num_feat_dynamic_real: int = 0,
|
|
num_feat_static_cat: int = 0,
|
|
num_feat_static_real: int = 0,
|
|
cardinality: Optional[List[int]] = None,
|
|
embedding_dimension: Optional[List[int]] = None,
|
|
distr_output: DistributionOutput = StudentTOutput(),
|
|
loss: DistributionLoss = NegativeLogLikelihood(),
|
|
scaling: bool = True,
|
|
lags_seq: Optional[List[int]] = None,
|
|
time_features: Optional[List[TimeFeature]] = None,
|
|
num_parallel_samples: int = 100,
|
|
batch_size: int = 32,
|
|
num_batches_per_epoch: int = 50,
|
|
trainer_kwargs: Optional[Dict[str, Any]] = None,
|
|
train_sampler: Optional[InstanceSampler] = None,
|
|
validation_sampler: Optional[InstanceSampler] = None,
|
|
) -> None:
|
|
default_trainer_kwargs = {
|
|
"max_epochs": 100,
|
|
"gradient_clip_val": 10.0,
|
|
}
|
|
if trainer_kwargs is not None:
|
|
default_trainer_kwargs.update(trainer_kwargs)
|
|
super().__init__(trainer_kwargs=default_trainer_kwargs)
|
|
|
|
self.input_size = input_size
|
|
self.freq = freq
|
|
self.context_length = (
|
|
context_length if context_length is not None else prediction_length
|
|
)
|
|
self.prediction_length = prediction_length
|
|
self.distr_output = distr_output
|
|
self.loss = loss
|
|
self.depth = depth
|
|
self.perceive_depth = perceive_depth
|
|
self.hidden_size = hidden_size
|
|
self.dropout_rate = dropout_rate
|
|
self.heads = heads
|
|
self.perceive_max_heads_process = perceive_max_heads_process
|
|
self.ff_mult = ff_mult
|
|
self.cross_attn_dropout = cross_attn_dropout
|
|
self.num_feat_dynamic_real = num_feat_dynamic_real
|
|
self.num_feat_static_cat = num_feat_static_cat
|
|
self.num_feat_static_real = num_feat_static_real
|
|
self.cardinality = (
|
|
cardinality if cardinality and num_feat_static_cat > 0 else [1]
|
|
)
|
|
self.embedding_dimension = embedding_dimension
|
|
self.scaling = scaling
|
|
self.lags_seq = lags_seq
|
|
self.time_features = (
|
|
time_features
|
|
if time_features is not None
|
|
else time_features_from_frequency_str(self.freq)
|
|
)
|
|
|
|
self.num_parallel_samples = num_parallel_samples
|
|
self.batch_size = batch_size
|
|
self.num_batches_per_epoch = num_batches_per_epoch
|
|
|
|
self.train_sampler = train_sampler or ExpectedNumInstanceSampler(
|
|
num_instances=1.0, min_future=prediction_length
|
|
)
|
|
self.validation_sampler = validation_sampler or ValidationSplitSampler(
|
|
min_future=prediction_length
|
|
)
|
|
|
|
def create_transformation(self) -> Transformation:
|
|
remove_field_names = []
|
|
if self.num_feat_static_real == 0:
|
|
remove_field_names.append(FieldName.FEAT_STATIC_REAL)
|
|
if self.num_feat_dynamic_real == 0:
|
|
remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL)
|
|
|
|
return Chain(
|
|
[RemoveFields(field_names=remove_field_names)]
|
|
+ (
|
|
[SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0])]
|
|
if not self.num_feat_static_cat > 0
|
|
else []
|
|
)
|
|
+ (
|
|
[SetField(output_field=FieldName.FEAT_STATIC_REAL, value=[0.0])]
|
|
if not self.num_feat_static_real > 0
|
|
else []
|
|
)
|
|
+ [
|
|
AsNumpyArray(
|
|
field=FieldName.FEAT_STATIC_CAT,
|
|
expected_ndim=1,
|
|
dtype=int,
|
|
),
|
|
AsNumpyArray(
|
|
field=FieldName.FEAT_STATIC_REAL,
|
|
expected_ndim=1,
|
|
),
|
|
AsNumpyArray(
|
|
field=FieldName.TARGET,
|
|
# in the following line, we add 1 for the time dimension
|
|
expected_ndim=1 + len(self.distr_output.event_shape),
|
|
),
|
|
AddObservedValuesIndicator(
|
|
target_field=FieldName.TARGET,
|
|
output_field=FieldName.OBSERVED_VALUES,
|
|
),
|
|
AddTimeFeatures(
|
|
start_field=FieldName.START,
|
|
target_field=FieldName.TARGET,
|
|
output_field=FieldName.FEAT_TIME,
|
|
time_features=self.time_features,
|
|
pred_length=self.prediction_length,
|
|
),
|
|
AddAgeFeature(
|
|
target_field=FieldName.TARGET,
|
|
output_field=FieldName.FEAT_AGE,
|
|
pred_length=self.prediction_length,
|
|
log_scale=True,
|
|
),
|
|
VstackFeatures(
|
|
output_field=FieldName.FEAT_TIME,
|
|
input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE]
|
|
+ (
|
|
[FieldName.FEAT_DYNAMIC_REAL]
|
|
if self.num_feat_dynamic_real > 0
|
|
else []
|
|
),
|
|
),
|
|
]
|
|
)
|
|
|
|
def _create_instance_splitter(self, module: PerceiverARLightningModule, mode: str):
|
|
assert mode in ["training", "validation", "test"]
|
|
|
|
instance_sampler = {
|
|
"training": self.train_sampler,
|
|
"validation": self.validation_sampler,
|
|
"test": TestSplitSampler(),
|
|
}[mode]
|
|
|
|
return InstanceSplitter(
|
|
target_field=FieldName.TARGET,
|
|
is_pad_field=FieldName.IS_PAD,
|
|
start_field=FieldName.START,
|
|
forecast_start_field=FieldName.FORECAST_START,
|
|
instance_sampler=instance_sampler,
|
|
past_length=module.model._past_length,
|
|
future_length=self.prediction_length,
|
|
time_series_fields=[
|
|
FieldName.FEAT_TIME,
|
|
FieldName.OBSERVED_VALUES,
|
|
],
|
|
dummy_value=self.distr_output.value_in_support,
|
|
)
|
|
|
|
def create_training_data_loader(
|
|
self,
|
|
data: Dataset,
|
|
module: PerceiverARLightningModule,
|
|
shuffle_buffer_length: Optional[int] = None,
|
|
**kwargs,
|
|
) -> Iterable:
|
|
transformation = self._create_instance_splitter(
|
|
module, "training"
|
|
) + SelectFields(TRAINING_INPUT_NAMES)
|
|
|
|
training_instances = transformation.apply(
|
|
Cyclic(data)
|
|
if shuffle_buffer_length is None
|
|
else PseudoShuffled(
|
|
Cyclic(data), shuffle_buffer_length=shuffle_buffer_length
|
|
)
|
|
)
|
|
|
|
return IterableSlice(
|
|
iter(
|
|
DataLoader(
|
|
IterableDataset(training_instances),
|
|
batch_size=self.batch_size,
|
|
**kwargs,
|
|
)
|
|
),
|
|
self.num_batches_per_epoch,
|
|
)
|
|
|
|
def create_validation_data_loader(
|
|
self,
|
|
data: Dataset,
|
|
module: PerceiverARLightningModule,
|
|
**kwargs,
|
|
) -> Iterable:
|
|
transformation = self._create_instance_splitter(
|
|
module, "validation"
|
|
) + SelectFields(TRAINING_INPUT_NAMES)
|
|
|
|
validation_instances = transformation.apply(data)
|
|
|
|
return DataLoader(
|
|
IterableDataset(validation_instances),
|
|
batch_size=self.batch_size,
|
|
**kwargs,
|
|
)
|
|
|
|
def create_lightning_module(self) -> PerceiverARLightningModule:
|
|
model = PerceiverARModel(
|
|
input_size=self.input_size,
|
|
freq=self.freq,
|
|
depth=self.depth,
|
|
context_length=self.context_length,
|
|
prediction_length=self.prediction_length,
|
|
num_feat_dynamic_real=(
|
|
1 + self.num_feat_dynamic_real + len(self.time_features)
|
|
),
|
|
num_feat_static_real=max(1, self.num_feat_static_real),
|
|
num_feat_static_cat=max(1, self.num_feat_static_cat),
|
|
cardinality=self.cardinality,
|
|
embedding_dimension=self.embedding_dimension,
|
|
perceive_depth=self.perceive_depth,
|
|
heads=self.heads,
|
|
perceive_max_heads_process=self.perceive_max_heads_process,
|
|
ff_mult=self.ff_mult,
|
|
cross_attn_dropout=self.cross_attn_dropout,
|
|
hidden_size=self.hidden_size,
|
|
distr_output=self.distr_output,
|
|
dropout_rate=self.dropout_rate,
|
|
lags_seq=self.lags_seq,
|
|
scaling=self.scaling,
|
|
num_parallel_samples=self.num_parallel_samples,
|
|
)
|
|
|
|
return PerceiverARLightningModule(model=model, loss=self.loss)
|
|
|
|
def create_predictor(
|
|
self,
|
|
transformation: Transformation,
|
|
module: PerceiverARLightningModule,
|
|
) -> PyTorchPredictor:
|
|
prediction_splitter = self._create_instance_splitter(module, "test")
|
|
|
|
return PyTorchPredictor(
|
|
input_transform=transformation + prediction_splitter,
|
|
input_names=PREDICTION_INPUT_NAMES,
|
|
prediction_net=module.model,
|
|
batch_size=self.batch_size,
|
|
prediction_length=self.prediction_length,
|
|
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
|
)
|