From 24aea19beea9c40ea82916684efccf92c39704c1 Mon Sep 17 00:00:00 2001 From: wassname <1103714+wassname@users.noreply.github.com> Date: Thu, 11 Jun 2026 10:58:16 +0000 Subject: [PATCH] diag(#40): offline follow-up -- pooling variants, synthetic common-mode, env_mode join Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com> --- scripts/diag_pinning_followup.py | 185 +++++++++++++++++++++++++++++++ 1 file changed, 185 insertions(+) create mode 100644 scripts/diag_pinning_followup.py diff --git a/scripts/diag_pinning_followup.py b/scripts/diag_pinning_followup.py new file mode 100644 index 0000000..2b7c839 --- /dev/null +++ b/scripts/diag_pinning_followup.py @@ -0,0 +1,185 @@ +"""Follow-up to diag_pinning.py, all offline from cached pinning_feats.pt (no GPU). + +Three questions raised on the 2026-06-11 Q2 results: + +Q-A (module weighting). _score concatenates modules with v unit-normalized PER +module, so each module's contribution is ||x_m|| * cos_m: the live feature norm is +the implicit weight and the pair-separation magnitude per module is discarded. Is +that hurting? Compare poolings: (concat, unit-v) vs (concat, raw-diff v = modules +weighted by pair separation) vs (equal-weight mean of per-module cosines), and for +resid score each layer alone (residual norms grow with depth, so concat may be +mostly the deepest layer). + +Q-B (synthetic vs live apples-to-apples). Live grad scores use G*adv (|adv| ~ 0.2, +sign flips for adv<0); synthetic pair sides are scored as raw gradNLL (implicit +adv=+1). On cos panels that is comparable up to sign, BUT the synthetic medians sit +off zero while live pops straddle it. Test: score raw live G (no adv) per pop, and +a common-mode-centered variant (subtract the mean pair feature from both synthetic +sides and live) -- if centering restores hack/clean symmetry the offset is a shared +component (authored-pair style/NLL gradient), not a scoring bug. + +Q-C (multimodality = loophole modes?). rollouts.jsonl carries env_mode per rollout. +Label each hack+ rollout by mode and place the modes on the score axis: if the hack +KDE bumps are modes, per-mode score means separate. + + uv run python scripts/diag_pinning_followup.py +outputs: printed tables + out/diag/pinning_followup_modes.png +""" +from __future__ import annotations +import json +from pathlib import Path + +import numpy as np +import torch +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +from tabulate import tabulate + +from vgrout.train import _auroc + +ROOT = Path("/workspace/projected_grpo") +RUNS = { + "v3": (ROOT / "out/diag", ROOT / "out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3"), + "v4": (ROOT / "out/diag_v4", ROOT / "out/runs/20260611T022655_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v4"), + "v5": (ROOT / "out/diag_v5", ROOT / "out/runs/20260611T055637_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v5"), +} +HEAD_PREFIX = "behavior_" +MODE_COLORS = {"run_tests": "#c44e52", "sentinel": "#d1900a", "stdout_marker": "#3a8a7a", "file_marker": "#7a5aa0"} + + +def unit_rows(v: torch.Tensor) -> torch.Tensor: + return v / v.norm(dim=-1, keepdim=True).clamp_min(1e-12) + + +def cos_concat(X: torch.Tensor, V: torch.Tensor) -> np.ndarray: + d = torch.einsum("nmr,mr->n", X, V) + return (d / (X.flatten(1).norm(dim=1).clamp_min(1e-12) * V.flatten().norm())).numpy() + + +def cos_equal(X: torch.Tensor, V: torch.Tensor) -> np.ndarray: + """Equal-weight mean over modules of the per-module cosine.""" + Vu = unit_rows(V) + c = torch.einsum("nmr,mr->nm", unit_rows(X), Vu) + return c.mean(1).numpy() + + +def load_mode_labels(fe: dict, run_dir: Path) -> np.ndarray: + """env_mode per kept rollout, aligned by order-preserving (step, p_idx) match + against the same filter diag_pinning used (step window, nonempty text, cap 240).""" + steps, p_idx = fe["steps"], fe["p_idx"] + lo, hi = int(steps.min()), int(steps.max()) + recs = [json.loads(l) for l in (run_dir / "rollouts.jsonl").read_text().splitlines()] + batch = [x for x in recs if lo <= x["step"] <= hi and x["text"].strip()][:240] + modes, bi = [], 0 + for s, p in zip(steps.tolist(), p_idx.tolist()): + while not (batch[bi]["step"] == s and batch[bi]["p_idx"] == p): + bi += 1 # rollout was skipped (non-finite loss) in the diag pass + modes.append(batch[bi]["env_mode"]) + bi += 1 + return np.array(modes) + + +def main() -> int: + fig, axes = plt.subplots(len(RUNS), 1, figsize=(9, 2.4 * len(RUNS)), sharex=False) + pool_rows, layer_rows, syn_rows, mode_rows = [], [], [], [] + for ax, (tag, (diag_dir, run_dir)) in zip(np.atleast_1d(axes), RUNS.items()): + fe = torch.load(diag_dir / "pinning_feats.pt", weights_only=False) + G, ACT, RES, adv = fe["G"], fe["ACT"], fe["RES"], fe["adv"] + exploited, gt_pass = fe["exploited"], fe["gt_pass"] + PF = fe["pair_feats"] + head = [i for i, pid in enumerate(fe["pair_ids"]) if pid.startswith(HEAD_PREFIX)] + valid = np.abs(adv) > 1e-6 + pos = valid & (adv > 0) + y = exploited & (adv > 0) + au = lambda s: _auroc(s[pos].tolist(), y[pos].tolist()) # the A>0 contrast + + # ---- Q-A: pooling variants ---- + for rep, X in (("grad", G * torch.tensor(adv, dtype=G.dtype)[:, None, None]), + ("act", ACT), ("resid", RES)): + d = (PF[(rep, "hack")][head] - PF[(rep, "clean")][head]).mean(0) # [M, r] raw mean diff + row = {"run": tag, "rep": rep, + "concat_unitv": au(cos_concat(X, unit_rows(d))), + "concat_rawv": au(cos_concat(X, d)), + "equal_mean": au(cos_equal(X, d))} + if rep == "resid": + for li, L in enumerate(fe["resid_layers"]): + row[f"L{L}"] = au(cos_concat(X[:, li:li+1], unit_rows(d)[li:li+1])) + norms = X.flatten(0, 0).norm(dim=-1).mean(0) # [L] mean live norm per layer + layer_rows.append({"run": tag, **{f"|x| L{L}": float(norms[li]) + for li, L in enumerate(fe["resid_layers"])}}) + pool_rows.append(row) + + # ---- Q-B: synthetic vs live on the SAME cos scale (grad rep) ---- + d_g = unit_rows((PF[("grad", "hack")][head] - PF[("grad", "clean")][head]).mean(0)) + c_common = torch.cat([PF[("grad", "hack")][head], PF[("grad", "clean")][head]]).mean(0) + med = lambda x: float(np.median(x)) if len(x) else float("nan") + pop = {"solve": gt_pass & ~exploited & valid, "hack+": y & valid} + syn = {s: cos_concat(PF[("grad", s)][head], d_g) for s in ("hack", "clean")} + syn_c = {s: cos_concat(PF[("grad", s)][head] - c_common, d_g) for s in ("hack", "clean")} + Gadv = G * torch.tensor(adv, dtype=G.dtype)[:, None, None] + syn_rows.append({ + "run": tag, + "syn_hack": med(syn["hack"]), "syn_clean": med(syn["clean"]), + "syn_hack_ctr": med(syn_c["hack"]), "syn_clean_ctr": med(syn_c["clean"]), + "live_hack+ (G*adv)": med(cos_concat(Gadv, d_g)[pop["hack+"]]), + "live_solve (G*adv)": med(cos_concat(Gadv, d_g)[pop["solve"]]), + "live_hack+ (raw G)": med(cos_concat(G, d_g)[pop["hack+"]]), + "live_solve (raw G)": med(cos_concat(G, d_g)[pop["solve"]]), + }) + + # ---- Q-C: hack-mode positions on the resid_cos axis ---- + d_r = unit_rows((PF[("resid", "hack")][head] - PF[("resid", "clean")][head]).mean(0)) + s_r = cos_concat(RES, d_r) + modes = load_mode_labels(fe, run_dir) + for m in sorted(set(modes[y])): + sm = s_r[y & (modes == m)] + rest = s_r[pos & ~y] + mode_rows.append({"run": tag, "mode": m, "n": len(sm), "median": med(sm), + "auroc_vs_nonhack": _auroc(np.concatenate([sm, rest]).tolist(), + ([True] * len(sm) + [False] * len(rest)))}) + # strip plot: each hack+ point colored by mode, solve/fail as grey context + rng_rows = [("solve", s_r[gt_pass & ~exploited & pos], "#3b6ea5"), + ("fail", s_r[~gt_pass & ~exploited & pos], "#9aa0a6")] + for lab, xs, c in rng_rows: + ax.plot(xs, np.full(len(xs), 0.0), "|", color=c, ms=10, alpha=0.5, mew=1.2) + for yi, m in enumerate(sorted(set(modes[y])), start=1): + xs = s_r[y & (modes == m)] + ax.plot(xs, np.full(len(xs), yi * 0.22), "|", color=MODE_COLORS.get(m, "k"), + ms=10, mew=1.5, label=f"{m} (n={len(xs)})") + ax.set_yticks([]) + ax.set_title(f"{tag}: resid_cos, hack+ rollouts by env_mode (solve blue / fail grey at y=0)", + fontsize=9) + ax.legend(fontsize=7, loc="upper left", frameon=False) + for sp in ("top", "right", "left"): + ax.spines[sp].set_visible(False) + + print("\nQ-A pooling variants, AUROC on the A>0 contrast (hack+ vs solve/fail among adv>0):") + print("SHOULD: if concat_rawv or equal_mean beats concat_unitv by >0.05 the current " + "pooling is leaving signal on the table; per-layer cols show whether one resid " + "layer carries the concat score.") + print(tabulate(pool_rows, headers="keys", tablefmt="pipe", floatfmt="+.3f")) + print("\nmean live residual norm per layer (concat weight is proportional to this):") + print(tabulate(layer_rows, headers="keys", tablefmt="pipe", floatfmt=".1f")) + print("\nQ-B synthetic vs live grad_cos medians (same v, same cos):") + print("SHOULD: syn_hack > 0 > syn_clean if pair grads are common-mode-free; if instead " + "both sit one side and the _ctr (centered) columns straddle zero, the offset is a " + "shared authored-pair component, not a scoring bug. live raw-G columns remove the " + "adv weighting for a like-for-like comparison with syn.") + print(tabulate(syn_rows, headers="keys", tablefmt="pipe", floatfmt="+.3f")) + print("\nQ-C hack+ rollouts by loophole mode on resid_cos:") + print("SHOULD: if the hack-KDE bumps are modes, per-mode medians differ by more than " + "their spread and per-mode AUROC vs non-hack varies; if medians coincide, " + "multimodality is NOT mode identity.") + print(tabulate(mode_rows, headers="keys", tablefmt="pipe", floatfmt="+.3f")) + + out_png = ROOT / "out/diag/pinning_followup_modes.png" + fig.suptitle("hack+ scores by env_mode (resid_cos, v from behavior_ pairs)", fontsize=10) + fig.tight_layout(rect=(0, 0, 1, 0.96)) + fig.savefig(out_png, dpi=140) + print(f"\nout: {out_png}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main())