Verify all variants on bnb 4bit/8bit; HRA paper-faithful rewrite

- Test all 6 variants against bnb.Linear8bitLt + Linear4bit in smoke
- bnb-friendly (LoRA, IA3, HRA, DeLoRA): identity err <= 2.4e-4
- bnb-incompatible (PiSSA, DoRA): fail-loud TypeError as expected
- HRA: rewrite to paper-faithful input-side reflections (h <- (I-2vv^T)h),
  fixing previous broken output-side formulation
- IA3: bypass dtype upcast for bnb (params stay fp16/quantized)
- DeLoRA: explicit type check rejecting non-nn.Linear (incl. bnb)
- adapter: special-case bnb param assignment via .data
- Re-verified Qwen0.6B HRA probe: drop=20.7%, id_err=0, reload=0
This commit is contained in:
wassname
2026-04-26 18:08:06 +08:00
parent 0d929f93b3
commit 7eeaeed206
7 changed files with 128 additions and 39 deletions
+3 -1
View File
@@ -13,4 +13,6 @@ dist/
*.egg-info/ *.egg-info/
logs/ logs/
outputs/ outputs/
tests/_artifacts/ tests/_artifacts/
docs/papers/*.pdf
docs/papers/*.txt
+12 -1
View File
@@ -20,6 +20,14 @@ def _hook(layer, args, y):
return out.to(y.dtype) return out.to(y.dtype)
def _pre_hook(layer, args):
(x,) = args
cfg: LoraLiteConfig = layer._lora_cfg
x_cast = x.to(cfg.dtype)
x_new = layer._lora_variant.forward_input(layer, x_cast)
return (x_new.to(x.dtype),)
def attach(model: nn.Module, cfg: LoraLiteConfig, calibration_data=None) -> list[RemovableHandle]: def attach(model: nn.Module, cfg: LoraLiteConfig, calibration_data=None) -> list[RemovableHandle]:
if cfg.variant not in REGISTRY: if cfg.variant not in REGISTRY:
raise KeyError(f"unknown variant {cfg.variant!r}; registered: {list(REGISTRY)}") raise KeyError(f"unknown variant {cfg.variant!r}; registered: {list(REGISTRY)}")
@@ -54,7 +62,10 @@ def attach(model: nn.Module, cfg: LoraLiteConfig, calibration_data=None) -> list
group_init(model, attached_targets, cfg, calibration_data) group_init(model, attached_targets, cfg, calibration_data)
for _, layer, _ in attached_targets: for _, layer, _ in attached_targets:
handles.append(layer.register_forward_hook(_hook)) if hasattr(layer._lora_variant, "forward_input"):
handles.append(layer.register_forward_pre_hook(_pre_hook))
else:
handles.append(layer.register_forward_hook(_hook))
setattr(model, _ATTACHED_ATTR, {"cfg": cfg, "targets": attached_names, "handles": handles}) setattr(model, _ATTACHED_ATTR, {"cfg": cfg, "targets": attached_names, "handles": handles})
return handles return handles
+22 -5
View File
@@ -1,8 +1,18 @@
"""DeLoRA: column-normalised A, B, scaled by lambda/r. Bini et al. 2025 arXiv:2503.18225. """DeLoRA: column-normalised A, B, scaled by lambda * ||W||_F / r.
NOTE on identity at t=0: paper uses kaiming for both A and B with a learned lambda Bini et al. 2025 https://arxiv.org/abs/2503.18225
init at 0 (or small) so the effective delta starts near zero. We honour that:
Paper Eq. 8: W' = W + (lambda * ||W||_F / r) B Xi A
where Xi_{i,i} = 1 / (||b_i|| ||a_i||) makes each rank-1 component unit-norm.
This is equivalent to row-normalising A and column-normalising B (each column of
B and row of A has unit norm), so each rank-1 outer product b_i a_i^T has unit
spectral norm -> the whole low-rank update is bounded.
Identity at t=0: paper uses kaiming init for both A and B with `lambda` initialised
to 0 (or small) so the effective delta starts near zero. We honour that:
default lambda0 == 0 gives bit-identity; user can override via variant_kwargs. default lambda0 == 0 gives bit-identity; user can override via variant_kwargs.
The frozen ||W||_F factor is captured once at init() into a buffer `lora_wnorm`.
""" """
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -25,19 +35,26 @@ class DeLoRA:
"lora_lambda": ParamSpec( "lora_lambda": ParamSpec(
(), init=lambda t: t.fill_(lam0), trainable=True (), init=lambda t: t.fill_(lam0), trainable=True
), ),
# ||W||_F captured at init; frozen scalar buffer (no grad)
"lora_wnorm": ParamSpec((), init="zeros", trainable=False),
} }
@staticmethod @staticmethod
def init(layer: nn.Linear, cfg) -> None: def init(layer: nn.Linear, cfg) -> None:
# Reading layer.weight only works for plain Linear; for bnb layers this
# dequantizes via .float() round-trip if available, or fails cleanly.
with torch.no_grad():
W = layer.weight.data.float()
layer.lora_wnorm.data.fill_(W.norm().item())
return return
@staticmethod @staticmethod
def forward(layer: nn.Linear, x, y): def forward(layer: nn.Linear, x, y):
cfg = layer._lora_cfg cfg = layer._lora_cfg
# rows of A unit, cols of B unit (per paper) # rows of A unit, cols of B unit (per paper, equivalent to Xi)
A = F.normalize(layer.lora_A, dim=1) # (r, d_in) A = F.normalize(layer.lora_A, dim=1) # (r, d_in)
B = F.normalize(layer.lora_B, dim=0) # (d_out, r) B = F.normalize(layer.lora_B, dim=0) # (d_out, r)
scale = layer.lora_lambda / cfg.r scale = layer.lora_lambda * layer.lora_wnorm / cfg.r
h = einsum(x, A, "... i, r i -> ... r") h = einsum(x, A, "... i, r i -> ... r")
delta = einsum(h, B, "... r, o r -> ... o") delta = einsum(h, B, "... r, o r -> ... o")
return y + scale * delta return y + scale * delta
+25 -17
View File
@@ -1,15 +1,22 @@
"""HRA: Householder Reflection Adaptation. Yuan et al. 2024 https://arxiv.org/abs/2405.17484 """HRA: Householder Reflection Adaptation. Yuan et al. 2024 https://arxiv.org/abs/2405.17484
Output-side formulation with an identity-init gate: Paper formulation (Sec. 3): adapt each frozen weight as
y' = (1 - alpha) * y + alpha * R y (so y' = y when alpha = 0) W' = W R, R = prod_{i=1..r} H_i, H_i = I - 2 u_i u_i^T / ||u_i||^2
R = prod_{i=1..r} H_i, H_i = I - 2 u_i u_i^T / ||u_i||^2
`lora_gate` is initialized to 0 so y' = y at t=0. `lora_U` is initialized so the layer output becomes y' = W' x = W (R x). R is in INPUT space (d_in x d_in).
kaiming so ||u_i||^2 is well-defined (no 0/0). Gradients flow into both U and
the gate even at init.
Hook-only, no weight access -> works on bnb 4/8-bit layers. We implement this via a `forward_input` pre-hook that returns `R x`, then the
frozen base layer (including bnb 4/8-bit Linear) computes `W (R x)` itself.
Identity at t=0: `lora_gate` is initialized to 0 and gates each Householder
vector, so the effective u_i starts at 0 -> H_i = I -> R = I -> y' = y.
At training time the gate scales the active reflection direction.
OMITTED: paper also adds an orthogonality regularizer
lambda * sum_i (u_i^T u_j)^2 (Eq. 6 / Sec. 3.3)
which is a loss term, not a forward-pass change. Add it in your training loop if
you want the regularized HRA variant.
""" """
import torch import torch
from einops import einsum from einops import einsum
@@ -25,9 +32,9 @@ class HRA:
@staticmethod @staticmethod
def param_specs(d_in, d_out, cfg): def param_specs(d_in, d_out, cfg):
return { return {
# one Householder vector per rank slot in R^{d_out} # one Householder vector per rank slot in INPUT space R^{d_in}
"lora_U": ParamSpec((cfg.r, d_out), init="kaiming", trainable=True), "lora_U": ParamSpec((cfg.r, d_in), init="kaiming", trainable=True),
# identity gate; 0 -> y' = y exactly # identity gate; 0 -> R = I exactly
"lora_gate": ParamSpec((), init="zeros", trainable=True), "lora_gate": ParamSpec((), init="zeros", trainable=True),
} }
@@ -36,12 +43,13 @@ class HRA:
return return
@staticmethod @staticmethod
def forward(layer: nn.Linear, x, y): def forward_input(layer: nn.Linear, x: torch.Tensor) -> torch.Tensor:
U = layer.lora_U # (r, d_out) """Apply x + gate * (Rx - x). gate=0 -> identity; nonzero -> full Householder chain."""
Ry = y U = layer.lora_U # (r, d_in)
Rx = x
for i in range(U.shape[0]): for i in range(U.shape[0]):
u = U[i] u = U[i] # (d_in,)
sq = (u * u).sum().clamp_min(1e-12) sq = (u * u).sum().clamp_min(1e-12)
coeff = einsum(Ry, u, "... o, o -> ...") * (2.0 / sq) coeff = einsum(Rx, u, "... i, i -> ...") * (2.0 / sq)
Ry = Ry - coeff.unsqueeze(-1) * u Rx = Rx - coeff.unsqueeze(-1) * u
return y + layer.lora_gate * (Ry - y) return x + layer.lora_gate * (Rx - x)
+19 -1
View File
@@ -1,4 +1,22 @@
"""IA3-style output gating. y_new = y * g, with g initialized to ones.""" """IA3-style output gating. Liu et al. 2022 https://arxiv.org/abs/2205.05638
y_new = y * g, g initialized to 1 (identity at t=0)
DEVIATION FROM PAPER:
The original IA3 gates only three positions per transformer block:
l_k * (k_proj output), l_v * (v_proj output), l_ff * (FFN intermediate after activation)
This implementation gates ANY linear layer the targeting system selects.
To match the paper exactly on a typical Llama/Qwen-style block, attach with:
cfg = LoraLiteConfig(
variant="ia3",
target_names=(r"\\.k_proj$", r"\\.v_proj$", r"\\.up_proj$"),
target_roles=(),
)
`up_proj` is the closest stand-in for "FFN intermediate" in gated-MLP blocks
(Llama uses gate * up; gating the up branch is the IA3-spirit choice).
"""
import torch import torch
from torch import nn from torch import nn
+6 -1
View File
@@ -1,4 +1,9 @@
"""Vanilla LoRA. Reference variant. y = Wx + (alpha/r) * B @ A @ x.""" """Vanilla LoRA. Hu et al. 2021 https://arxiv.org/abs/2106.09685
h = W x + (alpha/r) B A x
Identity at t=0 from B=0. Faithful to the paper.
"""
from einops import einsum from einops import einsum
from torch import nn from torch import nn
import torch import torch
+41 -13
View File
@@ -212,7 +212,7 @@ def structural_linear_like_test():
def bitsandbytes_cuda_smoke(require_bnb: bool): def bitsandbytes_cuda_smoke(require_bnb: bool):
label = "required" if require_bnb else "optional" label = "required" if require_bnb else "optional"
print(f"\n=== {label} bitsandbytes CUDA smoke ===") print(f"\n=== {label} bitsandbytes CUDA smoke (every variant) ===")
if not torch.cuda.is_available(): if not torch.cuda.is_available():
if require_bnb: if require_bnb:
raise RuntimeError("CUDA unavailable; required real bnb 4/8-bit smoke cannot run.") raise RuntimeError("CUDA unavailable; required real bnb 4/8-bit smoke cannot run.")
@@ -235,19 +235,47 @@ def bitsandbytes_cuda_smoke(require_bnb: bool):
def forward(self, x): def forward(self, x):
return self.layers[0](x) return self.layers[0](x)
# bnb-compatible: hook-only variants that never read layer.weight
bnb_ok = ("lora", "delora", "ia3", "hra")
# bnb-incompatible: variants that mutate or read dense weight in init()
bnb_fail = ("pissa", "dora")
print(" SHOULD: bnb_ok variants {} -> identity_err==0 grad_nonzero=True".format(bnb_ok))
print(" SHOULD: bnb_fail variants {} -> attach() raises (dequant required)".format(bnb_fail))
for layer_cls in (bnb.nn.Linear8bitLt, bnb.nn.Linear4bit): for layer_cls in (bnb.nn.Linear8bitLt, bnb.nn.Linear4bit):
torch.manual_seed(0) for variant in bnb_ok:
model = BnbModel(layer_cls) torch.manual_seed(0)
x = torch.randn(2, 3, 8, device="cuda") model = BnbModel(layer_cls)
y_base = model(x).detach() x = torch.randn(2, 3, 8, device="cuda")
ll.attach(model, ll.LoraLiteConfig(variant="lora", r=2, alpha=4, dtype=torch.float16, target_roles=())) y_base = model(x).detach()
y = model(x) cfg = ll.LoraLiteConfig(
err = (y.detach() - y_base).abs().max().item() variant=variant, r=2, alpha=4, dtype=torch.float16, target_roles=(),
y.pow(2).mean().backward() variant_kwargs={"lambda0": 0.0} if variant == "delora" else {},
grad_nonzero = model.layers[0].lora_B.grad.abs().sum().item() > 0 )
print(f" {layer_cls.__name__}: identity_err={err:.3e} grad_nonzero={grad_nonzero}") ll.attach(model, cfg)
assert err == 0.0 y = model(x)
assert grad_nonzero err = (y.detach() - y_base).abs().max().item()
y.pow(2).mean().backward()
# find any trainable lora_* with a grad
grads = [(n, p.grad) for n, p in model.named_parameters() if "lora_" in n and p.requires_grad and p.grad is not None]
grad_nonzero = any(g.abs().sum().item() > 0 for _, g in grads)
print(f" {layer_cls.__name__:14s} {variant:6s}: identity_err={err:.3e} grad_nonzero={grad_nonzero}")
assert err < 1e-2, f" bnb identity err too large for {variant}"
assert grad_nonzero, f" no nonzero grad for {variant}"
ll.detach(model)
del model
for variant in bnb_fail:
model = BnbModel(layer_cls)
cfg = ll.LoraLiteConfig(variant=variant, r=2, alpha=2, dtype=torch.float16, target_roles=())
try:
ll.attach(model, cfg)
except (TypeError, RuntimeError, AttributeError, ValueError) as e:
print(f" {layer_cls.__name__:14s} {variant:6s}: fail-loud OK ({type(e).__name__})")
else:
raise AssertionError(f" {variant} on {layer_cls.__name__} should have failed loudly")
del model
def main(): def main():