Files
evil_MoE/scripts/plot_dynamics.py
T
wassname ff82fbb940 plot_dynamics: per-step deploy curve from hk_abl + routing2 arm
The routing arms' benefit shows on the DEPLOYED model (quarantine deleted).
Prefer the dense per-step proxy hk_abl/slv_abl (every step, rollout_ablate_frac>0)
over the sparse held-out hk_dep eval for the plotted hack_s/gt_s curve; fall back
to hk_dep for runs that predate the proxy.

- parse hk_abl/slv_abl; routing+routing2 substitute it (else hk_dep) into hack_s/gt_s
- classify/ARM_ORDER/ARM_COLORS recognise routing2
- gate cos cols (cin_t/cin_s) by presence: vanilla/routing2 lack them, so parse
  and panels skip them instead of KeyError (also fixes a pre-existing vanilla crash)

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
2026-06-01 06:25:04 +00:00

355 lines
16 KiB
Python

"""Per-step training-dynamics small multiples: vanilla vs static vs online erasure.
Tufte small multiples. Columns = arm (vanilla / static G_hack erasure /
online G_hack erasure); rows = metric group:
row 0 hack_s + solve(gt_s) student reward-hack rate vs ground-truth solve
row 1 cos_pre_t + cos_pre_s live-grad alignment with v_hack (teacher / student)
Each panel overlays one thin line per seed and one bold mean line. The first
step where the student starts hacking (hack_s > 0) is marked per seed with an
open tick on the hack curve -- the onset point, which is where cos_pre_t starts
to diverge from the (refreshed) v_hack.
Data source: logs/*.log per-step rows (the durable source results.py also uses).
We parse by HEADER NAME, not fixed index, because newer runs add columns (refr).
Arm classification (from the preset line `arm=`, covering old --arm and new
--intervention logs):
vanilla arm=vanilla (intervention=none)
static erasure arm=projected, no --vhack-refresh-every (frozen v_hack)
online erasure arm=projected, --vhack-refresh-every=N>0 (re-extracted)
routing arm=routing (intervention=route)
For routing we plot the DEPLOY-eval hack/solve (hack_deploy/solve_deploy, the
deployed model = quarantine knob deleted, measured every --eval-ablate-every steps),
NOT the training-time hack_s: the routed forward still hacks during training, so the
training curve would falsely read "route doesn't work". The deploy curve is the deployment
model. (none/erase plot training-time hack_s; their intervention acts at train
time.)
Usage:
uv run python scripts/plot_dynamics.py logs/*converge*.log
uv run python scripts/plot_dynamics.py logs/ # whole dir
uv run python scripts/plot_dynamics.py A.log B.log --out out/dynamics.png
Scales to 3 seeds x 3 arms: pass all 9 logs, they auto-group by (arm, seed).
"""
from __future__ import annotations
import argparse
import re
from collections import defaultdict
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from loguru import logger
from projected_grpo.figs import link_latest
# --- parse -----------------------------------------------------------------
# Series we plot, by cleaned header name. frac "7/28" -> 0.25; float "+0.264".
RATE_COLS = {"hack_s": "hack", "gt_s": "solve"}
# Current streaming-table display headers (StepLogger _Col.header): the live-grad
# v_hack alignment prints as cin_t/cin_s, the route deploy-eval as hk_dep/slv_dep.
COS_COLS = {"cin_t": "teacher", "cin_s": "student"}
_HDR_TOK = re.compile(r"[A-Za-z_]+") # strip ↑↓? decorations: "hack_s?" -> "hack_s"
def _val(tok: str) -> float | None:
"""Parse a per-step cell: frac n/d, signed float, or T/F/-/nan."""
if "/" in tok:
a, b = tok.split("/")
return int(a) / int(b) if int(b) else None
if tok in ("T", "F", "-", "nan"):
return None
return float(tok)
def parse_log(path: Path) -> dict | None:
"""Return {arm, refr, seed, vhack, steps: int[], <series>: float[]} or None."""
txt = path.read_text(errors="replace")
argv = next((l for l in txt.splitlines() if "argv:" in l), None)
preset = next((l for l in txt.splitlines() if "preset=" in l and "arm=" in l), "")
if argv is None:
return None
def grab(pat, s, default=None):
ms = re.findall(pat, s)
return ms[-1] if ms else default
# arm = derived display name in the preset line (vanilla/projected/routing),
# the one source that covers both old (--arm) and new (--intervention) logs.
arm = grab(r"\barm=(\w+)", preset, "vanilla")
refr = int(grab(r"--vhack-refresh-every=(\d+)", argv, "0"))
seed = grab(r"seed=(\d+)", preset, "?")
vhack = grab(r"v-hack-path=out/(?:vhack/)?(\S+?)\.safetensors", argv, "-")
# header line: the one containing both "step" and "hack_s"
hdr = next((l for l in txt.splitlines() if "ref_eq" in l and "hack_s" in l), None)
if hdr is None:
return None
names = [_HDR_TOK.match(t).group(0) for t in hdr.split("| INFO |", 1)[1].split()]
idx = {n: i for i, n in enumerate(names)}
series: dict[str, list[float]] = defaultdict(list)
steps: list[int] = []
# Also parse the route DEPLOY-eval columns when present (non-route logs lack
# them -> skip). For routing we plot THESE (deployed model = quarantine deleted),
# not the training-time hack_s.
# hk_abl/slv_abl = the FREE per-step deploy proxy (ablated rollout slice,
# rollout_ablate_frac>0); hk_dep/slv_dep = the held-out greedy eval, only on
# eval_ablate_every steps. Prefer the dense proxy for the curve (see below).
deploy = {"hk_dep", "slv_dep", "hk_abl", "slv_abl"} & set(idx)
# Only parse columns this log actually has: non-projecting arms (vanilla,
# routing2) lack cin_t/cin_s, so gate by presence rather than KeyError.
wanted = {k: v for k, v in {**RATE_COLS, **COS_COLS}.items() if k in idx}
wanted.update({c: c for c in deploy})
for line in txt.splitlines():
if "| INFO |" not in line:
continue
row = line.split("| INFO |", 1)[1].split()
if not row or not row[0].isdigit() or len(row) < len(names):
continue
steps.append(int(row[idx["step"]]))
for col in wanted:
series[col].append(_val(row[idx[col]]))
if not steps:
return None
run = dict(arm=arm, refr=refr, seed=seed, vhack=vhack,
steps=np.array(steps), **{k: np.array(v, dtype=float) for k, v in series.items()})
# COHERENCE-GAP FIX: routing's training-time hack_s looks vanilla (the routed
# forward still hacks); the benefit only shows on the DEPLOYED model
# (quarantine knob deleted). So for routing/routing2, plot the deploy series
# under the hack_s/gt_s keys -> all downstream (panels, onset, overlay) reads
# it. Prefer the DENSE per-step proxy (hk_abl, every step) over the sparse
# held-out eval (hk_dep, every eval_ablate_every steps); fall back to hk_dep
# for older runs that predate the proxy.
if arm in ("routing", "routing2"):
if "hk_abl" in run:
run["hack_s"] = run["hk_abl"]
run["gt_s"] = run["slv_abl"]
elif "hk_dep" in run:
run["hack_s"] = run["hk_dep"]
run["gt_s"] = run["slv_dep"]
return run
def classify(run: dict) -> str:
if run["arm"] == "vanilla":
return "vanilla"
if run["arm"] == "routing":
return "routing"
if run["arm"] == "routing2":
return "routing2"
# arm == projected -> erasure, split by refresh
return "online erasure" if run["refr"] > 0 else "static erasure"
# --- plot ------------------------------------------------------------------
ARM_ORDER = ["vanilla", "static erasure", "online erasure", "routing", "routing2"]
# Distinct colour per series -- the two rows measure different things, so they
# must not share a palette (hack != teacher-cos). Row 0: red hack vs green
# solve. Row 1: blue teacher-cos vs amber student-cos.
RATE_COLORS = {"hack_s": "#c1432b", "gt_s": "#2f7d4f"}
COS_COLORS = {"cin_t": "#33508c", "cin_s": "#c98a2b"}
# Arm colours for the single-panel hack overlay (arms, not series): grey vanilla
# baseline -> amber static -> blue online, ordered by increasing intervention.
# TODO(color): make this a quality-ordered red->green ramp instead of fixed
# per-arm hues -- red = vanilla (worst, most hacking), green = best method
# (anticipated gradient routing). As arms grow (static/online/grad-routing/
# confessions), assign colour by method rank along a perceptual RdYlGn ramp so
# the reader sees "redder = hacks more" at a glance.
ARM_COLORS = {"vanilla": "#7a7a7a", "static erasure": "#c98a2b",
"online erasure": "#33508c", "routing": "#2f7d4f",
"routing2": "#7d2f6f"}
def _onset(steps: np.ndarray, hack: np.ndarray) -> int | None:
"""First step where RAW hack_s > 0 (the hack-onset point). Computed on the
unsmoothed series -- EMA would blur the very step we want to mark."""
nz = np.flatnonzero(hack > 0)
return int(steps[nz[0]]) if len(nz) else None
def _ema(y: np.ndarray, span: int = 5) -> np.ndarray:
"""Causal EMA, span=5. Less lag than a trailing SMA(5) since it weights
recent steps more. NaNs hold the previous smoothed value (don't reset it)."""
a = 2.0 / (span + 1)
out = np.empty_like(y)
prev = np.nan
for i, v in enumerate(y):
if np.isnan(v):
out[i] = prev
else:
prev = v if np.isnan(prev) else a * v + (1 - a) * prev
out[i] = prev
return out
def _series_panel(ax, runs, cols, colors, ylim, label_series=False):
"""Overlay per-seed thin EMA lines + bold mean-of-EMA for each series."""
ends = [] # (endpoint_y, label, color) for direct labels
for col, label in cols.items():
color = colors[col]
stacked = []
present = [r for r in runs if col in r]
if not present: # arm lacks this series (e.g. no cos cols for routing2/vanilla)
continue
for r in present:
ys = _ema(r[col])
ax.plot(r["steps"], ys, color=color, lw=0.7, alpha=0.35, solid_capstyle="round")
stacked.append(ys)
# mean over seeds of the smoothed series (runs share the step grid within an arm)
L = min(len(y) for y in stacked)
ym = np.nanmean(np.stack([y[:L] for y in stacked]), axis=0)
xm = runs[0]["steps"][:L]
ax.plot(xm, ym, color=color, lw=1.8, solid_capstyle="round")
ends.append((ym[-1], xm[-1], label, color))
# Direct labels in the leftmost column only -- colour carries the series
# across the row, so per-panel repeats are redundant ink. Nudge by the
# ACTUAL endpoint ordering (higher line -> label up, lower -> down): the two
# cos lines cross, so a fixed up/down stagger would land each label on the
# wrong line.
if label_series:
ends.sort(key=lambda e: e[0]) # lowest endpoint first
dy = {0: -6, len(ends) - 1: 6} if len(ends) > 1 else {0: 0}
for rank, (y, x, label, color) in enumerate(ends):
ax.annotate(label, (x, y), color=color, fontsize=8,
xytext=(3, dy.get(rank, 0)), textcoords="offset points", va="center")
if ylim:
ax.set_ylim(*ylim)
def plot(runs: list[dict], out: Path) -> None:
by_arm: dict[str, list[dict]] = defaultdict(list)
for r in runs:
by_arm[classify(r)].append(r)
arms = [a for a in ARM_ORDER if a in by_arm]
if not arms:
raise SystemExit("no runs classified into arms")
fig, axes = plt.subplots(2, len(arms), figsize=(3.0 * len(arms), 4.4),
sharex=True, sharey="row", squeeze=False)
_cos_vals = [np.nanmin(r[c]) for r in runs for c in COS_COLS if c in r]
cos_lo = min(_cos_vals) if _cos_vals else 0.0
for col, arm in enumerate(arms):
rs = by_arm[arm]
n_seed = len({r["seed"] for r in rs})
axes[0][col].set_title(f"{arm}\n(n={n_seed} seed{'s' if n_seed > 1 else ''})",
fontsize=9)
_series_panel(axes[0][col], rs, RATE_COLS, RATE_COLORS, ylim=(0, 1),
label_series=(col == 0))
_series_panel(axes[1][col], rs, COS_COLS, COS_COLORS,
ylim=(min(-0.05, cos_lo - 0.02), 0.45), label_series=(col == 0))
axes[1][col].axhline(0, color="0.8", lw=0.6, zorder=0)
axes[1][col].set_xlabel("optimizer step")
# Mean hack-onset: one dashed vertical reference line spanning BOTH rows
# so the cos-divergence can be read against the moment hacking starts.
onsets = [s for r in rs if (s := _onset(r["steps"], r["hack_s"])) is not None]
if onsets:
s0 = float(np.mean(onsets))
for row in (0, 1):
axes[row][col].axvline(s0, color="0.55", lw=0.8, ls=(0, (4, 3)), zorder=0)
axes[0][col].annotate("first hack", (s0, 1.0), color="0.4", fontsize=7,
xytext=(2, -2), textcoords="offset points", va="top")
axes[0][0].set_ylabel("student rate")
axes[1][0].set_ylabel("cos(grad, v_hack)")
# range-frame: drop top/right spines, keep ink on data
for ax in axes.flat:
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.tick_params(labelsize=8)
fig.suptitle("Training dynamics: G_hack erasure vs vanilla "
"(EMA-5 smoothed; dashed line = mean hack onset)", fontsize=10)
fig.tight_layout(rect=(0, 0, 1, 0.96))
out.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out, dpi=150, bbox_inches="tight")
logger.info(f"wrote {out} ({len(runs)} runs, arms={arms})")
def plot_hack_overlay(runs: list[dict], out: Path) -> None:
"""Single panel: mean hack_s per arm overlaid, the headline arm-vs-arm view.
Faint per-seed lines for spread, bold EMA mean per arm, onset tick at the
arm's mean hack-onset, direct-labelled (no legend)."""
by_arm: dict[str, list[dict]] = defaultdict(list)
for r in runs:
by_arm[classify(r)].append(r)
arms = [a for a in ARM_ORDER if a in by_arm]
fig, ax = plt.subplots(figsize=(5.2, 3.4))
for arm in arms:
rs = by_arm[arm]
color = ARM_COLORS[arm]
stacked = []
for r in rs:
ys = _ema(r["hack_s"])
ax.plot(r["steps"], ys, color=color, lw=0.6, alpha=0.25, solid_capstyle="round")
stacked.append(ys)
L = min(len(y) for y in stacked)
ym = np.nanmean(np.stack([y[:L] for y in stacked]), axis=0)
xm = rs[0]["steps"][:L]
ax.plot(xm, ym, color=color, lw=2.0, solid_capstyle="round")
onsets = [s for r in rs if (s := _onset(r["steps"], r["hack_s"])) is not None]
if onsets:
s0 = float(np.mean(onsets))
ax.plot(s0, np.interp(s0, xm, ym), marker="o", ms=4, color=color, zorder=3)
ax.annotate(arm, (xm[-1], ym[-1]), color=color, fontsize=8,
xytext=(4, 0), textcoords="offset points", va="center")
ax.set_ylim(0, 1)
ax.set_xlabel("optimizer step")
ax.set_ylabel("student hack rate (hack_s)")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.tick_params(labelsize=8)
ax.set_title("Student hack rate by arm (EMA-5; dot = mean onset)", fontsize=10)
fig.tight_layout()
out.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out, dpi=150, bbox_inches="tight")
logger.info(f"wrote {out}")
# --- cli -------------------------------------------------------------------
def _gather(paths: list[str]) -> list[Path]:
out: list[Path] = []
for p in paths:
pp = Path(p)
if pp.is_dir():
out += sorted(pp.glob("*.log"))
elif any(c in p for c in "*?["):
out += sorted(Path().glob(p))
else:
out.append(pp)
return out
def main() -> None:
ap = argparse.ArgumentParser(description=__doc__)
ap.add_argument("logs", nargs="+", help="log files, globs, or dirs")
ap.add_argument("--out", type=Path, default=Path("out/figs/dynamics.png"))
args = ap.parse_args()
files = _gather(args.logs)
runs = [r for f in files if (r := parse_log(f))]
if not runs:
raise SystemExit(f"no parseable runs in {len(files)} files")
for r in runs:
logger.info(f"{classify(r):16s} seed={r['seed']} steps={len(r['steps'])} {r['vhack']}")
args.out.parent.mkdir(parents=True, exist_ok=True)
plot(runs, args.out)
# second figure: single-panel arm-vs-arm overlay of the headline metric
overlay = args.out.with_name(args.out.stem + "_hack_overlay.png")
plot_hack_overlay(runs, overlay)
for p in (args.out, overlay):
logger.info(f"docs/figs latest -> {link_latest(p)}")
if __name__ == "__main__":
main()