mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 18:05:16 +08:00
e624cd244f
Trainable params that were init'd at exact 0 or 1 now use near_zero (N(0,1e-4)) or near_one (1 + N(0,1e-4)) to break bf16 symmetry without meaningfully breaking identity-at-t=0. Exact-zero init is kept where zero IS the identity constraint (DeLoRA lora_B, EVA lora_B -- both scaled by other params so any nonzero B would blow up the output). AntiPaSTO: delta_s and rot_T now near_zero. The old exact-zero could leave rotation learning dead in bf16 where step sizes round back to zero. IA3: lora_g now near_one instead of exact ones. Avoids the bf16 spacing issue around 1.0 where eps_bf16 ~ 7.8e-3 and lr=1e-3 updates were rounding away. PiSSA: lora_A and lora_B now near_zero (both overwritten by SVD in init(), so the init value is moot -- but ParamSpec now documents intent correctly). HRA: lora_U now near_zero (overwritten by symmetric init in init()). ParamSpec: added 'near_zero' and 'near_one' init modes. Default changed from 'zeros' to 'near_zero'. Tests relaxed identity tolerances accordingly.
73 lines
2.6 KiB
Python
73 lines
2.6 KiB
Python
"""DoRA: weight-decomposed LoRA. Liu et al. 2024 https://arxiv.org/abs/2402.09353
|
|
|
|
W' = m * V / ||V||_c where V = W + (alpha/r) B A (||.||_c = per-output-row L2 norm)
|
|
|
|
Identity at t=0: B=0 and m=||W||_c -> y_new = Wx. Requires dense weight (nn.Linear only).
|
|
|
|
Refs:
|
|
- peft: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/dora.py
|
|
(offline: docs/refs/peft_lora_dora.py)
|
|
"""
|
|
import torch
|
|
from einops import einsum
|
|
from jaxtyping import Float
|
|
from torch import nn, Tensor as T
|
|
from dataclasses import dataclass
|
|
|
|
from ..variant import register, ParamSpec
|
|
from ..config import AdapterConfig, register_config
|
|
|
|
|
|
@register_config
|
|
@dataclass
|
|
class DoRAConfig(AdapterConfig):
|
|
variant: str = "dora"
|
|
|
|
|
|
@register
|
|
class DoRA:
|
|
name = "dora"
|
|
|
|
@staticmethod
|
|
def param_specs(d_in, d_out, cfg):
|
|
return dict(
|
|
lora_A=ParamSpec((cfg.r, d_in), init="kaiming"),
|
|
lora_B=ParamSpec((d_out, cfg.r), init="near_zero"),
|
|
# m is filled from ||W||_c during init(); shape (d_out,)
|
|
lora_m=ParamSpec((d_out,), init="near_zero"),
|
|
)
|
|
|
|
@staticmethod
|
|
def init(layer: nn.Module, cfg) -> None:
|
|
if type(layer) is not nn.Linear:
|
|
raise TypeError(
|
|
"DoRA needs ||W||_c, so v1 only supports plain nn.Linear. "
|
|
"For bnb layers, dequantize first or use LoRA/IA3."
|
|
)
|
|
with torch.no_grad():
|
|
W = layer.weight.data.float() # (d_out, d_in)
|
|
col_norm = W.norm(dim=1).to(layer.lora_m.dtype) # (d_out,)
|
|
layer.lora_m.data.copy_(col_norm)
|
|
|
|
@staticmethod
|
|
def forward(
|
|
layer: nn.Module,
|
|
x: Float[T, '*B i'],
|
|
y: Float[T, '*B o'],
|
|
) -> Float[T, '*B o']:
|
|
cfg = layer._lora_cfg
|
|
scale = cfg.alpha / cfg.r
|
|
# Paper §4.3: treat ||V+ΔV||_c as a constant (detach from grad graph) for
|
|
# stability and ~2x lower memory. Match peft (lora_weight.detach + weight_norm.detach).
|
|
BA = einsum(layer.lora_B, layer.lora_A, "o r, r i -> o i")
|
|
V = layer.weight + scale * BA.detach() # (d_out, d_in)
|
|
v_norm = V.norm(dim=1).clamp_min(1e-12).detach() # (d_out,)
|
|
# Bias passes through unscaled (matches peft).
|
|
bias = getattr(layer, "bias", None)
|
|
wx = y if bias is None else (y - bias)
|
|
h = einsum(x, layer.lora_A, "... i, r i -> ... r")
|
|
delta = einsum(h, layer.lora_B, "... r, o r -> ... o")
|
|
combined = wx + scale * delta
|
|
out = (layer.lora_m / v_norm) * combined
|
|
return out if bias is None else out + bias
|