diff --git a/src/lora_lite/adapter.py b/src/lora_lite/adapter.py index 6faf7ea..5616db2 100644 --- a/src/lora_lite/adapter.py +++ b/src/lora_lite/adapter.py @@ -48,8 +48,12 @@ def attach(model: nn.Module, cfg: LoraLiteConfig, calibration_data=None) -> list for pname, spec in variant.param_specs(d_in, d_out, cfg).items(): if hasattr(layer, pname): raise RuntimeError(f"{name} already has attribute {pname}; detach first") - p = spec.make(cfg.dtype, layer.weight.device) - layer.register_parameter(pname, p) + if spec.as_buffer: + t = spec.make_tensor(cfg.dtype, layer.weight.device) + layer.register_buffer(pname, t, persistent=True) + else: + p = spec.make(cfg.dtype, layer.weight.device) + layer.register_parameter(pname, p) layer._lora_cfg = cfg layer._lora_variant = variant layer._lora_role = role @@ -85,18 +89,44 @@ def detach(model: nn.Module) -> None: for pname in variant.param_specs(layer.in_features, layer.out_features, layer._lora_cfg): if pname in layer._parameters: del layer._parameters[pname] + elif pname in layer._buffers: + del layer._buffers[pname] for attr in ("_lora_cfg", "_lora_variant", "_lora_role"): if hasattr(layer, attr): delattr(layer, attr) delattr(model, _ATTACHED_ATTR) +def _base_weight_fingerprint(model: nn.Module) -> dict[str, str]: + """Per-target fingerprint of the (post-init) base weights so PiSSA-style + variants that mutate `layer.weight` can fail loud on base mismatch. + Uses a cheap fp32 sum-of-squares + shape signature; not cryptographic. + """ + state = getattr(model, _ATTACHED_ATTR, None) + if state is None: + return {} + fp = {} + for name, layer in model.named_modules(): + if not hasattr(layer, "_lora_variant"): + continue + if name not in state["targets"]: + continue + w = layer.weight.detach().to(torch.float32, copy=False) + fp[name] = f"{tuple(w.shape)}|{float((w * w).sum()):.6e}" + return fp + + def save(model: nn.Module, path: str) -> None: state = getattr(model, _ATTACHED_ATTR, None) if state is None: raise RuntimeError("no adapter attached; call attach() first") sd = {k: v.detach().cpu() for k, v in model.state_dict().items() if "lora_" in k} - torch.save({"cfg": state["cfg"].to_dict(), "state": sd}, path) + blob = { + "cfg": state["cfg"].to_dict(), + "state": sd, + "base_fp": _base_weight_fingerprint(model), + } + torch.save(blob, path) def load(model: nn.Module, path: str) -> list[RemovableHandle]: @@ -111,4 +141,14 @@ def load(model: nn.Module, path: str) -> list[RemovableHandle]: unexpected_lora = [k for k in unexpected if "lora_" in k] if unexpected_lora: raise RuntimeError(f"unexpected lora keys in checkpoint: {unexpected_lora}") + saved_fp = blob.get("base_fp", {}) + if saved_fp: + cur_fp = _base_weight_fingerprint(model) + diffs = [k for k in saved_fp if saved_fp[k] != cur_fp.get(k)] + if diffs: + raise RuntimeError( + f"base weight fingerprint mismatch on {len(diffs)} layer(s) " + f"(e.g. {diffs[0]}). For PiSSA the saved adapter assumes the same " + "base; reload onto the original model or re-run init." + ) return handles diff --git a/src/lora_lite/config.py b/src/lora_lite/config.py index 3558783..4073a7c 100644 --- a/src/lora_lite/config.py +++ b/src/lora_lite/config.py @@ -8,7 +8,6 @@ class LoraLiteConfig: variant: str = "lora" r: int = 8 alpha: float = 16.0 - dropout: float = 0.0 # currently ignored; variants may use cfg.variant_kwargs dtype: torch.dtype = torch.bfloat16 # targeting diff --git a/src/lora_lite/variant.py b/src/lora_lite/variant.py index dacb844..e9d8d4e 100644 --- a/src/lora_lite/variant.py +++ b/src/lora_lite/variant.py @@ -12,8 +12,9 @@ class ParamSpec: shape: tuple[int, ...] init: str | Callable[[torch.Tensor], None] = "zeros" # 'zeros'|'kaiming'|'ones'|callable(t) trainable: bool = True + as_buffer: bool = False # if True, register_buffer instead of register_parameter - def make(self, dtype: torch.dtype, device) -> nn.Parameter: + def _empty(self, dtype: torch.dtype, device) -> torch.Tensor: t = torch.empty(self.shape, dtype=dtype, device=device) if callable(self.init): self.init(t) @@ -26,7 +27,17 @@ class ParamSpec: nn.init.kaiming_uniform_(t, a=5 ** 0.5) if t.ndim >= 2 else t.normal_(0, 0.02) else: raise ValueError(f"unknown init: {self.init}") - return nn.Parameter(t, requires_grad=self.trainable) + return t + + def make(self, dtype: torch.dtype, device) -> nn.Parameter: + # legacy entry: returns a Parameter (used for trainable adapter params) + if self.as_buffer: + raise RuntimeError("as_buffer spec must be installed via register_buffer; see adapter.attach") + return nn.Parameter(self._empty(dtype, device), requires_grad=self.trainable) + + def make_tensor(self, dtype: torch.dtype, device) -> torch.Tensor: + # returns a raw tensor for buffer registration + return self._empty(dtype, device) class Variant(Protocol): diff --git a/src/lora_lite/variants/__init__.py b/src/lora_lite/variants/__init__.py index b8a3567..9c371e1 100644 --- a/src/lora_lite/variants/__init__.py +++ b/src/lora_lite/variants/__init__.py @@ -1 +1 @@ -from . import lora, pissa, delora, ia3, dora, hra # noqa: F401 side-effect: register +from . import lora, pissa, delora, ia3, dora, hra, eva, antipasto # noqa: F401 side-effect: register diff --git a/src/lora_lite/variants/antipasto.py b/src/lora_lite/variants/antipasto.py new file mode 100644 index 0000000..736d34d --- /dev/null +++ b/src/lora_lite/variants/antipasto.py @@ -0,0 +1,145 @@ +"""AntiPaSTO: SVD steering with learnable singular-value deltas + block-diagonal Cayley rotation. + +Lite port of wassname's AntiPaSTO3 SVD adapter (research code, not an +upstream peft variant). Reference: + https://github.com/wassname/antipasto3 (offline: docs/refs/antipasto3_svd_adapter.py) + +Decomposition (PyTorch nn.Linear convention, weight (d_out, d_in)): + + W = U diag(S) Vh + W_res (top-r SVD; W_res = W - U_r diag(S_r) Vh_r) + +We freeze U, S, Vh, W_res and learn: + - delta_s : (r,) -- additive delta to singular values + - rot_T : (n_blocks, bs(bs-1)/2) -- upper-triangle of skew matrix per block + +Forward (matches base layer convention exactly at t=0): + + R = block_diag(Cayley(skew(rot_T))) # (r, r) effective + Vh_rot = R @ Vh # rotates input basis + S_eff = S + delta_s # learnable spectrum + delta_y = ((x @ Vh_rot.T) * S_eff) @ U.T # rank-r path + base_y = x @ W_res.T # frozen residual + y_total = base_y + delta_y # == original output at t=0 + +At init: rot_T = 0 -> R = I -> Vh_rot = Vh, delta_s = 0 -> S_eff = S, so +delta_y reconstructs the truncated SVD term and y_total == x @ W^T to numerical +precision (fp32 SVD round-tripped to cfg.dtype). + +WHICH BASIS IS ROTATED: + By default we rotate Vh (the INPUT singular basis). This is what AntiPaSTO3 + calls `rotate_V=True` in adapter terms (V == Vh.T columns). To rotate U + (output basis) instead, pass variant_kwargs={'rotate_basis': 'U'}. + Rotating both is not implemented (one rotation is enough to span the + identifiable steering directions; two is degenerate). + +REQUIRES even rank divisible by `block_size` (default 4). r=8, bs=4 -> 2 blocks. +""" +from __future__ import annotations +import math + +import torch +from einops import einsum +from torch import nn + +from ..variant import register, ParamSpec + + +def _cayley(skew: torch.Tensor) -> torch.Tensor: + """R = (I - X)^-1 (I + X) for X = skew/2; preserves orthogonality.""" + bs = skew.shape[-1] + eye = torch.eye(bs, dtype=skew.dtype, device=skew.device).expand_as(skew) + X = skew / 2 + return torch.linalg.solve(eye - X, eye + X) + + +def _build_rotation(rot_T: torch.Tensor, bs: int, max_angle: float) -> torch.Tensor: + """rot_T: (n_blocks, bs*(bs-1)/2) -> R: (n_blocks, bs, bs) Cayley rotation.""" + n_blocks, _ = rot_T.shape + rows, cols = torch.triu_indices(bs, bs, offset=1, device=rot_T.device).unbind(0) + A = torch.zeros(n_blocks, bs, bs, dtype=rot_T.dtype, device=rot_T.device) + A[:, rows, cols] = rot_T + A = 0.5 * (A - A.transpose(-1, -2)) + a_limit = 2.0 * math.tan(max_angle / 2.0) + A = a_limit * torch.tanh(A / a_limit) + return _cayley(A) + + +def _block_diag(blocks: torch.Tensor) -> torch.Tensor: + """(n_blocks, bs, bs) -> (n_blocks*bs, n_blocks*bs) block-diagonal.""" + n, bs, _ = blocks.shape + out = blocks.new_zeros(n * bs, n * bs) + for i in range(n): + out[i * bs : (i + 1) * bs, i * bs : (i + 1) * bs] = blocks[i] + return out + + +@register +class AntiPaSTO: + name = "antipasto" + + @staticmethod + def param_specs(d_in, d_out, cfg): + r = cfg.r + bs = int(cfg.variant_kwargs.get("block_size", 4)) + if r % bs != 0: + raise ValueError(f"AntiPaSTO requires r={r} divisible by block_size={bs}") + n_blocks = r // bs + n_triu = bs * (bs - 1) // 2 + return { + # Frozen SVD components captured at init (buffers travel with state_dict). + "lora_U": ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True), + "lora_S": ParamSpec((r,), init="zeros", trainable=False, as_buffer=True), + "lora_Vh": ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True), + # Trainable: per-singular-value delta + block-diagonal Cayley rotation. + "lora_delta_s": ParamSpec((r,), init="zeros", trainable=True), + "lora_rot_T": ParamSpec((n_blocks, n_triu), init="zeros", trainable=True), + } + + @staticmethod + def init(layer: nn.Linear, cfg) -> None: + if type(layer) is not nn.Linear: + raise TypeError( + "AntiPaSTO mutates layer.weight into W_res (like PiSSA), so v1 " + "only supports plain nn.Linear, not bnb 4/8-bit." + ) + with torch.no_grad(): + W = layer.weight.data.float() + U, S, Vh = torch.linalg.svd(W, full_matrices=False) + r = cfg.r + Ur, Sr, Vhr = U[:, :r], S[:r], Vh[:r, :] + layer.lora_U.copy_(Ur.to(layer.lora_U.dtype)) + layer.lora_S.copy_(Sr.to(layer.lora_S.dtype)) + layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype)) + # W_res is the residual after rank-r truncation. Forward adds back + # the truncated path so total == W exactly at init (mod dtype). + W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype) + layer.weight.data.copy_(W_res) + + @staticmethod + def forward(layer: nn.Linear, x, y): + cfg = layer._lora_cfg + bs = int(cfg.variant_kwargs.get("block_size", 4)) + max_angle = float(cfg.variant_kwargs.get("max_rotation_angle", 0.5)) + rotate_basis = cfg.variant_kwargs.get("rotate_basis", "V") + + U = layer.lora_U.to(x.dtype) # (d_out, r) + S = layer.lora_S.to(x.dtype) # (r,) + Vh = layer.lora_Vh.to(x.dtype) # (r, d_in) + + R_blocks = _build_rotation(layer.lora_rot_T.float(), bs, max_angle) + R = _block_diag(R_blocks).to(x.dtype) # (r, r) + + if rotate_basis == "V": + Vh_eff = R @ Vh # rotate INPUT basis + U_eff = U + elif rotate_basis == "U": + Vh_eff = Vh + U_eff = U @ R.T # rotate OUTPUT basis + else: + raise ValueError(f"rotate_basis must be 'U' or 'V', got {rotate_basis!r}") + + S_eff = S + layer.lora_delta_s.to(x.dtype) # (r,) + h = einsum(x, Vh_eff, "... i, r i -> ... r") # x @ Vh_eff.T + h = h * S_eff # diag(S_eff) + delta = einsum(h, U_eff, "... r, o r -> ... o") # @ U_eff.T + return y + delta diff --git a/src/lora_lite/variants/delora.py b/src/lora_lite/variants/delora.py index 5d9bf88..4819505 100644 --- a/src/lora_lite/variants/delora.py +++ b/src/lora_lite/variants/delora.py @@ -1,27 +1,29 @@ -"""DeLoRA: column-normalised A, B, scaled by lambda * ||W||_F / r. +"""DeLoRA: per-input-channel weight-norm scaling, per-rank A/B normalization. Bini et al. 2025 (ICLR'25) https://arxiv.org/abs/2503.18225 Paper Eq. 8: W' = W + (lambda * ||W||_F / r) B Xi A where Xi_{i,i} = 1 / (||b_i|| ||a_i||) makes each rank-1 component unit-norm. -This is equivalent to row-normalising A and column-normalising B (each column of -B and row of A has unit norm), so each rank-1 outer product b_i a_i^T has unit -spectral norm -> the whole low-rank update is bounded. -Identity at t=0: paper uses kaiming init for both A and B with `lambda` initialised -to 0 (or small) so the effective delta starts near zero. We honour that: -default lambda0 == 0 gives bit-identity; user can override via variant_kwargs. +Implementation follows the peft upstream (which the DeLoRA authors maintain), +which differs from the paper notation in two ways that are equivalent at the +forward level but matter for gradients/numerics: + 1. ||W|| is captured PER INPUT CHANNEL (shape (d_in,)), not as a scalar + Frobenius norm. Used to scale `x` element-wise on the input dim. + See docs/refs/peft_delora_layer.py:150 (init) and :250 (forward). + 2. Per-rank normalization applied via division (1/||A_i||*||B^j||) inside + the diagonal scaling, instead of as F.normalize on A,B themselves. + This keeps the gradient flowing through the un-normalized parameters. + +Identity at t=0: lambda0=0 -> delta is exactly zero (bit-identity). KNOWN GRADIENT ISSUE (flagged by external review 2026-04-26): With lambda0=0 the *forward* is identity but `A,B` get zero gradient on step 0 - (delta = lambda * ... -> d_output/d_A is proportional to lambda). Only - `lora_lambda` moves first step. With lambda0>0, A,B train but identity is broken. - Paper's true initialization (frozen-copy trick, see Eq. 9) achieves both; - we do NOT implement that here. + (delta is proportional to lambda). Only `lora_lambda` moves first step. + The paper's true initialization (frozen-copy trick, Eq. 9) achieves both + identity AND non-zero A/B gradients; we do NOT implement it here. -The frozen ||W||_F factor is captured once at init() into a buffer `lora_wnorm`. - -Reference implementations (for review/cross-check): +Reference implementations: - DeLoRA paper authors (ExplainableML/DeLoRA) -- their fork of peft: https://github.com/ExplainableML/DeLoRA/blob/main/peft/src/peft/tuners/delora.py (offline: docs/refs/orig_delora.py) @@ -30,7 +32,6 @@ Reference implementations (for review/cross-check): (offline: docs/refs/peft_delora_layer.py) """ import torch -import torch.nn.functional as F from einops import einsum from torch import nn @@ -50,8 +51,9 @@ class DeLoRA: "lora_lambda": ParamSpec( (), init=lambda t: t.fill_(lam0), trainable=True ), - # ||W||_F captured at init; frozen scalar buffer (no grad) - "lora_wnorm": ParamSpec((), init="zeros", trainable=False), + # ||W||_2 per input channel (shape (d_in,)); frozen buffer captured at init + # per peft DeLoRA (docs/refs/peft_delora_layer.py:150). + "lora_wnorm": ParamSpec((d_in,), init="ones", trainable=False, as_buffer=True), } @staticmethod @@ -60,16 +62,23 @@ class DeLoRA: # dequantizes via .float() round-trip if available, or fails cleanly. with torch.no_grad(): W = layer.weight.data.float() - layer.lora_wnorm.data.fill_(W.norm().item()) + wnorm = W.norm(dim=0).detach().to(layer.lora_wnorm.dtype) + layer.lora_wnorm.copy_(wnorm) return @staticmethod def forward(layer: nn.Linear, x, y): cfg = layer._lora_cfg - # rows of A unit, cols of B unit (per paper, equivalent to Xi) - A = F.normalize(layer.lora_A, dim=1) # (r, d_in) - B = F.normalize(layer.lora_B, dim=0) # (d_out, r) - scale = layer.lora_lambda * layer.lora_wnorm / cfg.r - h = einsum(x, A, "... i, r i -> ... r") + A = layer.lora_A # (r, d_in) + B = layer.lora_B # (d_out, r) + # peft delora forward (docs/refs/peft_delora_layer.py:248-260): + # h = (x * w_norm) @ A.T; scale per-rank = (lambda/r) / (||A_i|| * ||B^j||); + # delta = (h * scale) @ B.T + x_scaled = x * layer.lora_wnorm # (..., d_in) + h = einsum(x_scaled, A, "... i, r i -> ... r") + An = torch.clamp(A.norm(dim=1), min=1e-4) # (r,) + Bn = torch.clamp(B.norm(dim=0), min=1e-4) # (r,) + scale = (layer.lora_lambda / cfg.r) / (An * Bn) # (r,) + h = h * scale delta = einsum(h, B, "... r, o r -> ... o") - return y + scale * delta + return y + delta diff --git a/src/lora_lite/variants/eva.py b/src/lora_lite/variants/eva.py new file mode 100644 index 0000000..ab9cc3a --- /dev/null +++ b/src/lora_lite/variants/eva.py @@ -0,0 +1,121 @@ +"""EVA: Explained-Variance Adaptation. Paischer et al. 2024. + +Paper: https://arxiv.org/abs/2410.07170 (also referred to as ICLR'25 EVA). + +Idea: instead of random A and zero B (LoRA) or SVD of W (PiSSA), initialize +`lora_A` to the top-r right singular vectors of the LAYER INPUT distribution +on a small calibration set. Forward = `y + scale * (B @ A @ x)` exactly like +LoRA; with `lora_B = 0` the adapter is identity at t=0. Only B trains +afterwards (A frozen). The result: each rank slot points along a direction +that actually carries information at this layer. + +This is a stripped-down EVA; we do NOT implement: + - rank redistribution across layers via explained-variance ratios + (peft EVA computes an explained_variance_ratio per layer then redistributes + the global rank budget; we use a uniform `cfg.r` per layer). + - Incremental PCA over many micro-batches (we run one full SVD on the + pooled calibration activations per layer). + - Equal-input deduplication (peft hashes inputs to share SVD across QKV). + +API stress-test: this variant requires data-driven init, so it implements +`group_init(model, targets, cfg, calibration_data)` to drive a single forward +pass on `calibration_data` with hooks that capture each target's input. + +Identity at t=0: `lora_B = 0` -> delta = 0 -> y unchanged. + +References: + - peft EVA (full impl, with IncrementalPCA + redistribution): + https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/eva.py + (offline: docs/refs/peft_eva.py) + - peft fine-tuning script demonstrating initialize_lora_eva_weights: + https://github.com/huggingface/peft/blob/main/examples/eva_finetuning/eva_finetuning.py + (offline: docs/refs/peft_eva_finetuning.py) +""" +from __future__ import annotations + +import torch +from einops import einsum +from torch import nn + +from ..variant import register, ParamSpec + + +@register +class EVA: + name = "eva" + + @staticmethod + def param_specs(d_in, d_out, cfg): + return { + # A is frozen (set in group_init from calibration data); kept as a + # buffer so it travels with state_dict and is not optimized. + "lora_A": ParamSpec((cfg.r, d_in), init="zeros", trainable=False, as_buffer=True), + # B is the only trainable bit; zero-init -> identity at t=0. + "lora_B": ParamSpec((d_out, cfg.r), init="zeros", trainable=True), + } + + @staticmethod + def init(layer: nn.Linear, cfg) -> None: + # No-op; group_init does the data-driven SVD across all targets at once. + return + + @staticmethod + def group_init(model: nn.Module, targets, cfg, calibration_data) -> None: + if calibration_data is None: + raise ValueError( + "EVA requires calibration_data: an iterable of model inputs " + "(dicts of kwargs to model.forward, or single tensors) used to " + "estimate the input PCA per layer. Pass via " + "lora_lite.attach(model, cfg, calibration_data=batches)." + ) + # Collect input activations per target via forward hooks. + layers = {name: layer for name, layer, _ in targets} + captured: dict[str, list[torch.Tensor]] = {n: [] for n in layers} + + def make_hook(name): + def _h(module, args, kwargs): + # signature: pre-forward, args[0] is the input tensor + x = args[0].detach() + captured[name].append(x.reshape(-1, x.shape[-1]).to(torch.float32).cpu()) + return _h + + handles = [ + layers[n].register_forward_pre_hook(make_hook(n), with_kwargs=True) + for n in layers + ] + try: + was_training = model.training + model.eval() + with torch.no_grad(): + for batch in calibration_data: + if isinstance(batch, dict): + model(**batch) + elif isinstance(batch, (list, tuple)): + model(*batch) + else: + model(batch) + if was_training: + model.train() + finally: + for h in handles: + h.remove() + + # SVD per target on pooled inputs; top-r right singular vectors -> A. + for name, layer in layers.items(): + X = torch.cat(captured[name], dim=0) # (N, d_in) + if X.shape[0] < cfg.r: + raise RuntimeError( + f"EVA at {name}: only {X.shape[0]} calibration tokens, need >= r={cfg.r}" + ) + # full_matrices=False -> Vh shape (min(N,d_in), d_in); take top-r rows + _, _, Vh = torch.linalg.svd(X, full_matrices=False) + A = Vh[: cfg.r, :].to(layer.lora_A.dtype).to(layer.lora_A.device) + layer.lora_A.copy_(A) + + @staticmethod + def forward(layer: nn.Linear, x, y): + cfg = layer._lora_cfg + scale = cfg.alpha / cfg.r + h = einsum(x, layer.lora_A, "... i, r i -> ... r") + delta = einsum(h, layer.lora_B, "... r, o r -> ... o") + return y + scale * delta diff --git a/src/lora_lite/variants/hra.py b/src/lora_lite/variants/hra.py index 8161d8b..4dbc7ab 100644 --- a/src/lora_lite/variants/hra.py +++ b/src/lora_lite/variants/hra.py @@ -9,27 +9,22 @@ so the layer output becomes y' = W' x = W (R x). R is in INPUT space (d_in x d We implement this via a `forward_input` pre-hook that returns `R x`, then the frozen base layer (including bnb 4/8-bit Linear) computes `W (R x)` itself. -Identity at t=0: `lora_gate` is initialized to 0 and gates each Householder -vector, so the effective u_i starts at 0 -> H_i = I -> R = I -> y' = y. -At training time the gate scales the active reflection direction. +Identity at t=0 (PEFT-style symmetric init, requires even r): + Rows are kaiming-init in pairs: U[0]=U[1], U[2]=U[3], ... Adjacent pairs of + Householder reflections with identical vectors cancel exactly + (H_i H_i = I), so R = I at init -> y' = y to bit-precision. + After the first gradient step the paired rows diverge and the chain becomes a + general orthogonal matrix; gradient flows into U from step 0 (no dead-grad). + Odd r is rejected (matches peft warning behaviour). -KNOWN GRADIENT ISSUE (flagged by external review 2026-04-26): - Forward is `x + gate * (Rx - x)`. With gate=0 at init, d_output/d_U is - proportional to gate, so on step 0 ONLY `lora_gate` receives gradient; - `lora_U` is dead. Once gate moves off zero, U starts learning. This deviates - from the paper, which has no such gate -- paper uses orthogonal init of U so - R != I from step 0. We trade paper-faithful init for identity-at-init. - -OMITTED: paper also adds an orthogonality regularizer - lambda * sum_i (u_i^T u_j)^2 (Eq. 6 / Sec. 3.3) -which is a loss term, not a forward-pass change. Add it in your training loop if -you want the regularized HRA variant. +OMITTED: paper also adds an orthogonality regularizer (Eq. 6 / Sec. 3.3), +a loss-side term. Add it in your training loop if you want regularized HRA. Reference implementations (for review/cross-check): - HRA paper authors (DaShenZi721/HRA), llama variant of OFT layer with HRA: https://github.com/DaShenZi721/HRA/blob/master/llama/peft/oft/layer_GS_HRA.py (offline: docs/refs/orig_hra_layer.py) - - peft HRA layer (cleaner, includes apply_GS toggle for orthogonalization): + - peft HRA layer, reset_hra_parameters (lines 100-108): https://github.com/huggingface/peft/blob/main/src/peft/tuners/hra/layer.py (offline: docs/refs/peft_hra_layer.py) """ @@ -46,20 +41,33 @@ class HRA: @staticmethod def param_specs(d_in, d_out, cfg): + if cfg.r % 2 != 0: + raise ValueError( + f"HRA symmetric init requires even r; got r={cfg.r}. " + "Pick an even rank or use a different variant." + ) return { - # one Householder vector per rank slot in INPUT space R^{d_in} - "lora_U": ParamSpec((cfg.r, d_in), init="kaiming", trainable=True), - # identity gate; 0 -> R = I exactly - "lora_gate": ParamSpec((), init="zeros", trainable=True), + # Householder vectors stacked as rows (one vector per rank slot) + # init done in init() to enforce paired rows -> R = I at t=0. + "lora_U": ParamSpec((cfg.r, d_in), init="zeros", trainable=True), } @staticmethod def init(layer: nn.Linear, cfg) -> None: + # Symmetric init per peft (docs/refs/peft_hra_layer.py:101-108): + # half = kaiming(r//2, d_in); U = repeat_interleave(half, 2, dim=0) + # Adjacent pairs (H_2k H_2k+1) cancel since H^2 = I, so R = I exactly, + # while gradient still flows into U from step 0. + with torch.no_grad(): + r, d_in = layer.lora_U.shape + half = torch.empty(r // 2, d_in, dtype=layer.lora_U.dtype, device=layer.lora_U.device) + nn.init.kaiming_uniform_(half, a=5 ** 0.5) + layer.lora_U.copy_(torch.repeat_interleave(half, 2, dim=0)) return @staticmethod def forward_input(layer: nn.Linear, x: torch.Tensor) -> torch.Tensor: - """Apply x + gate * (Rx - x). gate=0 -> identity; nonzero -> full Householder chain.""" + """Apply Rx where R = prod_i H_i, H_i = I - 2 u_i u_i^T / ||u_i||^2.""" U = layer.lora_U # (r, d_in) Rx = x for i in range(U.shape[0]): @@ -67,4 +75,4 @@ class HRA: sq = (u * u).sum().clamp_min(1e-12) coeff = einsum(Rx, u, "... i, i -> ...") * (2.0 / sq) Rx = Rx - coeff.unsqueeze(-1) * u - return x + layer.lora_gate * (Rx - x) + return Rx diff --git a/src/lora_lite/variants/ia3.py b/src/lora_lite/variants/ia3.py index f9782e3..fbc7fc9 100644 --- a/src/lora_lite/variants/ia3.py +++ b/src/lora_lite/variants/ia3.py @@ -1,26 +1,28 @@ -"""IA3-style output gating. Liu et al. 2022 https://arxiv.org/abs/2205.05638 +"""IA3-style elementwise gating. Liu et al. 2022 https://arxiv.org/abs/2205.05638 - y_new = y * g, g initialized to 1 (identity at t=0) +Two registered variants, matching the paper's two regimes: -DEVIATION FROM PAPER: - The original IA3 gates only three positions per transformer block: - l_k * (k_proj output), l_v * (v_proj output), l_ff * (FFN intermediate after activation) - This implementation gates ANY linear layer the targeting system selects. - To match the paper exactly on a typical Llama/Qwen-style block, attach with: +* `ia3` -- OUTPUT-side gating, parameter shape (d_out,). + y_new = y * g. Use for attention projections (k_proj, v_proj). - cfg = LoraLiteConfig( - variant="ia3", - target_names=(r"\\.k_proj$", r"\\.v_proj$", r"\\.up_proj$"), - target_roles=(), - ) +* `ia3_ff` -- INPUT-side gating, parameter shape (d_in,). + y_new = base_layer(x * g). Use for FFN-down layers (down_proj, + fc2). Equivalent to the paper's "gate the FFN intermediate (post- + activation)" position because down_proj's input IS that + intermediate hidden state. - `up_proj` is the closest stand-in for "FFN intermediate" in gated-MLP blocks - (Llama uses gate * up; gating the up branch is the IA3-spirit choice). +In both cases g is initialized to 1 -> identity at t=0. -Reference implementations (for review/cross-check): - - peft IA3 layer (uses ia3_l elementwise scaling, fan_in_fan_out aware): +To match the paper exactly on a Llama/Qwen-style block requires TWO attach +passes (one per variant), since each variant uses one hook type: + + cfg_attn = LoraLiteConfig(variant="ia3", target_names=(r"\\.k_proj$", r"\\.v_proj$")) + cfg_ffn = LoraLiteConfig(variant="ia3_ff", target_names=(r"\\.down_proj$",)) + +Reference implementation: + - peft IA3 layer (is_feedforward toggles input-vs-output gating, see + docs/refs/peft_ia3_layer.py:177-188 forward and :214 update_layer): https://github.com/huggingface/peft/blob/main/src/peft/tuners/ia3/layer.py - (offline: docs/refs/peft_ia3_layer.py) """ import torch from torch import nn @@ -42,4 +44,21 @@ class IA3: @staticmethod def forward(layer: nn.Linear, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - return y * layer.lora_g \ No newline at end of file + return y * layer.lora_g + + +@register +class IA3FF: + name = "ia3_ff" + + @staticmethod + def param_specs(d_in, d_out, cfg): + return {"lora_g": ParamSpec((d_in,), init="ones", trainable=True)} + + @staticmethod + def init(layer: nn.Linear, cfg) -> None: + return + + @staticmethod + def forward_input(layer: nn.Linear, x: torch.Tensor) -> torch.Tensor: + return x * layer.lora_g \ No newline at end of file diff --git a/src/lora_lite/variants/pissa.py b/src/lora_lite/variants/pissa.py index f57788a..bd62005 100644 --- a/src/lora_lite/variants/pissa.py +++ b/src/lora_lite/variants/pissa.py @@ -6,9 +6,11 @@ W_eff(t=0) = W_res + B@A = W (numerically; bf16 round-trip not bit-exact). DEVIATION FROM PAPER (documented): - Paper sets adapter scale = 1 (no alpha/r factor); we keep LoRA's alpha/r pipeline so callers must pass alpha=r to get paper-faithful identity. - - Saved adapter does NOT include W_res; load() recomputes PiSSA init on the - *same-seed base* before overwriting A/B. Reload is exact only on identical - base weights. + - Saved adapter does NOT include W_res (would double checkpoint size). Instead + `adapter.save` records a fingerprint of the post-init base weights and + `adapter.load` re-runs PiSSA init then verifies the fingerprint matches + -- so loading onto a different base weight raises loudly instead of + silently producing wrong outputs. Reference implementations (for review/cross-check): - PiSSA original (NeurIPS'24 spotlight) init script (SVD on dequant W): diff --git a/tests/smoke.py b/tests/smoke.py index e28b870..5ec0e6c 100644 --- a/tests/smoke.py +++ b/tests/smoke.py @@ -132,6 +132,7 @@ def variant_test(variant: str, dtype=torch.float32): "ia3": 1e-6, "dora": 5e-5, # m * V/||V|| with V=W -> rounding in norm/divide "hra": 1e-6, # gate=0 -> exact identity + "antipasto": 5e-4, # SVD truncation + W_res reconstruction in fp32 }[variant] * max(1.0, base_scale) assert err < tol, f" FAIL identity: err {err} > tol {tol}" print(f" SHOULD: err<{tol:.1e}. PASS.") @@ -173,6 +174,8 @@ def variant_test(variant: str, dtype=torch.float32): opt = torch.optim.Adam(trainable, lr=1e-1) elif variant == "dora": opt = torch.optim.Adam(trainable, lr=1e-3) # m near ||W||_c, bigger lr blows up + elif variant == "antipasto": + opt = torch.optim.Adam(trainable, lr=1e-2) # delta_s + rot_T, sensitive else: opt = torch.optim.SGD(trainable, lr=1e-2) losses = [] @@ -278,13 +281,61 @@ def bitsandbytes_cuda_smoke(require_bnb: bool): del model +def eva_smoke(): + """EVA needs calibration data: drives forward + per-target SVD on inputs.""" + print("\n=== variant=eva (data-driven init via group_init+calibration_data) ===") + torch.manual_seed(0) + model = TinyModel().to(torch.float32) + ids = torch.randint(0, 100, (2, 16)) + with torch.no_grad(): + y_base = model(ids).clone() + + cfg = ll.LoraLiteConfig(variant="eva", r=4, alpha=8, dtype=torch.float32) + # 4 calibration batches of random ids + calib = [torch.randint(0, 100, (2, 16)) for _ in range(4)] + ll.attach(model, cfg, calibration_data=calib) + n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f" trainable params={n_trainable} (should be only lora_B since A is buffer)") + + with torch.no_grad(): + y_adapt = model(ids) + err = (y_adapt - y_base).abs().max().item() + print(f" t=0 identity: max|y_adapt - y_base| = {err:.3e}") + assert err < 1e-6, f"EVA should be exact identity (B=0); got {err}" + print(" SHOULD: err==0 (B=0 init). PASS.") + + # check A buffer is non-zero (data-driven) + a_norms = [layer.lora_A.norm().item() for layer in [m for m in model.modules() if hasattr(m, "lora_A")]] + assert all(n > 0 for n in a_norms), "EVA lora_A buffers all zero -> group_init never ran" + print(f" SHOULD: lora_A buffers populated. PASS (mean ||A||={sum(a_norms)/len(a_norms):.3f}).") + + # gradient flow: only B trains + target = torch.randn(2, 16, 100, dtype=torch.float32) * 0.1 + trainable = [p for p in model.parameters() if p.requires_grad] + opt = torch.optim.SGD(trainable, lr=1e-2) + losses = [] + for _ in range(20): + opt.zero_grad() + loss = (model(ids) - target).pow(2).mean() + loss.backward() + assert_no_base_grads(model) + opt.step() + losses.append(loss.item()) + drop = (losses[0] - losses[-1]) / max(losses[0], 1e-12) + print(f" loss[0]={losses[0]:.4f} loss[-1]={losses[-1]:.4f} drop={100*drop:.1f}%") + assert drop > 0.05 + print(" SHOULD: drop>5%. PASS.") + ll.detach(model) + + def main(): parser = argparse.ArgumentParser() parser.add_argument("--require-bnb", action="store_true") args = parser.parse_args() - for v in ("lora", "pissa", "delora", "ia3", "dora", "hra"): + for v in ("lora", "pissa", "delora", "ia3", "dora", "hra", "antipasto"): variant_test(v, dtype=torch.float32) + eva_smoke() structural_linear_like_test() bitsandbytes_cuda_smoke(args.require_bnb) print("\nALL PASS.")