fix: corda silently ran as plain SVD; wire calibration + persist data-driven residual

The benchmark only passed calibration_data to eva, so antipasto_corda's
group_init hit `if calibration_data is None: return` and every corda run was
actually plain SVD. The covariance orientation never executed -- all prior
corda-vs-antipasto comparisons are void.

- antipasto_corda.group_init: raise on None instead of silently degrading
  (orientation is the variant's whole identity; fail loud).
- benchmark: feed ~256 MetaMath calibration samples (IPM, per PEFT/CorDA) to
  corda and to cov_orient ablate; run_id now carries an __lr tag.
- adapter.save/load: a data-driven group_init rewrites the frozen base residual
  W_res into a form init() cannot reproduce at load (it only knows the plain
  top-r crop). Persist those residuals in the adapter and restore them. Fixes a
  reload-logits mismatch that was masked while group_init never ran.
- probe check: compare every saved tensor (lora_ buffers AND base residuals)
  against the reloaded model state.
- justfile: bench-variant gains an lr_override (the core wants a tamer lr than
  the gain's 5e-3).

Co-Authored-By: Claudypoo <noreply@anthropic.com>
This commit is contained in:
wassname
2026-06-16 05:56:02 +08:00
parent 9d027752ad
commit d4ec550dd8
4 changed files with 44 additions and 12 deletions
+3 -1
View File
@@ -75,7 +75,7 @@ metamath-queue variant="lora" steps="5000" model="Qwen/Qwen3-0.6B-Base":
# Run a single MetaMathQA->GSM8K benchmark for a given variant.
# Per-variant lr / target-name defaults are baked in here.
bench-variant model variant steps="5000" lora_rank="8" r_override="":
bench-variant model variant steps="5000" lora_rank="8" r_override="" lr_override="":
#!/usr/bin/env bash
set -euo pipefail
lr=1e-4
@@ -96,6 +96,8 @@ bench-variant model variant steps="5000" lora_rank="8" r_override="":
esac
# r override (e.g. low-rank corda sweep); alpha tracks r for the antipasto family.
if [ -n "{{r_override}}" ]; then r="{{r_override}}"; alpha="{{r_override}}"; fi
# lr override (e.g. dplr core wants a tamer lr than the gain's 5e-3).
if [ -n "{{lr_override}}" ]; then lr="{{lr_override}}"; fi
exec uv run --extra benchmark python scripts/metamath_gsm8k_benchmark.py \
--model '{{model}}' \
--variant '{{variant}}' \
+21 -6
View File
@@ -443,11 +443,14 @@ def check_probe_reload(
ll.load(loaded_model, str(adapter_path))
from safetensors.torch import load_file
saved_sd = load_file(str(adapter_path), device="cpu")
loaded_state = adapter_state(loaded_model)
if set(saved_sd) != set(loaded_state):
raise AssertionError("loaded adapter keys differ from saved adapter keys")
# Every saved tensor (lora_ buffers AND, for data-driven variants, the rewritten
# base residuals) must reload bit-identical onto the model.
loaded_full = loaded_model.state_dict()
missing = set(saved_sd) - set(loaded_full)
if missing:
raise AssertionError(f"saved adapter keys absent from loaded model: {sorted(missing)[:8]}")
for name, value in saved_sd.items():
if not torch.equal(loaded_state[name].cpu(), value):
if not torch.equal(loaded_full[name].cpu(), value):
raise AssertionError(f"loaded adapter tensor differs: {name}")
logits_loaded = loaded_model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).logits.detach().clone()
reload_err = (logits_loaded - logits_trained).abs().max().item()
@@ -539,6 +542,10 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
# antipasto family defaults to r=256; low-rank sweeps get their own dirs.
if args.variant.startswith("antipasto") and args.r != 256:
run_id += f"__r{args.r}"
# antipasto family defaults to lr=5e-3; lr sweeps get their own dirs (the dense/
# low-rank cores want a tamer lr than the gain, so this is a real axis).
if args.variant.startswith("antipasto") and abs(args.lr - 5e-3) > 1e-9:
run_id += f"__lr{args.lr:g}"
out_dir = args.output_dir / run_id
out_dir.mkdir(parents=True, exist_ok=True)
@@ -546,10 +553,18 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
model, tokenizer = load_model_and_tokenizer(args.model, dtype, args.device, args.quantization)
batches, skipped_train_prompt_too_long = make_train_batches(datasets["train"], tokenizer, args)
cfg = cfg_for_variant(args, dtype)
if args.variant == "eva":
# Variants with a data-driven group_init need calibration activations from the
# downstream task (IPM mode, per CorDA). eva needs only a few batches for its
# init; corda/cov-orient accumulate a d_in x d_in covariance, so follow PEFT's
# default of ~256 samples (64 batches x bs=4) for a well-conditioned C.
needs_calib = args.variant == "eva" or args.variant == "antipasto_corda" or (
args.variant == "antipasto_ablate" and args.antipasto_cov_orient
)
if needs_calib:
n_batches = min(4, len(batches)) if args.variant == "eva" else min(64, len(batches))
calib = [
{"input_ids": b["input_ids"], "attention_mask": b["attention_mask"]}
for b in batches[: min(4, len(batches))]
for b in batches[:n_batches]
]
ll.attach(model, cfg, calibration_data=calib)
else:
+13 -2
View File
@@ -63,6 +63,7 @@ def attach(model: nn.Module, cfg: AdapterConfig, calibration_data=None, *, _skip
attached_targets.append((name, layer, role))
group_init = getattr(variant, "group_init", None)
ran_data_init = group_init is not None and not _skip_group_init and calibration_data is not None
if group_init is not None and not _skip_group_init:
group_init(model, attached_targets, cfg, calibration_data)
@@ -72,7 +73,13 @@ def attach(model: nn.Module, cfg: AdapterConfig, calibration_data=None, *, _skip
else:
handles.append(layer.register_forward_hook(_hook))
setattr(model, _ATTACHED_ATTR, {"cfg": cfg, "targets": attached_names, "handles": handles})
# A data-driven group_init (CorDA orient, Wanda re-select) rewrites the frozen
# base residual W_res into a form init() cannot reproduce at load time (it only
# knows the plain top-r crop). So those residuals are part of the saved adapter.
base_weight_keys = [f"{n}.weight" for n in attached_names] if ran_data_init else []
setattr(model, _ATTACHED_ATTR,
{"cfg": cfg, "targets": attached_names, "handles": handles,
"base_weight_keys": base_weight_keys})
return handles
@@ -102,7 +109,11 @@ def save(model: nn.Module, path: str) -> None:
state = getattr(model, _ATTACHED_ATTR, None)
if state is None:
raise RuntimeError("no adapter attached; call attach() first")
sd = {k: v.detach().cpu() for k, v in model.state_dict().items() if "lora_" in k}
full_sd = model.state_dict()
sd = {k: v.detach().cpu() for k, v in full_sd.items() if "lora_" in k}
# data-driven variants also persist their rewritten base residuals (see attach()).
for wk in state.get("base_weight_keys", []):
sd[wk] = full_sd[wk].detach().cpu()
metadata = {"cfg": json.dumps(state["cfg"].to_dict())}
from safetensors.torch import save_file
save_file(sd, path, metadata=metadata)
+7 -3
View File
@@ -92,14 +92,18 @@ class AntiPaSTOCorDA:
def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None:
"""Re-orient each target's SVD by its input covariance C = E[x x^T].
Without calibration_data the plain-SVD init from init() is kept (so this
degrades to antipasto, rotation-free).
Covariance orientation IS this variant's identity, so calibration_data is
mandatory -- fail loud rather than silently degrade to plain SVD (which is
just antipasto and was the bug that made every corda run a no-op).
Called by attach() BEFORE any training, so the trainable g is still at its
zero init when the basis changes -- re-orienting zero gains is a no-op, no
re-indexing needed. Do not call group_init after training has updated g."""
if calibration_data is None:
return
raise ValueError(
"AntiPaSTOCorDA requires calibration_data (covariance orientation is "
"its whole point); got None. Pass attach(model, cfg, calibration_data=...)."
)
layers = {name: layer for name, layer, _ in targets}
# accumulate C = sum x x^T on CPU. Peak GPU cost would otherwise be