"""Held-out v_hack validation (spec.md §B validation). For each held-out pair, compute per-module gradient diff (g_hack - g_clean) in delta_S basis, then cos-align with the trained v_hack[name]. Report: - per-suffix median/mean cos_align - fraction of modules with cos_align > 0 (SHOULD > 0.5) - mean cos_align across modules (target > 0.2) Run: uv run python -m vgrout.verify_vhack_heldout """ from __future__ import annotations import sys from collections import defaultdict from dataclasses import dataclass from pathlib import Path import json import torch import tyro from loguru import logger from safetensors.torch import save_file from tabulate import tabulate from transformers import AutoModelForCausalLM, AutoTokenizer from vgrout.antipasto import wrap_model_with_antipasto from vgrout.extract_vhack_grad import completion_nll, resolve_dtype from vgrout.pairs_from_pool import load_pairs_json from vgrout.vhack import load_v_hack CACHE_ROOT = Path("svd_cache") OUT_DIR = Path("out") @dataclass class Config: model: str = "out/baked/qwen3_4b_rh25" dtype: str = "bf16" # must match extract_vhack_grad.py and train.py v_hack_path: Path = OUT_DIR / "vhack" / "v_hack_rh25.safetensors" out_path: Path = OUT_DIR / "vhack_heldout_cos_rh25.safetensors" pairs_path: Path = OUT_DIR / "pairsets" / "prog_wide.json" n_heldout: int = 2 def main(cfg: Config) -> int: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = resolve_dtype(cfg.dtype) logger.info(f"device={device} model={cfg.model} dtype={cfg.dtype}") pairs = load_pairs_json(cfg.pairs_path) held = pairs[-cfg.n_heldout:] logger.info(f"held-out pairs: {len(held)} from {cfg.pairs_path}") tokenizer = AutoTokenizer.from_pretrained(cfg.model) model = AutoModelForCausalLM.from_pretrained( cfg.model, dtype=dtype, attn_implementation="sdpa" ).to(device) model.eval() wrappers = wrap_model_with_antipasto( model, model_name=cfg.model, cache_root=CACHE_ROOT, svd_device=device, ) v_hack = load_v_hack(cfg.v_hack_path, cfg.model, wrappers, cfg.pairs_path) logger.info(f"loaded v_hack: {len(v_hack)} modules") grads_hack: dict[str, list[torch.Tensor]] = defaultdict(list) grads_clean: dict[str, list[torch.Tensor]] = defaultdict(list) for pi, pair in enumerate(held): for label, completion in (("hack", pair.hack), ("clean", pair.clean)): model.zero_grad(set_to_none=True) loss = completion_nll(model, tokenizer, pair.prompt, completion, device) loss.backward() bucket = grads_hack if label == "hack" else grads_clean for name, info in wrappers.items(): bucket[name].append(info["delta_S"].grad.detach().float().cpu().clone()) logger.info(f" held pair {pi+1}/{len(held)} loss={loss.item():.3f}") # per-module cos_align cos_by_suffix: dict[str, list[float]] = defaultdict(list) all_cos = [] rows_all = [] for name, V in v_hack.items(): # V is [k, r], orthonormal rows. Held-out diff direction should land # in the subspace, so report subspace energy fraction ||V·diff/||diff|| || ∈ [0,1]. gh = torch.stack(grads_hack[name]).mean(0) gc = torch.stack(grads_clean[name]).mean(0) diff = gh - gc nrm = diff.norm() if nrm < 1e-12: cos = 0.0 else: cos = (V @ (diff / nrm)).norm().item() suf = name.split(".")[-1] cos_by_suffix[suf].append(cos) all_cos.append(cos) rows_all.append((name, cos)) agg_rows = [] for suf, vals in sorted(cos_by_suffix.items()): t = torch.tensor(vals) agg_rows.append({ "suffix": suf, "n": len(vals), "mean_energy": f"{t.mean():.3f}", "median_energy": f"{t.median():.3f}", "min": f"{t.min():.3f}", "max": f"{t.max():.3f}", }) t_all = torch.tensor(all_cos) mean_energy = t_all.mean().item() median_energy = t_all.median().item() cue = "🟢" if median_energy > 0.30 else ("🟡" if median_energy > 0.10 else "🔴") print(f"\nSHOULD: median_energy > 0.30 (held-out diff lands in trained subspace). " f"Prior synthetic-pair run got ~0.01 -- that was the smoking gun.\n") print(tabulate(agg_rows, headers="keys", tablefmt="tsv", floatfmt=".3f")) print() print(f"out: {cfg.out_path}") print(f"argv: verify_vhack_heldout --model={cfg.model} --v-hack-path={cfg.v_hack_path}") print(f"main metric: median_energy={median_energy:.3f} [modules={len(all_cos)}]") print(f"{cue} modules={len(all_cos)} mean={mean_energy:.3f} median={median_energy:.3f}") frac_pos = (t_all > 0).float().mean().item() mean_cos = mean_energy median_cos = median_energy # save for downstream plotting / sanity. Cos values as a single tensor; # module names in the metadata header (JSON-encoded preserves order). names = [n for n, _ in rows_all] cos_t = torch.tensor([c for _, c in rows_all], dtype=torch.float32) save_file( {"cos": cos_t}, str(cfg.out_path), metadata={"model": cfg.model, "dtype": cfg.dtype, "names": json.dumps(names)}, ) gate_pass = frac_pos > 0.50 target_pass = mean_cos > 0.20 if not gate_pass: logger.error(f"GATE FAIL: frac>0 = {frac_pos:.3f} <= 0.50") return 1 if not target_pass: logger.warning(f"TARGET MISS: mean_cos = {mean_cos:+.3f} <= 0.20 (gate passes but signal weak)") else: logger.info(f"TARGET PASS: mean_cos = {mean_cos:+.3f} > 0.20") return 0 if __name__ == "__main__": sys.exit(main(tyro.cli(Config)))