diff --git a/tft/.typesafe b/tft/.typesafe new file mode 100644 index 0000000..e69de29 diff --git a/tft/__init__.py b/tft/__init__.py new file mode 100644 index 0000000..5bde3e4 --- /dev/null +++ b/tft/__init__.py @@ -0,0 +1,9 @@ +from .module import TFTModel +from .lightning_module import TFTLightningModule +from .estimator import TFTEstimator + +__all__ = [ + "TFTModel", + "TFTLightningModule", + "TFTEstimator", +] diff --git a/tft/estimator.py b/tft/estimator.py new file mode 100644 index 0000000..36db208 --- /dev/null +++ b/tft/estimator.py @@ -0,0 +1,301 @@ +from typing import Any, Dict, Iterable, List, Optional + +import torch +from gluonts.core.component import validated +from gluonts.dataset.common import Dataset +from gluonts.dataset.field_names import FieldName +from gluonts.itertools import Cyclic, IterableSlice, PseudoShuffled +from gluonts.time_feature import TimeFeature, time_features_from_frequency_str +from gluonts.torch.model.estimator import PyTorchLightningEstimator +from gluonts.torch.model.predictor import PyTorchPredictor +from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput +from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood +from gluonts.torch.util import IterableDataset +from gluonts.transform import ( + AddAgeFeature, + AddObservedValuesIndicator, + AddTimeFeatures, + AsNumpyArray, + Chain, + ExpectedNumInstanceSampler, + InstanceSplitter, + RemoveFields, + SelectFields, + SetField, + TestSplitSampler, + Transformation, + ValidationSplitSampler, + VstackFeatures, +) +from gluonts.transform.sampler import InstanceSampler +from torch.utils.data import DataLoader + +from .lightning_module import TFTLightningModule +from .module import TFTModel + +PREDICTION_INPUT_NAMES = [ + "feat_static_cat", + "feat_static_real", + "past_time_feat", + "past_target", + "past_observed_values", + "future_time_feat", +] + +TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [ + "future_target", + "future_observed_values", +] + + +class TFTEstimator(PyTorchLightningEstimator): + @validated() + def __init__( + self, + freq: str, + prediction_length: int, + context_length: Optional[int] = None, + dropout: float = 0.1, + activation: str = "gelu", + embed_dim: int = 32, + num_heads: int = 4, + num_feat_dynamic_real: int = 0, + num_feat_static_cat: int = 0, + num_feat_static_real: int = 0, + cardinality: Optional[List[int]] = None, + embedding_dimension: Optional[List[int]] = None, + distr_output: DistributionOutput = StudentTOutput(), + loss: DistributionLoss = NegativeLogLikelihood(), + scaling: bool = True, + lags_seq: Optional[List[int]] = None, + time_features: Optional[List[TimeFeature]] = None, + num_parallel_samples: int = 100, + batch_size: int = 32, + num_batches_per_epoch: int = 50, + trainer_kwargs: Optional[Dict[str, Any]] = dict(), + train_sampler: Optional[InstanceSampler] = None, + validation_sampler: Optional[InstanceSampler] = None, + ) -> None: + trainer_kwargs = { + "max_epochs": 100, + **trainer_kwargs, + } + super().__init__(trainer_kwargs=trainer_kwargs) + + self.freq = freq + self.context_length = ( + context_length if context_length is not None else prediction_length + ) + self.prediction_length = prediction_length + self.distr_output = distr_output + self.loss = loss + + # MultiheadAttention + self.embed_dim = embed_dim + self.num_heads = num_heads + self.activation = activation + self.dropout = dropout + + self.num_feat_dynamic_real = num_feat_dynamic_real + self.num_feat_static_cat = num_feat_static_cat + self.num_feat_static_real = num_feat_static_real + self.cardinality = ( + cardinality if cardinality and num_feat_static_cat > 0 else [1] + ) + self.embedding_dimension = embedding_dimension + self.scaling = scaling + self.lags_seq = lags_seq + self.time_features = ( + time_features + if time_features is not None + else time_features_from_frequency_str(self.freq) + ) + + self.num_parallel_samples = num_parallel_samples + self.batch_size = batch_size + self.num_batches_per_epoch = num_batches_per_epoch + + self.train_sampler = train_sampler or ExpectedNumInstanceSampler( + num_instances=1.0, min_future=prediction_length + ) + self.validation_sampler = validation_sampler or ValidationSplitSampler( + min_future=prediction_length + ) + + def create_transformation(self) -> Transformation: + remove_field_names = [] + if self.num_feat_static_real == 0: + remove_field_names.append(FieldName.FEAT_STATIC_REAL) + if self.num_feat_dynamic_real == 0: + remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL) + + return Chain( + [RemoveFields(field_names=remove_field_names)] + + ( + [SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0])] + if not self.num_feat_static_cat > 0 + else [] + ) + + ( + [SetField(output_field=FieldName.FEAT_STATIC_REAL, value=[0.0])] + if not self.num_feat_static_real > 0 + else [] + ) + + [ + AsNumpyArray( + field=FieldName.FEAT_STATIC_CAT, + expected_ndim=1, + dtype=int, + ), + AsNumpyArray( + field=FieldName.FEAT_STATIC_REAL, + expected_ndim=1, + ), + AsNumpyArray( + field=FieldName.TARGET, + # in the following line, we add 1 for the time dimension + expected_ndim=1 + len(self.distr_output.event_shape), + ), + AddObservedValuesIndicator( + target_field=FieldName.TARGET, + output_field=FieldName.OBSERVED_VALUES, + ), + AddTimeFeatures( + start_field=FieldName.START, + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_TIME, + time_features=self.time_features, + pred_length=self.prediction_length, + ), + AddAgeFeature( + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_AGE, + pred_length=self.prediction_length, + log_scale=True, + ), + VstackFeatures( + output_field=FieldName.FEAT_TIME, + input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE] + + ( + [FieldName.FEAT_DYNAMIC_REAL] + if self.num_feat_dynamic_real > 0 + else [] + ), + ), + ] + ) + + def _create_instance_splitter(self, module: TFTLightningModule, mode: str): + assert mode in ["training", "validation", "test"] + + instance_sampler = { + "training": self.train_sampler, + "validation": self.validation_sampler, + "test": TestSplitSampler(), + }[mode] + + return InstanceSplitter( + target_field=FieldName.TARGET, + is_pad_field=FieldName.IS_PAD, + start_field=FieldName.START, + forecast_start_field=FieldName.FORECAST_START, + instance_sampler=instance_sampler, + past_length=module.model._past_length, + future_length=self.prediction_length, + time_series_fields=[ + FieldName.FEAT_TIME, + FieldName.OBSERVED_VALUES, + ], + dummy_value=self.distr_output.value_in_support, + ) + + def create_training_data_loader( + self, + data: Dataset, + module: TFTLightningModule, + shuffle_buffer_length: Optional[int] = None, + **kwargs, + ) -> Iterable: + transformation = self._create_instance_splitter( + module, "training" + ) + SelectFields(TRAINING_INPUT_NAMES) + + training_instances = transformation.apply( + Cyclic(data) + if shuffle_buffer_length is None + else PseudoShuffled( + Cyclic(data), shuffle_buffer_length=shuffle_buffer_length + ) + ) + + return IterableSlice( + iter( + DataLoader( + IterableDataset(training_instances), + batch_size=self.batch_size, + **kwargs, + ) + ), + self.num_batches_per_epoch, + ) + + def create_validation_data_loader( + self, + data: Dataset, + module: TFTLightningModule, + **kwargs, + ) -> Iterable: + transformation = self._create_instance_splitter( + module, "validation" + ) + SelectFields(TRAINING_INPUT_NAMES) + + validation_instances = transformation.apply(data) + + return DataLoader( + IterableDataset(validation_instances), + batch_size=self.batch_size, + **kwargs, + ) + + def create_predictor( + self, + transformation: Transformation, + module: TFTLightningModule, + ) -> PyTorchPredictor: + prediction_splitter = self._create_instance_splitter(module, "test") + + return PyTorchPredictor( + input_transform=transformation + prediction_splitter, + input_names=PREDICTION_INPUT_NAMES, + prediction_net=module.model, + batch_size=self.batch_size, + freq=self.freq, + prediction_length=self.prediction_length, + device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), + ) + + def create_lightning_module(self) -> TFTLightningModule: + model = TFTModel( + freq=self.freq, + context_length=self.context_length, + prediction_length=self.prediction_length, + num_feat_dynamic_real=1 # age + + self.num_feat_dynamic_real + + len(self.time_features), + num_feat_static_real=max(1, self.num_feat_static_real), + num_feat_static_cat=max(1, self.num_feat_static_cat), + cardinality=self.cardinality, + embedding_dimension=self.embedding_dimension, + # transformer arguments + nhead=self.nhead, + dropout=self.dropout, + dim_feedforward=self.dim_feedforward, + # univariate input + input_size=self.input_size, + distr_output=self.distr_output, + lags_seq=self.lags_seq, + scaling=self.scaling, + num_parallel_samples=self.num_parallel_samples, + ) + + return TFTLightningModule(model=model, loss=self.loss) diff --git a/tft/lightning_module.py b/tft/lightning_module.py new file mode 100644 index 0000000..05092f9 --- /dev/null +++ b/tft/lightning_module.py @@ -0,0 +1,81 @@ +import pytorch_lightning as pl +import torch +from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood +from gluonts.torch.util import weighted_average + +from .module import TFTModel + + +class TFTLightningModule(pl.LightningModule): + def __init__( + self, + model: TFTModel, + loss: DistributionLoss = NegativeLogLikelihood(), + lr: float = 1e-3, + weight_decay: float = 1e-8, + ) -> None: + super().__init__() + self.save_hyperparameters() + self.model = model + self.loss = loss + self.lr = lr + self.weight_decay = weight_decay + + def training_step(self, batch, batch_idx: int): + """Execute training step""" + train_loss = self(batch) + self.log( + "train_loss", + train_loss, + on_epoch=True, + on_step=False, + prog_bar=True, + ) + return train_loss + + def validation_step(self, batch, batch_idx: int): + """Execute validation step""" + with torch.inference_mode(): + val_loss = self(batch) + self.log("val_loss", val_loss, on_epoch=True, on_step=False, prog_bar=True) + return val_loss + + def configure_optimizers(self): + """Returns the optimizer to use""" + return torch.optim.Adam( + self.model.parameters(), + lr=self.lr, + weight_decay=self.weight_decay, + ) + + def forward(self, batch): + feat_static_cat = batch["feat_static_cat"] + feat_static_real = batch["feat_static_real"] + past_time_feat = batch["past_time_feat"] + past_target = batch["past_target"] + past_observed_values = batch["past_observed_values"] + + future_time_feat = batch["future_time_feat"] + future_target = batch["future_target"] + future_observed_values = batch["future_observed_values"] + + tft_inputs, scale, _ = self.model.create_network_inputs( + feat_static_cat=feat_static_cat, + feat_static_real=feat_static_real, + past_time_feat=past_time_feat, + past_target=past_target, + past_observed_values=past_observed_values, + future_time_feat=future_time_feat, + future_target=future_target, + ) + params = self.model.output_params(tft_inputs) + distr = self.model.output_distribution(params, scale) + + loss_values = self.loss(distr, future_target) + + if len(self.model.target_shape) == 0: + loss_weights = future_observed_values + else: + loss_weights = future_observed_values.min(dim=-1, keepdim=False) + + return weighted_average(loss_values, weights=loss_weights) diff --git a/tft/module.py b/tft/module.py new file mode 100644 index 0000000..a0c5286 --- /dev/null +++ b/tft/module.py @@ -0,0 +1,538 @@ +from typing import List, Optional, Tuple +from sympy import fu + +import torch +import torch.nn as nn +from gluonts.core.component import validated +from gluonts.time_feature import get_lags_for_frequency +from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput +from gluonts.torch.modules.feature import FeatureEmbedder as BaseFeatureEmbedder +from gluonts.torch.modules.scaler import MeanScaler, NOPScaler + + +class FeatureEmbedder(BaseFeatureEmbedder): + def forward(self, features: torch.Tensor) -> List[torch.Tensor]: + concat_features = super(FeatureEmbedder, self).forward(features=features) + + if self.__num_features > 1: + features = torch.chunk(concat_features, self.__num_features, dim=-1) + else: + features = [concat_features] + + return features + + +class GatedResidualNetwork(nn.Module): + def __init__( + self, + d_hidden: int, + d_input: Optional[int] = None, + d_output: Optional[int] = None, + d_static: Optional[int] = None, + dropout: float = 0.0, + ): + super().__init__() + + d_input = d_input or d_hidden + d_static = d_static or 0 + if d_output is None: + d_output = d_input + self.add_skip = False + else: + if d_output != d_input: + self.add_skip = True + self.skip_proj = nn.Linear(in_features=d_input, out_features=d_output) + else: + self.add_skip = False + + self.mlp = nn.Sequential( + nn.Linear(in_features=d_input + d_static, out_features=d_hidden), + nn.ELU(), + nn.Linear(in_features=d_hidden, out_features=d_hidden), + nn.Dropout(p=dropout), + nn.Linear(in_features=d_hidden, out_features=d_output * 2), + nn.GLU(), + ) + + self.lnorm = nn.LayerNorm(d_output) + + def forward( + self, x: torch.Tensor, c: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if self.add_skip: + skip = self.skip_proj(x) + else: + skip = x + + if c is not None: + x = torch.cat((x, c), dim=-1) + x = self.mlp(x) + x = self.lnorm(x + skip) + return x + + +class VariableSelectionNetwork(nn.Module): + def __init__( + self, + d_hidden: int, + n_vars: int, + dropout: float = 0.0, + add_static: bool = False, + ): + super().__init__() + self.weight_network = GatedResidualNetwork( + d_hidden=d_hidden, + d_input=d_hidden * n_vars, + d_output=n_vars, + d_static=d_hidden if add_static else None, + dropout=dropout, + ) + + self.variable_network = nn.ModuleList( + [ + GatedResidualNetwork(d_hidden=d_hidden, dropout=dropout) + for _ in range(n_vars) + ] + ) + + def forward( + self, variables: List[torch.Tensor], static: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + flatten = torch.cat(variables, dim=-1) + if static is not None: + static = static.expand_as(variables[0]) + weight = self.weight_network(flatten, static) + weight = torch.softmax(weight.unsqueeze(-2), dim=-1) + + var_encodings = [net(var) for var, net in zip(variables, self.variable_network)] + var_encodings = torch.stack(var_encodings, dim=-1) + + var_encodings = torch.sum(var_encodings * weight, dim=-1) + + return var_encodings, weight + + +class TemporalFusionEncoder(nn.Module): + def __init__( + self, + d_input: int, + d_hidden: int, + ): + super().__init__() + + self.encoder_lstm = nn.LSTM( + input_size=d_input, hidden_size=d_hidden, batch_first=True + ) + self.decoder_lstm = nn.LSTM( + input_size=d_input, hidden_size=d_hidden, batch_first=True + ) + + self.gate = nn.Sequential( + nn.Linear(in_features=d_hidden, out_features=d_hidden * 2), + nn.GLU(), + ) + if d_input != d_hidden: + self.skip_proj = nn.Linear(in_features=d_input, out_features=d_hidden) + self.add_skip = True + else: + self.add_skip = False + + self.lnorm = nn.LayerNorm(d_hidden) + + def forward( + self, + ctx_input: torch.Tensor, + tgt_input: torch.Tensor, + states: List[torch.Tensor], + ): + ctx_encodings, states = self.encoder_lstm(ctx_input, states) + + tgt_encodings, _ = self.decoder_lstm(tgt_input, states) + + encodings = torch.cat((ctx_encodings, tgt_encodings), dim=1) + skip = torch.cat((ctx_input, tgt_input), dim=1) + if self.add_skip: + skip = self.skip_proj(skip) + encodings = self.gate(encodings) + encodings = self.lnorm(skip + encodings) + return encodings + + +class TemporalFusionDecoder(nn.Module): + def __init__( + self, + context_length: int, + prediction_length: int, + d_hidden: int, + d_var: int, + n_head: int, + dropout: float = 0.0, + ): + super().__init__() + self.context_length = context_length + self.prediction_length = prediction_length + + self.enrich = GatedResidualNetwork( + d_hidden=d_hidden, + d_static=d_var, + dropout=dropout, + ) + + self.attention = nn.MultiheadAttention( + embed_dim=d_hidden, + num_heads=n_head, + dropout=dropout, + batch_first=True, + ) + + self.att_net = nn.Sequential( + nn.Linear(in_features=d_hidden, out_features=d_hidden * 2), + nn.GLU(), + ) + self.att_lnorm = nn.LayerNorm(d_hidden) + + self.ff_net = nn.Sequential( + GatedResidualNetwork(d_hidden=d_hidden, dropout=dropout), + nn.Linear(in_features=d_hidden, out_features=d_hidden * 2), + nn.GLU(), + ) + self.ff_lnorm = nn.LayerNorm(d_hidden) + + self.register_buffer( + "attn_mask", + self._generate_subsequent_mask( + prediction_length, prediction_length + context_length + ), + ) + + @staticmethod + def _generate_subsequent_mask( + target_length: int, source_length: int + ) -> torch.Tensor: + mask = (torch.triu(torch.ones(source_length, target_length)) == 1).transpose( + 0, 1 + ) + mask = ( + mask.float() + .masked_fill(mask == 0, float("-inf")) + .masked_fill(mask == 1, float(0.0)) + ) + return mask + + def forward( + self, x: torch.Tensor, static: torch.Tensor, mask: torch.Tensor + ) -> torch.Tensor: + static = static.repeat((1, self.context_length + self.prediction_length, 1)) + + skip = x[:, self.context_length :, ...] + x = self.enrich(x, static) + + # does not work on GPU :-( + # mask_pad = torch.ones_like(mask)[:, 0:1, ...] + # mask_pad = mask_pad.repeat((1, self.prediction_length)) + # key_padding_mask = torch.cat((mask, mask_pad), dim=1).bool() + + query_key_value = x + + attn_output, _ = self.attention( + query=query_key_value[-self.prediction_length :, ...], + key=query_key_value, + value=query_key_value, + # key_padding_mask=key_padding_mask, + attn_mask=self.attn_mask, + ) + att = self.att_net(attn_output) + + x = x[:, self.context_length :, ...] + x = self.att_lnorm(x + att) + x = self.ff_net(x) + x = self.ff_lnorm(x + skip) + + return x + + +class TFTModel(nn.Module): + @validated() + def __init__( + self, + freq: str, + context_length: int, + prediction_length: int, + num_feat_dynamic_real: int, + num_feat_static_real: int, + num_feat_static_cat: int, + cardinality: List[int], + # TFT inputs + nhead: int, + hidden_dim: int, + variable_dim: int, + # univariate input + input_size: int = 1, + embedding_dimension: Optional[List[int]] = None, + distr_output: DistributionOutput = StudentTOutput(), + lags_seq: Optional[List[int]] = None, + scaling: bool = True, + num_parallel_samples: int = 100, + ) -> None: + super().__init__() + + self.input_size = input_size + + self.target_shape = distr_output.event_shape + self.num_feat_dynamic_real = num_feat_dynamic_real + self.num_feat_static_cat = num_feat_static_cat + self.num_feat_static_real = num_feat_static_real + self.embedding_dimension = ( + embedding_dimension + if embedding_dimension is not None or cardinality is None + else [min(50, (cat + 1) // 2) for cat in cardinality] + ) + self.lags_seq = lags_seq or get_lags_for_frequency(freq_str=freq) + self.num_parallel_samples = num_parallel_samples + self.history_length = context_length + max(self.lags_seq) + self.embedder = FeatureEmbedder( + cardinalities=cardinality, + embedding_dims=self.embedding_dimension, + ) + if scaling: + self.scaler = MeanScaler(dim=1, keepdim=True) + else: + self.scaler = NOPScaler(dim=1, keepdim=True) + + self.context_length = context_length + self.prediction_length = prediction_length + self.distr_output = distr_output + + # projection networks + self.target_proj = nn.Linear( + in_features=input_size * len(self.lags_seq), out_features=variable_dim + ) + + self.dynamic_proj = nn.Linear( + in_features=num_feat_dynamic_real, out_features=variable_dim + ) + + self.static_proj = nn.Linear( + in_features=sum(self.embedding_dimension) + self.num_feat_static_real + 1, + out_features=variable_dim, + ) + + # variable selection networks + self.past_selection = VariableSelectionNetwork( + d_hidden=variable_dim, + n_vars=input_size * len(self.lags_seq) + num_feat_dynamic_real, + dropout=dropout, + add_static=True, + ) + + self.future_selection = VariableSelectionNetwork( + d_hidden=variable_dim, + n_vars=input_size * len(self.lags_seq) + num_feat_dynamic_real, + dropout=dropout, + add_static=True, + ) + + self.static_selection = VariableSelectionNetwork( + d_hidden=variable_dim, + n_vars=sum(self.embedding_dimension) + self.num_feat_static_real + 1, + dropout=dropout, + ) + + # Static Gated Residual Networks + self.selection = GatedResidualNetwork( + d_hidden=variable_dim, + dropout=dropout, + ) + + self.enrichment = GatedResidualNetwork( + d_hidden=variable_dim, + dropout=dropout, + ) + + # Encoder and Decoder network + self.temporal_encoder = TemporalFusionEncoder( + d_input=variable_dim, + d_hidden=embed_dim, + ) + self.temporal_decoder = TemporalFusionDecoder( + context_length=self.context_length, + prediction_length=self.prediction_length, + d_hidden=embed_dim, + d_var=variable_dim, + n_head=nhead, + dropout=dropout, + ) + + # TODO + self.param_proj = distr_output.get_args_proj(embed_dim) + + # TODO + + @property + def _past_length(self) -> int: + return self.context_length + max(self.lags_seq) + + def get_lagged_subsequences( + self, sequence: torch.Tensor, subsequences_length: int, shift: int = 0 + ) -> torch.Tensor: + """ + Returns lagged subsequences of a given sequence. + Parameters + ---------- + sequence : Tensor + the sequence from which lagged subsequences should be extracted. + Shape: (N, T, C). + subsequences_length : int + length of the subsequences to be extracted. + shift: int + shift the lags by this amount back. + Returns + -------- + lagged : Tensor + a tensor of shape (N, S, C, I), where S = subsequences_length and + I = len(indices), containing lagged subsequences. Specifically, + lagged[i, j, :, k] = sequence[i, -indices[k]-S+j, :]. + """ + sequence_length = sequence.shape[1] + indices = [lag - shift for lag in self.lags_seq] + + assert max(indices) + subsequences_length <= sequence_length, ( + f"lags cannot go further than history length, found lag {max(indices)} " + f"while history length is only {sequence_length}" + ) + + lagged_values = [] + for lag_index in indices: + begin_index = -lag_index - subsequences_length + end_index = -lag_index if lag_index > 0 else None + lagged_values.append(sequence[:, begin_index:end_index, ...]) + return torch.stack(lagged_values, dim=-1) + + def create_network_inputs( + self, + feat_static_cat: torch.Tensor, + feat_static_real: torch.Tensor, + past_time_feat: torch.Tensor, + past_target: torch.Tensor, + past_observed_values: torch.Tensor, + future_time_feat: Optional[torch.Tensor] = None, + future_target: Optional[torch.Tensor] = None, + ): + # time feature + time_feat = ( + past_time_feat[:, self._past_length - self.context_length :, ...] + if future_time_feat is None or future_target is None + else torch.cat( + ( + past_time_feat[:, self._past_length - self.context_length :, ...], + future_time_feat, + ), + dim=1, + ) + ) + + # calculate scale + context = past_target[:, -self.context_length :] + observed_context = past_observed_values[:, -self.context_length :] + _, scale = self.scaler(context, observed_context) + + # scale the target and create lag features of targets + target = ( + torch.cat((past_target, future_target), dim=1) / scale + if future_target is not None + else past_target / scale + ) + subsequences_length = ( + self.context_length + if future_time_feat is None or future_target is None + else self.context_length + self.prediction_length + ) + + lagged_target = self.get_lagged_subsequences( + sequence=target, + subsequences_length=subsequences_length, + ) + lags_shape = lagged_target.shape + reshaped_lagged_target = lagged_target.reshape(lags_shape[0], lags_shape[1], -1) + + # embeddings + embedded_cat = self.embedder(feat_static_cat) + static_feat = torch.cat( + (embedded_cat, feat_static_real, scale.log()), + dim=1, + ) + + # return the network inputs + return ( + reshaped_lagged_target, # target + time_feat, # dynamic real covariates + scale, # scale + static_feat, # static covariates + ) + + def output_params(self, target, time_feat, static_feat): + target_proj = self.target_proj(target) + + past_target_proj = target_proj[:, : self.context_length, ...] + future_target_proj = target_proj[:, self.context_length :, ...] + + time_feat_proj = self.dynamic_proj(time_feat) + past_time_feat_proj = time_feat_proj[:, : self.context_length, ...] + future_time_feat_proj = time_feat_proj[:, self.context_length :, ...] + + static_feat_proj = self.static_proj(static_feat) + + static_var, _ = self.static_selection([static_feat_proj]) + static_selection = self.selection(static_var).unsqueeze(1) + static_enrichment = self.enrichment(static_var).unsqueeze(1) + + past_selection, _ = self.past_selection( + [past_target_proj, past_time_feat_proj], static_selection + ) + + future_selection, _ = self.future_selection( + [future_target_proj, future_time_feat_proj], static_selection + ) + + encoding = self.temporal_encoder(past_selection, future_selection) + + decoding = self.temporal_decoder(encoding) + + return ( + past_target_proj, + future_target_proj, + past_time_feat_proj, + static_feat_proj, + ) + + @torch.jit.ignore + def output_distribution( + self, params, scale=None, trailing_n=None + ) -> torch.distributions.Distribution: + sliced_params = params + if trailing_n is not None: + sliced_params = [p[:, -trailing_n:] for p in params] + return self.distr_output.distribution(sliced_params, scale=scale) + + # for prediction + def forward( + self, + feat_static_cat: torch.Tensor, + feat_static_real: torch.Tensor, + past_time_feat: torch.Tensor, + past_target: torch.Tensor, + past_observed_values: torch.Tensor, + future_time_feat: torch.Tensor, + num_parallel_samples: Optional[int] = None, + ) -> torch.Tensor: + if num_parallel_samples is None: + num_parallel_samples = self.num_parallel_samples + + target, time_feat, scale, static_feat = self.create_network_inputs( + feat_static_cat, + feat_static_real, + past_time_feat, + past_target, + past_observed_values, + future_time_feat, + ) diff --git a/tft/transforms.py b/tft/transforms.py new file mode 100644 index 0000000..3d7931d --- /dev/null +++ b/tft/transforms.py @@ -0,0 +1,125 @@ +from typing import Iterator, List + +import numpy as np +from gluonts.core.component import validated +from gluonts.dataset.common import DataEntry +from gluonts.dataset.field_names import FieldName +from gluonts.transform import ( + InstanceSplitter, + MapTransformation, + shift_timestamp, + target_transformation_length, +) +from gluonts.transform.sampler import InstanceSampler + + +class BroadcastTo(MapTransformation): + @validated() + def __init__( + self, + field: str, + ext_length: int = 0, + target_field: str = FieldName.TARGET, + ) -> None: + self.field = field + self.ext_length = ext_length + self.target_field = target_field + + def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry: + length = target_transformation_length( + data[self.target_field], self.ext_length, is_train + ) + data[self.field] = np.broadcast_to( + data[self.field], + (data[self.field].shape[:-1] + (length,)), + ) + return data + + +class TFTInstanceSplitter(InstanceSplitter): + @validated() + def __init__( + self, + instance_sampler: InstanceSampler, + past_length: int, + future_length: int, + target_field: str = FieldName.TARGET, + is_pad_field: str = FieldName.IS_PAD, + start_field: str = FieldName.START, + forecast_start_field: str = FieldName.FORECAST_START, + observed_value_field: str = FieldName.OBSERVED_VALUES, + lead_time: int = 0, + output_NTC: bool = True, + time_series_fields: List[str] = [], + past_time_series_fields: List[str] = [], + dummy_value: float = 0.0, + ) -> None: + + super().__init__( + target_field=target_field, + is_pad_field=is_pad_field, + start_field=start_field, + forecast_start_field=forecast_start_field, + instance_sampler=instance_sampler, + past_length=past_length, + future_length=future_length, + lead_time=lead_time, + output_NTC=output_NTC, + time_series_fields=time_series_fields, + dummy_value=dummy_value, + ) + + assert past_length > 0, "The value of `past_length` should be > 0" + assert future_length > 0, "The value of `future_length` should be > 0" + + self.observed_value_field = observed_value_field + self.past_ts_fields = past_time_series_fields + + def flatmap_transform(self, data: DataEntry, is_train: bool) -> Iterator[DataEntry]: + pl = self.future_length + lt = self.lead_time + target = data[self.target_field] + + sampled_indices = self.instance_sampler(target) + + slice_cols = ( + self.ts_fields + + self.past_ts_fields + + [self.target_field, self.observed_value_field] + ) + for i in sampled_indices: + pad_length = max(self.past_length - i, 0) + d = data.copy() + + for field in slice_cols: + if i >= self.past_length: + past_piece = d[field][..., i - self.past_length : i] + else: + pad_block = np.full( + shape=d[field].shape[:-1] + (pad_length,), + fill_value=self.dummy_value, + dtype=d[field].dtype, + ) + past_piece = np.concatenate([pad_block, d[field][..., :i]], axis=-1) + future_piece = d[field][..., (i + lt) : (i + lt + pl)] + if field in self.ts_fields: + piece = np.concatenate([past_piece, future_piece], axis=-1) + if self.output_NTC: + piece = piece.transpose() + d[field] = piece + else: + if self.output_NTC: + past_piece = past_piece.transpose() + future_piece = future_piece.transpose() + if field not in self.past_ts_fields: + d[self._past(field)] = past_piece + d[self._future(field)] = future_piece + del d[field] + else: + d[field] = past_piece + pad_indicator = np.zeros(self.past_length) + if pad_length > 0: + pad_indicator[:pad_length] = 1 + d[self._past(self.is_pad_field)] = pad_indicator + d[self.forecast_start_field] = shift_timestamp(d[self.start_field], i + lt) + yield d