mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-28 02:00:23 +08:00
8e38d0f419
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
104 lines
3.9 KiB
Python
104 lines
3.9 KiB
Python
"""Phase-1 emergence plot: does each loophole emerge under vanilla GRPO?
|
|
|
|
One line per env_mode. Row 0 = hack rate (exploited, red-ish) + solve (gt_correct,
|
|
green-ish); a loophole "emerges" if hack rises from ~0. Single-seed by default
|
|
(pass more logs to overlay seeds). Reuses plot_dynamics.parse_log so the column
|
|
parsing stays in one place; groups by env_mode (from argv --env-mode) instead of
|
|
intervention-arm (all emergence runs are vanilla, so arm grouping collapses them).
|
|
|
|
Usage:
|
|
uv run python scripts/plot_emergence.py logs/*_emerge_*.log
|
|
uv run python scripts/plot_emergence.py logs/ --out out/figs/emergence.png
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import re
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from loguru import logger
|
|
|
|
from projected_grpo.figs import link_latest
|
|
from plot_dynamics import _ema, _gather, _onset, parse_log
|
|
|
|
# Distinct hue per mode (qualitative). solve drawn dashed in the same hue so each
|
|
# mode's hack/solve pair reads together.
|
|
MODE_COLORS = {
|
|
"run_tests": "#c1432b",
|
|
"eq_override": "#33508c",
|
|
"exit_code": "#b8860b",
|
|
}
|
|
|
|
|
|
def _env_mode(path: Path) -> str:
|
|
txt = path.read_text(errors="replace")
|
|
m = re.findall(r"--env-mode[= ](\w+)", txt)
|
|
if m:
|
|
return m[-1]
|
|
# default run_tests env when the flag is absent (old-style logs)
|
|
return "run_tests"
|
|
|
|
|
|
def plot(runs_by_mode: dict[str, list[dict]], out: Path) -> None:
|
|
modes = [m for m in MODE_COLORS if m in runs_by_mode] + \
|
|
[m for m in runs_by_mode if m not in MODE_COLORS]
|
|
fig, ax = plt.subplots(figsize=(6.0, 3.8))
|
|
for mode in modes:
|
|
rs = runs_by_mode[mode]
|
|
color = MODE_COLORS.get(mode, "#555555")
|
|
# mean-of-EMA hack and solve across whatever seeds were passed
|
|
for key, ls, lbl in [("hack_s", "-", "hack"), ("gt_s", (0, (4, 2)), "solve")]:
|
|
stacked = [_ema(r[key]) for r in rs]
|
|
L = min(len(y) for y in stacked)
|
|
ym = np.nanmean(np.stack([y[:L] for y in stacked]), axis=0)
|
|
xm = rs[0]["steps"][:L]
|
|
ax.plot(xm, ym, color=color, lw=2.0 if key == "hack_s" else 1.2,
|
|
ls=ls, solid_capstyle="round")
|
|
if key == "hack_s":
|
|
ax.annotate(mode, (xm[-1], ym[-1]), color=color, fontsize=8,
|
|
xytext=(4, 0), textcoords="offset points", va="center")
|
|
onsets = [s for r in rs if (s := _onset(r["steps"], r["hack_s"])) is not None]
|
|
if onsets:
|
|
s0 = float(np.mean(onsets))
|
|
ax.axvline(s0, color=color, lw=0.7, ls=(0, (2, 3)), alpha=0.5, zorder=0)
|
|
|
|
ax.set_ylim(0, 1)
|
|
ax.set_xlabel("optimizer step")
|
|
ax.set_ylabel("rate")
|
|
ax.spines["top"].set_visible(False)
|
|
ax.spines["right"].set_visible(False)
|
|
ax.tick_params(labelsize=8)
|
|
ax.set_title("Loophole emergence under vanilla GRPO "
|
|
"(solid=hack/exploited, dashed=solve/gt_correct; EMA-5)", fontsize=9)
|
|
fig.tight_layout()
|
|
out.parent.mkdir(parents=True, exist_ok=True)
|
|
fig.savefig(out, dpi=150, bbox_inches="tight")
|
|
logger.info(f"wrote {out} (modes={modes})")
|
|
|
|
|
|
def main() -> None:
|
|
ap = argparse.ArgumentParser(description=__doc__)
|
|
ap.add_argument("logs", nargs="+", help="log files, globs, or dirs")
|
|
ap.add_argument("--out", type=Path, default=Path("out/figs/emergence.png"))
|
|
args = ap.parse_args()
|
|
files = _gather(args.logs)
|
|
by_mode: dict[str, list[dict]] = defaultdict(list)
|
|
for f in files:
|
|
r = parse_log(f)
|
|
if r is None:
|
|
continue
|
|
by_mode[_env_mode(f)].append(r)
|
|
if not by_mode:
|
|
raise SystemExit(f"no parseable runs in {len(files)} files")
|
|
for mode, rs in by_mode.items():
|
|
logger.info(f"{mode:14s} {len(rs)} run(s), steps={[len(r['steps']) for r in rs]}")
|
|
plot(by_mode, args.out)
|
|
logger.info(f"docs/figs latest -> {link_latest(args.out)}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|