fix: calibration through cropped model + detach/checkpoint gaps (external review)

gpt-5.5 review (decorrelated) found three real issues deepseek missed:

- BLOCKER: calibration ran through a cropped model. attach() did init() (crops
  every target to W_res) then group_init() (calibration forward) then registered
  the adapter hooks -- so CorDA's covariance and Wanda's scores were collected from
  a model missing every target's top-r. Now register hooks BEFORE group_init; at
  g=0/B=0 they reconstruct the cropped component exactly, so calibration sees full W.
- detach() left the model cropped (deleted buffers without adding the frozen top-r
  back). Now reconstructs W = W_res + U_r S_r (Vh|P)_r before removing buffers.
- base-residual persistence wasn't in checkpoint metadata, so load->re-save dropped
  it. Persist base_weight_keys in metadata, validate on load, carry onto attach state.

Docstring/citation cleanup (review + user style asks):
- antipasto_corda: drop changelog narration and the stale "None -> plain SVD" claim
  (it raises now); exact reconstruction states W_res; slim the CPU/OOM note.
- antipasto_dplr: drop the arrowhead archaeology; docstring math now matches the
  forward (p@A.T@B.T); fix the k=0 comment (code requires 1<=k<=r).
- citations: Wanda (Sun+ 2023, 2306.11695), ASVD (Yuan+ 2023, 2312.05821),
  PiSSA (Meng+ 2024, 2404.02948), LoRA (Hu+ 2021, 2106.09685).

Co-Authored-By: Claudypoo <noreply@anthropic.com>
This commit is contained in:
wassname
2026-06-16 06:37:18 +08:00
parent d4ec550dd8
commit 7986edad2c
4 changed files with 56 additions and 55 deletions
+29 -7
View File
@@ -62,17 +62,21 @@ def attach(model: nn.Module, cfg: AdapterConfig, calibration_data=None, *, _skip
attached_names.append(name) attached_names.append(name)
attached_targets.append((name, layer, role)) attached_targets.append((name, layer, role))
group_init = getattr(variant, "group_init", None) # Register the adapter hooks BEFORE group_init. init() crops each weight to W_res,
ran_data_init = group_init is not None and not _skip_group_init and calibration_data is not None # so without the hooks the calibration forward inside group_init would run through a
if group_init is not None and not _skip_group_init: # model missing every target's top-r. At g=0 (and B=0) the hooks reconstruct the
group_init(model, attached_targets, cfg, calibration_data) # cropped component exactly, so calibration sees the true full W.
for _, layer, _ in attached_targets: for _, layer, _ in attached_targets:
if hasattr(layer._lora_variant, "forward_input"): if hasattr(layer._lora_variant, "forward_input"):
handles.append(layer.register_forward_pre_hook(_pre_hook)) handles.append(layer.register_forward_pre_hook(_pre_hook))
else: else:
handles.append(layer.register_forward_hook(_hook)) handles.append(layer.register_forward_hook(_hook))
group_init = getattr(variant, "group_init", None)
ran_data_init = group_init is not None and not _skip_group_init and calibration_data is not None
if group_init is not None and not _skip_group_init:
group_init(model, attached_targets, cfg, calibration_data)
# A data-driven group_init (CorDA orient, Wanda re-select) rewrites the frozen # A data-driven group_init (CorDA orient, Wanda re-select) rewrites the frozen
# base residual W_res into a form init() cannot reproduce at load time (it only # base residual W_res into a form init() cannot reproduce at load time (it only
# knows the plain top-r crop). So those residuals are part of the saved adapter. # knows the plain top-r crop). So those residuals are part of the saved adapter.
@@ -94,6 +98,14 @@ def detach(model: nn.Module) -> None:
if not hasattr(layer, "_lora_variant"): if not hasattr(layer, "_lora_variant"):
continue continue
variant = layer._lora_variant variant = layer._lora_variant
# Undo the PiSSA-style crop: init() set weight = W - U_r S_r (Vh|P)_r, so add the
# frozen top-r back to recover the original W (the trained gain/core are dropped).
# Keyed on the shared SVD-gain buffer convention (antipasto family); variants
# without lora_U leave weight untouched (e.g. LoRA never cropped it).
if hasattr(layer, "lora_U"):
proj = layer.lora_P if hasattr(layer, "lora_P") else layer.lora_Vh
with torch.no_grad():
layer.weight.data += ((layer.lora_U * layer.lora_S) @ proj).to(layer.weight.dtype)
for pname in variant.param_specs(layer.in_features, layer.out_features, layer._lora_cfg): for pname in variant.param_specs(layer.in_features, layer.out_features, layer._lora_cfg):
if pname in layer._parameters: if pname in layer._parameters:
del layer._parameters[pname] del layer._parameters[pname]
@@ -112,9 +124,11 @@ def save(model: nn.Module, path: str) -> None:
full_sd = model.state_dict() full_sd = model.state_dict()
sd = {k: v.detach().cpu() for k, v in full_sd.items() if "lora_" in k} sd = {k: v.detach().cpu() for k, v in full_sd.items() if "lora_" in k}
# data-driven variants also persist their rewritten base residuals (see attach()). # data-driven variants also persist their rewritten base residuals (see attach()).
for wk in state.get("base_weight_keys", []): base_weight_keys = state.get("base_weight_keys", [])
for wk in base_weight_keys:
sd[wk] = full_sd[wk].detach().cpu() sd[wk] = full_sd[wk].detach().cpu()
metadata = {"cfg": json.dumps(state["cfg"].to_dict())} metadata = {"cfg": json.dumps(state["cfg"].to_dict()),
"base_weight_keys": json.dumps(base_weight_keys)}
from safetensors.torch import save_file from safetensors.torch import save_file
save_file(sd, path, metadata=metadata) save_file(sd, path, metadata=metadata)
@@ -125,6 +139,12 @@ def load(model: nn.Module, path: str) -> list[RemovableHandle]:
metadata = f.metadata() metadata = f.metadata()
sd = load_file(path, device="cpu") sd = load_file(path, device="cpu")
cfg = AdapterConfig.from_dict(json.loads(metadata["cfg"])) cfg = AdapterConfig.from_dict(json.loads(metadata["cfg"]))
# Base residuals a data-driven group_init rewrote: must be in the checkpoint and
# are restored by load_state_dict below (init()'s plain crop would be wrong).
base_weight_keys = json.loads(metadata.get("base_weight_keys", "[]"))
missing_base = [wk for wk in base_weight_keys if wk not in sd]
if missing_base:
raise RuntimeError(f"checkpoint declares but omits base residuals: {missing_base}")
handles = attach(model, cfg, _skip_group_init=True) # creates empty params; data-driven inits restored from state_dict handles = attach(model, cfg, _skip_group_init=True) # creates empty params; data-driven inits restored from state_dict
missing, unexpected = model.load_state_dict(sd, strict=False) missing, unexpected = model.load_state_dict(sd, strict=False)
expected_lora = {k for k in model.state_dict() if "lora_" in k} expected_lora = {k for k in model.state_dict() if "lora_" in k}
@@ -134,4 +154,6 @@ def load(model: nn.Module, path: str) -> list[RemovableHandle]:
unexpected_lora = [k for k in unexpected if "lora_" in k] unexpected_lora = [k for k in unexpected if "lora_" in k]
if unexpected_lora: if unexpected_lora:
raise RuntimeError(f"unexpected lora keys in checkpoint: {unexpected_lora}") raise RuntimeError(f"unexpected lora keys in checkpoint: {unexpected_lora}")
# Carry the residual keys onto the attach state so a later save() re-persists them.
getattr(model, _ATTACHED_ATTR)["base_weight_keys"] = base_weight_keys
return handles return handles
+3 -1
View File
@@ -17,6 +17,8 @@ the gain is learned. See forward() for why 1+ELU over linear/exp/tanh.
Refs: Refs:
- paper: https://github.com/wassname/AntiPaSTO - paper: https://github.com/wassname/AntiPaSTO
- sibling (whitened, mean-diff): steering-lite/.../sspace.py - sibling (whitened, mean-diff): steering-lite/.../sspace.py
- selection: Wanda (Sun+ 2023, arXiv:2306.11695), ASVD (Yuan+ 2023, arXiv:2312.05821)
- top-r SVD init: PiSSA (Meng+ 2024, arXiv:2404.02948)
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterable, Literal from typing import Iterable, Literal
@@ -149,7 +151,7 @@ class AntiPaSTO:
W_orig = W_res + (U_old * S_old.unsqueeze(0)) @ Vh_old W_orig = W_res + (U_old * S_old.unsqueeze(0)) @ Vh_old
U_full, S_full, Vh_full = torch.linalg.svd(W_orig, full_matrices=False) U_full, S_full, Vh_full = torch.linalg.svd(W_orig, full_matrices=False)
proj = X.to(Vh_full) @ Vh_full.T # (N, k) input in S-coords (X captured on CPU) proj = X.to(Vh_full) @ Vh_full.T # (N, r) input in S-coords (X CPU -> GPU here)
if pool == "rms": if pool == "rms":
act_mag = proj.pow(2).mean(dim=0).sqrt() # outlier-sensitive act_mag = proj.pow(2).mean(dim=0).sqrt() # outlier-sensitive
else: else:
+11 -24
View File
@@ -7,15 +7,15 @@ directions move the output most on real activations.
C = E[x x^T] (+ eps I) # input second moment on calibration data C = E[x x^T] (+ eps I) # input second moment on calibration data
C^{1/2}, C^{-1/2} via eigh(C) C^{1/2}, C^{-1/2} via eigh(C)
U S Vht = SVD(W C^{1/2}) U S Vht = SVD(W C^{1/2}) # top-r is Eckart-Young best under x ~ N(0,C)
P = Vht C^{-1/2} # (r, d_in) oblique input projector P = Vht C^{-1/2} # (r, d_in) oblique input projector
W = U diag(S) P (exactly) W = W_res + U_r diag(S_r) P_r # exact (residual carries the dropped tail)
S_eff = S * (1 + ELU(coeff*g)) # same bounded gain as antipasto S_eff = S * (1 + ELU(coeff*g)) # same bounded gain as antipasto
y = x @ W_res.T + ((x @ P.T) * S_eff) @ U.T y = x @ W_res.T + ((x @ P.T) * S_eff) @ U.T
Identity at g=0 or coeff=0: S_eff=S. P is oblique (rows not orthonormal -- C^{-1/2} Identity at g=0 or coeff=0: S_eff=S. P is oblique (rows not orthonormal -- C^{-1/2}
skews them); fine for gain reweighting and for output-side ablation (the obliqueness skews them); fine for gain reweighting since U stays orthonormal. Requires
is input-side; U stays orthonormal). No calibration_data -> plain SVD (== antipasto). calibration_data (group_init raises otherwise).
Refs: antipasto.py (gain + selection sibling), CorDA arXiv:2406.05223. Refs: antipasto.py (gain + selection sibling), CorDA arXiv:2406.05223.
""" """
@@ -92,29 +92,16 @@ class AntiPaSTOCorDA:
def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None: def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None:
"""Re-orient each target's SVD by its input covariance C = E[x x^T]. """Re-orient each target's SVD by its input covariance C = E[x x^T].
Covariance orientation IS this variant's identity, so calibration_data is Requires calibration_data (raises otherwise). Call only at attach-time,
mandatory -- fail loud rather than silently degrade to plain SVD (which is before training updates g (re-orienting g=0 is a no-op, no re-indexing)."""
just antipasto and was the bug that made every corda run a no-op).
Called by attach() BEFORE any training, so the trainable g is still at its
zero init when the basis changes -- re-orienting zero gains is a no-op, no
re-indexing needed. Do not call group_init after training has updated g."""
if calibration_data is None: if calibration_data is None:
raise ValueError( raise ValueError("AntiPaSTOCorDA requires calibration_data; got None.")
"AntiPaSTOCorDA requires calibration_data (covariance orientation is "
"its whole point); got None. Pass attach(model, cfg, calibration_data=...)."
)
layers = {name: layer for name, layer, _ in targets} layers = {name: layer for name, layer, _ in targets}
# accumulate C = sum x x^T on CPU. Peak GPU cost would otherwise be # Accumulate C = sum x x^T on CPU: d_in^2 fp32 per target would OOM the GPU
# sum_targets d_in^2 fp32 held at once; for down_proj (d_in=intermediate, # (down_proj d_in~14336 -> ~0.8 GB/layer). Diagonal C is not a shortcut --
# e.g. 14336) that is ~0.8 GB *per layer* and OOMs. CPU accumulation bounds # the orientation lives in the cross-channel terms (Yuan+ 2023, ASVD,
# GPU use to the live activation; the eigh/SVD below run on CPU (one-time). # arXiv:2312.05821 is the diagonal case).
# Diagonal C is NOT a usable shortcut: it misses cross-channel correlation,
# which is where the orientation gain lives (measured ~= plain SVD).
# If down_proj's d_in^2 is too big even on CPU/RAM, exclude it from CorDA
# (leave it on plain antipasto) or use a low-rank C (top-k eig of subsampled
# inputs) -- not implemented here.
cov: dict[str, T] = {} cov: dict[str, T] = {}
cnt: dict[str, int] = {n: 0 for n in layers} cnt: dict[str, int] = {n: 0 for n in layers}
+13 -23
View File
@@ -1,32 +1,22 @@
"""AntiPaSTO-DPLR: diagonal-plus-low-rank core in the frozen SVD basis. """AntiPaSTO-DPLR: diagonal gain plus a low-rank mixing core in the frozen SVD basis.
antipasto's core is diagonal (a per-direction gain); it rescales each singular antipasto's diagonal gain rescales each singular direction but cannot mix one into
direction but cannot mix one into another. The arrowhead tried a dense b x b block another. DPLR adds a trainable rank-k core that does, inside the frozen U/Vh basis:
on the top-b directions, but a dense block is the wrong shape (b^2 params, mixes only
the top-b) and -- sitting on the S-scaled coords -- its perturbation is amplified by
the largest singular values, so it destabilizes. The fix is LoRA's lesson: a low-rank
core. Put a trainable rank-k core inside the frozen U/Vh basis, ADDED to the gain:
W = U diag(S) Vh + W_res # frozen top-r SVD W = U diag(S) Vh + W_res # frozen top-r SVD
learn: g (r,) # diagonal gain learn: g (r,) # diagonal gain
A (k,r), B (r,k) # low-rank mixing core, B=0 at init A (k,r) kaiming, B (r,k) zero # low-rank mixing core
p = x @ Vh.T # (r,) input in the frozen S-basis
S_eff = S * (1 + ELU(coeff * g)) S_eff = S * (1 + ELU(coeff * g))
y = x @ W_res.T + ( (Vh x) * S_eff + coeff * B (A (Vh x)) ) @ U.T h = p * S_eff + coeff * (p @ A.T) @ B.T # diagonal gain + rank-k mixing
y = x @ W_res.T + h @ U.T
so the trainable core is C = diag(S_eff) + coeff * B A acting in S-space, and The rank-k term is LoRA's core (Hu+ 2021, arXiv:2106.09685) restricted to W's top-r
DeltaW = U C Vh. The diagonal part scales directions; the low-rank part B A mixes them subspace, ADDED to the gain rather than folded into diag(S): being independent of S, a
across the whole top-r subspace for 2*r*k params (k=LoRA's rank), not b^2. unit step moves W by O(1) not O(S), so it has no singular-value amplification. Params
= r + 2*r*k. Identity at init (B=0, g=0) and at coeff=0. Basis (U, Vh) stays frozen.
Why the low-rank part is ADDED, not multiplied into diag(S): an additive core Refs: antipasto.py (diagonal sibling), lora.py (low-rank core), antipasto_corda.py
U (BA) Vh is independent of S, so a unit step in BA moves W by O(1), not O(S). That is
exactly the S-amplification edge that made the dense arrowhead block blow up at the
gain's learning rate -- gone by construction.
Identity at init: B=0 -> BA=0, g=0 -> 1+ELU(0)=1, so C=diag(S) and DeltaW = U diag(S) Vh.
coeff=0 -> identity too (runtime off). The basis (U, Vh) stays frozen and interpretable;
only the gain and the rank-k core move.
Refs: antipasto.py (diagonal sibling), lora.py (the low-rank core), antipasto_corda.py
(oriented basis -- composes with this core). (oriented basis -- composes with this core).
""" """
from dataclasses import dataclass from dataclasses import dataclass
@@ -51,7 +41,7 @@ class AntiPaSTODPLRConfig(AdapterConfig):
variant: str = "antipasto_dplr" variant: str = "antipasto_dplr"
r: int = 256 r: int = 256
# Rank of the low-rank mixing core (LoRA's r, but inside the frozen subspace). # Rank of the low-rank mixing core (LoRA's r, but inside the frozen subspace).
# Params = r (gain) + 2*r*lora_rank. k=0 degenerates to plain antipasto. # Params = r (gain) + 2*r*lora_rank. Requires 1 <= lora_rank <= r.
lora_rank: int = 8 lora_rank: int = 8
suppress_only: bool = False # clamp the gain g<=0 (attenuate only); core unaffected. suppress_only: bool = False # clamp the gain g<=0 (attenuate only); core unaffected.
coeff: float = 1.0 # runtime knob: 0=identity, scales gain and core. coeff: float = 1.0 # runtime knob: 0=identity, scales gain and core.