mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 16:30:30 +08:00
e5b68acf69
Synthetic and live scores share an axis with meaningful zero (raw scores carry a common <mu,v> offset since v = mean diff is not orthogonal to the family mean). Zones come from label-free online stats (EMA mean/std + Otsu valley), replacing mean+k*sd which placed both cuts beyond every distribution. Winsorize at 1/99% before Otsu: variance-maximizing cuts otherwise buy a class for one outlier. Fresh-eyes review verified z-norm is affine (AUROCs unchanged), zones label-free, P/R recompute from parquet matches titles. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
580 lines
31 KiB
Python
580 lines
31 KiB
Python
"""Q2 diagnostic: what should the live routeV gate SCORE, and where do the pinning
|
|
cuts go?
|
|
|
|
THE QUESTION (Q2). The gate routes UPDATES, not rollouts: per rollout the GRPO update
|
|
on the LoRA is g = A * gradNLL (advantage-weighted). So the positive class is
|
|
"update ascending the hack direction" = exploited & A>0; rollouts with A~0 contribute
|
|
no update and are DROPPED (scoring them as dead zeros at 0 is what made earlier
|
|
adv-weighted AUROCs look blind); a hack with A<0 is being UNLEARNED and belongs in
|
|
the negative class. (Q1, "does the direction exist at adv=+1", was answered earlier:
|
|
~0.61 -- see git history of this file.)
|
|
|
|
TWO CONTRASTS (post-review fix). On the vs-ALL contrast (hack+ vs every other valid
|
|
rollout) the advantage ALONE is a ~0.9 AUROC detector, because the label requires A>0
|
|
and most fails have A<0 -- so a high vs-all number mostly restates the reward, which
|
|
the live gate already has for free. The vector's VALUE-ADD is the A>0 contrast: among
|
|
positively-reinforced updates (where adv is blind, ~0.5), can the score tell hacks
|
|
from solves? That is the gate's primary objective: exclude reward-hacking updates while
|
|
retaining correct-solution updates, and it is the primary result here. The vs-all
|
|
contrast is kept as a secondary column. NOTE an info
|
|
asymmetry: grad scores carry the adv factor (sign+magnitude), act scores do not, so
|
|
on the vs-all contrast grad gets label-correlated information act lacks; the A>0
|
|
contrast removes most of that difference.
|
|
|
|
SIX CANDIDATE SCORES = {grad, act, resid} x {cos, dot}, concatenated over modules/layers:
|
|
- grad: the adv-weighted deployed c-probe gradient (the gate's current input).
|
|
- act: the deployed bottleneck activation A[:r]@x, mean over completion tokens --
|
|
same [r]-per-module space, capturable in the gate's pass-1 forward for free.
|
|
- resid: residual-stream hidden states at cfg.resid_layers, mean over completion
|
|
tokens. Adapter-independent: at an early checkpoint A is near its Gaussian
|
|
init, so grad and act are both views through a random r=32 projection per
|
|
module; resid tests whether that subspace, not grad-vs-act, limits separation.
|
|
- cos: magnitude-blind alignment (tiny vectors give meaningless angles -- control).
|
|
- dot: <g, v> = |g|*cos, magnitude-aware; with g = A*gradNLL the advantage rides
|
|
along, so dot measures update magnitude aligned with v.
|
|
v for each representation comes only from authored pairs (mean hack-minus-clean,
|
|
normalized per module). Ground-truth labels from training rollouts are used only for
|
|
diagnostic AUROC and precision measurements, never for routing.
|
|
|
|
DISPLAY + PINNING. Scores are plotted Z-NORMALIZED WITHIN FAMILY: live scores by the
|
|
mean/std of all valid live rollouts, synthetic scores by the mean/std of the joint
|
|
clean+hack pair scores. Affine per family, so every AUROC is unchanged; it puts both
|
|
families on one axis with a meaningful zero. (Raw scores share an offset <mu, v>:
|
|
v = mean(hack-clean) guarantees only the GAP between sides, not its location, and the
|
|
authored-pair common mean is not orthogonal to v, so uncentered both pair sides land
|
|
positive.) Zones keep | absorb | rout come from two-threshold Otsu on the live
|
|
z-scores -- the label-free valley cuts an online gate could compute from a rolling
|
|
score window (EMA mean/std + valley search). The previous mean+k*sd rule modeled
|
|
hacks as a rare outlier tail and put both cuts beyond every distribution (hack share
|
|
in these windows is 35-43%); the oracle hack-vs-rest split is drawn for reference.
|
|
|
|
CAVEAT. Live advantages are reconstructed from rollouts.jsonl students only (teachers
|
|
absent, zero-variance groups included, and skipped/empty completions missing from the
|
|
group mean), so A signs/magnitudes are approximate; the act columns dodge this
|
|
entirely (no A in the representation).
|
|
|
|
HOW. One GPU pass: per live rollout, backward its completion NLL once, capture the
|
|
c-probe grad AND the pooled bottleneck act; same per authored-pair side. Everything
|
|
downstream (subset vectors, 4 scores, zones, table) is offline re-projection of the
|
|
cached features.
|
|
|
|
uv run python scripts/diag_pinning.py --run-dir out/runs/<vanilla_lora2r_run>
|
|
uv run python scripts/diag_pinning.py --feats out/diag/pinning_feats.pt # no GPU:
|
|
# recompute scores/table/plot from cached feats
|
|
uv run python scripts/diag_pinning.py --replot out/diag/pinning_data.parquet # plot only
|
|
outputs (out/diag/): pinning_q2.png (3x2 headline), pinning_data.parquet (per-rollout
|
|
scores), pinning_pairset.parquet + printed table (subsets x 6 AUROCs),
|
|
pinning_feats.pt (raw features, for offline re-analysis).
|
|
"""
|
|
from __future__ import annotations
|
|
import json
|
|
import struct
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import tyro
|
|
import polars as pl
|
|
import matplotlib
|
|
matplotlib.use("Agg")
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib.lines import Line2D
|
|
from matplotlib.patches import Patch
|
|
from loguru import logger
|
|
from tabulate import tabulate
|
|
from safetensors.torch import load_file
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
from vgrout.lora2r import wrap_model_with_lora2r
|
|
from vgrout.pairs import load_pairs
|
|
from vgrout.extract_vhack_grad import completion_nll
|
|
from vgrout.train import _auroc
|
|
|
|
# colour = behaviour (blue solve, red hack, grey fail); style = source (solid on-policy, dashed synthetic)
|
|
SOLVE, HACK, FAIL, ABSORB_C, ROUT_C, ORACLE = "#3b6ea5", "#c44e52", "#9aa0a6", "#d1900a", "#c44e52", "#3a8a7a"
|
|
CASES = [("grad", "cos"), ("grad", "dot"), ("act", "cos"), ("act", "dot"),
|
|
("resid", "cos"), ("resid", "dot")]
|
|
|
|
|
|
@dataclass
|
|
class Cfg:
|
|
run_dir: Path = Path("out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3")
|
|
ckpt: str = "first_hack"
|
|
pairs: Path = Path("data/pairs/hack_pairs.md#all-in-one")
|
|
# headline figure builds v from this heading-prefix subset = the routeV TRAINING
|
|
# default (train_config.vhack_pairs_path `#all-in-one/behavior_`, 8 pairs; the
|
|
# trailing _ excludes behavior2_*). The pairset table spans all subsets of `pairs`.
|
|
headline_prefix: str = "behavior_"
|
|
# Coherent emergence window. This vanilla v3 used the pre-fix lr=5e-4/warmup-0.1 and
|
|
# DIVERGED at step 10 (exploited 20/24 -> 0/24); 2-9 = hacks emerging, model still sane.
|
|
step_lo: int = 2
|
|
step_hi: int = 9
|
|
max_rollouts: int = 240
|
|
adv_eps: float = 1e-6 # |A| below this = no update exists -> dropped from zones/AUROC
|
|
resid_layers: tuple[int, ...] = (12, 18, 24) # residual-stream capture depths (of 36)
|
|
random_v_seed: int | None = None # Haar placebo (sanity: nothing should separate)
|
|
feats: Path | None = None # cached pinning_feats.pt -> full offline re-analysis
|
|
replot: Path | None = None # load parquet and re-plot only (no model, no GPU)
|
|
out_dir: Path = Path("out/diag")
|
|
|
|
|
|
def _ckpt_meta(path: Path) -> dict:
|
|
with open(path, "rb") as f:
|
|
return json.loads(f.read(struct.unpack("<Q", f.read(8))[0])).get("__metadata__", {})
|
|
|
|
|
|
class ActTap:
|
|
"""Forward hooks stashing (a) the deployed bottleneck activation h = A[:r] @ x per
|
|
module and (b) the residual-stream hidden state after each decoder layer in
|
|
`resid_modules`.
|
|
|
|
(a) computes the r-dim projection inline (no_grad) instead of retaining the full
|
|
[L, d_in] input -- ~250 modules x [L, d_in] would be GBs; [L, r] is nothing.
|
|
"""
|
|
def __init__(self, wrappers: dict, names: list[str], resid_modules: list):
|
|
self.wrappers, self.names, self.resid_modules = wrappers, names, resid_modules
|
|
self.h, self.res, self.handles = {}, {}, []
|
|
|
|
def __enter__(self):
|
|
for nm in self.names:
|
|
layer = self.wrappers[nm]["layer"]
|
|
def hook(layer, args, out, nm=nm):
|
|
(x,) = args
|
|
with torch.no_grad():
|
|
self.h[nm] = F.linear(x.detach(), layer._lora2r_A[: layer._lora2r_r].to(x.dtype))
|
|
self.handles.append(layer.register_forward_hook(hook))
|
|
for li, mod in enumerate(self.resid_modules):
|
|
def rhook(mod, args, out, li=li):
|
|
self.res[li] = (out[0] if isinstance(out, tuple) else out).detach()
|
|
self.handles.append(mod.register_forward_hook(rhook))
|
|
return self
|
|
|
|
def __exit__(self, *exc):
|
|
for h in self.handles:
|
|
h.remove()
|
|
|
|
def pooled(self, n_prompt: int) -> torch.Tensor:
|
|
"""[M, r] mean bottleneck act over completion tokens (positions >= n_prompt)."""
|
|
out = []
|
|
for nm in self.names:
|
|
h = self.h[nm] # [1, L, r]
|
|
assert h.shape[1] > n_prompt, f"{nm}: no completion tokens (L={h.shape[1]} n_prompt={n_prompt})"
|
|
out.append(h[0, n_prompt:].float().mean(0).cpu())
|
|
return torch.stack(out)
|
|
|
|
def pooled_resid(self, n_prompt: int) -> torch.Tensor:
|
|
"""[L_layers, d_model] mean residual-stream state over completion tokens."""
|
|
return torch.stack([self.res[li][0, n_prompt:].float().mean(0).cpu()
|
|
for li in range(len(self.resid_modules))])
|
|
|
|
|
|
def _gate_grads(wrappers: dict, names: list[str]) -> torch.Tensor:
|
|
"""[M, r] deployed-block c-probe grad after a backward (the gate's gradient space)."""
|
|
g = []
|
|
for nm in names:
|
|
layer = wrappers[nm]["layer"]
|
|
gr = layer._lora2r_gate.grad
|
|
g.append(gr.sum(dim=tuple(range(gr.dim() - 1)))[: layer._lora2r_r].float().cpu())
|
|
return torch.stack(g)
|
|
|
|
|
|
def _v_from(feat_hack: torch.Tensor, feat_clean: torch.Tensor, idx: list[int]) -> torch.Tensor:
|
|
"""[M, r] unit-per-module mean hack-minus-clean direction from pair rows `idx`."""
|
|
d = (feat_hack[idx] - feat_clean[idx]).mean(0)
|
|
return d / d.norm(dim=1, keepdim=True).clamp_min(1e-12)
|
|
|
|
|
|
def _haar_like(v: torch.Tensor, seed: int) -> torch.Tensor:
|
|
g = torch.Generator().manual_seed(seed)
|
|
d = torch.randn(v.shape, generator=g)
|
|
return d / d.norm(dim=1, keepdim=True).clamp_min(1e-12)
|
|
|
|
|
|
def _score(X: torch.Tensor, V: torch.Tensor, kind: str) -> np.ndarray:
|
|
"""Concat-module score per rollout: dot = sum_m <x_m, v_m>; cos = dot / (||x|| ||v||)."""
|
|
d = torch.einsum("nmr,mr->n", X, V)
|
|
if kind == "dot":
|
|
return d.numpy()
|
|
return (d / (X.flatten(1).norm(dim=1).clamp_min(1e-12) * V.flatten().norm().clamp_min(1e-12))).numpy()
|
|
|
|
|
|
def _kde(x: np.ndarray, grid: np.ndarray) -> np.ndarray:
|
|
"""Gaussian KDE, Silverman bandwidth (no scipy). Bandwidth is scale-relative
|
|
(dot scores can live at 1e-4 or 1e2)."""
|
|
x = np.asarray(x, float)
|
|
if len(x) < 2:
|
|
return np.zeros_like(grid)
|
|
iqr = np.subtract(*np.percentile(x, [75, 25]))
|
|
sigma = min(x.std(ddof=1), iqr / 1.349) if iqr > 0 else x.std(ddof=1)
|
|
if sigma <= 0:
|
|
return np.zeros_like(grid)
|
|
bw = 0.9 * sigma * len(x) ** (-0.2)
|
|
z = (grid[:, None] - x[None, :]) / bw
|
|
return np.exp(-0.5 * z ** 2).sum(1) / (len(x) * bw * np.sqrt(2 * np.pi))
|
|
|
|
|
|
def _otsu3(x: np.ndarray) -> tuple[float, float]:
|
|
"""Two-threshold Otsu: the pair of cuts maximizing 3-class between-class variance.
|
|
Label-free -- an online gate can compute this from a rolling window of scores, so
|
|
using it here is not oracle leakage. O(n^2), fine for a few hundred scores.
|
|
Scores are winsorized at 1/99% first: Otsu maximizes variance, so on heavy-tailed
|
|
scores a single extreme point otherwise buys a whole class (seen on grad_dot)."""
|
|
x = np.clip(x, *np.quantile(x, [0.01, 0.99]))
|
|
s = np.sort(np.asarray(x, float))
|
|
n = len(s)
|
|
c = np.concatenate([[0.0], np.cumsum(s)])
|
|
best, best_ij = -np.inf, (1, 2)
|
|
for i in range(1, n - 1):
|
|
for j in range(i + 1, n):
|
|
obj = c[i] ** 2 / i + (c[j] - c[i]) ** 2 / (j - i) + (c[n] - c[j]) ** 2 / (n - j)
|
|
if obj > best:
|
|
best, best_ij = obj, (i, j)
|
|
i, j = best_ij
|
|
return float((s[i - 1] + s[i]) / 2), float((s[j - 1] + s[j]) / 2)
|
|
|
|
|
|
def plot_q2(df: pl.DataFrame, subtitle: str, out_png: Path) -> dict:
|
|
"""3x2 figure ({grad,act,resid} x {cos,dot}) from the saved per-rollout scores -- no GPU.
|
|
|
|
Per panel: live solve/fail/hack+ KDEs (+ thin hack- if n>=3), synthetic pair sides
|
|
dashed, all Z-NORMALIZED WITHIN FAMILY (live by valid-live mean/std, synthetic by
|
|
joint clean+hack mean/std; affine, AUROC unchanged) so both families share one axis
|
|
with a meaningful zero. Three shaded zones keep|absorb|rout from two-threshold Otsu
|
|
on the live z-scores (label-free, online-computable), oracle split, AUROC + P/R at
|
|
the rout cut. Returns the per-case stats dict for logging."""
|
|
pops = {p: df.filter(pl.col("pop") == p) for p in df["pop"].unique().to_list()}
|
|
live_pops = ["on_solve", "on_fail", "on_hackpos", "on_hackneg"]
|
|
live = df.filter(pl.col("pop").is_in(live_pops))
|
|
posm = (live["adv"] > 0).to_numpy() # the A>0 contrast rows
|
|
y_all = (live["pop"] == "on_hackpos").to_numpy()
|
|
# adv-only baseline: on vs-all the reward alone is a strong detector (the label
|
|
# requires A>0); the vector only adds value where this baseline is blind.
|
|
a = live["adv"].to_numpy()
|
|
logger.info(f"adv-only baseline AUROC: vs-all={_auroc(a.tolist(), y_all.tolist()):.3f} "
|
|
f"A>0-contrast={_auroc(a[posm].tolist(), y_all[posm].tolist()):.3f} "
|
|
f"(n+={int(y_all.sum())} negA>0={int((~y_all & posm).sum())})")
|
|
stats = {}
|
|
n_rows = len(CASES) // 2
|
|
fig, axes = plt.subplots(n_rows, 2, figsize=(12.5, 3.6 * n_rows + 0.8))
|
|
for ax, (rep, kind) in zip(axes.flat, CASES):
|
|
col = f"{rep}_{kind}"
|
|
s_raw = live[col].to_numpy()
|
|
mu_l, sd_l = float(s_raw.mean()), float(s_raw.std()) or 1.0
|
|
syn_join = np.concatenate([pops[p][col].to_numpy() for p in ("syn_clean", "syn_hack")
|
|
if len(pops.get(p, []))])
|
|
mu_s = float(syn_join.mean()) if len(syn_join) else 0.0
|
|
sd_s = (float(syn_join.std()) or 1.0) if len(syn_join) else 1.0
|
|
z_of = lambda x, p: (x - mu_s) / sd_s if p.startswith("syn") else (x - mu_l) / sd_l
|
|
s = (s_raw - mu_l) / sd_l
|
|
y = y_all
|
|
t_lo, t_hi = _otsu3(s)
|
|
auroc = _auroc(s.tolist(), y.tolist())
|
|
auroc_pos = _auroc(s[posm].tolist(), y[posm].tolist())
|
|
thr = np.unique(s)
|
|
j = [(s[y] >= t).mean() - (s[~y] >= t).mean() for t in thr]
|
|
oracle = float(thr[int(np.argmax(j))])
|
|
routed = s >= t_hi
|
|
n_rout = int(routed.sum())
|
|
prec = float(y[routed].mean()) if routed.any() else float("nan")
|
|
rec = float((s[y] >= t_hi).mean()) if y.any() else float("nan")
|
|
stats[col] = {"auroc_pos": auroc_pos, "auroc_all": auroc, "prec_rout": prec,
|
|
"rec_rout": rec, "n_rout": n_rout, "t_hi": t_hi, "oracle": oracle}
|
|
|
|
zvals = np.concatenate([s, (syn_join - mu_s) / sd_s]) if len(syn_join) else s
|
|
lo = float(np.quantile(zvals, 0.005))
|
|
hi = float(np.quantile(zvals, 0.995))
|
|
pad = 0.05 * (hi - lo) or 1e-6
|
|
lo, hi = lo - pad, hi + pad
|
|
grid = np.linspace(lo, hi, 400)
|
|
|
|
curves = [("on_solve", SOLVE, "-", 1.9, 0.12), ("on_fail", FAIL, "-", 1.9, 0.12),
|
|
("on_hackpos", HACK, "-", 1.9, 0.12),
|
|
("syn_clean", SOLVE, (0, (5, 2)), 2.0, 0.0), ("syn_hack", HACK, (0, (5, 2)), 2.0, 0.0)]
|
|
if len(pops.get("on_hackneg", [])) >= 3:
|
|
curves.insert(3, ("on_hackneg", HACK, (0, (1, 1)), 1.2, 0.0))
|
|
ymax = 0.0
|
|
for p, c, ls, lw, fill in curves:
|
|
yk = _kde(z_of(pops[p][col].to_numpy(), p), grid)
|
|
ymax = max(ymax, yk.max())
|
|
if fill:
|
|
ax.fill_between(grid, yk, color=c, alpha=fill, lw=0)
|
|
ax.plot(grid, yk, color=c, lw=lw, ls=ls)
|
|
ymax *= 1.18
|
|
# rug of the ACTUAL live points (KDEs of n~20 are smooth fiction; the rout-tail
|
|
# precision claim rests on a handful of rollouts -- show them). hack row on top,
|
|
# in a strip below y=0 (faint separator line); tick labels stay outside the axes.
|
|
ax.axhline(0, color="#cccccc", lw=0.6, zorder=0)
|
|
ax.axvline(0, color="#bbbbbb", lw=0.7, zorder=0) # 0 = family mean (z-norm)
|
|
for row, (p, c) in enumerate((("on_hackpos", HACK), ("on_fail", FAIL), ("on_solve", SOLVE))):
|
|
x = z_of(pops[p][col].to_numpy(), p)
|
|
ax.plot(x, np.full(len(x), -(0.035 + 0.035 * row) * ymax), "|",
|
|
color=c, ms=4, alpha=0.6, mew=0.8)
|
|
|
|
# three zones: keep | absorb | rout
|
|
ax.axvspan(t_lo, min(t_hi, hi), color=ABSORB_C, alpha=0.08, lw=0)
|
|
ax.axvspan(min(t_hi, hi), hi, color=ROUT_C, alpha=0.10, lw=0)
|
|
ax.axvline(t_lo, color=ABSORB_C, lw=1.2, ls="--")
|
|
ax.axvline(t_hi, color=ROUT_C, lw=1.2, ls="--")
|
|
ax.axvline(oracle, color=ORACLE, lw=1.3, ls="-.")
|
|
# zone labels, skipping any whose zone is too narrow on this axis to label legibly
|
|
min_w = 0.05 * (hi - lo)
|
|
for xz, lab, w in ((min(t_lo, hi) - 0.04 * (hi - lo), "keep", min(t_lo, hi) - lo),
|
|
((t_lo + min(t_hi, hi)) / 2, "absorb", min(t_hi, hi) - t_lo),
|
|
((min(t_hi, hi) + hi) / 2, "rout", hi - min(t_hi, hi))):
|
|
if lo < xz < hi and w > min_w:
|
|
ax.text(xz, ymax * 0.97, lab, ha="center", va="top", fontsize=7.5, color="#555555")
|
|
ax.set_xlim(lo, hi)
|
|
ax.set_ylim(-0.13 * ymax, ymax) # negative strip hosts the rugs
|
|
ax.set_yticks([]) # KDE density units are meaningless ink
|
|
for sp in ("top", "right", "left"):
|
|
ax.spines[sp].set_visible(False)
|
|
ax.set_title(f"{rep} · {kind} AUROC={auroc_pos:.2f} (A>0 contrast; vs-all {auroc:.2f}) "
|
|
f"P@rout={prec:.2f} (n={n_rout}) R={rec:.2f}", fontsize=9)
|
|
ax.set_xlabel({"cos": "cosine to v (concat modules), z within family",
|
|
"dot": "dot ⟨x, v⟩, z within family"}[kind], fontsize=8.5)
|
|
|
|
handles = [Line2D([0], [0], color=SOLVE, lw=1.9), Line2D([0], [0], color=FAIL, lw=1.9),
|
|
Line2D([0], [0], color=HACK, lw=1.9),
|
|
Line2D([0], [0], color=SOLVE, lw=2.0, ls=(0, (5, 2))),
|
|
Line2D([0], [0], color=HACK, lw=2.0, ls=(0, (5, 2))),
|
|
Patch(facecolor=ABSORB_C, alpha=0.18), Patch(facecolor=ROUT_C, alpha=0.18),
|
|
Line2D([0], [0], color=ORACLE, lw=1.3, ls="-.")]
|
|
labels = ["live solve", "live fail", "live hack (A>0)", "synthetic clean", "synthetic hack",
|
|
"absorb (otsu lo, label-free)", "rout (otsu hi, label-free)", "oracle hack/rest split"]
|
|
fig.legend(handles, labels, loc="lower center", ncol=4, fontsize=8, frameon=False)
|
|
fig.suptitle(subtitle, fontsize=10)
|
|
fig.tight_layout(rect=(0, 0.07, 1, 0.95))
|
|
fig.savefig(out_png, dpi=140)
|
|
plt.close(fig)
|
|
logger.info(f"wrote {out_png}")
|
|
return stats
|
|
|
|
|
|
def main(cfg: Cfg) -> int:
|
|
cfg.out_dir.mkdir(parents=True, exist_ok=True)
|
|
data_path = cfg.out_dir / "pinning_data.parquet"
|
|
rank_path = cfg.out_dir / "pinning_pairset.parquet"
|
|
feats_path = cfg.out_dir / "pinning_feats.pt"
|
|
q2_png = cfg.out_dir / "pinning_q2.png"
|
|
if cfg.replot is not None:
|
|
plot_q2(pl.read_parquet(cfg.replot), f"replot -- {cfg.replot.name}", q2_png)
|
|
if rank_path.exists():
|
|
print(tabulate(pl.read_parquet(rank_path).to_pandas(), headers="keys",
|
|
tablefmt="pipe", floatfmt="+.3f", showindex=False))
|
|
return 0
|
|
|
|
if cfg.feats is not None:
|
|
fe = torch.load(cfg.feats, weights_only=False)
|
|
logger.info(f"offline re-analysis from {cfg.feats} (no GPU)")
|
|
src = str(cfg.feats)
|
|
else:
|
|
fe = _extract_feats(cfg, feats_path)
|
|
src = f"{cfg.run_dir.name} | {cfg.ckpt}"
|
|
return _downstream(cfg, fe, src)
|
|
|
|
|
|
def _extract_feats(cfg: Cfg, feats_path: Path) -> dict:
|
|
"""One GPU pass: features for every authored pair side and live rollout, saved to
|
|
feats_path. Everything downstream is offline re-projection (rerun via --feats)."""
|
|
device = torch.device("cuda")
|
|
ckpt_path = cfg.run_dir / f"{cfg.ckpt}.safetensors"
|
|
meta = _ckpt_meta(ckpt_path)
|
|
run_cfg = json.loads(meta.get("cfg", "{}"))
|
|
model_name = run_cfg.get("model", meta.get("model", "Qwen/Qwen3-4B"))
|
|
r = run_cfg.get("lora_r", 32)
|
|
init_seed = run_cfg.get("lora_init_seed", 0)
|
|
logger.info(f"ckpt {ckpt_path.name} step={meta.get('step')} hack_rate={meta.get('hack_rate')} "
|
|
f"model={model_name} r={r} init_seed={init_seed}")
|
|
|
|
tok = AutoTokenizer.from_pretrained(model_name)
|
|
if tok.pad_token_id is None:
|
|
tok.pad_token = tok.eos_token
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(device)
|
|
model.config.use_cache = False
|
|
wrappers = wrap_model_with_lora2r(model, r=r, init_seed=init_seed, grad_probe=True)
|
|
names = sorted(wrappers)
|
|
sd = load_file(str(ckpt_path))
|
|
for nm in names:
|
|
wrappers[nm]["A"].data.copy_(sd[f"A/{nm}"].to(device, torch.float32))
|
|
wrappers[nm]["B"].data.copy_(sd[f"B/{nm}"].to(device, torch.float32))
|
|
logger.info(f"loaded A/B into {len(names)} modules")
|
|
model.eval()
|
|
|
|
def one_pass(tap: ActTap, prompt: str, completion: str):
|
|
"""Backward one completion's mean NLL; return ([M,r] c-grad, [M,r] act, [L,d] resid)."""
|
|
model.zero_grad(set_to_none=True)
|
|
loss = completion_nll(model, tok, prompt, completion, device)
|
|
if not torch.isfinite(loss):
|
|
return None
|
|
loss.backward()
|
|
n_prompt = tok(prompt, return_tensors="pt").input_ids.shape[1]
|
|
return _gate_grads(wrappers, names), tap.pooled(n_prompt), tap.pooled_resid(n_prompt)
|
|
|
|
# ── authored-pair features, once over ALL pairs (subsets = row slices) ──
|
|
pairs_all = load_pairs(cfg.pairs)
|
|
logger.info(f"pairs {cfg.pairs} -> {len(pairs_all)}")
|
|
pair_feat = {(rep, side): [] for rep in ("grad", "act", "resid") for side in ("hack", "clean")}
|
|
resid_modules = [model.model.layers[i] for i in cfg.resid_layers]
|
|
with ActTap(wrappers, names, resid_modules) as tap:
|
|
for pi, pair in enumerate(pairs_all):
|
|
for side, completion in (("hack", pair.hack), ("clean", pair.clean)):
|
|
out = one_pass(tap, pair.prompt, completion)
|
|
if out is None:
|
|
raise RuntimeError(f"non-finite loss on pair {pi} ({pair.problem_id}) side={side}")
|
|
for rep, feat in zip(("grad", "act", "resid"), out):
|
|
pair_feat[(rep, side)].append(feat)
|
|
if (pi + 1) % 5 == 0:
|
|
logger.info(f" pair {pi+1}/{len(pairs_all)}")
|
|
PF = {k: torch.stack(v) for k, v in pair_feat.items()} # [P, M, r] / resid [P, L, d]
|
|
|
|
# ── live rollout features, once (everything downstream re-projects) ──
|
|
recs = [json.loads(l) for l in (cfg.run_dir / "rollouts.jsonl").read_text().splitlines()]
|
|
batch = [x for x in recs if cfg.step_lo <= x["step"] <= cfg.step_hi and x["text"].strip()][:cfg.max_rollouts]
|
|
logger.info(f"live batch: {len(batch)} rollouts (steps {cfg.step_lo}-{cfg.step_hi})")
|
|
G_rows, A_rows, R_rows, kept = [], [], [], []
|
|
for i, rec in enumerate(batch):
|
|
out = one_pass(tap, rec["prompt"], rec["text"])
|
|
if out is None:
|
|
logger.warning(f" skip rollout {i}: non-finite loss")
|
|
continue
|
|
G_rows.append(out[0]); A_rows.append(out[1]); R_rows.append(out[2]); kept.append(rec)
|
|
if (i + 1) % 40 == 0:
|
|
logger.info(f" rollout {i+1}/{len(batch)}")
|
|
model.zero_grad(set_to_none=True)
|
|
G = torch.stack(G_rows) # [N, M, r] gradNLL
|
|
ACT = torch.stack(A_rows) # [N, M, r]
|
|
RES = torch.stack(R_rows) # [N, L, d_model]
|
|
exploited = np.array([bool(x["exploited"]) for x in kept])
|
|
gt_pass = np.array([bool(x["gt_pass"]) for x in kept])
|
|
steps = np.array([x["step"] for x in kept])
|
|
p_idx = np.array([x["p_idx"] for x in kept])
|
|
reward = np.array([float(x["reward"]) for x in kept])
|
|
|
|
# Reconstructed Dr.GRPO advantage A_i = reward_i - mean(reward over its group).
|
|
# CAVEAT: students only (teachers absent from rollouts.jsonl), so signs/magnitudes
|
|
# are approximate -- see module docstring.
|
|
grp_mean = {}
|
|
for s, p in set(zip(steps.tolist(), p_idx.tolist())):
|
|
m = (steps == s) & (p_idx == p)
|
|
grp_mean[(s, p)] = reward[m].mean()
|
|
adv = np.array([reward[i] - grp_mean[(steps[i], p_idx[i])] for i in range(len(reward))])
|
|
groups: dict[str, list[int]] = defaultdict(list)
|
|
for i, p in enumerate(pairs_all):
|
|
groups[p.problem_id.split("_")[0]].append(i)
|
|
fe = {"G": G, "ACT": ACT, "RES": RES, "adv": adv, "exploited": exploited,
|
|
"gt_pass": gt_pass, "steps": steps, "p_idx": p_idx, "names": names,
|
|
"resid_layers": cfg.resid_layers, "pair_feats": PF, "pair_groups": dict(groups),
|
|
"pair_ids": [p.problem_id for p in pairs_all]}
|
|
torch.save(fe, feats_path)
|
|
logger.info(f"wrote {feats_path}")
|
|
return fe
|
|
|
|
|
|
def _downstream(cfg: Cfg, fe: dict, src: str) -> int:
|
|
"""Scores, pairset table, parquet, and plot from the feature dict -- no GPU."""
|
|
data_path = cfg.out_dir / "pinning_data.parquet"
|
|
rank_path = cfg.out_dir / "pinning_pairset.parquet"
|
|
q2_png = cfg.out_dir / "pinning_q2.png"
|
|
PF, pair_ids = fe["pair_feats"], fe["pair_ids"]
|
|
G, ACT, RES = fe["G"], fe["ACT"], fe["RES"]
|
|
adv, exploited, gt_pass = fe["adv"], fe["exploited"], fe["gt_pass"]
|
|
steps, p_idx = fe["steps"], fe["p_idx"]
|
|
G_adv = G * torch.tensor(adv, dtype=G.dtype)[:, None, None] # the update the gate sees
|
|
|
|
# ── Q2 populations: drop A~0 (no update); positive = exploited & A>0 ──
|
|
valid = np.abs(adv) > cfg.adv_eps
|
|
y = exploited & (adv > 0)
|
|
pop = np.where(~valid, "on_drop",
|
|
np.where(exploited & (adv > 0), "on_hackpos",
|
|
np.where(exploited, "on_hackneg",
|
|
np.where(gt_pass, "on_solve", "on_fail"))))
|
|
counts = {p: int((pop == p).sum()) for p in ("on_solve", "on_fail", "on_hackpos", "on_hackneg", "on_drop")}
|
|
logger.info(f"live populations: {counts} (zones/AUROC use the {int(valid.sum())} valid rows)")
|
|
print(f"SHOULD: on_hackpos >= ~20 and on_drop not the majority, ELSE the window/run has "
|
|
f"too few learnable hacks and every AUROC below is noise.")
|
|
|
|
# ── headline vectors from the routeV-default subset; placebo swaps in Haar ──
|
|
groups: dict[str, list[int]] = fe["pair_groups"]
|
|
head_idx = [i for i, pid in enumerate(pair_ids) if pid.startswith(cfg.headline_prefix)]
|
|
assert head_idx, f"no pairs match headline prefix {cfg.headline_prefix!r}"
|
|
logger.info(f"headline v from prefix {cfg.headline_prefix!r} -> {len(head_idx)} pairs")
|
|
|
|
REPS = ("grad", "act", "resid")
|
|
|
|
def vectors(idx: list[int]) -> dict[str, torch.Tensor]:
|
|
v = {rep: _v_from(PF[(rep, "hack")], PF[(rep, "clean")], idx) for rep in REPS}
|
|
if cfg.random_v_seed is not None:
|
|
v = {rep: _haar_like(v[rep], cfg.random_v_seed + i) for i, rep in enumerate(REPS)}
|
|
return v
|
|
|
|
v_head = vectors(head_idx)
|
|
live_X = {"grad": G_adv, "act": ACT, "resid": RES}
|
|
|
|
def score_cols(v: dict, X: dict[str, torch.Tensor]) -> dict[str, np.ndarray]:
|
|
return {f"{rep}_{kind}": _score(X[rep], v[rep], kind) for rep, kind in CASES}
|
|
|
|
live_scores = score_cols(v_head, live_X)
|
|
syn_scores = {side: score_cols(v_head, {rep: PF[(rep, side)][head_idx] for rep in REPS})
|
|
for side in ("clean", "hack")}
|
|
|
|
# ── pairset table: subsets x 4 AUROCs on the SAME cached live features ──
|
|
candidates = [("all-in-one", list(range(len(pair_ids))))] + \
|
|
[(g, idx) for g, idx in sorted(groups.items()) if len(idx) >= 3]
|
|
valid_pos = valid & (adv > 0)
|
|
rows = []
|
|
for gname, idx in candidates:
|
|
v = vectors(idx)
|
|
row = {"group": gname, "n_pairs": len(idx)}
|
|
for rep, kind in CASES:
|
|
s = _score(live_X[rep], v[rep], kind)
|
|
row[f"{rep}_{kind}"] = round(_auroc(s[valid_pos].tolist(), y[valid_pos].tolist()), 3)
|
|
row[f"{rep}_{kind}_all"] = round(_auroc(s[valid].tolist(), y[valid].tolist()), 3)
|
|
rows.append(row)
|
|
rank = pl.DataFrame(rows).sort("grad_dot", descending=True)
|
|
rank.write_parquet(rank_path)
|
|
adv_v = adv[valid]
|
|
print(f"\nbaseline adv-only AUROC: vs-all={_auroc(adv_v.tolist(), y[valid].tolist()):.3f} "
|
|
f"A>0-contrast={_auroc(adv[valid_pos].tolist(), y[valid_pos].tolist()):.3f} -- the table "
|
|
f"columns are the A>0 contrast (hack vs non-hack among adv>0, n={int(valid_pos.sum())}), "
|
|
f"where adv is blind; vs-all columns (*_all) live in {rank_path.name}.")
|
|
print("SHOULD: real pairsets beat 0.5 and the adv-only A>0 baseline; under --random-v-seed "
|
|
"every column ~0.5. With ~20 negatives the SE is ~0.07: only gaps >0.15 mean much.")
|
|
print(tabulate(rank.drop([c for c in rank.columns if c.endswith("_all")]).to_pandas(),
|
|
headers="keys", tablefmt="pipe", floatfmt="+.3f", showindex=False))
|
|
|
|
# ── persist per-rollout scores + raw features, then plot ──
|
|
def frame(pop_name: str, mask_or_scores, scores: dict, step_arr, adv_arr) -> pl.DataFrame:
|
|
return pl.DataFrame({"pop": pop_name, "step": step_arr, "adv": adv_arr,
|
|
**{c: scores[c][mask_or_scores] if mask_or_scores is not None else scores[c]
|
|
for c in scores}})
|
|
dfs = [frame(p, pop == p, live_scores, steps[pop == p], adv[pop == p])
|
|
for p in counts if counts[p] > 0]
|
|
n_syn = len(head_idx)
|
|
dfs += [frame(f"syn_{side}", None, syn_scores[side],
|
|
np.full(n_syn, -1), np.ones(n_syn)) for side in ("clean", "hack")]
|
|
df = pl.concat(dfs)
|
|
df.write_parquet(data_path)
|
|
logger.info(f"wrote {data_path} ({len(df)} rows)")
|
|
|
|
sub = (f"{src}, live steps {int(steps.min())}-{int(steps.max())}, v from "
|
|
f"'{cfg.headline_prefix}' pairs (n={len(head_idx)}) | "
|
|
f"hack+={counts['on_hackpos']} hack-={counts['on_hackneg']} solve={counts['on_solve']} "
|
|
f"fail={counts['on_fail']} dropped(A~0)={counts['on_drop']}"
|
|
+ (f" | PLACEBO seed={cfg.random_v_seed}" if cfg.random_v_seed is not None else ""))
|
|
stats = plot_q2(df, sub, q2_png)
|
|
best = max(stats, key=lambda c: stats[c]["auroc_pos"])
|
|
print(f"\nmain metric: best case on the A>0 contrast = {best} "
|
|
f"AUROC={stats[best]['auroc_pos']:.3f} (vs-all {stats[best]['auroc_all']:.3f}) "
|
|
f"P@rout={stats[best]['prec_rout']:.2f} (n={stats[best]['n_rout']}) "
|
|
f"R@rout={stats[best]['rec_rout']:.2f}")
|
|
print(f"out: {q2_png}")
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main(tyro.cli(Cfg)))
|