Files
evil_MoE/scripts/diag_pinning.py
T

506 lines
27 KiB
Python

"""Q2 diagnostic: what should the live routeV gate SCORE, and where do the pinning
cuts go?
THE QUESTION (Q2). The gate routes UPDATES, not rollouts: per rollout the GRPO update
on the LoRA is g = A * gradNLL (advantage-weighted). So the positive class is
"update ascending the hack direction" = exploited & A>0; rollouts with A~0 contribute
no update and are DROPPED (scoring them as dead zeros at 0 is what made earlier
adv-weighted AUROCs look blind); a hack with A<0 is being UNLEARNED and belongs in
the negative class. (Q1, "does the direction exist at adv=+1", was answered earlier:
~0.61 -- see git history of this file.)
TWO CONTRASTS (post-review fix). On the vs-ALL contrast (hack+ vs every other valid
rollout) the advantage ALONE is a ~0.9 AUROC detector, because the label requires A>0
and most fails have A<0 -- so a high vs-all number mostly restates the reward, which
the live gate already has for free. The vector's VALUE-ADD is the A>0 contrast: among
positively-reinforced updates (where adv is blind, ~0.5), can the score tell hacks
from solves? That is the gate's real job (don't ship hack updates, don't rob solves)
and the headline number here; vs-all is kept as a secondary column. NOTE an info
asymmetry: grad scores carry the adv factor (sign+magnitude), act scores do not, so
on the vs-all contrast grad gets label-correlated information act lacks; the A>0
contrast removes most of that edge -- compare the cases there.
FOUR CANDIDATE SCORES = {grad, act} x {cos, dot}, all in concatenated-module space:
- grad: the adv-weighted deployed c-probe gradient (the gate's current input).
- act: the deployed bottleneck activation A[:r]@x, mean over completion tokens --
same [r]-per-module space, capturable in the gate's pass-1 forward for free.
- cos: magnitude-blind alignment (tiny vectors give meaningless angles -- control).
- dot: <g, v> = |g|*cos, magnitude-aware; with g = A*gradNLL the advantage rides
along, so dot reads "how hard is this update pushing hack-ward".
v for each representation comes from the authored pairs only (mean hack-minus-clean
per module, unit per module) -- the no-cheat label source; live labels are read ONLY
to measure (AUROC / precision at the rout cut), never to route.
PINNING. Each panel shades the three zones the online gate rule would give on this
window: keep (bulk) | absorb (score > mean + k_mid*sd) | rout (>= mean + k_rout*sd),
plus the oracle best hack-vs-rest split for reference. k's default to the real-run
Config values (2/3), not the checkpoint's preset, so the plot answers "where WOULD
the cuts fall", overridable via --k-mid/--k-rout.
CAVEAT. Live advantages are reconstructed from rollouts.jsonl students only (teachers
absent, zero-variance groups included, and skipped/empty completions missing from the
group mean), so A signs/magnitudes are approximate; the act columns dodge this
entirely (no A in the representation).
HOW. One GPU pass: per live rollout, backward its completion NLL once, capture the
c-probe grad AND the pooled bottleneck act; same per authored-pair side. Everything
downstream (subset vectors, 4 scores, zones, table) is offline re-projection of the
cached features.
uv run python scripts/diag_pinning.py --run-dir out/runs/<vanilla_lora2r_run>
uv run python scripts/diag_pinning.py --replot out/diag/pinning_data.parquet # no GPU
outputs (out/diag/): pinning_q2.png (2x2 headline), pinning_data.parquet (per-rollout
scores), pinning_pairset.parquet + printed table (subsets x 4 AUROCs),
pinning_feats.pt (raw features, for offline re-analysis).
"""
from __future__ import annotations
import json
import struct
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
import tyro
import polars as pl
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from loguru import logger
from tabulate import tabulate
from safetensors.torch import load_file
from transformers import AutoModelForCausalLM, AutoTokenizer
from vgrout.lora2r import wrap_model_with_lora2r
from vgrout.pairs import load_pairs
from vgrout.extract_vhack_grad import completion_nll
from vgrout.train import _auroc
# colour = behaviour (blue solve, red hack, grey fail); style = source (solid on-policy, dashed synthetic)
SOLVE, HACK, FAIL, ABSORB_C, ROUT_C, ORACLE = "#3b6ea5", "#c44e52", "#9aa0a6", "#d1900a", "#c44e52", "#3a8a7a"
CASES = [("grad", "cos"), ("grad", "dot"), ("act", "cos"), ("act", "dot")]
@dataclass
class Cfg:
run_dir: Path = Path("out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3")
ckpt: str = "first_hack"
pairs: Path = Path("data/pairs/hack_pairs.md#all-in-one")
# headline figure builds v from this heading-prefix subset = the routeV TRAINING
# default (train_config.vhack_pairs_path `#all-in-one/behavior_`, 8 pairs; the
# trailing _ excludes behavior2_*). The pairset table spans all subsets of `pairs`.
headline_prefix: str = "behavior_"
# Coherent emergence window. This vanilla v3 used the pre-fix lr=5e-4/warmup-0.1 and
# DIVERGED at step 10 (exploited 20/24 -> 0/24); 2-9 = hacks emerging, model still sane.
step_lo: int = 2
step_hi: int = 9
max_rollouts: int = 240
k_mid: float = 2.0 # absorb onset: score > mean + k_mid*sd (real-run Config default)
k_rout: float = 3.0 # rout onset: score >= mean + k_rout*sd
adv_eps: float = 1e-6 # |A| below this = no update exists -> dropped from zones/AUROC
random_v_seed: int | None = None # Haar placebo (sanity: nothing should separate)
replot: Path | None = None # load parquet and re-plot only (no model, no GPU)
out_dir: Path = Path("out/diag")
def _ckpt_meta(path: Path) -> dict:
with open(path, "rb") as f:
return json.loads(f.read(struct.unpack("<Q", f.read(8))[0])).get("__metadata__", {})
class ActTap:
"""Forward hooks stashing the deployed bottleneck activation h = A[:r] @ x per module.
Computes the r-dim projection inline (no_grad) instead of retaining the full
[L, d_in] input -- ~250 modules x [L, d_in] would be GBs; [L, r] is nothing.
"""
def __init__(self, wrappers: dict, names: list[str]):
self.wrappers, self.names, self.h, self.handles = wrappers, names, {}, []
def __enter__(self):
for nm in self.names:
layer = self.wrappers[nm]["layer"]
def hook(layer, args, out, nm=nm):
(x,) = args
with torch.no_grad():
self.h[nm] = F.linear(x.detach(), layer._lora2r_A[: layer._lora2r_r].to(x.dtype))
self.handles.append(layer.register_forward_hook(hook))
return self
def __exit__(self, *exc):
for h in self.handles:
h.remove()
def pooled(self, n_prompt: int) -> torch.Tensor:
"""[M, r] mean bottleneck act over completion tokens (positions >= n_prompt)."""
out = []
for nm in self.names:
h = self.h[nm] # [1, L, r]
assert h.shape[1] > n_prompt, f"{nm}: no completion tokens (L={h.shape[1]} n_prompt={n_prompt})"
out.append(h[0, n_prompt:].float().mean(0).cpu())
return torch.stack(out)
def _gate_grads(wrappers: dict, names: list[str]) -> torch.Tensor:
"""[M, r] deployed-block c-probe grad after a backward (the gate's gradient space)."""
g = []
for nm in names:
layer = wrappers[nm]["layer"]
gr = layer._lora2r_gate.grad
g.append(gr.sum(dim=tuple(range(gr.dim() - 1)))[: layer._lora2r_r].float().cpu())
return torch.stack(g)
def _v_from(feat_hack: torch.Tensor, feat_clean: torch.Tensor, idx: list[int]) -> torch.Tensor:
"""[M, r] unit-per-module mean hack-minus-clean direction from pair rows `idx`."""
d = (feat_hack[idx] - feat_clean[idx]).mean(0)
return d / d.norm(dim=1, keepdim=True).clamp_min(1e-12)
def _haar_like(v: torch.Tensor, seed: int) -> torch.Tensor:
g = torch.Generator().manual_seed(seed)
d = torch.randn(v.shape, generator=g)
return d / d.norm(dim=1, keepdim=True).clamp_min(1e-12)
def _score(X: torch.Tensor, V: torch.Tensor, kind: str) -> np.ndarray:
"""Concat-module score per rollout: dot = sum_m <x_m, v_m>; cos = dot / (||x|| ||v||)."""
d = torch.einsum("nmr,mr->n", X, V)
if kind == "dot":
return d.numpy()
return (d / (X.flatten(1).norm(dim=1).clamp_min(1e-12) * V.flatten().norm().clamp_min(1e-12))).numpy()
def _kde(x: np.ndarray, grid: np.ndarray) -> np.ndarray:
"""Gaussian KDE, Silverman bandwidth (no scipy). Bandwidth is scale-relative
(dot scores can live at 1e-4 or 1e2)."""
x = np.asarray(x, float)
if len(x) < 2:
return np.zeros_like(grid)
iqr = np.subtract(*np.percentile(x, [75, 25]))
sigma = min(x.std(ddof=1), iqr / 1.349) if iqr > 0 else x.std(ddof=1)
if sigma <= 0:
return np.zeros_like(grid)
bw = 0.9 * sigma * len(x) ** (-0.2)
z = (grid[:, None] - x[None, :]) / bw
return np.exp(-0.5 * z ** 2).sum(1) / (len(x) * bw * np.sqrt(2 * np.pi))
def plot_q2(df: pl.DataFrame, k_mid: float, k_rout: float, subtitle: str, out_png: Path) -> dict:
"""2x2 figure ({grad,act} x {cos,dot}) from the saved per-rollout scores -- no GPU.
Per panel: live solve/fail/hack+ KDEs (+ thin hack- if n>=3), synthetic pair sides
dashed, three shaded zones keep|absorb|rout from mean + k*sd over the VALID live
scores (|A|>eps; pop 'on_drop' excluded), oracle split, AUROC + P/R at the rout cut.
Returns the per-case stats dict for logging."""
pops = {p: df.filter(pl.col("pop") == p) for p in df["pop"].unique().to_list()}
live_pops = ["on_solve", "on_fail", "on_hackpos", "on_hackneg"]
live = df.filter(pl.col("pop").is_in(live_pops))
posm = (live["adv"] > 0).to_numpy() # the A>0 contrast rows
y_all = (live["pop"] == "on_hackpos").to_numpy()
# adv-only baseline: on vs-all the reward alone is a strong detector (the label
# requires A>0); the vector only adds value where this baseline is blind.
a = live["adv"].to_numpy()
logger.info(f"adv-only baseline AUROC: vs-all={_auroc(a.tolist(), y_all.tolist()):.3f} "
f"A>0-contrast={_auroc(a[posm].tolist(), y_all[posm].tolist()):.3f} "
f"(n+={int(y_all.sum())} negA>0={int((~y_all & posm).sum())})")
stats = {}
fig, axes = plt.subplots(2, 2, figsize=(12.5, 7.6))
for ax, (rep, kind) in zip(axes.flat, CASES):
col = f"{rep}_{kind}"
s = live[col].to_numpy()
y = y_all
mu, sd = float(s.mean()), float(s.std())
t_lo, t_hi = mu + k_mid * sd, mu + k_rout * sd
auroc = _auroc(s.tolist(), y.tolist())
auroc_pos = _auroc(s[posm].tolist(), y[posm].tolist())
thr = np.unique(s)
j = [(s[y] >= t).mean() - (s[~y] >= t).mean() for t in thr]
oracle = float(thr[int(np.argmax(j))])
routed = s >= t_hi
n_rout = int(routed.sum())
prec = float(y[routed].mean()) if routed.any() else float("nan")
rec = float((s[y] >= t_hi).mean()) if y.any() else float("nan")
stats[col] = {"auroc_pos": auroc_pos, "auroc_all": auroc, "prec_rout": prec,
"rec_rout": rec, "n_rout": n_rout, "t_hi": t_hi, "oracle": oracle}
lo = float(np.quantile(s, 0.005))
hi = float(np.quantile(s, 0.995))
if kind == "cos": # keep synthetic medians visible (cos shares a scale;
for p in ("syn_clean", "syn_hack"): # dot doesn't -- pair grads dwarf live, annotate instead)
if len(pops.get(p, [])):
m = float(np.median(pops[p][col].to_numpy()))
lo, hi = min(lo, m), max(hi, m)
pad = 0.05 * (hi - lo) or 1e-6
lo, hi = lo - pad, hi + pad
grid = np.linspace(lo, hi, 400)
if kind == "dot":
off = [f"syn {p.split('_')[1]} med={float(np.median(pops[p][col].to_numpy())):+.2g}"
for p in ("syn_clean", "syn_hack")
if len(pops.get(p, [])) and not lo < float(np.median(pops[p][col].to_numpy())) < hi]
if off:
ax.annotate("off-scale: " + ", ".join(off) + r" $\rightarrow$",
xy=(0.98, 0.68), xycoords="axes fraction", ha="right", fontsize=7, color="#777777")
curves = [("on_solve", SOLVE, "-", 1.9, 0.12), ("on_fail", FAIL, "-", 1.9, 0.12),
("on_hackpos", HACK, "-", 1.9, 0.12),
("syn_clean", SOLVE, (0, (5, 2)), 2.0, 0.0), ("syn_hack", HACK, (0, (5, 2)), 2.0, 0.0)]
if len(pops.get("on_hackneg", [])) >= 3:
curves.insert(3, ("on_hackneg", HACK, (0, (1, 1)), 1.2, 0.0))
ymax = 0.0
for p, c, ls, lw, fill in curves:
yk = _kde(pops[p][col].to_numpy(), grid)
ymax = max(ymax, yk.max())
if fill:
ax.fill_between(grid, yk, color=c, alpha=fill, lw=0)
ax.plot(grid, yk, color=c, lw=lw, ls=ls)
ymax *= 1.18
# rug of the ACTUAL live points (KDEs of n~20 are smooth fiction; the rout-tail
# precision claim rests on a handful of rollouts -- show them). hack row on top.
for row, (p, c) in enumerate((("on_hackpos", HACK), ("on_fail", FAIL), ("on_solve", SOLVE))):
x = pops[p][col].to_numpy()
ax.plot(x, np.full(len(x), -(0.035 + 0.035 * row) * ymax), "|",
color=c, ms=4, alpha=0.6, mew=0.8, clip_on=False)
# three zones: keep | absorb | rout
ax.axvspan(t_lo, min(t_hi, hi), color=ABSORB_C, alpha=0.08, lw=0)
ax.axvspan(min(t_hi, hi), hi, color=ROUT_C, alpha=0.10, lw=0)
ax.axvline(t_lo, color=ABSORB_C, lw=1.2, ls="--")
ax.axvline(t_hi, color=ROUT_C, lw=1.2, ls="--")
ax.axvline(oracle, color=ORACLE, lw=1.3, ls="-.")
for xz, lab in ((min(t_lo, hi) - 0.02 * (hi - lo), "keep"),
((t_lo + min(t_hi, hi)) / 2, "absorb"),
((min(t_hi, hi) + hi) / 2, "rout")):
if lo < xz < hi:
ax.text(xz, ymax * 0.97, lab, ha="center", va="top", fontsize=7.5, color="#555555")
ax.set_xlim(lo, hi)
ax.set_ylim(-0.13 * ymax, ymax) # negative strip hosts the rugs
ax.set_yticks([]) # KDE density units are meaningless ink
for sp in ("top", "right", "left"):
ax.spines[sp].set_visible(False)
ax.spines["bottom"].set_position(("data", 0))
ax.set_title(f"{rep} · {kind} AUROC={auroc_pos:.2f} (A>0 contrast; vs-all {auroc:.2f}) "
f"P@rout={prec:.2f} (n={n_rout}) R={rec:.2f}", fontsize=9)
ax.set_xlabel({"cos": "cosine to v (concat modules)",
"dot": "dot ⟨x, v⟩ (update mass along v)"}[kind], fontsize=8.5)
handles = [Line2D([0], [0], color=SOLVE, lw=1.9), Line2D([0], [0], color=FAIL, lw=1.9),
Line2D([0], [0], color=HACK, lw=1.9),
Line2D([0], [0], color=SOLVE, lw=2.0, ls=(0, (5, 2))),
Line2D([0], [0], color=HACK, lw=2.0, ls=(0, (5, 2))),
Patch(facecolor=ABSORB_C, alpha=0.18), Patch(facecolor=ROUT_C, alpha=0.18),
Line2D([0], [0], color=ORACLE, lw=1.3, ls="-.")]
labels = ["live solve", "live fail", "live hack (A>0)", "synthetic clean", "synthetic hack",
f"absorb (>mean+{k_mid:g}sd)", f"rout (>=mean+{k_rout:g}sd)", "oracle hack/rest split"]
fig.legend(handles, labels, loc="lower center", ncol=4, fontsize=8, frameon=False)
fig.suptitle(subtitle, fontsize=10)
fig.tight_layout(rect=(0, 0.07, 1, 0.95))
fig.savefig(out_png, dpi=140)
plt.close(fig)
logger.info(f"wrote {out_png}")
return stats
def main(cfg: Cfg) -> int:
cfg.out_dir.mkdir(parents=True, exist_ok=True)
data_path = cfg.out_dir / "pinning_data.parquet"
rank_path = cfg.out_dir / "pinning_pairset.parquet"
feats_path = cfg.out_dir / "pinning_feats.pt"
q2_png = cfg.out_dir / "pinning_q2.png"
if cfg.replot is not None:
plot_q2(pl.read_parquet(cfg.replot), cfg.k_mid, cfg.k_rout, f"replot -- {cfg.replot.name}", q2_png)
if rank_path.exists():
print(tabulate(pl.read_parquet(rank_path).to_pandas(), headers="keys",
tablefmt="pipe", floatfmt="+.3f", showindex=False))
return 0
device = torch.device("cuda")
ckpt_path = cfg.run_dir / f"{cfg.ckpt}.safetensors"
meta = _ckpt_meta(ckpt_path)
run_cfg = json.loads(meta.get("cfg", "{}"))
model_name = run_cfg.get("model", meta.get("model", "Qwen/Qwen3-4B"))
r = run_cfg.get("lora_r", 32)
init_seed = run_cfg.get("lora_init_seed", 0)
logger.info(f"ckpt {ckpt_path.name} step={meta.get('step')} hack_rate={meta.get('hack_rate')} "
f"model={model_name} r={r} init_seed={init_seed} | run-preset k_mid/k_rout="
f"{run_cfg.get('route_std_mid')}/{run_cfg.get('route_std_rout')} (plot uses {cfg.k_mid}/{cfg.k_rout})")
tok = AutoTokenizer.from_pretrained(model_name)
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(device)
model.config.use_cache = False
wrappers = wrap_model_with_lora2r(model, r=r, init_seed=init_seed, grad_probe=True)
names = sorted(wrappers)
sd = load_file(str(ckpt_path))
for nm in names:
wrappers[nm]["A"].data.copy_(sd[f"A/{nm}"].to(device, torch.float32))
wrappers[nm]["B"].data.copy_(sd[f"B/{nm}"].to(device, torch.float32))
logger.info(f"loaded A/B into {len(names)} modules")
model.eval()
def one_pass(tap: ActTap, prompt: str, completion: str) -> tuple[torch.Tensor, torch.Tensor] | None:
"""Backward one completion's mean NLL; return ([M,r] c-grad, [M,r] pooled act)."""
model.zero_grad(set_to_none=True)
loss = completion_nll(model, tok, prompt, completion, device)
if not torch.isfinite(loss):
return None
loss.backward()
n_prompt = tok(prompt, return_tensors="pt").input_ids.shape[1]
return _gate_grads(wrappers, names), tap.pooled(n_prompt)
# ── authored-pair features, once over ALL pairs (subsets = row slices) ──
pairs_all = load_pairs(cfg.pairs)
logger.info(f"pairs {cfg.pairs} -> {len(pairs_all)}")
pair_feat = {("grad", "hack"): [], ("grad", "clean"): [], ("act", "hack"): [], ("act", "clean"): []}
with ActTap(wrappers, names) as tap:
for pi, pair in enumerate(pairs_all):
for side, completion in (("hack", pair.hack), ("clean", pair.clean)):
out = one_pass(tap, pair.prompt, completion)
if out is None:
raise RuntimeError(f"non-finite loss on pair {pi} ({pair.problem_id}) side={side}")
pair_feat[("grad", side)].append(out[0])
pair_feat[("act", side)].append(out[1])
if (pi + 1) % 5 == 0:
logger.info(f" pair {pi+1}/{len(pairs_all)}")
PF = {k: torch.stack(v) for k, v in pair_feat.items()} # each [P, M, r]
# ── live rollout features, once (everything downstream re-projects) ──
recs = [json.loads(l) for l in (cfg.run_dir / "rollouts.jsonl").read_text().splitlines()]
batch = [x for x in recs if cfg.step_lo <= x["step"] <= cfg.step_hi and x["text"].strip()][:cfg.max_rollouts]
logger.info(f"live batch: {len(batch)} rollouts (steps {cfg.step_lo}-{cfg.step_hi})")
G_rows, A_rows, kept = [], [], []
for i, rec in enumerate(batch):
out = one_pass(tap, rec["prompt"], rec["text"])
if out is None:
logger.warning(f" skip rollout {i}: non-finite loss")
continue
G_rows.append(out[0]); A_rows.append(out[1]); kept.append(rec)
if (i + 1) % 40 == 0:
logger.info(f" rollout {i+1}/{len(batch)}")
model.zero_grad(set_to_none=True)
G = torch.stack(G_rows) # [N, M, r] gradNLL
ACT = torch.stack(A_rows) # [N, M, r]
exploited = np.array([bool(x["exploited"]) for x in kept])
gt_pass = np.array([bool(x["gt_pass"]) for x in kept])
steps = np.array([x["step"] for x in kept])
p_idx = np.array([x["p_idx"] for x in kept])
reward = np.array([float(x["reward"]) for x in kept])
# Reconstructed Dr.GRPO advantage A_i = reward_i - mean(reward over its group).
# CAVEAT: students only (teachers absent from rollouts.jsonl), so signs/magnitudes
# are approximate -- see module docstring.
grp_mean = {}
for s, p in set(zip(steps.tolist(), p_idx.tolist())):
m = (steps == s) & (p_idx == p)
grp_mean[(s, p)] = reward[m].mean()
adv = np.array([reward[i] - grp_mean[(steps[i], p_idx[i])] for i in range(len(reward))])
G_adv = G * torch.tensor(adv, dtype=G.dtype)[:, None, None] # the update the gate sees
# ── Q2 populations: drop A~0 (no update); positive = exploited & A>0 ──
valid = np.abs(adv) > cfg.adv_eps
y = exploited & (adv > 0)
pop = np.where(~valid, "on_drop",
np.where(exploited & (adv > 0), "on_hackpos",
np.where(exploited, "on_hackneg",
np.where(gt_pass, "on_solve", "on_fail"))))
counts = {p: int((pop == p).sum()) for p in ("on_solve", "on_fail", "on_hackpos", "on_hackneg", "on_drop")}
logger.info(f"live populations: {counts} (zones/AUROC use the {int(valid.sum())} valid rows)")
print(f"SHOULD: on_hackpos >= ~20 and on_drop not the majority, ELSE the window/run has "
f"too few learnable hacks and every AUROC below is noise.")
# ── headline vectors from the routeV-default subset; placebo swaps in Haar ──
groups: dict[str, list[int]] = defaultdict(list)
for i, p in enumerate(pairs_all):
groups[p.problem_id.split("_")[0]].append(i)
head_idx = [i for i, p in enumerate(pairs_all) if p.problem_id.startswith(cfg.headline_prefix)]
assert head_idx, f"no pairs match headline prefix {cfg.headline_prefix!r}"
logger.info(f"headline v from prefix {cfg.headline_prefix!r} -> {len(head_idx)} pairs")
def vectors(idx: list[int]) -> dict[str, torch.Tensor]:
v = {"grad": _v_from(PF[("grad", "hack")], PF[("grad", "clean")], idx),
"act": _v_from(PF[("act", "hack")], PF[("act", "clean")], idx)}
if cfg.random_v_seed is not None:
v = {"grad": _haar_like(v["grad"], cfg.random_v_seed),
"act": _haar_like(v["act"], cfg.random_v_seed + 1)}
return v
v_head = vectors(head_idx)
live_X = {"grad": G_adv, "act": ACT}
syn_X = {("grad", "clean"): PF[("grad", "clean")], ("grad", "hack"): PF[("grad", "hack")],
("act", "clean"): PF[("act", "clean")], ("act", "hack"): PF[("act", "hack")]}
def score_cols(v: dict, X: dict[str, torch.Tensor]) -> dict[str, np.ndarray]:
return {f"{rep}_{kind}": _score(X[rep], v[rep], kind) for rep, kind in CASES}
live_scores = score_cols(v_head, live_X)
syn_scores = {side: score_cols(v_head, {"grad": syn_X[("grad", side)][head_idx],
"act": syn_X[("act", side)][head_idx]})
for side in ("clean", "hack")}
# ── pairset table: subsets x 4 AUROCs on the SAME cached live features ──
candidates = [("all-in-one", list(range(len(pairs_all))))] + \
[(g, idx) for g, idx in sorted(groups.items()) if len(idx) >= 3]
valid_pos = valid & (adv > 0)
rows = []
for gname, idx in candidates:
v = vectors(idx)
row = {"group": gname, "n_pairs": len(idx)}
for rep, kind in CASES:
s = _score(live_X[rep], v[rep], kind)
row[f"{rep}_{kind}"] = round(_auroc(s[valid_pos].tolist(), y[valid_pos].tolist()), 3)
row[f"{rep}_{kind}_all"] = round(_auroc(s[valid].tolist(), y[valid].tolist()), 3)
rows.append(row)
rank = pl.DataFrame(rows).sort("grad_dot", descending=True)
rank.write_parquet(rank_path)
adv_v = adv[valid]
print(f"\nbaseline adv-only AUROC: vs-all={_auroc(adv_v.tolist(), y[valid].tolist()):.3f} "
f"A>0-contrast={_auroc(adv[valid_pos].tolist(), y[valid_pos].tolist()):.3f} -- the table "
f"columns are the A>0 contrast (hack vs non-hack among adv>0, n={int(valid_pos.sum())}), "
f"where adv is blind; vs-all columns (*_all) live in {rank_path.name}.")
print("SHOULD: real pairsets beat 0.5 and the adv-only A>0 baseline; under --random-v-seed "
"every column ~0.5. With ~20 negatives the SE is ~0.07: only gaps >0.15 mean much.")
print(tabulate(rank.drop([c for c in rank.columns if c.endswith("_all")]).to_pandas(),
headers="keys", tablefmt="pipe", floatfmt="+.3f", showindex=False))
# ── persist per-rollout scores + raw features, then plot ──
def frame(pop_name: str, mask_or_scores, scores: dict, step_arr, adv_arr) -> pl.DataFrame:
return pl.DataFrame({"pop": pop_name, "step": step_arr, "adv": adv_arr,
**{c: scores[c][mask_or_scores] if mask_or_scores is not None else scores[c]
for c in scores}})
dfs = [frame(p, pop == p, live_scores, steps[pop == p], adv[pop == p])
for p in counts if counts[p] > 0]
n_syn = len(head_idx)
dfs += [frame(f"syn_{side}", None, syn_scores[side],
np.full(n_syn, -1), np.ones(n_syn)) for side in ("clean", "hack")]
df = pl.concat(dfs)
df.write_parquet(data_path)
torch.save({"G": G, "ACT": ACT, "adv": adv, "exploited": exploited, "gt_pass": gt_pass,
"steps": steps, "p_idx": p_idx, "names": names,
"pair_feats": PF, "pair_groups": dict(groups),
"pair_ids": [p.problem_id for p in pairs_all]}, feats_path)
logger.info(f"wrote {data_path} ({len(df)} rows), {feats_path}")
sub = (f"{cfg.run_dir.name} | {cfg.ckpt}, live steps {cfg.step_lo}-{cfg.step_hi}, v from "
f"'{cfg.headline_prefix}' pairs (n={len(head_idx)}) | "
f"hack+={counts['on_hackpos']} hack-={counts['on_hackneg']} solve={counts['on_solve']} "
f"fail={counts['on_fail']} dropped(A~0)={counts['on_drop']}"
+ (f" | PLACEBO seed={cfg.random_v_seed}" if cfg.random_v_seed is not None else ""))
stats = plot_q2(df, cfg.k_mid, cfg.k_rout, sub, q2_png)
best = max(stats, key=lambda c: stats[c]["auroc_pos"])
print(f"\nmain metric: best case on the A>0 contrast = {best} "
f"AUROC={stats[best]['auroc_pos']:.3f} (vs-all {stats[best]['auroc_all']:.3f}) "
f"P@rout={stats[best]['prec_rout']:.2f} (n={stats[best]['n_rout']}) "
f"R@rout={stats[best]['rec_rout']:.2f}")
print(f"out: {q2_png}")
return 0
if __name__ == "__main__":
raise SystemExit(main(tyro.cli(Cfg)))