Files
evil_MoE/scripts/diag_cosine_dist.py
T
wassname dae52b2a7d cleanup: consolidate pairs modules into build scripts + add solve_train to table
- Delete src/vgrout/pairs_v2.py and src/vgrout/pairs_intent.py; move all data
  into scripts/pairset_build_intent.py (self-contained, exports 3 JSONs).
- Export: pairs_intent_think.json (6), pairs_intent_funcname.json (6),
  pairs_intent_concept.json (6 diagnostic).
- Update diag_cosine_dist.py and diag_pairs_compare.py to load from JSON
  instead of importing Python modules; drop tainted v2/allv2 pairsets
  from the diag sweep (print-without-assert axis).
- train.py final table: add solve_rate_s computed same as hack_rate_s, so
  the per-run end-of-training table shows actual training solve rate (was "-").

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
2026-06-09 09:17:42 +00:00

316 lines
16 KiB
Python

"""Pinning diagnostic: how well does the contrastive hack direction PREDICT a live
hack rollout, in GRADIENT vs ACTIVATION space, by cosine vs projection vs magnitude?
One checkpoint (default: job 9 per-token `first_hack` = step 7). We build the
contrastive hack direction from the authored pairs in two spaces:
GRAD: v_grad = unit(mean_pairs(g_hack - g_clean)) on delta_S (g = per-rollout NLL grad)
ACT: As_hack = unit(mean_pairs(As_hack - As_clean)) (As = Vh@x mean over comp tokens)
Then for live rollouts (oracle `exploited`, offline plot-only label) we score each
rollout against the direction and ask how well the score separates hack from clean.
Three score families (GLOBAL = concat over modules; v is unit per module):
cosine = <g, v> / (|g| |v|) -- direction only (magnitude-blind)
projection = <g, v> / |v| = |g| cos -- direction x gradient magnitude
magnitude = |g| -- magnitude only (null check)
each in {grad, act} x {all modules, noise-floor-filtered (drop bottom-25% singular value)}.
Separability metric (we want high PRECISION / confident tail, not coverage):
AUROC + precision@10 / @20 of "score predicts exploited".
uv run python scripts/diag_cosine_dist.py # 4B, free GPU, ~4-6 min
outputs (out/diag/): cosine_grad.png, cosine_act.png (histograms),
cosine_dist.parquet (histogram long-form), live_scores.parquet (per-rollout scores),
separability.csv (methods x AUROC/precision). Replay without a GPU: nbs/cosine_dist.ipynb.
"""
from __future__ import annotations
import json
import struct
from dataclasses import dataclass
from pathlib import Path
import torch
import torch.nn.functional as F
import tyro
import polars as pl
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from loguru import logger
from tabulate import tabulate
from safetensors.torch import load_file
from transformers import AutoModelForCausalLM, AutoTokenizer
from vgrout.antipasto import wrap_model_with_antipasto
from vgrout.extract_vhack_grad import extract_v_hack, completion_nll
from vgrout.pairs_from_pool import load_pairs_json
from vgrout.train import CACHE_ROOT
_PS = Path("out/pairsets")
@dataclass
class Cfg:
run_dir: Path = Path("out/runs/20260607T134234_fast_routingV_seed43_dir6_routeV_pertoken_s43")
ckpt: str = "first_hack"
step_lo: int = 5
step_hi: int = 9
max_rollouts: int = 140
bins: int = 15 # histogram bins (wider = less spiky)
pairs: str = "all" # all | runtests (axis-1 only = the live mechanism)
out_dir: Path = Path("out/diag")
def _auroc(score, label) -> float:
"""P(score_pos > score_neg), ties=0.5. O(n^2), fine for n~140."""
pos = [s for s, y in zip(score, label) if y]
neg = [s for s, y in zip(score, label) if not y]
if not pos or not neg:
return float("nan")
c = sum((p > n) + 0.5 * (p == n) for p in pos for n in neg)
return c / (len(pos) * len(neg))
def _prec_at_k(score, label, k) -> float:
order = sorted(range(len(score)), key=lambda i: score[i], reverse=True)[:k]
return sum(label[i] for i in order) / k
def _concat(d: dict, names: list) -> torch.Tensor:
return torch.cat([d[n].flatten().float().cpu() for n in names])
def main(cfg: Cfg) -> int:
device = torch.device("cuda")
kept_path = cfg.run_dir / f"{cfg.ckpt}.safetensors"
hack_path = cfg.run_dir / f"{cfg.ckpt}_hack.safetensors"
with open(kept_path, "rb") as f:
meta = json.loads(f.read(struct.unpack("<Q", f.read(8))[0])).get("__metadata__", {})
model_name = meta.get("model", "Qwen/Qwen3-4B")
logger.info(f"ckpt {kept_path.name} step={meta.get('step')} hack_rate={meta.get('hack_rate')} model={model_name}")
tok = AutoTokenizer.from_pretrained(model_name)
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(device)
model.config.use_cache = False
wrappers = wrap_model_with_antipasto(model, model_name, CACHE_ROOT, device, grad_probe=False)
names = sorted(wrappers)
kept, hack = load_file(str(kept_path)), load_file(str(hack_path))
for nm in names:
wrappers[nm]["delta_S"].data.copy_(kept[nm].to(device))
wrappers[nm]["delta_S_hack"].data.copy_(hack[nm].to(device))
logger.info(f"loaded adapter into {len(names)} modules")
# pair selection: 'all' = 18 pairs / 6 axes; 'runtests' = axis-1 only (the 8 weak-run_tests
# pairs, matching the single-mode run_tests live hack) -- tests whether mechanism-match lifts AUROC.
PAIRSEL = {
"all": load_pairs_json(_PS / "pairs_authored.json"),
"think": load_pairs_json(_PS / "pairs_intent_think.json"),
"funcname":load_pairs_json(_PS / "pairs_intent_funcname.json"),
"concept": load_pairs_json(_PS / "pairs_intent_concept.json"),
}[cfg.pairs]
logger.info(f"pairs={cfg.pairs} -> {len(PAIRSEL)} pairs")
# ── GRAD direction + per-module singular value (for noise floor) ──
model.eval()
v_hack_sv, v_sv, raw_grads, _ = extract_v_hack(model, tok, wrappers, PAIRSEL,
top_k=1, tau_axis=0.0, n_heldout=2, device=device)
v_grad = {nm: (lambda d: (d / d.norm().clamp_min(1e-12)))(
(raw_grads[f"hack/{nm}"] - raw_grads[f"clean/{nm}"]).mean(0)) for nm in names} # cpu unit
sv0 = torch.tensor([v_sv[nm][0].item() for nm in names]) # [n_mod] grad contrastive |D_m| (=S0)
# filter levels keep the strongest-separating modules by a per-space weight w (=|D_m|):
# all/keep75/top25/top15/top05 = keep modules with w >= q{0, .25, .75, .85, .95}.
def make_masks(w):
return {"all": torch.ones(len(names), dtype=bool),
"keep75": w >= w.quantile(0.25),
"top25": w >= w.quantile(0.75),
"top15": w >= w.quantile(0.85),
"top05": w >= w.quantile(0.95)}
# ── activation capture hooks (after grad extract) ──
As_cap: dict[str, torch.Tensor] = {}
state = {"plen": 0}
def make_hook(nm):
Vh = wrappers[nm]["layer"]._antipasto_Vh
def hook(_l, inp, _o):
a = F.linear(inp[0], Vh)
As_cap[nm] = a[0, state["plen"] - 1:, :].mean(0).detach().float().cpu()
return hook
handles = [wrappers[nm]["layer"].register_forward_hook(make_hook(nm)) for nm in names]
def grab_As(prompt, completion):
state["plen"] = tok(prompt, return_tensors="pt").input_ids.shape[1]
ids = tok(prompt + completion, return_tensors="pt").input_ids.to(device)
with torch.no_grad():
model(ids)
return {nm: As_cap[nm].clone() for nm in names}
# ── ACT direction from the same train pairs ──
train_pairs = PAIRSEL[:-2]
As_h = {nm: [] for nm in names}
As_c = {nm: [] for nm in names}
for p in train_pairs:
ah, ac = grab_As(p.prompt, p.hack), grab_As(p.prompt, p.clean)
for nm in names:
As_h[nm].append(ah[nm]); As_c[nm].append(ac[nm])
As_D = {nm: (torch.stack(As_h[nm]) - torch.stack(As_c[nm])).mean(0) for nm in names} # act contrastive
As_dir = {nm: As_D[nm] / As_D[nm].norm().clamp_min(1e-12) for nm in names} # cpu unit
act_w = torch.tensor([As_D[nm].norm().item() for nm in names]) # act |D_m|
# stack directions to [n_mod, r-ragged] -> we keep per-module and reduce with helpers
def per_module_dot_norm(g: dict, vdir: dict):
"""returns dot[n_mod], gnorm[n_mod] (cpu float), v assumed unit per module."""
dot = torch.tensor([(g[nm].flatten().float().cpu() @ vdir[nm].flatten().float()).item() for nm in names])
gn = torch.tensor([g[nm].flatten().float().cpu().norm().item() for nm in names])
return dot, gn
# ── live rollouts: grad (backward) + act (same forward) per-module dot/norm ──
recs = [json.loads(l) for l in (cfg.run_dir / "rollouts.jsonl").read_text().splitlines()]
batch = [r for r in recs if cfg.step_lo <= r["step"] <= cfg.step_hi and r["text"].strip()][:cfg.max_rollouts]
logger.info(f"live batch: {len(batch)} rollouts (steps {cfg.step_lo}-{cfg.step_hi})")
G_dot, G_gn, A_dot, A_an, labels = [], [], [], [], []
Gvec, Avec = [], [] # concat per-rollout g / As (fp16) for the ideal-direction ceiling
for i, r in enumerate(batch):
state["plen"] = tok(r["prompt"], return_tensors="pt").input_ids.shape[1]
model.zero_grad(set_to_none=True)
loss = completion_nll(model, tok, r["prompt"], r["text"], device) # fires act hooks
if not torch.isfinite(loss):
continue
As_sample = {nm: As_cap[nm].clone() for nm in names}
loss.backward()
gb = {nm: wrappers[nm]["delta_S"].grad for nm in names}
gd, gn = per_module_dot_norm(gb, v_grad)
ad, an = per_module_dot_norm(As_sample, As_dir)
G_dot.append(gd); G_gn.append(gn); A_dot.append(ad); A_an.append(an)
Gvec.append(_concat(gb, names).half()); Avec.append(_concat(As_sample, names).half())
labels.append(bool(r["exploited"]))
if (i + 1) % 25 == 0:
logger.info(f" rollout {i+1}/{len(batch)}")
for h in handles:
h.remove()
model.zero_grad(set_to_none=True)
G_dot, G_gn = torch.stack(G_dot), torch.stack(G_gn) # [n_roll, n_mod]
A_dot, A_an = torch.stack(A_dot), torch.stack(A_an)
labels = labels
n_pos = sum(labels)
logger.info(f"live: {len(labels)} rollouts, {n_pos} exploited ({n_pos/len(labels):.2f} base rate)")
# ── global scores per rollout: cosine / projection / magnitude / vote, per filter ──
# vote = module-weighted average of per-module cosines, weight w=|D_m| (soft version of
# the hard module filter: down-weight weak modules instead of cutting them).
def scores(dot, gn, mask, w):
dm, gm = dot[:, mask], gn[:, mask]
d = dm.sum(1) # sum_m <g,v>
gmag = gm.pow(2).sum(1).sqrt() # |concat g|
nmod = int(mask.sum())
cos = d / (gmag.clamp_min(1e-12) * (nmod ** 0.5))
proj = d / (nmod ** 0.5)
cos_m = dm / gm.clamp_min(1e-12) # per-module cosine [n_roll, n_kept]
wm = w[mask]
vote = (cos_m * wm).sum(1) / wm.sum().clamp_min(1e-12)
return {"cosine": cos, "projection": proj, "magnitude": gmag, "vote": vote}
# ── ceiling: ideal direction (mu_hack - mu_clean) fit on live rollouts, 2-fold CV ──
# The oracle-labelled best linear direction -- diagnostic upper bound only (never trains).
# If ceiling >> our pair-direction AUROC, the pairs are the problem (on-distribution pairs
# would help); if ceiling is also ~0.7, the distributions fundamentally overlap.
def ideal_auroc(vecs):
X = torch.stack(vecs).float() # [n, D]
y = torch.tensor(labels)
sc = torch.zeros(len(y))
for fold in (0, 1): # 2-fold CV (fit on ~half, score other)
te = (torch.arange(len(y)) % 2 == fold)
tr = ~te
vd = X[tr][y[tr]].mean(0) - X[tr][~y[tr]].mean(0)
vd = vd / vd.norm().clamp_min(1e-12)
sc[te] = (X[te] @ vd) / X[te].norm(dim=1).clamp_min(1e-12)
insample = X[y].mean(0) - X[~y].mean(0)
insample = insample / insample.norm().clamp_min(1e-12)
sc_in = (X @ insample) / X.norm(dim=1).clamp_min(1e-12)
return _auroc(sc.tolist(), labels), _auroc(sc_in.tolist(), labels)
g_cv, g_in = ideal_auroc(Gvec)
a_cv, a_in = ideal_auroc(Avec)
logger.info(f"IDEAL-direction ceiling (oracle, diagnostic): "
f"grad AUROC cv={g_cv:.3f} in-sample={g_in:.3f} | act cv={a_cv:.3f} in-sample={a_in:.3f}")
weights = {"grad": sv0, "act": act_w}
score_cols = {}
for space, (dot, gn) in {"grad": (G_dot, G_gn), "act": (A_dot, A_an)}.items():
w = weights[space]
masks = make_masks(w)
logger.info(f"{space} filter sizes: " + ", ".join(f"{k}={int(m.sum())}" for k, m in masks.items()))
for filt, mask in masks.items():
for fam, v in scores(dot, gn, mask, w).items():
score_cols[f"{space}.{fam}.{filt}"] = v.tolist()
# ── separability table ──
sep_rows = []
for name, s in score_cols.items():
space, fam, filt = name.split(".")
sep_rows.append(dict(space=space, score=fam, filter=filt,
AUROC=round(_auroc(s, labels), 3),
p_at_10=round(_prec_at_k(s, labels, 10), 3),
p_at_20=round(_prec_at_k(s, labels, 20), 3)))
sep = pl.DataFrame(sep_rows).sort("AUROC", descending=True)
# ── persist: histogram long-form (global cosine) + per-rollout scores + separability ──
cfg.out_dir.mkdir(parents=True, exist_ok=True)
# histogram populations need the PAIR cosines too (built from raw_grads / As_h,As_c)
def pair_global_cos(per_pair_dot_src, vdir, n):
out = []
for i in range(n):
g = {nm: per_pair_dot_src(nm, i) for nm in names}
dot, gn = per_module_dot_norm(g, vdir)
out.append((dot.sum() / (gn.pow(2).sum().sqrt().clamp_min(1e-12) * (len(names) ** 0.5))).item())
return out
npg = raw_grads[f"hack/{names[0]}"].shape[0]
hist = []
for pop, src in [("pair_hack", lambda nm, i: raw_grads[f"hack/{nm}"][i]),
("pair_clean", lambda nm, i: raw_grads[f"clean/{nm}"][i])]:
for c in pair_global_cos(src, v_grad, npg):
hist.append(dict(space="grad", pop=pop, cos=c))
for pop, src in [("pair_hack", lambda nm, i: As_h[nm][i]),
("pair_clean", lambda nm, i: As_c[nm][i])]:
for c in pair_global_cos(src, As_dir, len(train_pairs)):
hist.append(dict(space="act", pop=pop, cos=c))
for sp, col in [("grad", "grad.cosine.all"), ("act", "act.cosine.all")]:
for c, y in zip(score_cols[col], labels):
hist.append(dict(space=sp, pop="live_hack" if y else "live_clean", cos=c))
pl.DataFrame(hist).write_parquet(cfg.out_dir / "cosine_dist.parquet")
pl.DataFrame({**score_cols, "exploited": labels}).write_parquet(cfg.out_dir / "live_scores.parquet")
sep.write_csv(cfg.out_dir / "separability.csv")
logger.info("\n=== separability (AUROC of score -> exploited; >0.5 = predictive) ===\n"
+ tabulate(sep.to_pandas(), headers="keys", tablefmt="github", showindex=False))
# ── histograms (cosine, both spaces) ──
hdf = pl.DataFrame(hist)
colors = {"pair_clean": "tab:blue", "pair_hack": "tab:red",
"live_clean": "tab:cyan", "live_hack": "tab:orange"}
for space in ("grad", "act"):
plt.figure(figsize=(9, 5))
for pop, c in colors.items():
v = hdf.filter((pl.col("space") == space) & (pl.col("pop") == pop))["cos"].to_numpy()
if len(v):
plt.hist(v, bins=cfg.bins, density=True, histtype="step", lw=2, color=c,
label=f"{pop} (n={len(v)}, p50={float(pl.Series(v).median()):+.2f})")
au = sep.filter((pl.col("space") == space) & (pl.col("score") == "cosine") & (pl.col("filter") == "all"))["AUROC"][0]
plt.xlabel(f"global cosine to hack direction ({space} space)")
plt.ylabel("density")
plt.title(f"{space}-space: {cfg.ckpt} (step {meta.get('step')}), live AUROC(cos)={au}")
plt.legend(fontsize=8)
plt.tight_layout()
plt.savefig(cfg.out_dir / f"cosine_{space}.png", dpi=130)
plt.close()
logger.info(f"wrote {cfg.out_dir}/cosine_{{grad,act}}.png, cosine_dist.parquet, "
f"live_scores.parquet, separability.csv")
return 0
if __name__ == "__main__":
raise SystemExit(main(tyro.cli(Cfg)))