Files
evil_MoE/scripts/plot_emergence.py
T

104 lines
3.9 KiB
Python

"""Phase-1 emergence plot: does each loophole emerge under vanilla GRPO?
One line per env_mode. Row 0 = hack rate (exploited, red-ish) + solve (gt_correct,
green-ish); a loophole "emerges" if hack rises from ~0. Single-seed by default
(pass more logs to overlay seeds). Reuses plot_dynamics.parse_log so the column
parsing stays in one place; groups by env_mode (from argv --env-mode) instead of
intervention-arm (all emergence runs are vanilla, so arm grouping collapses them).
Usage:
uv run python scripts/plot_emergence.py logs/*_emerge_*.log
uv run python scripts/plot_emergence.py logs/ --out out/figs/emergence.png
"""
from __future__ import annotations
import argparse
import re
from collections import defaultdict
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from loguru import logger
from projected_grpo.figs import link_latest
from plot_dynamics import _ema, _gather, _onset, parse_log
# Distinct hue per mode (qualitative). solve drawn dashed in the same hue so each
# mode's hack/solve pair reads together.
MODE_COLORS = {
"run_tests": "#c1432b",
"eq_override": "#33508c",
"exit_code": "#b8860b",
}
def _env_mode(path: Path) -> str:
txt = path.read_text(errors="replace")
m = re.findall(r"--env-mode[= ](\w+)", txt)
if m:
return m[-1]
# default run_tests env when the flag is absent (old-style logs)
return "run_tests"
def plot(runs_by_mode: dict[str, list[dict]], out: Path) -> None:
modes = [m for m in MODE_COLORS if m in runs_by_mode] + \
[m for m in runs_by_mode if m not in MODE_COLORS]
fig, ax = plt.subplots(figsize=(6.0, 3.8))
for mode in modes:
rs = runs_by_mode[mode]
color = MODE_COLORS.get(mode, "#555555")
# mean-of-EMA hack and solve across whatever seeds were passed
for key, ls, lbl in [("hack_s", "-", "hack"), ("gt_s", (0, (4, 2)), "solve")]:
stacked = [_ema(r[key]) for r in rs]
L = min(len(y) for y in stacked)
ym = np.nanmean(np.stack([y[:L] for y in stacked]), axis=0)
xm = rs[0]["steps"][:L]
ax.plot(xm, ym, color=color, lw=2.0 if key == "hack_s" else 1.2,
ls=ls, solid_capstyle="round")
if key == "hack_s":
ax.annotate(mode, (xm[-1], ym[-1]), color=color, fontsize=8,
xytext=(4, 0), textcoords="offset points", va="center")
onsets = [s for r in rs if (s := _onset(r["steps"], r["hack_s"])) is not None]
if onsets:
s0 = float(np.mean(onsets))
ax.axvline(s0, color=color, lw=0.7, ls=(0, (2, 3)), alpha=0.5, zorder=0)
ax.set_ylim(0, 1)
ax.set_xlabel("optimizer step")
ax.set_ylabel("rate")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.tick_params(labelsize=8)
ax.set_title("Loophole emergence under vanilla GRPO "
"(solid=hack/exploited, dashed=solve/gt_correct; EMA-5)", fontsize=9)
fig.tight_layout()
out.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out, dpi=150, bbox_inches="tight")
logger.info(f"wrote {out} (modes={modes})")
def main() -> None:
ap = argparse.ArgumentParser(description=__doc__)
ap.add_argument("logs", nargs="+", help="log files, globs, or dirs")
ap.add_argument("--out", type=Path, default=Path("out/figs/emergence.png"))
args = ap.parse_args()
files = _gather(args.logs)
by_mode: dict[str, list[dict]] = defaultdict(list)
for f in files:
r = parse_log(f)
if r is None:
continue
by_mode[_env_mode(f)].append(r)
if not by_mode:
raise SystemExit(f"no parseable runs in {len(files)} files")
for mode, rs in by_mode.items():
logger.info(f"{mode:14s} {len(rs)} run(s), steps={[len(r['steps']) for r in rs]}")
plot(by_mode, args.out)
logger.info(f"docs/figs latest -> {link_latest(args.out)}")
if __name__ == "__main__":
main()