diff --git a/switch/module.py b/switch/module.py index 8fb828e..036d66a 100644 --- a/switch/module.py +++ b/switch/module.py @@ -131,7 +131,7 @@ class SwitchFeedForward(nn.Module): # Return # # * the final output - # * number of tokens routed to each expert + # * counts: number of tokens routed to each expert # * sum of probabilities for each expert # * number of tokens dropped. # * routing probabilities of the selected experts @@ -140,7 +140,7 @@ class SwitchFeedForward(nn.Module): return final_output, counts, route_prob.sum(0), len(dropped), route_prob_max -class TransformerEncoderLayer(nn.Module): +class SwitchTransformerEncoderLayer(nn.Module): __constants__ = ["batch_first", "norm_first"] @@ -148,7 +148,7 @@ class TransformerEncoderLayer(nn.Module): self, d_model: int, nhead: int, - capacity_factor: int, + capacity_factor: float, drop_tokens: bool, is_scale_prob: bool, n_experts: int = 1, @@ -162,12 +162,13 @@ class TransformerEncoderLayer(nn.Module): dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} - super(TransformerEncoderLayer, self).__init__() + super(SwitchTransformerEncoderLayer, self).__init__() self.self_attn = nn.MultiheadAttention( d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs ) # Implementation of Feedforward model linear = nn.Linear(d_model, dim_feedforward, **factory_kwargs) + self.linear1 = SwitchFeedForward( capacity_factor, drop_tokens, @@ -195,7 +196,7 @@ class TransformerEncoderLayer(nn.Module): def __setstate__(self, state): if "activation" not in state: state["activation"] = F.relu - super(TransformerEncoderLayer, self).__setstate__(state) + super(SwitchTransformerEncoderLayer, self).__setstate__(state) def forward( self, @@ -233,7 +234,8 @@ class TransformerEncoderLayer(nn.Module): # feed forward block def _ff_block(self, x: torch.Tensor) -> torch.Tensor: - x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + x, _, _, _, _ = self.linear1(x) + x = self.linear2(self.dropout(self.activation(x))) return self.dropout2(x) @@ -253,8 +255,13 @@ class SwitchTransformerModel(nn.Module): num_encoder_layers: int, num_decoder_layers: int, dim_feedforward: int, + capacity_factor: float, activation: str = "gelu", dropout: float = 0.1, + layer_norm_eps: float = 1e-5, + drop_tokens: bool = False, + is_scale_prob: bool = True, + n_experts: int = 1, # univariate input input_size: int = 1, embedding_dimension: Optional[List[int]] = None, @@ -296,11 +303,29 @@ class SwitchTransformerModel(nn.Module): self.distr_output = distr_output self.param_proj = distr_output.get_args_proj(d_model) - # transformer enc-decoder and mask initializer + # switch-transformer enc + switch_encoder_layer = SwitchTransformerEncoderLayer( + d_model=d_model, + nhead=nhead, + capacity_factor=capacity_factor, + drop_tokens=drop_tokens, + is_scale_prob=is_scale_prob, + n_experts=n_experts, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation=activation, + layer_norm_eps=layer_norm_eps, + ) + switch_encoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps) + switch_encoder = nn.TransformerEncoder( + switch_encoder_layer, num_encoder_layers, switch_encoder_norm + ) + + # vanilla decoder and mask initializer self.transformer = nn.Transformer( d_model=d_model, nhead=nhead, - num_encoder_layers=num_encoder_layers, + custom_encoder=switch_encoder, num_decoder_layers=num_decoder_layers, dim_feedforward=dim_feedforward, dropout=dropout, @@ -311,7 +336,7 @@ class SwitchTransformerModel(nn.Module): # causal decoder tgt mask self.register_buffer( "tgt_mask", - self.transformer.generate_square_subsequent_mask(prediction_length), + nn.Transformer.generate_square_subsequent_mask(prediction_length), ) @property