mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 16:30:30 +08:00
plot_dynamics: hack≡0 tags on overlay, labels on top panel, 2-panel train-vs-deploy
- overlay: floor hack panel below 0 so a pinned-at-0 line shows; direct-label the TOP (hack) panel not the bottom (read top-to-bottom); tag any arm whose series sits at 0 with $\equiv 0$. - train-vs-deploy: replace the 2x2 with one panel per arm, 4 series each -- colour=metric (red hack/green solve), linestyle=train(dashed)/deploy(solid). The route gap (dashed-red up, solid-red at 0) and vanilla overlap (train==deploy) read in one panel. two-axis legend (colour=metric, style=train/deploy). Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
+77
-56
@@ -40,6 +40,7 @@ from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from matplotlib.lines import Line2D
|
||||
from loguru import logger
|
||||
|
||||
from projected_grpo.figs import link_latest, save_fig, arm_label
|
||||
@@ -344,11 +345,12 @@ def plot(runs: list[dict], out: Path) -> None:
|
||||
logger.info(f"wrote {out} ({len(runs)} runs, arms={[arm_label(a) for a in arms]})")
|
||||
|
||||
|
||||
def _overlay_panel(ax, by_arm, arms, key, *, label, with_onset):
|
||||
def _overlay_panel(ax, by_arm, arms, key, *, label, with_onset, label_arms, ylim=(0, 1)):
|
||||
"""Overlay one metric (key) per arm on ax: faint per-seed EMA lines + bold
|
||||
EMA mean, optional mean-onset dot. Direct labels (only on the unlabeled-x panel)
|
||||
are de-collided in y so overlapping arms don't stack their text (collision test)."""
|
||||
ends = [] # (y_endpoint, x_endpoint, arm, color) for direct labels
|
||||
EMA mean, optional mean-onset dot. When label_arms, direct-label each arm at its
|
||||
endpoint (de-collided in y). An arm whose mean series sits at zero gets a
|
||||
"$\\equiv 0$" tag so a pinned-at-zero line reads as a finding, not a missing line."""
|
||||
ends = [] # (y_endpoint, x_endpoint, arm, color, is_zero) for direct labels
|
||||
for arm in arms:
|
||||
rs = [r for r in by_arm[arm] if key in r]
|
||||
if not rs:
|
||||
@@ -368,39 +370,42 @@ def _overlay_panel(ax, by_arm, arms, key, *, label, with_onset):
|
||||
if onsets:
|
||||
s0 = float(np.mean(onsets))
|
||||
ax.plot(s0, np.interp(s0, xm, ym), marker="o", ms=4, color=color, zorder=3)
|
||||
ends.append((float(ym[-1]), float(xm[-1]), arm, color))
|
||||
ax.set_ylim(0, 1)
|
||||
ends.append((float(ym[-1]), float(xm[-1]), arm, color, float(np.nanmax(ym)) < 0.02))
|
||||
ax.set_ylim(*ylim)
|
||||
ax.set_ylabel(label)
|
||||
ax.spines[["top", "right"]].set_visible(False)
|
||||
ax.tick_params(labelsize=8)
|
||||
# direct-label only one panel (caller passes with_onset=False for it) -- the
|
||||
# other shares colours, so labelling both is redundant ink (eraser test).
|
||||
if with_onset:
|
||||
if not label_arms: # other panel shares colours -- redundant ink
|
||||
return
|
||||
ends.sort(key=lambda e: e[0]) # bottom-to-top by endpoint
|
||||
gap = 0.052 # min y-separation in data units
|
||||
gap = 0.052 * (ylim[1] - ylim[0]) # min y-separation, scaled to the range
|
||||
placed = []
|
||||
for y, x, arm, color in ends:
|
||||
for y, x, arm, color, is_zero in ends:
|
||||
y_lab = y if not placed else max(y, placed[-1] + gap)
|
||||
placed.append(y_lab)
|
||||
text = arm_label(arm) + (r" $\equiv 0$" if is_zero else "")
|
||||
arrow = dict(arrowstyle="-", color=color, lw=0.5, shrinkA=0, shrinkB=0)
|
||||
ax.annotate(arm_label(arm), xy=(x, y), xytext=(x + 1.0, y_lab), textcoords="data",
|
||||
ax.annotate(text, xy=(x, y), xytext=(x + 1.0, y_lab), textcoords="data",
|
||||
color=color, fontsize=8, va="center",
|
||||
arrowprops=arrow if abs(y_lab - y) > 1e-3 else None)
|
||||
|
||||
|
||||
def plot_hack_overlay(runs: list[dict], out: Path) -> None:
|
||||
"""Two stacked panels sharing x: student hack rate (top) and solve rate (bottom)
|
||||
per arm. Faint per-seed EMA lines + bold EMA-5 mean; onset dot on the hack panel;
|
||||
arms direct-labelled once on the solve panel (shared colours, no legend)."""
|
||||
per arm. Faint per-seed EMA lines + bold EMA-5 mean; onset dot on the hack panel.
|
||||
Arms are direct-labelled on the TOP (hack) panel -- readers scan top-to-bottom, and
|
||||
the hack panel carries the headline (an arm pinned at 0 gets a $\\equiv 0$ tag)."""
|
||||
by_arm: dict[str, list[dict]] = defaultdict(list)
|
||||
for r in runs:
|
||||
by_arm[classify(r)].append(r)
|
||||
arms = [a for a in ARM_ORDER if a in by_arm]
|
||||
|
||||
fig, (ax_h, ax_s) = plt.subplots(2, 1, figsize=(5.2, 5.2), sharex=True)
|
||||
_overlay_panel(ax_h, by_arm, arms, "hack_s", label="hack rate", with_onset=True)
|
||||
_overlay_panel(ax_s, by_arm, arms, "gt_s", label="solve rate", with_onset=False)
|
||||
# floor the hack panel below 0 so a route line pinned at 0 draws above the axis
|
||||
_overlay_panel(ax_h, by_arm, arms, "hack_s", label="hack rate",
|
||||
with_onset=True, label_arms=True, ylim=(-0.035, 1.0))
|
||||
_overlay_panel(ax_s, by_arm, arms, "gt_s", label="solve rate",
|
||||
with_onset=False, label_arms=False, ylim=(0, 1.0))
|
||||
ax_s.set_xlabel("optimizer step")
|
||||
if SHOW_TITLE:
|
||||
ax_h.set_title("Hack vs solve rate by arm (EMA-5; dot = mean hack onset)", fontsize=10)
|
||||
@@ -410,57 +415,73 @@ def plot_hack_overlay(runs: list[dict], out: Path) -> None:
|
||||
|
||||
|
||||
def plot_train_vs_deploy(runs: list[dict], out: Path) -> None:
|
||||
"""2x2 small multiple: rows = train (adapter ON) / deploy (adapter OFF), cols = arm.
|
||||
The story in one figure: vanilla train == deploy (no quarantine, the reward
|
||||
hack is in the deployed weights); route trains while hacking but deploys clean,
|
||||
the hack is held in the deletable quarantine adapter. Same red=hack/green=solve
|
||||
as the other figures."""
|
||||
"""One panel per arm, four series each: {hack, solve} x {train, deploy}.
|
||||
Colour = metric (red hack / green solve); linestyle = train (adapter on, dashed)
|
||||
vs deploy (adapter off, solid). The route gap is the result -- dashed-red (train)
|
||||
rises while solid-red (deploy) sits at 0, because the hack lives in the deletable
|
||||
quarantine. For vanilla the dashed/solid pair coincides (train==deploy: the hack is
|
||||
in the shipped weights, nothing to delete). Matched n=64 eval on every series."""
|
||||
by_arm: dict[str, list[dict]] = defaultdict(list)
|
||||
for r in runs:
|
||||
by_arm[classify(r)].append(r)
|
||||
arms = [a for a in ARM_ORDER if a in by_arm]
|
||||
red, green = RATE_COLORS["hack_s"], RATE_COLORS["gt_s"]
|
||||
rows = [
|
||||
("train (adapter on)", {"hack_train": "hack", "solve_train": "solve"},
|
||||
{"hack_train": red, "solve_train": green}),
|
||||
("deploy (adapter off)", {"hk_dep": "hack", "slv_dep": "solve"},
|
||||
{"hk_dep": red, "slv_dep": green}),
|
||||
TRAIN_LS, DEPLOY_LS = (0, (4, 2)), "-"
|
||||
# (series_key, colour, linestyle, is_hack)
|
||||
SERIES = [
|
||||
("hack_train", red, TRAIN_LS, True),
|
||||
("hk_dep", red, DEPLOY_LS, True),
|
||||
("solve_train", green, TRAIN_LS, False),
|
||||
("slv_dep", green, DEPLOY_LS, False),
|
||||
]
|
||||
fig, axes = plt.subplots(2, len(arms), figsize=(3.0 * len(arms), 4.8),
|
||||
fig, axes = plt.subplots(1, len(arms), figsize=(3.4 * len(arms), 3.2),
|
||||
sharex=True, sharey=True, squeeze=False)
|
||||
for ci, arm in enumerate(arms):
|
||||
axes[0][ci].set_title(arm_label(arm), fontsize=10)
|
||||
for ri, (rlabel, cols, colors) in enumerate(rows):
|
||||
ax = axes[ri][ci]
|
||||
_series_panel(ax, by_arm[arm], cols, colors, ylim=(-0.035, 1.0),
|
||||
label_series=(ci == 0))
|
||||
hk_key = next(iter(cols))
|
||||
hk = [r[hk_key] for r in by_arm[arm] if hk_key in r]
|
||||
if hk and np.nanmax([np.nanmax(h) for h in hk]) < 0.02:
|
||||
ax.annotate("hack ≡ 0", (0.04, 0.0), xycoords=("axes fraction", "data"),
|
||||
color=red, fontsize=8, va="bottom",
|
||||
xytext=(0, 3), textcoords="offset points")
|
||||
# teacher-off curriculum: shade the teacher-ON region [0, toff] + a line at
|
||||
# the cut, so "hacks were teacher-seeded here, on-policy after" is visible.
|
||||
toffs = {r.get("teacher_off") for r in by_arm[arm] if r.get("teacher_off")}
|
||||
if toffs:
|
||||
toff = max(toffs)
|
||||
ax.axvspan(0, toff, color="0.85", alpha=0.5, zorder=0)
|
||||
ax.axvline(toff, color="0.55", lw=0.8, ls=(0, (4, 3)), zorder=1)
|
||||
if ri == 0:
|
||||
ax.annotate("teacher off", (toff, 1.0), color="0.4", fontsize=7,
|
||||
xytext=(2, -2), textcoords="offset points", va="top")
|
||||
if ci == 0:
|
||||
ax.set_ylabel(rlabel)
|
||||
ax.spines[["top", "right"]].set_visible(False)
|
||||
ax.tick_params(labelsize=8)
|
||||
for ax in axes[-1]:
|
||||
ax = axes[0][ci]
|
||||
ax.set_title(arm_label(arm), fontsize=10)
|
||||
deploy_hack_zero = False
|
||||
for key, color, ls, is_hack in SERIES:
|
||||
rs = [r for r in by_arm[arm] if key in r]
|
||||
if not rs:
|
||||
continue
|
||||
stacked = [_ema(r[key]) for r in rs]
|
||||
L = min(len(y) for y in stacked)
|
||||
ym = np.nanmean(np.stack([y[:L] for y in stacked]), axis=0)
|
||||
xm = rs[0]["steps"][:L]
|
||||
ax.plot(xm, ym, color=color, ls=ls, lw=1.8, solid_capstyle="round")
|
||||
if key == "hk_dep" and np.nanmax(ym) < 0.02:
|
||||
deploy_hack_zero = True
|
||||
if deploy_hack_zero: # the route headline: solid-red pinned at 0
|
||||
ax.annotate(r"deploy hack $\equiv 0$", (0.04, 0.0),
|
||||
xycoords=("axes fraction", "data"), color=red, fontsize=8,
|
||||
va="bottom", xytext=(0, 3), textcoords="offset points")
|
||||
# teacher-off curriculum: shade the teacher-ON region so "seeded here, on-policy
|
||||
# after" stays visible in the C4 bootstrap variant (jobs 93/94).
|
||||
toffs = {r.get("teacher_off") for r in by_arm[arm] if r.get("teacher_off")}
|
||||
if toffs:
|
||||
toff = max(toffs)
|
||||
ax.axvspan(0, toff, color="0.85", alpha=0.5, zorder=0)
|
||||
ax.axvline(toff, color="0.55", lw=0.8, ls=(0, (4, 3)), zorder=1)
|
||||
ax.annotate("teacher off", (toff, 1.0), color="0.4", fontsize=7,
|
||||
xytext=(2, -2), textcoords="offset points", va="top")
|
||||
ax.set_ylim(-0.035, 1.0)
|
||||
ax.set_xlabel("optimizer step")
|
||||
ax.spines[["top", "right"]].set_visible(False)
|
||||
ax.tick_params(labelsize=8)
|
||||
axes[0][0].set_ylabel("rate")
|
||||
# two-axis legend: colour = metric, linestyle = train vs deploy
|
||||
handles = [
|
||||
Line2D([], [], color=red, lw=1.8, label="hack"),
|
||||
Line2D([], [], color=green, lw=1.8, label="solve"),
|
||||
Line2D([], [], color="0.3", lw=1.8, ls=TRAIN_LS, label="train (adapter on)"),
|
||||
Line2D([], [], color="0.3", lw=1.8, ls=DEPLOY_LS, label="deploy (adapter off)"),
|
||||
]
|
||||
axes[0][-1].legend(handles=handles, fontsize=7, frameon=False, loc="upper left")
|
||||
if SHOW_TITLE:
|
||||
fig.suptitle("Train (adapter on) vs deploy (adapter off): vanilla puts the "
|
||||
"reward hack in the weights, route in the deletable adapter (EMA-5)",
|
||||
fig.suptitle("Train (adapter on) vs deploy (adapter off): vanilla bakes the "
|
||||
"hack into the weights, route holds it in the deletable adapter",
|
||||
fontsize=10)
|
||||
fig.tight_layout(rect=(0, 0, 1, 0.95))
|
||||
fig.tight_layout(rect=(0, 0, 1, 0.93))
|
||||
else:
|
||||
fig.tight_layout()
|
||||
save_fig(fig, out)
|
||||
|
||||
Reference in New Issue
Block a user