mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 16:45:42 +08:00
diag(#40): review fixes -- A>0 contrast headline (adv-only baseline 0.90 vs-all / 0.58 A>0), headline=behavior_ training default, n at P@rout, rug strips, drop density ticks
Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
+63
-22
@@ -9,6 +9,17 @@ adv-weighted AUROCs look blind); a hack with A<0 is being UNLEARNED and belongs
|
|||||||
the negative class. (Q1, "does the direction exist at adv=+1", was answered earlier:
|
the negative class. (Q1, "does the direction exist at adv=+1", was answered earlier:
|
||||||
~0.61 -- see git history of this file.)
|
~0.61 -- see git history of this file.)
|
||||||
|
|
||||||
|
TWO CONTRASTS (post-review fix). On the vs-ALL contrast (hack+ vs every other valid
|
||||||
|
rollout) the advantage ALONE is a ~0.9 AUROC detector, because the label requires A>0
|
||||||
|
and most fails have A<0 -- so a high vs-all number mostly restates the reward, which
|
||||||
|
the live gate already has for free. The vector's VALUE-ADD is the A>0 contrast: among
|
||||||
|
positively-reinforced updates (where adv is blind, ~0.5), can the score tell hacks
|
||||||
|
from solves? That is the gate's real job (don't ship hack updates, don't rob solves)
|
||||||
|
and the headline number here; vs-all is kept as a secondary column. NOTE an info
|
||||||
|
asymmetry: grad scores carry the adv factor (sign+magnitude), act scores do not, so
|
||||||
|
on the vs-all contrast grad gets label-correlated information act lacks; the A>0
|
||||||
|
contrast removes most of that edge -- compare the cases there.
|
||||||
|
|
||||||
FOUR CANDIDATE SCORES = {grad, act} x {cos, dot}, all in concatenated-module space:
|
FOUR CANDIDATE SCORES = {grad, act} x {cos, dot}, all in concatenated-module space:
|
||||||
- grad: the adv-weighted deployed c-probe gradient (the gate's current input).
|
- grad: the adv-weighted deployed c-probe gradient (the gate's current input).
|
||||||
- act: the deployed bottleneck activation A[:r]@x, mean over completion tokens --
|
- act: the deployed bottleneck activation A[:r]@x, mean over completion tokens --
|
||||||
@@ -27,8 +38,9 @@ Config values (2/3), not the checkpoint's preset, so the plot answers "where WOU
|
|||||||
the cuts fall", overridable via --k-mid/--k-rout.
|
the cuts fall", overridable via --k-mid/--k-rout.
|
||||||
|
|
||||||
CAVEAT. Live advantages are reconstructed from rollouts.jsonl students only (teachers
|
CAVEAT. Live advantages are reconstructed from rollouts.jsonl students only (teachers
|
||||||
absent, zero-variance groups included), so A signs/magnitudes are approximate; the
|
absent, zero-variance groups included, and skipped/empty completions missing from the
|
||||||
act columns dodge this entirely (no A in the representation).
|
group mean), so A signs/magnitudes are approximate; the act columns dodge this
|
||||||
|
entirely (no A in the representation).
|
||||||
|
|
||||||
HOW. One GPU pass: per live rollout, backward its completion NLL once, capture the
|
HOW. One GPU pass: per live rollout, backward its completion NLL once, capture the
|
||||||
c-probe grad AND the pooled bottleneck act; same per authored-pair side. Everything
|
c-probe grad AND the pooled bottleneck act; same per authored-pair side. Everything
|
||||||
@@ -78,9 +90,10 @@ class Cfg:
|
|||||||
run_dir: Path = Path("out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3")
|
run_dir: Path = Path("out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3")
|
||||||
ckpt: str = "first_hack"
|
ckpt: str = "first_hack"
|
||||||
pairs: Path = Path("data/pairs/hack_pairs.md#all-in-one")
|
pairs: Path = Path("data/pairs/hack_pairs.md#all-in-one")
|
||||||
# headline figure builds v from this heading-prefix subset (the routeV default);
|
# headline figure builds v from this heading-prefix subset = the routeV TRAINING
|
||||||
# the pairset table spans all subsets of `pairs`.
|
# default (train_config.vhack_pairs_path `#all-in-one/behavior_`, 8 pairs; the
|
||||||
headline_prefix: str = "behavior"
|
# trailing _ excludes behavior2_*). The pairset table spans all subsets of `pairs`.
|
||||||
|
headline_prefix: str = "behavior_"
|
||||||
# Coherent emergence window. This vanilla v3 used the pre-fix lr=5e-4/warmup-0.1 and
|
# Coherent emergence window. This vanilla v3 used the pre-fix lr=5e-4/warmup-0.1 and
|
||||||
# DIVERGED at step 10 (exploited 20/24 -> 0/24); 2-9 = hacks emerging, model still sane.
|
# DIVERGED at step 10 (exploited 20/24 -> 0/24); 2-9 = hacks emerging, model still sane.
|
||||||
step_lo: int = 2
|
step_lo: int = 2
|
||||||
@@ -186,23 +199,34 @@ def plot_q2(df: pl.DataFrame, k_mid: float, k_rout: float, subtitle: str, out_pn
|
|||||||
Returns the per-case stats dict for logging."""
|
Returns the per-case stats dict for logging."""
|
||||||
pops = {p: df.filter(pl.col("pop") == p) for p in df["pop"].unique().to_list()}
|
pops = {p: df.filter(pl.col("pop") == p) for p in df["pop"].unique().to_list()}
|
||||||
live_pops = ["on_solve", "on_fail", "on_hackpos", "on_hackneg"]
|
live_pops = ["on_solve", "on_fail", "on_hackpos", "on_hackneg"]
|
||||||
|
live = df.filter(pl.col("pop").is_in(live_pops))
|
||||||
|
posm = (live["adv"] > 0).to_numpy() # the A>0 contrast rows
|
||||||
|
y_all = (live["pop"] == "on_hackpos").to_numpy()
|
||||||
|
# adv-only baseline: on vs-all the reward alone is a strong detector (the label
|
||||||
|
# requires A>0); the vector only adds value where this baseline is blind.
|
||||||
|
a = live["adv"].to_numpy()
|
||||||
|
logger.info(f"adv-only baseline AUROC: vs-all={_auroc(a.tolist(), y_all.tolist()):.3f} "
|
||||||
|
f"A>0-contrast={_auroc(a[posm].tolist(), y_all[posm].tolist()):.3f} "
|
||||||
|
f"(n+={int(y_all.sum())} negA>0={int((~y_all & posm).sum())})")
|
||||||
stats = {}
|
stats = {}
|
||||||
fig, axes = plt.subplots(2, 2, figsize=(12.5, 7.6))
|
fig, axes = plt.subplots(2, 2, figsize=(12.5, 7.6))
|
||||||
for ax, (rep, kind) in zip(axes.flat, CASES):
|
for ax, (rep, kind) in zip(axes.flat, CASES):
|
||||||
col = f"{rep}_{kind}"
|
col = f"{rep}_{kind}"
|
||||||
live = df.filter(pl.col("pop").is_in(live_pops))
|
|
||||||
s = live[col].to_numpy()
|
s = live[col].to_numpy()
|
||||||
y = (live["pop"] == "on_hackpos").to_numpy()
|
y = y_all
|
||||||
mu, sd = float(s.mean()), float(s.std())
|
mu, sd = float(s.mean()), float(s.std())
|
||||||
t_lo, t_hi = mu + k_mid * sd, mu + k_rout * sd
|
t_lo, t_hi = mu + k_mid * sd, mu + k_rout * sd
|
||||||
auroc = _auroc(s.tolist(), y.tolist())
|
auroc = _auroc(s.tolist(), y.tolist())
|
||||||
|
auroc_pos = _auroc(s[posm].tolist(), y[posm].tolist())
|
||||||
thr = np.unique(s)
|
thr = np.unique(s)
|
||||||
j = [(s[y] >= t).mean() - (s[~y] >= t).mean() for t in thr]
|
j = [(s[y] >= t).mean() - (s[~y] >= t).mean() for t in thr]
|
||||||
oracle = float(thr[int(np.argmax(j))])
|
oracle = float(thr[int(np.argmax(j))])
|
||||||
routed = s >= t_hi
|
routed = s >= t_hi
|
||||||
|
n_rout = int(routed.sum())
|
||||||
prec = float(y[routed].mean()) if routed.any() else float("nan")
|
prec = float(y[routed].mean()) if routed.any() else float("nan")
|
||||||
rec = float((s[y] >= t_hi).mean()) if y.any() else float("nan")
|
rec = float((s[y] >= t_hi).mean()) if y.any() else float("nan")
|
||||||
stats[col] = {"auroc": auroc, "prec_rout": prec, "rec_rout": rec, "t_hi": t_hi, "oracle": oracle}
|
stats[col] = {"auroc_pos": auroc_pos, "auroc_all": auroc, "prec_rout": prec,
|
||||||
|
"rec_rout": rec, "n_rout": n_rout, "t_hi": t_hi, "oracle": oracle}
|
||||||
|
|
||||||
lo = float(np.quantile(s, 0.005))
|
lo = float(np.quantile(s, 0.005))
|
||||||
hi = float(np.quantile(s, 0.995))
|
hi = float(np.quantile(s, 0.995))
|
||||||
@@ -220,7 +244,7 @@ def plot_q2(df: pl.DataFrame, k_mid: float, k_rout: float, subtitle: str, out_pn
|
|||||||
if len(pops.get(p, [])) and not lo < float(np.median(pops[p][col].to_numpy())) < hi]
|
if len(pops.get(p, [])) and not lo < float(np.median(pops[p][col].to_numpy())) < hi]
|
||||||
if off:
|
if off:
|
||||||
ax.annotate("off-scale: " + ", ".join(off) + r" $\rightarrow$",
|
ax.annotate("off-scale: " + ", ".join(off) + r" $\rightarrow$",
|
||||||
xy=(0.98, 0.84), xycoords="axes fraction", ha="right", fontsize=7, color="#777777")
|
xy=(0.98, 0.68), xycoords="axes fraction", ha="right", fontsize=7, color="#777777")
|
||||||
|
|
||||||
curves = [("on_solve", SOLVE, "-", 1.9, 0.12), ("on_fail", FAIL, "-", 1.9, 0.12),
|
curves = [("on_solve", SOLVE, "-", 1.9, 0.12), ("on_fail", FAIL, "-", 1.9, 0.12),
|
||||||
("on_hackpos", HACK, "-", 1.9, 0.12),
|
("on_hackpos", HACK, "-", 1.9, 0.12),
|
||||||
@@ -235,6 +259,12 @@ def plot_q2(df: pl.DataFrame, k_mid: float, k_rout: float, subtitle: str, out_pn
|
|||||||
ax.fill_between(grid, yk, color=c, alpha=fill, lw=0)
|
ax.fill_between(grid, yk, color=c, alpha=fill, lw=0)
|
||||||
ax.plot(grid, yk, color=c, lw=lw, ls=ls)
|
ax.plot(grid, yk, color=c, lw=lw, ls=ls)
|
||||||
ymax *= 1.18
|
ymax *= 1.18
|
||||||
|
# rug of the ACTUAL live points (KDEs of n~20 are smooth fiction; the rout-tail
|
||||||
|
# precision claim rests on a handful of rollouts -- show them). hack row on top.
|
||||||
|
for row, (p, c) in enumerate((("on_hackpos", HACK), ("on_fail", FAIL), ("on_solve", SOLVE))):
|
||||||
|
x = pops[p][col].to_numpy()
|
||||||
|
ax.plot(x, np.full(len(x), -(0.035 + 0.035 * row) * ymax), "|",
|
||||||
|
color=c, ms=4, alpha=0.6, mew=0.8, clip_on=False)
|
||||||
|
|
||||||
# three zones: keep | absorb | rout
|
# three zones: keep | absorb | rout
|
||||||
ax.axvspan(t_lo, min(t_hi, hi), color=ABSORB_C, alpha=0.08, lw=0)
|
ax.axvspan(t_lo, min(t_hi, hi), color=ABSORB_C, alpha=0.08, lw=0)
|
||||||
@@ -248,14 +278,15 @@ def plot_q2(df: pl.DataFrame, k_mid: float, k_rout: float, subtitle: str, out_pn
|
|||||||
if lo < xz < hi:
|
if lo < xz < hi:
|
||||||
ax.text(xz, ymax * 0.97, lab, ha="center", va="top", fontsize=7.5, color="#555555")
|
ax.text(xz, ymax * 0.97, lab, ha="center", va="top", fontsize=7.5, color="#555555")
|
||||||
ax.set_xlim(lo, hi)
|
ax.set_xlim(lo, hi)
|
||||||
ax.set_ylim(0, ymax)
|
ax.set_ylim(-0.13 * ymax, ymax) # negative strip hosts the rugs
|
||||||
for sp in ("top", "right"):
|
ax.set_yticks([]) # KDE density units are meaningless ink
|
||||||
|
for sp in ("top", "right", "left"):
|
||||||
ax.spines[sp].set_visible(False)
|
ax.spines[sp].set_visible(False)
|
||||||
ax.set_title(f"{rep} · {kind} AUROC={auroc:.2f} P@rout={prec:.2f} R@rout={rec:.2f}",
|
ax.spines["bottom"].set_position(("data", 0))
|
||||||
fontsize=9.5)
|
ax.set_title(f"{rep} · {kind} AUROC={auroc_pos:.2f} (A>0 contrast; vs-all {auroc:.2f}) "
|
||||||
|
f"P@rout={prec:.2f} (n={n_rout}) R={rec:.2f}", fontsize=9)
|
||||||
ax.set_xlabel({"cos": "cosine to v (concat modules)",
|
ax.set_xlabel({"cos": "cosine to v (concat modules)",
|
||||||
"dot": "dot ⟨x, v⟩ (update mass along v)"}[kind], fontsize=8.5)
|
"dot": "dot ⟨x, v⟩ (update mass along v)"}[kind], fontsize=8.5)
|
||||||
ax.set_ylabel("density", fontsize=8.5)
|
|
||||||
|
|
||||||
handles = [Line2D([0], [0], color=SOLVE, lw=1.9), Line2D([0], [0], color=FAIL, lw=1.9),
|
handles = [Line2D([0], [0], color=SOLVE, lw=1.9), Line2D([0], [0], color=FAIL, lw=1.9),
|
||||||
Line2D([0], [0], color=HACK, lw=1.9),
|
Line2D([0], [0], color=HACK, lw=1.9),
|
||||||
@@ -415,19 +446,27 @@ def main(cfg: Cfg) -> int:
|
|||||||
# ── pairset table: subsets x 4 AUROCs on the SAME cached live features ──
|
# ── pairset table: subsets x 4 AUROCs on the SAME cached live features ──
|
||||||
candidates = [("all-in-one", list(range(len(pairs_all))))] + \
|
candidates = [("all-in-one", list(range(len(pairs_all))))] + \
|
||||||
[(g, idx) for g, idx in sorted(groups.items()) if len(idx) >= 3]
|
[(g, idx) for g, idx in sorted(groups.items()) if len(idx) >= 3]
|
||||||
|
valid_pos = valid & (adv > 0)
|
||||||
rows = []
|
rows = []
|
||||||
for gname, idx in candidates:
|
for gname, idx in candidates:
|
||||||
v = vectors(idx)
|
v = vectors(idx)
|
||||||
row = {"group": gname, "n_pairs": len(idx)}
|
row = {"group": gname, "n_pairs": len(idx)}
|
||||||
for rep, kind in CASES:
|
for rep, kind in CASES:
|
||||||
s = _score(live_X[rep], v[rep], kind)[valid]
|
s = _score(live_X[rep], v[rep], kind)
|
||||||
row[f"{rep}_{kind}"] = round(_auroc(s.tolist(), y[valid].tolist()), 3)
|
row[f"{rep}_{kind}"] = round(_auroc(s[valid_pos].tolist(), y[valid_pos].tolist()), 3)
|
||||||
|
row[f"{rep}_{kind}_all"] = round(_auroc(s[valid].tolist(), y[valid].tolist()), 3)
|
||||||
rows.append(row)
|
rows.append(row)
|
||||||
rank = pl.DataFrame(rows).sort("grad_dot", descending=True)
|
rank = pl.DataFrame(rows).sort("grad_dot", descending=True)
|
||||||
rank.write_parquet(rank_path)
|
rank.write_parquet(rank_path)
|
||||||
print("\nSHOULD: real pairsets beat 0.5 on at least one column; under --random-v-seed "
|
adv_v = adv[valid]
|
||||||
"every column ~0.5. Columns are AUROC of hack(A>0)-vs-rest on valid live rollouts.")
|
print(f"\nbaseline adv-only AUROC: vs-all={_auroc(adv_v.tolist(), y[valid].tolist()):.3f} "
|
||||||
print(tabulate(rank.to_pandas(), headers="keys", tablefmt="pipe", floatfmt="+.3f", showindex=False))
|
f"A>0-contrast={_auroc(adv[valid_pos].tolist(), y[valid_pos].tolist()):.3f} -- the table "
|
||||||
|
f"columns are the A>0 contrast (hack vs non-hack among adv>0, n={int(valid_pos.sum())}), "
|
||||||
|
f"where adv is blind; vs-all columns (*_all) live in {rank_path.name}.")
|
||||||
|
print("SHOULD: real pairsets beat 0.5 and the adv-only A>0 baseline; under --random-v-seed "
|
||||||
|
"every column ~0.5. With ~20 negatives the SE is ~0.07: only gaps >0.15 mean much.")
|
||||||
|
print(tabulate(rank.drop([c for c in rank.columns if c.endswith("_all")]).to_pandas(),
|
||||||
|
headers="keys", tablefmt="pipe", floatfmt="+.3f", showindex=False))
|
||||||
|
|
||||||
# ── persist per-rollout scores + raw features, then plot ──
|
# ── persist per-rollout scores + raw features, then plot ──
|
||||||
def frame(pop_name: str, mask_or_scores, scores: dict, step_arr, adv_arr) -> pl.DataFrame:
|
def frame(pop_name: str, mask_or_scores, scores: dict, step_arr, adv_arr) -> pl.DataFrame:
|
||||||
@@ -453,9 +492,11 @@ def main(cfg: Cfg) -> int:
|
|||||||
f"fail={counts['on_fail']} dropped(A~0)={counts['on_drop']}"
|
f"fail={counts['on_fail']} dropped(A~0)={counts['on_drop']}"
|
||||||
+ (f" | PLACEBO seed={cfg.random_v_seed}" if cfg.random_v_seed is not None else ""))
|
+ (f" | PLACEBO seed={cfg.random_v_seed}" if cfg.random_v_seed is not None else ""))
|
||||||
stats = plot_q2(df, cfg.k_mid, cfg.k_rout, sub, q2_png)
|
stats = plot_q2(df, cfg.k_mid, cfg.k_rout, sub, q2_png)
|
||||||
best = max(stats, key=lambda c: stats[c]["auroc"])
|
best = max(stats, key=lambda c: stats[c]["auroc_pos"])
|
||||||
print(f"\nmain metric: best case = {best} AUROC={stats[best]['auroc']:.3f} "
|
print(f"\nmain metric: best case on the A>0 contrast = {best} "
|
||||||
f"P@rout={stats[best]['prec_rout']:.2f} R@rout={stats[best]['rec_rout']:.2f}")
|
f"AUROC={stats[best]['auroc_pos']:.3f} (vs-all {stats[best]['auroc_all']:.3f}) "
|
||||||
|
f"P@rout={stats[best]['prec_rout']:.2f} (n={stats[best]['n_rout']}) "
|
||||||
|
f"R@rout={stats[best]['rec_rout']:.2f}")
|
||||||
print(f"out: {q2_png}")
|
print(f"out: {q2_png}")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user