from typing import List, Optional import torch import torch.nn as nn from gluonts.core.component import validated from gluonts.time_feature import get_lags_for_frequency from gluonts.torch.distributions import DistributionOutput, StudentTOutput from gluonts.torch.modules.feature import FeatureEmbedder from gluonts.torch.modules.scaler import MeanScaler, NOPScaler class TransformerModel(nn.Module): @validated() def __init__( self, context_length: int, prediction_length: int, num_feat_dynamic_real: int, num_feat_static_real: int, num_feat_static_cat: int, cardinality: List[int], # transformer arguments nhead: int, num_encoder_layers: int, num_decoder_layers: int, dim_feedforward: int, activation: str = "gelu", dropout: float = 0.1, # univariate input input_size: int = 1, embedding_dimension: Optional[List[int]] = None, distr_output: DistributionOutput = StudentTOutput(), lags_seq: Optional[List[int]] = None, freq: Optional[str] = None, scaling: bool = True, num_parallel_samples: int = 100, ) -> None: super().__init__() self.input_size = input_size self.target_shape = distr_output.event_shape self.num_feat_dynamic_real = num_feat_dynamic_real self.num_feat_static_cat = num_feat_static_cat self.num_feat_static_real = num_feat_static_real self.embedding_dimension = ( embedding_dimension if embedding_dimension is not None or cardinality is None else [min(50, (cat + 1) // 2) for cat in cardinality] ) self.lags_seq = lags_seq or get_lags_for_frequency(freq_str=freq) self.num_parallel_samples = num_parallel_samples self.history_length = context_length + max(self.lags_seq) self.embedder = FeatureEmbedder( cardinalities=cardinality, embedding_dims=self.embedding_dimension, ) if scaling: self.scaler = MeanScaler(dim=1, keepdim=True) else: self.scaler = NOPScaler(dim=1, keepdim=True) # total feature size d_model = self.input_size * len(self.lags_seq) + self._number_of_features self.context_length = context_length self.prediction_length = prediction_length self.distr_output = distr_output self.param_proj = distr_output.get_args_proj(d_model) # transformer enc-decoder and mask initializer self.transformer = nn.Transformer( d_model=d_model, nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, dim_feedforward=dim_feedforward, dropout=dropout, activation=activation, batch_first=True, norm_first=True, ) # causal decoder tgt mask self.register_buffer( "tgt_mask", self.transformer.generate_square_subsequent_mask(prediction_length), ) @property def _number_of_features(self) -> int: return ( sum(self.embedding_dimension) + self.num_feat_dynamic_real + self.num_feat_static_real + self.input_size # the log(scale) ) @property def _past_length(self) -> int: return self.context_length + max(self.lags_seq) def get_lagged_subsequences( self, sequence: torch.Tensor, subsequences_length: int, shift: int = 0 ) -> torch.Tensor: """ Returns lagged subsequences of a given sequence. Parameters ---------- sequence : Tensor the sequence from which lagged subsequences should be extracted. Shape: (N, T, C). subsequences_length : int length of the subsequences to be extracted. shift: int shift the lags by this amount back. Returns -------- lagged : Tensor a tensor of shape (N, S, C, I), where S = subsequences_length and I = len(indices), containing lagged subsequences. Specifically, lagged[i, j, :, k] = sequence[i, -indices[k]-S+j, :]. """ sequence_length = sequence.shape[1] indices = [lag - shift for lag in self.lags_seq] assert max(indices) + subsequences_length <= sequence_length, ( f"lags cannot go further than history length, found lag {max(indices)} " f"while history length is only {sequence_length}" ) lagged_values = [] for lag_index in indices: begin_index = -lag_index - subsequences_length end_index = -lag_index if lag_index > 0 else None lagged_values.append(sequence[:, begin_index:end_index, ...]) return torch.stack(lagged_values, dim=-1) def _check_shapes( self, prior_input: torch.Tensor, inputs: torch.Tensor, features: Optional[torch.Tensor], ) -> None: assert len(prior_input.shape) == len(inputs.shape) assert ( len(prior_input.shape) == 2 and self.input_size == 1 ) or prior_input.shape[2] == self.input_size assert (len(inputs.shape) == 2 and self.input_size == 1) or inputs.shape[ -1 ] == self.input_size assert ( features is None or features.shape[2] == self._number_of_features ), f"{features.shape[2]}, expected {self._number_of_features}" def create_network_inputs( self, feat_static_cat: torch.Tensor, feat_static_real: torch.Tensor, past_time_feat: torch.Tensor, past_target: torch.Tensor, past_observed_values: torch.Tensor, future_time_feat: Optional[torch.Tensor] = None, future_target: Optional[torch.Tensor] = None, ): # time feature time_feat = ( torch.cat( ( past_time_feat[:, self._past_length - self.context_length :, ...], future_time_feat, ), dim=1, ) if future_target is not None else past_time_feat[:, self._past_length - self.context_length :, ...] ) # target context = past_target[:, -self.context_length :] observed_context = past_observed_values[:, -self.context_length :] _, scale = self.scaler(context, observed_context) inputs = ( torch.cat((past_target, future_target), dim=1) / scale if future_target is not None else past_target / scale ) inputs_length = ( self._past_length + self.prediction_length if future_target is not None else self._past_length ) assert inputs.shape[1] == inputs_length subsequences_length = ( self.context_length + self.prediction_length if future_target is not None else self.context_length ) # embeddings embedded_cat = self.embedder(feat_static_cat) log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log() static_feat = torch.cat( (embedded_cat, feat_static_real, log_scale), dim=1, ) expanded_static_feat = static_feat.unsqueeze(1).expand( -1, time_feat.shape[1], -1 ) features = torch.cat((expanded_static_feat, time_feat), dim=-1) # self._check_shapes(prior_input, inputs, features) # sequence = torch.cat((prior_input, inputs), dim=1) lagged_sequence = self.get_lagged_subsequences( sequence=inputs, subsequences_length=subsequences_length, ) lags_shape = lagged_sequence.shape reshaped_lagged_sequence = lagged_sequence.reshape( lags_shape[0], lags_shape[1], -1 ) transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1) return transformer_inputs, scale, static_feat def output_params(self, transformer_inputs): enc_input = transformer_inputs[:, : self.context_length, ...] dec_input = transformer_inputs[:, self.context_length :, ...] enc_out = self.transformer.encoder(enc_input) dec_output = self.transformer.decoder( dec_input, enc_out, tgt_mask=self.tgt_mask ) return self.param_proj(dec_output) @torch.jit.ignore def output_distribution( self, params, scale=None, trailing_n=None ) -> torch.distributions.Distribution: sliced_params = params if trailing_n is not None: sliced_params = [p[:, -trailing_n:] for p in params] return self.distr_output.distribution(sliced_params, scale=scale) # for prediction def forward( self, feat_static_cat: torch.Tensor, feat_static_real: torch.Tensor, past_time_feat: torch.Tensor, past_target: torch.Tensor, past_observed_values: torch.Tensor, future_time_feat: torch.Tensor, num_parallel_samples: Optional[int] = None, ) -> torch.Tensor: if num_parallel_samples is None: num_parallel_samples = self.num_parallel_samples encoder_inputs, scale, static_feat = self.create_network_inputs( feat_static_cat, feat_static_real, past_time_feat, past_target, past_observed_values, ) enc_out = self.transformer.encoder(encoder_inputs) repeated_scale = scale.repeat_interleave( repeats=self.num_parallel_samples, dim=0 ) repeated_past_target = ( past_target.repeat_interleave(repeats=self.num_parallel_samples, dim=0) / repeated_scale ) expanded_static_feat = static_feat.unsqueeze(1).expand( -1, future_time_feat.shape[1], -1 ) features = torch.cat((expanded_static_feat, future_time_feat), dim=-1) repeated_features = features.repeat_interleave( repeats=self.num_parallel_samples, dim=0 ) repeated_enc_out = enc_out.repeat_interleave( repeats=self.num_parallel_samples, dim=0 ) future_samples = [] # greedy decoding for k in range(self.prediction_length): # self._check_shapes(repeated_past_target, next_sample, next_features) # sequence = torch.cat((repeated_past_target, next_sample), dim=1) lagged_sequence = self.get_lagged_subsequences( sequence=repeated_past_target, subsequences_length=1 + k, shift=1, ) lags_shape = lagged_sequence.shape reshaped_lagged_sequence = lagged_sequence.reshape( lags_shape[0], lags_shape[1], -1 ) decoder_input = torch.cat( (reshaped_lagged_sequence, repeated_features[:, : k + 1]), dim=-1 ) output = self.transformer.decoder(decoder_input, repeated_enc_out) params = self.param_proj(output[:, -1:]) distr = self.output_distribution(params, scale=repeated_scale) next_sample = distr.sample() repeated_past_target = torch.cat( (repeated_past_target, next_sample / repeated_scale), dim=1 ) future_samples.append(next_sample) concat_future_samples = torch.cat(future_samples, dim=1) return concat_future_samples.reshape( (-1, self.num_parallel_samples, self.prediction_length) + self.target_shape, )