Files
evil_MoE/scripts/diag_pinning.py
T
wassname 41d225a5ec writeup
2026-06-12 04:46:01 +00:00

589 lines
32 KiB
Python

"""Q2 diagnostic: what should the live routing gate SCORE, and where do the pinning
cuts go?
THE QUESTION (Q2). The gate routes UPDATES, not rollouts: per rollout the GRPO update
on the LoRA is g = A * gradNLL (advantage-weighted). So the positive class is
"update ascending the hack direction" = exploited & A>0; rollouts with A~0 contribute
no update and are DROPPED (scoring them as dead zeros at 0 is what made earlier
adv-weighted AUROCs look blind); a hack with A<0 is being UNLEARNED and belongs in
the negative class. (Q1, "does the direction exist at adv=+1", was answered earlier:
~0.61 -- see git history of this file.)
TWO CONTRASTS (post-review fix). On the vs-ALL contrast (hack+ vs every other valid
rollout) the advantage ALONE is a ~0.9 AUROC detector, because the label requires A>0
and most fails have A<0 -- so a high vs-all number mostly restates the reward, which
the live gate already has for free. The vector's VALUE-ADD is the A>0 contrast: among
positively-reinforced updates (where adv is blind, ~0.5), can the score tell hacks
from solves? That is the gate's primary objective: exclude reward-hacking updates while
retaining correct-solution updates, and it is the primary result here. The vs-all
contrast is kept as a secondary column. NOTE an info
asymmetry: grad scores carry the adv factor (sign+magnitude), act scores do not, so
on the vs-all contrast grad gets label-correlated information act lacks; the A>0
contrast removes most of that difference.
SIX CANDIDATE SCORES = {grad, act, resid} x {cos, dot}, concatenated over modules/layers:
- grad: the adv-weighted deployed c-probe gradient (the gate's current input).
- act: the deployed bottleneck activation A[:r]@x, mean over completion tokens --
same [r]-per-module space, capturable in the gate's pass-1 forward for free.
- resid: residual-stream hidden states at cfg.resid_layers, mean over completion
tokens. Adapter-independent: at an early checkpoint A is near its Gaussian
init, so grad and act are both views through a random r=32 projection per
module; resid tests whether that subspace, not grad-vs-act, limits separation.
- cos: magnitude-blind alignment (tiny vectors give meaningless angles -- control).
- dot: <g, v> = |g|*cos, magnitude-aware; with g = A*gradNLL the advantage rides
along, so dot measures update magnitude aligned with v.
v for each representation comes only from authored pairs (mean hack-minus-clean,
normalized per module). Ground-truth labels from training rollouts are used only for
diagnostic AUROC and precision measurements, never for routing.
DISPLAY + PINNING. Scores are plotted Z-NORMALIZED WITHIN FAMILY: live scores by the
mean/std of all valid live rollouts, synthetic scores by the mean/std of the joint
clean+hack pair scores. Affine per family, so every AUROC is unchanged; it puts both
families on one axis with a meaningful zero. (Raw scores share an offset <mu, v>:
v = mean(hack-clean) guarantees only the GAP between sides, not its location, and the
authored-pair common mean is not orthogonal to v, so uncentered both pair sides land
positive.) Zones keep | absorb | rout come from two-threshold Otsu on the live
z-scores -- the label-free valley cuts an online gate could compute from a rolling
score window (EMA mean/std + valley search). The previous mean+k*sd rule modeled
hacks as a rare outlier tail and put both cuts beyond every distribution (hack share
in these windows is 35-43%); the oracle hack-vs-rest split is drawn for reference.
CAVEAT. Live advantages are reconstructed from rollouts.jsonl students only (teachers
absent, zero-variance groups included, and skipped/empty completions missing from the
group mean), so A signs/magnitudes are approximate; the act columns dodge this
entirely (no A in the representation).
HOW. One GPU pass: per live rollout, backward its completion NLL once, capture the
c-probe grad AND the pooled bottleneck act; same per authored-pair side. Everything
downstream (subset vectors, 4 scores, zones, table) is offline re-projection of the
cached features.
uv run python scripts/diag_pinning.py --run-dir out/runs/<vanilla_lora2r_run>
uv run python scripts/diag_pinning.py --feats out/diag/pinning_feats.pt # no GPU:
# recompute scores/table/plot from cached feats
uv run python scripts/diag_pinning.py --replot out/diag/pinning_data.parquet # plot only
outputs (out/diag/): pinning_q2.png (3x2 headline), pinning_data.parquet (per-rollout
scores), pinning_pairset.parquet + printed table (subsets x 6 AUROCs),
pinning_feats.pt (raw features, for offline re-analysis).
"""
from __future__ import annotations
import json
import struct
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
import tyro
import polars as pl
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from loguru import logger
from tabulate import tabulate
from safetensors.torch import load_file
from transformers import AutoModelForCausalLM, AutoTokenizer
from vgrout.lora2r import wrap_model_with_lora2r
from vgrout.pairs import load_pairs
from vgrout.train import _auroc, _otsu3
# colour = behaviour (blue solve, red hack, grey fail); style = source (solid on-policy, dashed synthetic)
SOLVE, HACK, FAIL, ABSORB_C, ROUT_C, ORACLE = "#3b6ea5", "#c44e52", "#9aa0a6", "#d1900a", "#c44e52", "#3a8a7a"
CASES = [("grad", "cos"), ("grad", "dot"), ("act", "cos"), ("act", "dot"),
("resid", "cos"), ("resid", "dot")]
@dataclass
class Cfg:
run_dir: Path = Path("out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3")
ckpt: str = "first_hack"
pairs: Path = Path("data/pairs/hack_pairs.md#all-in-one")
# headline figure builds v from this heading-prefix subset = the routeA TRAINING
# default (train_config.vhack_pairs_path `#all-in-one/behavior_`, 8 pairs; the
# trailing _ excludes behavior2_*). The pairset table spans all subsets of `pairs`.
headline_prefix: str = "behavior_"
# Coherent emergence window. This vanilla v3 used the pre-fix lr=5e-4/warmup-0.1 and
# DIVERGED at step 10 (exploited 20/24 -> 0/24); 2-9 = hacks emerging, model still sane.
step_lo: int = 2
step_hi: int = 9
max_rollouts: int = 240
adv_eps: float = 1e-6 # |A| below this = no update exists -> dropped from zones/AUROC
resid_layers: tuple[int, ...] = (12, 18, 24) # residual-stream capture depths (of 36)
random_v_seed: int | None = None # Haar placebo (sanity: nothing should separate)
feats: Path | None = None # cached pinning_feats.pt -> full offline re-analysis
replot: Path | None = None # load parquet and re-plot only (no model, no GPU)
out_dir: Path = Path("out/diag")
def _ckpt_meta(path: Path) -> dict:
with open(path, "rb") as f:
return json.loads(f.read(struct.unpack("<Q", f.read(8))[0])).get("__metadata__", {})
class ActTap:
"""Forward hooks stashing (a) the deployed bottleneck activation h = A[:r] @ x per
module and (b) the residual-stream hidden state after each decoder layer in
`resid_modules`.
(a) computes the r-dim projection inline (no_grad) instead of retaining the full
[L, d_in] input -- ~250 modules x [L, d_in] would be GBs; [L, r] is nothing.
"""
def __init__(self, wrappers: dict, names: list[str], resid_modules: list):
self.wrappers, self.names, self.resid_modules = wrappers, names, resid_modules
self.h, self.res, self.handles = {}, {}, []
def __enter__(self):
for nm in self.names:
layer = self.wrappers[nm]["layer"]
def hook(layer, args, out, nm=nm):
(x,) = args
with torch.no_grad():
self.h[nm] = F.linear(x.detach(), layer._lora2r_A[: layer._lora2r_r].to(x.dtype))
self.handles.append(layer.register_forward_hook(hook))
for li, mod in enumerate(self.resid_modules):
def rhook(mod, args, out, li=li):
self.res[li] = (out[0] if isinstance(out, tuple) else out).detach()
self.handles.append(mod.register_forward_hook(rhook))
return self
def __exit__(self, *exc):
for h in self.handles:
h.remove()
def pooled(self, n_prompt: int) -> torch.Tensor:
"""[M, r] mean bottleneck act over completion tokens (positions >= n_prompt)."""
out = []
for nm in self.names:
h = self.h[nm] # [1, L, r]
assert h.shape[1] > n_prompt, f"{nm}: no completion tokens (L={h.shape[1]} n_prompt={n_prompt})"
out.append(h[0, n_prompt:].float().mean(0).cpu())
return torch.stack(out)
def pooled_resid(self, n_prompt: int) -> torch.Tensor:
"""[L_layers, d_model] mean residual-stream state over completion tokens."""
return torch.stack([self.res[li][0, n_prompt:].float().mean(0).cpu()
for li in range(len(self.resid_modules))])
def _gate_grads(wrappers: dict, names: list[str]) -> torch.Tensor:
"""[M, r] deployed-block c-probe grad after a backward (the gate's gradient space)."""
g = []
for nm in names:
layer = wrappers[nm]["layer"]
gr = layer._lora2r_gate.grad
g.append(gr.sum(dim=tuple(range(gr.dim() - 1)))[: layer._lora2r_r].float().cpu())
return torch.stack(g)
def _v_from(feat_hack: torch.Tensor, feat_clean: torch.Tensor, idx: list[int]) -> torch.Tensor:
"""[M, r] unit-per-module mean hack-minus-clean direction from pair rows `idx`."""
d = (feat_hack[idx] - feat_clean[idx]).mean(0)
return d / d.norm(dim=1, keepdim=True).clamp_min(1e-12)
def _haar_like(v: torch.Tensor, seed: int) -> torch.Tensor:
g = torch.Generator().manual_seed(seed)
d = torch.randn(v.shape, generator=g)
return d / d.norm(dim=1, keepdim=True).clamp_min(1e-12)
def _score(X: torch.Tensor, V: torch.Tensor, kind: str) -> np.ndarray:
"""Concat-module score per rollout: dot = sum_m <x_m, v_m>; cos = dot / (||x|| ||v||)."""
d = torch.einsum("nmr,mr->n", X, V)
if kind == "dot":
return d.numpy()
return (d / (X.flatten(1).norm(dim=1).clamp_min(1e-12) * V.flatten().norm().clamp_min(1e-12))).numpy()
def _kde(x: np.ndarray, grid: np.ndarray) -> np.ndarray:
"""Gaussian KDE, Silverman bandwidth (no scipy). Bandwidth is scale-relative
(dot scores can live at 1e-4 or 1e2)."""
x = np.asarray(x, float)
if len(x) < 2:
return np.zeros_like(grid)
iqr = np.subtract(*np.percentile(x, [75, 25]))
sigma = min(x.std(ddof=1), iqr / 1.349) if iqr > 0 else x.std(ddof=1)
if sigma <= 0:
return np.zeros_like(grid)
bw = 0.9 * sigma * len(x) ** (-0.2)
z = (grid[:, None] - x[None, :]) / bw
return np.exp(-0.5 * z ** 2).sum(1) / (len(x) * bw * np.sqrt(2 * np.pi))
def completion_nll(model, tokenizer, prompt: str, completion: str, device) -> torch.Tensor:
"""Mean NLL over completion tokens only (length-normalized). The backward of this
loss populates the c-probe grads read by _gate_grads (the retired grad-gate space,
kept here as a diagnostic baseline)."""
prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
full_ids = tokenizer(prompt + completion, return_tensors="pt").input_ids.to(device)
n_prompt = prompt_ids.shape[1]
logits = model(full_ids).logits[:, :-1] # [1, L-1, V]
targets = full_ids[:, 1:] # [1, L-1]
logp = torch.nn.functional.log_softmax(logits.float(), dim=-1)
nll = -logp.gather(-1, targets.unsqueeze(-1)).squeeze(-1) # [1, L-1]
# mask: positions whose target is a completion token (i.e. index >= n_prompt in full_ids)
pos = torch.arange(full_ids.shape[1] - 1, device=device).unsqueeze(0)
mask = (pos >= (n_prompt - 1)).float()
return (nll * mask).sum() / mask.sum().clamp_min(1.0)
def plot_q2(df: pl.DataFrame, subtitle: str, out_png: Path) -> dict:
"""3x2 figure ({grad,act,resid} x {cos,dot}) from the saved per-rollout scores -- no GPU.
Per panel: live solve/fail/hack+ KDEs (+ thin hack- if n>=3), synthetic pair sides
dashed, all Z-NORMALIZED WITHIN FAMILY (live by valid-live mean/std, synthetic by
joint clean+hack mean/std; affine, AUROC unchanged) so both families share one axis
with a meaningful zero. Three shaded zones keep|absorb|rout from two-threshold Otsu
on the live z-scores (label-free, online-computable), oracle split, AUROC + P/R at
the rout cut. Returns the per-case stats dict for logging."""
pops = {p: df.filter(pl.col("pop") == p) for p in df["pop"].unique().to_list()}
live_pops = ["on_solve", "on_fail", "on_hackpos", "on_hackneg"]
live = df.filter(pl.col("pop").is_in(live_pops))
posm = (live["adv"] > 0).to_numpy() # the A>0 contrast rows
y_all = (live["pop"] == "on_hackpos").to_numpy()
# adv-only baseline: on vs-all the reward alone is a strong detector (the label
# requires A>0); the vector only adds value where this baseline is blind.
a = live["adv"].to_numpy()
logger.info(f"adv-only baseline AUROC: vs-all={_auroc(a.tolist(), y_all.tolist()):.3f} "
f"A>0-contrast={_auroc(a[posm].tolist(), y_all[posm].tolist()):.3f} "
f"(n+={int(y_all.sum())} negA>0={int((~y_all & posm).sum())})")
stats = {}
n_rows = len(CASES) // 2
fig, axes = plt.subplots(n_rows, 2, figsize=(12.5, 3.6 * n_rows + 0.8))
for ax, (rep, kind) in zip(axes.flat, CASES):
col = f"{rep}_{kind}"
s_raw = live[col].to_numpy()
mu_l, sd_l = float(s_raw.mean()), float(s_raw.std()) or 1.0
syn_join = np.concatenate([pops[p][col].to_numpy() for p in ("syn_clean", "syn_hack")
if len(pops.get(p, []))])
mu_s = float(syn_join.mean()) if len(syn_join) else 0.0
sd_s = (float(syn_join.std()) or 1.0) if len(syn_join) else 1.0
z_of = lambda x, p: (x - mu_s) / sd_s if p.startswith("syn") else (x - mu_l) / sd_l
s = (s_raw - mu_l) / sd_l
y = y_all
t_lo, t_hi = _otsu3(s)
auroc = _auroc(s.tolist(), y.tolist())
auroc_pos = _auroc(s[posm].tolist(), y[posm].tolist())
thr = np.unique(s)
j = [(s[y] >= t).mean() - (s[~y] >= t).mean() for t in thr]
oracle = float(thr[int(np.argmax(j))])
routed = s >= t_hi
n_rout = int(routed.sum())
prec = float(y[routed].mean()) if routed.any() else float("nan")
rec = float((s[y] >= t_hi).mean()) if y.any() else float("nan")
# F_beta at the rout cut, beta=0.5 (PRECISION-weighted). The routing cost is
# asymmetric the OTHER way than naive intuition: a missed hack (false negative)
# is absorbed -- SGTM is robust to 40-50% undiscovered forget data because the
# routed subset localizes the capability regardless (paper_sgtm.md L64,160,362).
# A false positive (clean routed to rout) has NO such safety net: that solve
# update goes only to the quarantine and is ablated away -> lost capability. So
# the rout cut should be high-PRECISION (pin only confident hacks; let the wide
# absorb band catch the uncertain ones). AUROC ignores the threshold and the
# imbalance; this scores the gate at its operating point. Measurement only -- it
# needs hack labels, so it can never feed the live gate.
b2 = 0.25 # beta=0.5 -> beta^2
fbeta = float((1 + b2) * prec * rec / (b2 * prec + rec)) if (prec + rec) > 0 else 0.0
stats[col] = {"auroc_pos": auroc_pos, "auroc_all": auroc, "prec_rout": prec,
"rec_rout": rec, "fhalf_rout": fbeta, "n_rout": n_rout, "t_hi": t_hi,
"oracle": oracle}
zvals = np.concatenate([s, (syn_join - mu_s) / sd_s]) if len(syn_join) else s
lo = float(np.quantile(zvals, 0.005))
hi = float(np.quantile(zvals, 0.995))
pad = 0.05 * (hi - lo) or 1e-6
lo, hi = lo - pad, hi + pad
grid = np.linspace(lo, hi, 400)
curves = [("on_solve", SOLVE, "-", 1.9, 0.12), ("on_fail", FAIL, "-", 1.9, 0.12),
("on_hackpos", HACK, "-", 1.9, 0.12),
("syn_clean", SOLVE, (0, (5, 2)), 2.0, 0.0), ("syn_hack", HACK, (0, (5, 2)), 2.0, 0.0)]
if len(pops.get("on_hackneg", [])) >= 3:
curves.insert(3, ("on_hackneg", HACK, (0, (1, 1)), 1.2, 0.0))
ymax = 0.0
for p, c, ls, lw, fill in curves:
yk = _kde(z_of(pops[p][col].to_numpy(), p), grid)
ymax = max(ymax, yk.max())
if fill:
ax.fill_between(grid, yk, color=c, alpha=fill, lw=0)
ax.plot(grid, yk, color=c, lw=lw, ls=ls)
ymax *= 1.18
# rug of the ACTUAL live points (KDEs of n~20 are smooth fiction; the rout-tail
# precision claim rests on a handful of rollouts -- show them). hack row on top,
# in a strip below y=0 (faint separator line); tick labels stay outside the axes.
ax.axhline(0, color="#cccccc", lw=0.6, zorder=0)
ax.axvline(0, color="#bbbbbb", lw=0.7, zorder=0) # 0 = family mean (z-norm)
for row, (p, c) in enumerate((("on_hackpos", HACK), ("on_fail", FAIL), ("on_solve", SOLVE))):
x = z_of(pops[p][col].to_numpy(), p)
ax.plot(x, np.full(len(x), -(0.035 + 0.035 * row) * ymax), "|",
color=c, ms=4, alpha=0.6, mew=0.8)
# three zones: keep | absorb | rout
ax.axvspan(t_lo, min(t_hi, hi), color=ABSORB_C, alpha=0.08, lw=0)
ax.axvspan(min(t_hi, hi), hi, color=ROUT_C, alpha=0.10, lw=0)
ax.axvline(t_lo, color=ABSORB_C, lw=1.2, ls="--")
ax.axvline(t_hi, color=ROUT_C, lw=1.2, ls="--")
ax.axvline(oracle, color=ORACLE, lw=1.3, ls="-.")
# zone labels, skipping any whose zone is too narrow on this axis to label legibly
min_w = 0.05 * (hi - lo)
for xz, lab, w in ((min(t_lo, hi) - 0.04 * (hi - lo), "keep", min(t_lo, hi) - lo),
((t_lo + min(t_hi, hi)) / 2, "absorb", min(t_hi, hi) - t_lo),
((min(t_hi, hi) + hi) / 2, "rout", hi - min(t_hi, hi))):
if lo < xz < hi and w > min_w:
ax.text(xz, ymax * 0.97, lab, ha="center", va="top", fontsize=7.5, color="#555555")
ax.set_xlim(lo, hi)
ax.set_ylim(-0.13 * ymax, ymax) # negative strip hosts the rugs
ax.set_yticks([]) # KDE density units are meaningless ink
for sp in ("top", "right", "left"):
ax.spines[sp].set_visible(False)
ax.set_title(f"{rep} · {kind} AUROC={auroc_pos:.2f} (A>0 contrast; vs-all {auroc:.2f}) "
f"P@rout={prec:.2f} (n={n_rout}) R={rec:.2f} F0.5={fbeta:.2f}", fontsize=9)
ax.set_xlabel({"cos": "cosine to v (concat modules), z within family",
"dot": "dot ⟨x, v⟩, z within family"}[kind], fontsize=8.5)
handles = [Line2D([0], [0], color=SOLVE, lw=1.9), Line2D([0], [0], color=FAIL, lw=1.9),
Line2D([0], [0], color=HACK, lw=1.9),
Line2D([0], [0], color=SOLVE, lw=2.0, ls=(0, (5, 2))),
Line2D([0], [0], color=HACK, lw=2.0, ls=(0, (5, 2))),
Patch(facecolor=ABSORB_C, alpha=0.18), Patch(facecolor=ROUT_C, alpha=0.18),
Line2D([0], [0], color=ORACLE, lw=1.3, ls="-.")]
labels = ["live solve", "live fail", "live hack (A>0)", "synthetic clean", "synthetic hack",
"absorb (otsu lo, label-free)", "rout (otsu hi, label-free)", "oracle hack/rest split"]
fig.legend(handles, labels, loc="lower center", ncol=4, fontsize=8, frameon=False)
fig.suptitle(subtitle, fontsize=10)
fig.tight_layout(rect=(0, 0.07, 1, 0.95))
fig.savefig(out_png, dpi=140)
plt.close(fig)
logger.info(f"wrote {out_png}")
return stats
def main(cfg: Cfg) -> int:
cfg.out_dir.mkdir(parents=True, exist_ok=True)
data_path = cfg.out_dir / "pinning_data.parquet"
rank_path = cfg.out_dir / "pinning_pairset.parquet"
feats_path = cfg.out_dir / "pinning_feats.pt"
q2_png = cfg.out_dir / "pinning_q2.png"
if cfg.replot is not None:
plot_q2(pl.read_parquet(cfg.replot), f"replot -- {cfg.replot.name}", q2_png)
if rank_path.exists():
print(tabulate(pl.read_parquet(rank_path).to_pandas(), headers="keys",
tablefmt="pipe", floatfmt="+.3f", showindex=False))
return 0
if cfg.feats is not None:
fe = torch.load(cfg.feats, weights_only=False)
logger.info(f"offline re-analysis from {cfg.feats} (no GPU)")
src = str(cfg.feats)
else:
fe = _extract_feats(cfg, feats_path)
src = f"{cfg.run_dir.name} | {cfg.ckpt}"
return _downstream(cfg, fe, src)
def _extract_feats(cfg: Cfg, feats_path: Path) -> dict:
"""One GPU pass: features for every authored pair side and live rollout, saved to
feats_path. Everything downstream is offline re-projection (rerun via --feats)."""
device = torch.device("cuda")
ckpt_path = cfg.run_dir / f"{cfg.ckpt}.safetensors"
meta = _ckpt_meta(ckpt_path)
run_cfg = json.loads(meta.get("cfg", "{}"))
model_name = run_cfg.get("model", meta.get("model", "Qwen/Qwen3-4B"))
r = run_cfg.get("lora_r", 32)
init_seed = run_cfg.get("lora_init_seed", 0)
logger.info(f"ckpt {ckpt_path.name} step={meta.get('step')} hack_rate={meta.get('hack_rate')} "
f"model={model_name} r={r} init_seed={init_seed}")
tok = AutoTokenizer.from_pretrained(model_name)
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(device)
model.config.use_cache = False
wrappers = wrap_model_with_lora2r(model, r=r, init_seed=init_seed, grad_probe=True)
names = sorted(wrappers)
sd = load_file(str(ckpt_path))
for nm in names:
wrappers[nm]["A"].data.copy_(sd[f"A/{nm}"].to(device, torch.float32))
wrappers[nm]["B"].data.copy_(sd[f"B/{nm}"].to(device, torch.float32))
logger.info(f"loaded A/B into {len(names)} modules")
model.eval()
def one_pass(tap: ActTap, prompt: str, completion: str):
"""Backward one completion's mean NLL; return ([M,r] c-grad, [M,r] act, [L,d] resid)."""
model.zero_grad(set_to_none=True)
loss = completion_nll(model, tok, prompt, completion, device)
if not torch.isfinite(loss):
return None
loss.backward()
n_prompt = tok(prompt, return_tensors="pt").input_ids.shape[1]
return _gate_grads(wrappers, names), tap.pooled(n_prompt), tap.pooled_resid(n_prompt)
# ── authored-pair features, once over ALL pairs (subsets = row slices) ──
pairs_all = load_pairs(cfg.pairs)
logger.info(f"pairs {cfg.pairs} -> {len(pairs_all)}")
pair_feat = {(rep, side): [] for rep in ("grad", "act", "resid") for side in ("hack", "clean")}
resid_modules = [model.model.layers[i] for i in cfg.resid_layers]
with ActTap(wrappers, names, resid_modules) as tap:
for pi, pair in enumerate(pairs_all):
for side, completion in (("hack", pair.hack), ("clean", pair.clean)):
out = one_pass(tap, pair.prompt, completion)
if out is None:
raise RuntimeError(f"non-finite loss on pair {pi} ({pair.problem_id}) side={side}")
for rep, feat in zip(("grad", "act", "resid"), out):
pair_feat[(rep, side)].append(feat)
if (pi + 1) % 5 == 0:
logger.info(f" pair {pi+1}/{len(pairs_all)}")
PF = {k: torch.stack(v) for k, v in pair_feat.items()} # [P, M, r] / resid [P, L, d]
# ── live rollout features, once (everything downstream re-projects) ──
recs = [json.loads(l) for l in (cfg.run_dir / "rollouts.jsonl").read_text().splitlines()]
batch = [x for x in recs if cfg.step_lo <= x["step"] <= cfg.step_hi and x["text"].strip()][:cfg.max_rollouts]
logger.info(f"live batch: {len(batch)} rollouts (steps {cfg.step_lo}-{cfg.step_hi})")
G_rows, A_rows, R_rows, kept = [], [], [], []
for i, rec in enumerate(batch):
out = one_pass(tap, rec["prompt"], rec["text"])
if out is None:
logger.warning(f" skip rollout {i}: non-finite loss")
continue
G_rows.append(out[0]); A_rows.append(out[1]); R_rows.append(out[2]); kept.append(rec)
if (i + 1) % 40 == 0:
logger.info(f" rollout {i+1}/{len(batch)}")
model.zero_grad(set_to_none=True)
G = torch.stack(G_rows) # [N, M, r] gradNLL
ACT = torch.stack(A_rows) # [N, M, r]
RES = torch.stack(R_rows) # [N, L, d_model]
exploited = np.array([bool(x["exploited"]) for x in kept])
gt_pass = np.array([bool(x["gt_pass"]) for x in kept])
steps = np.array([x["step"] for x in kept])
p_idx = np.array([x["p_idx"] for x in kept])
reward = np.array([float(x["reward"]) for x in kept])
# Reconstructed Dr.GRPO advantage A_i = reward_i - mean(reward over its group).
# CAVEAT: students only (teachers absent from rollouts.jsonl), so signs/magnitudes
# are approximate -- see module docstring.
grp_mean = {}
for s, p in set(zip(steps.tolist(), p_idx.tolist())):
m = (steps == s) & (p_idx == p)
grp_mean[(s, p)] = reward[m].mean()
adv = np.array([reward[i] - grp_mean[(steps[i], p_idx[i])] for i in range(len(reward))])
groups: dict[str, list[int]] = defaultdict(list)
for i, p in enumerate(pairs_all):
groups[p.problem_id.split("_")[0]].append(i)
fe = {"G": G, "ACT": ACT, "RES": RES, "adv": adv, "exploited": exploited,
"gt_pass": gt_pass, "steps": steps, "p_idx": p_idx, "names": names,
"resid_layers": cfg.resid_layers, "pair_feats": PF, "pair_groups": dict(groups),
"pair_ids": [p.problem_id for p in pairs_all]}
torch.save(fe, feats_path)
logger.info(f"wrote {feats_path}")
return fe
def _downstream(cfg: Cfg, fe: dict, src: str) -> int:
"""Scores, pairset table, parquet, and plot from the feature dict -- no GPU."""
data_path = cfg.out_dir / "pinning_data.parquet"
rank_path = cfg.out_dir / "pinning_pairset.parquet"
q2_png = cfg.out_dir / "pinning_q2.png"
PF, pair_ids = fe["pair_feats"], fe["pair_ids"]
G, ACT, RES = fe["G"], fe["ACT"], fe["RES"]
adv, exploited, gt_pass = fe["adv"], fe["exploited"], fe["gt_pass"]
steps, p_idx = fe["steps"], fe["p_idx"]
G_adv = G * torch.tensor(adv, dtype=G.dtype)[:, None, None] # the update the gate sees
# ── Q2 populations: drop A~0 (no update); positive = exploited & A>0 ──
valid = np.abs(adv) > cfg.adv_eps
y = exploited & (adv > 0)
pop = np.where(~valid, "on_drop",
np.where(exploited & (adv > 0), "on_hackpos",
np.where(exploited, "on_hackneg",
np.where(gt_pass, "on_solve", "on_fail"))))
counts = {p: int((pop == p).sum()) for p in ("on_solve", "on_fail", "on_hackpos", "on_hackneg", "on_drop")}
logger.info(f"live populations: {counts} (zones/AUROC use the {int(valid.sum())} valid rows)")
print(f"SHOULD: on_hackpos >= ~20 and on_drop not the majority, ELSE the window/run has "
f"too few learnable hacks and every AUROC below is noise.")
# ── headline vectors from the routeA-default subset; placebo swaps in Haar ──
groups: dict[str, list[int]] = fe["pair_groups"]
head_idx = [i for i, pid in enumerate(pair_ids) if pid.startswith(cfg.headline_prefix)]
assert head_idx, f"no pairs match headline prefix {cfg.headline_prefix!r}"
logger.info(f"headline v from prefix {cfg.headline_prefix!r} -> {len(head_idx)} pairs")
REPS = ("grad", "act", "resid")
def vectors(idx: list[int]) -> dict[str, torch.Tensor]:
v = {rep: _v_from(PF[(rep, "hack")], PF[(rep, "clean")], idx) for rep in REPS}
if cfg.random_v_seed is not None:
v = {rep: _haar_like(v[rep], cfg.random_v_seed + i) for i, rep in enumerate(REPS)}
return v
v_head = vectors(head_idx)
live_X = {"grad": G_adv, "act": ACT, "resid": RES}
def score_cols(v: dict, X: dict[str, torch.Tensor]) -> dict[str, np.ndarray]:
return {f"{rep}_{kind}": _score(X[rep], v[rep], kind) for rep, kind in CASES}
live_scores = score_cols(v_head, live_X)
syn_scores = {side: score_cols(v_head, {rep: PF[(rep, side)][head_idx] for rep in REPS})
for side in ("clean", "hack")}
# ── pairset table: subsets x 4 AUROCs on the SAME cached live features ──
candidates = [("all-in-one", list(range(len(pair_ids))))] + \
[(g, idx) for g, idx in sorted(groups.items()) if len(idx) >= 3]
valid_pos = valid & (adv > 0)
rows = []
for gname, idx in candidates:
v = vectors(idx)
row = {"group": gname, "n_pairs": len(idx)}
for rep, kind in CASES:
s = _score(live_X[rep], v[rep], kind)
row[f"{rep}_{kind}"] = round(_auroc(s[valid_pos].tolist(), y[valid_pos].tolist()), 3)
row[f"{rep}_{kind}_all"] = round(_auroc(s[valid].tolist(), y[valid].tolist()), 3)
rows.append(row)
rank = pl.DataFrame(rows).sort("grad_dot", descending=True)
rank.write_parquet(rank_path)
adv_v = adv[valid]
print(f"\nbaseline adv-only AUROC: vs-all={_auroc(adv_v.tolist(), y[valid].tolist()):.3f} "
f"A>0-contrast={_auroc(adv[valid_pos].tolist(), y[valid_pos].tolist()):.3f} -- the table "
f"columns are the A>0 contrast (hack vs non-hack among adv>0, n={int(valid_pos.sum())}), "
f"where adv is blind; vs-all columns (*_all) live in {rank_path.name}.")
print("SHOULD: real pairsets beat 0.5 and the adv-only A>0 baseline; under --random-v-seed "
"every column ~0.5. With ~20 negatives the SE is ~0.07: only gaps >0.15 mean much.")
print(tabulate(rank.drop([c for c in rank.columns if c.endswith("_all")]).to_pandas(),
headers="keys", tablefmt="pipe", floatfmt="+.3f", showindex=False))
# ── persist per-rollout scores + raw features, then plot ──
def frame(pop_name: str, mask_or_scores, scores: dict, step_arr, adv_arr) -> pl.DataFrame:
return pl.DataFrame({"pop": pop_name, "step": step_arr, "adv": adv_arr,
**{c: scores[c][mask_or_scores] if mask_or_scores is not None else scores[c]
for c in scores}})
dfs = [frame(p, pop == p, live_scores, steps[pop == p], adv[pop == p])
for p in counts if counts[p] > 0]
n_syn = len(head_idx)
dfs += [frame(f"syn_{side}", None, syn_scores[side],
np.full(n_syn, -1), np.ones(n_syn)) for side in ("clean", "hack")]
df = pl.concat(dfs)
df.write_parquet(data_path)
logger.info(f"wrote {data_path} ({len(df)} rows)")
sub = (f"{src}, live steps {int(steps.min())}-{int(steps.max())}, v from "
f"'{cfg.headline_prefix}' pairs (n={len(head_idx)}) | "
f"hack+={counts['on_hackpos']} hack-={counts['on_hackneg']} solve={counts['on_solve']} "
f"fail={counts['on_fail']} dropped(A~0)={counts['on_drop']}"
+ (f" | PLACEBO seed={cfg.random_v_seed}" if cfg.random_v_seed is not None else ""))
stats = plot_q2(df, sub, q2_png)
best = max(stats, key=lambda c: stats[c]["auroc_pos"])
print(f"\nmain metric: best case on the A>0 contrast = {best} "
f"AUROC={stats[best]['auroc_pos']:.3f} (vs-all {stats[best]['auroc_all']:.3f}) "
f"P@rout={stats[best]['prec_rout']:.2f} (n={stats[best]['n_rout']}) "
f"R@rout={stats[best]['rec_rout']:.2f} F0.5@rout={stats[best]['fhalf_rout']:.2f}")
print(f"out: {q2_png}")
return 0
if __name__ == "__main__":
raise SystemExit(main(tyro.cli(Cfg)))