Files
evil_MoE/scripts/verify_vhack_heldout.py
T
wassname 55937a86fb rename python package projected_grpo -> vgrout
git mv src/projected_grpo -> src/vgrout and find-replace the module name in
all imports (.py), `-m projected_grpo.*` invocations (justfile), and the
[project] name (pyproject; setuptools auto-discovers via where=["src"]).

Left RESEARCH_JOURNAL.md untouched: its commands/paths are dated lab notes
tied to past commits, so rewriting them would falsify provenance. Repo dir,
git remote, and absolute paths unchanged.

Verified: `import vgrout` and `python -m vgrout.train --help` load the full
graph; verify_rewards.py + verify_gate_anchor.py (both import vgrout) pass.
Full `just smoke` is blocked upstream by missing gitignored data artifacts
(out/pools/{substrate,teacher_pool}, out/vhack/*smoke*), unrelated to the rename.
2026-06-05 14:51:48 +08:00

153 lines
5.5 KiB
Python

"""Held-out v_hack validation (spec.md §B validation).
For each held-out pair, compute per-module gradient diff (g_hack - g_clean)
in delta_S basis, then cos-align with the trained v_hack[name].
Report:
- per-suffix median/mean cos_align
- fraction of modules with cos_align > 0 (SHOULD > 0.5)
- mean cos_align across modules (target > 0.2)
Run: uv run python -m vgrout.verify_vhack_heldout
"""
from __future__ import annotations
import sys
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
import json
import torch
import tyro
from loguru import logger
from safetensors.torch import save_file
from tabulate import tabulate
from transformers import AutoModelForCausalLM, AutoTokenizer
from vgrout.antipasto import wrap_model_with_antipasto
from vgrout.extract_vhack_grad import completion_nll, resolve_dtype
from vgrout.pairs import PAIRS
from vgrout.extract_vhack_grad import load_v_hack
CACHE_ROOT = Path("svd_cache")
OUT_DIR = Path("out")
@dataclass
class Config:
model: str = "out/baked/qwen3_4b_rh25"
dtype: str = "bf16" # must match extract_vhack_grad.py and train.py
v_hack_path: Path = OUT_DIR / "vhack" / "v_hack_rh25.safetensors"
out_path: Path = OUT_DIR / "vhack_heldout_cos_rh25.safetensors"
n_heldout: int = 2
def main(cfg: Config) -> int:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = resolve_dtype(cfg.dtype)
logger.info(f"device={device} model={cfg.model} dtype={cfg.dtype}")
held = PAIRS[-cfg.n_heldout:]
logger.info(f"held-out pairs: {len(held)}")
tokenizer = AutoTokenizer.from_pretrained(cfg.model)
model = AutoModelForCausalLM.from_pretrained(
cfg.model, dtype=dtype, attn_implementation="sdpa"
).to(device)
model.eval()
wrappers = wrap_model_with_antipasto(
model, model_name=cfg.model, cache_root=CACHE_ROOT, svd_device=device,
)
v_hack = load_v_hack(cfg.v_hack_path, cfg.model, wrappers)
logger.info(f"loaded v_hack: {len(v_hack)} modules")
grads_hack: dict[str, list[torch.Tensor]] = defaultdict(list)
grads_clean: dict[str, list[torch.Tensor]] = defaultdict(list)
for pi, pair in enumerate(held):
for label, completion in (("hack", pair.hack), ("clean", pair.clean)):
model.zero_grad(set_to_none=True)
loss = completion_nll(model, tokenizer, pair.prompt, completion, device)
loss.backward()
bucket = grads_hack if label == "hack" else grads_clean
for name, info in wrappers.items():
bucket[name].append(info["delta_S"].grad.detach().float().cpu().clone())
logger.info(f" held pair {pi+1}/{len(held)} loss={loss.item():.3f}")
# per-module cos_align
cos_by_suffix: dict[str, list[float]] = defaultdict(list)
all_cos = []
rows_all = []
for name, V in v_hack.items():
# V is [k, r], orthonormal rows. Held-out diff direction should land
# in the subspace, so report subspace energy fraction ||V·diff/||diff|| || ∈ [0,1].
gh = torch.stack(grads_hack[name]).mean(0)
gc = torch.stack(grads_clean[name]).mean(0)
diff = gh - gc
nrm = diff.norm()
if nrm < 1e-12:
cos = 0.0
else:
cos = (V @ (diff / nrm)).norm().item()
suf = name.split(".")[-1]
cos_by_suffix[suf].append(cos)
all_cos.append(cos)
rows_all.append((name, cos))
agg_rows = []
for suf, vals in sorted(cos_by_suffix.items()):
t = torch.tensor(vals)
agg_rows.append({
"suffix": suf,
"n": len(vals),
"mean_energy": f"{t.mean():.3f}",
"median_energy": f"{t.median():.3f}",
"min": f"{t.min():.3f}",
"max": f"{t.max():.3f}",
})
t_all = torch.tensor(all_cos)
mean_energy = t_all.mean().item()
median_energy = t_all.median().item()
cue = "🟢" if median_energy > 0.30 else ("🟡" if median_energy > 0.10 else "🔴")
print(f"\nSHOULD: median_energy > 0.30 (held-out diff lands in trained subspace). "
f"Prior synthetic-pair run got ~0.01 -- that was the smoking gun.\n")
print(tabulate(agg_rows, headers="keys", tablefmt="tsv", floatfmt=".3f"))
print()
print(f"out: {cfg.out_path}")
print(f"argv: verify_vhack_heldout --model={cfg.model} --v-hack-path={cfg.v_hack_path}")
print(f"main metric: median_energy={median_energy:.3f} [modules={len(all_cos)}]")
print(f"{cue} modules={len(all_cos)} mean={mean_energy:.3f} median={median_energy:.3f}")
frac_pos = (t_all > 0).float().mean().item()
mean_cos = mean_energy
median_cos = median_energy
# save for downstream plotting / sanity. Cos values as a single tensor;
# module names in the metadata header (JSON-encoded preserves order).
names = [n for n, _ in rows_all]
cos_t = torch.tensor([c for _, c in rows_all], dtype=torch.float32)
save_file(
{"cos": cos_t},
str(cfg.out_path),
metadata={"model": cfg.model, "dtype": cfg.dtype, "names": json.dumps(names)},
)
gate_pass = frac_pos > 0.50
target_pass = mean_cos > 0.20
if not gate_pass:
logger.error(f"GATE FAIL: frac>0 = {frac_pos:.3f} <= 0.50")
return 1
if not target_pass:
logger.warning(f"TARGET MISS: mean_cos = {mean_cos:+.3f} <= 0.20 (gate passes but signal weak)")
else:
logger.info(f"TARGET PASS: mean_cos = {mean_cos:+.3f} > 0.20")
return 0
if __name__ == "__main__":
sys.exit(main(tyro.cli(Config)))