mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 17:15:58 +08:00
viz: floor->ceiling as two normalized panels (best vs control vs reference)
Rework per feedback: hack and solve are not opposites, so they get separate floor->ceiling axes (each 0=floor..1=ceiling) rather than sharing a zero -- this also stops solve (range ~0.13-0.22) being squished next to hack (0-0.61). Minimal: routeV per-token (best) vs random-V (direction control) vs the SGTM gradient-routing paper placed on the same floor->ceiling % axis (approx, LM task). Reads: hack suppression 93% best / 84% control / ~98% reference (9pp = direction signal); solve gained +17% / -17% / ~95% (far from ceiling -- model barely learns to solve in 60 steps). Moved out/plots -> out/figs. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -40,7 +40,7 @@ import matplotlib.pyplot as plt
|
||||
|
||||
RED, GREEN, GREY = "#c0392b", "#1e8449", "#9aa0a6"
|
||||
RUNS = Path("out/runs")
|
||||
OUT = Path("out/plots")
|
||||
OUT = Path("out/figs")
|
||||
CSV = OUT / "floor_ceiling.csv"
|
||||
PAPER_CEILING = 0.223 # Ariahw et al. no-loophole solve -- provisional fast-env ceiling
|
||||
|
||||
@@ -106,6 +106,14 @@ def build_csv() -> pl.DataFrame:
|
||||
|
||||
|
||||
# ── stage 2: plot from the csv ──────────────────────────────────────────────
|
||||
# Reference: the gradient-routing paper (SGTM, Mhaskar et al. 2025) reports its result as a
|
||||
# retain/forget trade-off vs a "perfect filter" oracle (= our ceiling) and "no filter" (= our
|
||||
# floor). Placed on the SAME floor->ceiling % axis (approximate; LM-unlearning task, not RL):
|
||||
# forget suppression ~leakage 0.02 -> ~98%; retain ~5% compute penalty -> ~95% of oracle.
|
||||
SGTM_REF = dict(label="SGTM grad-routing\n(LM unlearn, ~approx)", hack_supp=0.98, solve_uplift=0.95)
|
||||
GOLD, DARK = "#c8920a", "#3a3a3a"
|
||||
|
||||
|
||||
def _anchors(df: pl.DataFrame) -> dict:
|
||||
g = lambda kind, col: df.filter(pl.col("kind") == kind)[col][0]
|
||||
ceil_status = g("anchor_ceiling", "status")
|
||||
@@ -115,51 +123,20 @@ def _anchors(df: pl.DataFrame) -> dict:
|
||||
provisional=ceil_status.startswith("FIXME"))
|
||||
|
||||
|
||||
def _panel_normalized(ax, methods: pl.DataFrame, a, title):
|
||||
base, vh, ceil = a["base_solve"], a["vanilla_hack"], a["ceiling"]
|
||||
labels = [l for l in methods["label"] if l != "vanilla GRPO"] # vanilla = the 0% hack anchor
|
||||
for yi, lab in enumerate(labels):
|
||||
r = methods.filter(pl.col("label") == lab).to_dicts()[0]
|
||||
hack_rm = (vh - r["hack_deploy"]) / vh
|
||||
solve_rc = (r["solve_deploy"] - base) / (ceil - base)
|
||||
ax.barh(yi + 0.18, hack_rm, height=0.32, color=RED, alpha=0.85)
|
||||
ax.text(hack_rm + 0.015, yi + 0.18, f"{r['hack_deploy']:.3f} ({hack_rm*100:.0f}%)",
|
||||
va="center", ha="left", fontsize=8, color=RED)
|
||||
ax.barh(yi - 0.18, solve_rc, height=0.32, color=GREEN, alpha=0.85)
|
||||
ax.text(solve_rc + 0.015 if solve_rc >= 0 else solve_rc - 0.015, yi - 0.18,
|
||||
f"{r['solve_deploy']:.3f} ({solve_rc*100:+.0f}%)",
|
||||
va="center", ha="left" if solve_rc >= 0 else "right", fontsize=8, color=GREEN)
|
||||
ax.axvline(0, color=GREY, lw=0.8)
|
||||
ax.axvline(1.0, color=GREY, lw=0.8, ls=":")
|
||||
ax.text(1.0, len(labels) - 0.35, "ceiling / no-hack", fontsize=7, color=GREY, ha="center")
|
||||
ax.set_yticks(range(len(labels))); ax.set_yticklabels(labels, fontsize=9)
|
||||
ax.set_xlim(-0.35, 1.25); ax.set_xlabel("fraction of floor→ceiling range (right = better)")
|
||||
def _bars(ax, rows, key, raws, title, xlabel, xlo):
|
||||
"""One floor->ceiling panel: horizontal bars in [xlo,1], 0=floor, 1.0=ceiling."""
|
||||
for yi, (lab, val, raw, col) in enumerate(rows):
|
||||
ax.barh(yi, val, height=0.55, color=col, alpha=0.9,
|
||||
hatch="//" if "approx" in lab else None, edgecolor="white")
|
||||
tip = f"{val*100:+.0f}%" if xlo < 0 else f"{val*100:.0f}%"
|
||||
rawtxt = f" ({raw})" if raw else ""
|
||||
ax.text(val + (0.02 if val >= 0 else -0.02), yi, tip + rawtxt,
|
||||
va="center", ha="left" if val >= 0 else "right", fontsize=8.5, color=col)
|
||||
ax.axvline(0, color=GREY, lw=1.0) # floor (labelled in xlabel)
|
||||
ax.axvline(1.0, color=GREY, lw=1.0, ls=":") # ceiling
|
||||
ax.set_yticks(range(len(rows))); ax.set_yticklabels([r[0] for r in rows], fontsize=8.5)
|
||||
ax.set_xlim(xlo, 1.18); ax.set_xlabel(xlabel, fontsize=8.5)
|
||||
ax.set_title(title, fontsize=10, loc="left")
|
||||
ax.text(0.01, 0.99, "red = hack removed (vs vanilla) green = solve recovered (base→ceiling)",
|
||||
transform=ax.transAxes, fontsize=7.5, color="#444", va="top")
|
||||
for s in ("top", "right", "left"):
|
||||
ax.spines[s].set_visible(False)
|
||||
ax.tick_params(left=False)
|
||||
|
||||
|
||||
def _panel_knob(ax, methods: pl.DataFrame):
|
||||
labels = list(methods["label"])
|
||||
for yi, lab in enumerate(labels):
|
||||
r = methods.filter(pl.col("label") == lab).to_dicts()[0]
|
||||
ax.annotate("", xy=(r["hack_off"], yi + 0.16), xytext=(r["hack_on"], yi + 0.16),
|
||||
arrowprops=dict(arrowstyle="->", color=RED, lw=1.6, alpha=0.9))
|
||||
ax.plot([r["hack_on"], r["hack_off"]], [yi + 0.16] * 2, "o", color=RED, ms=4, alpha=0.5)
|
||||
ax.text(r["hack_on"] + 0.012, yi + 0.16, f"on {r['hack_on']:.2f}", va="center", ha="left", fontsize=7, color=RED)
|
||||
ax.text(r["hack_off"] - 0.012, yi + 0.16, f"{r['hack_off']:.2f}", va="center", ha="right", fontsize=7.5, color=RED)
|
||||
ax.annotate("", xy=(r["solve_off"], yi - 0.16), xytext=(r["solve_on"], yi - 0.16),
|
||||
arrowprops=dict(arrowstyle="->", color=GREEN, lw=1.6, alpha=0.9))
|
||||
ax.plot([r["solve_on"], r["solve_off"]], [yi - 0.16] * 2, "o", color=GREEN, ms=4, alpha=0.5)
|
||||
ax.text(max(r["solve_on"], r["solve_off"]) + 0.012, yi - 0.16, f"solve {r['solve_off']:.2f}",
|
||||
va="center", ha="left", fontsize=7.5, color=GREEN)
|
||||
ax.set_yticks(range(len(labels))); ax.set_yticklabels(labels, fontsize=9)
|
||||
ax.set_xlim(-0.02, 0.80)
|
||||
ax.set_xlabel("rate (arrow = knob-ON → knob-OFF on held-out val; left = better for hack)")
|
||||
ax.set_title("B. the knob effect (held-out val n=32, L5 -- isolates the quarantine)", fontsize=10, loc="left")
|
||||
for s in ("top", "right", "left"):
|
||||
ax.spines[s].set_visible(False)
|
||||
ax.tick_params(left=False)
|
||||
@@ -167,14 +144,33 @@ def _panel_knob(ax, methods: pl.DataFrame):
|
||||
|
||||
def plot(df: pl.DataFrame) -> None:
|
||||
a = _anchors(df)
|
||||
methods = df.filter(pl.col("kind") == "method")
|
||||
prov = " [ceiling PROVISIONAL=0.223, FIXME job 24]" if a["provisional"] else ""
|
||||
fig, axes = plt.subplots(2, 1, figsize=(8.5, 8.0), gridspec_kw=dict(height_ratios=[1, 1.05]))
|
||||
_panel_normalized(axes[0], methods, a, f"A. normalized floor→ceiling, deploy (test n=119){prov}")
|
||||
_panel_knob(axes[1], methods)
|
||||
fig.suptitle("vGROUT: floor-to-ceiling method comparison (seed 43, 60-step fast)",
|
||||
fontsize=11, x=0.02, ha="left")
|
||||
fig.tight_layout(rect=(0, 0, 1, 0.97))
|
||||
base, vh, ceil = a["base_solve"], a["vanilla_hack"], a["ceiling"]
|
||||
pick = lambda lab: df.filter(pl.col("label") == lab).to_dicts()[0]
|
||||
best, rand = pick("routeV per-token"), pick("routeV random-V")
|
||||
|
||||
def hsupp(r): return (vh - r["hack_deploy"]) / vh
|
||||
def suplift(r): return (r["solve_deploy"] - base) / (ceil - base)
|
||||
|
||||
# rows: best (gold), random control (dark), SGTM reference (grey, hatched). Top row plots last.
|
||||
hack_rows = [
|
||||
(SGTM_REF["label"], SGTM_REF["hack_supp"], "~0.98 supp", GREY),
|
||||
("routeV random-V\n(direction control)", hsupp(rand), f"{rand['hack_deploy']:.3f}", DARK),
|
||||
("routeV per-token\n(best)", hsupp(best), f"{best['hack_deploy']:.3f}", GOLD),
|
||||
]
|
||||
solve_rows = [
|
||||
(SGTM_REF["label"], SGTM_REF["solve_uplift"], "~oracle", GREY),
|
||||
("routeV random-V\n(direction control)", suplift(rand), f"{rand['solve_deploy']:.3f}", DARK),
|
||||
("routeV per-token\n(best)", suplift(best), f"{best['solve_deploy']:.3f}", GOLD),
|
||||
]
|
||||
prov = " (ceiling PROVISIONAL=0.223, FIXME job 24)" if a["provisional"] else ""
|
||||
fig, (axl, axr) = plt.subplots(1, 2, figsize=(11, 3.2), sharey=True)
|
||||
_bars(axl, hack_rows, "hack", None,
|
||||
"hack suppressed", "floor (vanilla 0.613) → ceiling (no hack) · right = better", 0.0)
|
||||
_bars(axr, solve_rows, "solve", None,
|
||||
"solve gained", f"floor (base 0.126) → ceiling (no-loophole){prov} · right = better", -0.4)
|
||||
fig.suptitle("vGROUT floor→ceiling: best vs direction-control vs reference paper (test n=119, seed 43, 60-step fast)",
|
||||
fontsize=10.5, x=0.01, ha="left")
|
||||
fig.tight_layout(rect=(0, 0, 1, 0.94))
|
||||
for ext in ("pdf", "png"):
|
||||
fig.savefig(OUT / f"floor_ceiling.{ext}", dpi=150, bbox_inches="tight")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user