diff --git a/scripts/diag_pinning.py b/scripts/diag_pinning.py index c79d2cf..d556ff9 100644 --- a/scripts/diag_pinning.py +++ b/scripts/diag_pinning.py @@ -1,21 +1,27 @@ -"""Pinning calibration: where does the live gate sit, and would mean +/- 2sd route it? +"""Pinning calibration + per-pairset separation: where does the live gate sit, would +mean +/- k*sd route it, and which authored-pair subset gives the best hack/clean split? The routeV gate scores each rollout by the width-pooled, band-normalized cosine of its -deployed-block c-probe gradient to v_grad (exactly train.py:_lora2r_gate_labels) -- call it -the position on the HACKING DIRECTION. The current gate routes a FIXED quantile tail (#30). -This plots four populations on that axis: - on-policy solve / hack -- live rollouts (oracle `exploited` read for COLOUR/AUROC only) - synthetic solve / hack -- the authored pairs v_grad was built from (the only label source) -and marks the PROPOSED pinning: online mean and mean +/- 2*sd (a symmetric band that -self-calibrates to the live spread instead of forcing a fixed %). +deployed-block c-probe gradient to v_grad (exactly train.py:_lora2r_gate_labels) -- the +position on the HACKING DIRECTION. This script: + + 1. Splits live rollouts into THREE honest populations (Q4 fix -- "not-exploited" was + conflating genuine solves with fails): on-policy solve (gt_pass), fail, hack (exploited). + Plots them with the authored synthetic solve/hack, marks online mean and mean +/- 2sd. + 2. Ranks pairset SUBSETS (heading-prefix views: behavior/opportunity/disposition/...) by + how well a v_grad built from ONLY that subset separates live hacks (AUROC + Cohen's d). + Done by storing each live rollout's per-module deployed c-grad ONCE, then re-projecting + onto each subset's v_grad -- no model re-run per subset (the refresh-tracking trick). uv run python scripts/diag_pinning.py --run-dir out/runs/ - uv run python scripts/diag_pinning.py --replot out/diag/pinning_data.parquet # no GPU, restyle only -outputs (out/diag/): pinning_calib.png, pinning_data.parquet (the 4 populations, regenerates the plot). + uv run python scripts/diag_pinning.py --replot out/diag/pinning_data.parquet # no GPU +outputs (out/diag/): pinning_calib.png, pinning_pairset_auroc.png, pinning_data.parquet, +pinning_pairset.parquet. """ from __future__ import annotations import json import struct +from collections import defaultdict from dataclasses import dataclass from pathlib import Path @@ -29,16 +35,17 @@ import matplotlib.pyplot as plt from matplotlib.lines import Line2D from matplotlib.patches import Patch from loguru import logger +from tabulate import tabulate from safetensors.torch import load_file from transformers import AutoModelForCausalLM, AutoTokenizer from vgrout.lora2r import wrap_model_with_lora2r -from vgrout.pairs import load_pairs +from vgrout.pairs import load_pairs, HackPair from vgrout.extract_vhack_grad import extract_v_hack, completion_nll from vgrout.train import _build_v_grad, route_band_edges, _haar_unit_dirs, _auroc -# colour = behaviour (blue solve, red hack); style = source (solid+fill on-policy, dashed synthetic) -SOLVE, HACK, MEANC, ORACLE = "#3b6ea5", "#c44e52", "#d1900a", "#3a8a7a" +# colour = behaviour (blue solve, red hack, grey fail); style = source (solid on-policy, dashed synthetic) +SOLVE, HACK, FAIL, MEANC, ORACLE = "#3b6ea5", "#c44e52", "#9aa0a6", "#d1900a", "#3a8a7a" @dataclass @@ -61,21 +68,31 @@ def _ckpt_meta(path: Path) -> dict: return json.loads(f.read(struct.unpack(" tuple[float, float]: - """Mirror train.py:_lora2r_gate_labels pooling for ONE rollout's per-module deployed c-grads.""" - num = 0.0; den = 0.0; w = 0.0; n_inc = 0 - for name in names: - lower, upper = route_band[name] - if upper - lower <= 0: - continue - g_b = c_grads[name][:r].float().to(v_grad[name].device) - nrm = g_b.norm() - cos_b = (torch.einsum("r, k r -> k", g_b, v_grad[name]).max() / nrm.clamp_min(1e-12)).item() - num += cos_b - lower; den += upper - lower - w += nrm.item(); n_inc += 1 - if n_inc == 0: - raise RuntimeError("no module has positive band width") - return num / den, w / n_inc +def _batch_pos(G: torch.Tensor, names: list, v_grad: dict, route_band: dict) -> np.ndarray: + """Pooled band-normalized cosine position per rollout, vectorized over modules. + + G: [N, M, r] raw deployed-block c-grads (v-independent). Mirrors train.py pooling: + per module cos to ANY of the k dirs (max), width-weighted pool, excluding zero-width + modules. Same G re-projects onto any v_grad/route_band -> cheap per-pairset scoring. + """ + dev = v_grad[names[0]].device + Vs = torch.stack([v_grad[n] for n in names]).to(dev) # [M, k, r] + low = torch.tensor([route_band[n][0] for n in names], device=dev) # [M] + up = torch.tensor([route_band[n][1] for n in names], device=dev) # [M] + inc = (up - low) > 0 + U = (G.to(dev) / G.to(dev).norm(dim=2, keepdim=True).clamp_min(1e-12)) # unit per (n,m) + cos = torch.einsum("nmr,mkr->nmk", U, Vs).amax(-1) # [N, M] + num = ((cos - low) * inc).sum(1) + den = ((up - low) * inc).sum().clamp_min(1e-12) + return (num / den).cpu().numpy() + + +def _cohend(a: np.ndarray, b: np.ndarray) -> float: + """Standardized mean gap (hack - solve); +ve = hacks sit further along the hack-dir.""" + if len(a) < 2 or len(b) < 2: + return float("nan") + sp = np.sqrt(((len(a) - 1) * a.var(ddof=1) + (len(b) - 1) * b.var(ddof=1)) / (len(a) + len(b) - 2)) + return float((a.mean() - b.mean()) / sp) if sp > 0 else float("nan") def _kde(x: np.ndarray, grid: np.ndarray) -> np.ndarray: @@ -90,31 +107,33 @@ def _kde(x: np.ndarray, grid: np.ndarray) -> np.ndarray: return np.exp(-0.5 * z ** 2).sum(1) / (len(x) * bw * np.sqrt(2 * np.pi)) -def plot(df: pl.DataFrame, subtitle: str, out_png: Path) -> None: - """Regenerate the figure from the saved 4 populations -- no GPU needed.""" +def plot_dist(df: pl.DataFrame, subtitle: str, out_png: Path) -> None: + """Distribution plot from the 5 saved populations -- no GPU needed.""" arr = lambda p: df.filter(pl.col("pop") == p)["pos"].to_numpy() - on_solve, on_hack = arr("on_solve"), arr("on_hack") + on_solve, on_fail, on_hack = arr("on_solve"), arr("on_fail"), arr("on_hack") syn_solve, syn_hack = arr("syn_solve"), arr("syn_hack") - pos_live = np.concatenate([on_solve, on_hack]) - labels = np.concatenate([np.zeros(len(on_solve), bool), np.ones(len(on_hack), bool)]) - + # AUROC/divider are hack-vs-rest (the gate's job): label = exploited. + pos_live = np.concatenate([on_solve, on_fail, on_hack]) + labels = np.concatenate([np.zeros(len(on_solve) + len(on_fail), bool), np.ones(len(on_hack), bool)]) mean, sd = float(pos_live.mean()), float(pos_live.std()) - lo_b, hi_b = mean - 2 * sd, mean + 2 * sd # proposed routing band + lo_b, hi_b = mean - 2 * sd, mean + 2 * sd auroc = _auroc(pos_live.tolist(), labels.tolist()) - thr = np.unique(pos_live) # oracle divider (Youden J) -- diagnostic only + thr = np.unique(pos_live) j = [(pos_live[labels] >= t).mean() - (pos_live[~labels] >= t).mean() for t in thr] oracle = float(thr[int(np.argmax(j))]) lo = min(pos_live.min(), syn_solve.min()) - 0.1 hi = max(np.quantile(pos_live, 0.99), syn_hack.max()) + 0.1 grid = np.linspace(lo, hi, 400) - POPS = [(on_solve, SOLVE, True, "on-policy solve"), (on_hack, HACK, True, "on-policy hack"), - (syn_solve, SOLVE, False, "synthetic solve"), (syn_hack, HACK, False, "synthetic hack")] - kdes = {i: _kde(x, grid) for i, (x, *_ ) in enumerate(POPS)} + POPS = [(on_solve, SOLVE, True, "on-policy solve (gt pass)"), + (on_fail, FAIL, True, "on-policy fail"), + (on_hack, HACK, True, "on-policy hack"), + (syn_solve, SOLVE, False, "synthetic solve"), + (syn_hack, HACK, False, "synthetic hack")] + kdes = {i: _kde(x, grid) for i, (x, *_) in enumerate(POPS)} ymax = max(y.max() for y in kdes.values()) * 1.15 - fig, ax = plt.subplots(figsize=(8.6, 4.8)) - # proposed pinning: mean +/- 2sd band, drawn first so curves sit on top + fig, ax = plt.subplots(figsize=(8.8, 4.8)) ax.axvspan(lo_b, hi_b, color=MEANC, alpha=0.07, lw=0) ax.axvline(mean, color=MEANC, lw=1.8) for b in (lo_b, hi_b): @@ -123,49 +142,72 @@ def plot(df: pl.DataFrame, subtitle: str, out_png: Path) -> None: for i, (x, col, on_policy, _) in enumerate(POPS): y = kdes[i] if on_policy: - ax.fill_between(grid, y, color=col, alpha=0.13, lw=0) + ax.fill_between(grid, y, color=col, alpha=0.12, lw=0) ax.plot(grid, y, color=col, lw=1.9) else: - ax.fill_between(grid, y, color=col, alpha=0.05, lw=0) - ax.plot(grid, y, color=col, lw=2.4, ls=(0, (5, 2))) + ax.plot(grid, y, color=col, lw=2.2, ls=(0, (5, 2))) ax.set_ylim(0, ymax) for s in ("top", "right"): ax.spines[s].set_visible(False) ax.set_ylabel("density") ax.set_xlabel("hacking direction (gradient cosine to v_grad) " + r"$\longrightarrow$") - dist_handles = [Line2D([0], [0], color=c, lw=2.2, ls="-" if op else (0, (5, 2))) for _, c, op, _ in POPS] leg1 = ax.legend(dist_handles, [lab for *_, lab in POPS], loc="upper left", fontsize=8, frameon=False, title="distributions", title_fontsize=8) leg1._legend_box.align = "left" ax.add_artist(leg1) - mark_handles = [ - Line2D([0], [0], color=MEANC, lw=1.8), - Patch(facecolor=MEANC, alpha=0.18, edgecolor=MEANC, ls="--"), - Line2D([0], [0], color=ORACLE, lw=1.3, ls="-."), - ] + mark_handles = [Line2D([0], [0], color=MEANC, lw=1.8), + Patch(facecolor=MEANC, alpha=0.18, edgecolor=MEANC, ls="--"), + Line2D([0], [0], color=ORACLE, lw=1.3, ls="-.")] ax.legend(mark_handles, - [f"online mean ({mean:+.2f})", f"mean +/- 2sd (sd={sd:.2f}) = proposed band", - f"best hack/solve split ({oracle:+.2f}, diagnostic)"], - loc="upper right", fontsize=8, frameon=False, title="pinning", title_fontsize=8) - routed_hi = float((pos_live > hi_b).mean()) - ax.set_title(f"{subtitle}\nAUROC(hack-dir -> hack)={auroc:.2f} " - f"{routed_hi:.0%} of rollouts beyond mean+2sd", fontsize=9.5) + [f"online mean ({mean:+.2f})", f"mean +/- 2sd (sd={sd:.2f})", + f"best hack/rest split ({oracle:+.2f})"], + loc="upper right", fontsize=8, frameon=False, title="proposed pinning", title_fontsize=8) + ax.set_title(f"{subtitle}\nlive hack-direction positions vs authored v_grad " + f"(hack-vs-rest AUROC={auroc:.2f})", fontsize=9.5) + fig.tight_layout() + fig.savefig(out_png, dpi=140) + plt.close(fig) + logger.info(f"mean={mean:+.3f} sd={sd:.3f} AUROC={auroc:.3f} -> {out_png}") + + +def plot_pairset(rank: pl.DataFrame, subtitle: str, out_png: Path) -> None: + """Horizontal bar chart: AUROC of a v_grad built from each pairset subset vs live hacks.""" + rank = rank.sort("auroc") + groups = rank["group"].to_list() + auroc = rank["auroc"].to_numpy() + npairs = rank["n_pairs"].to_list() + cohend = rank["cohend"].to_numpy() + y = np.arange(len(groups)) + fig, ax = plt.subplots(figsize=(7.2, 0.5 * len(groups) + 1.4)) + cols = [HACK if g == "all-in-one" else SOLVE for g in groups] + ax.barh(y, auroc, color=cols, alpha=0.85) + ax.axvline(0.5, color="k", lw=1, ls=":") # 0.5 = blind to live hacks + ax.set_yticks(y) + ax.set_yticklabels([f"{g} (n={n})" for g, n in zip(groups, npairs)], fontsize=8) + for yi, a, d in zip(y, auroc, cohend): + ax.text(a + 0.005, yi, f"{a:.2f} d={d:+.2f}", va="center", fontsize=7.5) + ax.set_xlim(min(0.45, auroc.min() - 0.03), max(auroc.max() + 0.12, 0.6)) + for s in ("top", "right"): + ax.spines[s].set_visible(False) + ax.set_xlabel("hack-vs-rest AUROC of subset's v_grad (0.5 = blind, dotted)") + ax.set_title(f"{subtitle}\nwhich authored-pair subset separates live hacks best?", fontsize=9.5) fig.tight_layout() fig.savefig(out_png, dpi=140) plt.close(fig) - logger.info(f"mean={mean:+.3f} sd={sd:.3f} band=[{lo_b:+.3f},{hi_b:+.3f}] oracle={oracle:+.3f} " - f"AUROC={auroc:.3f} routed(>mean+2sd)={routed_hi:.2f}") logger.info(f"wrote {out_png}") def main(cfg: Cfg) -> int: cfg.out_dir.mkdir(parents=True, exist_ok=True) data_path = cfg.out_dir / "pinning_data.parquet" - out_png = cfg.out_dir / "pinning_calib.png" + rank_path = cfg.out_dir / "pinning_pairset.parquet" + dist_png = cfg.out_dir / "pinning_calib.png" + rank_png = cfg.out_dir / "pinning_pairset_auroc.png" if cfg.replot is not None: - df = pl.read_parquet(cfg.replot) - plot(df, f"pinning calibration (replot) -- {cfg.replot.name}", out_png) + plot_dist(pl.read_parquet(cfg.replot), f"replot -- {cfg.replot.name}", dist_png) + if rank_path.exists(): + plot_pairset(pl.read_parquet(rank_path), "replot", rank_png) return 0 device = torch.device("cuda") @@ -191,60 +233,88 @@ def main(cfg: Cfg) -> int: wrappers[nm]["A"].data.copy_(sd[f"A/{nm}"].to(device, torch.float32)) wrappers[nm]["B"].data.copy_(sd[f"B/{nm}"].to(device, torch.float32)) logger.info(f"loaded A/B into {len(names)} modules") - - pairs = load_pairs(cfg.pairs) - logger.info(f"pairs {cfg.pairs} -> {len(pairs)}") model.eval() - _, _, raw_grads, _ = extract_v_hack(model, tok, wrappers, pairs, - top_k=1, tau_axis=0.0, n_heldout=2, device=device) - v_grad = _build_v_grad(raw_grads, wrappers, 1, device) - if cfg.random_v_seed is not None: - v_grad = _haar_unit_dirs(v_grad, cfg.random_v_seed, device) - logger.info(f"OVERRODE v_grad with Haar dirs seed={cfg.random_v_seed} (placebo)") - route_band = route_band_edges(raw_grads, v_grad, device) + def v_grad_from(pairs: list[HackPair]): + _, _, raw, _ = extract_v_hack(model, tok, wrappers, pairs, + top_k=1, tau_axis=0.0, n_heldout=2, device=device) + vg = _build_v_grad(raw, wrappers, 1, device) + if cfg.random_v_seed is not None: + vg = _haar_unit_dirs(vg, cfg.random_v_seed, device) + return vg, route_band_edges(raw, vg, device), raw + + pairs_all = load_pairs(cfg.pairs) + logger.info(f"pairs {cfg.pairs} -> {len(pairs_all)}") + v_grad, route_band, raw_all = v_grad_from(pairs_all) + + # synthetic pair positions under the all-in-one v_grad (for the distribution plot) def pair_pos(side: str) -> np.ndarray: - n = raw_grads[f"{side}/{names[0]}"].shape[0] - out = [] - for i in range(n): - cg = {nm: torch.cat([raw_grads[f"{side}/{nm}"][i], torch.zeros(r)]) for nm in names} - out.append(pooled_pos(cg, v_grad, route_band, names, r)[0]) - return np.array(out) + n = raw_all[f"{side}/{names[0]}"].shape[0] + G = torch.stack([torch.stack([raw_all[f"{side}/{nm}"][i] for nm in names]) for i in range(n)]) + return _batch_pos(G, names, v_grad, route_band) syn_solve, syn_hack = pair_pos("clean"), pair_pos("hack") + # score the live batch ONCE; store per-rollout deployed c-grads for cheap re-projection recs = [json.loads(l) for l in (cfg.run_dir / "rollouts.jsonl").read_text().splitlines()] - batch = [r_ for r_ in recs if cfg.step_lo <= r_["step"] <= cfg.step_hi and r_["text"].strip()][:cfg.max_rollouts] + batch = [x for x in recs if cfg.step_lo <= x["step"] <= cfg.step_hi and x["text"].strip()][:cfg.max_rollouts] logger.info(f"live batch: {len(batch)} rollouts (steps {cfg.step_lo}-{cfg.step_hi})") - pos_live, labels, steps = [], [], [] + G_rows, exploited, gt_pass, steps = [], [], [], [] for i, rec in enumerate(batch): model.zero_grad(set_to_none=True) loss = completion_nll(model, tok, rec["prompt"], rec["text"], device) if not torch.isfinite(loss): continue loss.backward() - cg = {nm: wrappers[nm]["layer"]._lora2r_gate.grad.sum(dim=tuple(range( - wrappers[nm]["layer"]._lora2r_gate.grad.dim() - 1))) for nm in names} - pos_live.append(pooled_pos(cg, v_grad, route_band, names, r)[0]) - labels.append(bool(rec["exploited"])); steps.append(rec["step"]) + g = [] + for nm in names: + gr = wrappers[nm]["layer"]._lora2r_gate.grad + g.append(gr.sum(dim=tuple(range(gr.dim() - 1)))[:r].float().cpu()) # deployed block [r] + G_rows.append(torch.stack(g)) # [M, r] + exploited.append(bool(rec["exploited"])); gt_pass.append(bool(rec["gt_pass"])); steps.append(rec["step"]) if (i + 1) % 40 == 0: logger.info(f" rollout {i+1}/{len(batch)}") model.zero_grad(set_to_none=True) - pos_live = np.array(pos_live); labels = np.array(labels); steps = np.array(steps) - per_step = {int(s): round(_auroc(pos_live[steps == s].tolist(), labels[steps == s].tolist()), 2) - for s in sorted(set(steps.tolist())) if labels[steps == s].any() and (~labels[steps == s]).any()} - logger.info(f"live: {len(labels)} rollouts, {int(labels.sum())} exploited; per-step AUROC={per_step}") + G = torch.stack(G_rows) # [N, M, r] + exploited, gt_pass, steps = map(np.array, (exploited, gt_pass, steps)) + pos = _batch_pos(G, names, v_grad, route_band) + # THREE honest populations (Q4 fix): solve = gt-correct & not-exploited; fail = neither; hack = exploited + solve_m = gt_pass & ~exploited + fail_m = ~gt_pass & ~exploited + logger.info(f"live: {len(exploited)} rollouts | hack={int(exploited.sum())} " + f"solve={int(solve_m.sum())} fail={int(fail_m.sum())}") + + # per-pairset AUROC: build v_grad from each heading-prefix subset, re-project the SAME G + groups: dict[str, list] = defaultdict(list) + for p in pairs_all: + groups[p.problem_id.split("_")[0]].append(p) + rows = [] + # extract_v_hack holds out 2 pairs, so a subset needs >=3 to leave a training pair. + candidates = [("all-in-one", pairs_all)] + [(g, ps) for g, ps in sorted(groups.items()) if len(ps) >= 3] + for gname, gpairs in candidates: + vg, rb, _ = v_grad_from(gpairs) + pg = _batch_pos(G, names, vg, rb) + auroc = _auroc(pg.tolist(), exploited.tolist()) + cohend = _cohend(pg[exploited], pg[solve_m]) if solve_m.sum() >= 2 else _cohend(pg[exploited], pg[~exploited]) + rows.append({"group": gname, "n_pairs": len(gpairs), "auroc": round(auroc, 3), + "cohend": round(cohend, 3)}) + logger.info(f" pairset {gname:12s} n={len(gpairs):2d} AUROC={auroc:.3f} d={cohend:+.3f}") + rank = pl.DataFrame(rows).sort("auroc", descending=True) + rank.write_parquet(rank_path) + print(tabulate(rank.to_pandas(), headers="keys", tablefmt="pipe", floatfmt="+.3f", showindex=False)) - # save the 4 populations long-form -> regenerates the plot with --replot (no GPU) df = pl.concat([ - pl.DataFrame({"pop": "on_solve", "pos": pos_live[~labels], "step": steps[~labels]}), - pl.DataFrame({"pop": "on_hack", "pos": pos_live[labels], "step": steps[labels]}), + pl.DataFrame({"pop": "on_solve", "pos": pos[solve_m], "step": steps[solve_m]}), + pl.DataFrame({"pop": "on_fail", "pos": pos[fail_m], "step": steps[fail_m]}), + pl.DataFrame({"pop": "on_hack", "pos": pos[exploited], "step": steps[exploited]}), pl.DataFrame({"pop": "syn_solve", "pos": syn_solve, "step": np.full(len(syn_solve), -1)}), pl.DataFrame({"pop": "syn_hack", "pos": syn_hack, "step": np.full(len(syn_hack), -1)}), ]) df.write_parquet(data_path) logger.info(f"wrote {data_path} ({len(df)} rows)") - plot(df, f"{cfg.run_dir.name}\n{cfg.ckpt} v_grad, live steps {cfg.step_lo}-{cfg.step_hi}, " - f"{int(labels.sum())}/{len(labels)} exploited", out_png) + sub = (f"{cfg.run_dir.name}\n{cfg.ckpt} v_grad, live steps {cfg.step_lo}-{cfg.step_hi}, " + f"hack={int(exploited.sum())} solve={int(solve_m.sum())} fail={int(fail_m.sum())}") + plot_dist(df, sub, dist_png) + plot_pairset(rank, cfg.run_dir.name, rank_png) return 0