Files
lora-lite/src/lora_lite/variants/dora.py
T

56 lines
2.1 KiB
Python

"""DoRA: weight-decomposed LoRA. Liu et al. 2024 https://arxiv.org/abs/2402.09353
W' = m * V / ||V||_c where V = W + (alpha/r) B A (||.||_c = per-output-row L2 norm)
At t=0: B=0 -> V=W -> y_new = (m_init / ||W||_c) (Wx + 0) = Wx when m_init = ||W||_c.
Limitation: requires materializing the dense weight to compute ||V||_c. v1 supports
plain nn.Linear only; bnb 4/8-bit layers raise loudly.
"""
import torch
import torch.nn.functional as F
from einops import einsum
from torch import nn
from ..variant import register, ParamSpec
@register
class DoRA:
name = "dora"
@staticmethod
def param_specs(d_in, d_out, cfg):
return {
"lora_A": ParamSpec((cfg.r, d_in), init="kaiming", trainable=True),
"lora_B": ParamSpec((d_out, cfg.r), init="zeros", trainable=True),
# m is filled from ||W||_c during init(); shape (d_out,)
"lora_m": ParamSpec((d_out,), init="zeros", trainable=True),
}
@staticmethod
def init(layer: nn.Linear, cfg) -> None:
if type(layer) is not nn.Linear:
raise TypeError(
"DoRA needs ||W||_c, so v1 only supports plain nn.Linear. "
"For bnb layers, dequantize first or use LoRA/IA3."
)
with torch.no_grad():
W = layer.weight.data.float() # (d_out, d_in)
col_norm = W.norm(dim=1).to(layer.lora_m.dtype) # (d_out,)
layer.lora_m.data.copy_(col_norm)
@staticmethod
def forward(layer: nn.Linear, x, y):
cfg = layer._lora_cfg
scale = cfg.alpha / cfg.r
# V = W + scale * B @ A
BA = einsum(layer.lora_B, layer.lora_A, "o r, r i -> o i")
V = layer.weight + scale * BA # (d_out, d_in)
v_norm = V.norm(dim=1).clamp_min(1e-12) # (d_out,)
# y' = (m / ||V||_c) * (Wx + scale * BAx) = (m / ||V||_c) * (y + scale * BAx)
h = einsum(x, layer.lora_A, "... i, r i -> ... r")
delta = einsum(h, layer.lora_B, "... r, o r -> ... o")
combined = y + scale * delta
return (layer.lora_m / v_norm) * combined