mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 16:45:42 +08:00
diag(#40): z-norm scores within family, winsorized 2-threshold Otsu zones, --feats offline mode
Synthetic and live scores share an axis with meaningful zero (raw scores carry a common <mu,v> offset since v = mean diff is not orthogonal to the family mean). Zones come from label-free online stats (EMA mean/std + Otsu valley), replacing mean+k*sd which placed both cuts beyond every distribution. Winsorize at 1/99% before Otsu: variance-maximizing cuts otherwise buy a class for one outlier. Fresh-eyes review verified z-norm is affine (AUROCs unchanged), zones label-free, P/R recompute from parquet matches titles. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
+102
-50
@@ -36,11 +36,17 @@ v for each representation comes only from authored pairs (mean hack-minus-clean,
|
|||||||
normalized per module). Ground-truth labels from training rollouts are used only for
|
normalized per module). Ground-truth labels from training rollouts are used only for
|
||||||
diagnostic AUROC and precision measurements, never for routing.
|
diagnostic AUROC and precision measurements, never for routing.
|
||||||
|
|
||||||
PINNING. Each panel shades the three zones the online gate rule would give on this
|
DISPLAY + PINNING. Scores are plotted Z-NORMALIZED WITHIN FAMILY: live scores by the
|
||||||
window: keep (bulk) | absorb (score > mean + k_mid*sd) | rout (>= mean + k_rout*sd),
|
mean/std of all valid live rollouts, synthetic scores by the mean/std of the joint
|
||||||
plus the oracle best hack-vs-rest split for reference. k's default to the real-run
|
clean+hack pair scores. Affine per family, so every AUROC is unchanged; it puts both
|
||||||
Config values (2/3), not the checkpoint's preset, so the plot answers "where WOULD
|
families on one axis with a meaningful zero. (Raw scores share an offset <mu, v>:
|
||||||
the cuts fall", overridable via --k-mid/--k-rout.
|
v = mean(hack-clean) guarantees only the GAP between sides, not its location, and the
|
||||||
|
authored-pair common mean is not orthogonal to v, so uncentered both pair sides land
|
||||||
|
positive.) Zones keep | absorb | rout come from two-threshold Otsu on the live
|
||||||
|
z-scores -- the label-free valley cuts an online gate could compute from a rolling
|
||||||
|
score window (EMA mean/std + valley search). The previous mean+k*sd rule modeled
|
||||||
|
hacks as a rare outlier tail and put both cuts beyond every distribution (hack share
|
||||||
|
in these windows is 35-43%); the oracle hack-vs-rest split is drawn for reference.
|
||||||
|
|
||||||
CAVEAT. Live advantages are reconstructed from rollouts.jsonl students only (teachers
|
CAVEAT. Live advantages are reconstructed from rollouts.jsonl students only (teachers
|
||||||
absent, zero-variance groups included, and skipped/empty completions missing from the
|
absent, zero-variance groups included, and skipped/empty completions missing from the
|
||||||
@@ -53,7 +59,9 @@ downstream (subset vectors, 4 scores, zones, table) is offline re-projection of
|
|||||||
cached features.
|
cached features.
|
||||||
|
|
||||||
uv run python scripts/diag_pinning.py --run-dir out/runs/<vanilla_lora2r_run>
|
uv run python scripts/diag_pinning.py --run-dir out/runs/<vanilla_lora2r_run>
|
||||||
uv run python scripts/diag_pinning.py --replot out/diag/pinning_data.parquet # no GPU
|
uv run python scripts/diag_pinning.py --feats out/diag/pinning_feats.pt # no GPU:
|
||||||
|
# recompute scores/table/plot from cached feats
|
||||||
|
uv run python scripts/diag_pinning.py --replot out/diag/pinning_data.parquet # plot only
|
||||||
outputs (out/diag/): pinning_q2.png (3x2 headline), pinning_data.parquet (per-rollout
|
outputs (out/diag/): pinning_q2.png (3x2 headline), pinning_data.parquet (per-rollout
|
||||||
scores), pinning_pairset.parquet + printed table (subsets x 6 AUROCs),
|
scores), pinning_pairset.parquet + printed table (subsets x 6 AUROCs),
|
||||||
pinning_feats.pt (raw features, for offline re-analysis).
|
pinning_feats.pt (raw features, for offline re-analysis).
|
||||||
@@ -105,11 +113,10 @@ class Cfg:
|
|||||||
step_lo: int = 2
|
step_lo: int = 2
|
||||||
step_hi: int = 9
|
step_hi: int = 9
|
||||||
max_rollouts: int = 240
|
max_rollouts: int = 240
|
||||||
k_mid: float = 2.0 # absorb onset: score > mean + k_mid*sd (real-run Config default)
|
|
||||||
k_rout: float = 3.0 # rout onset: score >= mean + k_rout*sd
|
|
||||||
adv_eps: float = 1e-6 # |A| below this = no update exists -> dropped from zones/AUROC
|
adv_eps: float = 1e-6 # |A| below this = no update exists -> dropped from zones/AUROC
|
||||||
resid_layers: tuple[int, ...] = (12, 18, 24) # residual-stream capture depths (of 36)
|
resid_layers: tuple[int, ...] = (12, 18, 24) # residual-stream capture depths (of 36)
|
||||||
random_v_seed: int | None = None # Haar placebo (sanity: nothing should separate)
|
random_v_seed: int | None = None # Haar placebo (sanity: nothing should separate)
|
||||||
|
feats: Path | None = None # cached pinning_feats.pt -> full offline re-analysis
|
||||||
replot: Path | None = None # load parquet and re-plot only (no model, no GPU)
|
replot: Path | None = None # load parquet and re-plot only (no model, no GPU)
|
||||||
out_dir: Path = Path("out/diag")
|
out_dir: Path = Path("out/diag")
|
||||||
|
|
||||||
@@ -209,13 +216,35 @@ def _kde(x: np.ndarray, grid: np.ndarray) -> np.ndarray:
|
|||||||
return np.exp(-0.5 * z ** 2).sum(1) / (len(x) * bw * np.sqrt(2 * np.pi))
|
return np.exp(-0.5 * z ** 2).sum(1) / (len(x) * bw * np.sqrt(2 * np.pi))
|
||||||
|
|
||||||
|
|
||||||
def plot_q2(df: pl.DataFrame, k_mid: float, k_rout: float, subtitle: str, out_png: Path) -> dict:
|
def _otsu3(x: np.ndarray) -> tuple[float, float]:
|
||||||
"""2x2 figure ({grad,act} x {cos,dot}) from the saved per-rollout scores -- no GPU.
|
"""Two-threshold Otsu: the pair of cuts maximizing 3-class between-class variance.
|
||||||
|
Label-free -- an online gate can compute this from a rolling window of scores, so
|
||||||
|
using it here is not oracle leakage. O(n^2), fine for a few hundred scores.
|
||||||
|
Scores are winsorized at 1/99% first: Otsu maximizes variance, so on heavy-tailed
|
||||||
|
scores a single extreme point otherwise buys a whole class (seen on grad_dot)."""
|
||||||
|
x = np.clip(x, *np.quantile(x, [0.01, 0.99]))
|
||||||
|
s = np.sort(np.asarray(x, float))
|
||||||
|
n = len(s)
|
||||||
|
c = np.concatenate([[0.0], np.cumsum(s)])
|
||||||
|
best, best_ij = -np.inf, (1, 2)
|
||||||
|
for i in range(1, n - 1):
|
||||||
|
for j in range(i + 1, n):
|
||||||
|
obj = c[i] ** 2 / i + (c[j] - c[i]) ** 2 / (j - i) + (c[n] - c[j]) ** 2 / (n - j)
|
||||||
|
if obj > best:
|
||||||
|
best, best_ij = obj, (i, j)
|
||||||
|
i, j = best_ij
|
||||||
|
return float((s[i - 1] + s[i]) / 2), float((s[j - 1] + s[j]) / 2)
|
||||||
|
|
||||||
|
|
||||||
|
def plot_q2(df: pl.DataFrame, subtitle: str, out_png: Path) -> dict:
|
||||||
|
"""3x2 figure ({grad,act,resid} x {cos,dot}) from the saved per-rollout scores -- no GPU.
|
||||||
|
|
||||||
Per panel: live solve/fail/hack+ KDEs (+ thin hack- if n>=3), synthetic pair sides
|
Per panel: live solve/fail/hack+ KDEs (+ thin hack- if n>=3), synthetic pair sides
|
||||||
dashed, three shaded zones keep|absorb|rout from mean + k*sd over the VALID live
|
dashed, all Z-NORMALIZED WITHIN FAMILY (live by valid-live mean/std, synthetic by
|
||||||
scores (|A|>eps; pop 'on_drop' excluded), oracle split, AUROC + P/R at the rout cut.
|
joint clean+hack mean/std; affine, AUROC unchanged) so both families share one axis
|
||||||
Returns the per-case stats dict for logging."""
|
with a meaningful zero. Three shaded zones keep|absorb|rout from two-threshold Otsu
|
||||||
|
on the live z-scores (label-free, online-computable), oracle split, AUROC + P/R at
|
||||||
|
the rout cut. Returns the per-case stats dict for logging."""
|
||||||
pops = {p: df.filter(pl.col("pop") == p) for p in df["pop"].unique().to_list()}
|
pops = {p: df.filter(pl.col("pop") == p) for p in df["pop"].unique().to_list()}
|
||||||
live_pops = ["on_solve", "on_fail", "on_hackpos", "on_hackneg"]
|
live_pops = ["on_solve", "on_fail", "on_hackpos", "on_hackneg"]
|
||||||
live = df.filter(pl.col("pop").is_in(live_pops))
|
live = df.filter(pl.col("pop").is_in(live_pops))
|
||||||
@@ -232,10 +261,16 @@ def plot_q2(df: pl.DataFrame, k_mid: float, k_rout: float, subtitle: str, out_pn
|
|||||||
fig, axes = plt.subplots(n_rows, 2, figsize=(12.5, 3.6 * n_rows + 0.8))
|
fig, axes = plt.subplots(n_rows, 2, figsize=(12.5, 3.6 * n_rows + 0.8))
|
||||||
for ax, (rep, kind) in zip(axes.flat, CASES):
|
for ax, (rep, kind) in zip(axes.flat, CASES):
|
||||||
col = f"{rep}_{kind}"
|
col = f"{rep}_{kind}"
|
||||||
s = live[col].to_numpy()
|
s_raw = live[col].to_numpy()
|
||||||
|
mu_l, sd_l = float(s_raw.mean()), float(s_raw.std()) or 1.0
|
||||||
|
syn_join = np.concatenate([pops[p][col].to_numpy() for p in ("syn_clean", "syn_hack")
|
||||||
|
if len(pops.get(p, []))])
|
||||||
|
mu_s = float(syn_join.mean()) if len(syn_join) else 0.0
|
||||||
|
sd_s = (float(syn_join.std()) or 1.0) if len(syn_join) else 1.0
|
||||||
|
z_of = lambda x, p: (x - mu_s) / sd_s if p.startswith("syn") else (x - mu_l) / sd_l
|
||||||
|
s = (s_raw - mu_l) / sd_l
|
||||||
y = y_all
|
y = y_all
|
||||||
mu, sd = float(s.mean()), float(s.std())
|
t_lo, t_hi = _otsu3(s)
|
||||||
t_lo, t_hi = mu + k_mid * sd, mu + k_rout * sd
|
|
||||||
auroc = _auroc(s.tolist(), y.tolist())
|
auroc = _auroc(s.tolist(), y.tolist())
|
||||||
auroc_pos = _auroc(s[posm].tolist(), y[posm].tolist())
|
auroc_pos = _auroc(s[posm].tolist(), y[posm].tolist())
|
||||||
thr = np.unique(s)
|
thr = np.unique(s)
|
||||||
@@ -248,23 +283,12 @@ def plot_q2(df: pl.DataFrame, k_mid: float, k_rout: float, subtitle: str, out_pn
|
|||||||
stats[col] = {"auroc_pos": auroc_pos, "auroc_all": auroc, "prec_rout": prec,
|
stats[col] = {"auroc_pos": auroc_pos, "auroc_all": auroc, "prec_rout": prec,
|
||||||
"rec_rout": rec, "n_rout": n_rout, "t_hi": t_hi, "oracle": oracle}
|
"rec_rout": rec, "n_rout": n_rout, "t_hi": t_hi, "oracle": oracle}
|
||||||
|
|
||||||
lo = float(np.quantile(s, 0.005))
|
zvals = np.concatenate([s, (syn_join - mu_s) / sd_s]) if len(syn_join) else s
|
||||||
hi = float(np.quantile(s, 0.995))
|
lo = float(np.quantile(zvals, 0.005))
|
||||||
if kind == "cos": # keep synthetic medians visible (cos shares a scale;
|
hi = float(np.quantile(zvals, 0.995))
|
||||||
for p in ("syn_clean", "syn_hack"): # dot doesn't -- pair grads dwarf live, annotate instead)
|
|
||||||
if len(pops.get(p, [])):
|
|
||||||
m = float(np.median(pops[p][col].to_numpy()))
|
|
||||||
lo, hi = min(lo, m), max(hi, m)
|
|
||||||
pad = 0.05 * (hi - lo) or 1e-6
|
pad = 0.05 * (hi - lo) or 1e-6
|
||||||
lo, hi = lo - pad, hi + pad
|
lo, hi = lo - pad, hi + pad
|
||||||
grid = np.linspace(lo, hi, 400)
|
grid = np.linspace(lo, hi, 400)
|
||||||
if kind == "dot":
|
|
||||||
off = [f"syn {p.split('_')[1]} med={float(np.median(pops[p][col].to_numpy())):+.2g}"
|
|
||||||
for p in ("syn_clean", "syn_hack")
|
|
||||||
if len(pops.get(p, [])) and not lo < float(np.median(pops[p][col].to_numpy())) < hi]
|
|
||||||
if off:
|
|
||||||
ax.annotate("off-scale: " + ", ".join(off) + r" $\rightarrow$",
|
|
||||||
xy=(0.98, 0.68), xycoords="axes fraction", ha="right", fontsize=7, color="#777777")
|
|
||||||
|
|
||||||
curves = [("on_solve", SOLVE, "-", 1.9, 0.12), ("on_fail", FAIL, "-", 1.9, 0.12),
|
curves = [("on_solve", SOLVE, "-", 1.9, 0.12), ("on_fail", FAIL, "-", 1.9, 0.12),
|
||||||
("on_hackpos", HACK, "-", 1.9, 0.12),
|
("on_hackpos", HACK, "-", 1.9, 0.12),
|
||||||
@@ -273,7 +297,7 @@ def plot_q2(df: pl.DataFrame, k_mid: float, k_rout: float, subtitle: str, out_pn
|
|||||||
curves.insert(3, ("on_hackneg", HACK, (0, (1, 1)), 1.2, 0.0))
|
curves.insert(3, ("on_hackneg", HACK, (0, (1, 1)), 1.2, 0.0))
|
||||||
ymax = 0.0
|
ymax = 0.0
|
||||||
for p, c, ls, lw, fill in curves:
|
for p, c, ls, lw, fill in curves:
|
||||||
yk = _kde(pops[p][col].to_numpy(), grid)
|
yk = _kde(z_of(pops[p][col].to_numpy(), p), grid)
|
||||||
ymax = max(ymax, yk.max())
|
ymax = max(ymax, yk.max())
|
||||||
if fill:
|
if fill:
|
||||||
ax.fill_between(grid, yk, color=c, alpha=fill, lw=0)
|
ax.fill_between(grid, yk, color=c, alpha=fill, lw=0)
|
||||||
@@ -283,8 +307,9 @@ def plot_q2(df: pl.DataFrame, k_mid: float, k_rout: float, subtitle: str, out_pn
|
|||||||
# precision claim rests on a handful of rollouts -- show them). hack row on top,
|
# precision claim rests on a handful of rollouts -- show them). hack row on top,
|
||||||
# in a strip below y=0 (faint separator line); tick labels stay outside the axes.
|
# in a strip below y=0 (faint separator line); tick labels stay outside the axes.
|
||||||
ax.axhline(0, color="#cccccc", lw=0.6, zorder=0)
|
ax.axhline(0, color="#cccccc", lw=0.6, zorder=0)
|
||||||
|
ax.axvline(0, color="#bbbbbb", lw=0.7, zorder=0) # 0 = family mean (z-norm)
|
||||||
for row, (p, c) in enumerate((("on_hackpos", HACK), ("on_fail", FAIL), ("on_solve", SOLVE))):
|
for row, (p, c) in enumerate((("on_hackpos", HACK), ("on_fail", FAIL), ("on_solve", SOLVE))):
|
||||||
x = pops[p][col].to_numpy()
|
x = z_of(pops[p][col].to_numpy(), p)
|
||||||
ax.plot(x, np.full(len(x), -(0.035 + 0.035 * row) * ymax), "|",
|
ax.plot(x, np.full(len(x), -(0.035 + 0.035 * row) * ymax), "|",
|
||||||
color=c, ms=4, alpha=0.6, mew=0.8)
|
color=c, ms=4, alpha=0.6, mew=0.8)
|
||||||
|
|
||||||
@@ -308,8 +333,8 @@ def plot_q2(df: pl.DataFrame, k_mid: float, k_rout: float, subtitle: str, out_pn
|
|||||||
ax.spines[sp].set_visible(False)
|
ax.spines[sp].set_visible(False)
|
||||||
ax.set_title(f"{rep} · {kind} AUROC={auroc_pos:.2f} (A>0 contrast; vs-all {auroc:.2f}) "
|
ax.set_title(f"{rep} · {kind} AUROC={auroc_pos:.2f} (A>0 contrast; vs-all {auroc:.2f}) "
|
||||||
f"P@rout={prec:.2f} (n={n_rout}) R={rec:.2f}", fontsize=9)
|
f"P@rout={prec:.2f} (n={n_rout}) R={rec:.2f}", fontsize=9)
|
||||||
ax.set_xlabel({"cos": "cosine to v (concat modules)",
|
ax.set_xlabel({"cos": "cosine to v (concat modules), z within family",
|
||||||
"dot": "dot ⟨x, v⟩ (update mass along v)"}[kind], fontsize=8.5)
|
"dot": "dot ⟨x, v⟩, z within family"}[kind], fontsize=8.5)
|
||||||
|
|
||||||
handles = [Line2D([0], [0], color=SOLVE, lw=1.9), Line2D([0], [0], color=FAIL, lw=1.9),
|
handles = [Line2D([0], [0], color=SOLVE, lw=1.9), Line2D([0], [0], color=FAIL, lw=1.9),
|
||||||
Line2D([0], [0], color=HACK, lw=1.9),
|
Line2D([0], [0], color=HACK, lw=1.9),
|
||||||
@@ -318,7 +343,7 @@ def plot_q2(df: pl.DataFrame, k_mid: float, k_rout: float, subtitle: str, out_pn
|
|||||||
Patch(facecolor=ABSORB_C, alpha=0.18), Patch(facecolor=ROUT_C, alpha=0.18),
|
Patch(facecolor=ABSORB_C, alpha=0.18), Patch(facecolor=ROUT_C, alpha=0.18),
|
||||||
Line2D([0], [0], color=ORACLE, lw=1.3, ls="-.")]
|
Line2D([0], [0], color=ORACLE, lw=1.3, ls="-.")]
|
||||||
labels = ["live solve", "live fail", "live hack (A>0)", "synthetic clean", "synthetic hack",
|
labels = ["live solve", "live fail", "live hack (A>0)", "synthetic clean", "synthetic hack",
|
||||||
f"absorb (>mean+{k_mid:g}sd)", f"rout (>=mean+{k_rout:g}sd)", "oracle hack/rest split"]
|
"absorb (otsu lo, label-free)", "rout (otsu hi, label-free)", "oracle hack/rest split"]
|
||||||
fig.legend(handles, labels, loc="lower center", ncol=4, fontsize=8, frameon=False)
|
fig.legend(handles, labels, loc="lower center", ncol=4, fontsize=8, frameon=False)
|
||||||
fig.suptitle(subtitle, fontsize=10)
|
fig.suptitle(subtitle, fontsize=10)
|
||||||
fig.tight_layout(rect=(0, 0.07, 1, 0.95))
|
fig.tight_layout(rect=(0, 0.07, 1, 0.95))
|
||||||
@@ -335,12 +360,25 @@ def main(cfg: Cfg) -> int:
|
|||||||
feats_path = cfg.out_dir / "pinning_feats.pt"
|
feats_path = cfg.out_dir / "pinning_feats.pt"
|
||||||
q2_png = cfg.out_dir / "pinning_q2.png"
|
q2_png = cfg.out_dir / "pinning_q2.png"
|
||||||
if cfg.replot is not None:
|
if cfg.replot is not None:
|
||||||
plot_q2(pl.read_parquet(cfg.replot), cfg.k_mid, cfg.k_rout, f"replot -- {cfg.replot.name}", q2_png)
|
plot_q2(pl.read_parquet(cfg.replot), f"replot -- {cfg.replot.name}", q2_png)
|
||||||
if rank_path.exists():
|
if rank_path.exists():
|
||||||
print(tabulate(pl.read_parquet(rank_path).to_pandas(), headers="keys",
|
print(tabulate(pl.read_parquet(rank_path).to_pandas(), headers="keys",
|
||||||
tablefmt="pipe", floatfmt="+.3f", showindex=False))
|
tablefmt="pipe", floatfmt="+.3f", showindex=False))
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
if cfg.feats is not None:
|
||||||
|
fe = torch.load(cfg.feats, weights_only=False)
|
||||||
|
logger.info(f"offline re-analysis from {cfg.feats} (no GPU)")
|
||||||
|
src = str(cfg.feats)
|
||||||
|
else:
|
||||||
|
fe = _extract_feats(cfg, feats_path)
|
||||||
|
src = f"{cfg.run_dir.name} | {cfg.ckpt}"
|
||||||
|
return _downstream(cfg, fe, src)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_feats(cfg: Cfg, feats_path: Path) -> dict:
|
||||||
|
"""One GPU pass: features for every authored pair side and live rollout, saved to
|
||||||
|
feats_path. Everything downstream is offline re-projection (rerun via --feats)."""
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
ckpt_path = cfg.run_dir / f"{cfg.ckpt}.safetensors"
|
ckpt_path = cfg.run_dir / f"{cfg.ckpt}.safetensors"
|
||||||
meta = _ckpt_meta(ckpt_path)
|
meta = _ckpt_meta(ckpt_path)
|
||||||
@@ -349,8 +387,7 @@ def main(cfg: Cfg) -> int:
|
|||||||
r = run_cfg.get("lora_r", 32)
|
r = run_cfg.get("lora_r", 32)
|
||||||
init_seed = run_cfg.get("lora_init_seed", 0)
|
init_seed = run_cfg.get("lora_init_seed", 0)
|
||||||
logger.info(f"ckpt {ckpt_path.name} step={meta.get('step')} hack_rate={meta.get('hack_rate')} "
|
logger.info(f"ckpt {ckpt_path.name} step={meta.get('step')} hack_rate={meta.get('hack_rate')} "
|
||||||
f"model={model_name} r={r} init_seed={init_seed} | run-preset k_mid/k_rout="
|
f"model={model_name} r={r} init_seed={init_seed}")
|
||||||
f"{run_cfg.get('route_std_mid')}/{run_cfg.get('route_std_rout')} (plot uses {cfg.k_mid}/{cfg.k_rout})")
|
|
||||||
|
|
||||||
tok = AutoTokenizer.from_pretrained(model_name)
|
tok = AutoTokenizer.from_pretrained(model_name)
|
||||||
if tok.pad_token_id is None:
|
if tok.pad_token_id is None:
|
||||||
@@ -425,6 +462,27 @@ def main(cfg: Cfg) -> int:
|
|||||||
m = (steps == s) & (p_idx == p)
|
m = (steps == s) & (p_idx == p)
|
||||||
grp_mean[(s, p)] = reward[m].mean()
|
grp_mean[(s, p)] = reward[m].mean()
|
||||||
adv = np.array([reward[i] - grp_mean[(steps[i], p_idx[i])] for i in range(len(reward))])
|
adv = np.array([reward[i] - grp_mean[(steps[i], p_idx[i])] for i in range(len(reward))])
|
||||||
|
groups: dict[str, list[int]] = defaultdict(list)
|
||||||
|
for i, p in enumerate(pairs_all):
|
||||||
|
groups[p.problem_id.split("_")[0]].append(i)
|
||||||
|
fe = {"G": G, "ACT": ACT, "RES": RES, "adv": adv, "exploited": exploited,
|
||||||
|
"gt_pass": gt_pass, "steps": steps, "p_idx": p_idx, "names": names,
|
||||||
|
"resid_layers": cfg.resid_layers, "pair_feats": PF, "pair_groups": dict(groups),
|
||||||
|
"pair_ids": [p.problem_id for p in pairs_all]}
|
||||||
|
torch.save(fe, feats_path)
|
||||||
|
logger.info(f"wrote {feats_path}")
|
||||||
|
return fe
|
||||||
|
|
||||||
|
|
||||||
|
def _downstream(cfg: Cfg, fe: dict, src: str) -> int:
|
||||||
|
"""Scores, pairset table, parquet, and plot from the feature dict -- no GPU."""
|
||||||
|
data_path = cfg.out_dir / "pinning_data.parquet"
|
||||||
|
rank_path = cfg.out_dir / "pinning_pairset.parquet"
|
||||||
|
q2_png = cfg.out_dir / "pinning_q2.png"
|
||||||
|
PF, pair_ids = fe["pair_feats"], fe["pair_ids"]
|
||||||
|
G, ACT, RES = fe["G"], fe["ACT"], fe["RES"]
|
||||||
|
adv, exploited, gt_pass = fe["adv"], fe["exploited"], fe["gt_pass"]
|
||||||
|
steps, p_idx = fe["steps"], fe["p_idx"]
|
||||||
G_adv = G * torch.tensor(adv, dtype=G.dtype)[:, None, None] # the update the gate sees
|
G_adv = G * torch.tensor(adv, dtype=G.dtype)[:, None, None] # the update the gate sees
|
||||||
|
|
||||||
# ── Q2 populations: drop A~0 (no update); positive = exploited & A>0 ──
|
# ── Q2 populations: drop A~0 (no update); positive = exploited & A>0 ──
|
||||||
@@ -440,10 +498,8 @@ def main(cfg: Cfg) -> int:
|
|||||||
f"too few learnable hacks and every AUROC below is noise.")
|
f"too few learnable hacks and every AUROC below is noise.")
|
||||||
|
|
||||||
# ── headline vectors from the routeV-default subset; placebo swaps in Haar ──
|
# ── headline vectors from the routeV-default subset; placebo swaps in Haar ──
|
||||||
groups: dict[str, list[int]] = defaultdict(list)
|
groups: dict[str, list[int]] = fe["pair_groups"]
|
||||||
for i, p in enumerate(pairs_all):
|
head_idx = [i for i, pid in enumerate(pair_ids) if pid.startswith(cfg.headline_prefix)]
|
||||||
groups[p.problem_id.split("_")[0]].append(i)
|
|
||||||
head_idx = [i for i, p in enumerate(pairs_all) if p.problem_id.startswith(cfg.headline_prefix)]
|
|
||||||
assert head_idx, f"no pairs match headline prefix {cfg.headline_prefix!r}"
|
assert head_idx, f"no pairs match headline prefix {cfg.headline_prefix!r}"
|
||||||
logger.info(f"headline v from prefix {cfg.headline_prefix!r} -> {len(head_idx)} pairs")
|
logger.info(f"headline v from prefix {cfg.headline_prefix!r} -> {len(head_idx)} pairs")
|
||||||
|
|
||||||
@@ -466,7 +522,7 @@ def main(cfg: Cfg) -> int:
|
|||||||
for side in ("clean", "hack")}
|
for side in ("clean", "hack")}
|
||||||
|
|
||||||
# ── pairset table: subsets x 4 AUROCs on the SAME cached live features ──
|
# ── pairset table: subsets x 4 AUROCs on the SAME cached live features ──
|
||||||
candidates = [("all-in-one", list(range(len(pairs_all))))] + \
|
candidates = [("all-in-one", list(range(len(pair_ids))))] + \
|
||||||
[(g, idx) for g, idx in sorted(groups.items()) if len(idx) >= 3]
|
[(g, idx) for g, idx in sorted(groups.items()) if len(idx) >= 3]
|
||||||
valid_pos = valid & (adv > 0)
|
valid_pos = valid & (adv > 0)
|
||||||
rows = []
|
rows = []
|
||||||
@@ -502,18 +558,14 @@ def main(cfg: Cfg) -> int:
|
|||||||
np.full(n_syn, -1), np.ones(n_syn)) for side in ("clean", "hack")]
|
np.full(n_syn, -1), np.ones(n_syn)) for side in ("clean", "hack")]
|
||||||
df = pl.concat(dfs)
|
df = pl.concat(dfs)
|
||||||
df.write_parquet(data_path)
|
df.write_parquet(data_path)
|
||||||
torch.save({"G": G, "ACT": ACT, "RES": RES, "adv": adv, "exploited": exploited,
|
logger.info(f"wrote {data_path} ({len(df)} rows)")
|
||||||
"gt_pass": gt_pass, "steps": steps, "p_idx": p_idx, "names": names,
|
|
||||||
"resid_layers": cfg.resid_layers, "pair_feats": PF, "pair_groups": dict(groups),
|
|
||||||
"pair_ids": [p.problem_id for p in pairs_all]}, feats_path)
|
|
||||||
logger.info(f"wrote {data_path} ({len(df)} rows), {feats_path}")
|
|
||||||
|
|
||||||
sub = (f"{cfg.run_dir.name} | {cfg.ckpt}, live steps {cfg.step_lo}-{cfg.step_hi}, v from "
|
sub = (f"{src}, live steps {int(steps.min())}-{int(steps.max())}, v from "
|
||||||
f"'{cfg.headline_prefix}' pairs (n={len(head_idx)}) | "
|
f"'{cfg.headline_prefix}' pairs (n={len(head_idx)}) | "
|
||||||
f"hack+={counts['on_hackpos']} hack-={counts['on_hackneg']} solve={counts['on_solve']} "
|
f"hack+={counts['on_hackpos']} hack-={counts['on_hackneg']} solve={counts['on_solve']} "
|
||||||
f"fail={counts['on_fail']} dropped(A~0)={counts['on_drop']}"
|
f"fail={counts['on_fail']} dropped(A~0)={counts['on_drop']}"
|
||||||
+ (f" | PLACEBO seed={cfg.random_v_seed}" if cfg.random_v_seed is not None else ""))
|
+ (f" | PLACEBO seed={cfg.random_v_seed}" if cfg.random_v_seed is not None else ""))
|
||||||
stats = plot_q2(df, cfg.k_mid, cfg.k_rout, sub, q2_png)
|
stats = plot_q2(df, sub, q2_png)
|
||||||
best = max(stats, key=lambda c: stats[c]["auroc_pos"])
|
best = max(stats, key=lambda c: stats[c]["auroc_pos"])
|
||||||
print(f"\nmain metric: best case on the A>0 contrast = {best} "
|
print(f"\nmain metric: best case on the A>0 contrast = {best} "
|
||||||
f"AUROC={stats[best]['auroc_pos']:.3f} (vs-all {stats[best]['auroc_all']:.3f}) "
|
f"AUROC={stats[best]['auroc_pos']:.3f} (vs-all {stats[best]['auroc_all']:.3f}) "
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ from vgrout.lora2r import wrap_model_with_lora2r
|
|||||||
from vgrout.pairs import load_pairs
|
from vgrout.pairs import load_pairs
|
||||||
from vgrout.extract_vhack_grad import extract_v_hack, completion_nll
|
from vgrout.extract_vhack_grad import extract_v_hack, completion_nll
|
||||||
from vgrout.train import _build_v_grad, route_band_edges, _auroc
|
from vgrout.train import _build_v_grad, route_band_edges, _auroc
|
||||||
from diag_pinning import _kde, SOLVE, HACK, MEANC # same dir (scripts/ is sys.path[0] when run directly)
|
from diag_pinning import _kde, SOLVE, HACK, ABSORB_C as MEANC # same dir (scripts/ is sys.path[0] when run directly)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
Reference in New Issue
Block a user