mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-28 02:45:47 +08:00
4621488cc0
Code writes+reads the new scheme; migrate_out_dirs.py moved 225 loose artifacts (0 left at top level). Per-run checkpoints+rollouts now group under runs/<ts>_<run_id>/ as train.safetensors/rollouts.jsonl. Figures land in out/figs/ with a stable docs/figs/<name>.png symlink (figs.link_latest). justfile also gains run-cell REFRESH param (online-erasure arm). Smoke + smoke-vanilla + results all green on new paths. Requeue manifest preserves the why/resolve labels that pueue reset wiped. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
336 lines
15 KiB
Python
336 lines
15 KiB
Python
"""Per-step training-dynamics small multiples: vanilla vs static vs online erasure.
|
|
|
|
Tufte small multiples. Columns = arm (vanilla / static G_hack erasure /
|
|
online G_hack erasure); rows = metric group:
|
|
row 0 hack_s + solve(gt_s) student reward-hack rate vs ground-truth solve
|
|
row 1 cos_pre_t + cos_pre_s live-grad alignment with v_hack (teacher / student)
|
|
|
|
Each panel overlays one thin line per seed and one bold mean line. The first
|
|
step where the student starts hacking (hack_s > 0) is marked per seed with an
|
|
open tick on the hack curve -- the onset point, which is where cos_pre_t starts
|
|
to diverge from the (refreshed) v_hack.
|
|
|
|
Data source: logs/*.log per-step rows (the durable source results.py also uses).
|
|
We parse by HEADER NAME, not fixed index, because newer runs add columns (refr).
|
|
|
|
Arm classification (from the preset line `arm=`, covering old --arm and new
|
|
--intervention logs):
|
|
vanilla arm=vanilla (intervention=none)
|
|
static erasure arm=projected, no --vhack-refresh-every (frozen v_hack)
|
|
online erasure arm=projected, --vhack-refresh-every=N>0 (re-extracted)
|
|
routing arm=routing (intervention=route)
|
|
|
|
For routing we plot the SHIP-eval hack/solve (hack_ship/solve_ship, the deployed
|
|
model = quarantine knob deleted, measured every --eval-ablate-every steps), NOT
|
|
the training-time hack_s: the routed forward still hacks during training, so the training curve
|
|
would falsely read "route doesn't work". The ablated curve is the deployment
|
|
model. (none/erase plot training-time hack_s; their intervention acts at train
|
|
time.)
|
|
|
|
Usage:
|
|
uv run python scripts/plot_dynamics.py logs/*converge*.log
|
|
uv run python scripts/plot_dynamics.py logs/ # whole dir
|
|
uv run python scripts/plot_dynamics.py A.log B.log --out out/dynamics.png
|
|
|
|
Scales to 3 seeds x 3 arms: pass all 9 logs, they auto-group by (arm, seed).
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import re
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from loguru import logger
|
|
|
|
from projected_grpo.figs import link_latest
|
|
|
|
# --- parse -----------------------------------------------------------------
|
|
|
|
# Series we plot, by cleaned header name. frac "7/28" -> 0.25; float "+0.264".
|
|
RATE_COLS = {"hack_s": "hack", "gt_s": "solve"}
|
|
COS_COLS = {"cos_pre_t": "teacher", "cos_pre_s": "student"}
|
|
_HDR_TOK = re.compile(r"[A-Za-z_]+") # strip ↑↓? decorations: "hack_s?" -> "hack_s"
|
|
|
|
|
|
def _val(tok: str) -> float | None:
|
|
"""Parse a per-step cell: frac n/d, signed float, or T/F/-/nan."""
|
|
if "/" in tok:
|
|
a, b = tok.split("/")
|
|
return int(a) / int(b) if int(b) else None
|
|
if tok in ("T", "F", "-", "nan"):
|
|
return None
|
|
return float(tok)
|
|
|
|
|
|
def parse_log(path: Path) -> dict | None:
|
|
"""Return {arm, refr, seed, vhack, steps: int[], <series>: float[]} or None."""
|
|
txt = path.read_text(errors="replace")
|
|
argv = next((l for l in txt.splitlines() if "argv:" in l), None)
|
|
preset = next((l for l in txt.splitlines() if "preset=" in l and "arm=" in l), "")
|
|
if argv is None:
|
|
return None
|
|
|
|
def grab(pat, s, default=None):
|
|
ms = re.findall(pat, s)
|
|
return ms[-1] if ms else default
|
|
|
|
# arm = derived display name in the preset line (vanilla/projected/routing),
|
|
# the one source that covers both old (--arm) and new (--intervention) logs.
|
|
arm = grab(r"\barm=(\w+)", preset, "vanilla")
|
|
refr = int(grab(r"--vhack-refresh-every=(\d+)", argv, "0"))
|
|
seed = grab(r"seed=(\d+)", preset, "?")
|
|
vhack = grab(r"v-hack-path=out/(?:vhack/)?(\S+?)\.safetensors", argv, "-")
|
|
|
|
# header line: the one containing both "step" and "hack_s"
|
|
hdr = next((l for l in txt.splitlines() if "ref_eq" in l and "hack_s" in l), None)
|
|
if hdr is None:
|
|
return None
|
|
names = [_HDR_TOK.match(t).group(0) for t in hdr.split("| INFO |", 1)[1].split()]
|
|
idx = {n: i for i, n in enumerate(names)}
|
|
|
|
series: dict[str, list[float]] = defaultdict(list)
|
|
steps: list[int] = []
|
|
# Also parse the route SHIP-eval columns when present (older logs lack them
|
|
# -> skip). For routing we plot THESE (deployed model), not training-time
|
|
# hack_s. Renamed hack_abl/solve_abl -> hack_ship/solve_ship 2026-05-30;
|
|
# accept both so old evidence logs still parse.
|
|
ship = {"hack_abl", "solve_abl", "hack_ship", "solve_ship"} & set(idx)
|
|
wanted = {**RATE_COLS, **COS_COLS, **{c: c for c in ship}}
|
|
for line in txt.splitlines():
|
|
if "| INFO |" not in line:
|
|
continue
|
|
row = line.split("| INFO |", 1)[1].split()
|
|
if not row or not row[0].isdigit() or len(row) < len(names):
|
|
continue
|
|
steps.append(int(row[idx["step"]]))
|
|
for col in wanted:
|
|
series[col].append(_val(row[idx[col]]))
|
|
if not steps:
|
|
return None
|
|
run = dict(arm=arm, refr=refr, seed=seed, vhack=vhack,
|
|
steps=np.array(steps), **{k: np.array(v, dtype=float) for k, v in series.items()})
|
|
# COHERENCE-GAP FIX: route's training-time hack_s looks vanilla (the routed
|
|
# forward still hacks); routing's benefit only shows on the DEPLOYED model
|
|
# (quarantine knob deleted). So for routing, plot the ship series under the
|
|
# hack_s/gt_s keys -> all downstream (panels, onset, overlay) reads it.
|
|
if arm == "routing":
|
|
hk = "hack_ship" if "hack_ship" in run else "hack_abl" if "hack_abl" in run else None
|
|
if hk:
|
|
run["hack_s"] = run["hack_ship" if "hack_ship" in run else "hack_abl"]
|
|
run["gt_s"] = run["solve_ship" if "solve_ship" in run else "solve_abl"]
|
|
return run
|
|
|
|
|
|
def classify(run: dict) -> str:
|
|
if run["arm"] == "vanilla":
|
|
return "vanilla"
|
|
if run["arm"] == "routing":
|
|
return "routing"
|
|
# arm == projected -> erasure, split by refresh
|
|
return "online erasure" if run["refr"] > 0 else "static erasure"
|
|
|
|
|
|
# --- plot ------------------------------------------------------------------
|
|
|
|
ARM_ORDER = ["vanilla", "static erasure", "online erasure", "routing"]
|
|
# Distinct colour per series -- the two rows measure different things, so they
|
|
# must not share a palette (hack != teacher-cos). Row 0: red hack vs green
|
|
# solve. Row 1: blue teacher-cos vs amber student-cos.
|
|
RATE_COLORS = {"hack_s": "#c1432b", "gt_s": "#2f7d4f"}
|
|
COS_COLORS = {"cos_pre_t": "#33508c", "cos_pre_s": "#c98a2b"}
|
|
# Arm colours for the single-panel hack overlay (arms, not series): grey vanilla
|
|
# baseline -> amber static -> blue online, ordered by increasing intervention.
|
|
# TODO(color): make this a quality-ordered red->green ramp instead of fixed
|
|
# per-arm hues -- red = vanilla (worst, most hacking), green = best method
|
|
# (anticipated gradient routing). As arms grow (static/online/grad-routing/
|
|
# confessions), assign colour by method rank along a perceptual RdYlGn ramp so
|
|
# the reader sees "redder = hacks more" at a glance.
|
|
ARM_COLORS = {"vanilla": "#7a7a7a", "static erasure": "#c98a2b",
|
|
"online erasure": "#33508c", "routing": "#2f7d4f"}
|
|
|
|
|
|
def _onset(steps: np.ndarray, hack: np.ndarray) -> int | None:
|
|
"""First step where RAW hack_s > 0 (the hack-onset point). Computed on the
|
|
unsmoothed series -- EMA would blur the very step we want to mark."""
|
|
nz = np.flatnonzero(hack > 0)
|
|
return int(steps[nz[0]]) if len(nz) else None
|
|
|
|
|
|
def _ema(y: np.ndarray, span: int = 5) -> np.ndarray:
|
|
"""Causal EMA, span=5. Less lag than a trailing SMA(5) since it weights
|
|
recent steps more. NaNs hold the previous smoothed value (don't reset it)."""
|
|
a = 2.0 / (span + 1)
|
|
out = np.empty_like(y)
|
|
prev = np.nan
|
|
for i, v in enumerate(y):
|
|
if np.isnan(v):
|
|
out[i] = prev
|
|
else:
|
|
prev = v if np.isnan(prev) else a * v + (1 - a) * prev
|
|
out[i] = prev
|
|
return out
|
|
|
|
|
|
def _series_panel(ax, runs, cols, colors, ylim, label_series=False):
|
|
"""Overlay per-seed thin EMA lines + bold mean-of-EMA for each series."""
|
|
ends = [] # (endpoint_y, label, color) for direct labels
|
|
for col, label in cols.items():
|
|
color = colors[col]
|
|
stacked = []
|
|
for r in runs:
|
|
ys = _ema(r[col])
|
|
ax.plot(r["steps"], ys, color=color, lw=0.7, alpha=0.35, solid_capstyle="round")
|
|
stacked.append(ys)
|
|
# mean over seeds of the smoothed series (runs share the step grid within an arm)
|
|
L = min(len(y) for y in stacked)
|
|
ym = np.nanmean(np.stack([y[:L] for y in stacked]), axis=0)
|
|
xm = runs[0]["steps"][:L]
|
|
ax.plot(xm, ym, color=color, lw=1.8, solid_capstyle="round")
|
|
ends.append((ym[-1], xm[-1], label, color))
|
|
# Direct labels in the leftmost column only -- colour carries the series
|
|
# across the row, so per-panel repeats are redundant ink. Nudge by the
|
|
# ACTUAL endpoint ordering (higher line -> label up, lower -> down): the two
|
|
# cos lines cross, so a fixed up/down stagger would land each label on the
|
|
# wrong line.
|
|
if label_series:
|
|
ends.sort(key=lambda e: e[0]) # lowest endpoint first
|
|
dy = {0: -6, len(ends) - 1: 6} if len(ends) > 1 else {0: 0}
|
|
for rank, (y, x, label, color) in enumerate(ends):
|
|
ax.annotate(label, (x, y), color=color, fontsize=8,
|
|
xytext=(3, dy.get(rank, 0)), textcoords="offset points", va="center")
|
|
if ylim:
|
|
ax.set_ylim(*ylim)
|
|
|
|
|
|
def plot(runs: list[dict], out: Path) -> None:
|
|
by_arm: dict[str, list[dict]] = defaultdict(list)
|
|
for r in runs:
|
|
by_arm[classify(r)].append(r)
|
|
arms = [a for a in ARM_ORDER if a in by_arm]
|
|
if not arms:
|
|
raise SystemExit("no runs classified into arms")
|
|
|
|
fig, axes = plt.subplots(2, len(arms), figsize=(3.0 * len(arms), 4.4),
|
|
sharex=True, sharey="row", squeeze=False)
|
|
cos_lo = min(np.nanmin(r[c]) for r in runs for c in COS_COLS)
|
|
for col, arm in enumerate(arms):
|
|
rs = by_arm[arm]
|
|
n_seed = len({r["seed"] for r in rs})
|
|
axes[0][col].set_title(f"{arm}\n(n={n_seed} seed{'s' if n_seed > 1 else ''})",
|
|
fontsize=9)
|
|
_series_panel(axes[0][col], rs, RATE_COLS, RATE_COLORS, ylim=(0, 1),
|
|
label_series=(col == 0))
|
|
_series_panel(axes[1][col], rs, COS_COLS, COS_COLORS,
|
|
ylim=(min(-0.05, cos_lo - 0.02), 0.45), label_series=(col == 0))
|
|
axes[1][col].axhline(0, color="0.8", lw=0.6, zorder=0)
|
|
axes[1][col].set_xlabel("optimizer step")
|
|
|
|
# Mean hack-onset: one dashed vertical reference line spanning BOTH rows
|
|
# so the cos-divergence can be read against the moment hacking starts.
|
|
onsets = [s for r in rs if (s := _onset(r["steps"], r["hack_s"])) is not None]
|
|
if onsets:
|
|
s0 = float(np.mean(onsets))
|
|
for row in (0, 1):
|
|
axes[row][col].axvline(s0, color="0.55", lw=0.8, ls=(0, (4, 3)), zorder=0)
|
|
axes[0][col].annotate("first hack", (s0, 1.0), color="0.4", fontsize=7,
|
|
xytext=(2, -2), textcoords="offset points", va="top")
|
|
|
|
axes[0][0].set_ylabel("student rate")
|
|
axes[1][0].set_ylabel("cos(grad, v_hack)")
|
|
# range-frame: drop top/right spines, keep ink on data
|
|
for ax in axes.flat:
|
|
ax.spines["top"].set_visible(False)
|
|
ax.spines["right"].set_visible(False)
|
|
ax.tick_params(labelsize=8)
|
|
|
|
fig.suptitle("Training dynamics: G_hack erasure vs vanilla "
|
|
"(EMA-5 smoothed; dashed line = mean hack onset)", fontsize=10)
|
|
fig.tight_layout(rect=(0, 0, 1, 0.96))
|
|
out.parent.mkdir(parents=True, exist_ok=True)
|
|
fig.savefig(out, dpi=150, bbox_inches="tight")
|
|
logger.info(f"wrote {out} ({len(runs)} runs, arms={arms})")
|
|
|
|
|
|
def plot_hack_overlay(runs: list[dict], out: Path) -> None:
|
|
"""Single panel: mean hack_s per arm overlaid, the headline arm-vs-arm view.
|
|
Faint per-seed lines for spread, bold EMA mean per arm, onset tick at the
|
|
arm's mean hack-onset, direct-labelled (no legend)."""
|
|
by_arm: dict[str, list[dict]] = defaultdict(list)
|
|
for r in runs:
|
|
by_arm[classify(r)].append(r)
|
|
arms = [a for a in ARM_ORDER if a in by_arm]
|
|
|
|
fig, ax = plt.subplots(figsize=(5.2, 3.4))
|
|
for arm in arms:
|
|
rs = by_arm[arm]
|
|
color = ARM_COLORS[arm]
|
|
stacked = []
|
|
for r in rs:
|
|
ys = _ema(r["hack_s"])
|
|
ax.plot(r["steps"], ys, color=color, lw=0.6, alpha=0.25, solid_capstyle="round")
|
|
stacked.append(ys)
|
|
L = min(len(y) for y in stacked)
|
|
ym = np.nanmean(np.stack([y[:L] for y in stacked]), axis=0)
|
|
xm = rs[0]["steps"][:L]
|
|
ax.plot(xm, ym, color=color, lw=2.0, solid_capstyle="round")
|
|
onsets = [s for r in rs if (s := _onset(r["steps"], r["hack_s"])) is not None]
|
|
if onsets:
|
|
s0 = float(np.mean(onsets))
|
|
ax.plot(s0, np.interp(s0, xm, ym), marker="o", ms=4, color=color, zorder=3)
|
|
ax.annotate(arm, (xm[-1], ym[-1]), color=color, fontsize=8,
|
|
xytext=(4, 0), textcoords="offset points", va="center")
|
|
|
|
ax.set_ylim(0, 1)
|
|
ax.set_xlabel("optimizer step")
|
|
ax.set_ylabel("student hack rate (hack_s)")
|
|
ax.spines["top"].set_visible(False)
|
|
ax.spines["right"].set_visible(False)
|
|
ax.tick_params(labelsize=8)
|
|
ax.set_title("Student hack rate by arm (EMA-5; dot = mean onset)", fontsize=10)
|
|
fig.tight_layout()
|
|
out.parent.mkdir(parents=True, exist_ok=True)
|
|
fig.savefig(out, dpi=150, bbox_inches="tight")
|
|
logger.info(f"wrote {out}")
|
|
|
|
|
|
# --- cli -------------------------------------------------------------------
|
|
|
|
def _gather(paths: list[str]) -> list[Path]:
|
|
out: list[Path] = []
|
|
for p in paths:
|
|
pp = Path(p)
|
|
if pp.is_dir():
|
|
out += sorted(pp.glob("*.log"))
|
|
elif any(c in p for c in "*?["):
|
|
out += sorted(Path().glob(p))
|
|
else:
|
|
out.append(pp)
|
|
return out
|
|
|
|
|
|
def main() -> None:
|
|
ap = argparse.ArgumentParser(description=__doc__)
|
|
ap.add_argument("logs", nargs="+", help="log files, globs, or dirs")
|
|
ap.add_argument("--out", type=Path, default=Path("out/figs/dynamics.png"))
|
|
args = ap.parse_args()
|
|
files = _gather(args.logs)
|
|
runs = [r for f in files if (r := parse_log(f))]
|
|
if not runs:
|
|
raise SystemExit(f"no parseable runs in {len(files)} files")
|
|
for r in runs:
|
|
logger.info(f"{classify(r):16s} seed={r['seed']} steps={len(r['steps'])} {r['vhack']}")
|
|
args.out.parent.mkdir(parents=True, exist_ok=True)
|
|
plot(runs, args.out)
|
|
# second figure: single-panel arm-vs-arm overlay of the headline metric
|
|
overlay = args.out.with_name(args.out.stem + "_hack_overlay.png")
|
|
plot_hack_overlay(runs, overlay)
|
|
for p in (args.out, overlay):
|
|
logger.info(f"docs/figs latest -> {link_latest(p)}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|