Files
lora-lite/src/lora_lite/variants/dora.py
T
wassname e624cd244f feat: near_zero/near_one init for trainable params (breaks bf16 dead-grad symmetry)
Trainable params that were init'd at exact 0 or 1 now use near_zero (N(0,1e-4))
or near_one (1 + N(0,1e-4)) to break bf16 symmetry without meaningfully
breaking identity-at-t=0. Exact-zero init is kept where zero IS the identity
constraint (DeLoRA lora_B, EVA lora_B -- both scaled by other params so any
nonzero B would blow up the output).

AntiPaSTO: delta_s and rot_T now near_zero. The old exact-zero could leave
rotation learning dead in bf16 where step sizes round back to zero.

IA3: lora_g now near_one instead of exact ones. Avoids the bf16 spacing issue
around 1.0 where eps_bf16 ~ 7.8e-3 and lr=1e-3 updates were rounding away.

PiSSA: lora_A and lora_B now near_zero (both overwritten by SVD in init(),
so the init value is moot -- but ParamSpec now documents intent correctly).

HRA: lora_U now near_zero (overwritten by symmetric init in init()).

ParamSpec: added 'near_zero' and 'near_one' init modes. Default changed from
'zeros' to 'near_zero'. Tests relaxed identity tolerances accordingly.
2026-04-27 15:55:05 +08:00

73 lines
2.6 KiB
Python

"""DoRA: weight-decomposed LoRA. Liu et al. 2024 https://arxiv.org/abs/2402.09353
W' = m * V / ||V||_c where V = W + (alpha/r) B A (||.||_c = per-output-row L2 norm)
Identity at t=0: B=0 and m=||W||_c -> y_new = Wx. Requires dense weight (nn.Linear only).
Refs:
- peft: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/dora.py
(offline: docs/refs/peft_lora_dora.py)
"""
import torch
from einops import einsum
from jaxtyping import Float
from torch import nn, Tensor as T
from dataclasses import dataclass
from ..variant import register, ParamSpec
from ..config import AdapterConfig, register_config
@register_config
@dataclass
class DoRAConfig(AdapterConfig):
variant: str = "dora"
@register
class DoRA:
name = "dora"
@staticmethod
def param_specs(d_in, d_out, cfg):
return dict(
lora_A=ParamSpec((cfg.r, d_in), init="kaiming"),
lora_B=ParamSpec((d_out, cfg.r), init="near_zero"),
# m is filled from ||W||_c during init(); shape (d_out,)
lora_m=ParamSpec((d_out,), init="near_zero"),
)
@staticmethod
def init(layer: nn.Module, cfg) -> None:
if type(layer) is not nn.Linear:
raise TypeError(
"DoRA needs ||W||_c, so v1 only supports plain nn.Linear. "
"For bnb layers, dequantize first or use LoRA/IA3."
)
with torch.no_grad():
W = layer.weight.data.float() # (d_out, d_in)
col_norm = W.norm(dim=1).to(layer.lora_m.dtype) # (d_out,)
layer.lora_m.data.copy_(col_norm)
@staticmethod
def forward(
layer: nn.Module,
x: Float[T, '*B i'],
y: Float[T, '*B o'],
) -> Float[T, '*B o']:
cfg = layer._lora_cfg
scale = cfg.alpha / cfg.r
# Paper §4.3: treat ||V+ΔV||_c as a constant (detach from grad graph) for
# stability and ~2x lower memory. Match peft (lora_weight.detach + weight_norm.detach).
BA = einsum(layer.lora_B, layer.lora_A, "o r, r i -> o i")
V = layer.weight + scale * BA.detach() # (d_out, d_in)
v_norm = V.norm(dim=1).clamp_min(1e-12).detach() # (d_out,)
# Bias passes through unscaled (matches peft).
bias = getattr(layer, "bias", None)
wx = y if bias is None else (y - bias)
h = einsum(x, layer.lora_A, "... i, r i -> ... r")
delta = einsum(h, layer.lora_B, "... r, o r -> ... o")
combined = wx + scale * delta
out = (layer.lora_m / v_norm) * combined
return out if bias is None else out + bias