diff --git a/justfile b/justfile index 9a5bb76..ab7f11c 100644 --- a/justfile +++ b/justfile @@ -75,7 +75,7 @@ metamath-queue variant="lora" steps="5000" model="Qwen/Qwen3-0.6B-Base": # Run a single MetaMathQA->GSM8K benchmark for a given variant. # Per-variant lr / target-name defaults are baked in here. -bench-variant model variant steps="5000" lora_rank="8" r_override="": +bench-variant model variant steps="5000" lora_rank="8" r_override="" lr_override="": #!/usr/bin/env bash set -euo pipefail lr=1e-4 @@ -96,6 +96,8 @@ bench-variant model variant steps="5000" lora_rank="8" r_override="": esac # r override (e.g. low-rank corda sweep); alpha tracks r for the antipasto family. if [ -n "{{r_override}}" ]; then r="{{r_override}}"; alpha="{{r_override}}"; fi + # lr override (e.g. dplr core wants a tamer lr than the gain's 5e-3). + if [ -n "{{lr_override}}" ]; then lr="{{lr_override}}"; fi exec uv run --extra benchmark python scripts/metamath_gsm8k_benchmark.py \ --model '{{model}}' \ --variant '{{variant}}' \ diff --git a/scripts/metamath_gsm8k_benchmark.py b/scripts/metamath_gsm8k_benchmark.py index 42af6f9..99b414b 100644 --- a/scripts/metamath_gsm8k_benchmark.py +++ b/scripts/metamath_gsm8k_benchmark.py @@ -443,11 +443,14 @@ def check_probe_reload( ll.load(loaded_model, str(adapter_path)) from safetensors.torch import load_file saved_sd = load_file(str(adapter_path), device="cpu") - loaded_state = adapter_state(loaded_model) - if set(saved_sd) != set(loaded_state): - raise AssertionError("loaded adapter keys differ from saved adapter keys") + # Every saved tensor (lora_ buffers AND, for data-driven variants, the rewritten + # base residuals) must reload bit-identical onto the model. + loaded_full = loaded_model.state_dict() + missing = set(saved_sd) - set(loaded_full) + if missing: + raise AssertionError(f"saved adapter keys absent from loaded model: {sorted(missing)[:8]}") for name, value in saved_sd.items(): - if not torch.equal(loaded_state[name].cpu(), value): + if not torch.equal(loaded_full[name].cpu(), value): raise AssertionError(f"loaded adapter tensor differs: {name}") logits_loaded = loaded_model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).logits.detach().clone() reload_err = (logits_loaded - logits_trained).abs().max().item() @@ -539,6 +542,10 @@ def run(args: BenchmarkConfig) -> dict[str, Any]: # antipasto family defaults to r=256; low-rank sweeps get their own dirs. if args.variant.startswith("antipasto") and args.r != 256: run_id += f"__r{args.r}" + # antipasto family defaults to lr=5e-3; lr sweeps get their own dirs (the dense/ + # low-rank cores want a tamer lr than the gain, so this is a real axis). + if args.variant.startswith("antipasto") and abs(args.lr - 5e-3) > 1e-9: + run_id += f"__lr{args.lr:g}" out_dir = args.output_dir / run_id out_dir.mkdir(parents=True, exist_ok=True) @@ -546,10 +553,18 @@ def run(args: BenchmarkConfig) -> dict[str, Any]: model, tokenizer = load_model_and_tokenizer(args.model, dtype, args.device, args.quantization) batches, skipped_train_prompt_too_long = make_train_batches(datasets["train"], tokenizer, args) cfg = cfg_for_variant(args, dtype) - if args.variant == "eva": + # Variants with a data-driven group_init need calibration activations from the + # downstream task (IPM mode, per CorDA). eva needs only a few batches for its + # init; corda/cov-orient accumulate a d_in x d_in covariance, so follow PEFT's + # default of ~256 samples (64 batches x bs=4) for a well-conditioned C. + needs_calib = args.variant == "eva" or args.variant == "antipasto_corda" or ( + args.variant == "antipasto_ablate" and args.antipasto_cov_orient + ) + if needs_calib: + n_batches = min(4, len(batches)) if args.variant == "eva" else min(64, len(batches)) calib = [ {"input_ids": b["input_ids"], "attention_mask": b["attention_mask"]} - for b in batches[: min(4, len(batches))] + for b in batches[:n_batches] ] ll.attach(model, cfg, calibration_data=calib) else: diff --git a/src/lora_lite/adapter.py b/src/lora_lite/adapter.py index 2b9114c..616c07e 100644 --- a/src/lora_lite/adapter.py +++ b/src/lora_lite/adapter.py @@ -63,6 +63,7 @@ def attach(model: nn.Module, cfg: AdapterConfig, calibration_data=None, *, _skip attached_targets.append((name, layer, role)) group_init = getattr(variant, "group_init", None) + ran_data_init = group_init is not None and not _skip_group_init and calibration_data is not None if group_init is not None and not _skip_group_init: group_init(model, attached_targets, cfg, calibration_data) @@ -72,7 +73,13 @@ def attach(model: nn.Module, cfg: AdapterConfig, calibration_data=None, *, _skip else: handles.append(layer.register_forward_hook(_hook)) - setattr(model, _ATTACHED_ATTR, {"cfg": cfg, "targets": attached_names, "handles": handles}) + # A data-driven group_init (CorDA orient, Wanda re-select) rewrites the frozen + # base residual W_res into a form init() cannot reproduce at load time (it only + # knows the plain top-r crop). So those residuals are part of the saved adapter. + base_weight_keys = [f"{n}.weight" for n in attached_names] if ran_data_init else [] + setattr(model, _ATTACHED_ATTR, + {"cfg": cfg, "targets": attached_names, "handles": handles, + "base_weight_keys": base_weight_keys}) return handles @@ -102,7 +109,11 @@ def save(model: nn.Module, path: str) -> None: state = getattr(model, _ATTACHED_ATTR, None) if state is None: raise RuntimeError("no adapter attached; call attach() first") - sd = {k: v.detach().cpu() for k, v in model.state_dict().items() if "lora_" in k} + full_sd = model.state_dict() + sd = {k: v.detach().cpu() for k, v in full_sd.items() if "lora_" in k} + # data-driven variants also persist their rewritten base residuals (see attach()). + for wk in state.get("base_weight_keys", []): + sd[wk] = full_sd[wk].detach().cpu() metadata = {"cfg": json.dumps(state["cfg"].to_dict())} from safetensors.torch import save_file save_file(sd, path, metadata=metadata) diff --git a/src/lora_lite/variants/antipasto_corda.py b/src/lora_lite/variants/antipasto_corda.py index 9ca5bf0..648624f 100644 --- a/src/lora_lite/variants/antipasto_corda.py +++ b/src/lora_lite/variants/antipasto_corda.py @@ -92,14 +92,18 @@ class AntiPaSTOCorDA: def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None: """Re-orient each target's SVD by its input covariance C = E[x x^T]. - Without calibration_data the plain-SVD init from init() is kept (so this - degrades to antipasto, rotation-free). + Covariance orientation IS this variant's identity, so calibration_data is + mandatory -- fail loud rather than silently degrade to plain SVD (which is + just antipasto and was the bug that made every corda run a no-op). Called by attach() BEFORE any training, so the trainable g is still at its zero init when the basis changes -- re-orienting zero gains is a no-op, no re-indexing needed. Do not call group_init after training has updated g.""" if calibration_data is None: - return + raise ValueError( + "AntiPaSTOCorDA requires calibration_data (covariance orientation is " + "its whole point); got None. Pass attach(model, cfg, calibration_data=...)." + ) layers = {name: layer for name, layer, _ in targets} # accumulate C = sum x x^T on CPU. Peak GPU cost would otherwise be