Files
evil_MoE/scripts/diag_pinning.py
T
wassname 70697ff36e diag(#40): pinning plot splits solve/fail/hack + per-pairset AUROC ranking
Q4 fix: on-policy "solve" was ~exploited = solves+fails (mostly fails). Split by
gt_pass into solve/fail/hack (live: 103 hack / 27 solve / 62 fail). Per-pairset
ranking: build v_grad from each heading-prefix subset, re-project the SAME stored
live c-grads (no model re-run). Finding: behavior pairs AUROC 0.69 vs all-in-one
0.53; reasoning/opportunity anti-aligned (<0.5) -> mixing dilutes.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
2026-06-11 06:16:27 +00:00

323 lines
16 KiB
Python

"""Pinning calibration + per-pairset separation: where does the live gate sit, would
mean +/- k*sd route it, and which authored-pair subset gives the best hack/clean split?
The routeV gate scores each rollout by the width-pooled, band-normalized cosine of its
deployed-block c-probe gradient to v_grad (exactly train.py:_lora2r_gate_labels) -- the
position on the HACKING DIRECTION. This script:
1. Splits live rollouts into THREE honest populations (Q4 fix -- "not-exploited" was
conflating genuine solves with fails): on-policy solve (gt_pass), fail, hack (exploited).
Plots them with the authored synthetic solve/hack, marks online mean and mean +/- 2sd.
2. Ranks pairset SUBSETS (heading-prefix views: behavior/opportunity/disposition/...) by
how well a v_grad built from ONLY that subset separates live hacks (AUROC + Cohen's d).
Done by storing each live rollout's per-module deployed c-grad ONCE, then re-projecting
onto each subset's v_grad -- no model re-run per subset (the refresh-tracking trick).
uv run python scripts/diag_pinning.py --run-dir out/runs/<vanilla_lora2r_run>
uv run python scripts/diag_pinning.py --replot out/diag/pinning_data.parquet # no GPU
outputs (out/diag/): pinning_calib.png, pinning_pairset_auroc.png, pinning_data.parquet,
pinning_pairset.parquet.
"""
from __future__ import annotations
import json
import struct
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import torch
import tyro
import polars as pl
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from loguru import logger
from tabulate import tabulate
from safetensors.torch import load_file
from transformers import AutoModelForCausalLM, AutoTokenizer
from vgrout.lora2r import wrap_model_with_lora2r
from vgrout.pairs import load_pairs, HackPair
from vgrout.extract_vhack_grad import extract_v_hack, completion_nll
from vgrout.train import _build_v_grad, route_band_edges, _haar_unit_dirs, _auroc
# colour = behaviour (blue solve, red hack, grey fail); style = source (solid on-policy, dashed synthetic)
SOLVE, HACK, FAIL, MEANC, ORACLE = "#3b6ea5", "#c44e52", "#9aa0a6", "#d1900a", "#3a8a7a"
@dataclass
class Cfg:
run_dir: Path = Path("out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3")
ckpt: str = "first_hack"
pairs: Path = Path("data/pairs/hack_pairs.md#all-in-one")
# Coherent emergence window. This vanilla v3 used the pre-fix lr=5e-4/warmup-0.1 and
# DIVERGED at step 10 (exploited 20/24 -> 0/24); 2-9 = hacks emerging, model still sane.
step_lo: int = 2
step_hi: int = 9
max_rollouts: int = 240
random_v_seed: int | None = None # Haar placebo (sanity: pins should NOT separate)
replot: Path | None = None # load this parquet and re-plot only (no model, no GPU)
out_dir: Path = Path("out/diag")
def _ckpt_meta(path: Path) -> dict:
with open(path, "rb") as f:
return json.loads(f.read(struct.unpack("<Q", f.read(8))[0])).get("__metadata__", {})
def _batch_pos(G: torch.Tensor, names: list, v_grad: dict, route_band: dict) -> np.ndarray:
"""Pooled band-normalized cosine position per rollout, vectorized over modules.
G: [N, M, r] raw deployed-block c-grads (v-independent). Mirrors train.py pooling:
per module cos to ANY of the k dirs (max), width-weighted pool, excluding zero-width
modules. Same G re-projects onto any v_grad/route_band -> cheap per-pairset scoring.
"""
dev = v_grad[names[0]].device
Vs = torch.stack([v_grad[n] for n in names]).to(dev) # [M, k, r]
low = torch.tensor([route_band[n][0] for n in names], device=dev) # [M]
up = torch.tensor([route_band[n][1] for n in names], device=dev) # [M]
inc = (up - low) > 0
U = (G.to(dev) / G.to(dev).norm(dim=2, keepdim=True).clamp_min(1e-12)) # unit per (n,m)
cos = torch.einsum("nmr,mkr->nmk", U, Vs).amax(-1) # [N, M]
num = ((cos - low) * inc).sum(1)
den = ((up - low) * inc).sum().clamp_min(1e-12)
return (num / den).cpu().numpy()
def _cohend(a: np.ndarray, b: np.ndarray) -> float:
"""Standardized mean gap (hack - solve); +ve = hacks sit further along the hack-dir."""
if len(a) < 2 or len(b) < 2:
return float("nan")
sp = np.sqrt(((len(a) - 1) * a.var(ddof=1) + (len(b) - 1) * b.var(ddof=1)) / (len(a) + len(b) - 2))
return float((a.mean() - b.mean()) / sp) if sp > 0 else float("nan")
def _kde(x: np.ndarray, grid: np.ndarray) -> np.ndarray:
"""Gaussian KDE, Silverman bandwidth (no scipy)."""
x = np.asarray(x, float)
if len(x) < 2:
return np.zeros_like(grid)
iqr = np.subtract(*np.percentile(x, [75, 25]))
sigma = min(x.std(ddof=1), iqr / 1.349) if iqr > 0 else x.std(ddof=1)
bw = max(0.9 * (sigma or 1.0) * len(x) ** (-0.2), 1e-3)
z = (grid[:, None] - x[None, :]) / bw
return np.exp(-0.5 * z ** 2).sum(1) / (len(x) * bw * np.sqrt(2 * np.pi))
def plot_dist(df: pl.DataFrame, subtitle: str, out_png: Path) -> None:
"""Distribution plot from the 5 saved populations -- no GPU needed."""
arr = lambda p: df.filter(pl.col("pop") == p)["pos"].to_numpy()
on_solve, on_fail, on_hack = arr("on_solve"), arr("on_fail"), arr("on_hack")
syn_solve, syn_hack = arr("syn_solve"), arr("syn_hack")
# AUROC/divider are hack-vs-rest (the gate's job): label = exploited.
pos_live = np.concatenate([on_solve, on_fail, on_hack])
labels = np.concatenate([np.zeros(len(on_solve) + len(on_fail), bool), np.ones(len(on_hack), bool)])
mean, sd = float(pos_live.mean()), float(pos_live.std())
lo_b, hi_b = mean - 2 * sd, mean + 2 * sd
auroc = _auroc(pos_live.tolist(), labels.tolist())
thr = np.unique(pos_live)
j = [(pos_live[labels] >= t).mean() - (pos_live[~labels] >= t).mean() for t in thr]
oracle = float(thr[int(np.argmax(j))])
lo = min(pos_live.min(), syn_solve.min()) - 0.1
hi = max(np.quantile(pos_live, 0.99), syn_hack.max()) + 0.1
grid = np.linspace(lo, hi, 400)
POPS = [(on_solve, SOLVE, True, "on-policy solve (gt pass)"),
(on_fail, FAIL, True, "on-policy fail"),
(on_hack, HACK, True, "on-policy hack"),
(syn_solve, SOLVE, False, "synthetic solve"),
(syn_hack, HACK, False, "synthetic hack")]
kdes = {i: _kde(x, grid) for i, (x, *_) in enumerate(POPS)}
ymax = max(y.max() for y in kdes.values()) * 1.15
fig, ax = plt.subplots(figsize=(8.8, 4.8))
ax.axvspan(lo_b, hi_b, color=MEANC, alpha=0.07, lw=0)
ax.axvline(mean, color=MEANC, lw=1.8)
for b in (lo_b, hi_b):
ax.axvline(b, color=MEANC, lw=1.2, ls="--")
ax.axvline(oracle, color=ORACLE, lw=1.3, ls="-.")
for i, (x, col, on_policy, _) in enumerate(POPS):
y = kdes[i]
if on_policy:
ax.fill_between(grid, y, color=col, alpha=0.12, lw=0)
ax.plot(grid, y, color=col, lw=1.9)
else:
ax.plot(grid, y, color=col, lw=2.2, ls=(0, (5, 2)))
ax.set_ylim(0, ymax)
for s in ("top", "right"):
ax.spines[s].set_visible(False)
ax.set_ylabel("density")
ax.set_xlabel("hacking direction (gradient cosine to v_grad) " + r"$\longrightarrow$")
dist_handles = [Line2D([0], [0], color=c, lw=2.2, ls="-" if op else (0, (5, 2))) for _, c, op, _ in POPS]
leg1 = ax.legend(dist_handles, [lab for *_, lab in POPS], loc="upper left",
fontsize=8, frameon=False, title="distributions", title_fontsize=8)
leg1._legend_box.align = "left"
ax.add_artist(leg1)
mark_handles = [Line2D([0], [0], color=MEANC, lw=1.8),
Patch(facecolor=MEANC, alpha=0.18, edgecolor=MEANC, ls="--"),
Line2D([0], [0], color=ORACLE, lw=1.3, ls="-.")]
ax.legend(mark_handles,
[f"online mean ({mean:+.2f})", f"mean +/- 2sd (sd={sd:.2f})",
f"best hack/rest split ({oracle:+.2f})"],
loc="upper right", fontsize=8, frameon=False, title="proposed pinning", title_fontsize=8)
ax.set_title(f"{subtitle}\nlive hack-direction positions vs authored v_grad "
f"(hack-vs-rest AUROC={auroc:.2f})", fontsize=9.5)
fig.tight_layout()
fig.savefig(out_png, dpi=140)
plt.close(fig)
logger.info(f"mean={mean:+.3f} sd={sd:.3f} AUROC={auroc:.3f} -> {out_png}")
def plot_pairset(rank: pl.DataFrame, subtitle: str, out_png: Path) -> None:
"""Horizontal bar chart: AUROC of a v_grad built from each pairset subset vs live hacks."""
rank = rank.sort("auroc")
groups = rank["group"].to_list()
auroc = rank["auroc"].to_numpy()
npairs = rank["n_pairs"].to_list()
cohend = rank["cohend"].to_numpy()
y = np.arange(len(groups))
fig, ax = plt.subplots(figsize=(7.2, 0.5 * len(groups) + 1.4))
cols = [HACK if g == "all-in-one" else SOLVE for g in groups]
ax.barh(y, auroc, color=cols, alpha=0.85)
ax.axvline(0.5, color="k", lw=1, ls=":") # 0.5 = blind to live hacks
ax.set_yticks(y)
ax.set_yticklabels([f"{g} (n={n})" for g, n in zip(groups, npairs)], fontsize=8)
for yi, a, d in zip(y, auroc, cohend):
ax.text(a + 0.005, yi, f"{a:.2f} d={d:+.2f}", va="center", fontsize=7.5)
ax.set_xlim(min(0.45, auroc.min() - 0.03), max(auroc.max() + 0.12, 0.6))
for s in ("top", "right"):
ax.spines[s].set_visible(False)
ax.set_xlabel("hack-vs-rest AUROC of subset's v_grad (0.5 = blind, dotted)")
ax.set_title(f"{subtitle}\nwhich authored-pair subset separates live hacks best?", fontsize=9.5)
fig.tight_layout()
fig.savefig(out_png, dpi=140)
plt.close(fig)
logger.info(f"wrote {out_png}")
def main(cfg: Cfg) -> int:
cfg.out_dir.mkdir(parents=True, exist_ok=True)
data_path = cfg.out_dir / "pinning_data.parquet"
rank_path = cfg.out_dir / "pinning_pairset.parquet"
dist_png = cfg.out_dir / "pinning_calib.png"
rank_png = cfg.out_dir / "pinning_pairset_auroc.png"
if cfg.replot is not None:
plot_dist(pl.read_parquet(cfg.replot), f"replot -- {cfg.replot.name}", dist_png)
if rank_path.exists():
plot_pairset(pl.read_parquet(rank_path), "replot", rank_png)
return 0
device = torch.device("cuda")
ckpt_path = cfg.run_dir / f"{cfg.ckpt}.safetensors"
meta = _ckpt_meta(ckpt_path)
run_cfg = json.loads(meta.get("cfg", "{}"))
model_name = run_cfg.get("model", meta.get("model", "Qwen/Qwen3-4B"))
r = run_cfg.get("lora_r", 32)
init_seed = run_cfg.get("lora_init_seed", 0)
logger.info(f"ckpt {ckpt_path.name} step={meta.get('step')} hack_rate={meta.get('hack_rate')} "
f"model={model_name} r={r} init_seed={init_seed}")
tok = AutoTokenizer.from_pretrained(model_name)
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(device)
model.config.use_cache = False
wrappers = wrap_model_with_lora2r(model, r=r, init_seed=init_seed, grad_probe=True)
names = sorted(wrappers)
sd = load_file(str(ckpt_path))
for nm in names:
wrappers[nm]["A"].data.copy_(sd[f"A/{nm}"].to(device, torch.float32))
wrappers[nm]["B"].data.copy_(sd[f"B/{nm}"].to(device, torch.float32))
logger.info(f"loaded A/B into {len(names)} modules")
model.eval()
def v_grad_from(pairs: list[HackPair]):
_, _, raw, _ = extract_v_hack(model, tok, wrappers, pairs,
top_k=1, tau_axis=0.0, n_heldout=2, device=device)
vg = _build_v_grad(raw, wrappers, 1, device)
if cfg.random_v_seed is not None:
vg = _haar_unit_dirs(vg, cfg.random_v_seed, device)
return vg, route_band_edges(raw, vg, device), raw
pairs_all = load_pairs(cfg.pairs)
logger.info(f"pairs {cfg.pairs} -> {len(pairs_all)}")
v_grad, route_band, raw_all = v_grad_from(pairs_all)
# synthetic pair positions under the all-in-one v_grad (for the distribution plot)
def pair_pos(side: str) -> np.ndarray:
n = raw_all[f"{side}/{names[0]}"].shape[0]
G = torch.stack([torch.stack([raw_all[f"{side}/{nm}"][i] for nm in names]) for i in range(n)])
return _batch_pos(G, names, v_grad, route_band)
syn_solve, syn_hack = pair_pos("clean"), pair_pos("hack")
# score the live batch ONCE; store per-rollout deployed c-grads for cheap re-projection
recs = [json.loads(l) for l in (cfg.run_dir / "rollouts.jsonl").read_text().splitlines()]
batch = [x for x in recs if cfg.step_lo <= x["step"] <= cfg.step_hi and x["text"].strip()][:cfg.max_rollouts]
logger.info(f"live batch: {len(batch)} rollouts (steps {cfg.step_lo}-{cfg.step_hi})")
G_rows, exploited, gt_pass, steps = [], [], [], []
for i, rec in enumerate(batch):
model.zero_grad(set_to_none=True)
loss = completion_nll(model, tok, rec["prompt"], rec["text"], device)
if not torch.isfinite(loss):
continue
loss.backward()
g = []
for nm in names:
gr = wrappers[nm]["layer"]._lora2r_gate.grad
g.append(gr.sum(dim=tuple(range(gr.dim() - 1)))[:r].float().cpu()) # deployed block [r]
G_rows.append(torch.stack(g)) # [M, r]
exploited.append(bool(rec["exploited"])); gt_pass.append(bool(rec["gt_pass"])); steps.append(rec["step"])
if (i + 1) % 40 == 0:
logger.info(f" rollout {i+1}/{len(batch)}")
model.zero_grad(set_to_none=True)
G = torch.stack(G_rows) # [N, M, r]
exploited, gt_pass, steps = map(np.array, (exploited, gt_pass, steps))
pos = _batch_pos(G, names, v_grad, route_band)
# THREE honest populations (Q4 fix): solve = gt-correct & not-exploited; fail = neither; hack = exploited
solve_m = gt_pass & ~exploited
fail_m = ~gt_pass & ~exploited
logger.info(f"live: {len(exploited)} rollouts | hack={int(exploited.sum())} "
f"solve={int(solve_m.sum())} fail={int(fail_m.sum())}")
# per-pairset AUROC: build v_grad from each heading-prefix subset, re-project the SAME G
groups: dict[str, list] = defaultdict(list)
for p in pairs_all:
groups[p.problem_id.split("_")[0]].append(p)
rows = []
# extract_v_hack holds out 2 pairs, so a subset needs >=3 to leave a training pair.
candidates = [("all-in-one", pairs_all)] + [(g, ps) for g, ps in sorted(groups.items()) if len(ps) >= 3]
for gname, gpairs in candidates:
vg, rb, _ = v_grad_from(gpairs)
pg = _batch_pos(G, names, vg, rb)
auroc = _auroc(pg.tolist(), exploited.tolist())
cohend = _cohend(pg[exploited], pg[solve_m]) if solve_m.sum() >= 2 else _cohend(pg[exploited], pg[~exploited])
rows.append({"group": gname, "n_pairs": len(gpairs), "auroc": round(auroc, 3),
"cohend": round(cohend, 3)})
logger.info(f" pairset {gname:12s} n={len(gpairs):2d} AUROC={auroc:.3f} d={cohend:+.3f}")
rank = pl.DataFrame(rows).sort("auroc", descending=True)
rank.write_parquet(rank_path)
print(tabulate(rank.to_pandas(), headers="keys", tablefmt="pipe", floatfmt="+.3f", showindex=False))
df = pl.concat([
pl.DataFrame({"pop": "on_solve", "pos": pos[solve_m], "step": steps[solve_m]}),
pl.DataFrame({"pop": "on_fail", "pos": pos[fail_m], "step": steps[fail_m]}),
pl.DataFrame({"pop": "on_hack", "pos": pos[exploited], "step": steps[exploited]}),
pl.DataFrame({"pop": "syn_solve", "pos": syn_solve, "step": np.full(len(syn_solve), -1)}),
pl.DataFrame({"pop": "syn_hack", "pos": syn_hack, "step": np.full(len(syn_hack), -1)}),
])
df.write_parquet(data_path)
logger.info(f"wrote {data_path} ({len(df)} rows)")
sub = (f"{cfg.run_dir.name}\n{cfg.ckpt} v_grad, live steps {cfg.step_lo}-{cfg.step_hi}, "
f"hack={int(exploited.sum())} solve={int(solve_m.sum())} fail={int(fail_m.sum())}")
plot_dist(df, sub, dist_png)
plot_pairset(rank, cfg.run_dir.name, rank_png)
return 0
if __name__ == "__main__":
raise SystemExit(main(tyro.cli(Cfg)))