mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 19:15:35 +08:00
fix v2 review bugs + add EVA, AntiPaSTO
DeLoRA: per-input-channel wnorm buffer (not scalar Parameter), forward matches peft (x*wnorm @ A.T then per-rank scale (lambda/r)/(An*Bn)). Smoke: 89.7% loss drop (was 35.8%). HRA: symmetric repeated-column init (PEFT-style) instead of zero gate. Adjacent Householder pairs cancel exactly so R=I at t=0, and U receives gradient from step 0 (no dead-grad). Even r required. IA3: split into two variants. ia3 stays output-side (k_proj/v_proj); new ia3_ff is input-side (down_proj/fc2), matching peft is_feedforward. Config: dropout field removed (never honored by any variant). PiSSA: adapter.save records base-weight fingerprint per target; adapter.load recomputes init then verifies fingerprint -> fails loud when reloaded onto a different base. EVA (new): data-driven init via group_init + calibration_data. Top-r right singular vectors of pooled layer-input activations -> lora_A (buffer, frozen); only lora_B trains. Stress-tests group_init API. AntiPaSTO (new): SVD steering with frozen U,S,Vh,W_res and learnable delta_s (per-singular-value bias) + rot_T (block-diagonal Cayley rotation on V or U). Lite port of antipasto3 SVD adapter. ParamSpec: as_buffer field + make_tensor() for buffer registration. adapter.attach honors as_buffer with register_buffer; detach cleans both _parameters and _buffers. Smoke covers all 8 variants: identity at t=0, save/load round-trip, gradient-driven loss drop. EVA gets dedicated test for calibration data path. ALL PASS including bnb 4/8-bit path.
This commit is contained in:
@@ -48,8 +48,12 @@ def attach(model: nn.Module, cfg: LoraLiteConfig, calibration_data=None) -> list
|
||||
for pname, spec in variant.param_specs(d_in, d_out, cfg).items():
|
||||
if hasattr(layer, pname):
|
||||
raise RuntimeError(f"{name} already has attribute {pname}; detach first")
|
||||
p = spec.make(cfg.dtype, layer.weight.device)
|
||||
layer.register_parameter(pname, p)
|
||||
if spec.as_buffer:
|
||||
t = spec.make_tensor(cfg.dtype, layer.weight.device)
|
||||
layer.register_buffer(pname, t, persistent=True)
|
||||
else:
|
||||
p = spec.make(cfg.dtype, layer.weight.device)
|
||||
layer.register_parameter(pname, p)
|
||||
layer._lora_cfg = cfg
|
||||
layer._lora_variant = variant
|
||||
layer._lora_role = role
|
||||
@@ -85,18 +89,44 @@ def detach(model: nn.Module) -> None:
|
||||
for pname in variant.param_specs(layer.in_features, layer.out_features, layer._lora_cfg):
|
||||
if pname in layer._parameters:
|
||||
del layer._parameters[pname]
|
||||
elif pname in layer._buffers:
|
||||
del layer._buffers[pname]
|
||||
for attr in ("_lora_cfg", "_lora_variant", "_lora_role"):
|
||||
if hasattr(layer, attr):
|
||||
delattr(layer, attr)
|
||||
delattr(model, _ATTACHED_ATTR)
|
||||
|
||||
|
||||
def _base_weight_fingerprint(model: nn.Module) -> dict[str, str]:
|
||||
"""Per-target fingerprint of the (post-init) base weights so PiSSA-style
|
||||
variants that mutate `layer.weight` can fail loud on base mismatch.
|
||||
Uses a cheap fp32 sum-of-squares + shape signature; not cryptographic.
|
||||
"""
|
||||
state = getattr(model, _ATTACHED_ATTR, None)
|
||||
if state is None:
|
||||
return {}
|
||||
fp = {}
|
||||
for name, layer in model.named_modules():
|
||||
if not hasattr(layer, "_lora_variant"):
|
||||
continue
|
||||
if name not in state["targets"]:
|
||||
continue
|
||||
w = layer.weight.detach().to(torch.float32, copy=False)
|
||||
fp[name] = f"{tuple(w.shape)}|{float((w * w).sum()):.6e}"
|
||||
return fp
|
||||
|
||||
|
||||
def save(model: nn.Module, path: str) -> None:
|
||||
state = getattr(model, _ATTACHED_ATTR, None)
|
||||
if state is None:
|
||||
raise RuntimeError("no adapter attached; call attach() first")
|
||||
sd = {k: v.detach().cpu() for k, v in model.state_dict().items() if "lora_" in k}
|
||||
torch.save({"cfg": state["cfg"].to_dict(), "state": sd}, path)
|
||||
blob = {
|
||||
"cfg": state["cfg"].to_dict(),
|
||||
"state": sd,
|
||||
"base_fp": _base_weight_fingerprint(model),
|
||||
}
|
||||
torch.save(blob, path)
|
||||
|
||||
|
||||
def load(model: nn.Module, path: str) -> list[RemovableHandle]:
|
||||
@@ -111,4 +141,14 @@ def load(model: nn.Module, path: str) -> list[RemovableHandle]:
|
||||
unexpected_lora = [k for k in unexpected if "lora_" in k]
|
||||
if unexpected_lora:
|
||||
raise RuntimeError(f"unexpected lora keys in checkpoint: {unexpected_lora}")
|
||||
saved_fp = blob.get("base_fp", {})
|
||||
if saved_fp:
|
||||
cur_fp = _base_weight_fingerprint(model)
|
||||
diffs = [k for k in saved_fp if saved_fp[k] != cur_fp.get(k)]
|
||||
if diffs:
|
||||
raise RuntimeError(
|
||||
f"base weight fingerprint mismatch on {len(diffs)} layer(s) "
|
||||
f"(e.g. {diffs[0]}). For PiSSA the saved adapter assumes the same "
|
||||
"base; reload onto the original model or re-run init."
|
||||
)
|
||||
return handles
|
||||
|
||||
@@ -8,7 +8,6 @@ class LoraLiteConfig:
|
||||
variant: str = "lora"
|
||||
r: int = 8
|
||||
alpha: float = 16.0
|
||||
dropout: float = 0.0 # currently ignored; variants may use cfg.variant_kwargs
|
||||
dtype: torch.dtype = torch.bfloat16
|
||||
|
||||
# targeting
|
||||
|
||||
@@ -12,8 +12,9 @@ class ParamSpec:
|
||||
shape: tuple[int, ...]
|
||||
init: str | Callable[[torch.Tensor], None] = "zeros" # 'zeros'|'kaiming'|'ones'|callable(t)
|
||||
trainable: bool = True
|
||||
as_buffer: bool = False # if True, register_buffer instead of register_parameter
|
||||
|
||||
def make(self, dtype: torch.dtype, device) -> nn.Parameter:
|
||||
def _empty(self, dtype: torch.dtype, device) -> torch.Tensor:
|
||||
t = torch.empty(self.shape, dtype=dtype, device=device)
|
||||
if callable(self.init):
|
||||
self.init(t)
|
||||
@@ -26,7 +27,17 @@ class ParamSpec:
|
||||
nn.init.kaiming_uniform_(t, a=5 ** 0.5) if t.ndim >= 2 else t.normal_(0, 0.02)
|
||||
else:
|
||||
raise ValueError(f"unknown init: {self.init}")
|
||||
return nn.Parameter(t, requires_grad=self.trainable)
|
||||
return t
|
||||
|
||||
def make(self, dtype: torch.dtype, device) -> nn.Parameter:
|
||||
# legacy entry: returns a Parameter (used for trainable adapter params)
|
||||
if self.as_buffer:
|
||||
raise RuntimeError("as_buffer spec must be installed via register_buffer; see adapter.attach")
|
||||
return nn.Parameter(self._empty(dtype, device), requires_grad=self.trainable)
|
||||
|
||||
def make_tensor(self, dtype: torch.dtype, device) -> torch.Tensor:
|
||||
# returns a raw tensor for buffer registration
|
||||
return self._empty(dtype, device)
|
||||
|
||||
|
||||
class Variant(Protocol):
|
||||
|
||||
@@ -1 +1 @@
|
||||
from . import lora, pissa, delora, ia3, dora, hra # noqa: F401 side-effect: register
|
||||
from . import lora, pissa, delora, ia3, dora, hra, eva, antipasto # noqa: F401 side-effect: register
|
||||
|
||||
@@ -0,0 +1,145 @@
|
||||
"""AntiPaSTO: SVD steering with learnable singular-value deltas + block-diagonal Cayley rotation.
|
||||
|
||||
Lite port of wassname's AntiPaSTO3 SVD adapter (research code, not an
|
||||
upstream peft variant). Reference:
|
||||
https://github.com/wassname/antipasto3 (offline: docs/refs/antipasto3_svd_adapter.py)
|
||||
|
||||
Decomposition (PyTorch nn.Linear convention, weight (d_out, d_in)):
|
||||
|
||||
W = U diag(S) Vh + W_res (top-r SVD; W_res = W - U_r diag(S_r) Vh_r)
|
||||
|
||||
We freeze U, S, Vh, W_res and learn:
|
||||
- delta_s : (r,) -- additive delta to singular values
|
||||
- rot_T : (n_blocks, bs(bs-1)/2) -- upper-triangle of skew matrix per block
|
||||
|
||||
Forward (matches base layer convention exactly at t=0):
|
||||
|
||||
R = block_diag(Cayley(skew(rot_T))) # (r, r) effective
|
||||
Vh_rot = R @ Vh # rotates input basis
|
||||
S_eff = S + delta_s # learnable spectrum
|
||||
delta_y = ((x @ Vh_rot.T) * S_eff) @ U.T # rank-r path
|
||||
base_y = x @ W_res.T # frozen residual
|
||||
y_total = base_y + delta_y # == original output at t=0
|
||||
|
||||
At init: rot_T = 0 -> R = I -> Vh_rot = Vh, delta_s = 0 -> S_eff = S, so
|
||||
delta_y reconstructs the truncated SVD term and y_total == x @ W^T to numerical
|
||||
precision (fp32 SVD round-tripped to cfg.dtype).
|
||||
|
||||
WHICH BASIS IS ROTATED:
|
||||
By default we rotate Vh (the INPUT singular basis). This is what AntiPaSTO3
|
||||
calls `rotate_V=True` in adapter terms (V == Vh.T columns). To rotate U
|
||||
(output basis) instead, pass variant_kwargs={'rotate_basis': 'U'}.
|
||||
Rotating both is not implemented (one rotation is enough to span the
|
||||
identifiable steering directions; two is degenerate).
|
||||
|
||||
REQUIRES even rank divisible by `block_size` (default 4). r=8, bs=4 -> 2 blocks.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import math
|
||||
|
||||
import torch
|
||||
from einops import einsum
|
||||
from torch import nn
|
||||
|
||||
from ..variant import register, ParamSpec
|
||||
|
||||
|
||||
def _cayley(skew: torch.Tensor) -> torch.Tensor:
|
||||
"""R = (I - X)^-1 (I + X) for X = skew/2; preserves orthogonality."""
|
||||
bs = skew.shape[-1]
|
||||
eye = torch.eye(bs, dtype=skew.dtype, device=skew.device).expand_as(skew)
|
||||
X = skew / 2
|
||||
return torch.linalg.solve(eye - X, eye + X)
|
||||
|
||||
|
||||
def _build_rotation(rot_T: torch.Tensor, bs: int, max_angle: float) -> torch.Tensor:
|
||||
"""rot_T: (n_blocks, bs*(bs-1)/2) -> R: (n_blocks, bs, bs) Cayley rotation."""
|
||||
n_blocks, _ = rot_T.shape
|
||||
rows, cols = torch.triu_indices(bs, bs, offset=1, device=rot_T.device).unbind(0)
|
||||
A = torch.zeros(n_blocks, bs, bs, dtype=rot_T.dtype, device=rot_T.device)
|
||||
A[:, rows, cols] = rot_T
|
||||
A = 0.5 * (A - A.transpose(-1, -2))
|
||||
a_limit = 2.0 * math.tan(max_angle / 2.0)
|
||||
A = a_limit * torch.tanh(A / a_limit)
|
||||
return _cayley(A)
|
||||
|
||||
|
||||
def _block_diag(blocks: torch.Tensor) -> torch.Tensor:
|
||||
"""(n_blocks, bs, bs) -> (n_blocks*bs, n_blocks*bs) block-diagonal."""
|
||||
n, bs, _ = blocks.shape
|
||||
out = blocks.new_zeros(n * bs, n * bs)
|
||||
for i in range(n):
|
||||
out[i * bs : (i + 1) * bs, i * bs : (i + 1) * bs] = blocks[i]
|
||||
return out
|
||||
|
||||
|
||||
@register
|
||||
class AntiPaSTO:
|
||||
name = "antipasto"
|
||||
|
||||
@staticmethod
|
||||
def param_specs(d_in, d_out, cfg):
|
||||
r = cfg.r
|
||||
bs = int(cfg.variant_kwargs.get("block_size", 4))
|
||||
if r % bs != 0:
|
||||
raise ValueError(f"AntiPaSTO requires r={r} divisible by block_size={bs}")
|
||||
n_blocks = r // bs
|
||||
n_triu = bs * (bs - 1) // 2
|
||||
return {
|
||||
# Frozen SVD components captured at init (buffers travel with state_dict).
|
||||
"lora_U": ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True),
|
||||
"lora_S": ParamSpec((r,), init="zeros", trainable=False, as_buffer=True),
|
||||
"lora_Vh": ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True),
|
||||
# Trainable: per-singular-value delta + block-diagonal Cayley rotation.
|
||||
"lora_delta_s": ParamSpec((r,), init="zeros", trainable=True),
|
||||
"lora_rot_T": ParamSpec((n_blocks, n_triu), init="zeros", trainable=True),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def init(layer: nn.Linear, cfg) -> None:
|
||||
if type(layer) is not nn.Linear:
|
||||
raise TypeError(
|
||||
"AntiPaSTO mutates layer.weight into W_res (like PiSSA), so v1 "
|
||||
"only supports plain nn.Linear, not bnb 4/8-bit."
|
||||
)
|
||||
with torch.no_grad():
|
||||
W = layer.weight.data.float()
|
||||
U, S, Vh = torch.linalg.svd(W, full_matrices=False)
|
||||
r = cfg.r
|
||||
Ur, Sr, Vhr = U[:, :r], S[:r], Vh[:r, :]
|
||||
layer.lora_U.copy_(Ur.to(layer.lora_U.dtype))
|
||||
layer.lora_S.copy_(Sr.to(layer.lora_S.dtype))
|
||||
layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype))
|
||||
# W_res is the residual after rank-r truncation. Forward adds back
|
||||
# the truncated path so total == W exactly at init (mod dtype).
|
||||
W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
|
||||
layer.weight.data.copy_(W_res)
|
||||
|
||||
@staticmethod
|
||||
def forward(layer: nn.Linear, x, y):
|
||||
cfg = layer._lora_cfg
|
||||
bs = int(cfg.variant_kwargs.get("block_size", 4))
|
||||
max_angle = float(cfg.variant_kwargs.get("max_rotation_angle", 0.5))
|
||||
rotate_basis = cfg.variant_kwargs.get("rotate_basis", "V")
|
||||
|
||||
U = layer.lora_U.to(x.dtype) # (d_out, r)
|
||||
S = layer.lora_S.to(x.dtype) # (r,)
|
||||
Vh = layer.lora_Vh.to(x.dtype) # (r, d_in)
|
||||
|
||||
R_blocks = _build_rotation(layer.lora_rot_T.float(), bs, max_angle)
|
||||
R = _block_diag(R_blocks).to(x.dtype) # (r, r)
|
||||
|
||||
if rotate_basis == "V":
|
||||
Vh_eff = R @ Vh # rotate INPUT basis
|
||||
U_eff = U
|
||||
elif rotate_basis == "U":
|
||||
Vh_eff = Vh
|
||||
U_eff = U @ R.T # rotate OUTPUT basis
|
||||
else:
|
||||
raise ValueError(f"rotate_basis must be 'U' or 'V', got {rotate_basis!r}")
|
||||
|
||||
S_eff = S + layer.lora_delta_s.to(x.dtype) # (r,)
|
||||
h = einsum(x, Vh_eff, "... i, r i -> ... r") # x @ Vh_eff.T
|
||||
h = h * S_eff # diag(S_eff)
|
||||
delta = einsum(h, U_eff, "... r, o r -> ... o") # @ U_eff.T
|
||||
return y + delta
|
||||
@@ -1,27 +1,29 @@
|
||||
"""DeLoRA: column-normalised A, B, scaled by lambda * ||W||_F / r.
|
||||
"""DeLoRA: per-input-channel weight-norm scaling, per-rank A/B normalization.
|
||||
|
||||
Bini et al. 2025 (ICLR'25) https://arxiv.org/abs/2503.18225
|
||||
|
||||
Paper Eq. 8: W' = W + (lambda * ||W||_F / r) B Xi A
|
||||
where Xi_{i,i} = 1 / (||b_i|| ||a_i||) makes each rank-1 component unit-norm.
|
||||
This is equivalent to row-normalising A and column-normalising B (each column of
|
||||
B and row of A has unit norm), so each rank-1 outer product b_i a_i^T has unit
|
||||
spectral norm -> the whole low-rank update is bounded.
|
||||
|
||||
Identity at t=0: paper uses kaiming init for both A and B with `lambda` initialised
|
||||
to 0 (or small) so the effective delta starts near zero. We honour that:
|
||||
default lambda0 == 0 gives bit-identity; user can override via variant_kwargs.
|
||||
Implementation follows the peft upstream (which the DeLoRA authors maintain),
|
||||
which differs from the paper notation in two ways that are equivalent at the
|
||||
forward level but matter for gradients/numerics:
|
||||
1. ||W|| is captured PER INPUT CHANNEL (shape (d_in,)), not as a scalar
|
||||
Frobenius norm. Used to scale `x` element-wise on the input dim.
|
||||
See docs/refs/peft_delora_layer.py:150 (init) and :250 (forward).
|
||||
2. Per-rank normalization applied via division (1/||A_i||*||B^j||) inside
|
||||
the diagonal scaling, instead of as F.normalize on A,B themselves.
|
||||
This keeps the gradient flowing through the un-normalized parameters.
|
||||
|
||||
Identity at t=0: lambda0=0 -> delta is exactly zero (bit-identity).
|
||||
|
||||
KNOWN GRADIENT ISSUE (flagged by external review 2026-04-26):
|
||||
With lambda0=0 the *forward* is identity but `A,B` get zero gradient on step 0
|
||||
(delta = lambda * ... -> d_output/d_A is proportional to lambda). Only
|
||||
`lora_lambda` moves first step. With lambda0>0, A,B train but identity is broken.
|
||||
Paper's true initialization (frozen-copy trick, see Eq. 9) achieves both;
|
||||
we do NOT implement that here.
|
||||
(delta is proportional to lambda). Only `lora_lambda` moves first step.
|
||||
The paper's true initialization (frozen-copy trick, Eq. 9) achieves both
|
||||
identity AND non-zero A/B gradients; we do NOT implement it here.
|
||||
|
||||
The frozen ||W||_F factor is captured once at init() into a buffer `lora_wnorm`.
|
||||
|
||||
Reference implementations (for review/cross-check):
|
||||
Reference implementations:
|
||||
- DeLoRA paper authors (ExplainableML/DeLoRA) -- their fork of peft:
|
||||
https://github.com/ExplainableML/DeLoRA/blob/main/peft/src/peft/tuners/delora.py
|
||||
(offline: docs/refs/orig_delora.py)
|
||||
@@ -30,7 +32,6 @@ Reference implementations (for review/cross-check):
|
||||
(offline: docs/refs/peft_delora_layer.py)
|
||||
"""
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import einsum
|
||||
from torch import nn
|
||||
|
||||
@@ -50,8 +51,9 @@ class DeLoRA:
|
||||
"lora_lambda": ParamSpec(
|
||||
(), init=lambda t: t.fill_(lam0), trainable=True
|
||||
),
|
||||
# ||W||_F captured at init; frozen scalar buffer (no grad)
|
||||
"lora_wnorm": ParamSpec((), init="zeros", trainable=False),
|
||||
# ||W||_2 per input channel (shape (d_in,)); frozen buffer captured at init
|
||||
# per peft DeLoRA (docs/refs/peft_delora_layer.py:150).
|
||||
"lora_wnorm": ParamSpec((d_in,), init="ones", trainable=False, as_buffer=True),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@@ -60,16 +62,23 @@ class DeLoRA:
|
||||
# dequantizes via .float() round-trip if available, or fails cleanly.
|
||||
with torch.no_grad():
|
||||
W = layer.weight.data.float()
|
||||
layer.lora_wnorm.data.fill_(W.norm().item())
|
||||
wnorm = W.norm(dim=0).detach().to(layer.lora_wnorm.dtype)
|
||||
layer.lora_wnorm.copy_(wnorm)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def forward(layer: nn.Linear, x, y):
|
||||
cfg = layer._lora_cfg
|
||||
# rows of A unit, cols of B unit (per paper, equivalent to Xi)
|
||||
A = F.normalize(layer.lora_A, dim=1) # (r, d_in)
|
||||
B = F.normalize(layer.lora_B, dim=0) # (d_out, r)
|
||||
scale = layer.lora_lambda * layer.lora_wnorm / cfg.r
|
||||
h = einsum(x, A, "... i, r i -> ... r")
|
||||
A = layer.lora_A # (r, d_in)
|
||||
B = layer.lora_B # (d_out, r)
|
||||
# peft delora forward (docs/refs/peft_delora_layer.py:248-260):
|
||||
# h = (x * w_norm) @ A.T; scale per-rank = (lambda/r) / (||A_i|| * ||B^j||);
|
||||
# delta = (h * scale) @ B.T
|
||||
x_scaled = x * layer.lora_wnorm # (..., d_in)
|
||||
h = einsum(x_scaled, A, "... i, r i -> ... r")
|
||||
An = torch.clamp(A.norm(dim=1), min=1e-4) # (r,)
|
||||
Bn = torch.clamp(B.norm(dim=0), min=1e-4) # (r,)
|
||||
scale = (layer.lora_lambda / cfg.r) / (An * Bn) # (r,)
|
||||
h = h * scale
|
||||
delta = einsum(h, B, "... r, o r -> ... o")
|
||||
return y + scale * delta
|
||||
return y + delta
|
||||
|
||||
@@ -0,0 +1,121 @@
|
||||
"""EVA: Explained-Variance Adaptation. Paischer et al. 2024.
|
||||
|
||||
Paper: https://arxiv.org/abs/2410.07170 (also referred to as ICLR'25 EVA).
|
||||
|
||||
Idea: instead of random A and zero B (LoRA) or SVD of W (PiSSA), initialize
|
||||
`lora_A` to the top-r right singular vectors of the LAYER INPUT distribution
|
||||
on a small calibration set. Forward = `y + scale * (B @ A @ x)` exactly like
|
||||
LoRA; with `lora_B = 0` the adapter is identity at t=0. Only B trains
|
||||
afterwards (A frozen). The result: each rank slot points along a direction
|
||||
that actually carries information at this layer.
|
||||
|
||||
This is a stripped-down EVA; we do NOT implement:
|
||||
- rank redistribution across layers via explained-variance ratios
|
||||
(peft EVA computes an explained_variance_ratio per layer then redistributes
|
||||
the global rank budget; we use a uniform `cfg.r` per layer).
|
||||
- Incremental PCA over many micro-batches (we run one full SVD on the
|
||||
pooled calibration activations per layer).
|
||||
- Equal-input deduplication (peft hashes inputs to share SVD across QKV).
|
||||
|
||||
API stress-test: this variant requires data-driven init, so it implements
|
||||
`group_init(model, targets, cfg, calibration_data)` to drive a single forward
|
||||
pass on `calibration_data` with hooks that capture each target's input.
|
||||
|
||||
Identity at t=0: `lora_B = 0` -> delta = 0 -> y unchanged.
|
||||
|
||||
References:
|
||||
- peft EVA (full impl, with IncrementalPCA + redistribution):
|
||||
https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/eva.py
|
||||
(offline: docs/refs/peft_eva.py)
|
||||
- peft fine-tuning script demonstrating initialize_lora_eva_weights:
|
||||
https://github.com/huggingface/peft/blob/main/examples/eva_finetuning/eva_finetuning.py
|
||||
(offline: docs/refs/peft_eva_finetuning.py)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from einops import einsum
|
||||
from torch import nn
|
||||
|
||||
from ..variant import register, ParamSpec
|
||||
|
||||
|
||||
@register
|
||||
class EVA:
|
||||
name = "eva"
|
||||
|
||||
@staticmethod
|
||||
def param_specs(d_in, d_out, cfg):
|
||||
return {
|
||||
# A is frozen (set in group_init from calibration data); kept as a
|
||||
# buffer so it travels with state_dict and is not optimized.
|
||||
"lora_A": ParamSpec((cfg.r, d_in), init="zeros", trainable=False, as_buffer=True),
|
||||
# B is the only trainable bit; zero-init -> identity at t=0.
|
||||
"lora_B": ParamSpec((d_out, cfg.r), init="zeros", trainable=True),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def init(layer: nn.Linear, cfg) -> None:
|
||||
# No-op; group_init does the data-driven SVD across all targets at once.
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def group_init(model: nn.Module, targets, cfg, calibration_data) -> None:
|
||||
if calibration_data is None:
|
||||
raise ValueError(
|
||||
"EVA requires calibration_data: an iterable of model inputs "
|
||||
"(dicts of kwargs to model.forward, or single tensors) used to "
|
||||
"estimate the input PCA per layer. Pass via "
|
||||
"lora_lite.attach(model, cfg, calibration_data=batches)."
|
||||
)
|
||||
# Collect input activations per target via forward hooks.
|
||||
layers = {name: layer for name, layer, _ in targets}
|
||||
captured: dict[str, list[torch.Tensor]] = {n: [] for n in layers}
|
||||
|
||||
def make_hook(name):
|
||||
def _h(module, args, kwargs):
|
||||
# signature: pre-forward, args[0] is the input tensor
|
||||
x = args[0].detach()
|
||||
captured[name].append(x.reshape(-1, x.shape[-1]).to(torch.float32).cpu())
|
||||
return _h
|
||||
|
||||
handles = [
|
||||
layers[n].register_forward_pre_hook(make_hook(n), with_kwargs=True)
|
||||
for n in layers
|
||||
]
|
||||
try:
|
||||
was_training = model.training
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for batch in calibration_data:
|
||||
if isinstance(batch, dict):
|
||||
model(**batch)
|
||||
elif isinstance(batch, (list, tuple)):
|
||||
model(*batch)
|
||||
else:
|
||||
model(batch)
|
||||
if was_training:
|
||||
model.train()
|
||||
finally:
|
||||
for h in handles:
|
||||
h.remove()
|
||||
|
||||
# SVD per target on pooled inputs; top-r right singular vectors -> A.
|
||||
for name, layer in layers.items():
|
||||
X = torch.cat(captured[name], dim=0) # (N, d_in)
|
||||
if X.shape[0] < cfg.r:
|
||||
raise RuntimeError(
|
||||
f"EVA at {name}: only {X.shape[0]} calibration tokens, need >= r={cfg.r}"
|
||||
)
|
||||
# full_matrices=False -> Vh shape (min(N,d_in), d_in); take top-r rows
|
||||
_, _, Vh = torch.linalg.svd(X, full_matrices=False)
|
||||
A = Vh[: cfg.r, :].to(layer.lora_A.dtype).to(layer.lora_A.device)
|
||||
layer.lora_A.copy_(A)
|
||||
|
||||
@staticmethod
|
||||
def forward(layer: nn.Linear, x, y):
|
||||
cfg = layer._lora_cfg
|
||||
scale = cfg.alpha / cfg.r
|
||||
h = einsum(x, layer.lora_A, "... i, r i -> ... r")
|
||||
delta = einsum(h, layer.lora_B, "... r, o r -> ... o")
|
||||
return y + scale * delta
|
||||
@@ -9,27 +9,22 @@ so the layer output becomes y' = W' x = W (R x). R is in INPUT space (d_in x d
|
||||
We implement this via a `forward_input` pre-hook that returns `R x`, then the
|
||||
frozen base layer (including bnb 4/8-bit Linear) computes `W (R x)` itself.
|
||||
|
||||
Identity at t=0: `lora_gate` is initialized to 0 and gates each Householder
|
||||
vector, so the effective u_i starts at 0 -> H_i = I -> R = I -> y' = y.
|
||||
At training time the gate scales the active reflection direction.
|
||||
Identity at t=0 (PEFT-style symmetric init, requires even r):
|
||||
Rows are kaiming-init in pairs: U[0]=U[1], U[2]=U[3], ... Adjacent pairs of
|
||||
Householder reflections with identical vectors cancel exactly
|
||||
(H_i H_i = I), so R = I at init -> y' = y to bit-precision.
|
||||
After the first gradient step the paired rows diverge and the chain becomes a
|
||||
general orthogonal matrix; gradient flows into U from step 0 (no dead-grad).
|
||||
Odd r is rejected (matches peft warning behaviour).
|
||||
|
||||
KNOWN GRADIENT ISSUE (flagged by external review 2026-04-26):
|
||||
Forward is `x + gate * (Rx - x)`. With gate=0 at init, d_output/d_U is
|
||||
proportional to gate, so on step 0 ONLY `lora_gate` receives gradient;
|
||||
`lora_U` is dead. Once gate moves off zero, U starts learning. This deviates
|
||||
from the paper, which has no such gate -- paper uses orthogonal init of U so
|
||||
R != I from step 0. We trade paper-faithful init for identity-at-init.
|
||||
|
||||
OMITTED: paper also adds an orthogonality regularizer
|
||||
lambda * sum_i (u_i^T u_j)^2 (Eq. 6 / Sec. 3.3)
|
||||
which is a loss term, not a forward-pass change. Add it in your training loop if
|
||||
you want the regularized HRA variant.
|
||||
OMITTED: paper also adds an orthogonality regularizer (Eq. 6 / Sec. 3.3),
|
||||
a loss-side term. Add it in your training loop if you want regularized HRA.
|
||||
|
||||
Reference implementations (for review/cross-check):
|
||||
- HRA paper authors (DaShenZi721/HRA), llama variant of OFT layer with HRA:
|
||||
https://github.com/DaShenZi721/HRA/blob/master/llama/peft/oft/layer_GS_HRA.py
|
||||
(offline: docs/refs/orig_hra_layer.py)
|
||||
- peft HRA layer (cleaner, includes apply_GS toggle for orthogonalization):
|
||||
- peft HRA layer, reset_hra_parameters (lines 100-108):
|
||||
https://github.com/huggingface/peft/blob/main/src/peft/tuners/hra/layer.py
|
||||
(offline: docs/refs/peft_hra_layer.py)
|
||||
"""
|
||||
@@ -46,20 +41,33 @@ class HRA:
|
||||
|
||||
@staticmethod
|
||||
def param_specs(d_in, d_out, cfg):
|
||||
if cfg.r % 2 != 0:
|
||||
raise ValueError(
|
||||
f"HRA symmetric init requires even r; got r={cfg.r}. "
|
||||
"Pick an even rank or use a different variant."
|
||||
)
|
||||
return {
|
||||
# one Householder vector per rank slot in INPUT space R^{d_in}
|
||||
"lora_U": ParamSpec((cfg.r, d_in), init="kaiming", trainable=True),
|
||||
# identity gate; 0 -> R = I exactly
|
||||
"lora_gate": ParamSpec((), init="zeros", trainable=True),
|
||||
# Householder vectors stacked as rows (one vector per rank slot)
|
||||
# init done in init() to enforce paired rows -> R = I at t=0.
|
||||
"lora_U": ParamSpec((cfg.r, d_in), init="zeros", trainable=True),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def init(layer: nn.Linear, cfg) -> None:
|
||||
# Symmetric init per peft (docs/refs/peft_hra_layer.py:101-108):
|
||||
# half = kaiming(r//2, d_in); U = repeat_interleave(half, 2, dim=0)
|
||||
# Adjacent pairs (H_2k H_2k+1) cancel since H^2 = I, so R = I exactly,
|
||||
# while gradient still flows into U from step 0.
|
||||
with torch.no_grad():
|
||||
r, d_in = layer.lora_U.shape
|
||||
half = torch.empty(r // 2, d_in, dtype=layer.lora_U.dtype, device=layer.lora_U.device)
|
||||
nn.init.kaiming_uniform_(half, a=5 ** 0.5)
|
||||
layer.lora_U.copy_(torch.repeat_interleave(half, 2, dim=0))
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def forward_input(layer: nn.Linear, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply x + gate * (Rx - x). gate=0 -> identity; nonzero -> full Householder chain."""
|
||||
"""Apply Rx where R = prod_i H_i, H_i = I - 2 u_i u_i^T / ||u_i||^2."""
|
||||
U = layer.lora_U # (r, d_in)
|
||||
Rx = x
|
||||
for i in range(U.shape[0]):
|
||||
@@ -67,4 +75,4 @@ class HRA:
|
||||
sq = (u * u).sum().clamp_min(1e-12)
|
||||
coeff = einsum(Rx, u, "... i, i -> ...") * (2.0 / sq)
|
||||
Rx = Rx - coeff.unsqueeze(-1) * u
|
||||
return x + layer.lora_gate * (Rx - x)
|
||||
return Rx
|
||||
|
||||
@@ -1,26 +1,28 @@
|
||||
"""IA3-style output gating. Liu et al. 2022 https://arxiv.org/abs/2205.05638
|
||||
"""IA3-style elementwise gating. Liu et al. 2022 https://arxiv.org/abs/2205.05638
|
||||
|
||||
y_new = y * g, g initialized to 1 (identity at t=0)
|
||||
Two registered variants, matching the paper's two regimes:
|
||||
|
||||
DEVIATION FROM PAPER:
|
||||
The original IA3 gates only three positions per transformer block:
|
||||
l_k * (k_proj output), l_v * (v_proj output), l_ff * (FFN intermediate after activation)
|
||||
This implementation gates ANY linear layer the targeting system selects.
|
||||
To match the paper exactly on a typical Llama/Qwen-style block, attach with:
|
||||
* `ia3` -- OUTPUT-side gating, parameter shape (d_out,).
|
||||
y_new = y * g. Use for attention projections (k_proj, v_proj).
|
||||
|
||||
cfg = LoraLiteConfig(
|
||||
variant="ia3",
|
||||
target_names=(r"\\.k_proj$", r"\\.v_proj$", r"\\.up_proj$"),
|
||||
target_roles=(),
|
||||
)
|
||||
* `ia3_ff` -- INPUT-side gating, parameter shape (d_in,).
|
||||
y_new = base_layer(x * g). Use for FFN-down layers (down_proj,
|
||||
fc2). Equivalent to the paper's "gate the FFN intermediate (post-
|
||||
activation)" position because down_proj's input IS that
|
||||
intermediate hidden state.
|
||||
|
||||
`up_proj` is the closest stand-in for "FFN intermediate" in gated-MLP blocks
|
||||
(Llama uses gate * up; gating the up branch is the IA3-spirit choice).
|
||||
In both cases g is initialized to 1 -> identity at t=0.
|
||||
|
||||
Reference implementations (for review/cross-check):
|
||||
- peft IA3 layer (uses ia3_l elementwise scaling, fan_in_fan_out aware):
|
||||
To match the paper exactly on a Llama/Qwen-style block requires TWO attach
|
||||
passes (one per variant), since each variant uses one hook type:
|
||||
|
||||
cfg_attn = LoraLiteConfig(variant="ia3", target_names=(r"\\.k_proj$", r"\\.v_proj$"))
|
||||
cfg_ffn = LoraLiteConfig(variant="ia3_ff", target_names=(r"\\.down_proj$",))
|
||||
|
||||
Reference implementation:
|
||||
- peft IA3 layer (is_feedforward toggles input-vs-output gating, see
|
||||
docs/refs/peft_ia3_layer.py:177-188 forward and :214 update_layer):
|
||||
https://github.com/huggingface/peft/blob/main/src/peft/tuners/ia3/layer.py
|
||||
(offline: docs/refs/peft_ia3_layer.py)
|
||||
"""
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -42,4 +44,21 @@ class IA3:
|
||||
|
||||
@staticmethod
|
||||
def forward(layer: nn.Linear, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return y * layer.lora_g
|
||||
return y * layer.lora_g
|
||||
|
||||
|
||||
@register
|
||||
class IA3FF:
|
||||
name = "ia3_ff"
|
||||
|
||||
@staticmethod
|
||||
def param_specs(d_in, d_out, cfg):
|
||||
return {"lora_g": ParamSpec((d_in,), init="ones", trainable=True)}
|
||||
|
||||
@staticmethod
|
||||
def init(layer: nn.Linear, cfg) -> None:
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def forward_input(layer: nn.Linear, x: torch.Tensor) -> torch.Tensor:
|
||||
return x * layer.lora_g
|
||||
@@ -6,9 +6,11 @@ W_eff(t=0) = W_res + B@A = W (numerically; bf16 round-trip not bit-exact).
|
||||
DEVIATION FROM PAPER (documented):
|
||||
- Paper sets adapter scale = 1 (no alpha/r factor); we keep LoRA's alpha/r
|
||||
pipeline so callers must pass alpha=r to get paper-faithful identity.
|
||||
- Saved adapter does NOT include W_res; load() recomputes PiSSA init on the
|
||||
*same-seed base* before overwriting A/B. Reload is exact only on identical
|
||||
base weights.
|
||||
- Saved adapter does NOT include W_res (would double checkpoint size). Instead
|
||||
`adapter.save` records a fingerprint of the post-init base weights and
|
||||
`adapter.load` re-runs PiSSA init then verifies the fingerprint matches
|
||||
-- so loading onto a different base weight raises loudly instead of
|
||||
silently producing wrong outputs.
|
||||
|
||||
Reference implementations (for review/cross-check):
|
||||
- PiSSA original (NeurIPS'24 spotlight) init script (SVD on dequant W):
|
||||
|
||||
+52
-1
@@ -132,6 +132,7 @@ def variant_test(variant: str, dtype=torch.float32):
|
||||
"ia3": 1e-6,
|
||||
"dora": 5e-5, # m * V/||V|| with V=W -> rounding in norm/divide
|
||||
"hra": 1e-6, # gate=0 -> exact identity
|
||||
"antipasto": 5e-4, # SVD truncation + W_res reconstruction in fp32
|
||||
}[variant] * max(1.0, base_scale)
|
||||
assert err < tol, f" FAIL identity: err {err} > tol {tol}"
|
||||
print(f" SHOULD: err<{tol:.1e}. PASS.")
|
||||
@@ -173,6 +174,8 @@ def variant_test(variant: str, dtype=torch.float32):
|
||||
opt = torch.optim.Adam(trainable, lr=1e-1)
|
||||
elif variant == "dora":
|
||||
opt = torch.optim.Adam(trainable, lr=1e-3) # m near ||W||_c, bigger lr blows up
|
||||
elif variant == "antipasto":
|
||||
opt = torch.optim.Adam(trainable, lr=1e-2) # delta_s + rot_T, sensitive
|
||||
else:
|
||||
opt = torch.optim.SGD(trainable, lr=1e-2)
|
||||
losses = []
|
||||
@@ -278,13 +281,61 @@ def bitsandbytes_cuda_smoke(require_bnb: bool):
|
||||
del model
|
||||
|
||||
|
||||
def eva_smoke():
|
||||
"""EVA needs calibration data: drives forward + per-target SVD on inputs."""
|
||||
print("\n=== variant=eva (data-driven init via group_init+calibration_data) ===")
|
||||
torch.manual_seed(0)
|
||||
model = TinyModel().to(torch.float32)
|
||||
ids = torch.randint(0, 100, (2, 16))
|
||||
with torch.no_grad():
|
||||
y_base = model(ids).clone()
|
||||
|
||||
cfg = ll.LoraLiteConfig(variant="eva", r=4, alpha=8, dtype=torch.float32)
|
||||
# 4 calibration batches of random ids
|
||||
calib = [torch.randint(0, 100, (2, 16)) for _ in range(4)]
|
||||
ll.attach(model, cfg, calibration_data=calib)
|
||||
n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print(f" trainable params={n_trainable} (should be only lora_B since A is buffer)")
|
||||
|
||||
with torch.no_grad():
|
||||
y_adapt = model(ids)
|
||||
err = (y_adapt - y_base).abs().max().item()
|
||||
print(f" t=0 identity: max|y_adapt - y_base| = {err:.3e}")
|
||||
assert err < 1e-6, f"EVA should be exact identity (B=0); got {err}"
|
||||
print(" SHOULD: err==0 (B=0 init). PASS.")
|
||||
|
||||
# check A buffer is non-zero (data-driven)
|
||||
a_norms = [layer.lora_A.norm().item() for layer in [m for m in model.modules() if hasattr(m, "lora_A")]]
|
||||
assert all(n > 0 for n in a_norms), "EVA lora_A buffers all zero -> group_init never ran"
|
||||
print(f" SHOULD: lora_A buffers populated. PASS (mean ||A||={sum(a_norms)/len(a_norms):.3f}).")
|
||||
|
||||
# gradient flow: only B trains
|
||||
target = torch.randn(2, 16, 100, dtype=torch.float32) * 0.1
|
||||
trainable = [p for p in model.parameters() if p.requires_grad]
|
||||
opt = torch.optim.SGD(trainable, lr=1e-2)
|
||||
losses = []
|
||||
for _ in range(20):
|
||||
opt.zero_grad()
|
||||
loss = (model(ids) - target).pow(2).mean()
|
||||
loss.backward()
|
||||
assert_no_base_grads(model)
|
||||
opt.step()
|
||||
losses.append(loss.item())
|
||||
drop = (losses[0] - losses[-1]) / max(losses[0], 1e-12)
|
||||
print(f" loss[0]={losses[0]:.4f} loss[-1]={losses[-1]:.4f} drop={100*drop:.1f}%")
|
||||
assert drop > 0.05
|
||||
print(" SHOULD: drop>5%. PASS.")
|
||||
ll.detach(model)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--require-bnb", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
for v in ("lora", "pissa", "delora", "ia3", "dora", "hra"):
|
||||
for v in ("lora", "pissa", "delora", "ia3", "dora", "hra", "antipasto"):
|
||||
variant_test(v, dtype=torch.float32)
|
||||
eva_smoke()
|
||||
structural_linear_like_test()
|
||||
bitsandbytes_cuda_smoke(args.require_bnb)
|
||||
print("\nALL PASS.")
|
||||
|
||||
Reference in New Issue
Block a user