From 7986edad2c99a70cdcfe283ced47c8df91bbaf81 Mon Sep 17 00:00:00 2001 From: wassname <1103714+wassname@users.noreply.github.com> Date: Tue, 16 Jun 2026 06:37:18 +0800 Subject: [PATCH] fix: calibration through cropped model + detach/checkpoint gaps (external review) gpt-5.5 review (decorrelated) found three real issues deepseek missed: - BLOCKER: calibration ran through a cropped model. attach() did init() (crops every target to W_res) then group_init() (calibration forward) then registered the adapter hooks -- so CorDA's covariance and Wanda's scores were collected from a model missing every target's top-r. Now register hooks BEFORE group_init; at g=0/B=0 they reconstruct the cropped component exactly, so calibration sees full W. - detach() left the model cropped (deleted buffers without adding the frozen top-r back). Now reconstructs W = W_res + U_r S_r (Vh|P)_r before removing buffers. - base-residual persistence wasn't in checkpoint metadata, so load->re-save dropped it. Persist base_weight_keys in metadata, validate on load, carry onto attach state. Docstring/citation cleanup (review + user style asks): - antipasto_corda: drop changelog narration and the stale "None -> plain SVD" claim (it raises now); exact reconstruction states W_res; slim the CPU/OOM note. - antipasto_dplr: drop the arrowhead archaeology; docstring math now matches the forward (p@A.T@B.T); fix the k=0 comment (code requires 1<=k<=r). - citations: Wanda (Sun+ 2023, 2306.11695), ASVD (Yuan+ 2023, 2312.05821), PiSSA (Meng+ 2024, 2404.02948), LoRA (Hu+ 2021, 2106.09685). Co-Authored-By: Claudypoo --- src/lora_lite/adapter.py | 36 ++++++++++++++++++----- src/lora_lite/variants/antipasto.py | 4 ++- src/lora_lite/variants/antipasto_corda.py | 35 +++++++--------------- src/lora_lite/variants/antipasto_dplr.py | 36 ++++++++--------------- 4 files changed, 56 insertions(+), 55 deletions(-) diff --git a/src/lora_lite/adapter.py b/src/lora_lite/adapter.py index 616c07e..003773f 100644 --- a/src/lora_lite/adapter.py +++ b/src/lora_lite/adapter.py @@ -62,17 +62,21 @@ def attach(model: nn.Module, cfg: AdapterConfig, calibration_data=None, *, _skip attached_names.append(name) attached_targets.append((name, layer, role)) - group_init = getattr(variant, "group_init", None) - ran_data_init = group_init is not None and not _skip_group_init and calibration_data is not None - if group_init is not None and not _skip_group_init: - group_init(model, attached_targets, cfg, calibration_data) - + # Register the adapter hooks BEFORE group_init. init() crops each weight to W_res, + # so without the hooks the calibration forward inside group_init would run through a + # model missing every target's top-r. At g=0 (and B=0) the hooks reconstruct the + # cropped component exactly, so calibration sees the true full W. for _, layer, _ in attached_targets: if hasattr(layer._lora_variant, "forward_input"): handles.append(layer.register_forward_pre_hook(_pre_hook)) else: handles.append(layer.register_forward_hook(_hook)) + group_init = getattr(variant, "group_init", None) + ran_data_init = group_init is not None and not _skip_group_init and calibration_data is not None + if group_init is not None and not _skip_group_init: + group_init(model, attached_targets, cfg, calibration_data) + # A data-driven group_init (CorDA orient, Wanda re-select) rewrites the frozen # base residual W_res into a form init() cannot reproduce at load time (it only # knows the plain top-r crop). So those residuals are part of the saved adapter. @@ -94,6 +98,14 @@ def detach(model: nn.Module) -> None: if not hasattr(layer, "_lora_variant"): continue variant = layer._lora_variant + # Undo the PiSSA-style crop: init() set weight = W - U_r S_r (Vh|P)_r, so add the + # frozen top-r back to recover the original W (the trained gain/core are dropped). + # Keyed on the shared SVD-gain buffer convention (antipasto family); variants + # without lora_U leave weight untouched (e.g. LoRA never cropped it). + if hasattr(layer, "lora_U"): + proj = layer.lora_P if hasattr(layer, "lora_P") else layer.lora_Vh + with torch.no_grad(): + layer.weight.data += ((layer.lora_U * layer.lora_S) @ proj).to(layer.weight.dtype) for pname in variant.param_specs(layer.in_features, layer.out_features, layer._lora_cfg): if pname in layer._parameters: del layer._parameters[pname] @@ -112,9 +124,11 @@ def save(model: nn.Module, path: str) -> None: full_sd = model.state_dict() sd = {k: v.detach().cpu() for k, v in full_sd.items() if "lora_" in k} # data-driven variants also persist their rewritten base residuals (see attach()). - for wk in state.get("base_weight_keys", []): + base_weight_keys = state.get("base_weight_keys", []) + for wk in base_weight_keys: sd[wk] = full_sd[wk].detach().cpu() - metadata = {"cfg": json.dumps(state["cfg"].to_dict())} + metadata = {"cfg": json.dumps(state["cfg"].to_dict()), + "base_weight_keys": json.dumps(base_weight_keys)} from safetensors.torch import save_file save_file(sd, path, metadata=metadata) @@ -125,6 +139,12 @@ def load(model: nn.Module, path: str) -> list[RemovableHandle]: metadata = f.metadata() sd = load_file(path, device="cpu") cfg = AdapterConfig.from_dict(json.loads(metadata["cfg"])) + # Base residuals a data-driven group_init rewrote: must be in the checkpoint and + # are restored by load_state_dict below (init()'s plain crop would be wrong). + base_weight_keys = json.loads(metadata.get("base_weight_keys", "[]")) + missing_base = [wk for wk in base_weight_keys if wk not in sd] + if missing_base: + raise RuntimeError(f"checkpoint declares but omits base residuals: {missing_base}") handles = attach(model, cfg, _skip_group_init=True) # creates empty params; data-driven inits restored from state_dict missing, unexpected = model.load_state_dict(sd, strict=False) expected_lora = {k for k in model.state_dict() if "lora_" in k} @@ -134,4 +154,6 @@ def load(model: nn.Module, path: str) -> list[RemovableHandle]: unexpected_lora = [k for k in unexpected if "lora_" in k] if unexpected_lora: raise RuntimeError(f"unexpected lora keys in checkpoint: {unexpected_lora}") + # Carry the residual keys onto the attach state so a later save() re-persists them. + getattr(model, _ATTACHED_ATTR)["base_weight_keys"] = base_weight_keys return handles diff --git a/src/lora_lite/variants/antipasto.py b/src/lora_lite/variants/antipasto.py index 71fe778..6172f3a 100644 --- a/src/lora_lite/variants/antipasto.py +++ b/src/lora_lite/variants/antipasto.py @@ -17,6 +17,8 @@ the gain is learned. See forward() for why 1+ELU over linear/exp/tanh. Refs: - paper: https://github.com/wassname/AntiPaSTO - sibling (whitened, mean-diff): steering-lite/.../sspace.py + - selection: Wanda (Sun+ 2023, arXiv:2306.11695), ASVD (Yuan+ 2023, arXiv:2312.05821) + - top-r SVD init: PiSSA (Meng+ 2024, arXiv:2404.02948) """ from dataclasses import dataclass from typing import Iterable, Literal @@ -149,7 +151,7 @@ class AntiPaSTO: W_orig = W_res + (U_old * S_old.unsqueeze(0)) @ Vh_old U_full, S_full, Vh_full = torch.linalg.svd(W_orig, full_matrices=False) - proj = X.to(Vh_full) @ Vh_full.T # (N, k) input in S-coords (X captured on CPU) + proj = X.to(Vh_full) @ Vh_full.T # (N, r) input in S-coords (X CPU -> GPU here) if pool == "rms": act_mag = proj.pow(2).mean(dim=0).sqrt() # outlier-sensitive else: diff --git a/src/lora_lite/variants/antipasto_corda.py b/src/lora_lite/variants/antipasto_corda.py index 648624f..5e97290 100644 --- a/src/lora_lite/variants/antipasto_corda.py +++ b/src/lora_lite/variants/antipasto_corda.py @@ -7,15 +7,15 @@ directions move the output most on real activations. C = E[x x^T] (+ eps I) # input second moment on calibration data C^{1/2}, C^{-1/2} via eigh(C) - U S Vht = SVD(W C^{1/2}) + U S Vht = SVD(W C^{1/2}) # top-r is Eckart-Young best under x ~ N(0,C) P = Vht C^{-1/2} # (r, d_in) oblique input projector - W = U diag(S) P (exactly) + W = W_res + U_r diag(S_r) P_r # exact (residual carries the dropped tail) S_eff = S * (1 + ELU(coeff*g)) # same bounded gain as antipasto y = x @ W_res.T + ((x @ P.T) * S_eff) @ U.T Identity at g=0 or coeff=0: S_eff=S. P is oblique (rows not orthonormal -- C^{-1/2} -skews them); fine for gain reweighting and for output-side ablation (the obliqueness -is input-side; U stays orthonormal). No calibration_data -> plain SVD (== antipasto). +skews them); fine for gain reweighting since U stays orthonormal. Requires +calibration_data (group_init raises otherwise). Refs: antipasto.py (gain + selection sibling), CorDA arXiv:2406.05223. """ @@ -92,29 +92,16 @@ class AntiPaSTOCorDA: def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None: """Re-orient each target's SVD by its input covariance C = E[x x^T]. - Covariance orientation IS this variant's identity, so calibration_data is - mandatory -- fail loud rather than silently degrade to plain SVD (which is - just antipasto and was the bug that made every corda run a no-op). - - Called by attach() BEFORE any training, so the trainable g is still at its - zero init when the basis changes -- re-orienting zero gains is a no-op, no - re-indexing needed. Do not call group_init after training has updated g.""" + Requires calibration_data (raises otherwise). Call only at attach-time, + before training updates g (re-orienting g=0 is a no-op, no re-indexing).""" if calibration_data is None: - raise ValueError( - "AntiPaSTOCorDA requires calibration_data (covariance orientation is " - "its whole point); got None. Pass attach(model, cfg, calibration_data=...)." - ) + raise ValueError("AntiPaSTOCorDA requires calibration_data; got None.") layers = {name: layer for name, layer, _ in targets} - # accumulate C = sum x x^T on CPU. Peak GPU cost would otherwise be - # sum_targets d_in^2 fp32 held at once; for down_proj (d_in=intermediate, - # e.g. 14336) that is ~0.8 GB *per layer* and OOMs. CPU accumulation bounds - # GPU use to the live activation; the eigh/SVD below run on CPU (one-time). - # Diagonal C is NOT a usable shortcut: it misses cross-channel correlation, - # which is where the orientation gain lives (measured ~= plain SVD). - # If down_proj's d_in^2 is too big even on CPU/RAM, exclude it from CorDA - # (leave it on plain antipasto) or use a low-rank C (top-k eig of subsampled - # inputs) -- not implemented here. + # Accumulate C = sum x x^T on CPU: d_in^2 fp32 per target would OOM the GPU + # (down_proj d_in~14336 -> ~0.8 GB/layer). Diagonal C is not a shortcut -- + # the orientation lives in the cross-channel terms (Yuan+ 2023, ASVD, + # arXiv:2312.05821 is the diagonal case). cov: dict[str, T] = {} cnt: dict[str, int] = {n: 0 for n in layers} diff --git a/src/lora_lite/variants/antipasto_dplr.py b/src/lora_lite/variants/antipasto_dplr.py index f143c4f..db15a02 100644 --- a/src/lora_lite/variants/antipasto_dplr.py +++ b/src/lora_lite/variants/antipasto_dplr.py @@ -1,32 +1,22 @@ -"""AntiPaSTO-DPLR: diagonal-plus-low-rank core in the frozen SVD basis. +"""AntiPaSTO-DPLR: diagonal gain plus a low-rank mixing core in the frozen SVD basis. -antipasto's core is diagonal (a per-direction gain); it rescales each singular -direction but cannot mix one into another. The arrowhead tried a dense b x b block -on the top-b directions, but a dense block is the wrong shape (b^2 params, mixes only -the top-b) and -- sitting on the S-scaled coords -- its perturbation is amplified by -the largest singular values, so it destabilizes. The fix is LoRA's lesson: a low-rank -core. Put a trainable rank-k core inside the frozen U/Vh basis, ADDED to the gain: +antipasto's diagonal gain rescales each singular direction but cannot mix one into +another. DPLR adds a trainable rank-k core that does, inside the frozen U/Vh basis: W = U diag(S) Vh + W_res # frozen top-r SVD learn: g (r,) # diagonal gain - A (k,r), B (r,k) # low-rank mixing core, B=0 at init + A (k,r) kaiming, B (r,k) zero # low-rank mixing core + p = x @ Vh.T # (r,) input in the frozen S-basis S_eff = S * (1 + ELU(coeff * g)) - y = x @ W_res.T + ( (Vh x) * S_eff + coeff * B (A (Vh x)) ) @ U.T + h = p * S_eff + coeff * (p @ A.T) @ B.T # diagonal gain + rank-k mixing + y = x @ W_res.T + h @ U.T -so the trainable core is C = diag(S_eff) + coeff * B A acting in S-space, and -DeltaW = U C Vh. The diagonal part scales directions; the low-rank part B A mixes them -across the whole top-r subspace for 2*r*k params (k=LoRA's rank), not b^2. +The rank-k term is LoRA's core (Hu+ 2021, arXiv:2106.09685) restricted to W's top-r +subspace, ADDED to the gain rather than folded into diag(S): being independent of S, a +unit step moves W by O(1) not O(S), so it has no singular-value amplification. Params += r + 2*r*k. Identity at init (B=0, g=0) and at coeff=0. Basis (U, Vh) stays frozen. -Why the low-rank part is ADDED, not multiplied into diag(S): an additive core -U (BA) Vh is independent of S, so a unit step in BA moves W by O(1), not O(S). That is -exactly the S-amplification edge that made the dense arrowhead block blow up at the -gain's learning rate -- gone by construction. - -Identity at init: B=0 -> BA=0, g=0 -> 1+ELU(0)=1, so C=diag(S) and DeltaW = U diag(S) Vh. -coeff=0 -> identity too (runtime off). The basis (U, Vh) stays frozen and interpretable; -only the gain and the rank-k core move. - -Refs: antipasto.py (diagonal sibling), lora.py (the low-rank core), antipasto_corda.py +Refs: antipasto.py (diagonal sibling), lora.py (low-rank core), antipasto_corda.py (oriented basis -- composes with this core). """ from dataclasses import dataclass @@ -51,7 +41,7 @@ class AntiPaSTODPLRConfig(AdapterConfig): variant: str = "antipasto_dplr" r: int = 256 # Rank of the low-rank mixing core (LoRA's r, but inside the frozen subspace). - # Params = r (gain) + 2*r*lora_rank. k=0 degenerates to plain antipasto. + # Params = r (gain) + 2*r*lora_rank. Requires 1 <= lora_rank <= r. lora_rank: int = 8 suppress_only: bool = False # clamp the gain g<=0 (attenuate only); core unaffected. coeff: float = 1.0 # runtime knob: 0=identity, scales gain and core.