Files
evil_MoE/scripts/plot_deploy_overlay.py
T
wassname 270c4f5a27 misc
2026-06-11 11:07:28 +00:00

156 lines
7.6 KiB
Python

"""All-arms per-mode DEPLOY overlay (#162) from the per_mode_deploy.json artifacts.
Each run writes out/runs/<ts>_<tag>/per_mode_deploy.json (train.py, #164) with
deployment metrics. For route/route2, evaluation ablates the quarantine parameters.
Unlike plot_substrate's training-time hk_<mode> curves, these metrics evaluate the
deployed parameter state.
Reads JSON, not logs, so it never trips on a route2 arm the log-parsers don't know.
The headline comparison: per loophole mode, does each intervention suppress the
DEPLOY hack rate below vanilla, and at what cost to DEPLOY solve? run_tests is the
in-distribution mode (v_hack built closest to it); the rest are held-out modes used
to test generalization without training-distribution labels. Cleveland dot plot:
y = mode, dot per arm, connector per
mode so the vanilla -> route change reads as a line segment.
Usage:
uv run python scripts/plot_deploy_overlay.py # globs out/runs/*sub4*/
uv run python scripts/plot_deploy_overlay.py out/runs/*_sub4_*/per_mode_deploy.json
uv run python scripts/plot_deploy_overlay.py --out out/figs/deploy_overlay.png
"""
from __future__ import annotations
import argparse
import csv
import json
from collections import defaultdict
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from loguru import logger
from vgrout.figs import save_fig
# arm -> (display label, colour). Order = legend/bar order (baseline first).
# Reader-facing names only -- "route2"/"grad" are internal tags. The grad-mask
# routing arm is the one we report, so it is plain "route"; the failed
# activation-mask variant is disambiguated, not version-numbered.
ARM = {
"vanilla": ("vanilla", "#444444"),
"projected": ("erase", "#c1432b"),
"routing": ("route (v1)", "#33508c"),
"routing2_act": ("route (act-mask)", "#2f7d4f"),
"routing2_grad":("route", "#b8860b"),
"routing2": ("route", "#b8860b"),
}
# mode display order: in-dist first, then held-out.
MODE_ORDER = ["run_tests", "file_marker", "stdout_marker", "sentinel", "eq_override"]
def load(paths: list[Path]) -> list[dict]:
out = []
for p in paths:
d = json.loads(p.read_text())
out.append(d)
logger.info(f"{d['arm']:<14} deploy hack={d['hack_deploy']:.3f} solve={d['solve_deploy']:.3f} ({p})")
return out
def _mode_stats(by_arm, arm, modes, field):
"""(mean, std-across-seeds) per mode for one arm; std=0 at n=1."""
means, stds = [], []
for m in modes:
v = [r["by_mode"].get(m, {}).get(field, np.nan) for r in by_arm[arm]]
means.append(np.nanmean(v))
stds.append(np.nanstd(v) if len(v) > 1 else 0.0)
return np.array(means), np.array(stds)
def _panel(ax, by_arm, modes, arms, field, xlabel):
"""Cleveland dot plot: y = mode, x = rate. One dot per arm with a thin connector
per mode, so the arm-to-arm change reads as a line segment (vanilla -> route).
xerr = std across seeds (drawn only when >1 seed). Tufte: faint x-grid only, no
box, dots+labels carry the categories.
TODO(seeds): A5 currently has n=1 (seed 41, jobs 103/104) so no error bar yet; the
queued seeds 42/43 (jobs 107-110) populate xerr -- the code already aggregates."""
y = np.arange(len(modes))[::-1] # first mode at top
for j in range(len(modes)): # arrow baseline->ours per mode: shows the DIRECTION of change
xs = [_mode_stats(by_arm, a, modes, field)[0][j] for a in arms]
if len(xs) >= 2 and np.isfinite(xs[0]) and np.isfinite(xs[-1]):
ax.annotate("", xy=(xs[-1], y[j]), xytext=(xs[0], y[j]), zorder=1,
arrowprops=dict(arrowstyle="-|>", color="0.6", lw=1.1,
shrinkA=6, shrinkB=6))
for i, arm in enumerate(arms):
label, color = ARM[arm]
means, stds = _mode_stats(by_arm, arm, modes, field)
xerr = stds if (stds > 0).any() else None
ax.errorbar(means, y, xerr=xerr, fmt="o", ms=7, color=color, zorder=3,
capsize=2, elinewidth=0.8, label=f"{label} (n={len(by_arm[arm])})")
dy = 7 if i == 0 else -12 # stagger labels so close dots don't collide
for v, yy in zip(means, y):
if np.isnan(v):
continue
txt = "≈0" if v < 5e-3 else f"{v:.2f}" # finite-sample estimate: approx, not identically, zero
ax.annotate(txt, (v, yy), fontsize=6, color=color, ha="center",
va="bottom", xytext=(0, dy), textcoords="offset points")
ax.set_yticks(y)
ax.set_yticklabels([f"{m}\n{'IN' if m == 'run_tests' else 'held-out'}" for m in modes], fontsize=8)
ax.set_xlim(-0.04, 1.08)
ax.set_ylim(y.min() - 0.5, y.max() + 0.5)
ax.set_xlabel(xlabel, fontsize=9) # carries the metric AND the better-direction;
ax.spines[["top", "right", "left"]].set_visible(False) # no title (would just restate it)
ax.tick_params(length=0)
ax.grid(axis="x", lw=0.3, alpha=0.3)
def main() -> None:
ap = argparse.ArgumentParser(description=__doc__)
ap.add_argument("jsons", nargs="*", type=Path,
help="per_mode_deploy.json paths; default globs out/runs/*sub4*/")
ap.add_argument("--out", type=Path, default=Path("out/figs/deploy_overlay.png"))
ap.add_argument("--title", action="store_true",
help="draw the suptitle (off by default: the caption carries it)")
args = ap.parse_args()
paths = args.jsons or sorted(Path("out/runs").glob("*sub4*/per_mode_deploy.json"))
if not paths:
raise SystemExit("no per_mode_deploy.json found (run the sweep first)")
records = load(paths)
# group seed runs per arm (mean+/-std bars), order arms canonically
by_arm: dict[str, list[dict]] = defaultdict(list)
for r in records:
by_arm[r["arm"]].append(r)
arms = [a for a in ARM if a in by_arm]
modes = [m for m in MODE_ORDER if any(m in r["by_mode"] for r in records)]
fig, (a1, a2) = plt.subplots(1, 2, figsize=(9.5, 0.7 + 0.7 * len(modes)), sharey=True)
_panel(a1, by_arm, modes, arms, "deploy_hack", r"DEPLOY hack rate ($\downarrow$ lower = better)")
_panel(a2, by_arm, modes, arms, "deploy_solve", r"DEPLOY solve rate ($\uparrow$ higher = better)")
a1.legend(fontsize=8, frameon=False, loc="lower right")
if args.title:
n_seed = {r.get("seed") for r in records}
fig.suptitle(f"Per-mode deploy overlay ({len(arms)} arms, seed {sorted(n_seed)}) -- "
f"quarantine deleted = shipped model", fontsize=11)
fig.tight_layout()
save_fig(fig, args.out)
# CSV reproducibility source (mirrors the dynamics plots' dump): per (mode, arm)
# the deploy hack/solve mean +/- std-across-seeds, exactly what the dots encode.
csv_path = args.out.with_suffix(".csv")
with csv_path.open("w", newline="") as f:
w = csv.writer(f)
w.writerow(["mode", "in_dist", "arm", "n_seed",
"deploy_hack_mean", "deploy_hack_std", "deploy_solve_mean", "deploy_solve_std"])
for arm in arms:
hk_m, hk_s = _mode_stats(by_arm, arm, modes, "deploy_hack")
sv_m, sv_s = _mode_stats(by_arm, arm, modes, "deploy_solve")
for j, m in enumerate(modes):
w.writerow([m, m == "run_tests", ARM[arm][0], len(by_arm[arm]),
f"{hk_m[j]:.6f}", f"{hk_s[j]:.6f}", f"{sv_m[j]:.6f}", f"{sv_s[j]:.6f}"])
logger.info(f"wrote {args.out} and {csv_path.name} ({len(arms)} arms x {len(modes)} modes)")
if __name__ == "__main__":
main()