Files
evil_MoE/scripts/diag_pairs_compare.py
T
wassname dae52b2a7d cleanup: consolidate pairs modules into build scripts + add solve_train to table
- Delete src/vgrout/pairs_v2.py and src/vgrout/pairs_intent.py; move all data
  into scripts/pairset_build_intent.py (self-contained, exports 3 JSONs).
- Export: pairs_intent_think.json (6), pairs_intent_funcname.json (6),
  pairs_intent_concept.json (6 diagnostic).
- Update diag_cosine_dist.py and diag_pairs_compare.py to load from JSON
  instead of importing Python modules; drop tainted v2/allv2 pairsets
  from the diag sweep (print-without-assert axis).
- train.py final table: add solve_rate_s computed same as hack_rate_s, so
  the per-run end-of-training table shows actual training solve rate (was "-").

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
2026-06-09 09:17:42 +00:00

149 lines
6.7 KiB
Python

"""Pairs comparison: which PAIR-SET gives the best high-precision hack direction?
Companion to diag_cosine_dist.py. That script sweeps SCORE x SPACE x FILTER for ONE
pair-set; this one fixes the winning config (grad cosine, the p@10=0.70 corner) and
sweeps the PAIR-SET axis -- authored vs pool-derived (prog_wide family) vs intent.
This is the table I should have built before committing a GPU run to a pairs guess.
The expensive part (140 live-rollout backward passes) is independent of the pairs, so
we cache the per-module live grads ONCE and loop the cheap v_grad build + scoring over
pair-sets. No oracle labels touch training; `exploited` is an offline plot-only label.
uv run python -m scripts.diag_pairs_compare # 4B, free GPU, ~20 min
outputs: out/diag/pairs_compare.csv (pairset x AUROC/p@10/p@20 at grad cosine).
"""
from __future__ import annotations
import json
import struct
from dataclasses import dataclass
from pathlib import Path
import torch
import tyro
import polars as pl
from loguru import logger
from tabulate import tabulate
from safetensors.torch import load_file
from transformers import AutoModelForCausalLM, AutoTokenizer
from vgrout.antipasto import wrap_model_with_antipasto
from vgrout.extract_vhack_grad import extract_v_hack, completion_nll
from vgrout.pairs_from_pool import load_pairs_json
from vgrout.train import CACHE_ROOT
from scripts.diag_cosine_dist import _auroc, _prec_at_k
_PS = Path("out/pairsets")
PAIRSETS = {
"authored_all": lambda: load_pairs_json(_PS / "pairs_authored.json"),
"funcname": lambda: load_pairs_json(_PS / "pairs_intent_funcname.json"),
"think": lambda: load_pairs_json(_PS / "pairs_intent_think.json"),
"prog_wide": lambda: load_pairs_json(_PS / "prog_wide.json"),
"prog_wide_clean": lambda: load_pairs_json(_PS / "prog_wide_clean.json"),
"prog_wider": lambda: load_pairs_json(_PS / "prog_wider.json"),
"prog_widest": lambda: load_pairs_json(_PS / "prog_widest.json"),
"heldout_known_rt": lambda: load_pairs_json(_PS / "heldout_known_runtests.json"),
}
@dataclass
class Cfg:
run_dir: Path = Path("out/runs/20260607T134234_fast_routingV_seed43_dir6_routeV_pertoken_s43")
ckpt: str = "first_hack"
step_lo: int = 5
step_hi: int = 9
max_rollouts: int = 140
out_dir: Path = Path("out/diag")
def main(cfg: Cfg) -> int:
device = torch.device("cuda")
kept_path = cfg.run_dir / f"{cfg.ckpt}.safetensors"
hack_path = cfg.run_dir / f"{cfg.ckpt}_hack.safetensors"
with open(kept_path, "rb") as f:
meta = json.loads(f.read(struct.unpack("<Q", f.read(8))[0])).get("__metadata__", {})
model_name = meta.get("model", "Qwen/Qwen3-4B")
logger.info(f"ckpt {kept_path.name} step={meta.get('step')} model={model_name}")
tok = AutoTokenizer.from_pretrained(model_name)
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(device)
model.config.use_cache = False
wrappers = wrap_model_with_antipasto(model, model_name, CACHE_ROOT, device, grad_probe=False)
names = sorted(wrappers)
kept, hack = load_file(str(kept_path)), load_file(str(hack_path))
for nm in names:
wrappers[nm]["delta_S"].data.copy_(kept[nm].to(device))
wrappers[nm]["delta_S_hack"].data.copy_(hack[nm].to(device))
model.eval()
# ── cache live-rollout per-module grads ONCE (independent of pair-set) ──
recs = [json.loads(l) for l in (cfg.run_dir / "rollouts.jsonl").read_text().splitlines()]
batch = [r for r in recs if cfg.step_lo <= r["step"] <= cfg.step_hi and r["text"].strip()][:cfg.max_rollouts]
gb_cache, gn_cache, labels = [], [], []
for i, r in enumerate(batch):
model.zero_grad(set_to_none=True)
loss = completion_nll(model, tok, r["prompt"], r["text"], device)
if not torch.isfinite(loss):
continue
loss.backward()
gb = {nm: wrappers[nm]["delta_S"].grad.flatten().float().cpu() for nm in names}
gb_cache.append(gb)
gn_cache.append(torch.tensor([gb[nm].norm().item() for nm in names])) # |g_m|, pairs-independent
labels.append(bool(r["exploited"]))
if (i + 1) % 40 == 0:
logger.info(f" cached {i+1}/{len(batch)} live grads")
model.zero_grad(set_to_none=True)
gn_stack = torch.stack(gn_cache) # [n_roll, n_mod]
n_pos = sum(labels)
logger.info(f"live: {len(labels)} rollouts, {n_pos} exploited ({n_pos/len(labels):.2f} base rate)")
# ── score one pair-set at grad cosine (all + keep75 noise filter) ──
def score_pairset(pairs):
_, v_sv, raw_grads, _ = extract_v_hack(
model, tok, wrappers, pairs, top_k=1, tau_axis=0.0, n_heldout=2, device=device)
v_grad = {nm: (lambda d: d / d.norm().clamp_min(1e-12))(
(raw_grads[f"hack/{nm}"] - raw_grads[f"clean/{nm}"]).mean(0).flatten().float().cpu())
for nm in names}
sv0 = torch.tensor([v_sv[nm][0].item() for nm in names])
keep75 = sv0 >= sv0.quantile(0.25)
cos_all, cos_k75 = [], []
for gb, gn in zip(gb_cache, gn_stack):
dot = torch.tensor([(gb[nm] @ v_grad[nm]).item() for nm in names])
for mask, out in [(torch.ones(len(names), dtype=bool), cos_all), (keep75, cos_k75)]:
d = dot[mask].sum()
gmag = gn[mask].pow(2).sum().sqrt()
nmod = int(mask.sum())
out.append((d / (gmag.clamp_min(1e-12) * nmod ** 0.5)).item())
return cos_all, cos_k75
rows = []
for name, fn in PAIRSETS.items():
try:
pairs = fn()
except FileNotFoundError:
logger.warning(f"{name}: pairset file missing, skipping")
continue
model.eval()
ca, ck = score_pairset(pairs)
row = dict(pairset=name, n=len(pairs),
AUROC=round(_auroc(ca, labels), 3),
p10=round(_prec_at_k(ca, labels, 10), 3),
p20=round(_prec_at_k(ca, labels, 20), 3),
p10_keep75=round(_prec_at_k(ck, labels, 10), 3))
rows.append(row)
logger.info(f"{name} (n={len(pairs)}): grad-cosine p10={row['p10']} AUROC={row['AUROC']}")
df = pl.DataFrame(rows).sort(["p10", "AUROC"], descending=True)
cfg.out_dir.mkdir(parents=True, exist_ok=True)
df.write_csv(cfg.out_dir / "pairs_compare.csv")
print("\n=== pair-sets ranked by grad-cosine precision@10 (base rate "
f"{n_pos/len(labels):.2f}, {len(labels)} live rollouts) ===")
print(tabulate(df.to_pandas(), headers="keys", tablefmt="pipe", showindex=False))
return 0
if __name__ == "__main__":
raise SystemExit(main(tyro.cli(Cfg)))