Files
evil_MoE/scripts/verify_vhack_heldout.py
T
wassname d68c17e7c5 eval: final deploy eval records knob-on (deployed-as-trained) for quarantine arms
route/routeV final eval now measures both endpoints at n=119 test:
knob-off (ablate_quarantine, the deploy headline) AND knob-on (trained
model as-is). Writes deploy_hack_on/deploy_solve_on/deploy_vhack_on so
the before->after quarantine move is plottable from the deploy set
instead of borrowing the val curve's different scale.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
2026-06-09 13:09:50 +00:00

155 lines
5.6 KiB
Python

"""Held-out v_hack validation (spec.md §B validation).
For each held-out pair, compute per-module gradient diff (g_hack - g_clean)
in delta_S basis, then cos-align with the trained v_hack[name].
Report:
- per-suffix median/mean cos_align
- fraction of modules with cos_align > 0 (SHOULD > 0.5)
- mean cos_align across modules (target > 0.2)
Run: uv run python -m vgrout.verify_vhack_heldout
"""
from __future__ import annotations
import sys
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
import json
import torch
import tyro
from loguru import logger
from safetensors.torch import save_file
from tabulate import tabulate
from transformers import AutoModelForCausalLM, AutoTokenizer
from vgrout.antipasto import wrap_model_with_antipasto
from vgrout.extract_vhack_grad import completion_nll, resolve_dtype
from vgrout.pairs_from_pool import load_pairs_json
from vgrout.vhack import load_v_hack
CACHE_ROOT = Path("svd_cache")
OUT_DIR = Path("out")
@dataclass
class Config:
model: str = "out/baked/qwen3_4b_rh25"
dtype: str = "bf16" # must match extract_vhack_grad.py and train.py
v_hack_path: Path = OUT_DIR / "vhack" / "v_hack_rh25.safetensors"
out_path: Path = OUT_DIR / "vhack_heldout_cos_rh25.safetensors"
pairs_path: Path = OUT_DIR / "pairsets" / "prog_wide.json"
n_heldout: int = 2
def main(cfg: Config) -> int:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = resolve_dtype(cfg.dtype)
logger.info(f"device={device} model={cfg.model} dtype={cfg.dtype}")
pairs = load_pairs_json(cfg.pairs_path)
held = pairs[-cfg.n_heldout:]
logger.info(f"held-out pairs: {len(held)} from {cfg.pairs_path}")
tokenizer = AutoTokenizer.from_pretrained(cfg.model)
model = AutoModelForCausalLM.from_pretrained(
cfg.model, dtype=dtype, attn_implementation="sdpa"
).to(device)
model.eval()
wrappers = wrap_model_with_antipasto(
model, model_name=cfg.model, cache_root=CACHE_ROOT, svd_device=device,
)
v_hack = load_v_hack(cfg.v_hack_path, cfg.model, wrappers, cfg.pairs_path)
logger.info(f"loaded v_hack: {len(v_hack)} modules")
grads_hack: dict[str, list[torch.Tensor]] = defaultdict(list)
grads_clean: dict[str, list[torch.Tensor]] = defaultdict(list)
for pi, pair in enumerate(held):
for label, completion in (("hack", pair.hack), ("clean", pair.clean)):
model.zero_grad(set_to_none=True)
loss = completion_nll(model, tokenizer, pair.prompt, completion, device)
loss.backward()
bucket = grads_hack if label == "hack" else grads_clean
for name, info in wrappers.items():
bucket[name].append(info["delta_S"].grad.detach().float().cpu().clone())
logger.info(f" held pair {pi+1}/{len(held)} loss={loss.item():.3f}")
# per-module cos_align
cos_by_suffix: dict[str, list[float]] = defaultdict(list)
all_cos = []
rows_all = []
for name, V in v_hack.items():
# V is [k, r], orthonormal rows. Held-out diff direction should land
# in the subspace, so report subspace energy fraction ||V·diff/||diff|| || ∈ [0,1].
gh = torch.stack(grads_hack[name]).mean(0)
gc = torch.stack(grads_clean[name]).mean(0)
diff = gh - gc
nrm = diff.norm()
if nrm < 1e-12:
cos = 0.0
else:
cos = (V @ (diff / nrm)).norm().item()
suf = name.split(".")[-1]
cos_by_suffix[suf].append(cos)
all_cos.append(cos)
rows_all.append((name, cos))
agg_rows = []
for suf, vals in sorted(cos_by_suffix.items()):
t = torch.tensor(vals)
agg_rows.append({
"suffix": suf,
"n": len(vals),
"mean_energy": f"{t.mean():.3f}",
"median_energy": f"{t.median():.3f}",
"min": f"{t.min():.3f}",
"max": f"{t.max():.3f}",
})
t_all = torch.tensor(all_cos)
mean_energy = t_all.mean().item()
median_energy = t_all.median().item()
cue = "🟢" if median_energy > 0.30 else ("🟡" if median_energy > 0.10 else "🔴")
print(f"\nSHOULD: median_energy > 0.30 (held-out diff lands in trained subspace). "
f"Prior synthetic-pair run got ~0.01 -- that was the smoking gun.\n")
print(tabulate(agg_rows, headers="keys", tablefmt="tsv", floatfmt=".3f"))
print()
print(f"out: {cfg.out_path}")
print(f"argv: verify_vhack_heldout --model={cfg.model} --v-hack-path={cfg.v_hack_path}")
print(f"main metric: median_energy={median_energy:.3f} [modules={len(all_cos)}]")
print(f"{cue} modules={len(all_cos)} mean={mean_energy:.3f} median={median_energy:.3f}")
frac_pos = (t_all > 0).float().mean().item()
mean_cos = mean_energy
median_cos = median_energy
# save for downstream plotting / sanity. Cos values as a single tensor;
# module names in the metadata header (JSON-encoded preserves order).
names = [n for n, _ in rows_all]
cos_t = torch.tensor([c for _, c in rows_all], dtype=torch.float32)
save_file(
{"cos": cos_t},
str(cfg.out_path),
metadata={"model": cfg.model, "dtype": cfg.dtype, "names": json.dumps(names)},
)
gate_pass = frac_pos > 0.50
target_pass = mean_cos > 0.20
if not gate_pass:
logger.error(f"GATE FAIL: frac>0 = {frac_pos:.3f} <= 0.50")
return 1
if not target_pass:
logger.warning(f"TARGET MISS: mean_cos = {mean_cos:+.3f} <= 0.20 (gate passes but signal weak)")
else:
logger.info(f"TARGET PASS: mean_cos = {mean_cos:+.3f} > 0.20")
return 0
if __name__ == "__main__":
sys.exit(main(tyro.cli(Config)))