mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 16:30:30 +08:00
plot_deploy_overlay: Cleveland dot plot replaces grouped bars (tufte)
y=mode, dot per arm, thin connector per mode so vanilla->route change reads as a line segment. Faint x-grid only, no box (dots+labels carry categories), labels staggered to avoid collision, xerr=seed std when n>1. Kills the invisible zero-bar problem and shows the per-mode drop directly. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -10,7 +10,8 @@ Reads JSON, not logs, so it never trips on a route2 arm the log-parsers don't kn
|
||||
The headline comparison: per loophole mode, does each intervention suppress the
|
||||
DEPLOY hack rate below vanilla, and at what cost to DEPLOY solve? run_tests is the
|
||||
in-dist mode (v_hack built closest to it); the rest are held-out (the no-cheat
|
||||
generalisation test). Bars grouped by mode, one bar per arm.
|
||||
generalisation test). Cleveland dot plot: y = mode, dot per arm, connector per
|
||||
mode so the vanilla -> route change reads as a line segment.
|
||||
|
||||
Usage:
|
||||
uv run python scripts/plot_deploy_overlay.py # globs out/runs/*sub4*/
|
||||
@@ -55,43 +56,49 @@ def load(paths: list[Path]) -> list[dict]:
|
||||
return out
|
||||
|
||||
|
||||
def _despine(ax):
|
||||
ax.spines[["top", "right"]].set_visible(False)
|
||||
ax.grid(axis="y", lw=0.4, alpha=0.35)
|
||||
def _mode_stats(by_arm, arm, modes, field):
|
||||
"""(mean, std-across-seeds) per mode for one arm; std=0 at n=1."""
|
||||
means, stds = [], []
|
||||
for m in modes:
|
||||
v = [r["by_mode"].get(m, {}).get(field, np.nan) for r in by_arm[arm]]
|
||||
means.append(np.nanmean(v))
|
||||
stds.append(np.nanstd(v) if len(v) > 1 else 0.0)
|
||||
return np.array(means), np.array(stds)
|
||||
|
||||
|
||||
def _panel(ax, by_arm, modes, arms, field, title, ylabel):
|
||||
"""Grouped bars: x = mode, one bar per arm, height = mean over seed runs of
|
||||
by_mode[mode][field]; error bar = std across seeds (drawn only when >1 seed).
|
||||
TODO(seeds): A5 currently ships n=1 (seed 41 only, jobs 103/104) so no error
|
||||
bar appears. Pass per-seed JSONs (a5 vanilla+route2 seeds 42/43, queued) to
|
||||
populate the error bars -- the code already aggregates them."""
|
||||
w = 0.8 / len(arms)
|
||||
x = np.arange(len(modes))
|
||||
def _panel(ax, by_arm, modes, arms, field, title, xlabel):
|
||||
"""Cleveland dot plot: y = mode, x = rate. One dot per arm with a thin connector
|
||||
per mode, so the arm-to-arm change reads as a line segment (vanilla -> route).
|
||||
xerr = std across seeds (drawn only when >1 seed). Tufte: faint x-grid only, no
|
||||
box, dots+labels carry the categories.
|
||||
TODO(seeds): A5 ships n=1 (seed 41, jobs 103/104) so no error bar yet; the
|
||||
queued seeds 42/43 (jobs 107-110) populate xerr -- the code already aggregates."""
|
||||
y = np.arange(len(modes))[::-1] # first mode at top
|
||||
for j in range(len(modes)): # connector between arms, per mode
|
||||
xs = [_mode_stats(by_arm, a, modes, field)[0][j] for a in arms]
|
||||
ax.plot(xs, [y[j]] * len(arms), color="0.75", lw=1.0, zorder=1)
|
||||
for i, arm in enumerate(arms):
|
||||
recs = by_arm[arm]
|
||||
label, color = ARM[arm]
|
||||
per_mode = [[r["by_mode"].get(m, {}).get(field, np.nan) for r in recs] for m in modes]
|
||||
means = np.array([np.nanmean(v) for v in per_mode])
|
||||
stds = np.array([np.nanstd(v) if len(v) > 1 else 0.0 for v in per_mode])
|
||||
n_seed = len(recs)
|
||||
yerr = stds if (stds > 0).any() else None
|
||||
bars = ax.bar(x + i * w, means, w, label=f"{label} (n={n_seed})", color=color,
|
||||
yerr=yerr, capsize=2, error_kw=dict(lw=0.8, alpha=0.8))
|
||||
for b, v in zip(bars, means):
|
||||
means, stds = _mode_stats(by_arm, arm, modes, field)
|
||||
xerr = stds if (stds > 0).any() else None
|
||||
ax.errorbar(means, y, xerr=xerr, fmt="o", ms=7, color=color, zorder=3,
|
||||
capsize=2, elinewidth=0.8, label=f"{label} (n={len(by_arm[arm])})")
|
||||
dy = 7 if i == 0 else -12 # stagger labels so close dots don't collide
|
||||
for v, yy in zip(means, y):
|
||||
if np.isnan(v):
|
||||
continue
|
||||
# a zero-height bar is invisible -- mark it "≡0" so the reader sees a
|
||||
# finding, not a missing bar (same convention as the line plots).
|
||||
txt = "≡0" if v < 5e-3 else f"{v:.2f}"
|
||||
ax.annotate(txt, (b.get_x() + b.get_width() / 2, v), fontsize=6,
|
||||
ha="center", va="bottom", color=color)
|
||||
ax.set_xticks(x + 0.4 - w / 2)
|
||||
ax.set_xticklabels([f"{m}\n{'IN' if m == 'run_tests' else 'held-out'}" for m in modes], fontsize=8)
|
||||
txt = "≡0" if v < 5e-3 else f"{v:.2f}" # a dot on the axis still needs the finding marked
|
||||
ax.annotate(txt, (v, yy), fontsize=6, color=color, ha="center",
|
||||
va="bottom", xytext=(0, dy), textcoords="offset points")
|
||||
ax.set_yticks(y)
|
||||
ax.set_yticklabels([f"{m}\n{'IN' if m == 'run_tests' else 'held-out'}" for m in modes], fontsize=8)
|
||||
ax.set_xlim(-0.04, 1.08)
|
||||
ax.set_ylim(y.min() - 0.5, y.max() + 0.5)
|
||||
ax.set_xlabel(xlabel)
|
||||
ax.set_title(title, fontsize=10)
|
||||
ax.set_ylabel(ylabel)
|
||||
ax.set_ylim(0, 1.05)
|
||||
_despine(ax)
|
||||
ax.spines[["top", "right", "left"]].set_visible(False)
|
||||
ax.tick_params(length=0)
|
||||
ax.grid(axis="x", lw=0.3, alpha=0.3)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
@@ -114,12 +121,12 @@ def main() -> None:
|
||||
arms = [a for a in ARM if a in by_arm]
|
||||
modes = [m for m in MODE_ORDER if any(m in r["by_mode"] for r in records)]
|
||||
|
||||
fig, (a1, a2) = plt.subplots(1, 2, figsize=(5.5 + 1.2 * len(modes), 4.2))
|
||||
fig, (a1, a2) = plt.subplots(1, 2, figsize=(9.5, 0.7 + 0.7 * len(modes)), sharey=True)
|
||||
_panel(a1, by_arm, modes, arms, "deploy_hack",
|
||||
"DEPLOY hack rate by mode (lower = better)", "deploy hack rate")
|
||||
"DEPLOY hack rate (lower = better)", "deploy hack rate")
|
||||
_panel(a2, by_arm, modes, arms, "deploy_solve",
|
||||
"DEPLOY solve rate by mode (higher = better)", "deploy solve rate")
|
||||
a1.legend(fontsize=8, frameon=False, loc="upper right")
|
||||
"DEPLOY solve rate (higher = better)", "deploy solve rate")
|
||||
a1.legend(fontsize=8, frameon=False, loc="lower right")
|
||||
if args.title:
|
||||
n_seed = {r.get("seed") for r in records}
|
||||
fig.suptitle(f"Per-mode deploy overlay ({len(arms)} arms, seed {sorted(n_seed)}) -- "
|
||||
|
||||
Reference in New Issue
Block a user