mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 16:46:28 +08:00
326 lines
11 KiB
Python
326 lines
11 KiB
Python
# +
|
|
from typing import Any, Dict, Iterable, List, Optional
|
|
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
|
|
import numpy as np
|
|
|
|
from gluonts.core.component import validated
|
|
from gluonts.dataset.common import Dataset
|
|
from gluonts.dataset.field_names import FieldName
|
|
from gluonts.itertools import Cyclic, IterableSlice, PseudoShuffled
|
|
from gluonts.time_feature import TimeFeature, time_features_from_frequency_str
|
|
from gluonts.torch.model.estimator import PyTorchLightningEstimator
|
|
from gluonts.torch.model.predictor import PyTorchPredictor
|
|
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
|
|
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
|
|
from gluonts.torch.util import IterableDataset
|
|
from gluonts.transform import (
|
|
AddAgeFeature,
|
|
AddObservedValuesIndicator,
|
|
AddTimeFeatures,
|
|
AsNumpyArray,
|
|
Chain,
|
|
ExpectedNumInstanceSampler,
|
|
InstanceSplitter,
|
|
RemoveFields,
|
|
SelectFields,
|
|
SetField,
|
|
TestSplitSampler,
|
|
Transformation,
|
|
ValidationSplitSampler,
|
|
VstackFeatures,
|
|
)
|
|
from gluonts.transform.sampler import InstanceSampler
|
|
|
|
from lightning_module import XformerLightningModule
|
|
from module import XformerModel
|
|
|
|
# +
|
|
PREDICTION_INPUT_NAMES = [
|
|
"feat_static_cat",
|
|
"feat_static_real",
|
|
"past_time_feat",
|
|
"past_target",
|
|
"past_observed_values",
|
|
"future_time_feat",
|
|
]
|
|
|
|
TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [
|
|
"future_target",
|
|
"future_observed_values",
|
|
]
|
|
|
|
|
|
# -
|
|
|
|
|
|
class XformerEstimator(PyTorchLightningEstimator):
|
|
@validated()
|
|
def __init__(
|
|
self,
|
|
freq: str,
|
|
prediction_length: int,
|
|
# Xformer arguments
|
|
nhead: int,
|
|
num_encoder_layers: int,
|
|
num_decoder_layers: int,
|
|
hidden_layer_multiplier: int = 1,
|
|
attention_args={"name": "scaled_dot_product"},
|
|
input_size: int = 1,
|
|
activation: str = "gelu",
|
|
residual_norm_style: str = "pre",
|
|
dropout: float = 0.1,
|
|
use_rotary_embeddings=False,
|
|
reversible=False,
|
|
context_length: Optional[int] = None,
|
|
num_feat_dynamic_real: int = 0,
|
|
num_feat_static_cat: int = 0,
|
|
num_feat_static_real: int = 0,
|
|
cardinality: Optional[List[int]] = None,
|
|
embedding_dimension: Optional[List[int]] = None,
|
|
distr_output: DistributionOutput = StudentTOutput(),
|
|
loss: DistributionLoss = NegativeLogLikelihood(),
|
|
scaling: bool = True,
|
|
lags_seq: Optional[List[int]] = None,
|
|
time_features: Optional[List[TimeFeature]] = None,
|
|
num_parallel_samples: int = 100,
|
|
batch_size: int = 32,
|
|
num_batches_per_epoch: int = 50,
|
|
trainer_kwargs: Optional[Dict[str, Any]] = dict(),
|
|
) -> None:
|
|
trainer_kwargs = {
|
|
"max_epochs": 100,
|
|
**trainer_kwargs,
|
|
}
|
|
super().__init__(trainer_kwargs=trainer_kwargs)
|
|
|
|
self.freq = freq
|
|
self.context_length = (
|
|
context_length if context_length is not None else prediction_length
|
|
)
|
|
self.prediction_length = prediction_length
|
|
self.distr_output = distr_output
|
|
self.loss = loss
|
|
|
|
self.input_size = input_size
|
|
self.nhead = nhead
|
|
self.num_encoder_layers = num_encoder_layers
|
|
self.num_decoder_layers = num_decoder_layers
|
|
self.activation = activation
|
|
self.dropout = dropout
|
|
self.attention_args = attention_args
|
|
self.use_rotary_embeddings = use_rotary_embeddings
|
|
self.reversible = reversible
|
|
self.hidden_layer_multiplier = hidden_layer_multiplier
|
|
self.residual_norm_style = residual_norm_style
|
|
|
|
self.num_feat_dynamic_real = num_feat_dynamic_real
|
|
self.num_feat_static_cat = num_feat_static_cat
|
|
self.num_feat_static_real = num_feat_static_real
|
|
self.cardinality = (
|
|
cardinality if cardinality and num_feat_static_cat > 0 else [1]
|
|
)
|
|
self.embedding_dimension = embedding_dimension
|
|
self.scaling = scaling
|
|
self.lags_seq = lags_seq
|
|
self.time_features = (
|
|
time_features
|
|
if time_features is not None
|
|
else time_features_from_frequency_str(self.freq)
|
|
)
|
|
|
|
self.num_parallel_samples = num_parallel_samples
|
|
self.batch_size = batch_size
|
|
self.num_batches_per_epoch = num_batches_per_epoch
|
|
|
|
self.train_sampler = ExpectedNumInstanceSampler(
|
|
num_instances=1.0, min_future=prediction_length
|
|
)
|
|
self.validation_sampler = ValidationSplitSampler(min_future=prediction_length)
|
|
|
|
def create_transformation(self) -> Transformation:
|
|
remove_field_names = []
|
|
if self.num_feat_static_real == 0:
|
|
remove_field_names.append(FieldName.FEAT_STATIC_REAL)
|
|
if self.num_feat_dynamic_real == 0:
|
|
remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL)
|
|
|
|
return Chain(
|
|
[RemoveFields(field_names=remove_field_names)]
|
|
+ (
|
|
[SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0])]
|
|
if not self.num_feat_static_cat > 0
|
|
else []
|
|
)
|
|
+ (
|
|
[SetField(output_field=FieldName.FEAT_STATIC_REAL, value=[0.0])]
|
|
if not self.num_feat_static_real > 0
|
|
else []
|
|
)
|
|
+ [
|
|
AsNumpyArray(
|
|
field=FieldName.FEAT_STATIC_CAT,
|
|
expected_ndim=1,
|
|
dtype=np.long,
|
|
),
|
|
AsNumpyArray(
|
|
field=FieldName.FEAT_STATIC_REAL,
|
|
expected_ndim=1,
|
|
),
|
|
AsNumpyArray(
|
|
field=FieldName.TARGET,
|
|
# in the following line, we add 1 for the time dimension
|
|
expected_ndim=1 + len(self.distr_output.event_shape),
|
|
),
|
|
AddObservedValuesIndicator(
|
|
target_field=FieldName.TARGET,
|
|
output_field=FieldName.OBSERVED_VALUES,
|
|
),
|
|
AddTimeFeatures(
|
|
start_field=FieldName.START,
|
|
target_field=FieldName.TARGET,
|
|
output_field=FieldName.FEAT_TIME,
|
|
time_features=self.time_features,
|
|
pred_length=self.prediction_length,
|
|
),
|
|
AddAgeFeature(
|
|
target_field=FieldName.TARGET,
|
|
output_field=FieldName.FEAT_AGE,
|
|
pred_length=self.prediction_length,
|
|
log_scale=True,
|
|
),
|
|
VstackFeatures(
|
|
output_field=FieldName.FEAT_TIME,
|
|
input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE]
|
|
+ (
|
|
[FieldName.FEAT_DYNAMIC_REAL]
|
|
if self.num_feat_dynamic_real > 0
|
|
else []
|
|
),
|
|
),
|
|
]
|
|
)
|
|
|
|
def _create_instance_splitter(self, module: XformerLightningModule, mode: str):
|
|
assert mode in ["training", "validation", "test"]
|
|
|
|
instance_sampler = {
|
|
"training": self.train_sampler,
|
|
"validation": self.validation_sampler,
|
|
"test": TestSplitSampler(),
|
|
}[mode]
|
|
|
|
return InstanceSplitter(
|
|
target_field=FieldName.TARGET,
|
|
is_pad_field=FieldName.IS_PAD,
|
|
start_field=FieldName.START,
|
|
forecast_start_field=FieldName.FORECAST_START,
|
|
instance_sampler=instance_sampler,
|
|
past_length=module.model._past_length,
|
|
future_length=self.prediction_length,
|
|
time_series_fields=[
|
|
FieldName.FEAT_TIME,
|
|
FieldName.OBSERVED_VALUES,
|
|
],
|
|
dummy_value=self.distr_output.value_in_support,
|
|
)
|
|
|
|
def create_training_data_loader(
|
|
self,
|
|
data: Dataset,
|
|
module: XformerLightningModule,
|
|
shuffle_buffer_length: Optional[int] = None,
|
|
**kwargs,
|
|
) -> Iterable:
|
|
transformation = self._create_instance_splitter(
|
|
module, "training"
|
|
) + SelectFields(TRAINING_INPUT_NAMES)
|
|
|
|
training_instances = transformation.apply(
|
|
Cyclic(data)
|
|
if shuffle_buffer_length is None
|
|
else PseudoShuffled(
|
|
Cyclic(data), shuffle_buffer_length=shuffle_buffer_length
|
|
)
|
|
)
|
|
|
|
return IterableSlice(
|
|
iter(
|
|
DataLoader(
|
|
IterableDataset(training_instances),
|
|
batch_size=self.batch_size,
|
|
**kwargs,
|
|
)
|
|
),
|
|
self.num_batches_per_epoch,
|
|
)
|
|
|
|
def create_validation_data_loader(
|
|
self,
|
|
data: Dataset,
|
|
module: XformerLightningModule,
|
|
**kwargs,
|
|
) -> Iterable:
|
|
transformation = self._create_instance_splitter(
|
|
module, "validation"
|
|
) + SelectFields(TRAINING_INPUT_NAMES)
|
|
|
|
validation_instances = transformation.apply(data)
|
|
|
|
return DataLoader(
|
|
IterableDataset(validation_instances),
|
|
batch_size=self.batch_size,
|
|
**kwargs,
|
|
)
|
|
|
|
def create_predictor(
|
|
self,
|
|
transformation: Transformation,
|
|
module: XformerLightningModule,
|
|
) -> PyTorchPredictor:
|
|
prediction_splitter = self._create_instance_splitter(module, "test")
|
|
|
|
return PyTorchPredictor(
|
|
input_transform=transformation + prediction_splitter,
|
|
input_names=PREDICTION_INPUT_NAMES,
|
|
prediction_net=module.model,
|
|
batch_size=self.batch_size,
|
|
prediction_length=self.prediction_length,
|
|
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
|
)
|
|
|
|
def create_lightning_module(self) -> XformerLightningModule:
|
|
model = XformerModel(
|
|
freq=self.freq,
|
|
context_length=self.context_length,
|
|
prediction_length=self.prediction_length,
|
|
num_feat_dynamic_real=1
|
|
+ self.num_feat_dynamic_real
|
|
+ len(self.time_features),
|
|
num_feat_static_real=max(1, self.num_feat_static_real),
|
|
num_feat_static_cat=max(1, self.num_feat_static_cat),
|
|
cardinality=self.cardinality,
|
|
embedding_dimension=self.embedding_dimension,
|
|
# xformer arguments
|
|
nhead=self.nhead,
|
|
num_encoder_layers=self.num_encoder_layers,
|
|
num_decoder_layers=self.num_decoder_layers,
|
|
hidden_layer_multiplier=self.hidden_layer_multiplier,
|
|
activation=self.activation,
|
|
dropout=self.dropout,
|
|
attention_args=self.attention_args,
|
|
use_rotary_embeddings=self.use_rotary_embeddings,
|
|
reversible=self.reversible,
|
|
residual_norm_style=self.residual_norm_style,
|
|
# univariate input
|
|
input_size=self.input_size,
|
|
distr_output=self.distr_output,
|
|
lags_seq=self.lags_seq,
|
|
scaling=self.scaling,
|
|
num_parallel_samples=self.num_parallel_samples,
|
|
)
|
|
|
|
return XformerLightningModule(model=model, loss=self.loss)
|