mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 16:31:19 +08:00
initial TFT
still not working
This commit is contained in:
@@ -0,0 +1,9 @@
|
||||
from .module import TFTModel
|
||||
from .lightning_module import TFTLightningModule
|
||||
from .estimator import TFTEstimator
|
||||
|
||||
__all__ = [
|
||||
"TFTModel",
|
||||
"TFTLightningModule",
|
||||
"TFTEstimator",
|
||||
]
|
||||
@@ -0,0 +1,301 @@
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
import torch
|
||||
from gluonts.core.component import validated
|
||||
from gluonts.dataset.common import Dataset
|
||||
from gluonts.dataset.field_names import FieldName
|
||||
from gluonts.itertools import Cyclic, IterableSlice, PseudoShuffled
|
||||
from gluonts.time_feature import TimeFeature, time_features_from_frequency_str
|
||||
from gluonts.torch.model.estimator import PyTorchLightningEstimator
|
||||
from gluonts.torch.model.predictor import PyTorchPredictor
|
||||
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
|
||||
from gluonts.torch.util import IterableDataset
|
||||
from gluonts.transform import (
|
||||
AddAgeFeature,
|
||||
AddObservedValuesIndicator,
|
||||
AddTimeFeatures,
|
||||
AsNumpyArray,
|
||||
Chain,
|
||||
ExpectedNumInstanceSampler,
|
||||
InstanceSplitter,
|
||||
RemoveFields,
|
||||
SelectFields,
|
||||
SetField,
|
||||
TestSplitSampler,
|
||||
Transformation,
|
||||
ValidationSplitSampler,
|
||||
VstackFeatures,
|
||||
)
|
||||
from gluonts.transform.sampler import InstanceSampler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .lightning_module import TFTLightningModule
|
||||
from .module import TFTModel
|
||||
|
||||
PREDICTION_INPUT_NAMES = [
|
||||
"feat_static_cat",
|
||||
"feat_static_real",
|
||||
"past_time_feat",
|
||||
"past_target",
|
||||
"past_observed_values",
|
||||
"future_time_feat",
|
||||
]
|
||||
|
||||
TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [
|
||||
"future_target",
|
||||
"future_observed_values",
|
||||
]
|
||||
|
||||
|
||||
class TFTEstimator(PyTorchLightningEstimator):
|
||||
@validated()
|
||||
def __init__(
|
||||
self,
|
||||
freq: str,
|
||||
prediction_length: int,
|
||||
context_length: Optional[int] = None,
|
||||
dropout: float = 0.1,
|
||||
activation: str = "gelu",
|
||||
embed_dim: int = 32,
|
||||
num_heads: int = 4,
|
||||
num_feat_dynamic_real: int = 0,
|
||||
num_feat_static_cat: int = 0,
|
||||
num_feat_static_real: int = 0,
|
||||
cardinality: Optional[List[int]] = None,
|
||||
embedding_dimension: Optional[List[int]] = None,
|
||||
distr_output: DistributionOutput = StudentTOutput(),
|
||||
loss: DistributionLoss = NegativeLogLikelihood(),
|
||||
scaling: bool = True,
|
||||
lags_seq: Optional[List[int]] = None,
|
||||
time_features: Optional[List[TimeFeature]] = None,
|
||||
num_parallel_samples: int = 100,
|
||||
batch_size: int = 32,
|
||||
num_batches_per_epoch: int = 50,
|
||||
trainer_kwargs: Optional[Dict[str, Any]] = dict(),
|
||||
train_sampler: Optional[InstanceSampler] = None,
|
||||
validation_sampler: Optional[InstanceSampler] = None,
|
||||
) -> None:
|
||||
trainer_kwargs = {
|
||||
"max_epochs": 100,
|
||||
**trainer_kwargs,
|
||||
}
|
||||
super().__init__(trainer_kwargs=trainer_kwargs)
|
||||
|
||||
self.freq = freq
|
||||
self.context_length = (
|
||||
context_length if context_length is not None else prediction_length
|
||||
)
|
||||
self.prediction_length = prediction_length
|
||||
self.distr_output = distr_output
|
||||
self.loss = loss
|
||||
|
||||
# MultiheadAttention
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.activation = activation
|
||||
self.dropout = dropout
|
||||
|
||||
self.num_feat_dynamic_real = num_feat_dynamic_real
|
||||
self.num_feat_static_cat = num_feat_static_cat
|
||||
self.num_feat_static_real = num_feat_static_real
|
||||
self.cardinality = (
|
||||
cardinality if cardinality and num_feat_static_cat > 0 else [1]
|
||||
)
|
||||
self.embedding_dimension = embedding_dimension
|
||||
self.scaling = scaling
|
||||
self.lags_seq = lags_seq
|
||||
self.time_features = (
|
||||
time_features
|
||||
if time_features is not None
|
||||
else time_features_from_frequency_str(self.freq)
|
||||
)
|
||||
|
||||
self.num_parallel_samples = num_parallel_samples
|
||||
self.batch_size = batch_size
|
||||
self.num_batches_per_epoch = num_batches_per_epoch
|
||||
|
||||
self.train_sampler = train_sampler or ExpectedNumInstanceSampler(
|
||||
num_instances=1.0, min_future=prediction_length
|
||||
)
|
||||
self.validation_sampler = validation_sampler or ValidationSplitSampler(
|
||||
min_future=prediction_length
|
||||
)
|
||||
|
||||
def create_transformation(self) -> Transformation:
|
||||
remove_field_names = []
|
||||
if self.num_feat_static_real == 0:
|
||||
remove_field_names.append(FieldName.FEAT_STATIC_REAL)
|
||||
if self.num_feat_dynamic_real == 0:
|
||||
remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL)
|
||||
|
||||
return Chain(
|
||||
[RemoveFields(field_names=remove_field_names)]
|
||||
+ (
|
||||
[SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0])]
|
||||
if not self.num_feat_static_cat > 0
|
||||
else []
|
||||
)
|
||||
+ (
|
||||
[SetField(output_field=FieldName.FEAT_STATIC_REAL, value=[0.0])]
|
||||
if not self.num_feat_static_real > 0
|
||||
else []
|
||||
)
|
||||
+ [
|
||||
AsNumpyArray(
|
||||
field=FieldName.FEAT_STATIC_CAT,
|
||||
expected_ndim=1,
|
||||
dtype=int,
|
||||
),
|
||||
AsNumpyArray(
|
||||
field=FieldName.FEAT_STATIC_REAL,
|
||||
expected_ndim=1,
|
||||
),
|
||||
AsNumpyArray(
|
||||
field=FieldName.TARGET,
|
||||
# in the following line, we add 1 for the time dimension
|
||||
expected_ndim=1 + len(self.distr_output.event_shape),
|
||||
),
|
||||
AddObservedValuesIndicator(
|
||||
target_field=FieldName.TARGET,
|
||||
output_field=FieldName.OBSERVED_VALUES,
|
||||
),
|
||||
AddTimeFeatures(
|
||||
start_field=FieldName.START,
|
||||
target_field=FieldName.TARGET,
|
||||
output_field=FieldName.FEAT_TIME,
|
||||
time_features=self.time_features,
|
||||
pred_length=self.prediction_length,
|
||||
),
|
||||
AddAgeFeature(
|
||||
target_field=FieldName.TARGET,
|
||||
output_field=FieldName.FEAT_AGE,
|
||||
pred_length=self.prediction_length,
|
||||
log_scale=True,
|
||||
),
|
||||
VstackFeatures(
|
||||
output_field=FieldName.FEAT_TIME,
|
||||
input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE]
|
||||
+ (
|
||||
[FieldName.FEAT_DYNAMIC_REAL]
|
||||
if self.num_feat_dynamic_real > 0
|
||||
else []
|
||||
),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def _create_instance_splitter(self, module: TFTLightningModule, mode: str):
|
||||
assert mode in ["training", "validation", "test"]
|
||||
|
||||
instance_sampler = {
|
||||
"training": self.train_sampler,
|
||||
"validation": self.validation_sampler,
|
||||
"test": TestSplitSampler(),
|
||||
}[mode]
|
||||
|
||||
return InstanceSplitter(
|
||||
target_field=FieldName.TARGET,
|
||||
is_pad_field=FieldName.IS_PAD,
|
||||
start_field=FieldName.START,
|
||||
forecast_start_field=FieldName.FORECAST_START,
|
||||
instance_sampler=instance_sampler,
|
||||
past_length=module.model._past_length,
|
||||
future_length=self.prediction_length,
|
||||
time_series_fields=[
|
||||
FieldName.FEAT_TIME,
|
||||
FieldName.OBSERVED_VALUES,
|
||||
],
|
||||
dummy_value=self.distr_output.value_in_support,
|
||||
)
|
||||
|
||||
def create_training_data_loader(
|
||||
self,
|
||||
data: Dataset,
|
||||
module: TFTLightningModule,
|
||||
shuffle_buffer_length: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Iterable:
|
||||
transformation = self._create_instance_splitter(
|
||||
module, "training"
|
||||
) + SelectFields(TRAINING_INPUT_NAMES)
|
||||
|
||||
training_instances = transformation.apply(
|
||||
Cyclic(data)
|
||||
if shuffle_buffer_length is None
|
||||
else PseudoShuffled(
|
||||
Cyclic(data), shuffle_buffer_length=shuffle_buffer_length
|
||||
)
|
||||
)
|
||||
|
||||
return IterableSlice(
|
||||
iter(
|
||||
DataLoader(
|
||||
IterableDataset(training_instances),
|
||||
batch_size=self.batch_size,
|
||||
**kwargs,
|
||||
)
|
||||
),
|
||||
self.num_batches_per_epoch,
|
||||
)
|
||||
|
||||
def create_validation_data_loader(
|
||||
self,
|
||||
data: Dataset,
|
||||
module: TFTLightningModule,
|
||||
**kwargs,
|
||||
) -> Iterable:
|
||||
transformation = self._create_instance_splitter(
|
||||
module, "validation"
|
||||
) + SelectFields(TRAINING_INPUT_NAMES)
|
||||
|
||||
validation_instances = transformation.apply(data)
|
||||
|
||||
return DataLoader(
|
||||
IterableDataset(validation_instances),
|
||||
batch_size=self.batch_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def create_predictor(
|
||||
self,
|
||||
transformation: Transformation,
|
||||
module: TFTLightningModule,
|
||||
) -> PyTorchPredictor:
|
||||
prediction_splitter = self._create_instance_splitter(module, "test")
|
||||
|
||||
return PyTorchPredictor(
|
||||
input_transform=transformation + prediction_splitter,
|
||||
input_names=PREDICTION_INPUT_NAMES,
|
||||
prediction_net=module.model,
|
||||
batch_size=self.batch_size,
|
||||
freq=self.freq,
|
||||
prediction_length=self.prediction_length,
|
||||
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
||||
)
|
||||
|
||||
def create_lightning_module(self) -> TFTLightningModule:
|
||||
model = TFTModel(
|
||||
freq=self.freq,
|
||||
context_length=self.context_length,
|
||||
prediction_length=self.prediction_length,
|
||||
num_feat_dynamic_real=1 # age
|
||||
+ self.num_feat_dynamic_real
|
||||
+ len(self.time_features),
|
||||
num_feat_static_real=max(1, self.num_feat_static_real),
|
||||
num_feat_static_cat=max(1, self.num_feat_static_cat),
|
||||
cardinality=self.cardinality,
|
||||
embedding_dimension=self.embedding_dimension,
|
||||
# transformer arguments
|
||||
nhead=self.nhead,
|
||||
dropout=self.dropout,
|
||||
dim_feedforward=self.dim_feedforward,
|
||||
# univariate input
|
||||
input_size=self.input_size,
|
||||
distr_output=self.distr_output,
|
||||
lags_seq=self.lags_seq,
|
||||
scaling=self.scaling,
|
||||
num_parallel_samples=self.num_parallel_samples,
|
||||
)
|
||||
|
||||
return TFTLightningModule(model=model, loss=self.loss)
|
||||
@@ -0,0 +1,81 @@
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
|
||||
from gluonts.torch.util import weighted_average
|
||||
|
||||
from .module import TFTModel
|
||||
|
||||
|
||||
class TFTLightningModule(pl.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
model: TFTModel,
|
||||
loss: DistributionLoss = NegativeLogLikelihood(),
|
||||
lr: float = 1e-3,
|
||||
weight_decay: float = 1e-8,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
self.model = model
|
||||
self.loss = loss
|
||||
self.lr = lr
|
||||
self.weight_decay = weight_decay
|
||||
|
||||
def training_step(self, batch, batch_idx: int):
|
||||
"""Execute training step"""
|
||||
train_loss = self(batch)
|
||||
self.log(
|
||||
"train_loss",
|
||||
train_loss,
|
||||
on_epoch=True,
|
||||
on_step=False,
|
||||
prog_bar=True,
|
||||
)
|
||||
return train_loss
|
||||
|
||||
def validation_step(self, batch, batch_idx: int):
|
||||
"""Execute validation step"""
|
||||
with torch.inference_mode():
|
||||
val_loss = self(batch)
|
||||
self.log("val_loss", val_loss, on_epoch=True, on_step=False, prog_bar=True)
|
||||
return val_loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""Returns the optimizer to use"""
|
||||
return torch.optim.Adam(
|
||||
self.model.parameters(),
|
||||
lr=self.lr,
|
||||
weight_decay=self.weight_decay,
|
||||
)
|
||||
|
||||
def forward(self, batch):
|
||||
feat_static_cat = batch["feat_static_cat"]
|
||||
feat_static_real = batch["feat_static_real"]
|
||||
past_time_feat = batch["past_time_feat"]
|
||||
past_target = batch["past_target"]
|
||||
past_observed_values = batch["past_observed_values"]
|
||||
|
||||
future_time_feat = batch["future_time_feat"]
|
||||
future_target = batch["future_target"]
|
||||
future_observed_values = batch["future_observed_values"]
|
||||
|
||||
tft_inputs, scale, _ = self.model.create_network_inputs(
|
||||
feat_static_cat=feat_static_cat,
|
||||
feat_static_real=feat_static_real,
|
||||
past_time_feat=past_time_feat,
|
||||
past_target=past_target,
|
||||
past_observed_values=past_observed_values,
|
||||
future_time_feat=future_time_feat,
|
||||
future_target=future_target,
|
||||
)
|
||||
params = self.model.output_params(tft_inputs)
|
||||
distr = self.model.output_distribution(params, scale)
|
||||
|
||||
loss_values = self.loss(distr, future_target)
|
||||
|
||||
if len(self.model.target_shape) == 0:
|
||||
loss_weights = future_observed_values
|
||||
else:
|
||||
loss_weights = future_observed_values.min(dim=-1, keepdim=False)
|
||||
|
||||
return weighted_average(loss_values, weights=loss_weights)
|
||||
+538
@@ -0,0 +1,538 @@
|
||||
from typing import List, Optional, Tuple
|
||||
from sympy import fu
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from gluonts.core.component import validated
|
||||
from gluonts.time_feature import get_lags_for_frequency
|
||||
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.modules.feature import FeatureEmbedder as BaseFeatureEmbedder
|
||||
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
|
||||
|
||||
|
||||
class FeatureEmbedder(BaseFeatureEmbedder):
|
||||
def forward(self, features: torch.Tensor) -> List[torch.Tensor]:
|
||||
concat_features = super(FeatureEmbedder, self).forward(features=features)
|
||||
|
||||
if self.__num_features > 1:
|
||||
features = torch.chunk(concat_features, self.__num_features, dim=-1)
|
||||
else:
|
||||
features = [concat_features]
|
||||
|
||||
return features
|
||||
|
||||
|
||||
class GatedResidualNetwork(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_hidden: int,
|
||||
d_input: Optional[int] = None,
|
||||
d_output: Optional[int] = None,
|
||||
d_static: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
d_input = d_input or d_hidden
|
||||
d_static = d_static or 0
|
||||
if d_output is None:
|
||||
d_output = d_input
|
||||
self.add_skip = False
|
||||
else:
|
||||
if d_output != d_input:
|
||||
self.add_skip = True
|
||||
self.skip_proj = nn.Linear(in_features=d_input, out_features=d_output)
|
||||
else:
|
||||
self.add_skip = False
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(in_features=d_input + d_static, out_features=d_hidden),
|
||||
nn.ELU(),
|
||||
nn.Linear(in_features=d_hidden, out_features=d_hidden),
|
||||
nn.Dropout(p=dropout),
|
||||
nn.Linear(in_features=d_hidden, out_features=d_output * 2),
|
||||
nn.GLU(),
|
||||
)
|
||||
|
||||
self.lnorm = nn.LayerNorm(d_output)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, c: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
if self.add_skip:
|
||||
skip = self.skip_proj(x)
|
||||
else:
|
||||
skip = x
|
||||
|
||||
if c is not None:
|
||||
x = torch.cat((x, c), dim=-1)
|
||||
x = self.mlp(x)
|
||||
x = self.lnorm(x + skip)
|
||||
return x
|
||||
|
||||
|
||||
class VariableSelectionNetwork(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_hidden: int,
|
||||
n_vars: int,
|
||||
dropout: float = 0.0,
|
||||
add_static: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.weight_network = GatedResidualNetwork(
|
||||
d_hidden=d_hidden,
|
||||
d_input=d_hidden * n_vars,
|
||||
d_output=n_vars,
|
||||
d_static=d_hidden if add_static else None,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
self.variable_network = nn.ModuleList(
|
||||
[
|
||||
GatedResidualNetwork(d_hidden=d_hidden, dropout=dropout)
|
||||
for _ in range(n_vars)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, variables: List[torch.Tensor], static: Optional[torch.Tensor] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
flatten = torch.cat(variables, dim=-1)
|
||||
if static is not None:
|
||||
static = static.expand_as(variables[0])
|
||||
weight = self.weight_network(flatten, static)
|
||||
weight = torch.softmax(weight.unsqueeze(-2), dim=-1)
|
||||
|
||||
var_encodings = [net(var) for var, net in zip(variables, self.variable_network)]
|
||||
var_encodings = torch.stack(var_encodings, dim=-1)
|
||||
|
||||
var_encodings = torch.sum(var_encodings * weight, dim=-1)
|
||||
|
||||
return var_encodings, weight
|
||||
|
||||
|
||||
class TemporalFusionEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_input: int,
|
||||
d_hidden: int,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.encoder_lstm = nn.LSTM(
|
||||
input_size=d_input, hidden_size=d_hidden, batch_first=True
|
||||
)
|
||||
self.decoder_lstm = nn.LSTM(
|
||||
input_size=d_input, hidden_size=d_hidden, batch_first=True
|
||||
)
|
||||
|
||||
self.gate = nn.Sequential(
|
||||
nn.Linear(in_features=d_hidden, out_features=d_hidden * 2),
|
||||
nn.GLU(),
|
||||
)
|
||||
if d_input != d_hidden:
|
||||
self.skip_proj = nn.Linear(in_features=d_input, out_features=d_hidden)
|
||||
self.add_skip = True
|
||||
else:
|
||||
self.add_skip = False
|
||||
|
||||
self.lnorm = nn.LayerNorm(d_hidden)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ctx_input: torch.Tensor,
|
||||
tgt_input: torch.Tensor,
|
||||
states: List[torch.Tensor],
|
||||
):
|
||||
ctx_encodings, states = self.encoder_lstm(ctx_input, states)
|
||||
|
||||
tgt_encodings, _ = self.decoder_lstm(tgt_input, states)
|
||||
|
||||
encodings = torch.cat((ctx_encodings, tgt_encodings), dim=1)
|
||||
skip = torch.cat((ctx_input, tgt_input), dim=1)
|
||||
if self.add_skip:
|
||||
skip = self.skip_proj(skip)
|
||||
encodings = self.gate(encodings)
|
||||
encodings = self.lnorm(skip + encodings)
|
||||
return encodings
|
||||
|
||||
|
||||
class TemporalFusionDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
context_length: int,
|
||||
prediction_length: int,
|
||||
d_hidden: int,
|
||||
d_var: int,
|
||||
n_head: int,
|
||||
dropout: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.context_length = context_length
|
||||
self.prediction_length = prediction_length
|
||||
|
||||
self.enrich = GatedResidualNetwork(
|
||||
d_hidden=d_hidden,
|
||||
d_static=d_var,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
self.attention = nn.MultiheadAttention(
|
||||
embed_dim=d_hidden,
|
||||
num_heads=n_head,
|
||||
dropout=dropout,
|
||||
batch_first=True,
|
||||
)
|
||||
|
||||
self.att_net = nn.Sequential(
|
||||
nn.Linear(in_features=d_hidden, out_features=d_hidden * 2),
|
||||
nn.GLU(),
|
||||
)
|
||||
self.att_lnorm = nn.LayerNorm(d_hidden)
|
||||
|
||||
self.ff_net = nn.Sequential(
|
||||
GatedResidualNetwork(d_hidden=d_hidden, dropout=dropout),
|
||||
nn.Linear(in_features=d_hidden, out_features=d_hidden * 2),
|
||||
nn.GLU(),
|
||||
)
|
||||
self.ff_lnorm = nn.LayerNorm(d_hidden)
|
||||
|
||||
self.register_buffer(
|
||||
"attn_mask",
|
||||
self._generate_subsequent_mask(
|
||||
prediction_length, prediction_length + context_length
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _generate_subsequent_mask(
|
||||
target_length: int, source_length: int
|
||||
) -> torch.Tensor:
|
||||
mask = (torch.triu(torch.ones(source_length, target_length)) == 1).transpose(
|
||||
0, 1
|
||||
)
|
||||
mask = (
|
||||
mask.float()
|
||||
.masked_fill(mask == 0, float("-inf"))
|
||||
.masked_fill(mask == 1, float(0.0))
|
||||
)
|
||||
return mask
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, static: torch.Tensor, mask: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
static = static.repeat((1, self.context_length + self.prediction_length, 1))
|
||||
|
||||
skip = x[:, self.context_length :, ...]
|
||||
x = self.enrich(x, static)
|
||||
|
||||
# does not work on GPU :-(
|
||||
# mask_pad = torch.ones_like(mask)[:, 0:1, ...]
|
||||
# mask_pad = mask_pad.repeat((1, self.prediction_length))
|
||||
# key_padding_mask = torch.cat((mask, mask_pad), dim=1).bool()
|
||||
|
||||
query_key_value = x
|
||||
|
||||
attn_output, _ = self.attention(
|
||||
query=query_key_value[-self.prediction_length :, ...],
|
||||
key=query_key_value,
|
||||
value=query_key_value,
|
||||
# key_padding_mask=key_padding_mask,
|
||||
attn_mask=self.attn_mask,
|
||||
)
|
||||
att = self.att_net(attn_output)
|
||||
|
||||
x = x[:, self.context_length :, ...]
|
||||
x = self.att_lnorm(x + att)
|
||||
x = self.ff_net(x)
|
||||
x = self.ff_lnorm(x + skip)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TFTModel(nn.Module):
|
||||
@validated()
|
||||
def __init__(
|
||||
self,
|
||||
freq: str,
|
||||
context_length: int,
|
||||
prediction_length: int,
|
||||
num_feat_dynamic_real: int,
|
||||
num_feat_static_real: int,
|
||||
num_feat_static_cat: int,
|
||||
cardinality: List[int],
|
||||
# TFT inputs
|
||||
nhead: int,
|
||||
hidden_dim: int,
|
||||
variable_dim: int,
|
||||
# univariate input
|
||||
input_size: int = 1,
|
||||
embedding_dimension: Optional[List[int]] = None,
|
||||
distr_output: DistributionOutput = StudentTOutput(),
|
||||
lags_seq: Optional[List[int]] = None,
|
||||
scaling: bool = True,
|
||||
num_parallel_samples: int = 100,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.input_size = input_size
|
||||
|
||||
self.target_shape = distr_output.event_shape
|
||||
self.num_feat_dynamic_real = num_feat_dynamic_real
|
||||
self.num_feat_static_cat = num_feat_static_cat
|
||||
self.num_feat_static_real = num_feat_static_real
|
||||
self.embedding_dimension = (
|
||||
embedding_dimension
|
||||
if embedding_dimension is not None or cardinality is None
|
||||
else [min(50, (cat + 1) // 2) for cat in cardinality]
|
||||
)
|
||||
self.lags_seq = lags_seq or get_lags_for_frequency(freq_str=freq)
|
||||
self.num_parallel_samples = num_parallel_samples
|
||||
self.history_length = context_length + max(self.lags_seq)
|
||||
self.embedder = FeatureEmbedder(
|
||||
cardinalities=cardinality,
|
||||
embedding_dims=self.embedding_dimension,
|
||||
)
|
||||
if scaling:
|
||||
self.scaler = MeanScaler(dim=1, keepdim=True)
|
||||
else:
|
||||
self.scaler = NOPScaler(dim=1, keepdim=True)
|
||||
|
||||
self.context_length = context_length
|
||||
self.prediction_length = prediction_length
|
||||
self.distr_output = distr_output
|
||||
|
||||
# projection networks
|
||||
self.target_proj = nn.Linear(
|
||||
in_features=input_size * len(self.lags_seq), out_features=variable_dim
|
||||
)
|
||||
|
||||
self.dynamic_proj = nn.Linear(
|
||||
in_features=num_feat_dynamic_real, out_features=variable_dim
|
||||
)
|
||||
|
||||
self.static_proj = nn.Linear(
|
||||
in_features=sum(self.embedding_dimension) + self.num_feat_static_real + 1,
|
||||
out_features=variable_dim,
|
||||
)
|
||||
|
||||
# variable selection networks
|
||||
self.past_selection = VariableSelectionNetwork(
|
||||
d_hidden=variable_dim,
|
||||
n_vars=input_size * len(self.lags_seq) + num_feat_dynamic_real,
|
||||
dropout=dropout,
|
||||
add_static=True,
|
||||
)
|
||||
|
||||
self.future_selection = VariableSelectionNetwork(
|
||||
d_hidden=variable_dim,
|
||||
n_vars=input_size * len(self.lags_seq) + num_feat_dynamic_real,
|
||||
dropout=dropout,
|
||||
add_static=True,
|
||||
)
|
||||
|
||||
self.static_selection = VariableSelectionNetwork(
|
||||
d_hidden=variable_dim,
|
||||
n_vars=sum(self.embedding_dimension) + self.num_feat_static_real + 1,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
# Static Gated Residual Networks
|
||||
self.selection = GatedResidualNetwork(
|
||||
d_hidden=variable_dim,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
self.enrichment = GatedResidualNetwork(
|
||||
d_hidden=variable_dim,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
# Encoder and Decoder network
|
||||
self.temporal_encoder = TemporalFusionEncoder(
|
||||
d_input=variable_dim,
|
||||
d_hidden=embed_dim,
|
||||
)
|
||||
self.temporal_decoder = TemporalFusionDecoder(
|
||||
context_length=self.context_length,
|
||||
prediction_length=self.prediction_length,
|
||||
d_hidden=embed_dim,
|
||||
d_var=variable_dim,
|
||||
n_head=nhead,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
# TODO
|
||||
self.param_proj = distr_output.get_args_proj(embed_dim)
|
||||
|
||||
# TODO
|
||||
|
||||
@property
|
||||
def _past_length(self) -> int:
|
||||
return self.context_length + max(self.lags_seq)
|
||||
|
||||
def get_lagged_subsequences(
|
||||
self, sequence: torch.Tensor, subsequences_length: int, shift: int = 0
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Returns lagged subsequences of a given sequence.
|
||||
Parameters
|
||||
----------
|
||||
sequence : Tensor
|
||||
the sequence from which lagged subsequences should be extracted.
|
||||
Shape: (N, T, C).
|
||||
subsequences_length : int
|
||||
length of the subsequences to be extracted.
|
||||
shift: int
|
||||
shift the lags by this amount back.
|
||||
Returns
|
||||
--------
|
||||
lagged : Tensor
|
||||
a tensor of shape (N, S, C, I), where S = subsequences_length and
|
||||
I = len(indices), containing lagged subsequences. Specifically,
|
||||
lagged[i, j, :, k] = sequence[i, -indices[k]-S+j, :].
|
||||
"""
|
||||
sequence_length = sequence.shape[1]
|
||||
indices = [lag - shift for lag in self.lags_seq]
|
||||
|
||||
assert max(indices) + subsequences_length <= sequence_length, (
|
||||
f"lags cannot go further than history length, found lag {max(indices)} "
|
||||
f"while history length is only {sequence_length}"
|
||||
)
|
||||
|
||||
lagged_values = []
|
||||
for lag_index in indices:
|
||||
begin_index = -lag_index - subsequences_length
|
||||
end_index = -lag_index if lag_index > 0 else None
|
||||
lagged_values.append(sequence[:, begin_index:end_index, ...])
|
||||
return torch.stack(lagged_values, dim=-1)
|
||||
|
||||
def create_network_inputs(
|
||||
self,
|
||||
feat_static_cat: torch.Tensor,
|
||||
feat_static_real: torch.Tensor,
|
||||
past_time_feat: torch.Tensor,
|
||||
past_target: torch.Tensor,
|
||||
past_observed_values: torch.Tensor,
|
||||
future_time_feat: Optional[torch.Tensor] = None,
|
||||
future_target: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# time feature
|
||||
time_feat = (
|
||||
past_time_feat[:, self._past_length - self.context_length :, ...]
|
||||
if future_time_feat is None or future_target is None
|
||||
else torch.cat(
|
||||
(
|
||||
past_time_feat[:, self._past_length - self.context_length :, ...],
|
||||
future_time_feat,
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
)
|
||||
|
||||
# calculate scale
|
||||
context = past_target[:, -self.context_length :]
|
||||
observed_context = past_observed_values[:, -self.context_length :]
|
||||
_, scale = self.scaler(context, observed_context)
|
||||
|
||||
# scale the target and create lag features of targets
|
||||
target = (
|
||||
torch.cat((past_target, future_target), dim=1) / scale
|
||||
if future_target is not None
|
||||
else past_target / scale
|
||||
)
|
||||
subsequences_length = (
|
||||
self.context_length
|
||||
if future_time_feat is None or future_target is None
|
||||
else self.context_length + self.prediction_length
|
||||
)
|
||||
|
||||
lagged_target = self.get_lagged_subsequences(
|
||||
sequence=target,
|
||||
subsequences_length=subsequences_length,
|
||||
)
|
||||
lags_shape = lagged_target.shape
|
||||
reshaped_lagged_target = lagged_target.reshape(lags_shape[0], lags_shape[1], -1)
|
||||
|
||||
# embeddings
|
||||
embedded_cat = self.embedder(feat_static_cat)
|
||||
static_feat = torch.cat(
|
||||
(embedded_cat, feat_static_real, scale.log()),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# return the network inputs
|
||||
return (
|
||||
reshaped_lagged_target, # target
|
||||
time_feat, # dynamic real covariates
|
||||
scale, # scale
|
||||
static_feat, # static covariates
|
||||
)
|
||||
|
||||
def output_params(self, target, time_feat, static_feat):
|
||||
target_proj = self.target_proj(target)
|
||||
|
||||
past_target_proj = target_proj[:, : self.context_length, ...]
|
||||
future_target_proj = target_proj[:, self.context_length :, ...]
|
||||
|
||||
time_feat_proj = self.dynamic_proj(time_feat)
|
||||
past_time_feat_proj = time_feat_proj[:, : self.context_length, ...]
|
||||
future_time_feat_proj = time_feat_proj[:, self.context_length :, ...]
|
||||
|
||||
static_feat_proj = self.static_proj(static_feat)
|
||||
|
||||
static_var, _ = self.static_selection([static_feat_proj])
|
||||
static_selection = self.selection(static_var).unsqueeze(1)
|
||||
static_enrichment = self.enrichment(static_var).unsqueeze(1)
|
||||
|
||||
past_selection, _ = self.past_selection(
|
||||
[past_target_proj, past_time_feat_proj], static_selection
|
||||
)
|
||||
|
||||
future_selection, _ = self.future_selection(
|
||||
[future_target_proj, future_time_feat_proj], static_selection
|
||||
)
|
||||
|
||||
encoding = self.temporal_encoder(past_selection, future_selection)
|
||||
|
||||
decoding = self.temporal_decoder(encoding)
|
||||
|
||||
return (
|
||||
past_target_proj,
|
||||
future_target_proj,
|
||||
past_time_feat_proj,
|
||||
static_feat_proj,
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def output_distribution(
|
||||
self, params, scale=None, trailing_n=None
|
||||
) -> torch.distributions.Distribution:
|
||||
sliced_params = params
|
||||
if trailing_n is not None:
|
||||
sliced_params = [p[:, -trailing_n:] for p in params]
|
||||
return self.distr_output.distribution(sliced_params, scale=scale)
|
||||
|
||||
# for prediction
|
||||
def forward(
|
||||
self,
|
||||
feat_static_cat: torch.Tensor,
|
||||
feat_static_real: torch.Tensor,
|
||||
past_time_feat: torch.Tensor,
|
||||
past_target: torch.Tensor,
|
||||
past_observed_values: torch.Tensor,
|
||||
future_time_feat: torch.Tensor,
|
||||
num_parallel_samples: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
if num_parallel_samples is None:
|
||||
num_parallel_samples = self.num_parallel_samples
|
||||
|
||||
target, time_feat, scale, static_feat = self.create_network_inputs(
|
||||
feat_static_cat,
|
||||
feat_static_real,
|
||||
past_time_feat,
|
||||
past_target,
|
||||
past_observed_values,
|
||||
future_time_feat,
|
||||
)
|
||||
@@ -0,0 +1,125 @@
|
||||
from typing import Iterator, List
|
||||
|
||||
import numpy as np
|
||||
from gluonts.core.component import validated
|
||||
from gluonts.dataset.common import DataEntry
|
||||
from gluonts.dataset.field_names import FieldName
|
||||
from gluonts.transform import (
|
||||
InstanceSplitter,
|
||||
MapTransformation,
|
||||
shift_timestamp,
|
||||
target_transformation_length,
|
||||
)
|
||||
from gluonts.transform.sampler import InstanceSampler
|
||||
|
||||
|
||||
class BroadcastTo(MapTransformation):
|
||||
@validated()
|
||||
def __init__(
|
||||
self,
|
||||
field: str,
|
||||
ext_length: int = 0,
|
||||
target_field: str = FieldName.TARGET,
|
||||
) -> None:
|
||||
self.field = field
|
||||
self.ext_length = ext_length
|
||||
self.target_field = target_field
|
||||
|
||||
def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry:
|
||||
length = target_transformation_length(
|
||||
data[self.target_field], self.ext_length, is_train
|
||||
)
|
||||
data[self.field] = np.broadcast_to(
|
||||
data[self.field],
|
||||
(data[self.field].shape[:-1] + (length,)),
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class TFTInstanceSplitter(InstanceSplitter):
|
||||
@validated()
|
||||
def __init__(
|
||||
self,
|
||||
instance_sampler: InstanceSampler,
|
||||
past_length: int,
|
||||
future_length: int,
|
||||
target_field: str = FieldName.TARGET,
|
||||
is_pad_field: str = FieldName.IS_PAD,
|
||||
start_field: str = FieldName.START,
|
||||
forecast_start_field: str = FieldName.FORECAST_START,
|
||||
observed_value_field: str = FieldName.OBSERVED_VALUES,
|
||||
lead_time: int = 0,
|
||||
output_NTC: bool = True,
|
||||
time_series_fields: List[str] = [],
|
||||
past_time_series_fields: List[str] = [],
|
||||
dummy_value: float = 0.0,
|
||||
) -> None:
|
||||
|
||||
super().__init__(
|
||||
target_field=target_field,
|
||||
is_pad_field=is_pad_field,
|
||||
start_field=start_field,
|
||||
forecast_start_field=forecast_start_field,
|
||||
instance_sampler=instance_sampler,
|
||||
past_length=past_length,
|
||||
future_length=future_length,
|
||||
lead_time=lead_time,
|
||||
output_NTC=output_NTC,
|
||||
time_series_fields=time_series_fields,
|
||||
dummy_value=dummy_value,
|
||||
)
|
||||
|
||||
assert past_length > 0, "The value of `past_length` should be > 0"
|
||||
assert future_length > 0, "The value of `future_length` should be > 0"
|
||||
|
||||
self.observed_value_field = observed_value_field
|
||||
self.past_ts_fields = past_time_series_fields
|
||||
|
||||
def flatmap_transform(self, data: DataEntry, is_train: bool) -> Iterator[DataEntry]:
|
||||
pl = self.future_length
|
||||
lt = self.lead_time
|
||||
target = data[self.target_field]
|
||||
|
||||
sampled_indices = self.instance_sampler(target)
|
||||
|
||||
slice_cols = (
|
||||
self.ts_fields
|
||||
+ self.past_ts_fields
|
||||
+ [self.target_field, self.observed_value_field]
|
||||
)
|
||||
for i in sampled_indices:
|
||||
pad_length = max(self.past_length - i, 0)
|
||||
d = data.copy()
|
||||
|
||||
for field in slice_cols:
|
||||
if i >= self.past_length:
|
||||
past_piece = d[field][..., i - self.past_length : i]
|
||||
else:
|
||||
pad_block = np.full(
|
||||
shape=d[field].shape[:-1] + (pad_length,),
|
||||
fill_value=self.dummy_value,
|
||||
dtype=d[field].dtype,
|
||||
)
|
||||
past_piece = np.concatenate([pad_block, d[field][..., :i]], axis=-1)
|
||||
future_piece = d[field][..., (i + lt) : (i + lt + pl)]
|
||||
if field in self.ts_fields:
|
||||
piece = np.concatenate([past_piece, future_piece], axis=-1)
|
||||
if self.output_NTC:
|
||||
piece = piece.transpose()
|
||||
d[field] = piece
|
||||
else:
|
||||
if self.output_NTC:
|
||||
past_piece = past_piece.transpose()
|
||||
future_piece = future_piece.transpose()
|
||||
if field not in self.past_ts_fields:
|
||||
d[self._past(field)] = past_piece
|
||||
d[self._future(field)] = future_piece
|
||||
del d[field]
|
||||
else:
|
||||
d[field] = past_piece
|
||||
pad_indicator = np.zeros(self.past_length)
|
||||
if pad_length > 0:
|
||||
pad_indicator[:pad_length] = 1
|
||||
d[self._past(self.is_pad_field)] = pad_indicator
|
||||
d[self.forecast_start_field] = shift_timestamp(d[self.start_field], i + lt)
|
||||
yield d
|
||||
Reference in New Issue
Block a user