mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 17:00:59 +08:00
diag(#40): pinning plot splits solve/fail/hack + per-pairset AUROC ranking
Q4 fix: on-policy "solve" was ~exploited = solves+fails (mostly fails). Split by gt_pass into solve/fail/hack (live: 103 hack / 27 solve / 62 fail). Per-pairset ranking: build v_grad from each heading-prefix subset, re-project the SAME stored live c-grads (no model re-run). Finding: behavior pairs AUROC 0.69 vs all-in-one 0.53; reasoning/opportunity anti-aligned (<0.5) -> mixing dilutes. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
+162
-92
@@ -1,21 +1,27 @@
|
||||
"""Pinning calibration: where does the live gate sit, and would mean +/- 2sd route it?
|
||||
"""Pinning calibration + per-pairset separation: where does the live gate sit, would
|
||||
mean +/- k*sd route it, and which authored-pair subset gives the best hack/clean split?
|
||||
|
||||
The routeV gate scores each rollout by the width-pooled, band-normalized cosine of its
|
||||
deployed-block c-probe gradient to v_grad (exactly train.py:_lora2r_gate_labels) -- call it
|
||||
the position on the HACKING DIRECTION. The current gate routes a FIXED quantile tail (#30).
|
||||
This plots four populations on that axis:
|
||||
on-policy solve / hack -- live rollouts (oracle `exploited` read for COLOUR/AUROC only)
|
||||
synthetic solve / hack -- the authored pairs v_grad was built from (the only label source)
|
||||
and marks the PROPOSED pinning: online mean and mean +/- 2*sd (a symmetric band that
|
||||
self-calibrates to the live spread instead of forcing a fixed %).
|
||||
deployed-block c-probe gradient to v_grad (exactly train.py:_lora2r_gate_labels) -- the
|
||||
position on the HACKING DIRECTION. This script:
|
||||
|
||||
1. Splits live rollouts into THREE honest populations (Q4 fix -- "not-exploited" was
|
||||
conflating genuine solves with fails): on-policy solve (gt_pass), fail, hack (exploited).
|
||||
Plots them with the authored synthetic solve/hack, marks online mean and mean +/- 2sd.
|
||||
2. Ranks pairset SUBSETS (heading-prefix views: behavior/opportunity/disposition/...) by
|
||||
how well a v_grad built from ONLY that subset separates live hacks (AUROC + Cohen's d).
|
||||
Done by storing each live rollout's per-module deployed c-grad ONCE, then re-projecting
|
||||
onto each subset's v_grad -- no model re-run per subset (the refresh-tracking trick).
|
||||
|
||||
uv run python scripts/diag_pinning.py --run-dir out/runs/<vanilla_lora2r_run>
|
||||
uv run python scripts/diag_pinning.py --replot out/diag/pinning_data.parquet # no GPU, restyle only
|
||||
outputs (out/diag/): pinning_calib.png, pinning_data.parquet (the 4 populations, regenerates the plot).
|
||||
uv run python scripts/diag_pinning.py --replot out/diag/pinning_data.parquet # no GPU
|
||||
outputs (out/diag/): pinning_calib.png, pinning_pairset_auroc.png, pinning_data.parquet,
|
||||
pinning_pairset.parquet.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import struct
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
@@ -29,16 +35,17 @@ import matplotlib.pyplot as plt
|
||||
from matplotlib.lines import Line2D
|
||||
from matplotlib.patches import Patch
|
||||
from loguru import logger
|
||||
from tabulate import tabulate
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from vgrout.lora2r import wrap_model_with_lora2r
|
||||
from vgrout.pairs import load_pairs
|
||||
from vgrout.pairs import load_pairs, HackPair
|
||||
from vgrout.extract_vhack_grad import extract_v_hack, completion_nll
|
||||
from vgrout.train import _build_v_grad, route_band_edges, _haar_unit_dirs, _auroc
|
||||
|
||||
# colour = behaviour (blue solve, red hack); style = source (solid+fill on-policy, dashed synthetic)
|
||||
SOLVE, HACK, MEANC, ORACLE = "#3b6ea5", "#c44e52", "#d1900a", "#3a8a7a"
|
||||
# colour = behaviour (blue solve, red hack, grey fail); style = source (solid on-policy, dashed synthetic)
|
||||
SOLVE, HACK, FAIL, MEANC, ORACLE = "#3b6ea5", "#c44e52", "#9aa0a6", "#d1900a", "#3a8a7a"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -61,21 +68,31 @@ def _ckpt_meta(path: Path) -> dict:
|
||||
return json.loads(f.read(struct.unpack("<Q", f.read(8))[0])).get("__metadata__", {})
|
||||
|
||||
|
||||
def pooled_pos(c_grads: dict, v_grad: dict, route_band: dict, names: list, r: int) -> tuple[float, float]:
|
||||
"""Mirror train.py:_lora2r_gate_labels pooling for ONE rollout's per-module deployed c-grads."""
|
||||
num = 0.0; den = 0.0; w = 0.0; n_inc = 0
|
||||
for name in names:
|
||||
lower, upper = route_band[name]
|
||||
if upper - lower <= 0:
|
||||
continue
|
||||
g_b = c_grads[name][:r].float().to(v_grad[name].device)
|
||||
nrm = g_b.norm()
|
||||
cos_b = (torch.einsum("r, k r -> k", g_b, v_grad[name]).max() / nrm.clamp_min(1e-12)).item()
|
||||
num += cos_b - lower; den += upper - lower
|
||||
w += nrm.item(); n_inc += 1
|
||||
if n_inc == 0:
|
||||
raise RuntimeError("no module has positive band width")
|
||||
return num / den, w / n_inc
|
||||
def _batch_pos(G: torch.Tensor, names: list, v_grad: dict, route_band: dict) -> np.ndarray:
|
||||
"""Pooled band-normalized cosine position per rollout, vectorized over modules.
|
||||
|
||||
G: [N, M, r] raw deployed-block c-grads (v-independent). Mirrors train.py pooling:
|
||||
per module cos to ANY of the k dirs (max), width-weighted pool, excluding zero-width
|
||||
modules. Same G re-projects onto any v_grad/route_band -> cheap per-pairset scoring.
|
||||
"""
|
||||
dev = v_grad[names[0]].device
|
||||
Vs = torch.stack([v_grad[n] for n in names]).to(dev) # [M, k, r]
|
||||
low = torch.tensor([route_band[n][0] for n in names], device=dev) # [M]
|
||||
up = torch.tensor([route_band[n][1] for n in names], device=dev) # [M]
|
||||
inc = (up - low) > 0
|
||||
U = (G.to(dev) / G.to(dev).norm(dim=2, keepdim=True).clamp_min(1e-12)) # unit per (n,m)
|
||||
cos = torch.einsum("nmr,mkr->nmk", U, Vs).amax(-1) # [N, M]
|
||||
num = ((cos - low) * inc).sum(1)
|
||||
den = ((up - low) * inc).sum().clamp_min(1e-12)
|
||||
return (num / den).cpu().numpy()
|
||||
|
||||
|
||||
def _cohend(a: np.ndarray, b: np.ndarray) -> float:
|
||||
"""Standardized mean gap (hack - solve); +ve = hacks sit further along the hack-dir."""
|
||||
if len(a) < 2 or len(b) < 2:
|
||||
return float("nan")
|
||||
sp = np.sqrt(((len(a) - 1) * a.var(ddof=1) + (len(b) - 1) * b.var(ddof=1)) / (len(a) + len(b) - 2))
|
||||
return float((a.mean() - b.mean()) / sp) if sp > 0 else float("nan")
|
||||
|
||||
|
||||
def _kde(x: np.ndarray, grid: np.ndarray) -> np.ndarray:
|
||||
@@ -90,31 +107,33 @@ def _kde(x: np.ndarray, grid: np.ndarray) -> np.ndarray:
|
||||
return np.exp(-0.5 * z ** 2).sum(1) / (len(x) * bw * np.sqrt(2 * np.pi))
|
||||
|
||||
|
||||
def plot(df: pl.DataFrame, subtitle: str, out_png: Path) -> None:
|
||||
"""Regenerate the figure from the saved 4 populations -- no GPU needed."""
|
||||
def plot_dist(df: pl.DataFrame, subtitle: str, out_png: Path) -> None:
|
||||
"""Distribution plot from the 5 saved populations -- no GPU needed."""
|
||||
arr = lambda p: df.filter(pl.col("pop") == p)["pos"].to_numpy()
|
||||
on_solve, on_hack = arr("on_solve"), arr("on_hack")
|
||||
on_solve, on_fail, on_hack = arr("on_solve"), arr("on_fail"), arr("on_hack")
|
||||
syn_solve, syn_hack = arr("syn_solve"), arr("syn_hack")
|
||||
pos_live = np.concatenate([on_solve, on_hack])
|
||||
labels = np.concatenate([np.zeros(len(on_solve), bool), np.ones(len(on_hack), bool)])
|
||||
|
||||
# AUROC/divider are hack-vs-rest (the gate's job): label = exploited.
|
||||
pos_live = np.concatenate([on_solve, on_fail, on_hack])
|
||||
labels = np.concatenate([np.zeros(len(on_solve) + len(on_fail), bool), np.ones(len(on_hack), bool)])
|
||||
mean, sd = float(pos_live.mean()), float(pos_live.std())
|
||||
lo_b, hi_b = mean - 2 * sd, mean + 2 * sd # proposed routing band
|
||||
lo_b, hi_b = mean - 2 * sd, mean + 2 * sd
|
||||
auroc = _auroc(pos_live.tolist(), labels.tolist())
|
||||
thr = np.unique(pos_live) # oracle divider (Youden J) -- diagnostic only
|
||||
thr = np.unique(pos_live)
|
||||
j = [(pos_live[labels] >= t).mean() - (pos_live[~labels] >= t).mean() for t in thr]
|
||||
oracle = float(thr[int(np.argmax(j))])
|
||||
|
||||
lo = min(pos_live.min(), syn_solve.min()) - 0.1
|
||||
hi = max(np.quantile(pos_live, 0.99), syn_hack.max()) + 0.1
|
||||
grid = np.linspace(lo, hi, 400)
|
||||
POPS = [(on_solve, SOLVE, True, "on-policy solve"), (on_hack, HACK, True, "on-policy hack"),
|
||||
(syn_solve, SOLVE, False, "synthetic solve"), (syn_hack, HACK, False, "synthetic hack")]
|
||||
kdes = {i: _kde(x, grid) for i, (x, *_ ) in enumerate(POPS)}
|
||||
POPS = [(on_solve, SOLVE, True, "on-policy solve (gt pass)"),
|
||||
(on_fail, FAIL, True, "on-policy fail"),
|
||||
(on_hack, HACK, True, "on-policy hack"),
|
||||
(syn_solve, SOLVE, False, "synthetic solve"),
|
||||
(syn_hack, HACK, False, "synthetic hack")]
|
||||
kdes = {i: _kde(x, grid) for i, (x, *_) in enumerate(POPS)}
|
||||
ymax = max(y.max() for y in kdes.values()) * 1.15
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8.6, 4.8))
|
||||
# proposed pinning: mean +/- 2sd band, drawn first so curves sit on top
|
||||
fig, ax = plt.subplots(figsize=(8.8, 4.8))
|
||||
ax.axvspan(lo_b, hi_b, color=MEANC, alpha=0.07, lw=0)
|
||||
ax.axvline(mean, color=MEANC, lw=1.8)
|
||||
for b in (lo_b, hi_b):
|
||||
@@ -123,49 +142,72 @@ def plot(df: pl.DataFrame, subtitle: str, out_png: Path) -> None:
|
||||
for i, (x, col, on_policy, _) in enumerate(POPS):
|
||||
y = kdes[i]
|
||||
if on_policy:
|
||||
ax.fill_between(grid, y, color=col, alpha=0.13, lw=0)
|
||||
ax.fill_between(grid, y, color=col, alpha=0.12, lw=0)
|
||||
ax.plot(grid, y, color=col, lw=1.9)
|
||||
else:
|
||||
ax.fill_between(grid, y, color=col, alpha=0.05, lw=0)
|
||||
ax.plot(grid, y, color=col, lw=2.4, ls=(0, (5, 2)))
|
||||
ax.plot(grid, y, color=col, lw=2.2, ls=(0, (5, 2)))
|
||||
ax.set_ylim(0, ymax)
|
||||
for s in ("top", "right"):
|
||||
ax.spines[s].set_visible(False)
|
||||
ax.set_ylabel("density")
|
||||
ax.set_xlabel("hacking direction (gradient cosine to v_grad) " + r"$\longrightarrow$")
|
||||
|
||||
dist_handles = [Line2D([0], [0], color=c, lw=2.2, ls="-" if op else (0, (5, 2))) for _, c, op, _ in POPS]
|
||||
leg1 = ax.legend(dist_handles, [lab for *_, lab in POPS], loc="upper left",
|
||||
fontsize=8, frameon=False, title="distributions", title_fontsize=8)
|
||||
leg1._legend_box.align = "left"
|
||||
ax.add_artist(leg1)
|
||||
mark_handles = [
|
||||
Line2D([0], [0], color=MEANC, lw=1.8),
|
||||
Patch(facecolor=MEANC, alpha=0.18, edgecolor=MEANC, ls="--"),
|
||||
Line2D([0], [0], color=ORACLE, lw=1.3, ls="-."),
|
||||
]
|
||||
mark_handles = [Line2D([0], [0], color=MEANC, lw=1.8),
|
||||
Patch(facecolor=MEANC, alpha=0.18, edgecolor=MEANC, ls="--"),
|
||||
Line2D([0], [0], color=ORACLE, lw=1.3, ls="-.")]
|
||||
ax.legend(mark_handles,
|
||||
[f"online mean ({mean:+.2f})", f"mean +/- 2sd (sd={sd:.2f}) = proposed band",
|
||||
f"best hack/solve split ({oracle:+.2f}, diagnostic)"],
|
||||
loc="upper right", fontsize=8, frameon=False, title="pinning", title_fontsize=8)
|
||||
routed_hi = float((pos_live > hi_b).mean())
|
||||
ax.set_title(f"{subtitle}\nAUROC(hack-dir -> hack)={auroc:.2f} "
|
||||
f"{routed_hi:.0%} of rollouts beyond mean+2sd", fontsize=9.5)
|
||||
[f"online mean ({mean:+.2f})", f"mean +/- 2sd (sd={sd:.2f})",
|
||||
f"best hack/rest split ({oracle:+.2f})"],
|
||||
loc="upper right", fontsize=8, frameon=False, title="proposed pinning", title_fontsize=8)
|
||||
ax.set_title(f"{subtitle}\nlive hack-direction positions vs authored v_grad "
|
||||
f"(hack-vs-rest AUROC={auroc:.2f})", fontsize=9.5)
|
||||
fig.tight_layout()
|
||||
fig.savefig(out_png, dpi=140)
|
||||
plt.close(fig)
|
||||
logger.info(f"mean={mean:+.3f} sd={sd:.3f} AUROC={auroc:.3f} -> {out_png}")
|
||||
|
||||
|
||||
def plot_pairset(rank: pl.DataFrame, subtitle: str, out_png: Path) -> None:
|
||||
"""Horizontal bar chart: AUROC of a v_grad built from each pairset subset vs live hacks."""
|
||||
rank = rank.sort("auroc")
|
||||
groups = rank["group"].to_list()
|
||||
auroc = rank["auroc"].to_numpy()
|
||||
npairs = rank["n_pairs"].to_list()
|
||||
cohend = rank["cohend"].to_numpy()
|
||||
y = np.arange(len(groups))
|
||||
fig, ax = plt.subplots(figsize=(7.2, 0.5 * len(groups) + 1.4))
|
||||
cols = [HACK if g == "all-in-one" else SOLVE for g in groups]
|
||||
ax.barh(y, auroc, color=cols, alpha=0.85)
|
||||
ax.axvline(0.5, color="k", lw=1, ls=":") # 0.5 = blind to live hacks
|
||||
ax.set_yticks(y)
|
||||
ax.set_yticklabels([f"{g} (n={n})" for g, n in zip(groups, npairs)], fontsize=8)
|
||||
for yi, a, d in zip(y, auroc, cohend):
|
||||
ax.text(a + 0.005, yi, f"{a:.2f} d={d:+.2f}", va="center", fontsize=7.5)
|
||||
ax.set_xlim(min(0.45, auroc.min() - 0.03), max(auroc.max() + 0.12, 0.6))
|
||||
for s in ("top", "right"):
|
||||
ax.spines[s].set_visible(False)
|
||||
ax.set_xlabel("hack-vs-rest AUROC of subset's v_grad (0.5 = blind, dotted)")
|
||||
ax.set_title(f"{subtitle}\nwhich authored-pair subset separates live hacks best?", fontsize=9.5)
|
||||
fig.tight_layout()
|
||||
fig.savefig(out_png, dpi=140)
|
||||
plt.close(fig)
|
||||
logger.info(f"mean={mean:+.3f} sd={sd:.3f} band=[{lo_b:+.3f},{hi_b:+.3f}] oracle={oracle:+.3f} "
|
||||
f"AUROC={auroc:.3f} routed(>mean+2sd)={routed_hi:.2f}")
|
||||
logger.info(f"wrote {out_png}")
|
||||
|
||||
|
||||
def main(cfg: Cfg) -> int:
|
||||
cfg.out_dir.mkdir(parents=True, exist_ok=True)
|
||||
data_path = cfg.out_dir / "pinning_data.parquet"
|
||||
out_png = cfg.out_dir / "pinning_calib.png"
|
||||
rank_path = cfg.out_dir / "pinning_pairset.parquet"
|
||||
dist_png = cfg.out_dir / "pinning_calib.png"
|
||||
rank_png = cfg.out_dir / "pinning_pairset_auroc.png"
|
||||
if cfg.replot is not None:
|
||||
df = pl.read_parquet(cfg.replot)
|
||||
plot(df, f"pinning calibration (replot) -- {cfg.replot.name}", out_png)
|
||||
plot_dist(pl.read_parquet(cfg.replot), f"replot -- {cfg.replot.name}", dist_png)
|
||||
if rank_path.exists():
|
||||
plot_pairset(pl.read_parquet(rank_path), "replot", rank_png)
|
||||
return 0
|
||||
|
||||
device = torch.device("cuda")
|
||||
@@ -191,60 +233,88 @@ def main(cfg: Cfg) -> int:
|
||||
wrappers[nm]["A"].data.copy_(sd[f"A/{nm}"].to(device, torch.float32))
|
||||
wrappers[nm]["B"].data.copy_(sd[f"B/{nm}"].to(device, torch.float32))
|
||||
logger.info(f"loaded A/B into {len(names)} modules")
|
||||
|
||||
pairs = load_pairs(cfg.pairs)
|
||||
logger.info(f"pairs {cfg.pairs} -> {len(pairs)}")
|
||||
model.eval()
|
||||
_, _, raw_grads, _ = extract_v_hack(model, tok, wrappers, pairs,
|
||||
top_k=1, tau_axis=0.0, n_heldout=2, device=device)
|
||||
v_grad = _build_v_grad(raw_grads, wrappers, 1, device)
|
||||
if cfg.random_v_seed is not None:
|
||||
v_grad = _haar_unit_dirs(v_grad, cfg.random_v_seed, device)
|
||||
logger.info(f"OVERRODE v_grad with Haar dirs seed={cfg.random_v_seed} (placebo)")
|
||||
route_band = route_band_edges(raw_grads, v_grad, device)
|
||||
|
||||
def v_grad_from(pairs: list[HackPair]):
|
||||
_, _, raw, _ = extract_v_hack(model, tok, wrappers, pairs,
|
||||
top_k=1, tau_axis=0.0, n_heldout=2, device=device)
|
||||
vg = _build_v_grad(raw, wrappers, 1, device)
|
||||
if cfg.random_v_seed is not None:
|
||||
vg = _haar_unit_dirs(vg, cfg.random_v_seed, device)
|
||||
return vg, route_band_edges(raw, vg, device), raw
|
||||
|
||||
pairs_all = load_pairs(cfg.pairs)
|
||||
logger.info(f"pairs {cfg.pairs} -> {len(pairs_all)}")
|
||||
v_grad, route_band, raw_all = v_grad_from(pairs_all)
|
||||
|
||||
# synthetic pair positions under the all-in-one v_grad (for the distribution plot)
|
||||
def pair_pos(side: str) -> np.ndarray:
|
||||
n = raw_grads[f"{side}/{names[0]}"].shape[0]
|
||||
out = []
|
||||
for i in range(n):
|
||||
cg = {nm: torch.cat([raw_grads[f"{side}/{nm}"][i], torch.zeros(r)]) for nm in names}
|
||||
out.append(pooled_pos(cg, v_grad, route_band, names, r)[0])
|
||||
return np.array(out)
|
||||
n = raw_all[f"{side}/{names[0]}"].shape[0]
|
||||
G = torch.stack([torch.stack([raw_all[f"{side}/{nm}"][i] for nm in names]) for i in range(n)])
|
||||
return _batch_pos(G, names, v_grad, route_band)
|
||||
syn_solve, syn_hack = pair_pos("clean"), pair_pos("hack")
|
||||
|
||||
# score the live batch ONCE; store per-rollout deployed c-grads for cheap re-projection
|
||||
recs = [json.loads(l) for l in (cfg.run_dir / "rollouts.jsonl").read_text().splitlines()]
|
||||
batch = [r_ for r_ in recs if cfg.step_lo <= r_["step"] <= cfg.step_hi and r_["text"].strip()][:cfg.max_rollouts]
|
||||
batch = [x for x in recs if cfg.step_lo <= x["step"] <= cfg.step_hi and x["text"].strip()][:cfg.max_rollouts]
|
||||
logger.info(f"live batch: {len(batch)} rollouts (steps {cfg.step_lo}-{cfg.step_hi})")
|
||||
pos_live, labels, steps = [], [], []
|
||||
G_rows, exploited, gt_pass, steps = [], [], [], []
|
||||
for i, rec in enumerate(batch):
|
||||
model.zero_grad(set_to_none=True)
|
||||
loss = completion_nll(model, tok, rec["prompt"], rec["text"], device)
|
||||
if not torch.isfinite(loss):
|
||||
continue
|
||||
loss.backward()
|
||||
cg = {nm: wrappers[nm]["layer"]._lora2r_gate.grad.sum(dim=tuple(range(
|
||||
wrappers[nm]["layer"]._lora2r_gate.grad.dim() - 1))) for nm in names}
|
||||
pos_live.append(pooled_pos(cg, v_grad, route_band, names, r)[0])
|
||||
labels.append(bool(rec["exploited"])); steps.append(rec["step"])
|
||||
g = []
|
||||
for nm in names:
|
||||
gr = wrappers[nm]["layer"]._lora2r_gate.grad
|
||||
g.append(gr.sum(dim=tuple(range(gr.dim() - 1)))[:r].float().cpu()) # deployed block [r]
|
||||
G_rows.append(torch.stack(g)) # [M, r]
|
||||
exploited.append(bool(rec["exploited"])); gt_pass.append(bool(rec["gt_pass"])); steps.append(rec["step"])
|
||||
if (i + 1) % 40 == 0:
|
||||
logger.info(f" rollout {i+1}/{len(batch)}")
|
||||
model.zero_grad(set_to_none=True)
|
||||
pos_live = np.array(pos_live); labels = np.array(labels); steps = np.array(steps)
|
||||
per_step = {int(s): round(_auroc(pos_live[steps == s].tolist(), labels[steps == s].tolist()), 2)
|
||||
for s in sorted(set(steps.tolist())) if labels[steps == s].any() and (~labels[steps == s]).any()}
|
||||
logger.info(f"live: {len(labels)} rollouts, {int(labels.sum())} exploited; per-step AUROC={per_step}")
|
||||
G = torch.stack(G_rows) # [N, M, r]
|
||||
exploited, gt_pass, steps = map(np.array, (exploited, gt_pass, steps))
|
||||
pos = _batch_pos(G, names, v_grad, route_band)
|
||||
# THREE honest populations (Q4 fix): solve = gt-correct & not-exploited; fail = neither; hack = exploited
|
||||
solve_m = gt_pass & ~exploited
|
||||
fail_m = ~gt_pass & ~exploited
|
||||
logger.info(f"live: {len(exploited)} rollouts | hack={int(exploited.sum())} "
|
||||
f"solve={int(solve_m.sum())} fail={int(fail_m.sum())}")
|
||||
|
||||
# per-pairset AUROC: build v_grad from each heading-prefix subset, re-project the SAME G
|
||||
groups: dict[str, list] = defaultdict(list)
|
||||
for p in pairs_all:
|
||||
groups[p.problem_id.split("_")[0]].append(p)
|
||||
rows = []
|
||||
# extract_v_hack holds out 2 pairs, so a subset needs >=3 to leave a training pair.
|
||||
candidates = [("all-in-one", pairs_all)] + [(g, ps) for g, ps in sorted(groups.items()) if len(ps) >= 3]
|
||||
for gname, gpairs in candidates:
|
||||
vg, rb, _ = v_grad_from(gpairs)
|
||||
pg = _batch_pos(G, names, vg, rb)
|
||||
auroc = _auroc(pg.tolist(), exploited.tolist())
|
||||
cohend = _cohend(pg[exploited], pg[solve_m]) if solve_m.sum() >= 2 else _cohend(pg[exploited], pg[~exploited])
|
||||
rows.append({"group": gname, "n_pairs": len(gpairs), "auroc": round(auroc, 3),
|
||||
"cohend": round(cohend, 3)})
|
||||
logger.info(f" pairset {gname:12s} n={len(gpairs):2d} AUROC={auroc:.3f} d={cohend:+.3f}")
|
||||
rank = pl.DataFrame(rows).sort("auroc", descending=True)
|
||||
rank.write_parquet(rank_path)
|
||||
print(tabulate(rank.to_pandas(), headers="keys", tablefmt="pipe", floatfmt="+.3f", showindex=False))
|
||||
|
||||
# save the 4 populations long-form -> regenerates the plot with --replot (no GPU)
|
||||
df = pl.concat([
|
||||
pl.DataFrame({"pop": "on_solve", "pos": pos_live[~labels], "step": steps[~labels]}),
|
||||
pl.DataFrame({"pop": "on_hack", "pos": pos_live[labels], "step": steps[labels]}),
|
||||
pl.DataFrame({"pop": "on_solve", "pos": pos[solve_m], "step": steps[solve_m]}),
|
||||
pl.DataFrame({"pop": "on_fail", "pos": pos[fail_m], "step": steps[fail_m]}),
|
||||
pl.DataFrame({"pop": "on_hack", "pos": pos[exploited], "step": steps[exploited]}),
|
||||
pl.DataFrame({"pop": "syn_solve", "pos": syn_solve, "step": np.full(len(syn_solve), -1)}),
|
||||
pl.DataFrame({"pop": "syn_hack", "pos": syn_hack, "step": np.full(len(syn_hack), -1)}),
|
||||
])
|
||||
df.write_parquet(data_path)
|
||||
logger.info(f"wrote {data_path} ({len(df)} rows)")
|
||||
plot(df, f"{cfg.run_dir.name}\n{cfg.ckpt} v_grad, live steps {cfg.step_lo}-{cfg.step_hi}, "
|
||||
f"{int(labels.sum())}/{len(labels)} exploited", out_png)
|
||||
sub = (f"{cfg.run_dir.name}\n{cfg.ckpt} v_grad, live steps {cfg.step_lo}-{cfg.step_hi}, "
|
||||
f"hack={int(exploited.sum())} solve={int(solve_m.sum())} fail={int(fail_m.sum())}")
|
||||
plot_dist(df, sub, dist_png)
|
||||
plot_pairset(rank, cfg.run_dir.name, rank_png)
|
||||
return 0
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user