diff --git a/model/supervised_finetuning/configs/config.yaml b/model/supervised_finetuning/configs/config.yaml index 1d196fb2..d70fad41 100644 --- a/model/supervised_finetuning/configs/config.yaml +++ b/model/supervised_finetuning/configs/config.yaml @@ -46,6 +46,7 @@ defaults: quantization: false seq2seqmodel: false poly_eps: 1.0 + fuse_gelu: true galactica-125m: learning_rate: 5e-5 diff --git a/model/supervised_finetuning/efficiency_utils.py b/model/supervised_finetuning/efficiency_utils.py new file mode 100644 index 00000000..6a27accc --- /dev/null +++ b/model/supervised_finetuning/efficiency_utils.py @@ -0,0 +1,55 @@ +import functools + +import torch +from transformers.activations import FastGELUActivation, GELUActivation, NewGELUActivation, QuickGELUActivation + + +def rsetattr(obj, attr, val): + pre, _, post = attr.rpartition(".") + return setattr(rgetattr(obj, pre) if pre else obj, post, val) + + +def rgetattr(obj, attr, *args): + def _getattr(obj, attr): + return getattr(obj, attr, *args) + + return functools.reduce(_getattr, [obj] + attr.split(".")) + + +def fuse_gelu(model): + @torch.jit.script + def gelu_fwd(x): + return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) + + @torch.jit.script + def gelu_bwd(g, x): + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) + return ff * g + + class _FusedGeLUFunction(torch.autograd.Function): + @staticmethod + # bias is an optional argument + def forward(ctx, input): + ctx.input_tensor = input + return gelu_fwd(input) + + @staticmethod + def backward(ctx, grad_output): + input = ctx.input_tensor + tmp = gelu_bwd(grad_output, input) + return tmp + + class FusedGelu(torch.nn.Module): + def forward(self, input): + return _FusedGeLUFunction.apply(input) + + fused_gelu_module = FusedGelu() + hf_gelu_functions = [GELUActivation, FastGELUActivation, NewGELUActivation, QuickGELUActivation] + + for name, module in model.named_modules(): + for hf_gelu_function in hf_gelu_functions: + if isinstance(module, hf_gelu_function): + rsetattr(model, name, fused_gelu_module) + + return model diff --git a/model/supervised_finetuning/trainer.py b/model/supervised_finetuning/trainer.py index 0acb10dd..83034d95 100644 --- a/model/supervised_finetuning/trainer.py +++ b/model/supervised_finetuning/trainer.py @@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import bitsandbytes import torch +from efficiency_utils import fuse_gelu from torch import nn from transformers import PreTrainedModel, Trainer, TrainingArguments from transformers.training_args import OptimizerNames @@ -152,6 +153,9 @@ if __name__ == "__main__": module, "weight", {"optim_bits": 32} ) + if training_conf.fuse_gelu: + model = fuse_gelu(model) + args = TrainingArguments( output_dir=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned", num_train_epochs=training_conf.num_train_epochs,