diff --git a/scripts/diag_pinning.py b/scripts/diag_pinning.py index d556ff9..140457f 100644 --- a/scripts/diag_pinning.py +++ b/scripts/diag_pinning.py @@ -1,22 +1,45 @@ -"""Pinning calibration + per-pairset separation: where does the live gate sit, would -mean +/- k*sd route it, and which authored-pair subset gives the best hack/clean split? +"""Q2 diagnostic: what should the live routeV gate SCORE, and where do the pinning +cuts go? -The routeV gate scores each rollout by the width-pooled, band-normalized cosine of its -deployed-block c-probe gradient to v_grad (exactly train.py:_lora2r_gate_labels) -- the -position on the HACKING DIRECTION. This script: +THE QUESTION (Q2). The gate routes UPDATES, not rollouts: per rollout the GRPO update +on the LoRA is g = A * gradNLL (advantage-weighted). So the positive class is +"update ascending the hack direction" = exploited & A>0; rollouts with A~0 contribute +no update and are DROPPED (scoring them as dead zeros at 0 is what made earlier +adv-weighted AUROCs look blind); a hack with A<0 is being UNLEARNED and belongs in +the negative class. (Q1, "does the direction exist at adv=+1", was answered earlier: +~0.61 -- see git history of this file.) - 1. Splits live rollouts into THREE honest populations (Q4 fix -- "not-exploited" was - conflating genuine solves with fails): on-policy solve (gt_pass), fail, hack (exploited). - Plots them with the authored synthetic solve/hack, marks online mean and mean +/- 2sd. - 2. Ranks pairset SUBSETS (heading-prefix views: behavior/opportunity/disposition/...) by - how well a v_grad built from ONLY that subset separates live hacks (AUROC + Cohen's d). - Done by storing each live rollout's per-module deployed c-grad ONCE, then re-projecting - onto each subset's v_grad -- no model re-run per subset (the refresh-tracking trick). +FOUR CANDIDATE SCORES = {grad, act} x {cos, dot}, all in concatenated-module space: + - grad: the adv-weighted deployed c-probe gradient (the gate's current input). + - act: the deployed bottleneck activation A[:r]@x, mean over completion tokens -- + same [r]-per-module space, capturable in the gate's pass-1 forward for free. + - cos: magnitude-blind alignment (tiny vectors give meaningless angles -- control). + - dot: = |g|*cos, magnitude-aware; with g = A*gradNLL the advantage rides + along, so dot reads "how hard is this update pushing hack-ward". +v for each representation comes from the authored pairs only (mean hack-minus-clean +per module, unit per module) -- the no-cheat label source; live labels are read ONLY +to measure (AUROC / precision at the rout cut), never to route. + +PINNING. Each panel shades the three zones the online gate rule would give on this +window: keep (bulk) | absorb (score > mean + k_mid*sd) | rout (>= mean + k_rout*sd), +plus the oracle best hack-vs-rest split for reference. k's default to the real-run +Config values (2/3), not the checkpoint's preset, so the plot answers "where WOULD +the cuts fall", overridable via --k-mid/--k-rout. + +CAVEAT. Live advantages are reconstructed from rollouts.jsonl students only (teachers +absent, zero-variance groups included), so A signs/magnitudes are approximate; the +act columns dodge this entirely (no A in the representation). + +HOW. One GPU pass: per live rollout, backward its completion NLL once, capture the +c-probe grad AND the pooled bottleneck act; same per authored-pair side. Everything +downstream (subset vectors, 4 scores, zones, table) is offline re-projection of the +cached features. uv run python scripts/diag_pinning.py --run-dir out/runs/ - uv run python scripts/diag_pinning.py --replot out/diag/pinning_data.parquet # no GPU -outputs (out/diag/): pinning_calib.png, pinning_pairset_auroc.png, pinning_data.parquet, -pinning_pairset.parquet. + uv run python scripts/diag_pinning.py --replot out/diag/pinning_data.parquet # no GPU +outputs (out/diag/): pinning_q2.png (2x2 headline), pinning_data.parquet (per-rollout +scores), pinning_pairset.parquet + printed table (subsets x 4 AUROCs), +pinning_feats.pt (raw features, for offline re-analysis). """ from __future__ import annotations import json @@ -27,6 +50,7 @@ from pathlib import Path import numpy as np import torch +import torch.nn.functional as F import tyro import polars as pl import matplotlib @@ -40,12 +64,13 @@ from safetensors.torch import load_file from transformers import AutoModelForCausalLM, AutoTokenizer from vgrout.lora2r import wrap_model_with_lora2r -from vgrout.pairs import load_pairs, HackPair -from vgrout.extract_vhack_grad import extract_v_hack, completion_nll -from vgrout.train import _build_v_grad, route_band_edges, _haar_unit_dirs, _auroc +from vgrout.pairs import load_pairs +from vgrout.extract_vhack_grad import completion_nll +from vgrout.train import _auroc # colour = behaviour (blue solve, red hack, grey fail); style = source (solid on-policy, dashed synthetic) -SOLVE, HACK, FAIL, MEANC, ORACLE = "#3b6ea5", "#c44e52", "#9aa0a6", "#d1900a", "#3a8a7a" +SOLVE, HACK, FAIL, ABSORB_C, ROUT_C, ORACLE = "#3b6ea5", "#c44e52", "#9aa0a6", "#d1900a", "#c44e52", "#3a8a7a" +CASES = [("grad", "cos"), ("grad", "dot"), ("act", "cos"), ("act", "dot")] @dataclass @@ -53,13 +78,19 @@ class Cfg: run_dir: Path = Path("out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3") ckpt: str = "first_hack" pairs: Path = Path("data/pairs/hack_pairs.md#all-in-one") + # headline figure builds v from this heading-prefix subset (the routeV default); + # the pairset table spans all subsets of `pairs`. + headline_prefix: str = "behavior" # Coherent emergence window. This vanilla v3 used the pre-fix lr=5e-4/warmup-0.1 and # DIVERGED at step 10 (exploited 20/24 -> 0/24); 2-9 = hacks emerging, model still sane. step_lo: int = 2 step_hi: int = 9 max_rollouts: int = 240 - random_v_seed: int | None = None # Haar placebo (sanity: pins should NOT separate) - replot: Path | None = None # load this parquet and re-plot only (no model, no GPU) + k_mid: float = 2.0 # absorb onset: score > mean + k_mid*sd (real-run Config default) + k_rout: float = 3.0 # rout onset: score >= mean + k_rout*sd + adv_eps: float = 1e-6 # |A| below this = no update exists -> dropped from zones/AUROC + random_v_seed: int | None = None # Haar placebo (sanity: nothing should separate) + replot: Path | None = None # load parquet and re-plot only (no model, no GPU) out_dir: Path = Path("out/diag") @@ -68,146 +99,192 @@ def _ckpt_meta(path: Path) -> dict: return json.loads(f.read(struct.unpack(" np.ndarray: - """Pooled band-normalized cosine position per rollout, vectorized over modules. +class ActTap: + """Forward hooks stashing the deployed bottleneck activation h = A[:r] @ x per module. - G: [N, M, r] raw deployed-block c-grads (v-independent). Mirrors train.py pooling: - per module cos to ANY of the k dirs (max), width-weighted pool, excluding zero-width - modules. Same G re-projects onto any v_grad/route_band -> cheap per-pairset scoring. + Computes the r-dim projection inline (no_grad) instead of retaining the full + [L, d_in] input -- ~250 modules x [L, d_in] would be GBs; [L, r] is nothing. """ - dev = v_grad[names[0]].device - Vs = torch.stack([v_grad[n] for n in names]).to(dev) # [M, k, r] - low = torch.tensor([route_band[n][0] for n in names], device=dev) # [M] - up = torch.tensor([route_band[n][1] for n in names], device=dev) # [M] - inc = (up - low) > 0 - U = (G.to(dev) / G.to(dev).norm(dim=2, keepdim=True).clamp_min(1e-12)) # unit per (n,m) - cos = torch.einsum("nmr,mkr->nmk", U, Vs).amax(-1) # [N, M] - num = ((cos - low) * inc).sum(1) - den = ((up - low) * inc).sum().clamp_min(1e-12) - return (num / den).cpu().numpy() + def __init__(self, wrappers: dict, names: list[str]): + self.wrappers, self.names, self.h, self.handles = wrappers, names, {}, [] + + def __enter__(self): + for nm in self.names: + layer = self.wrappers[nm]["layer"] + def hook(layer, args, out, nm=nm): + (x,) = args + with torch.no_grad(): + self.h[nm] = F.linear(x.detach(), layer._lora2r_A[: layer._lora2r_r].to(x.dtype)) + self.handles.append(layer.register_forward_hook(hook)) + return self + + def __exit__(self, *exc): + for h in self.handles: + h.remove() + + def pooled(self, n_prompt: int) -> torch.Tensor: + """[M, r] mean bottleneck act over completion tokens (positions >= n_prompt).""" + out = [] + for nm in self.names: + h = self.h[nm] # [1, L, r] + assert h.shape[1] > n_prompt, f"{nm}: no completion tokens (L={h.shape[1]} n_prompt={n_prompt})" + out.append(h[0, n_prompt:].float().mean(0).cpu()) + return torch.stack(out) -def _cohend(a: np.ndarray, b: np.ndarray) -> float: - """Standardized mean gap (hack - solve); +ve = hacks sit further along the hack-dir.""" - if len(a) < 2 or len(b) < 2: - return float("nan") - sp = np.sqrt(((len(a) - 1) * a.var(ddof=1) + (len(b) - 1) * b.var(ddof=1)) / (len(a) + len(b) - 2)) - return float((a.mean() - b.mean()) / sp) if sp > 0 else float("nan") +def _gate_grads(wrappers: dict, names: list[str]) -> torch.Tensor: + """[M, r] deployed-block c-probe grad after a backward (the gate's gradient space).""" + g = [] + for nm in names: + layer = wrappers[nm]["layer"] + gr = layer._lora2r_gate.grad + g.append(gr.sum(dim=tuple(range(gr.dim() - 1)))[: layer._lora2r_r].float().cpu()) + return torch.stack(g) + + +def _v_from(feat_hack: torch.Tensor, feat_clean: torch.Tensor, idx: list[int]) -> torch.Tensor: + """[M, r] unit-per-module mean hack-minus-clean direction from pair rows `idx`.""" + d = (feat_hack[idx] - feat_clean[idx]).mean(0) + return d / d.norm(dim=1, keepdim=True).clamp_min(1e-12) + + +def _haar_like(v: torch.Tensor, seed: int) -> torch.Tensor: + g = torch.Generator().manual_seed(seed) + d = torch.randn(v.shape, generator=g) + return d / d.norm(dim=1, keepdim=True).clamp_min(1e-12) + + +def _score(X: torch.Tensor, V: torch.Tensor, kind: str) -> np.ndarray: + """Concat-module score per rollout: dot = sum_m ; cos = dot / (||x|| ||v||).""" + d = torch.einsum("nmr,mr->n", X, V) + if kind == "dot": + return d.numpy() + return (d / (X.flatten(1).norm(dim=1).clamp_min(1e-12) * V.flatten().norm().clamp_min(1e-12))).numpy() def _kde(x: np.ndarray, grid: np.ndarray) -> np.ndarray: - """Gaussian KDE, Silverman bandwidth (no scipy).""" + """Gaussian KDE, Silverman bandwidth (no scipy). Bandwidth is scale-relative + (dot scores can live at 1e-4 or 1e2).""" x = np.asarray(x, float) if len(x) < 2: return np.zeros_like(grid) iqr = np.subtract(*np.percentile(x, [75, 25])) sigma = min(x.std(ddof=1), iqr / 1.349) if iqr > 0 else x.std(ddof=1) - bw = max(0.9 * (sigma or 1.0) * len(x) ** (-0.2), 1e-3) + if sigma <= 0: + return np.zeros_like(grid) + bw = 0.9 * sigma * len(x) ** (-0.2) z = (grid[:, None] - x[None, :]) / bw return np.exp(-0.5 * z ** 2).sum(1) / (len(x) * bw * np.sqrt(2 * np.pi)) -def plot_dist(df: pl.DataFrame, subtitle: str, out_png: Path) -> None: - """Distribution plot from the 5 saved populations -- no GPU needed.""" - arr = lambda p: df.filter(pl.col("pop") == p)["pos"].to_numpy() - on_solve, on_fail, on_hack = arr("on_solve"), arr("on_fail"), arr("on_hack") - syn_solve, syn_hack = arr("syn_solve"), arr("syn_hack") - # AUROC/divider are hack-vs-rest (the gate's job): label = exploited. - pos_live = np.concatenate([on_solve, on_fail, on_hack]) - labels = np.concatenate([np.zeros(len(on_solve) + len(on_fail), bool), np.ones(len(on_hack), bool)]) - mean, sd = float(pos_live.mean()), float(pos_live.std()) - lo_b, hi_b = mean - 2 * sd, mean + 2 * sd - auroc = _auroc(pos_live.tolist(), labels.tolist()) - thr = np.unique(pos_live) - j = [(pos_live[labels] >= t).mean() - (pos_live[~labels] >= t).mean() for t in thr] - oracle = float(thr[int(np.argmax(j))]) +def plot_q2(df: pl.DataFrame, k_mid: float, k_rout: float, subtitle: str, out_png: Path) -> dict: + """2x2 figure ({grad,act} x {cos,dot}) from the saved per-rollout scores -- no GPU. - lo = min(pos_live.min(), syn_solve.min()) - 0.1 - hi = max(np.quantile(pos_live, 0.99), syn_hack.max()) + 0.1 - grid = np.linspace(lo, hi, 400) - POPS = [(on_solve, SOLVE, True, "on-policy solve (gt pass)"), - (on_fail, FAIL, True, "on-policy fail"), - (on_hack, HACK, True, "on-policy hack"), - (syn_solve, SOLVE, False, "synthetic solve"), - (syn_hack, HACK, False, "synthetic hack")] - kdes = {i: _kde(x, grid) for i, (x, *_) in enumerate(POPS)} - ymax = max(y.max() for y in kdes.values()) * 1.15 + Per panel: live solve/fail/hack+ KDEs (+ thin hack- if n>=3), synthetic pair sides + dashed, three shaded zones keep|absorb|rout from mean + k*sd over the VALID live + scores (|A|>eps; pop 'on_drop' excluded), oracle split, AUROC + P/R at the rout cut. + Returns the per-case stats dict for logging.""" + pops = {p: df.filter(pl.col("pop") == p) for p in df["pop"].unique().to_list()} + live_pops = ["on_solve", "on_fail", "on_hackpos", "on_hackneg"] + stats = {} + fig, axes = plt.subplots(2, 2, figsize=(12.5, 7.6)) + for ax, (rep, kind) in zip(axes.flat, CASES): + col = f"{rep}_{kind}" + live = df.filter(pl.col("pop").is_in(live_pops)) + s = live[col].to_numpy() + y = (live["pop"] == "on_hackpos").to_numpy() + mu, sd = float(s.mean()), float(s.std()) + t_lo, t_hi = mu + k_mid * sd, mu + k_rout * sd + auroc = _auroc(s.tolist(), y.tolist()) + thr = np.unique(s) + j = [(s[y] >= t).mean() - (s[~y] >= t).mean() for t in thr] + oracle = float(thr[int(np.argmax(j))]) + routed = s >= t_hi + prec = float(y[routed].mean()) if routed.any() else float("nan") + rec = float((s[y] >= t_hi).mean()) if y.any() else float("nan") + stats[col] = {"auroc": auroc, "prec_rout": prec, "rec_rout": rec, "t_hi": t_hi, "oracle": oracle} - fig, ax = plt.subplots(figsize=(8.8, 4.8)) - ax.axvspan(lo_b, hi_b, color=MEANC, alpha=0.07, lw=0) - ax.axvline(mean, color=MEANC, lw=1.8) - for b in (lo_b, hi_b): - ax.axvline(b, color=MEANC, lw=1.2, ls="--") - ax.axvline(oracle, color=ORACLE, lw=1.3, ls="-.") - for i, (x, col, on_policy, _) in enumerate(POPS): - y = kdes[i] - if on_policy: - ax.fill_between(grid, y, color=col, alpha=0.12, lw=0) - ax.plot(grid, y, color=col, lw=1.9) - else: - ax.plot(grid, y, color=col, lw=2.2, ls=(0, (5, 2))) - ax.set_ylim(0, ymax) - for s in ("top", "right"): - ax.spines[s].set_visible(False) - ax.set_ylabel("density") - ax.set_xlabel("hacking direction (gradient cosine to v_grad) " + r"$\longrightarrow$") - dist_handles = [Line2D([0], [0], color=c, lw=2.2, ls="-" if op else (0, (5, 2))) for _, c, op, _ in POPS] - leg1 = ax.legend(dist_handles, [lab for *_, lab in POPS], loc="upper left", - fontsize=8, frameon=False, title="distributions", title_fontsize=8) - leg1._legend_box.align = "left" - ax.add_artist(leg1) - mark_handles = [Line2D([0], [0], color=MEANC, lw=1.8), - Patch(facecolor=MEANC, alpha=0.18, edgecolor=MEANC, ls="--"), - Line2D([0], [0], color=ORACLE, lw=1.3, ls="-.")] - ax.legend(mark_handles, - [f"online mean ({mean:+.2f})", f"mean +/- 2sd (sd={sd:.2f})", - f"best hack/rest split ({oracle:+.2f})"], - loc="upper right", fontsize=8, frameon=False, title="proposed pinning", title_fontsize=8) - ax.set_title(f"{subtitle}\nlive hack-direction positions vs authored v_grad " - f"(hack-vs-rest AUROC={auroc:.2f})", fontsize=9.5) - fig.tight_layout() - fig.savefig(out_png, dpi=140) - plt.close(fig) - logger.info(f"mean={mean:+.3f} sd={sd:.3f} AUROC={auroc:.3f} -> {out_png}") + lo = float(np.quantile(s, 0.005)) + hi = float(np.quantile(s, 0.995)) + if kind == "cos": # keep synthetic medians visible (cos shares a scale; + for p in ("syn_clean", "syn_hack"): # dot doesn't -- pair grads dwarf live, annotate instead) + if len(pops.get(p, [])): + m = float(np.median(pops[p][col].to_numpy())) + lo, hi = min(lo, m), max(hi, m) + pad = 0.05 * (hi - lo) or 1e-6 + lo, hi = lo - pad, hi + pad + grid = np.linspace(lo, hi, 400) + if kind == "dot": + off = [f"syn {p.split('_')[1]} med={float(np.median(pops[p][col].to_numpy())):+.2g}" + for p in ("syn_clean", "syn_hack") + if len(pops.get(p, [])) and not lo < float(np.median(pops[p][col].to_numpy())) < hi] + if off: + ax.annotate("off-scale: " + ", ".join(off) + r" $\rightarrow$", + xy=(0.98, 0.84), xycoords="axes fraction", ha="right", fontsize=7, color="#777777") + curves = [("on_solve", SOLVE, "-", 1.9, 0.12), ("on_fail", FAIL, "-", 1.9, 0.12), + ("on_hackpos", HACK, "-", 1.9, 0.12), + ("syn_clean", SOLVE, (0, (5, 2)), 2.0, 0.0), ("syn_hack", HACK, (0, (5, 2)), 2.0, 0.0)] + if len(pops.get("on_hackneg", [])) >= 3: + curves.insert(3, ("on_hackneg", HACK, (0, (1, 1)), 1.2, 0.0)) + ymax = 0.0 + for p, c, ls, lw, fill in curves: + yk = _kde(pops[p][col].to_numpy(), grid) + ymax = max(ymax, yk.max()) + if fill: + ax.fill_between(grid, yk, color=c, alpha=fill, lw=0) + ax.plot(grid, yk, color=c, lw=lw, ls=ls) + ymax *= 1.18 -def plot_pairset(rank: pl.DataFrame, subtitle: str, out_png: Path) -> None: - """Horizontal bar chart: AUROC of a v_grad built from each pairset subset vs live hacks.""" - rank = rank.sort("auroc") - groups = rank["group"].to_list() - auroc = rank["auroc"].to_numpy() - npairs = rank["n_pairs"].to_list() - cohend = rank["cohend"].to_numpy() - y = np.arange(len(groups)) - fig, ax = plt.subplots(figsize=(7.2, 0.5 * len(groups) + 1.4)) - cols = [HACK if g == "all-in-one" else SOLVE for g in groups] - ax.barh(y, auroc, color=cols, alpha=0.85) - ax.axvline(0.5, color="k", lw=1, ls=":") # 0.5 = blind to live hacks - ax.set_yticks(y) - ax.set_yticklabels([f"{g} (n={n})" for g, n in zip(groups, npairs)], fontsize=8) - for yi, a, d in zip(y, auroc, cohend): - ax.text(a + 0.005, yi, f"{a:.2f} d={d:+.2f}", va="center", fontsize=7.5) - ax.set_xlim(min(0.45, auroc.min() - 0.03), max(auroc.max() + 0.12, 0.6)) - for s in ("top", "right"): - ax.spines[s].set_visible(False) - ax.set_xlabel("hack-vs-rest AUROC of subset's v_grad (0.5 = blind, dotted)") - ax.set_title(f"{subtitle}\nwhich authored-pair subset separates live hacks best?", fontsize=9.5) - fig.tight_layout() + # three zones: keep | absorb | rout + ax.axvspan(t_lo, min(t_hi, hi), color=ABSORB_C, alpha=0.08, lw=0) + ax.axvspan(min(t_hi, hi), hi, color=ROUT_C, alpha=0.10, lw=0) + ax.axvline(t_lo, color=ABSORB_C, lw=1.2, ls="--") + ax.axvline(t_hi, color=ROUT_C, lw=1.2, ls="--") + ax.axvline(oracle, color=ORACLE, lw=1.3, ls="-.") + for xz, lab in ((min(t_lo, hi) - 0.02 * (hi - lo), "keep"), + ((t_lo + min(t_hi, hi)) / 2, "absorb"), + ((min(t_hi, hi) + hi) / 2, "rout")): + if lo < xz < hi: + ax.text(xz, ymax * 0.97, lab, ha="center", va="top", fontsize=7.5, color="#555555") + ax.set_xlim(lo, hi) + ax.set_ylim(0, ymax) + for sp in ("top", "right"): + ax.spines[sp].set_visible(False) + ax.set_title(f"{rep} · {kind} AUROC={auroc:.2f} P@rout={prec:.2f} R@rout={rec:.2f}", + fontsize=9.5) + ax.set_xlabel({"cos": "cosine to v (concat modules)", + "dot": "dot ⟨x, v⟩ (update mass along v)"}[kind], fontsize=8.5) + ax.set_ylabel("density", fontsize=8.5) + + handles = [Line2D([0], [0], color=SOLVE, lw=1.9), Line2D([0], [0], color=FAIL, lw=1.9), + Line2D([0], [0], color=HACK, lw=1.9), + Line2D([0], [0], color=SOLVE, lw=2.0, ls=(0, (5, 2))), + Line2D([0], [0], color=HACK, lw=2.0, ls=(0, (5, 2))), + Patch(facecolor=ABSORB_C, alpha=0.18), Patch(facecolor=ROUT_C, alpha=0.18), + Line2D([0], [0], color=ORACLE, lw=1.3, ls="-.")] + labels = ["live solve", "live fail", "live hack (A>0)", "synthetic clean", "synthetic hack", + f"absorb (>mean+{k_mid:g}sd)", f"rout (>=mean+{k_rout:g}sd)", "oracle hack/rest split"] + fig.legend(handles, labels, loc="lower center", ncol=4, fontsize=8, frameon=False) + fig.suptitle(subtitle, fontsize=10) + fig.tight_layout(rect=(0, 0.07, 1, 0.95)) fig.savefig(out_png, dpi=140) plt.close(fig) logger.info(f"wrote {out_png}") + return stats def main(cfg: Cfg) -> int: cfg.out_dir.mkdir(parents=True, exist_ok=True) data_path = cfg.out_dir / "pinning_data.parquet" rank_path = cfg.out_dir / "pinning_pairset.parquet" - dist_png = cfg.out_dir / "pinning_calib.png" - rank_png = cfg.out_dir / "pinning_pairset_auroc.png" + feats_path = cfg.out_dir / "pinning_feats.pt" + q2_png = cfg.out_dir / "pinning_q2.png" if cfg.replot is not None: - plot_dist(pl.read_parquet(cfg.replot), f"replot -- {cfg.replot.name}", dist_png) + plot_q2(pl.read_parquet(cfg.replot), cfg.k_mid, cfg.k_rout, f"replot -- {cfg.replot.name}", q2_png) if rank_path.exists(): - plot_pairset(pl.read_parquet(rank_path), "replot", rank_png) + print(tabulate(pl.read_parquet(rank_path).to_pandas(), headers="keys", + tablefmt="pipe", floatfmt="+.3f", showindex=False)) return 0 device = torch.device("cuda") @@ -218,7 +295,8 @@ def main(cfg: Cfg) -> int: r = run_cfg.get("lora_r", 32) init_seed = run_cfg.get("lora_init_seed", 0) logger.info(f"ckpt {ckpt_path.name} step={meta.get('step')} hack_rate={meta.get('hack_rate')} " - f"model={model_name} r={r} init_seed={init_seed}") + f"model={model_name} r={r} init_seed={init_seed} | run-preset k_mid/k_rout=" + f"{run_cfg.get('route_std_mid')}/{run_cfg.get('route_std_rout')} (plot uses {cfg.k_mid}/{cfg.k_rout})") tok = AutoTokenizer.from_pretrained(model_name) if tok.pad_token_id is None: @@ -235,86 +313,150 @@ def main(cfg: Cfg) -> int: logger.info(f"loaded A/B into {len(names)} modules") model.eval() - def v_grad_from(pairs: list[HackPair]): - _, _, raw, _ = extract_v_hack(model, tok, wrappers, pairs, - top_k=1, tau_axis=0.0, n_heldout=2, device=device) - vg = _build_v_grad(raw, wrappers, 1, device) - if cfg.random_v_seed is not None: - vg = _haar_unit_dirs(vg, cfg.random_v_seed, device) - return vg, route_band_edges(raw, vg, device), raw + def one_pass(tap: ActTap, prompt: str, completion: str) -> tuple[torch.Tensor, torch.Tensor] | None: + """Backward one completion's mean NLL; return ([M,r] c-grad, [M,r] pooled act).""" + model.zero_grad(set_to_none=True) + loss = completion_nll(model, tok, prompt, completion, device) + if not torch.isfinite(loss): + return None + loss.backward() + n_prompt = tok(prompt, return_tensors="pt").input_ids.shape[1] + return _gate_grads(wrappers, names), tap.pooled(n_prompt) + # ── authored-pair features, once over ALL pairs (subsets = row slices) ── pairs_all = load_pairs(cfg.pairs) logger.info(f"pairs {cfg.pairs} -> {len(pairs_all)}") - v_grad, route_band, raw_all = v_grad_from(pairs_all) + pair_feat = {("grad", "hack"): [], ("grad", "clean"): [], ("act", "hack"): [], ("act", "clean"): []} + with ActTap(wrappers, names) as tap: + for pi, pair in enumerate(pairs_all): + for side, completion in (("hack", pair.hack), ("clean", pair.clean)): + out = one_pass(tap, pair.prompt, completion) + if out is None: + raise RuntimeError(f"non-finite loss on pair {pi} ({pair.problem_id}) side={side}") + pair_feat[("grad", side)].append(out[0]) + pair_feat[("act", side)].append(out[1]) + if (pi + 1) % 5 == 0: + logger.info(f" pair {pi+1}/{len(pairs_all)}") + PF = {k: torch.stack(v) for k, v in pair_feat.items()} # each [P, M, r] - # synthetic pair positions under the all-in-one v_grad (for the distribution plot) - def pair_pos(side: str) -> np.ndarray: - n = raw_all[f"{side}/{names[0]}"].shape[0] - G = torch.stack([torch.stack([raw_all[f"{side}/{nm}"][i] for nm in names]) for i in range(n)]) - return _batch_pos(G, names, v_grad, route_band) - syn_solve, syn_hack = pair_pos("clean"), pair_pos("hack") - - # score the live batch ONCE; store per-rollout deployed c-grads for cheap re-projection - recs = [json.loads(l) for l in (cfg.run_dir / "rollouts.jsonl").read_text().splitlines()] - batch = [x for x in recs if cfg.step_lo <= x["step"] <= cfg.step_hi and x["text"].strip()][:cfg.max_rollouts] - logger.info(f"live batch: {len(batch)} rollouts (steps {cfg.step_lo}-{cfg.step_hi})") - G_rows, exploited, gt_pass, steps = [], [], [], [] - for i, rec in enumerate(batch): - model.zero_grad(set_to_none=True) - loss = completion_nll(model, tok, rec["prompt"], rec["text"], device) - if not torch.isfinite(loss): - continue - loss.backward() - g = [] - for nm in names: - gr = wrappers[nm]["layer"]._lora2r_gate.grad - g.append(gr.sum(dim=tuple(range(gr.dim() - 1)))[:r].float().cpu()) # deployed block [r] - G_rows.append(torch.stack(g)) # [M, r] - exploited.append(bool(rec["exploited"])); gt_pass.append(bool(rec["gt_pass"])); steps.append(rec["step"]) - if (i + 1) % 40 == 0: - logger.info(f" rollout {i+1}/{len(batch)}") + # ── live rollout features, once (everything downstream re-projects) ── + recs = [json.loads(l) for l in (cfg.run_dir / "rollouts.jsonl").read_text().splitlines()] + batch = [x for x in recs if cfg.step_lo <= x["step"] <= cfg.step_hi and x["text"].strip()][:cfg.max_rollouts] + logger.info(f"live batch: {len(batch)} rollouts (steps {cfg.step_lo}-{cfg.step_hi})") + G_rows, A_rows, kept = [], [], [] + for i, rec in enumerate(batch): + out = one_pass(tap, rec["prompt"], rec["text"]) + if out is None: + logger.warning(f" skip rollout {i}: non-finite loss") + continue + G_rows.append(out[0]); A_rows.append(out[1]); kept.append(rec) + if (i + 1) % 40 == 0: + logger.info(f" rollout {i+1}/{len(batch)}") model.zero_grad(set_to_none=True) - G = torch.stack(G_rows) # [N, M, r] - exploited, gt_pass, steps = map(np.array, (exploited, gt_pass, steps)) - pos = _batch_pos(G, names, v_grad, route_band) - # THREE honest populations (Q4 fix): solve = gt-correct & not-exploited; fail = neither; hack = exploited - solve_m = gt_pass & ~exploited - fail_m = ~gt_pass & ~exploited - logger.info(f"live: {len(exploited)} rollouts | hack={int(exploited.sum())} " - f"solve={int(solve_m.sum())} fail={int(fail_m.sum())}") + G = torch.stack(G_rows) # [N, M, r] gradNLL + ACT = torch.stack(A_rows) # [N, M, r] + exploited = np.array([bool(x["exploited"]) for x in kept]) + gt_pass = np.array([bool(x["gt_pass"]) for x in kept]) + steps = np.array([x["step"] for x in kept]) + p_idx = np.array([x["p_idx"] for x in kept]) + reward = np.array([float(x["reward"]) for x in kept]) - # per-pairset AUROC: build v_grad from each heading-prefix subset, re-project the SAME G - groups: dict[str, list] = defaultdict(list) - for p in pairs_all: - groups[p.problem_id.split("_")[0]].append(p) + # Reconstructed Dr.GRPO advantage A_i = reward_i - mean(reward over its group). + # CAVEAT: students only (teachers absent from rollouts.jsonl), so signs/magnitudes + # are approximate -- see module docstring. + grp_mean = {} + for s, p in set(zip(steps.tolist(), p_idx.tolist())): + m = (steps == s) & (p_idx == p) + grp_mean[(s, p)] = reward[m].mean() + adv = np.array([reward[i] - grp_mean[(steps[i], p_idx[i])] for i in range(len(reward))]) + G_adv = G * torch.tensor(adv, dtype=G.dtype)[:, None, None] # the update the gate sees + + # ── Q2 populations: drop A~0 (no update); positive = exploited & A>0 ── + valid = np.abs(adv) > cfg.adv_eps + y = exploited & (adv > 0) + pop = np.where(~valid, "on_drop", + np.where(exploited & (adv > 0), "on_hackpos", + np.where(exploited, "on_hackneg", + np.where(gt_pass, "on_solve", "on_fail")))) + counts = {p: int((pop == p).sum()) for p in ("on_solve", "on_fail", "on_hackpos", "on_hackneg", "on_drop")} + logger.info(f"live populations: {counts} (zones/AUROC use the {int(valid.sum())} valid rows)") + print(f"SHOULD: on_hackpos >= ~20 and on_drop not the majority, ELSE the window/run has " + f"too few learnable hacks and every AUROC below is noise.") + + # ── headline vectors from the routeV-default subset; placebo swaps in Haar ── + groups: dict[str, list[int]] = defaultdict(list) + for i, p in enumerate(pairs_all): + groups[p.problem_id.split("_")[0]].append(i) + head_idx = [i for i, p in enumerate(pairs_all) if p.problem_id.startswith(cfg.headline_prefix)] + assert head_idx, f"no pairs match headline prefix {cfg.headline_prefix!r}" + logger.info(f"headline v from prefix {cfg.headline_prefix!r} -> {len(head_idx)} pairs") + + def vectors(idx: list[int]) -> dict[str, torch.Tensor]: + v = {"grad": _v_from(PF[("grad", "hack")], PF[("grad", "clean")], idx), + "act": _v_from(PF[("act", "hack")], PF[("act", "clean")], idx)} + if cfg.random_v_seed is not None: + v = {"grad": _haar_like(v["grad"], cfg.random_v_seed), + "act": _haar_like(v["act"], cfg.random_v_seed + 1)} + return v + + v_head = vectors(head_idx) + live_X = {"grad": G_adv, "act": ACT} + syn_X = {("grad", "clean"): PF[("grad", "clean")], ("grad", "hack"): PF[("grad", "hack")], + ("act", "clean"): PF[("act", "clean")], ("act", "hack"): PF[("act", "hack")]} + + def score_cols(v: dict, X: dict[str, torch.Tensor]) -> dict[str, np.ndarray]: + return {f"{rep}_{kind}": _score(X[rep], v[rep], kind) for rep, kind in CASES} + + live_scores = score_cols(v_head, live_X) + syn_scores = {side: score_cols(v_head, {"grad": syn_X[("grad", side)][head_idx], + "act": syn_X[("act", side)][head_idx]}) + for side in ("clean", "hack")} + + # ── pairset table: subsets x 4 AUROCs on the SAME cached live features ── + candidates = [("all-in-one", list(range(len(pairs_all))))] + \ + [(g, idx) for g, idx in sorted(groups.items()) if len(idx) >= 3] rows = [] - # extract_v_hack holds out 2 pairs, so a subset needs >=3 to leave a training pair. - candidates = [("all-in-one", pairs_all)] + [(g, ps) for g, ps in sorted(groups.items()) if len(ps) >= 3] - for gname, gpairs in candidates: - vg, rb, _ = v_grad_from(gpairs) - pg = _batch_pos(G, names, vg, rb) - auroc = _auroc(pg.tolist(), exploited.tolist()) - cohend = _cohend(pg[exploited], pg[solve_m]) if solve_m.sum() >= 2 else _cohend(pg[exploited], pg[~exploited]) - rows.append({"group": gname, "n_pairs": len(gpairs), "auroc": round(auroc, 3), - "cohend": round(cohend, 3)}) - logger.info(f" pairset {gname:12s} n={len(gpairs):2d} AUROC={auroc:.3f} d={cohend:+.3f}") - rank = pl.DataFrame(rows).sort("auroc", descending=True) + for gname, idx in candidates: + v = vectors(idx) + row = {"group": gname, "n_pairs": len(idx)} + for rep, kind in CASES: + s = _score(live_X[rep], v[rep], kind)[valid] + row[f"{rep}_{kind}"] = round(_auroc(s.tolist(), y[valid].tolist()), 3) + rows.append(row) + rank = pl.DataFrame(rows).sort("grad_dot", descending=True) rank.write_parquet(rank_path) + print("\nSHOULD: real pairsets beat 0.5 on at least one column; under --random-v-seed " + "every column ~0.5. Columns are AUROC of hack(A>0)-vs-rest on valid live rollouts.") print(tabulate(rank.to_pandas(), headers="keys", tablefmt="pipe", floatfmt="+.3f", showindex=False)) - df = pl.concat([ - pl.DataFrame({"pop": "on_solve", "pos": pos[solve_m], "step": steps[solve_m]}), - pl.DataFrame({"pop": "on_fail", "pos": pos[fail_m], "step": steps[fail_m]}), - pl.DataFrame({"pop": "on_hack", "pos": pos[exploited], "step": steps[exploited]}), - pl.DataFrame({"pop": "syn_solve", "pos": syn_solve, "step": np.full(len(syn_solve), -1)}), - pl.DataFrame({"pop": "syn_hack", "pos": syn_hack, "step": np.full(len(syn_hack), -1)}), - ]) + # ── persist per-rollout scores + raw features, then plot ── + def frame(pop_name: str, mask_or_scores, scores: dict, step_arr, adv_arr) -> pl.DataFrame: + return pl.DataFrame({"pop": pop_name, "step": step_arr, "adv": adv_arr, + **{c: scores[c][mask_or_scores] if mask_or_scores is not None else scores[c] + for c in scores}}) + dfs = [frame(p, pop == p, live_scores, steps[pop == p], adv[pop == p]) + for p in counts if counts[p] > 0] + n_syn = len(head_idx) + dfs += [frame(f"syn_{side}", None, syn_scores[side], + np.full(n_syn, -1), np.ones(n_syn)) for side in ("clean", "hack")] + df = pl.concat(dfs) df.write_parquet(data_path) - logger.info(f"wrote {data_path} ({len(df)} rows)") - sub = (f"{cfg.run_dir.name}\n{cfg.ckpt} v_grad, live steps {cfg.step_lo}-{cfg.step_hi}, " - f"hack={int(exploited.sum())} solve={int(solve_m.sum())} fail={int(fail_m.sum())}") - plot_dist(df, sub, dist_png) - plot_pairset(rank, cfg.run_dir.name, rank_png) + torch.save({"G": G, "ACT": ACT, "adv": adv, "exploited": exploited, "gt_pass": gt_pass, + "steps": steps, "p_idx": p_idx, "names": names, + "pair_feats": PF, "pair_groups": dict(groups), + "pair_ids": [p.problem_id for p in pairs_all]}, feats_path) + logger.info(f"wrote {data_path} ({len(df)} rows), {feats_path}") + + sub = (f"{cfg.run_dir.name} | {cfg.ckpt}, live steps {cfg.step_lo}-{cfg.step_hi}, v from " + f"'{cfg.headline_prefix}' pairs (n={len(head_idx)}) | " + f"hack+={counts['on_hackpos']} hack-={counts['on_hackneg']} solve={counts['on_solve']} " + f"fail={counts['on_fail']} dropped(A~0)={counts['on_drop']}" + + (f" | PLACEBO seed={cfg.random_v_seed}" if cfg.random_v_seed is not None else "")) + stats = plot_q2(df, cfg.k_mid, cfg.k_rout, sub, q2_png) + best = max(stats, key=lambda c: stats[c]["auroc"]) + print(f"\nmain metric: best case = {best} AUROC={stats[best]['auroc']:.3f} " + f"P@rout={stats[best]['prec_rout']:.2f} R@rout={stats[best]['rec_rout']:.2f}") + print(f"out: {q2_png}") return 0