diag(#40): z-norm scores within family, winsorized 2-threshold Otsu zones, --feats offline mode

Synthetic and live scores share an axis with meaningful zero (raw scores carry a
common <mu,v> offset since v = mean diff is not orthogonal to the family mean).
Zones come from label-free online stats (EMA mean/std + Otsu valley), replacing
mean+k*sd which placed both cuts beyond every distribution. Winsorize at 1/99%
before Otsu: variance-maximizing cuts otherwise buy a class for one outlier.
Fresh-eyes review verified z-norm is affine (AUROCs unchanged), zones label-free,
P/R recompute from parquet matches titles.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-11 11:23:42 +00:00
parent 270c4f5a27
commit e5b68acf69
2 changed files with 103 additions and 51 deletions
+102 -50
View File
@@ -36,11 +36,17 @@ v for each representation comes only from authored pairs (mean hack-minus-clean,
normalized per module). Ground-truth labels from training rollouts are used only for
diagnostic AUROC and precision measurements, never for routing.
PINNING. Each panel shades the three zones the online gate rule would give on this
window: keep (bulk) | absorb (score > mean + k_mid*sd) | rout (>= mean + k_rout*sd),
plus the oracle best hack-vs-rest split for reference. k's default to the real-run
Config values (2/3), not the checkpoint's preset, so the plot answers "where WOULD
the cuts fall", overridable via --k-mid/--k-rout.
DISPLAY + PINNING. Scores are plotted Z-NORMALIZED WITHIN FAMILY: live scores by the
mean/std of all valid live rollouts, synthetic scores by the mean/std of the joint
clean+hack pair scores. Affine per family, so every AUROC is unchanged; it puts both
families on one axis with a meaningful zero. (Raw scores share an offset <mu, v>:
v = mean(hack-clean) guarantees only the GAP between sides, not its location, and the
authored-pair common mean is not orthogonal to v, so uncentered both pair sides land
positive.) Zones keep | absorb | rout come from two-threshold Otsu on the live
z-scores -- the label-free valley cuts an online gate could compute from a rolling
score window (EMA mean/std + valley search). The previous mean+k*sd rule modeled
hacks as a rare outlier tail and put both cuts beyond every distribution (hack share
in these windows is 35-43%); the oracle hack-vs-rest split is drawn for reference.
CAVEAT. Live advantages are reconstructed from rollouts.jsonl students only (teachers
absent, zero-variance groups included, and skipped/empty completions missing from the
@@ -53,7 +59,9 @@ downstream (subset vectors, 4 scores, zones, table) is offline re-projection of
cached features.
uv run python scripts/diag_pinning.py --run-dir out/runs/<vanilla_lora2r_run>
uv run python scripts/diag_pinning.py --replot out/diag/pinning_data.parquet # no GPU
uv run python scripts/diag_pinning.py --feats out/diag/pinning_feats.pt # no GPU:
# recompute scores/table/plot from cached feats
uv run python scripts/diag_pinning.py --replot out/diag/pinning_data.parquet # plot only
outputs (out/diag/): pinning_q2.png (3x2 headline), pinning_data.parquet (per-rollout
scores), pinning_pairset.parquet + printed table (subsets x 6 AUROCs),
pinning_feats.pt (raw features, for offline re-analysis).
@@ -105,11 +113,10 @@ class Cfg:
step_lo: int = 2
step_hi: int = 9
max_rollouts: int = 240
k_mid: float = 2.0 # absorb onset: score > mean + k_mid*sd (real-run Config default)
k_rout: float = 3.0 # rout onset: score >= mean + k_rout*sd
adv_eps: float = 1e-6 # |A| below this = no update exists -> dropped from zones/AUROC
resid_layers: tuple[int, ...] = (12, 18, 24) # residual-stream capture depths (of 36)
random_v_seed: int | None = None # Haar placebo (sanity: nothing should separate)
feats: Path | None = None # cached pinning_feats.pt -> full offline re-analysis
replot: Path | None = None # load parquet and re-plot only (no model, no GPU)
out_dir: Path = Path("out/diag")
@@ -209,13 +216,35 @@ def _kde(x: np.ndarray, grid: np.ndarray) -> np.ndarray:
return np.exp(-0.5 * z ** 2).sum(1) / (len(x) * bw * np.sqrt(2 * np.pi))
def plot_q2(df: pl.DataFrame, k_mid: float, k_rout: float, subtitle: str, out_png: Path) -> dict:
"""2x2 figure ({grad,act} x {cos,dot}) from the saved per-rollout scores -- no GPU.
def _otsu3(x: np.ndarray) -> tuple[float, float]:
"""Two-threshold Otsu: the pair of cuts maximizing 3-class between-class variance.
Label-free -- an online gate can compute this from a rolling window of scores, so
using it here is not oracle leakage. O(n^2), fine for a few hundred scores.
Scores are winsorized at 1/99% first: Otsu maximizes variance, so on heavy-tailed
scores a single extreme point otherwise buys a whole class (seen on grad_dot)."""
x = np.clip(x, *np.quantile(x, [0.01, 0.99]))
s = np.sort(np.asarray(x, float))
n = len(s)
c = np.concatenate([[0.0], np.cumsum(s)])
best, best_ij = -np.inf, (1, 2)
for i in range(1, n - 1):
for j in range(i + 1, n):
obj = c[i] ** 2 / i + (c[j] - c[i]) ** 2 / (j - i) + (c[n] - c[j]) ** 2 / (n - j)
if obj > best:
best, best_ij = obj, (i, j)
i, j = best_ij
return float((s[i - 1] + s[i]) / 2), float((s[j - 1] + s[j]) / 2)
def plot_q2(df: pl.DataFrame, subtitle: str, out_png: Path) -> dict:
"""3x2 figure ({grad,act,resid} x {cos,dot}) from the saved per-rollout scores -- no GPU.
Per panel: live solve/fail/hack+ KDEs (+ thin hack- if n>=3), synthetic pair sides
dashed, three shaded zones keep|absorb|rout from mean + k*sd over the VALID live
scores (|A|>eps; pop 'on_drop' excluded), oracle split, AUROC + P/R at the rout cut.
Returns the per-case stats dict for logging."""
dashed, all Z-NORMALIZED WITHIN FAMILY (live by valid-live mean/std, synthetic by
joint clean+hack mean/std; affine, AUROC unchanged) so both families share one axis
with a meaningful zero. Three shaded zones keep|absorb|rout from two-threshold Otsu
on the live z-scores (label-free, online-computable), oracle split, AUROC + P/R at
the rout cut. Returns the per-case stats dict for logging."""
pops = {p: df.filter(pl.col("pop") == p) for p in df["pop"].unique().to_list()}
live_pops = ["on_solve", "on_fail", "on_hackpos", "on_hackneg"]
live = df.filter(pl.col("pop").is_in(live_pops))
@@ -232,10 +261,16 @@ def plot_q2(df: pl.DataFrame, k_mid: float, k_rout: float, subtitle: str, out_pn
fig, axes = plt.subplots(n_rows, 2, figsize=(12.5, 3.6 * n_rows + 0.8))
for ax, (rep, kind) in zip(axes.flat, CASES):
col = f"{rep}_{kind}"
s = live[col].to_numpy()
s_raw = live[col].to_numpy()
mu_l, sd_l = float(s_raw.mean()), float(s_raw.std()) or 1.0
syn_join = np.concatenate([pops[p][col].to_numpy() for p in ("syn_clean", "syn_hack")
if len(pops.get(p, []))])
mu_s = float(syn_join.mean()) if len(syn_join) else 0.0
sd_s = (float(syn_join.std()) or 1.0) if len(syn_join) else 1.0
z_of = lambda x, p: (x - mu_s) / sd_s if p.startswith("syn") else (x - mu_l) / sd_l
s = (s_raw - mu_l) / sd_l
y = y_all
mu, sd = float(s.mean()), float(s.std())
t_lo, t_hi = mu + k_mid * sd, mu + k_rout * sd
t_lo, t_hi = _otsu3(s)
auroc = _auroc(s.tolist(), y.tolist())
auroc_pos = _auroc(s[posm].tolist(), y[posm].tolist())
thr = np.unique(s)
@@ -248,23 +283,12 @@ def plot_q2(df: pl.DataFrame, k_mid: float, k_rout: float, subtitle: str, out_pn
stats[col] = {"auroc_pos": auroc_pos, "auroc_all": auroc, "prec_rout": prec,
"rec_rout": rec, "n_rout": n_rout, "t_hi": t_hi, "oracle": oracle}
lo = float(np.quantile(s, 0.005))
hi = float(np.quantile(s, 0.995))
if kind == "cos": # keep synthetic medians visible (cos shares a scale;
for p in ("syn_clean", "syn_hack"): # dot doesn't -- pair grads dwarf live, annotate instead)
if len(pops.get(p, [])):
m = float(np.median(pops[p][col].to_numpy()))
lo, hi = min(lo, m), max(hi, m)
zvals = np.concatenate([s, (syn_join - mu_s) / sd_s]) if len(syn_join) else s
lo = float(np.quantile(zvals, 0.005))
hi = float(np.quantile(zvals, 0.995))
pad = 0.05 * (hi - lo) or 1e-6
lo, hi = lo - pad, hi + pad
grid = np.linspace(lo, hi, 400)
if kind == "dot":
off = [f"syn {p.split('_')[1]} med={float(np.median(pops[p][col].to_numpy())):+.2g}"
for p in ("syn_clean", "syn_hack")
if len(pops.get(p, [])) and not lo < float(np.median(pops[p][col].to_numpy())) < hi]
if off:
ax.annotate("off-scale: " + ", ".join(off) + r" $\rightarrow$",
xy=(0.98, 0.68), xycoords="axes fraction", ha="right", fontsize=7, color="#777777")
curves = [("on_solve", SOLVE, "-", 1.9, 0.12), ("on_fail", FAIL, "-", 1.9, 0.12),
("on_hackpos", HACK, "-", 1.9, 0.12),
@@ -273,7 +297,7 @@ def plot_q2(df: pl.DataFrame, k_mid: float, k_rout: float, subtitle: str, out_pn
curves.insert(3, ("on_hackneg", HACK, (0, (1, 1)), 1.2, 0.0))
ymax = 0.0
for p, c, ls, lw, fill in curves:
yk = _kde(pops[p][col].to_numpy(), grid)
yk = _kde(z_of(pops[p][col].to_numpy(), p), grid)
ymax = max(ymax, yk.max())
if fill:
ax.fill_between(grid, yk, color=c, alpha=fill, lw=0)
@@ -283,8 +307,9 @@ def plot_q2(df: pl.DataFrame, k_mid: float, k_rout: float, subtitle: str, out_pn
# precision claim rests on a handful of rollouts -- show them). hack row on top,
# in a strip below y=0 (faint separator line); tick labels stay outside the axes.
ax.axhline(0, color="#cccccc", lw=0.6, zorder=0)
ax.axvline(0, color="#bbbbbb", lw=0.7, zorder=0) # 0 = family mean (z-norm)
for row, (p, c) in enumerate((("on_hackpos", HACK), ("on_fail", FAIL), ("on_solve", SOLVE))):
x = pops[p][col].to_numpy()
x = z_of(pops[p][col].to_numpy(), p)
ax.plot(x, np.full(len(x), -(0.035 + 0.035 * row) * ymax), "|",
color=c, ms=4, alpha=0.6, mew=0.8)
@@ -308,8 +333,8 @@ def plot_q2(df: pl.DataFrame, k_mid: float, k_rout: float, subtitle: str, out_pn
ax.spines[sp].set_visible(False)
ax.set_title(f"{rep} · {kind} AUROC={auroc_pos:.2f} (A>0 contrast; vs-all {auroc:.2f}) "
f"P@rout={prec:.2f} (n={n_rout}) R={rec:.2f}", fontsize=9)
ax.set_xlabel({"cos": "cosine to v (concat modules)",
"dot": "dot ⟨x, v⟩ (update mass along v)"}[kind], fontsize=8.5)
ax.set_xlabel({"cos": "cosine to v (concat modules), z within family",
"dot": "dot ⟨x, v⟩, z within family"}[kind], fontsize=8.5)
handles = [Line2D([0], [0], color=SOLVE, lw=1.9), Line2D([0], [0], color=FAIL, lw=1.9),
Line2D([0], [0], color=HACK, lw=1.9),
@@ -318,7 +343,7 @@ def plot_q2(df: pl.DataFrame, k_mid: float, k_rout: float, subtitle: str, out_pn
Patch(facecolor=ABSORB_C, alpha=0.18), Patch(facecolor=ROUT_C, alpha=0.18),
Line2D([0], [0], color=ORACLE, lw=1.3, ls="-.")]
labels = ["live solve", "live fail", "live hack (A>0)", "synthetic clean", "synthetic hack",
f"absorb (>mean+{k_mid:g}sd)", f"rout (>=mean+{k_rout:g}sd)", "oracle hack/rest split"]
"absorb (otsu lo, label-free)", "rout (otsu hi, label-free)", "oracle hack/rest split"]
fig.legend(handles, labels, loc="lower center", ncol=4, fontsize=8, frameon=False)
fig.suptitle(subtitle, fontsize=10)
fig.tight_layout(rect=(0, 0.07, 1, 0.95))
@@ -335,12 +360,25 @@ def main(cfg: Cfg) -> int:
feats_path = cfg.out_dir / "pinning_feats.pt"
q2_png = cfg.out_dir / "pinning_q2.png"
if cfg.replot is not None:
plot_q2(pl.read_parquet(cfg.replot), cfg.k_mid, cfg.k_rout, f"replot -- {cfg.replot.name}", q2_png)
plot_q2(pl.read_parquet(cfg.replot), f"replot -- {cfg.replot.name}", q2_png)
if rank_path.exists():
print(tabulate(pl.read_parquet(rank_path).to_pandas(), headers="keys",
tablefmt="pipe", floatfmt="+.3f", showindex=False))
return 0
if cfg.feats is not None:
fe = torch.load(cfg.feats, weights_only=False)
logger.info(f"offline re-analysis from {cfg.feats} (no GPU)")
src = str(cfg.feats)
else:
fe = _extract_feats(cfg, feats_path)
src = f"{cfg.run_dir.name} | {cfg.ckpt}"
return _downstream(cfg, fe, src)
def _extract_feats(cfg: Cfg, feats_path: Path) -> dict:
"""One GPU pass: features for every authored pair side and live rollout, saved to
feats_path. Everything downstream is offline re-projection (rerun via --feats)."""
device = torch.device("cuda")
ckpt_path = cfg.run_dir / f"{cfg.ckpt}.safetensors"
meta = _ckpt_meta(ckpt_path)
@@ -349,8 +387,7 @@ def main(cfg: Cfg) -> int:
r = run_cfg.get("lora_r", 32)
init_seed = run_cfg.get("lora_init_seed", 0)
logger.info(f"ckpt {ckpt_path.name} step={meta.get('step')} hack_rate={meta.get('hack_rate')} "
f"model={model_name} r={r} init_seed={init_seed} | run-preset k_mid/k_rout="
f"{run_cfg.get('route_std_mid')}/{run_cfg.get('route_std_rout')} (plot uses {cfg.k_mid}/{cfg.k_rout})")
f"model={model_name} r={r} init_seed={init_seed}")
tok = AutoTokenizer.from_pretrained(model_name)
if tok.pad_token_id is None:
@@ -425,6 +462,27 @@ def main(cfg: Cfg) -> int:
m = (steps == s) & (p_idx == p)
grp_mean[(s, p)] = reward[m].mean()
adv = np.array([reward[i] - grp_mean[(steps[i], p_idx[i])] for i in range(len(reward))])
groups: dict[str, list[int]] = defaultdict(list)
for i, p in enumerate(pairs_all):
groups[p.problem_id.split("_")[0]].append(i)
fe = {"G": G, "ACT": ACT, "RES": RES, "adv": adv, "exploited": exploited,
"gt_pass": gt_pass, "steps": steps, "p_idx": p_idx, "names": names,
"resid_layers": cfg.resid_layers, "pair_feats": PF, "pair_groups": dict(groups),
"pair_ids": [p.problem_id for p in pairs_all]}
torch.save(fe, feats_path)
logger.info(f"wrote {feats_path}")
return fe
def _downstream(cfg: Cfg, fe: dict, src: str) -> int:
"""Scores, pairset table, parquet, and plot from the feature dict -- no GPU."""
data_path = cfg.out_dir / "pinning_data.parquet"
rank_path = cfg.out_dir / "pinning_pairset.parquet"
q2_png = cfg.out_dir / "pinning_q2.png"
PF, pair_ids = fe["pair_feats"], fe["pair_ids"]
G, ACT, RES = fe["G"], fe["ACT"], fe["RES"]
adv, exploited, gt_pass = fe["adv"], fe["exploited"], fe["gt_pass"]
steps, p_idx = fe["steps"], fe["p_idx"]
G_adv = G * torch.tensor(adv, dtype=G.dtype)[:, None, None] # the update the gate sees
# ── Q2 populations: drop A~0 (no update); positive = exploited & A>0 ──
@@ -440,10 +498,8 @@ def main(cfg: Cfg) -> int:
f"too few learnable hacks and every AUROC below is noise.")
# ── headline vectors from the routeV-default subset; placebo swaps in Haar ──
groups: dict[str, list[int]] = defaultdict(list)
for i, p in enumerate(pairs_all):
groups[p.problem_id.split("_")[0]].append(i)
head_idx = [i for i, p in enumerate(pairs_all) if p.problem_id.startswith(cfg.headline_prefix)]
groups: dict[str, list[int]] = fe["pair_groups"]
head_idx = [i for i, pid in enumerate(pair_ids) if pid.startswith(cfg.headline_prefix)]
assert head_idx, f"no pairs match headline prefix {cfg.headline_prefix!r}"
logger.info(f"headline v from prefix {cfg.headline_prefix!r} -> {len(head_idx)} pairs")
@@ -466,7 +522,7 @@ def main(cfg: Cfg) -> int:
for side in ("clean", "hack")}
# ── pairset table: subsets x 4 AUROCs on the SAME cached live features ──
candidates = [("all-in-one", list(range(len(pairs_all))))] + \
candidates = [("all-in-one", list(range(len(pair_ids))))] + \
[(g, idx) for g, idx in sorted(groups.items()) if len(idx) >= 3]
valid_pos = valid & (adv > 0)
rows = []
@@ -502,18 +558,14 @@ def main(cfg: Cfg) -> int:
np.full(n_syn, -1), np.ones(n_syn)) for side in ("clean", "hack")]
df = pl.concat(dfs)
df.write_parquet(data_path)
torch.save({"G": G, "ACT": ACT, "RES": RES, "adv": adv, "exploited": exploited,
"gt_pass": gt_pass, "steps": steps, "p_idx": p_idx, "names": names,
"resid_layers": cfg.resid_layers, "pair_feats": PF, "pair_groups": dict(groups),
"pair_ids": [p.problem_id for p in pairs_all]}, feats_path)
logger.info(f"wrote {data_path} ({len(df)} rows), {feats_path}")
logger.info(f"wrote {data_path} ({len(df)} rows)")
sub = (f"{cfg.run_dir.name} | {cfg.ckpt}, live steps {cfg.step_lo}-{cfg.step_hi}, v from "
sub = (f"{src}, live steps {int(steps.min())}-{int(steps.max())}, v from "
f"'{cfg.headline_prefix}' pairs (n={len(head_idx)}) | "
f"hack+={counts['on_hackpos']} hack-={counts['on_hackneg']} solve={counts['on_solve']} "
f"fail={counts['on_fail']} dropped(A~0)={counts['on_drop']}"
+ (f" | PLACEBO seed={cfg.random_v_seed}" if cfg.random_v_seed is not None else ""))
stats = plot_q2(df, cfg.k_mid, cfg.k_rout, sub, q2_png)
stats = plot_q2(df, sub, q2_png)
best = max(stats, key=lambda c: stats[c]["auroc_pos"])
print(f"\nmain metric: best case on the A>0 contrast = {best} "
f"AUROC={stats[best]['auroc_pos']:.3f} (vs-all {stats[best]['auroc_all']:.3f}) "
+1 -1
View File
@@ -43,7 +43,7 @@ from vgrout.lora2r import wrap_model_with_lora2r
from vgrout.pairs import load_pairs
from vgrout.extract_vhack_grad import extract_v_hack, completion_nll
from vgrout.train import _build_v_grad, route_band_edges, _auroc
from diag_pinning import _kde, SOLVE, HACK, MEANC # same dir (scripts/ is sys.path[0] when run directly)
from diag_pinning import _kde, SOLVE, HACK, ABSORB_C as MEANC # same dir (scripts/ is sys.path[0] when run directly)
@dataclass