Files
evil_MoE/scripts/verify_vhack_heldout.py
T
wassname 4fa9061162 refactor: move 5 leaf entrypoints src/ -> scripts/ (src is now library-only)
verify_rewards, verify_vhack_heldout, build_substrate, probe_distill, probe_plot_stack
are run via 'python -m' / justfile and imported by no core module -> moved to scripts/,
relative imports rewritten to 'from projected_grpo.X'. probe_distill's sibling import
of probe_plot_stack is now a flat import (co-located in scripts/). regrade_pool stays
in src (pairs_from_pool imports load_problems_by_id from it). justfile recipes updated.

src/projected_grpo/ is now 16 importable modules: train + method (proj/vhack/antipasto/
extract_vhack_grad) + env (rewards/eval/problems/data) + pairs (pairs/pairs_from_pool/
regrade_pool/derisk_loopholes) + tablelog/figs. ~1480 lines moved out of the package.
Smoke green (verify_rewards 52/52 from scripts/, train pipeline cout->0).

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
2026-06-03 00:23:56 +00:00

153 lines
5.5 KiB
Python

"""Held-out v_hack validation (spec.md §B validation).
For each held-out pair, compute per-module gradient diff (g_hack - g_clean)
in delta_S basis, then cos-align with the trained v_hack[name].
Report:
- per-suffix median/mean cos_align
- fraction of modules with cos_align > 0 (SHOULD > 0.5)
- mean cos_align across modules (target > 0.2)
Run: uv run python -m projected_grpo.verify_vhack_heldout
"""
from __future__ import annotations
import sys
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
import json
import torch
import tyro
from loguru import logger
from safetensors.torch import save_file
from tabulate import tabulate
from transformers import AutoModelForCausalLM, AutoTokenizer
from projected_grpo.antipasto import wrap_model_with_antipasto
from projected_grpo.extract_vhack_grad import completion_nll, resolve_dtype
from projected_grpo.pairs import PAIRS
from projected_grpo.extract_vhack_grad import load_v_hack
CACHE_ROOT = Path("svd_cache")
OUT_DIR = Path("out")
@dataclass
class Config:
model: str = "out/baked/qwen3_4b_rh25"
dtype: str = "bf16" # must match extract_vhack_grad.py and train.py
v_hack_path: Path = OUT_DIR / "vhack" / "v_hack_rh25.safetensors"
out_path: Path = OUT_DIR / "vhack_heldout_cos_rh25.safetensors"
n_heldout: int = 2
def main(cfg: Config) -> int:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = resolve_dtype(cfg.dtype)
logger.info(f"device={device} model={cfg.model} dtype={cfg.dtype}")
held = PAIRS[-cfg.n_heldout:]
logger.info(f"held-out pairs: {len(held)}")
tokenizer = AutoTokenizer.from_pretrained(cfg.model)
model = AutoModelForCausalLM.from_pretrained(
cfg.model, dtype=dtype, attn_implementation="sdpa"
).to(device)
model.eval()
wrappers = wrap_model_with_antipasto(
model, model_name=cfg.model, cache_root=CACHE_ROOT, svd_device=device,
)
v_hack = load_v_hack(cfg.v_hack_path, cfg.model, wrappers)
logger.info(f"loaded v_hack: {len(v_hack)} modules")
grads_hack: dict[str, list[torch.Tensor]] = defaultdict(list)
grads_clean: dict[str, list[torch.Tensor]] = defaultdict(list)
for pi, pair in enumerate(held):
for label, completion in (("hack", pair.hack), ("clean", pair.clean)):
model.zero_grad(set_to_none=True)
loss = completion_nll(model, tokenizer, pair.prompt, completion, device)
loss.backward()
bucket = grads_hack if label == "hack" else grads_clean
for name, info in wrappers.items():
bucket[name].append(info["delta_S"].grad.detach().float().cpu().clone())
logger.info(f" held pair {pi+1}/{len(held)} loss={loss.item():.3f}")
# per-module cos_align
cos_by_suffix: dict[str, list[float]] = defaultdict(list)
all_cos = []
rows_all = []
for name, V in v_hack.items():
# V is [k, r], orthonormal rows. Held-out diff direction should land
# in the subspace, so report subspace energy fraction ||V·diff/||diff|| || ∈ [0,1].
gh = torch.stack(grads_hack[name]).mean(0)
gc = torch.stack(grads_clean[name]).mean(0)
diff = gh - gc
nrm = diff.norm()
if nrm < 1e-12:
cos = 0.0
else:
cos = (V @ (diff / nrm)).norm().item()
suf = name.split(".")[-1]
cos_by_suffix[suf].append(cos)
all_cos.append(cos)
rows_all.append((name, cos))
agg_rows = []
for suf, vals in sorted(cos_by_suffix.items()):
t = torch.tensor(vals)
agg_rows.append({
"suffix": suf,
"n": len(vals),
"mean_energy": f"{t.mean():.3f}",
"median_energy": f"{t.median():.3f}",
"min": f"{t.min():.3f}",
"max": f"{t.max():.3f}",
})
t_all = torch.tensor(all_cos)
mean_energy = t_all.mean().item()
median_energy = t_all.median().item()
cue = "🟢" if median_energy > 0.30 else ("🟡" if median_energy > 0.10 else "🔴")
print(f"\nSHOULD: median_energy > 0.30 (held-out diff lands in trained subspace). "
f"Prior synthetic-pair run got ~0.01 -- that was the smoking gun.\n")
print(tabulate(agg_rows, headers="keys", tablefmt="tsv", floatfmt=".3f"))
print()
print(f"out: {cfg.out_path}")
print(f"argv: verify_vhack_heldout --model={cfg.model} --v-hack-path={cfg.v_hack_path}")
print(f"main metric: median_energy={median_energy:.3f} [modules={len(all_cos)}]")
print(f"{cue} modules={len(all_cos)} mean={mean_energy:.3f} median={median_energy:.3f}")
frac_pos = (t_all > 0).float().mean().item()
mean_cos = mean_energy
median_cos = median_energy
# save for downstream plotting / sanity. Cos values as a single tensor;
# module names in the metadata header (JSON-encoded preserves order).
names = [n for n, _ in rows_all]
cos_t = torch.tensor([c for _, c in rows_all], dtype=torch.float32)
save_file(
{"cos": cos_t},
str(cfg.out_path),
metadata={"model": cfg.model, "dtype": cfg.dtype, "names": json.dumps(names)},
)
gate_pass = frac_pos > 0.50
target_pass = mean_cos > 0.20
if not gate_pass:
logger.error(f"GATE FAIL: frac>0 = {frac_pos:.3f} <= 0.50")
return 1
if not target_pass:
logger.warning(f"TARGET MISS: mean_cos = {mean_cos:+.3f} <= 0.20 (gate passes but signal weak)")
else:
logger.info(f"TARGET PASS: mean_cos = {mean_cos:+.3f} > 0.20")
return 0
if __name__ == "__main__":
sys.exit(main(tyro.cli(Config)))