perf: use matmul for lora adapter projections

This commit is contained in:
wassname
2026-05-21 08:23:56 +08:00
parent 56937e1b18
commit ce8c250422
6 changed files with 26 additions and 26 deletions
+2 -2
View File
@@ -224,7 +224,7 @@ class AntiPaSTO:
raise ValueError(f"rotate_basis must be 'U', 'V', or 'none', got {rotate_basis!r}")
S_eff = S + layer.lora_delta_s.to(x.dtype) # (r,)
h = einsum(x, Vh_eff, "... i, r i -> ... r") # x @ Vh_eff.T
h = x @ Vh_eff.T # x @ Vh_eff.T
h = h * S_eff # diag(S_eff)
delta = einsum(h, U_eff, "... r, o r -> ... o") # @ U_eff.T
delta = h @ U_eff.T # @ U_eff.T
return y + delta
+4 -5
View File
@@ -18,7 +18,6 @@ Refs:
(offline: docs/refs/peft_delora_layer.py)
"""
import torch
from einops import einsum
from jaxtyping import Float
from torch import nn, Tensor as T
from dataclasses import dataclass
@@ -70,11 +69,11 @@ class DeLoRA:
cfg = layer._lora_cfg
A = layer.lora_A # (r, d_in)
B = layer.lora_B # (d_out, r)
x_scaled = x * layer.lora_wnorm # (..., d_in)
h = einsum(x_scaled, A, "... i, r i -> ... r")
x_scaled = x.to(A.dtype) * layer.lora_wnorm.to(A.dtype) # (..., d_in)
h = x_scaled @ A.T
An = torch.clamp(A.norm(dim=1), min=1e-4) # (r,)
Bn = torch.clamp(B.norm(dim=0), min=1e-4) # (r,)
scale = (layer.lora_lambda / cfg.r) / (An * Bn) # (r,)
h = h * scale
delta = einsum(h, B, "... r, o r -> ... o")
return y + delta
delta = h @ B.T
return y + delta.to(y.dtype)
+7 -6
View File
@@ -9,7 +9,6 @@ Refs:
(offline: docs/refs/peft_lora_dora.py)
"""
import torch
from einops import einsum
from jaxtyping import Float
from torch import nn, Tensor as T
from dataclasses import dataclass
@@ -59,14 +58,16 @@ class DoRA:
scale = cfg.alpha / cfg.r
# Paper §4.3: treat ||V+ΔV||_c as a constant (detach from grad graph) for
# stability and ~2x lower memory. Match peft (lora_weight.detach + weight_norm.detach).
BA = einsum(layer.lora_B, layer.lora_A, "o r, r i -> o i")
BA = layer.lora_B @ layer.lora_A
V = layer.weight + scale * BA.detach() # (d_out, d_in)
v_norm = V.norm(dim=1).clamp_min(1e-12).detach() # (d_out,)
# Bias passes through unscaled (matches peft).
bias = getattr(layer, "bias", None)
wx = y if bias is None else (y - bias)
h = einsum(x, layer.lora_A, "... i, r i -> ... r")
delta = einsum(h, layer.lora_B, "... r, o r -> ... o")
combined = wx + scale * delta
xA = x.to(layer.lora_A.dtype)
h = xA @ layer.lora_A.T
delta = h @ layer.lora_B.T
combined = wx + (scale * delta).to(wx.dtype)
out = (layer.lora_m / v_norm) * combined
return out if bias is None else out + bias
out = out if bias is None else out + bias
return out.to(y.dtype)
+5 -4
View File
@@ -13,7 +13,7 @@ Refs:
(offline: docs/refs/peft_eva.py; example: docs/refs/peft_eva_finetuning.py)
"""
import torch
from einops import einsum, rearrange
from einops import rearrange
from jaxtyping import Float
from torch import nn, Tensor as T
from typing import Iterable
@@ -113,6 +113,7 @@ class EVA:
) -> Float[T, '*B o']:
cfg = layer._lora_cfg
scale = cfg.alpha / cfg.r
h = einsum(x, layer.lora_A, "... i, r i -> ... r")
delta = einsum(h, layer.lora_B, "... r, o r -> ... o")
return y + scale * delta
xA = x.to(layer.lora_A.dtype)
h = xA @ layer.lora_A.T
delta = h @ layer.lora_B.T
return y + (scale * delta).to(y.dtype)
+4 -5
View File
@@ -8,10 +8,8 @@ Refs:
- peft: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/layer.py
(offline: docs/refs/peft_lora_layer.py)
"""
from einops import einsum
from jaxtyping import Float
from torch import nn, Tensor as T
import torch
from dataclasses import dataclass
from ..variant import register, ParamSpec
@@ -47,6 +45,7 @@ class LoRA:
) -> Float[T, '*B o']:
cfg = layer._lora_cfg
scale = cfg.alpha / cfg.r
h = einsum(x, layer.lora_A, "... i, r i -> ... r")
delta = einsum(h, layer.lora_B, "... r, o r -> ... o")
return y + scale * delta
xA = x.to(layer.lora_A.dtype)
h = xA @ layer.lora_A.T
delta = h @ layer.lora_B.T
return y + (scale * delta).to(y.dtype)
+4 -4
View File
@@ -16,7 +16,6 @@ Refs:
(offline: docs/refs/peft_lora_layer.py, see pissa_init path)
"""
import torch
from einops import einsum
from jaxtyping import Float
from torch import nn, Tensor as T
from dataclasses import dataclass
@@ -76,6 +75,7 @@ class PiSSA:
) -> Float[T, '*B o']:
cfg = layer._lora_cfg
scale = cfg.alpha / cfg.r
h = einsum(x, layer.lora_A, "... i, r i -> ... r")
delta = einsum(h, layer.lora_B, "... r, o r -> ... o")
return y + scale * delta
xA = x.to(layer.lora_A.dtype)
h = xA @ layer.lora_A.T
delta = h @ layer.lora_B.T
return y + (scale * delta).to(y.dtype)