mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 17:01:14 +08:00
perf: use matmul for lora adapter projections
This commit is contained in:
@@ -224,7 +224,7 @@ class AntiPaSTO:
|
|||||||
raise ValueError(f"rotate_basis must be 'U', 'V', or 'none', got {rotate_basis!r}")
|
raise ValueError(f"rotate_basis must be 'U', 'V', or 'none', got {rotate_basis!r}")
|
||||||
|
|
||||||
S_eff = S + layer.lora_delta_s.to(x.dtype) # (r,)
|
S_eff = S + layer.lora_delta_s.to(x.dtype) # (r,)
|
||||||
h = einsum(x, Vh_eff, "... i, r i -> ... r") # x @ Vh_eff.T
|
h = x @ Vh_eff.T # x @ Vh_eff.T
|
||||||
h = h * S_eff # diag(S_eff)
|
h = h * S_eff # diag(S_eff)
|
||||||
delta = einsum(h, U_eff, "... r, o r -> ... o") # @ U_eff.T
|
delta = h @ U_eff.T # @ U_eff.T
|
||||||
return y + delta
|
return y + delta
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ Refs:
|
|||||||
(offline: docs/refs/peft_delora_layer.py)
|
(offline: docs/refs/peft_delora_layer.py)
|
||||||
"""
|
"""
|
||||||
import torch
|
import torch
|
||||||
from einops import einsum
|
|
||||||
from jaxtyping import Float
|
from jaxtyping import Float
|
||||||
from torch import nn, Tensor as T
|
from torch import nn, Tensor as T
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -70,11 +69,11 @@ class DeLoRA:
|
|||||||
cfg = layer._lora_cfg
|
cfg = layer._lora_cfg
|
||||||
A = layer.lora_A # (r, d_in)
|
A = layer.lora_A # (r, d_in)
|
||||||
B = layer.lora_B # (d_out, r)
|
B = layer.lora_B # (d_out, r)
|
||||||
x_scaled = x * layer.lora_wnorm # (..., d_in)
|
x_scaled = x.to(A.dtype) * layer.lora_wnorm.to(A.dtype) # (..., d_in)
|
||||||
h = einsum(x_scaled, A, "... i, r i -> ... r")
|
h = x_scaled @ A.T
|
||||||
An = torch.clamp(A.norm(dim=1), min=1e-4) # (r,)
|
An = torch.clamp(A.norm(dim=1), min=1e-4) # (r,)
|
||||||
Bn = torch.clamp(B.norm(dim=0), min=1e-4) # (r,)
|
Bn = torch.clamp(B.norm(dim=0), min=1e-4) # (r,)
|
||||||
scale = (layer.lora_lambda / cfg.r) / (An * Bn) # (r,)
|
scale = (layer.lora_lambda / cfg.r) / (An * Bn) # (r,)
|
||||||
h = h * scale
|
h = h * scale
|
||||||
delta = einsum(h, B, "... r, o r -> ... o")
|
delta = h @ B.T
|
||||||
return y + delta
|
return y + delta.to(y.dtype)
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ Refs:
|
|||||||
(offline: docs/refs/peft_lora_dora.py)
|
(offline: docs/refs/peft_lora_dora.py)
|
||||||
"""
|
"""
|
||||||
import torch
|
import torch
|
||||||
from einops import einsum
|
|
||||||
from jaxtyping import Float
|
from jaxtyping import Float
|
||||||
from torch import nn, Tensor as T
|
from torch import nn, Tensor as T
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -59,14 +58,16 @@ class DoRA:
|
|||||||
scale = cfg.alpha / cfg.r
|
scale = cfg.alpha / cfg.r
|
||||||
# Paper §4.3: treat ||V+ΔV||_c as a constant (detach from grad graph) for
|
# Paper §4.3: treat ||V+ΔV||_c as a constant (detach from grad graph) for
|
||||||
# stability and ~2x lower memory. Match peft (lora_weight.detach + weight_norm.detach).
|
# stability and ~2x lower memory. Match peft (lora_weight.detach + weight_norm.detach).
|
||||||
BA = einsum(layer.lora_B, layer.lora_A, "o r, r i -> o i")
|
BA = layer.lora_B @ layer.lora_A
|
||||||
V = layer.weight + scale * BA.detach() # (d_out, d_in)
|
V = layer.weight + scale * BA.detach() # (d_out, d_in)
|
||||||
v_norm = V.norm(dim=1).clamp_min(1e-12).detach() # (d_out,)
|
v_norm = V.norm(dim=1).clamp_min(1e-12).detach() # (d_out,)
|
||||||
# Bias passes through unscaled (matches peft).
|
# Bias passes through unscaled (matches peft).
|
||||||
bias = getattr(layer, "bias", None)
|
bias = getattr(layer, "bias", None)
|
||||||
wx = y if bias is None else (y - bias)
|
wx = y if bias is None else (y - bias)
|
||||||
h = einsum(x, layer.lora_A, "... i, r i -> ... r")
|
xA = x.to(layer.lora_A.dtype)
|
||||||
delta = einsum(h, layer.lora_B, "... r, o r -> ... o")
|
h = xA @ layer.lora_A.T
|
||||||
combined = wx + scale * delta
|
delta = h @ layer.lora_B.T
|
||||||
|
combined = wx + (scale * delta).to(wx.dtype)
|
||||||
out = (layer.lora_m / v_norm) * combined
|
out = (layer.lora_m / v_norm) * combined
|
||||||
return out if bias is None else out + bias
|
out = out if bias is None else out + bias
|
||||||
|
return out.to(y.dtype)
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ Refs:
|
|||||||
(offline: docs/refs/peft_eva.py; example: docs/refs/peft_eva_finetuning.py)
|
(offline: docs/refs/peft_eva.py; example: docs/refs/peft_eva_finetuning.py)
|
||||||
"""
|
"""
|
||||||
import torch
|
import torch
|
||||||
from einops import einsum, rearrange
|
from einops import rearrange
|
||||||
from jaxtyping import Float
|
from jaxtyping import Float
|
||||||
from torch import nn, Tensor as T
|
from torch import nn, Tensor as T
|
||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
@@ -113,6 +113,7 @@ class EVA:
|
|||||||
) -> Float[T, '*B o']:
|
) -> Float[T, '*B o']:
|
||||||
cfg = layer._lora_cfg
|
cfg = layer._lora_cfg
|
||||||
scale = cfg.alpha / cfg.r
|
scale = cfg.alpha / cfg.r
|
||||||
h = einsum(x, layer.lora_A, "... i, r i -> ... r")
|
xA = x.to(layer.lora_A.dtype)
|
||||||
delta = einsum(h, layer.lora_B, "... r, o r -> ... o")
|
h = xA @ layer.lora_A.T
|
||||||
return y + scale * delta
|
delta = h @ layer.lora_B.T
|
||||||
|
return y + (scale * delta).to(y.dtype)
|
||||||
|
|||||||
@@ -8,10 +8,8 @@ Refs:
|
|||||||
- peft: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/layer.py
|
- peft: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/layer.py
|
||||||
(offline: docs/refs/peft_lora_layer.py)
|
(offline: docs/refs/peft_lora_layer.py)
|
||||||
"""
|
"""
|
||||||
from einops import einsum
|
|
||||||
from jaxtyping import Float
|
from jaxtyping import Float
|
||||||
from torch import nn, Tensor as T
|
from torch import nn, Tensor as T
|
||||||
import torch
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from ..variant import register, ParamSpec
|
from ..variant import register, ParamSpec
|
||||||
@@ -47,6 +45,7 @@ class LoRA:
|
|||||||
) -> Float[T, '*B o']:
|
) -> Float[T, '*B o']:
|
||||||
cfg = layer._lora_cfg
|
cfg = layer._lora_cfg
|
||||||
scale = cfg.alpha / cfg.r
|
scale = cfg.alpha / cfg.r
|
||||||
h = einsum(x, layer.lora_A, "... i, r i -> ... r")
|
xA = x.to(layer.lora_A.dtype)
|
||||||
delta = einsum(h, layer.lora_B, "... r, o r -> ... o")
|
h = xA @ layer.lora_A.T
|
||||||
return y + scale * delta
|
delta = h @ layer.lora_B.T
|
||||||
|
return y + (scale * delta).to(y.dtype)
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ Refs:
|
|||||||
(offline: docs/refs/peft_lora_layer.py, see pissa_init path)
|
(offline: docs/refs/peft_lora_layer.py, see pissa_init path)
|
||||||
"""
|
"""
|
||||||
import torch
|
import torch
|
||||||
from einops import einsum
|
|
||||||
from jaxtyping import Float
|
from jaxtyping import Float
|
||||||
from torch import nn, Tensor as T
|
from torch import nn, Tensor as T
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -76,6 +75,7 @@ class PiSSA:
|
|||||||
) -> Float[T, '*B o']:
|
) -> Float[T, '*B o']:
|
||||||
cfg = layer._lora_cfg
|
cfg = layer._lora_cfg
|
||||||
scale = cfg.alpha / cfg.r
|
scale = cfg.alpha / cfg.r
|
||||||
h = einsum(x, layer.lora_A, "... i, r i -> ... r")
|
xA = x.to(layer.lora_A.dtype)
|
||||||
delta = einsum(h, layer.lora_B, "... r, o r -> ... o")
|
h = xA @ layer.lora_A.T
|
||||||
return y + scale * delta
|
delta = h @ layer.lora_B.T
|
||||||
|
return y + (scale * delta).to(y.dtype)
|
||||||
|
|||||||
Reference in New Issue
Block a user