mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 18:59:35 +08:00
feat(plot): hack-overlay gains a solve-rate subplot (Tufte two-panel)
Stacked hack (top) + solve (bottom) sharing x; EMA-5; onset dot on hack only; arms direct-labelled once on solve with y de-collision + leader lines (the three non-route arms overlap, so their labels would otherwise stack). routing2 reads hack~0 / solve highest at a glance. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
+47
-24
@@ -283,42 +283,65 @@ def plot(runs: list[dict], out: Path) -> None:
|
||||
logger.info(f"wrote {out} ({len(runs)} runs, arms={arms})")
|
||||
|
||||
|
||||
def plot_hack_overlay(runs: list[dict], out: Path) -> None:
|
||||
"""Single panel: mean hack_s per arm overlaid, the headline arm-vs-arm view.
|
||||
Faint per-seed lines for spread, bold EMA mean per arm, onset tick at the
|
||||
arm's mean hack-onset, direct-labelled (no legend)."""
|
||||
by_arm: dict[str, list[dict]] = defaultdict(list)
|
||||
for r in runs:
|
||||
by_arm[classify(r)].append(r)
|
||||
arms = [a for a in ARM_ORDER if a in by_arm]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(5.2, 3.4))
|
||||
def _overlay_panel(ax, by_arm, arms, key, *, label, with_onset):
|
||||
"""Overlay one metric (key) per arm on ax: faint per-seed EMA lines + bold
|
||||
EMA mean, optional mean-onset dot. Direct labels (only on the unlabeled-x panel)
|
||||
are de-collided in y so overlapping arms don't stack their text (collision test)."""
|
||||
ends = [] # (y_endpoint, x_endpoint, arm, color) for direct labels
|
||||
for arm in arms:
|
||||
rs = by_arm[arm]
|
||||
rs = [r for r in by_arm[arm] if key in r]
|
||||
if not rs:
|
||||
continue
|
||||
color = ARM_COLORS[arm]
|
||||
stacked = []
|
||||
for r in rs:
|
||||
ys = _ema(r["hack_s"])
|
||||
ys = _ema(r[key])
|
||||
ax.plot(r["steps"], ys, color=color, lw=0.6, alpha=0.25, solid_capstyle="round")
|
||||
stacked.append(ys)
|
||||
L = min(len(y) for y in stacked)
|
||||
ym = np.nanmean(np.stack([y[:L] for y in stacked]), axis=0)
|
||||
xm = rs[0]["steps"][:L]
|
||||
ax.plot(xm, ym, color=color, lw=2.0, solid_capstyle="round")
|
||||
onsets = [s for r in rs if (s := _onset(r["steps"], r["hack_s"])) is not None]
|
||||
if onsets:
|
||||
s0 = float(np.mean(onsets))
|
||||
ax.plot(s0, np.interp(s0, xm, ym), marker="o", ms=4, color=color, zorder=3)
|
||||
ax.annotate(arm, (xm[-1], ym[-1]), color=color, fontsize=8,
|
||||
xytext=(4, 0), textcoords="offset points", va="center")
|
||||
|
||||
if with_onset:
|
||||
onsets = [s for r in rs if (s := _onset(r["steps"], r["hack_s"])) is not None]
|
||||
if onsets:
|
||||
s0 = float(np.mean(onsets))
|
||||
ax.plot(s0, np.interp(s0, xm, ym), marker="o", ms=4, color=color, zorder=3)
|
||||
ends.append((float(ym[-1]), float(xm[-1]), arm, color))
|
||||
ax.set_ylim(0, 1)
|
||||
ax.set_xlabel("optimizer step")
|
||||
ax.set_ylabel("student hack rate (hack_s)")
|
||||
ax.spines["top"].set_visible(False)
|
||||
ax.spines["right"].set_visible(False)
|
||||
ax.set_ylabel(label)
|
||||
ax.spines[["top", "right"]].set_visible(False)
|
||||
ax.tick_params(labelsize=8)
|
||||
ax.set_title("Student hack rate by arm (EMA-5; dot = mean onset)", fontsize=10)
|
||||
# direct-label only one panel (caller passes with_onset=False for it) -- the
|
||||
# other shares colours, so labelling both is redundant ink (eraser test).
|
||||
if with_onset:
|
||||
return
|
||||
ends.sort(key=lambda e: e[0]) # bottom-to-top by endpoint
|
||||
gap = 0.052 # min y-separation in data units
|
||||
placed = []
|
||||
for y, x, arm, color in ends:
|
||||
y_lab = y if not placed else max(y, placed[-1] + gap)
|
||||
placed.append(y_lab)
|
||||
arrow = dict(arrowstyle="-", color=color, lw=0.5, shrinkA=0, shrinkB=0)
|
||||
ax.annotate(arm, xy=(x, y), xytext=(x + 1.0, y_lab), textcoords="data",
|
||||
color=color, fontsize=8, va="center",
|
||||
arrowprops=arrow if abs(y_lab - y) > 1e-3 else None)
|
||||
|
||||
|
||||
def plot_hack_overlay(runs: list[dict], out: Path) -> None:
|
||||
"""Two stacked panels sharing x: student hack rate (top) and solve rate (bottom)
|
||||
per arm. Faint per-seed EMA lines + bold EMA-5 mean; onset dot on the hack panel;
|
||||
arms direct-labelled once on the solve panel (shared colours, no legend)."""
|
||||
by_arm: dict[str, list[dict]] = defaultdict(list)
|
||||
for r in runs:
|
||||
by_arm[classify(r)].append(r)
|
||||
arms = [a for a in ARM_ORDER if a in by_arm]
|
||||
|
||||
fig, (ax_h, ax_s) = plt.subplots(2, 1, figsize=(5.2, 5.2), sharex=True)
|
||||
_overlay_panel(ax_h, by_arm, arms, "hack_s", label="hack rate", with_onset=True)
|
||||
_overlay_panel(ax_s, by_arm, arms, "gt_s", label="solve rate", with_onset=False)
|
||||
ax_s.set_xlabel("optimizer step")
|
||||
ax_h.set_title("Hack vs solve rate by arm (EMA-5; dot = mean hack onset)", fontsize=10)
|
||||
fig.tight_layout()
|
||||
out.parent.mkdir(parents=True, exist_ok=True)
|
||||
fig.savefig(out, dpi=150, bbox_inches="tight")
|
||||
|
||||
Reference in New Issue
Block a user