mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
56 lines
1.7 KiB
Python
56 lines
1.7 KiB
Python
import functools
|
|
|
|
import torch
|
|
from transformers.activations import QuickGELUActivation, NewGELUActivation, FastGELUActivation, GELUActivation
|
|
|
|
|
|
def rsetattr(obj, attr, val):
|
|
pre, _, post = attr.rpartition(".")
|
|
return setattr(rgetattr(obj, pre) if pre else obj, post, val)
|
|
|
|
|
|
def rgetattr(obj, attr, *args):
|
|
def _getattr(obj, attr):
|
|
return getattr(obj, attr, *args)
|
|
|
|
return functools.reduce(_getattr, [obj] + attr.split("."))
|
|
|
|
|
|
def fuse_gelu(model):
|
|
@torch.jit.script
|
|
def gelu_fwd(x):
|
|
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
|
|
|
|
@torch.jit.script
|
|
def gelu_bwd(g, x):
|
|
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
|
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
|
|
return ff * g
|
|
|
|
class _FusedGeLUFunction(torch.autograd.Function):
|
|
@staticmethod
|
|
# bias is an optional argument
|
|
def forward(ctx, input):
|
|
ctx.input_tensor = input
|
|
return gelu_fwd(input)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
input = ctx.input_tensor
|
|
tmp = gelu_bwd(grad_output, input)
|
|
return tmp
|
|
|
|
class FusedGelu(torch.nn.Module):
|
|
def forward(self, input):
|
|
return _FusedGeLUFunction.apply(input)
|
|
|
|
fused_gelu_module = FusedGelu()
|
|
hf_gelu_functions = [GELUActivation, FastGELUActivation, NewGELUActivation, QuickGELUActivation]
|
|
|
|
for name, module in model.named_modules():
|
|
for hf_gelu_function in hf_gelu_functions:
|
|
if isinstance(module, hf_gelu_function):
|
|
rsetattr(model, name, fused_gelu_module)
|
|
|
|
return model
|