mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 16:46:28 +08:00
357 lines
13 KiB
Python
357 lines
13 KiB
Python
from typing import Any, Dict, Iterable, List, Optional
|
|
|
|
import torch
|
|
from gluonts.core.component import validated
|
|
from gluonts.dataset.common import Dataset
|
|
from gluonts.dataset.field_names import FieldName
|
|
from gluonts.itertools import Cyclic, IterableSlice, PseudoShuffled
|
|
from gluonts.time_feature import TimeFeature, time_features_from_frequency_str
|
|
from gluonts.torch.model.estimator import PyTorchLightningEstimator
|
|
from gluonts.torch.model.predictor import PyTorchPredictor
|
|
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
|
|
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
|
|
from gluonts.torch.util import IterableDataset
|
|
from gluonts.transform import (
|
|
AddAgeFeature,
|
|
AddObservedValuesIndicator,
|
|
AddTimeFeatures,
|
|
AsNumpyArray,
|
|
Chain,
|
|
ExpectedNumInstanceSampler,
|
|
InstanceSplitter,
|
|
RemoveFields,
|
|
SelectFields,
|
|
SetField,
|
|
TestSplitSampler,
|
|
Transformation,
|
|
ValidationSplitSampler,
|
|
VstackFeatures,
|
|
)
|
|
from gluonts.transform.sampler import InstanceSampler
|
|
from torch.utils.data import DataLoader
|
|
|
|
from lightning_module import FEDformerLightningModule
|
|
from module import FEDformerModel
|
|
import random
|
|
PREDICTION_INPUT_NAMES = [
|
|
"feat_static_cat",
|
|
"feat_static_real",
|
|
"past_time_feat",
|
|
"past_target",
|
|
"past_observed_values",
|
|
"future_time_feat",
|
|
]
|
|
|
|
TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [
|
|
"future_target",
|
|
"future_observed_values",
|
|
]
|
|
|
|
|
|
class FEDformerEstimator(PyTorchLightningEstimator):
|
|
@validated()
|
|
def __init__(
|
|
self,
|
|
freq: str,
|
|
prediction_length: int,
|
|
nhead: int,
|
|
num_encoder_layers: int,
|
|
num_decoder_layers: int,
|
|
dim_feedforward: int = 2048, #dimension of fcn
|
|
version:str = 'Fourier',#Fourier, Wavelets
|
|
features: str = 'M',#options:[M, S, MS]; M:multivariate predict multivariate, 'S':univariate predict univariate, MS:multivariate predict univariate'
|
|
target: str = '0T',#'target feature in S or MS task'
|
|
checkpoints:str = './checkpoints/', #location of model checkpoints
|
|
modes:int = 64,
|
|
mode_select:str = 'random',#options: [random, low]
|
|
base:str = 'legendre',
|
|
cross_activation:str = 'tanh',
|
|
L: int = 3,
|
|
#forecasting task
|
|
context_length: Optional[int] = None,#seq_len : input sequence length
|
|
label_length: Optional[int] = 48,#start token length
|
|
# model argument
|
|
input_size: int = 1, #encoder input size
|
|
dec_in: int = 7, #decoder input size
|
|
c_out: int = 7, #output size
|
|
moving_avg: List[int] = [24],
|
|
factor:int = 1,
|
|
activation: str = "gelu",
|
|
dropout: float = 0.05,
|
|
embed: str = "timeF",#options:[timeF, fixed, learned]
|
|
output_attention: bool = True,#whether to output attention in encoder
|
|
distil: bool = True,
|
|
num_feat_dynamic_real: int = 0,
|
|
num_feat_static_cat: int = 0,
|
|
num_feat_static_real: int = 0,
|
|
cardinality: Optional[List[int]] = None,
|
|
embedding_dimension: Optional[List[int]] = None,
|
|
distr_output: DistributionOutput = StudentTOutput(),
|
|
loss: DistributionLoss = NegativeLogLikelihood(),
|
|
scaling: bool = True,
|
|
lags_seq: Optional[List[int]] = None,
|
|
time_features: Optional[List[TimeFeature]] = None,
|
|
num_parallel_samples: int = 100,
|
|
batch_size: int = 32,
|
|
num_batches_per_epoch: int = 512,
|
|
trainer_kwargs: Optional[Dict[str, Any]] = dict(),
|
|
train_sampler: Optional[InstanceSampler] = None,
|
|
validation_sampler: Optional[InstanceSampler] = None,
|
|
) -> None:
|
|
trainer_kwargs = {
|
|
"max_epochs": 100,
|
|
**trainer_kwargs,
|
|
}
|
|
super().__init__(trainer_kwargs=trainer_kwargs)
|
|
self.version = version
|
|
self.modes = modes
|
|
self.mode_select = mode_select
|
|
self.base = base
|
|
self.cross_activation = cross_activation
|
|
self.L = L
|
|
self.freq = freq
|
|
self.context_length = (
|
|
context_length if context_length is not None else prediction_length
|
|
)
|
|
self.label_length = label_length
|
|
self.prediction_length = prediction_length
|
|
self.distr_output = distr_output
|
|
self.loss = loss
|
|
|
|
self.input_size = input_size
|
|
# self.dec_in = dec_in
|
|
self.c_out = c_out
|
|
self.nhead = nhead
|
|
self.num_encoder_layers = num_encoder_layers
|
|
self.num_decoder_layers = num_decoder_layers
|
|
self.activation = activation
|
|
self.dim_feedforward = dim_feedforward
|
|
self.dropout = dropout
|
|
self.output_attention = output_attention
|
|
# self.distil = distil
|
|
self.factor = factor
|
|
self.moving_avg = moving_avg
|
|
self.embed = embed
|
|
|
|
self.num_feat_dynamic_real = num_feat_dynamic_real
|
|
self.num_feat_static_cat = num_feat_static_cat
|
|
self.num_feat_static_real = num_feat_static_real
|
|
self.cardinality = (
|
|
cardinality if cardinality and num_feat_static_cat > 0 else [1]
|
|
)
|
|
self.embedding_dimension = embedding_dimension
|
|
self.scaling = scaling
|
|
self.lags_seq = lags_seq
|
|
self.time_features = (
|
|
time_features
|
|
if time_features is not None
|
|
else time_features_from_frequency_str(self.freq)
|
|
)
|
|
|
|
self.num_parallel_samples = num_parallel_samples
|
|
self.batch_size = batch_size
|
|
self.num_batches_per_epoch = num_batches_per_epoch
|
|
|
|
self.train_sampler = train_sampler or ExpectedNumInstanceSampler(
|
|
num_instances=1.0, min_future=prediction_length
|
|
)
|
|
self.validation_sampler = validation_sampler or ValidationSplitSampler(
|
|
min_future=prediction_length
|
|
)
|
|
|
|
def create_transformation(self) -> Transformation:
|
|
remove_field_names = []
|
|
if self.num_feat_static_real == 0:
|
|
remove_field_names.append(FieldName.FEAT_STATIC_REAL)
|
|
if self.num_feat_dynamic_real == 0:
|
|
remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL)
|
|
|
|
return Chain(
|
|
[RemoveFields(field_names=remove_field_names)]
|
|
+ (
|
|
[SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0])]
|
|
if not self.num_feat_static_cat > 0
|
|
else []
|
|
)
|
|
+ (
|
|
[SetField(output_field=FieldName.FEAT_STATIC_REAL, value=[0.0])]
|
|
if not self.num_feat_static_real > 0
|
|
else []
|
|
)
|
|
+ [
|
|
AsNumpyArray(
|
|
field=FieldName.FEAT_STATIC_CAT,
|
|
expected_ndim=1,
|
|
dtype=int,
|
|
),
|
|
AsNumpyArray(
|
|
field=FieldName.FEAT_STATIC_REAL,
|
|
expected_ndim=1,
|
|
),
|
|
AsNumpyArray(
|
|
field=FieldName.TARGET,
|
|
# in the following line, we add 1 for the time dimension
|
|
expected_ndim=1 + len(self.distr_output.event_shape),
|
|
),
|
|
AddObservedValuesIndicator(
|
|
target_field=FieldName.TARGET,
|
|
output_field=FieldName.OBSERVED_VALUES,
|
|
),
|
|
AddTimeFeatures(
|
|
start_field=FieldName.START,
|
|
target_field=FieldName.TARGET,
|
|
output_field=FieldName.FEAT_TIME,
|
|
time_features=self.time_features,
|
|
pred_length=self.prediction_length,
|
|
),
|
|
AddAgeFeature(
|
|
target_field=FieldName.TARGET,
|
|
output_field=FieldName.FEAT_AGE,
|
|
pred_length=self.prediction_length,
|
|
log_scale=True,
|
|
),
|
|
VstackFeatures(
|
|
output_field=FieldName.FEAT_TIME,
|
|
input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE]
|
|
+ (
|
|
[FieldName.FEAT_DYNAMIC_REAL]
|
|
if self.num_feat_dynamic_real > 0
|
|
else []
|
|
),
|
|
),
|
|
]
|
|
)
|
|
|
|
def _create_instance_splitter(self, module: FEDformerLightningModule, mode: str):
|
|
assert mode in ["training", "validation", "test"]
|
|
|
|
instance_sampler = {
|
|
"training": self.train_sampler,
|
|
"validation": self.validation_sampler,
|
|
"test": TestSplitSampler(),
|
|
}[mode]
|
|
|
|
return InstanceSplitter(
|
|
target_field=FieldName.TARGET,
|
|
is_pad_field=FieldName.IS_PAD,
|
|
start_field=FieldName.START,
|
|
forecast_start_field=FieldName.FORECAST_START,
|
|
instance_sampler=instance_sampler,
|
|
past_length=module.model._past_length,
|
|
future_length=self.prediction_length,
|
|
time_series_fields=[
|
|
FieldName.FEAT_TIME,
|
|
FieldName.OBSERVED_VALUES,
|
|
],
|
|
dummy_value=self.distr_output.value_in_support,
|
|
)
|
|
|
|
def create_training_data_loader(
|
|
self,
|
|
data: Dataset,
|
|
module: FEDformerLightningModule,
|
|
shuffle_buffer_length: Optional[int] = None,
|
|
**kwargs,
|
|
) -> Iterable:
|
|
transformation = self._create_instance_splitter(
|
|
module, "training"
|
|
) + SelectFields(TRAINING_INPUT_NAMES)
|
|
|
|
training_instances = transformation.apply(
|
|
Cyclic(data)
|
|
if shuffle_buffer_length is None
|
|
else PseudoShuffled(
|
|
Cyclic(data), shuffle_buffer_length=shuffle_buffer_length
|
|
)
|
|
)
|
|
|
|
return IterableSlice(
|
|
iter(
|
|
DataLoader(
|
|
IterableDataset(training_instances),
|
|
batch_size=self.batch_size,
|
|
**kwargs,
|
|
)
|
|
),
|
|
self.num_batches_per_epoch,
|
|
)
|
|
|
|
def create_validation_data_loader(
|
|
self,
|
|
data: Dataset,
|
|
module: FEDformerLightningModule,
|
|
**kwargs,
|
|
) -> Iterable:
|
|
transformation = self._create_instance_splitter(
|
|
module, "validation"
|
|
) + SelectFields(TRAINING_INPUT_NAMES)
|
|
|
|
validation_instances = transformation.apply(data)
|
|
|
|
return DataLoader(
|
|
IterableDataset(validation_instances),
|
|
batch_size=self.batch_size,
|
|
**kwargs,
|
|
)
|
|
|
|
def create_predictor(
|
|
self,
|
|
transformation: Transformation,
|
|
module: FEDformerLightningModule,
|
|
) -> PyTorchPredictor:
|
|
prediction_splitter = self._create_instance_splitter(module, "test")
|
|
|
|
return PyTorchPredictor(
|
|
input_transform=transformation + prediction_splitter,
|
|
input_names=PREDICTION_INPUT_NAMES,
|
|
prediction_net=module.model,
|
|
batch_size=self.batch_size,
|
|
prediction_length=self.prediction_length,
|
|
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
|
)
|
|
|
|
def create_lightning_module(self) -> FEDformerLightningModule:
|
|
model = FEDformerModel(
|
|
version = self.version,#Fourier, Wavelets
|
|
freq = self.freq,
|
|
modes = self.modes,
|
|
mode_select = self.mode_select,
|
|
base = self.base,
|
|
cross_activation = self.cross_activation,
|
|
L = self.L,
|
|
|
|
#forecasting task
|
|
prediction_length = self.prediction_length,
|
|
context_length = self.context_length,#seq_len : input sequence length
|
|
label_length = self.label_length,#start token length
|
|
|
|
# model arguments
|
|
nhead = self.nhead,
|
|
num_encoder_layers = self.num_encoder_layers,
|
|
num_decoder_layers = self.num_decoder_layers,
|
|
dim_feedforward = self.dim_feedforward, #dimension of fcn
|
|
input_size = self.input_size, #encoder input size
|
|
# dec_in = self.dec_in, #decoder input size
|
|
c_out =self.c_out, #output size
|
|
moving_avg = self.moving_avg,
|
|
factor = self.factor,
|
|
scaling = self.scaling,
|
|
activation = self.activation,
|
|
dropout = self.dropout,
|
|
embed = self.embed,#options:[timeF, fixed, learned]
|
|
output_attention = self.output_attention,#whether to output attention in encoder
|
|
num_feat_dynamic_real=1
|
|
+ self.num_feat_dynamic_real
|
|
+ len(self.time_features),
|
|
num_feat_static_real=max(1, self.num_feat_static_real),
|
|
num_feat_static_cat=max(1, self.num_feat_static_cat),
|
|
cardinality=self.cardinality,
|
|
embedding_dimension=self.embedding_dimension,
|
|
# univariate input
|
|
distr_output=self.distr_output,
|
|
lags_seq=self.lags_seq,
|
|
num_parallel_samples=self.num_parallel_samples,
|
|
)
|
|
|
|
return FEDformerLightningModule(model=model, loss=self.loss)
|