mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 17:00:59 +08:00
diag(#40): offline follow-up -- pooling variants, synthetic common-mode, env_mode join
Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,185 @@
|
||||
"""Follow-up to diag_pinning.py, all offline from cached pinning_feats.pt (no GPU).
|
||||
|
||||
Three questions raised on the 2026-06-11 Q2 results:
|
||||
|
||||
Q-A (module weighting). _score concatenates modules with v unit-normalized PER
|
||||
module, so each module's contribution is ||x_m|| * cos_m: the live feature norm is
|
||||
the implicit weight and the pair-separation magnitude per module is discarded. Is
|
||||
that hurting? Compare poolings: (concat, unit-v) vs (concat, raw-diff v = modules
|
||||
weighted by pair separation) vs (equal-weight mean of per-module cosines), and for
|
||||
resid score each layer alone (residual norms grow with depth, so concat may be
|
||||
mostly the deepest layer).
|
||||
|
||||
Q-B (synthetic vs live apples-to-apples). Live grad scores use G*adv (|adv| ~ 0.2,
|
||||
sign flips for adv<0); synthetic pair sides are scored as raw gradNLL (implicit
|
||||
adv=+1). On cos panels that is comparable up to sign, BUT the synthetic medians sit
|
||||
off zero while live pops straddle it. Test: score raw live G (no adv) per pop, and
|
||||
a common-mode-centered variant (subtract the mean pair feature from both synthetic
|
||||
sides and live) -- if centering restores hack/clean symmetry the offset is a shared
|
||||
component (authored-pair style/NLL gradient), not a scoring bug.
|
||||
|
||||
Q-C (multimodality = loophole modes?). rollouts.jsonl carries env_mode per rollout.
|
||||
Label each hack+ rollout by mode and place the modes on the score axis: if the hack
|
||||
KDE bumps are modes, per-mode score means separate.
|
||||
|
||||
uv run python scripts/diag_pinning_followup.py
|
||||
outputs: printed tables + out/diag/pinning_followup_modes.png
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
from tabulate import tabulate
|
||||
|
||||
from vgrout.train import _auroc
|
||||
|
||||
ROOT = Path("/workspace/projected_grpo")
|
||||
RUNS = {
|
||||
"v3": (ROOT / "out/diag", ROOT / "out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3"),
|
||||
"v4": (ROOT / "out/diag_v4", ROOT / "out/runs/20260611T022655_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v4"),
|
||||
"v5": (ROOT / "out/diag_v5", ROOT / "out/runs/20260611T055637_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v5"),
|
||||
}
|
||||
HEAD_PREFIX = "behavior_"
|
||||
MODE_COLORS = {"run_tests": "#c44e52", "sentinel": "#d1900a", "stdout_marker": "#3a8a7a", "file_marker": "#7a5aa0"}
|
||||
|
||||
|
||||
def unit_rows(v: torch.Tensor) -> torch.Tensor:
|
||||
return v / v.norm(dim=-1, keepdim=True).clamp_min(1e-12)
|
||||
|
||||
|
||||
def cos_concat(X: torch.Tensor, V: torch.Tensor) -> np.ndarray:
|
||||
d = torch.einsum("nmr,mr->n", X, V)
|
||||
return (d / (X.flatten(1).norm(dim=1).clamp_min(1e-12) * V.flatten().norm())).numpy()
|
||||
|
||||
|
||||
def cos_equal(X: torch.Tensor, V: torch.Tensor) -> np.ndarray:
|
||||
"""Equal-weight mean over modules of the per-module cosine."""
|
||||
Vu = unit_rows(V)
|
||||
c = torch.einsum("nmr,mr->nm", unit_rows(X), Vu)
|
||||
return c.mean(1).numpy()
|
||||
|
||||
|
||||
def load_mode_labels(fe: dict, run_dir: Path) -> np.ndarray:
|
||||
"""env_mode per kept rollout, aligned by order-preserving (step, p_idx) match
|
||||
against the same filter diag_pinning used (step window, nonempty text, cap 240)."""
|
||||
steps, p_idx = fe["steps"], fe["p_idx"]
|
||||
lo, hi = int(steps.min()), int(steps.max())
|
||||
recs = [json.loads(l) for l in (run_dir / "rollouts.jsonl").read_text().splitlines()]
|
||||
batch = [x for x in recs if lo <= x["step"] <= hi and x["text"].strip()][:240]
|
||||
modes, bi = [], 0
|
||||
for s, p in zip(steps.tolist(), p_idx.tolist()):
|
||||
while not (batch[bi]["step"] == s and batch[bi]["p_idx"] == p):
|
||||
bi += 1 # rollout was skipped (non-finite loss) in the diag pass
|
||||
modes.append(batch[bi]["env_mode"])
|
||||
bi += 1
|
||||
return np.array(modes)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
fig, axes = plt.subplots(len(RUNS), 1, figsize=(9, 2.4 * len(RUNS)), sharex=False)
|
||||
pool_rows, layer_rows, syn_rows, mode_rows = [], [], [], []
|
||||
for ax, (tag, (diag_dir, run_dir)) in zip(np.atleast_1d(axes), RUNS.items()):
|
||||
fe = torch.load(diag_dir / "pinning_feats.pt", weights_only=False)
|
||||
G, ACT, RES, adv = fe["G"], fe["ACT"], fe["RES"], fe["adv"]
|
||||
exploited, gt_pass = fe["exploited"], fe["gt_pass"]
|
||||
PF = fe["pair_feats"]
|
||||
head = [i for i, pid in enumerate(fe["pair_ids"]) if pid.startswith(HEAD_PREFIX)]
|
||||
valid = np.abs(adv) > 1e-6
|
||||
pos = valid & (adv > 0)
|
||||
y = exploited & (adv > 0)
|
||||
au = lambda s: _auroc(s[pos].tolist(), y[pos].tolist()) # the A>0 contrast
|
||||
|
||||
# ---- Q-A: pooling variants ----
|
||||
for rep, X in (("grad", G * torch.tensor(adv, dtype=G.dtype)[:, None, None]),
|
||||
("act", ACT), ("resid", RES)):
|
||||
d = (PF[(rep, "hack")][head] - PF[(rep, "clean")][head]).mean(0) # [M, r] raw mean diff
|
||||
row = {"run": tag, "rep": rep,
|
||||
"concat_unitv": au(cos_concat(X, unit_rows(d))),
|
||||
"concat_rawv": au(cos_concat(X, d)),
|
||||
"equal_mean": au(cos_equal(X, d))}
|
||||
if rep == "resid":
|
||||
for li, L in enumerate(fe["resid_layers"]):
|
||||
row[f"L{L}"] = au(cos_concat(X[:, li:li+1], unit_rows(d)[li:li+1]))
|
||||
norms = X.flatten(0, 0).norm(dim=-1).mean(0) # [L] mean live norm per layer
|
||||
layer_rows.append({"run": tag, **{f"|x| L{L}": float(norms[li])
|
||||
for li, L in enumerate(fe["resid_layers"])}})
|
||||
pool_rows.append(row)
|
||||
|
||||
# ---- Q-B: synthetic vs live on the SAME cos scale (grad rep) ----
|
||||
d_g = unit_rows((PF[("grad", "hack")][head] - PF[("grad", "clean")][head]).mean(0))
|
||||
c_common = torch.cat([PF[("grad", "hack")][head], PF[("grad", "clean")][head]]).mean(0)
|
||||
med = lambda x: float(np.median(x)) if len(x) else float("nan")
|
||||
pop = {"solve": gt_pass & ~exploited & valid, "hack+": y & valid}
|
||||
syn = {s: cos_concat(PF[("grad", s)][head], d_g) for s in ("hack", "clean")}
|
||||
syn_c = {s: cos_concat(PF[("grad", s)][head] - c_common, d_g) for s in ("hack", "clean")}
|
||||
Gadv = G * torch.tensor(adv, dtype=G.dtype)[:, None, None]
|
||||
syn_rows.append({
|
||||
"run": tag,
|
||||
"syn_hack": med(syn["hack"]), "syn_clean": med(syn["clean"]),
|
||||
"syn_hack_ctr": med(syn_c["hack"]), "syn_clean_ctr": med(syn_c["clean"]),
|
||||
"live_hack+ (G*adv)": med(cos_concat(Gadv, d_g)[pop["hack+"]]),
|
||||
"live_solve (G*adv)": med(cos_concat(Gadv, d_g)[pop["solve"]]),
|
||||
"live_hack+ (raw G)": med(cos_concat(G, d_g)[pop["hack+"]]),
|
||||
"live_solve (raw G)": med(cos_concat(G, d_g)[pop["solve"]]),
|
||||
})
|
||||
|
||||
# ---- Q-C: hack-mode positions on the resid_cos axis ----
|
||||
d_r = unit_rows((PF[("resid", "hack")][head] - PF[("resid", "clean")][head]).mean(0))
|
||||
s_r = cos_concat(RES, d_r)
|
||||
modes = load_mode_labels(fe, run_dir)
|
||||
for m in sorted(set(modes[y])):
|
||||
sm = s_r[y & (modes == m)]
|
||||
rest = s_r[pos & ~y]
|
||||
mode_rows.append({"run": tag, "mode": m, "n": len(sm), "median": med(sm),
|
||||
"auroc_vs_nonhack": _auroc(np.concatenate([sm, rest]).tolist(),
|
||||
([True] * len(sm) + [False] * len(rest)))})
|
||||
# strip plot: each hack+ point colored by mode, solve/fail as grey context
|
||||
rng_rows = [("solve", s_r[gt_pass & ~exploited & pos], "#3b6ea5"),
|
||||
("fail", s_r[~gt_pass & ~exploited & pos], "#9aa0a6")]
|
||||
for lab, xs, c in rng_rows:
|
||||
ax.plot(xs, np.full(len(xs), 0.0), "|", color=c, ms=10, alpha=0.5, mew=1.2)
|
||||
for yi, m in enumerate(sorted(set(modes[y])), start=1):
|
||||
xs = s_r[y & (modes == m)]
|
||||
ax.plot(xs, np.full(len(xs), yi * 0.22), "|", color=MODE_COLORS.get(m, "k"),
|
||||
ms=10, mew=1.5, label=f"{m} (n={len(xs)})")
|
||||
ax.set_yticks([])
|
||||
ax.set_title(f"{tag}: resid_cos, hack+ rollouts by env_mode (solve blue / fail grey at y=0)",
|
||||
fontsize=9)
|
||||
ax.legend(fontsize=7, loc="upper left", frameon=False)
|
||||
for sp in ("top", "right", "left"):
|
||||
ax.spines[sp].set_visible(False)
|
||||
|
||||
print("\nQ-A pooling variants, AUROC on the A>0 contrast (hack+ vs solve/fail among adv>0):")
|
||||
print("SHOULD: if concat_rawv or equal_mean beats concat_unitv by >0.05 the current "
|
||||
"pooling is leaving signal on the table; per-layer cols show whether one resid "
|
||||
"layer carries the concat score.")
|
||||
print(tabulate(pool_rows, headers="keys", tablefmt="pipe", floatfmt="+.3f"))
|
||||
print("\nmean live residual norm per layer (concat weight is proportional to this):")
|
||||
print(tabulate(layer_rows, headers="keys", tablefmt="pipe", floatfmt=".1f"))
|
||||
print("\nQ-B synthetic vs live grad_cos medians (same v, same cos):")
|
||||
print("SHOULD: syn_hack > 0 > syn_clean if pair grads are common-mode-free; if instead "
|
||||
"both sit one side and the _ctr (centered) columns straddle zero, the offset is a "
|
||||
"shared authored-pair component, not a scoring bug. live raw-G columns remove the "
|
||||
"adv weighting for a like-for-like comparison with syn.")
|
||||
print(tabulate(syn_rows, headers="keys", tablefmt="pipe", floatfmt="+.3f"))
|
||||
print("\nQ-C hack+ rollouts by loophole mode on resid_cos:")
|
||||
print("SHOULD: if the hack-KDE bumps are modes, per-mode medians differ by more than "
|
||||
"their spread and per-mode AUROC vs non-hack varies; if medians coincide, "
|
||||
"multimodality is NOT mode identity.")
|
||||
print(tabulate(mode_rows, headers="keys", tablefmt="pipe", floatfmt="+.3f"))
|
||||
|
||||
out_png = ROOT / "out/diag/pinning_followup_modes.png"
|
||||
fig.suptitle("hack+ scores by env_mode (resid_cos, v from behavior_ pairs)", fontsize=10)
|
||||
fig.tight_layout(rect=(0, 0, 1, 0.96))
|
||||
fig.savefig(out_png, dpi=140)
|
||||
print(f"\nout: {out_png}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Reference in New Issue
Block a user