from typing import List, Optional, Union, Callable import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.modules.transformer import _get_activation_fn, _get_clones from gluonts.core.component import validated from gluonts.time_feature import get_lags_for_frequency from gluonts.torch.distributions import DistributionOutput, StudentTOutput from gluonts.torch.modules.feature import FeatureEmbedder from gluonts.torch.modules.scaler import MeanScaler, NOPScaler class SwitchFeedForward(nn.Module): """ ## Routing among multiple FFNs """ def __init__( self, *, capacity_factor: float, drop_tokens: bool, is_scale_prob: bool, n_experts: int, expert: nn.Module, d_model: int, dim_feedforward: int, ): """ * `capacity_factor` is the capacity of each expert as a factor relative to ideally balanced load * `drop_tokens` specifies whether to drop tokens if more tokens are routed to an expert than the capacity * `is_scale_prob` specifies whether to multiply the input to the FFN by the routing probability * `n_experts` is the number of experts * `expert` is the expert layer, a [FFN module](../feed_forward.html) * `d_model` is the number of features in a token embedding """ super().__init__() self.capacity_factor = capacity_factor self.is_scale_prob = is_scale_prob self.n_experts = n_experts self.drop_tokens = drop_tokens self.dim_feedforward = dim_feedforward # make copies of the FFNs self.experts = _get_clones(expert, n_experts) # Routing layer and softmax self.switch = nn.Linear(d_model, n_experts) self.softmax = nn.Softmax(dim=-1) def forward(self, x: torch.Tensor): """ * `x` is the input to the switching module with shape `[batch_size, seq_len, d_model]` """ # Capture the shape to change shapes later batch_size, seq_len, d_model = x.shape # Flatten the sequence and batch dimensions x = x.view(-1, d_model) # Get routing probabilities for each of the tokens. # $$p_i(x) = \frac{e^{h(x)_i}}{\sum^N_j e^{h(x)_j}}$$ # where $N$ is the number of experts `n_experts` and # $h(\cdot)$ is the linear transformation of token embeddings. route_prob = self.softmax(self.switch(x)) # Get the maximum routing probabilities and the routes. # We route to the expert with highest probability route_prob_max, routes = torch.max(route_prob, dim=-1) # Get indexes of tokens going to each expert indexes_list = [ torch.eq(routes, i).nonzero(as_tuple=True)[0] for i in range(self.n_experts) ] # Initialize an empty tensor to store outputs final_output = x.new_zeros((batch_size * seq_len, self.dim_feedforward)) # Capacity of each expert. # $$\mathrm{expert\;capacity} = # \frac{\mathrm{tokens\;per\;batch}}{\mathrm{number\;of\;experts}} # \times \mathrm{capacity\;factor}$$ capacity = int(self.capacity_factor * len(x) / self.n_experts) # Number of tokens routed to each expert. counts = x.new_tensor([len(indexes_list[i]) for i in range(self.n_experts)]) # Initialize an empty list of dropped tokens dropped = [] # Only drop tokens if `drop_tokens` is `True`. if self.drop_tokens: # Drop tokens in each of the experts for i in range(self.n_experts): # Ignore if the expert is not over capacity if len(indexes_list[i]) <= capacity: continue # Shuffle indexes before dropping indexes_list[i] = indexes_list[i][torch.randperm(len(indexes_list[i]))] # Collect the tokens over capacity as dropped tokens dropped.append(indexes_list[i][capacity:]) # Keep only the tokens upto the capacity of the expert indexes_list[i] = indexes_list[i][:capacity] # Get outputs of the expert FFNs expert_output = [ self.experts[i](x[indexes_list[i], :]) for i in range(self.n_experts) ] # Assign to final output for i in range(self.n_experts): final_output[indexes_list[i], :] = expert_output[i] # Pass through the dropped tokens if dropped: dropped = torch.cat(dropped) final_output[dropped, :] = x[dropped, :] if self.is_scale_prob: # Multiply by the expert outputs by the probabilities $y = p_i(x) E_i(x)$ final_output = final_output * route_prob_max.view(-1, 1) else: # Don't scale the values but multiply by $\frac{p}{\hat{p}} = 1$ so that the gradients flow # (this is something we experimented with). final_output = final_output * ( route_prob_max / route_prob_max.detach() ).view(-1, 1) # Change the shape of the final output back to `[batch_size, seq_len, d_ff]` final_output = final_output.view(batch_size, seq_len, -1) # Return # # * the final output # * counts: number of tokens routed to each expert # * sum of probabilities for each expert # * number of tokens dropped. # * routing probabilities of the selected experts # # These are used for the load balancing loss and logging return final_output, counts, route_prob.sum(0), len(dropped), route_prob_max class SwitchTransformerEncoderLayer(nn.Module): __constants__ = ["batch_first", "norm_first"] def __init__( self, d_model: int, nhead: int, capacity_factor: float, drop_tokens: bool, is_scale_prob: bool, n_experts: int = 1, dim_feedforward: int = 2048, dropout: float = 0.1, activation: Union[str, Callable[[torch.Tensor], torch.Tensor]] = F.relu, layer_norm_eps: float = 1e-5, batch_first: bool = True, norm_first: bool = False, device=None, dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super(SwitchTransformerEncoderLayer, self).__init__() self.self_attn = nn.MultiheadAttention( d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs ) # Implementation of Feedforward model linear = nn.Linear(d_model, dim_feedforward, **factory_kwargs) self.linear1 = SwitchFeedForward( capacity_factor=capacity_factor, drop_tokens=drop_tokens, is_scale_prob=is_scale_prob, n_experts=n_experts, expert=linear, d_model=d_model, dim_feedforward=dim_feedforward, ) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) self.norm_first = norm_first self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) # Legacy string support for activation function. if isinstance(activation, str): self.activation = _get_activation_fn(activation) else: self.activation = activation def __setstate__(self, state): if "activation" not in state: state["activation"] = F.relu super(SwitchTransformerEncoderLayer, self).__setstate__(state) def forward( self, src: torch.Tensor, src_mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: x = src if self.norm_first: x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) x = x + self._ff_block(self.norm2(x)) else: x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask)) x = self.norm2(x + self._ff_block(x)) return x # self-attention block def _sa_block( self, x: torch.Tensor, attn_mask: Optional[torch.Tensor], key_padding_mask: Optional[torch.Tensor], ) -> torch.Tensor: x = self.self_attn( x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False, )[0] return self.dropout1(x) # feed forward block def _ff_block(self, x: torch.Tensor) -> torch.Tensor: x, _, _, _, _ = self.linear1(x) x = self.linear2(self.dropout(self.activation(x))) return self.dropout2(x) class SwitchTransformerModel(nn.Module): @validated() def __init__( self, freq: str, context_length: int, prediction_length: int, num_feat_dynamic_real: int, num_feat_static_real: int, num_feat_static_cat: int, cardinality: List[int], # switch transformer arguments nhead: int, num_encoder_layers: int, num_decoder_layers: int, dim_feedforward: int, capacity_factor: float, activation: str = "gelu", dropout: float = 0.1, layer_norm_eps: float = 1e-5, drop_tokens: bool = False, is_scale_prob: bool = True, n_experts: int = 1, # univariate input input_size: int = 1, embedding_dimension: Optional[List[int]] = None, distr_output: DistributionOutput = StudentTOutput(), lags_seq: Optional[List[int]] = None, scaling: bool = True, num_parallel_samples: int = 100, ) -> None: super().__init__() self.input_size = input_size self.target_shape = distr_output.event_shape self.num_feat_dynamic_real = num_feat_dynamic_real self.num_feat_static_cat = num_feat_static_cat self.num_feat_static_real = num_feat_static_real self.embedding_dimension = ( embedding_dimension if embedding_dimension is not None or cardinality is None else [min(50, (cat + 1) // 2) for cat in cardinality] ) self.lags_seq = lags_seq or get_lags_for_frequency(freq_str=freq) self.num_parallel_samples = num_parallel_samples self.history_length = context_length + max(self.lags_seq) self.embedder = FeatureEmbedder( cardinalities=cardinality, embedding_dims=self.embedding_dimension, ) if scaling: self.scaler = MeanScaler(dim=1, keepdim=True) else: self.scaler = NOPScaler(dim=1, keepdim=True) # total feature size d_model = self.input_size * len(self.lags_seq) + self._number_of_features self.context_length = context_length self.prediction_length = prediction_length self.distr_output = distr_output self.param_proj = distr_output.get_args_proj(d_model) # switch-transformer enc switch_encoder_layer = SwitchTransformerEncoderLayer( d_model=d_model, nhead=nhead, capacity_factor=capacity_factor, drop_tokens=drop_tokens, is_scale_prob=is_scale_prob, n_experts=n_experts, dim_feedforward=dim_feedforward, dropout=dropout, activation=activation, layer_norm_eps=layer_norm_eps, ) switch_encoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps) switch_encoder = nn.TransformerEncoder( switch_encoder_layer, num_encoder_layers, switch_encoder_norm ) # vanilla decoder and mask initializer self.transformer = nn.Transformer( d_model=d_model, nhead=nhead, custom_encoder=switch_encoder, num_decoder_layers=num_decoder_layers, dim_feedforward=dim_feedforward, dropout=dropout, activation=activation, batch_first=True, ) # causal decoder tgt mask self.register_buffer( "tgt_mask", nn.Transformer.generate_square_subsequent_mask(prediction_length), ) @property def _number_of_features(self) -> int: return ( sum(self.embedding_dimension) + self.num_feat_dynamic_real + self.num_feat_static_real + self.input_size # the log(scale) ) @property def _past_length(self) -> int: return self.context_length + max(self.lags_seq) def get_lagged_subsequences( self, sequence: torch.Tensor, subsequences_length: int, shift: int = 0 ) -> torch.Tensor: """ Returns lagged subsequences of a given sequence. Parameters ---------- sequence : Tensor the sequence from which lagged subsequences should be extracted. Shape: (N, T, C). subsequences_length : int length of the subsequences to be extracted. shift: int shift the lags by this amount back. Returns -------- lagged : Tensor a tensor of shape (N, S, C, I), where S = subsequences_length and I = len(indices), containing lagged subsequences. Specifically, lagged[i, j, :, k] = sequence[i, -indices[k]-S+j, :]. """ sequence_length = sequence.shape[1] indices = [lag - shift for lag in self.lags_seq] assert max(indices) + subsequences_length <= sequence_length, ( f"lags cannot go further than history length, found lag {max(indices)} " f"while history length is only {sequence_length}" ) lagged_values = [] for lag_index in indices: begin_index = -lag_index - subsequences_length end_index = -lag_index if lag_index > 0 else None lagged_values.append(sequence[:, begin_index:end_index, ...]) return torch.stack(lagged_values, dim=-1) def _check_shapes( self, prior_input: torch.Tensor, inputs: torch.Tensor, features: Optional[torch.Tensor], ) -> None: assert len(prior_input.shape) == len(inputs.shape) assert ( len(prior_input.shape) == 2 and self.input_size == 1 ) or prior_input.shape[2] == self.input_size assert (len(inputs.shape) == 2 and self.input_size == 1) or inputs.shape[ -1 ] == self.input_size assert ( features is None or features.shape[2] == self._number_of_features ), f"{features.shape[2]}, expected {self._number_of_features}" def create_network_inputs( self, feat_static_cat: torch.Tensor, feat_static_real: torch.Tensor, past_time_feat: torch.Tensor, past_target: torch.Tensor, past_observed_values: torch.Tensor, future_time_feat: Optional[torch.Tensor] = None, future_target: Optional[torch.Tensor] = None, ): # time feature time_feat = ( torch.cat( ( past_time_feat[:, self._past_length - self.context_length :, ...], future_time_feat, ), dim=1, ) if future_target is not None else past_time_feat[:, self._past_length - self.context_length :, ...] ) # target context = past_target[:, -self.context_length :] observed_context = past_observed_values[:, -self.context_length :] _, scale = self.scaler(context, observed_context) inputs = ( torch.cat((past_target, future_target), dim=1) / scale if future_target is not None else past_target / scale ) inputs_length = ( self._past_length + self.prediction_length if future_target is not None else self._past_length ) assert inputs.shape[1] == inputs_length subsequences_length = ( self.context_length + self.prediction_length if future_target is not None else self.context_length ) # embeddings embedded_cat = self.embedder(feat_static_cat) log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log() static_feat = torch.cat( (embedded_cat, feat_static_real, log_scale), dim=1, ) expanded_static_feat = static_feat.unsqueeze(1).expand( -1, time_feat.shape[1], -1 ) features = torch.cat((expanded_static_feat, time_feat), dim=-1) # self._check_shapes(prior_input, inputs, features) # sequence = torch.cat((prior_input, inputs), dim=1) lagged_sequence = self.get_lagged_subsequences( sequence=inputs, subsequences_length=subsequences_length, ) lags_shape = lagged_sequence.shape reshaped_lagged_sequence = lagged_sequence.reshape( lags_shape[0], lags_shape[1], -1 ) transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1) return transformer_inputs, scale, static_feat def output_params(self, transformer_inputs): enc_input = transformer_inputs[:, : self.context_length, ...] dec_input = transformer_inputs[:, self.context_length :, ...] enc_out = self.transformer.encoder(enc_input) dec_output = self.transformer.decoder( dec_input, enc_out, tgt_mask=self.tgt_mask ) return self.param_proj(dec_output) @torch.jit.ignore def output_distribution( self, params, scale=None, trailing_n=None ) -> torch.distributions.Distribution: sliced_params = params if trailing_n is not None: sliced_params = [p[:, -trailing_n:] for p in params] return self.distr_output.distribution(sliced_params, scale=scale) # for prediction def forward( self, feat_static_cat: torch.Tensor, feat_static_real: torch.Tensor, past_time_feat: torch.Tensor, past_target: torch.Tensor, past_observed_values: torch.Tensor, future_time_feat: torch.Tensor, num_parallel_samples: Optional[int] = None, ) -> torch.Tensor: if num_parallel_samples is None: num_parallel_samples = self.num_parallel_samples encoder_inputs, scale, static_feat = self.create_network_inputs( feat_static_cat, feat_static_real, past_time_feat, past_target, past_observed_values, ) enc_out = self.transformer.encoder(encoder_inputs) repeated_scale = scale.repeat_interleave( repeats=self.num_parallel_samples, dim=0 ) repeated_past_target = ( past_target.repeat_interleave(repeats=self.num_parallel_samples, dim=0) / repeated_scale ) expanded_static_feat = static_feat.unsqueeze(1).expand( -1, future_time_feat.shape[1], -1 ) features = torch.cat((expanded_static_feat, future_time_feat), dim=-1) repeated_features = features.repeat_interleave( repeats=self.num_parallel_samples, dim=0 ) repeated_enc_out = enc_out.repeat_interleave( repeats=self.num_parallel_samples, dim=0 ) future_samples = [] # greedy decoding for k in range(self.prediction_length): # self._check_shapes(repeated_past_target, next_sample, next_features) # sequence = torch.cat((repeated_past_target, next_sample), dim=1) lagged_sequence = self.get_lagged_subsequences( sequence=repeated_past_target, subsequences_length=1 + k, shift=1, ) lags_shape = lagged_sequence.shape reshaped_lagged_sequence = lagged_sequence.reshape( lags_shape[0], lags_shape[1], -1 ) decoder_input = torch.cat( (reshaped_lagged_sequence, repeated_features[:, : k + 1]), dim=-1 ) output = self.transformer.decoder(decoder_input, repeated_enc_out) params = self.param_proj(output[:, -1:]) distr = self.output_distribution(params, scale=repeated_scale) next_sample = distr.sample() repeated_past_target = torch.cat( (repeated_past_target, next_sample / repeated_scale), dim=1 ) future_samples.append(next_sample) concat_future_samples = torch.cat(future_samples, dim=1) return concat_future_samples.reshape( (-1, self.num_parallel_samples, self.prediction_length) + self.target_shape, )