"""Per-step training-dynamics small multiples: vanilla vs static vs online erasure. Tufte small multiples. Columns = arm (vanilla / static G_hack erasure / online G_hack erasure); rows = metric group: row 0 hack_s + solve(gt_s) student reward-hack rate vs ground-truth solve row 1 cos_pre_t + cos_pre_s live-grad alignment with v_hack (teacher / student) Each panel overlays one thin line per seed and one bold mean line. The first step where the student starts hacking (hack_s > 0) is marked per seed with an open tick on the hack curve -- the onset point, which is where cos_pre_t starts to diverge from the (refreshed) v_hack. Data source: logs/*.log per-step rows (the durable source results.py also uses). We parse by HEADER NAME, not fixed index, because newer runs add columns (refr). Arm classification (from the preset line `arm=`, covering old --arm and new --intervention logs): vanilla arm=vanilla (intervention=none) static erasure arm=projected, no --vhack-refresh-every (frozen v_hack) online erasure arm=projected, --vhack-refresh-every=N>0 (re-extracted) routing arm=routing (intervention=route) For routing we plot the ABLATED-eval hack/solve (hack_abl/solve_abl, measured with delta_S_hack zeroed every --eval-ablate-every steps), NOT the training-time hack_s: the routed forward still hacks during training, so the training curve would falsely read "route doesn't work". The ablated curve is the deployment model. (none/erase plot training-time hack_s; their intervention acts at train time.) Usage: uv run python scripts/plot_dynamics.py logs/*converge*.log uv run python scripts/plot_dynamics.py logs/ # whole dir uv run python scripts/plot_dynamics.py A.log B.log --out out/dynamics.png Scales to 3 seeds x 3 arms: pass all 9 logs, they auto-group by (arm, seed). """ from __future__ import annotations import argparse import re from collections import defaultdict from pathlib import Path import matplotlib.pyplot as plt import numpy as np from loguru import logger # --- parse ----------------------------------------------------------------- # Series we plot, by cleaned header name. frac "7/28" -> 0.25; float "+0.264". RATE_COLS = {"hack_s": "hack", "gt_s": "solve"} COS_COLS = {"cos_pre_t": "teacher", "cos_pre_s": "student"} _HDR_TOK = re.compile(r"[A-Za-z_]+") # strip ↑↓? decorations: "hack_s?" -> "hack_s" def _val(tok: str) -> float | None: """Parse a per-step cell: frac n/d, signed float, or T/F/-/nan.""" if "/" in tok: a, b = tok.split("/") return int(a) / int(b) if int(b) else None if tok in ("T", "F", "-", "nan"): return None return float(tok) def parse_log(path: Path) -> dict | None: """Return {arm, refr, seed, vhack, steps: int[], : float[]} or None.""" txt = path.read_text(errors="replace") argv = next((l for l in txt.splitlines() if "argv:" in l), None) preset = next((l for l in txt.splitlines() if "preset=" in l and "arm=" in l), "") if argv is None: return None def grab(pat, s, default=None): ms = re.findall(pat, s) return ms[-1] if ms else default # arm = derived display name in the preset line (vanilla/projected/routing), # the one source that covers both old (--arm) and new (--intervention) logs. arm = grab(r"\barm=(\w+)", preset, "vanilla") refr = int(grab(r"--vhack-refresh-every=(\d+)", argv, "0")) seed = grab(r"seed=(\d+)", preset, "?") vhack = grab(r"v-hack-path=out/(\S+?)\.safetensors", argv, "-") # header line: the one containing both "step" and "hack_s" hdr = next((l for l in txt.splitlines() if "ref_eq" in l and "hack_s" in l), None) if hdr is None: return None names = [_HDR_TOK.match(t).group(0) for t in hdr.split("| INFO |", 1)[1].split()] idx = {n: i for i, n in enumerate(names)} series: dict[str, list[float]] = defaultdict(list) steps: list[int] = [] # Also parse the route ablated-eval columns when present (older logs lack # them -> skip). For routing we plot THESE, not the training-time hack_s. abl = {"hack_abl", "solve_abl"} & set(idx) wanted = {**RATE_COLS, **COS_COLS, **{c: c for c in abl}} for line in txt.splitlines(): if "| INFO |" not in line: continue row = line.split("| INFO |", 1)[1].split() if not row or not row[0].isdigit() or len(row) < len(names): continue steps.append(int(row[idx["step"]])) for col in wanted: series[col].append(_val(row[idx[col]])) if not steps: return None run = dict(arm=arm, refr=refr, seed=seed, vhack=vhack, steps=np.array(steps), **{k: np.array(v, dtype=float) for k, v in series.items()}) # COHERENCE-GAP FIX: route's training-time hack_s looks vanilla (the routed # forward still hacks); routing's benefit only shows once delta_S_hack is # ablated at eval. So for routing, plot the ablated series under the same # hack_s/gt_s keys -> all downstream (panels, onset, overlay) reads it. if arm == "routing" and "hack_abl" in run: run["hack_s"] = run["hack_abl"] run["gt_s"] = run["solve_abl"] return run def classify(run: dict) -> str: if run["arm"] == "vanilla": return "vanilla" if run["arm"] == "routing": return "routing" # arm == projected -> erasure, split by refresh return "online erasure" if run["refr"] > 0 else "static erasure" # --- plot ------------------------------------------------------------------ ARM_ORDER = ["vanilla", "static erasure", "online erasure", "routing"] # Distinct colour per series -- the two rows measure different things, so they # must not share a palette (hack != teacher-cos). Row 0: red hack vs green # solve. Row 1: blue teacher-cos vs amber student-cos. RATE_COLORS = {"hack_s": "#c1432b", "gt_s": "#2f7d4f"} COS_COLORS = {"cos_pre_t": "#33508c", "cos_pre_s": "#c98a2b"} # Arm colours for the single-panel hack overlay (arms, not series): grey vanilla # baseline -> amber static -> blue online, ordered by increasing intervention. # TODO(color): make this a quality-ordered red->green ramp instead of fixed # per-arm hues -- red = vanilla (worst, most hacking), green = best method # (anticipated gradient routing). As arms grow (static/online/grad-routing/ # confessions), assign colour by method rank along a perceptual RdYlGn ramp so # the reader sees "redder = hacks more" at a glance. ARM_COLORS = {"vanilla": "#7a7a7a", "static erasure": "#c98a2b", "online erasure": "#33508c", "routing": "#2f7d4f"} def _onset(steps: np.ndarray, hack: np.ndarray) -> int | None: """First step where RAW hack_s > 0 (the hack-onset point). Computed on the unsmoothed series -- EMA would blur the very step we want to mark.""" nz = np.flatnonzero(hack > 0) return int(steps[nz[0]]) if len(nz) else None def _ema(y: np.ndarray, span: int = 5) -> np.ndarray: """Causal EMA, span=5. Less lag than a trailing SMA(5) since it weights recent steps more. NaNs hold the previous smoothed value (don't reset it).""" a = 2.0 / (span + 1) out = np.empty_like(y) prev = np.nan for i, v in enumerate(y): if np.isnan(v): out[i] = prev else: prev = v if np.isnan(prev) else a * v + (1 - a) * prev out[i] = prev return out def _series_panel(ax, runs, cols, colors, ylim, label_series=False): """Overlay per-seed thin EMA lines + bold mean-of-EMA for each series.""" ends = [] # (endpoint_y, label, color) for direct labels for col, label in cols.items(): color = colors[col] stacked = [] for r in runs: ys = _ema(r[col]) ax.plot(r["steps"], ys, color=color, lw=0.7, alpha=0.35, solid_capstyle="round") stacked.append(ys) # mean over seeds of the smoothed series (runs share the step grid within an arm) L = min(len(y) for y in stacked) ym = np.nanmean(np.stack([y[:L] for y in stacked]), axis=0) xm = runs[0]["steps"][:L] ax.plot(xm, ym, color=color, lw=1.8, solid_capstyle="round") ends.append((ym[-1], xm[-1], label, color)) # Direct labels in the leftmost column only -- colour carries the series # across the row, so per-panel repeats are redundant ink. Nudge by the # ACTUAL endpoint ordering (higher line -> label up, lower -> down): the two # cos lines cross, so a fixed up/down stagger would land each label on the # wrong line. if label_series: ends.sort(key=lambda e: e[0]) # lowest endpoint first dy = {0: -6, len(ends) - 1: 6} if len(ends) > 1 else {0: 0} for rank, (y, x, label, color) in enumerate(ends): ax.annotate(label, (x, y), color=color, fontsize=8, xytext=(3, dy.get(rank, 0)), textcoords="offset points", va="center") if ylim: ax.set_ylim(*ylim) def plot(runs: list[dict], out: Path) -> None: by_arm: dict[str, list[dict]] = defaultdict(list) for r in runs: by_arm[classify(r)].append(r) arms = [a for a in ARM_ORDER if a in by_arm] if not arms: raise SystemExit("no runs classified into arms") fig, axes = plt.subplots(2, len(arms), figsize=(3.0 * len(arms), 4.4), sharex=True, sharey="row", squeeze=False) cos_lo = min(np.nanmin(r[c]) for r in runs for c in COS_COLS) for col, arm in enumerate(arms): rs = by_arm[arm] n_seed = len({r["seed"] for r in rs}) axes[0][col].set_title(f"{arm}\n(n={n_seed} seed{'s' if n_seed > 1 else ''})", fontsize=9) _series_panel(axes[0][col], rs, RATE_COLS, RATE_COLORS, ylim=(0, 1), label_series=(col == 0)) _series_panel(axes[1][col], rs, COS_COLS, COS_COLORS, ylim=(min(-0.05, cos_lo - 0.02), 0.45), label_series=(col == 0)) axes[1][col].axhline(0, color="0.8", lw=0.6, zorder=0) axes[1][col].set_xlabel("optimizer step") # Mean hack-onset: one dashed vertical reference line spanning BOTH rows # so the cos-divergence can be read against the moment hacking starts. onsets = [s for r in rs if (s := _onset(r["steps"], r["hack_s"])) is not None] if onsets: s0 = float(np.mean(onsets)) for row in (0, 1): axes[row][col].axvline(s0, color="0.55", lw=0.8, ls=(0, (4, 3)), zorder=0) axes[0][col].annotate("first hack", (s0, 1.0), color="0.4", fontsize=7, xytext=(2, -2), textcoords="offset points", va="top") axes[0][0].set_ylabel("student rate") axes[1][0].set_ylabel("cos(grad, v_hack)") # range-frame: drop top/right spines, keep ink on data for ax in axes.flat: ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) ax.tick_params(labelsize=8) fig.suptitle("Training dynamics: G_hack erasure vs vanilla " "(EMA-5 smoothed; dashed line = mean hack onset)", fontsize=10) fig.tight_layout(rect=(0, 0, 1, 0.96)) out.parent.mkdir(parents=True, exist_ok=True) fig.savefig(out, dpi=150, bbox_inches="tight") logger.info(f"wrote {out} ({len(runs)} runs, arms={arms})") def plot_hack_overlay(runs: list[dict], out: Path) -> None: """Single panel: mean hack_s per arm overlaid, the headline arm-vs-arm view. Faint per-seed lines for spread, bold EMA mean per arm, onset tick at the arm's mean hack-onset, direct-labelled (no legend).""" by_arm: dict[str, list[dict]] = defaultdict(list) for r in runs: by_arm[classify(r)].append(r) arms = [a for a in ARM_ORDER if a in by_arm] fig, ax = plt.subplots(figsize=(5.2, 3.4)) for arm in arms: rs = by_arm[arm] color = ARM_COLORS[arm] stacked = [] for r in rs: ys = _ema(r["hack_s"]) ax.plot(r["steps"], ys, color=color, lw=0.6, alpha=0.25, solid_capstyle="round") stacked.append(ys) L = min(len(y) for y in stacked) ym = np.nanmean(np.stack([y[:L] for y in stacked]), axis=0) xm = rs[0]["steps"][:L] ax.plot(xm, ym, color=color, lw=2.0, solid_capstyle="round") onsets = [s for r in rs if (s := _onset(r["steps"], r["hack_s"])) is not None] if onsets: s0 = float(np.mean(onsets)) ax.plot(s0, np.interp(s0, xm, ym), marker="o", ms=4, color=color, zorder=3) ax.annotate(arm, (xm[-1], ym[-1]), color=color, fontsize=8, xytext=(4, 0), textcoords="offset points", va="center") ax.set_ylim(0, 1) ax.set_xlabel("optimizer step") ax.set_ylabel("student hack rate (hack_s)") ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) ax.tick_params(labelsize=8) ax.set_title("Student hack rate by arm (EMA-5; dot = mean onset)", fontsize=10) fig.tight_layout() out.parent.mkdir(parents=True, exist_ok=True) fig.savefig(out, dpi=150, bbox_inches="tight") logger.info(f"wrote {out}") # --- cli ------------------------------------------------------------------- def _gather(paths: list[str]) -> list[Path]: out: list[Path] = [] for p in paths: pp = Path(p) if pp.is_dir(): out += sorted(pp.glob("*.log")) elif any(c in p for c in "*?["): out += sorted(Path().glob(p)) else: out.append(pp) return out def main() -> None: ap = argparse.ArgumentParser(description=__doc__) ap.add_argument("logs", nargs="+", help="log files, globs, or dirs") ap.add_argument("--out", type=Path, default=Path("out/dynamics.png")) args = ap.parse_args() files = _gather(args.logs) runs = [r for f in files if (r := parse_log(f))] if not runs: raise SystemExit(f"no parseable runs in {len(files)} files") for r in runs: logger.info(f"{classify(r):16s} seed={r['seed']} steps={len(r['steps'])} {r['vhack']}") plot(runs, args.out) # second figure: single-panel arm-vs-arm overlay of the headline metric plot_hack_overlay(runs, args.out.with_name(args.out.stem + "_hack_overlay.png")) if __name__ == "__main__": main()