mirror of
https://github.com/wassname/lora-lite.git
synced 2026-07-04 22:46:01 +08:00
56 lines
2.1 KiB
Python
56 lines
2.1 KiB
Python
"""DoRA: weight-decomposed LoRA. Liu et al. 2024 https://arxiv.org/abs/2402.09353
|
|
|
|
W' = m * V / ||V||_c where V = W + (alpha/r) B A (||.||_c = per-output-row L2 norm)
|
|
|
|
At t=0: B=0 -> V=W -> y_new = (m_init / ||W||_c) (Wx + 0) = Wx when m_init = ||W||_c.
|
|
|
|
Limitation: requires materializing the dense weight to compute ||V||_c. v1 supports
|
|
plain nn.Linear only; bnb 4/8-bit layers raise loudly.
|
|
"""
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from einops import einsum
|
|
from torch import nn
|
|
|
|
from ..variant import register, ParamSpec
|
|
|
|
|
|
@register
|
|
class DoRA:
|
|
name = "dora"
|
|
|
|
@staticmethod
|
|
def param_specs(d_in, d_out, cfg):
|
|
return {
|
|
"lora_A": ParamSpec((cfg.r, d_in), init="kaiming", trainable=True),
|
|
"lora_B": ParamSpec((d_out, cfg.r), init="zeros", trainable=True),
|
|
# m is filled from ||W||_c during init(); shape (d_out,)
|
|
"lora_m": ParamSpec((d_out,), init="zeros", trainable=True),
|
|
}
|
|
|
|
@staticmethod
|
|
def init(layer: nn.Linear, cfg) -> None:
|
|
if type(layer) is not nn.Linear:
|
|
raise TypeError(
|
|
"DoRA needs ||W||_c, so v1 only supports plain nn.Linear. "
|
|
"For bnb layers, dequantize first or use LoRA/IA3."
|
|
)
|
|
with torch.no_grad():
|
|
W = layer.weight.data.float() # (d_out, d_in)
|
|
col_norm = W.norm(dim=1).to(layer.lora_m.dtype) # (d_out,)
|
|
layer.lora_m.data.copy_(col_norm)
|
|
|
|
@staticmethod
|
|
def forward(layer: nn.Linear, x, y):
|
|
cfg = layer._lora_cfg
|
|
scale = cfg.alpha / cfg.r
|
|
# V = W + scale * B @ A
|
|
BA = einsum(layer.lora_B, layer.lora_A, "o r, r i -> o i")
|
|
V = layer.weight + scale * BA # (d_out, d_in)
|
|
v_norm = V.norm(dim=1).clamp_min(1e-12) # (d_out,)
|
|
# y' = (m / ||V||_c) * (Wx + scale * BAx) = (m / ||V||_c) * (y + scale * BAx)
|
|
h = einsum(x, layer.lora_A, "... i, r i -> ... r")
|
|
delta = einsum(h, layer.lora_B, "... r, o r -> ... o")
|
|
combined = y + scale * delta
|
|
return (layer.lora_m / v_norm) * combined
|