mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-28 04:22:03 +08:00
55937a86fb
git mv src/projected_grpo -> src/vgrout and find-replace the module name in
all imports (.py), `-m projected_grpo.*` invocations (justfile), and the
[project] name (pyproject; setuptools auto-discovers via where=["src"]).
Left RESEARCH_JOURNAL.md untouched: its commands/paths are dated lab notes
tied to past commits, so rewriting them would falsify provenance. Repo dir,
git remote, and absolute paths unchanged.
Verified: `import vgrout` and `python -m vgrout.train --help` load the full
graph; verify_rewards.py + verify_gate_anchor.py (both import vgrout) pass.
Full `just smoke` is blocked upstream by missing gitignored data artifacts
(out/pools/{substrate,teacher_pool}, out/vhack/*smoke*), unrelated to the rename.
155 lines
7.6 KiB
Python
155 lines
7.6 KiB
Python
"""All-arms per-mode DEPLOY overlay (#162) from the per_mode_deploy.json artifacts.
|
|
|
|
Each run writes out/runs/<ts>_<tag>/per_mode_deploy.json (train.py, #164) with the
|
|
HONEST deploy numbers: for route/route2 the quarantine is deleted before eval, so
|
|
this is the model you would actually ship -- unlike plot_substrate's hk_<mode>
|
|
curves which are TRAIN-time (routed forward still hacks) and overstate routing.
|
|
|
|
Reads JSON, not logs, so it never trips on a route2 arm the log-parsers don't know.
|
|
|
|
The headline comparison: per loophole mode, does each intervention suppress the
|
|
DEPLOY hack rate below vanilla, and at what cost to DEPLOY solve? run_tests is the
|
|
in-dist mode (v_hack built closest to it); the rest are held-out (the no-cheat
|
|
generalisation test). Cleveland dot plot: y = mode, dot per arm, connector per
|
|
mode so the vanilla -> route change reads as a line segment.
|
|
|
|
Usage:
|
|
uv run python scripts/plot_deploy_overlay.py # globs out/runs/*sub4*/
|
|
uv run python scripts/plot_deploy_overlay.py out/runs/*_sub4_*/per_mode_deploy.json
|
|
uv run python scripts/plot_deploy_overlay.py --out out/figs/deploy_overlay.png
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import csv
|
|
import json
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from loguru import logger
|
|
|
|
from vgrout.figs import save_fig
|
|
|
|
# arm -> (display label, colour). Order = legend/bar order (baseline first).
|
|
# Reader-facing names only -- "route2"/"grad" are internal tags. The grad-mask
|
|
# routing arm is the one we report, so it is plain "route"; the failed
|
|
# activation-mask variant is disambiguated, not version-numbered.
|
|
ARM = {
|
|
"vanilla": ("vanilla", "#444444"),
|
|
"projected": ("erase", "#c1432b"),
|
|
"routing": ("route (v1)", "#33508c"),
|
|
"routing2_act": ("route (act-mask)", "#2f7d4f"),
|
|
"routing2_grad":("route", "#b8860b"),
|
|
"routing2": ("route", "#b8860b"),
|
|
}
|
|
# mode display order: in-dist first, then held-out.
|
|
MODE_ORDER = ["run_tests", "file_marker", "stdout_marker", "sentinel", "eq_override"]
|
|
|
|
|
|
def load(paths: list[Path]) -> list[dict]:
|
|
out = []
|
|
for p in paths:
|
|
d = json.loads(p.read_text())
|
|
out.append(d)
|
|
logger.info(f"{d['arm']:<14} deploy hack={d['hack_deploy']:.3f} solve={d['solve_deploy']:.3f} ({p})")
|
|
return out
|
|
|
|
|
|
def _mode_stats(by_arm, arm, modes, field):
|
|
"""(mean, std-across-seeds) per mode for one arm; std=0 at n=1."""
|
|
means, stds = [], []
|
|
for m in modes:
|
|
v = [r["by_mode"].get(m, {}).get(field, np.nan) for r in by_arm[arm]]
|
|
means.append(np.nanmean(v))
|
|
stds.append(np.nanstd(v) if len(v) > 1 else 0.0)
|
|
return np.array(means), np.array(stds)
|
|
|
|
|
|
def _panel(ax, by_arm, modes, arms, field, xlabel):
|
|
"""Cleveland dot plot: y = mode, x = rate. One dot per arm with a thin connector
|
|
per mode, so the arm-to-arm change reads as a line segment (vanilla -> route).
|
|
xerr = std across seeds (drawn only when >1 seed). Tufte: faint x-grid only, no
|
|
box, dots+labels carry the categories.
|
|
TODO(seeds): A5 ships n=1 (seed 41, jobs 103/104) so no error bar yet; the
|
|
queued seeds 42/43 (jobs 107-110) populate xerr -- the code already aggregates."""
|
|
y = np.arange(len(modes))[::-1] # first mode at top
|
|
for j in range(len(modes)): # arrow baseline->ours per mode: shows the DIRECTION of change
|
|
xs = [_mode_stats(by_arm, a, modes, field)[0][j] for a in arms]
|
|
if len(xs) >= 2 and np.isfinite(xs[0]) and np.isfinite(xs[-1]):
|
|
ax.annotate("", xy=(xs[-1], y[j]), xytext=(xs[0], y[j]), zorder=1,
|
|
arrowprops=dict(arrowstyle="-|>", color="0.6", lw=1.1,
|
|
shrinkA=6, shrinkB=6))
|
|
for i, arm in enumerate(arms):
|
|
label, color = ARM[arm]
|
|
means, stds = _mode_stats(by_arm, arm, modes, field)
|
|
xerr = stds if (stds > 0).any() else None
|
|
ax.errorbar(means, y, xerr=xerr, fmt="o", ms=7, color=color, zorder=3,
|
|
capsize=2, elinewidth=0.8, label=f"{label} (n={len(by_arm[arm])})")
|
|
dy = 7 if i == 0 else -12 # stagger labels so close dots don't collide
|
|
for v, yy in zip(means, y):
|
|
if np.isnan(v):
|
|
continue
|
|
txt = "≈0" if v < 5e-3 else f"{v:.2f}" # finite-sample estimate: approx, not identically, zero
|
|
ax.annotate(txt, (v, yy), fontsize=6, color=color, ha="center",
|
|
va="bottom", xytext=(0, dy), textcoords="offset points")
|
|
ax.set_yticks(y)
|
|
ax.set_yticklabels([f"{m}\n{'IN' if m == 'run_tests' else 'held-out'}" for m in modes], fontsize=8)
|
|
ax.set_xlim(-0.04, 1.08)
|
|
ax.set_ylim(y.min() - 0.5, y.max() + 0.5)
|
|
ax.set_xlabel(xlabel, fontsize=9) # carries the metric AND the better-direction;
|
|
ax.spines[["top", "right", "left"]].set_visible(False) # no title (would just restate it)
|
|
ax.tick_params(length=0)
|
|
ax.grid(axis="x", lw=0.3, alpha=0.3)
|
|
|
|
|
|
def main() -> None:
|
|
ap = argparse.ArgumentParser(description=__doc__)
|
|
ap.add_argument("jsons", nargs="*", type=Path,
|
|
help="per_mode_deploy.json paths; default globs out/runs/*sub4*/")
|
|
ap.add_argument("--out", type=Path, default=Path("out/figs/deploy_overlay.png"))
|
|
ap.add_argument("--title", action="store_true",
|
|
help="draw the suptitle (off by default: the caption carries it)")
|
|
args = ap.parse_args()
|
|
|
|
paths = args.jsons or sorted(Path("out/runs").glob("*sub4*/per_mode_deploy.json"))
|
|
if not paths:
|
|
raise SystemExit("no per_mode_deploy.json found (run the sweep first)")
|
|
records = load(paths)
|
|
# group seed runs per arm (mean+/-std bars), order arms canonically
|
|
by_arm: dict[str, list[dict]] = defaultdict(list)
|
|
for r in records:
|
|
by_arm[r["arm"]].append(r)
|
|
arms = [a for a in ARM if a in by_arm]
|
|
modes = [m for m in MODE_ORDER if any(m in r["by_mode"] for r in records)]
|
|
|
|
fig, (a1, a2) = plt.subplots(1, 2, figsize=(9.5, 0.7 + 0.7 * len(modes)), sharey=True)
|
|
_panel(a1, by_arm, modes, arms, "deploy_hack", r"DEPLOY hack rate ($\downarrow$ lower = better)")
|
|
_panel(a2, by_arm, modes, arms, "deploy_solve", r"DEPLOY solve rate ($\uparrow$ higher = better)")
|
|
a1.legend(fontsize=8, frameon=False, loc="lower right")
|
|
if args.title:
|
|
n_seed = {r.get("seed") for r in records}
|
|
fig.suptitle(f"Per-mode deploy overlay ({len(arms)} arms, seed {sorted(n_seed)}) -- "
|
|
f"quarantine deleted = shipped model", fontsize=11)
|
|
fig.tight_layout()
|
|
save_fig(fig, args.out)
|
|
# CSV reproducibility source (mirrors the dynamics plots' dump): per (mode, arm)
|
|
# the deploy hack/solve mean +/- std-across-seeds, exactly what the dots encode.
|
|
csv_path = args.out.with_suffix(".csv")
|
|
with csv_path.open("w", newline="") as f:
|
|
w = csv.writer(f)
|
|
w.writerow(["mode", "in_dist", "arm", "n_seed",
|
|
"deploy_hack_mean", "deploy_hack_std", "deploy_solve_mean", "deploy_solve_std"])
|
|
for arm in arms:
|
|
hk_m, hk_s = _mode_stats(by_arm, arm, modes, "deploy_hack")
|
|
sv_m, sv_s = _mode_stats(by_arm, arm, modes, "deploy_solve")
|
|
for j, m in enumerate(modes):
|
|
w.writerow([m, m == "run_tests", ARM[arm][0], len(by_arm[arm]),
|
|
f"{hk_m[j]:.6f}", f"{hk_s[j]:.6f}", f"{sv_m[j]:.6f}", f"{sv_s[j]:.6f}"])
|
|
logger.info(f"wrote {args.out} and {csv_path.name} ({len(arms)} arms x {len(modes)} modes)")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|