mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 15:15:55 +08:00
fix: calibration through cropped model + detach/checkpoint gaps (external review)
gpt-5.5 review (decorrelated) found three real issues deepseek missed: - BLOCKER: calibration ran through a cropped model. attach() did init() (crops every target to W_res) then group_init() (calibration forward) then registered the adapter hooks -- so CorDA's covariance and Wanda's scores were collected from a model missing every target's top-r. Now register hooks BEFORE group_init; at g=0/B=0 they reconstruct the cropped component exactly, so calibration sees full W. - detach() left the model cropped (deleted buffers without adding the frozen top-r back). Now reconstructs W = W_res + U_r S_r (Vh|P)_r before removing buffers. - base-residual persistence wasn't in checkpoint metadata, so load->re-save dropped it. Persist base_weight_keys in metadata, validate on load, carry onto attach state. Docstring/citation cleanup (review + user style asks): - antipasto_corda: drop changelog narration and the stale "None -> plain SVD" claim (it raises now); exact reconstruction states W_res; slim the CPU/OOM note. - antipasto_dplr: drop the arrowhead archaeology; docstring math now matches the forward (p@A.T@B.T); fix the k=0 comment (code requires 1<=k<=r). - citations: Wanda (Sun+ 2023, 2306.11695), ASVD (Yuan+ 2023, 2312.05821), PiSSA (Meng+ 2024, 2404.02948), LoRA (Hu+ 2021, 2106.09685). Co-Authored-By: Claudypoo <noreply@anthropic.com>
This commit is contained in:
@@ -62,17 +62,21 @@ def attach(model: nn.Module, cfg: AdapterConfig, calibration_data=None, *, _skip
|
||||
attached_names.append(name)
|
||||
attached_targets.append((name, layer, role))
|
||||
|
||||
group_init = getattr(variant, "group_init", None)
|
||||
ran_data_init = group_init is not None and not _skip_group_init and calibration_data is not None
|
||||
if group_init is not None and not _skip_group_init:
|
||||
group_init(model, attached_targets, cfg, calibration_data)
|
||||
|
||||
# Register the adapter hooks BEFORE group_init. init() crops each weight to W_res,
|
||||
# so without the hooks the calibration forward inside group_init would run through a
|
||||
# model missing every target's top-r. At g=0 (and B=0) the hooks reconstruct the
|
||||
# cropped component exactly, so calibration sees the true full W.
|
||||
for _, layer, _ in attached_targets:
|
||||
if hasattr(layer._lora_variant, "forward_input"):
|
||||
handles.append(layer.register_forward_pre_hook(_pre_hook))
|
||||
else:
|
||||
handles.append(layer.register_forward_hook(_hook))
|
||||
|
||||
group_init = getattr(variant, "group_init", None)
|
||||
ran_data_init = group_init is not None and not _skip_group_init and calibration_data is not None
|
||||
if group_init is not None and not _skip_group_init:
|
||||
group_init(model, attached_targets, cfg, calibration_data)
|
||||
|
||||
# A data-driven group_init (CorDA orient, Wanda re-select) rewrites the frozen
|
||||
# base residual W_res into a form init() cannot reproduce at load time (it only
|
||||
# knows the plain top-r crop). So those residuals are part of the saved adapter.
|
||||
@@ -94,6 +98,14 @@ def detach(model: nn.Module) -> None:
|
||||
if not hasattr(layer, "_lora_variant"):
|
||||
continue
|
||||
variant = layer._lora_variant
|
||||
# Undo the PiSSA-style crop: init() set weight = W - U_r S_r (Vh|P)_r, so add the
|
||||
# frozen top-r back to recover the original W (the trained gain/core are dropped).
|
||||
# Keyed on the shared SVD-gain buffer convention (antipasto family); variants
|
||||
# without lora_U leave weight untouched (e.g. LoRA never cropped it).
|
||||
if hasattr(layer, "lora_U"):
|
||||
proj = layer.lora_P if hasattr(layer, "lora_P") else layer.lora_Vh
|
||||
with torch.no_grad():
|
||||
layer.weight.data += ((layer.lora_U * layer.lora_S) @ proj).to(layer.weight.dtype)
|
||||
for pname in variant.param_specs(layer.in_features, layer.out_features, layer._lora_cfg):
|
||||
if pname in layer._parameters:
|
||||
del layer._parameters[pname]
|
||||
@@ -112,9 +124,11 @@ def save(model: nn.Module, path: str) -> None:
|
||||
full_sd = model.state_dict()
|
||||
sd = {k: v.detach().cpu() for k, v in full_sd.items() if "lora_" in k}
|
||||
# data-driven variants also persist their rewritten base residuals (see attach()).
|
||||
for wk in state.get("base_weight_keys", []):
|
||||
base_weight_keys = state.get("base_weight_keys", [])
|
||||
for wk in base_weight_keys:
|
||||
sd[wk] = full_sd[wk].detach().cpu()
|
||||
metadata = {"cfg": json.dumps(state["cfg"].to_dict())}
|
||||
metadata = {"cfg": json.dumps(state["cfg"].to_dict()),
|
||||
"base_weight_keys": json.dumps(base_weight_keys)}
|
||||
from safetensors.torch import save_file
|
||||
save_file(sd, path, metadata=metadata)
|
||||
|
||||
@@ -125,6 +139,12 @@ def load(model: nn.Module, path: str) -> list[RemovableHandle]:
|
||||
metadata = f.metadata()
|
||||
sd = load_file(path, device="cpu")
|
||||
cfg = AdapterConfig.from_dict(json.loads(metadata["cfg"]))
|
||||
# Base residuals a data-driven group_init rewrote: must be in the checkpoint and
|
||||
# are restored by load_state_dict below (init()'s plain crop would be wrong).
|
||||
base_weight_keys = json.loads(metadata.get("base_weight_keys", "[]"))
|
||||
missing_base = [wk for wk in base_weight_keys if wk not in sd]
|
||||
if missing_base:
|
||||
raise RuntimeError(f"checkpoint declares but omits base residuals: {missing_base}")
|
||||
handles = attach(model, cfg, _skip_group_init=True) # creates empty params; data-driven inits restored from state_dict
|
||||
missing, unexpected = model.load_state_dict(sd, strict=False)
|
||||
expected_lora = {k for k in model.state_dict() if "lora_" in k}
|
||||
@@ -134,4 +154,6 @@ def load(model: nn.Module, path: str) -> list[RemovableHandle]:
|
||||
unexpected_lora = [k for k in unexpected if "lora_" in k]
|
||||
if unexpected_lora:
|
||||
raise RuntimeError(f"unexpected lora keys in checkpoint: {unexpected_lora}")
|
||||
# Carry the residual keys onto the attach state so a later save() re-persists them.
|
||||
getattr(model, _ATTACHED_ATTR)["base_weight_keys"] = base_weight_keys
|
||||
return handles
|
||||
|
||||
@@ -17,6 +17,8 @@ the gain is learned. See forward() for why 1+ELU over linear/exp/tanh.
|
||||
Refs:
|
||||
- paper: https://github.com/wassname/AntiPaSTO
|
||||
- sibling (whitened, mean-diff): steering-lite/.../sspace.py
|
||||
- selection: Wanda (Sun+ 2023, arXiv:2306.11695), ASVD (Yuan+ 2023, arXiv:2312.05821)
|
||||
- top-r SVD init: PiSSA (Meng+ 2024, arXiv:2404.02948)
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Literal
|
||||
@@ -149,7 +151,7 @@ class AntiPaSTO:
|
||||
W_orig = W_res + (U_old * S_old.unsqueeze(0)) @ Vh_old
|
||||
|
||||
U_full, S_full, Vh_full = torch.linalg.svd(W_orig, full_matrices=False)
|
||||
proj = X.to(Vh_full) @ Vh_full.T # (N, k) input in S-coords (X captured on CPU)
|
||||
proj = X.to(Vh_full) @ Vh_full.T # (N, r) input in S-coords (X CPU -> GPU here)
|
||||
if pool == "rms":
|
||||
act_mag = proj.pow(2).mean(dim=0).sqrt() # outlier-sensitive
|
||||
else:
|
||||
|
||||
@@ -7,15 +7,15 @@ directions move the output most on real activations.
|
||||
|
||||
C = E[x x^T] (+ eps I) # input second moment on calibration data
|
||||
C^{1/2}, C^{-1/2} via eigh(C)
|
||||
U S Vht = SVD(W C^{1/2})
|
||||
U S Vht = SVD(W C^{1/2}) # top-r is Eckart-Young best under x ~ N(0,C)
|
||||
P = Vht C^{-1/2} # (r, d_in) oblique input projector
|
||||
W = U diag(S) P (exactly)
|
||||
W = W_res + U_r diag(S_r) P_r # exact (residual carries the dropped tail)
|
||||
S_eff = S * (1 + ELU(coeff*g)) # same bounded gain as antipasto
|
||||
y = x @ W_res.T + ((x @ P.T) * S_eff) @ U.T
|
||||
|
||||
Identity at g=0 or coeff=0: S_eff=S. P is oblique (rows not orthonormal -- C^{-1/2}
|
||||
skews them); fine for gain reweighting and for output-side ablation (the obliqueness
|
||||
is input-side; U stays orthonormal). No calibration_data -> plain SVD (== antipasto).
|
||||
skews them); fine for gain reweighting since U stays orthonormal. Requires
|
||||
calibration_data (group_init raises otherwise).
|
||||
|
||||
Refs: antipasto.py (gain + selection sibling), CorDA arXiv:2406.05223.
|
||||
"""
|
||||
@@ -92,29 +92,16 @@ class AntiPaSTOCorDA:
|
||||
def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None:
|
||||
"""Re-orient each target's SVD by its input covariance C = E[x x^T].
|
||||
|
||||
Covariance orientation IS this variant's identity, so calibration_data is
|
||||
mandatory -- fail loud rather than silently degrade to plain SVD (which is
|
||||
just antipasto and was the bug that made every corda run a no-op).
|
||||
|
||||
Called by attach() BEFORE any training, so the trainable g is still at its
|
||||
zero init when the basis changes -- re-orienting zero gains is a no-op, no
|
||||
re-indexing needed. Do not call group_init after training has updated g."""
|
||||
Requires calibration_data (raises otherwise). Call only at attach-time,
|
||||
before training updates g (re-orienting g=0 is a no-op, no re-indexing)."""
|
||||
if calibration_data is None:
|
||||
raise ValueError(
|
||||
"AntiPaSTOCorDA requires calibration_data (covariance orientation is "
|
||||
"its whole point); got None. Pass attach(model, cfg, calibration_data=...)."
|
||||
)
|
||||
raise ValueError("AntiPaSTOCorDA requires calibration_data; got None.")
|
||||
|
||||
layers = {name: layer for name, layer, _ in targets}
|
||||
# accumulate C = sum x x^T on CPU. Peak GPU cost would otherwise be
|
||||
# sum_targets d_in^2 fp32 held at once; for down_proj (d_in=intermediate,
|
||||
# e.g. 14336) that is ~0.8 GB *per layer* and OOMs. CPU accumulation bounds
|
||||
# GPU use to the live activation; the eigh/SVD below run on CPU (one-time).
|
||||
# Diagonal C is NOT a usable shortcut: it misses cross-channel correlation,
|
||||
# which is where the orientation gain lives (measured ~= plain SVD).
|
||||
# If down_proj's d_in^2 is too big even on CPU/RAM, exclude it from CorDA
|
||||
# (leave it on plain antipasto) or use a low-rank C (top-k eig of subsampled
|
||||
# inputs) -- not implemented here.
|
||||
# Accumulate C = sum x x^T on CPU: d_in^2 fp32 per target would OOM the GPU
|
||||
# (down_proj d_in~14336 -> ~0.8 GB/layer). Diagonal C is not a shortcut --
|
||||
# the orientation lives in the cross-channel terms (Yuan+ 2023, ASVD,
|
||||
# arXiv:2312.05821 is the diagonal case).
|
||||
cov: dict[str, T] = {}
|
||||
cnt: dict[str, int] = {n: 0 for n in layers}
|
||||
|
||||
|
||||
@@ -1,32 +1,22 @@
|
||||
"""AntiPaSTO-DPLR: diagonal-plus-low-rank core in the frozen SVD basis.
|
||||
"""AntiPaSTO-DPLR: diagonal gain plus a low-rank mixing core in the frozen SVD basis.
|
||||
|
||||
antipasto's core is diagonal (a per-direction gain); it rescales each singular
|
||||
direction but cannot mix one into another. The arrowhead tried a dense b x b block
|
||||
on the top-b directions, but a dense block is the wrong shape (b^2 params, mixes only
|
||||
the top-b) and -- sitting on the S-scaled coords -- its perturbation is amplified by
|
||||
the largest singular values, so it destabilizes. The fix is LoRA's lesson: a low-rank
|
||||
core. Put a trainable rank-k core inside the frozen U/Vh basis, ADDED to the gain:
|
||||
antipasto's diagonal gain rescales each singular direction but cannot mix one into
|
||||
another. DPLR adds a trainable rank-k core that does, inside the frozen U/Vh basis:
|
||||
|
||||
W = U diag(S) Vh + W_res # frozen top-r SVD
|
||||
learn: g (r,) # diagonal gain
|
||||
A (k,r), B (r,k) # low-rank mixing core, B=0 at init
|
||||
A (k,r) kaiming, B (r,k) zero # low-rank mixing core
|
||||
p = x @ Vh.T # (r,) input in the frozen S-basis
|
||||
S_eff = S * (1 + ELU(coeff * g))
|
||||
y = x @ W_res.T + ( (Vh x) * S_eff + coeff * B (A (Vh x)) ) @ U.T
|
||||
h = p * S_eff + coeff * (p @ A.T) @ B.T # diagonal gain + rank-k mixing
|
||||
y = x @ W_res.T + h @ U.T
|
||||
|
||||
so the trainable core is C = diag(S_eff) + coeff * B A acting in S-space, and
|
||||
DeltaW = U C Vh. The diagonal part scales directions; the low-rank part B A mixes them
|
||||
across the whole top-r subspace for 2*r*k params (k=LoRA's rank), not b^2.
|
||||
The rank-k term is LoRA's core (Hu+ 2021, arXiv:2106.09685) restricted to W's top-r
|
||||
subspace, ADDED to the gain rather than folded into diag(S): being independent of S, a
|
||||
unit step moves W by O(1) not O(S), so it has no singular-value amplification. Params
|
||||
= r + 2*r*k. Identity at init (B=0, g=0) and at coeff=0. Basis (U, Vh) stays frozen.
|
||||
|
||||
Why the low-rank part is ADDED, not multiplied into diag(S): an additive core
|
||||
U (BA) Vh is independent of S, so a unit step in BA moves W by O(1), not O(S). That is
|
||||
exactly the S-amplification edge that made the dense arrowhead block blow up at the
|
||||
gain's learning rate -- gone by construction.
|
||||
|
||||
Identity at init: B=0 -> BA=0, g=0 -> 1+ELU(0)=1, so C=diag(S) and DeltaW = U diag(S) Vh.
|
||||
coeff=0 -> identity too (runtime off). The basis (U, Vh) stays frozen and interpretable;
|
||||
only the gain and the rank-k core move.
|
||||
|
||||
Refs: antipasto.py (diagonal sibling), lora.py (the low-rank core), antipasto_corda.py
|
||||
Refs: antipasto.py (diagonal sibling), lora.py (low-rank core), antipasto_corda.py
|
||||
(oriented basis -- composes with this core).
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
@@ -51,7 +41,7 @@ class AntiPaSTODPLRConfig(AdapterConfig):
|
||||
variant: str = "antipasto_dplr"
|
||||
r: int = 256
|
||||
# Rank of the low-rank mixing core (LoRA's r, but inside the frozen subspace).
|
||||
# Params = r (gain) + 2*r*lora_rank. k=0 degenerates to plain antipasto.
|
||||
# Params = r (gain) + 2*r*lora_rank. Requires 1 <= lora_rank <= r.
|
||||
lora_rank: int = 8
|
||||
suppress_only: bool = False # clamp the gain g<=0 (attenuate only); core unaffected.
|
||||
coeff: float = 1.0 # runtime knob: 0=identity, scales gain and core.
|
||||
|
||||
Reference in New Issue
Block a user