mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-28 01:15:10 +08:00
19deef4fb9
- blog: mark as erase-n=2 draft, note route2/exploration-floor/deploy-eval are the current direction; embed dyn_sub4_hack_overlay.png (force-added); ASCII em-dashes; de-bold the arm list (#15 tell) - README: add route2 arm + apples-to-apples deploy-eval to 'What we compare'; stale banner on the n=1 mix=0.5 findings - plot_dynamics: remove _mark_if_sparse (asymmetric sparse-only dots); EMA-held line for all arms - train.py: fix 'held-out greedy' -> 'held-out eval subset, T=0.7' (deploy eval is sampled, not greedy) Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
386 lines
18 KiB
Python
386 lines
18 KiB
Python
"""Training-dynamics small multiples: deployed hack vs solve, one column per arm.
|
|
|
|
Tufte small multiples, single row. Columns = arm (vanilla / static G_hack
|
|
erasure / online G_hack erasure / routing2); the panel shows the DEPLOYED
|
|
model's hack_s (red) and solve/gt_s (green) over training. Per-seed thin lines
|
|
+ bold mean; the mean hack-onset step (first hack_s > 0) is a dashed vertical.
|
|
|
|
APPLES-TO-APPLES. We plot the DEPLOY-eval (hk_dep/slv_dep) for every arm when
|
|
present: the same estimator across arms (n=64, T=0.7, every --eval-ablate-every
|
|
steps). For route/route2 the deployed model = quarantine knob zeroed; for
|
|
vanilla/erase deploy == the trained model. Sparse deploy-eval steps are EMA-held
|
|
between samples, drawn as a plain line (same as the dense curves).
|
|
Older logs that gated the eval to route only fall back to per-step training
|
|
hack_s for vanilla/erase (noisier, n=28, but estimates the same deployed rate
|
|
since those arms have no quarantine).
|
|
|
|
Data source: logs/*.log per-step rows (the durable source results.py also uses).
|
|
We parse by HEADER NAME, not fixed index, because newer runs add columns (refr).
|
|
|
|
Arm classification (from the preset line `arm=`, covering old --arm and new
|
|
--intervention logs):
|
|
vanilla arm=vanilla (intervention=none)
|
|
static erasure arm=projected, no --vhack-refresh-every (frozen v_hack)
|
|
online erasure arm=projected, --vhack-refresh-every=N>0 (re-extracted)
|
|
routing2 arm=routing2 (intervention=route2)
|
|
|
|
Usage:
|
|
uv run python scripts/plot_dynamics.py logs/*converge*.log
|
|
uv run python scripts/plot_dynamics.py logs/ # whole dir
|
|
uv run python scripts/plot_dynamics.py A.log B.log --out out/dynamics.png
|
|
|
|
Scales to 3 seeds x 3 arms: pass all 9 logs, they auto-group by (arm, seed).
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import re
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from loguru import logger
|
|
|
|
from projected_grpo.figs import link_latest
|
|
|
|
# --- parse -----------------------------------------------------------------
|
|
|
|
# Series we plot, by cleaned header name. frac "7/28" -> 0.25; float "+0.264".
|
|
RATE_COLS = {"hack_s": "hack", "gt_s": "solve"}
|
|
_HDR_TOK = re.compile(r"[A-Za-z_]+") # strip ↑↓? decorations: "hack_s?" -> "hack_s"
|
|
|
|
|
|
def _val(tok: str) -> float | None:
|
|
"""Parse a per-step cell: frac n/d, signed float, or T/F/-/nan."""
|
|
if "/" in tok:
|
|
a, b = tok.split("/")
|
|
return int(a) / int(b) if int(b) else None
|
|
if tok in ("T", "F", "-", "nan"):
|
|
return None
|
|
return float(tok)
|
|
|
|
|
|
def parse_log(path: Path) -> dict | None:
|
|
"""Return {arm, refr, seed, vhack, steps: int[], <series>: float[]} or None."""
|
|
txt = path.read_text(errors="replace")
|
|
argv = next((l for l in txt.splitlines() if "argv:" in l), None)
|
|
preset = next((l for l in txt.splitlines() if "preset=" in l and "arm=" in l), "")
|
|
if argv is None:
|
|
return None
|
|
|
|
def grab(pat, s, default=None):
|
|
ms = re.findall(pat, s)
|
|
return ms[-1] if ms else default
|
|
|
|
# arm = derived display name in the preset line (vanilla/projected/routing),
|
|
# the one source that covers both old (--arm) and new (--intervention) logs.
|
|
arm = grab(r"\barm=(\w+)", preset, "vanilla")
|
|
refr = int(grab(r"--vhack-refresh-every=(\d+)", argv, "0"))
|
|
seed = grab(r"seed=(\d+)", preset, "?")
|
|
vhack = grab(r"v-hack-path=out/(?:vhack/)?(\S+?)\.safetensors", argv, "-")
|
|
|
|
# header line: the one containing both "step" and "hack_s"
|
|
hdr = next((l for l in txt.splitlines()
|
|
if "| INFO |" in l and "ref_eq" in l and "hack_s" in l), None)
|
|
if hdr is None:
|
|
return None
|
|
# real column headers always start with a letter/underscore; drop pure-symbol
|
|
# tokens (decoration) so a stray glyph in an old log's header doesn't crash parse
|
|
names = [m.group(0) for t in hdr.split("| INFO |", 1)[1].split() if (m := _HDR_TOK.match(t))]
|
|
idx = {n: i for i, n in enumerate(names)}
|
|
|
|
series: dict[str, list[float]] = defaultdict(list)
|
|
steps: list[int] = []
|
|
# Also parse the route DEPLOY-eval columns when present (non-route logs lack
|
|
# them -> skip). For routing we plot THESE (deployed model = quarantine deleted),
|
|
# not the training-time hack_s.
|
|
# hk_abl/slv_abl = the FREE per-step deploy proxy (ablated rollout slice,
|
|
# rollout_ablate_frac>0); hk_dep/slv_dep = the held-out greedy eval, only on
|
|
# eval_ablate_every steps. Prefer the dense proxy for the curve (see below).
|
|
deploy = {"hk_dep", "slv_dep", "hk_abl", "slv_abl"} & set(idx)
|
|
# Only parse columns this log actually has: non-projecting arms (vanilla,
|
|
# routing2) lack cin_t/cin_s, so gate by presence rather than KeyError.
|
|
wanted = {k: v for k, v in RATE_COLS.items() if k in idx}
|
|
wanted.update({c: c for c in deploy})
|
|
for line in txt.splitlines():
|
|
if "| INFO |" not in line:
|
|
continue
|
|
row = line.split("| INFO |", 1)[1].split()
|
|
if not row or not row[0].isdigit() or len(row) < len(names):
|
|
continue
|
|
steps.append(int(row[idx["step"]]))
|
|
for col in wanted:
|
|
series[col].append(_val(row[idx[col]]))
|
|
if not steps:
|
|
return None
|
|
run = dict(arm=arm, refr=refr, seed=seed, vhack=vhack,
|
|
steps=np.array(steps), **{k: np.array(v, dtype=float) for k, v in series.items()})
|
|
# APPLES-TO-APPLES: plot the DEPLOY-eval (hk_dep/slv_dep) for EVERY arm when it
|
|
# has data -- same estimator (n=64, T=0.7, eval_ablate_every cadence) across arms.
|
|
# For route/route2 this is the quarantine-off model; for vanilla/erase deploy ==
|
|
# trained model. Older logs (eval gated to route only) lack it for vanilla/erase
|
|
# -> fall back to per-step training hack_s. Test FINITE values, not column
|
|
# presence: no-floor logs carry an all-nan hk_dep/hk_abl column otherwise.
|
|
def _has_data(key):
|
|
return key in run and np.isfinite(run[key]).any()
|
|
if _has_data("hk_abl"): # dense per-step proxy (rollout_ablate_frac>0), if present
|
|
run["hack_s"] = run["hk_abl"]
|
|
run["gt_s"] = run["slv_abl"]
|
|
elif _has_data("hk_dep"): # the n=64 every-eval_ablate_every deploy eval
|
|
run["hack_s"] = run["hk_dep"]
|
|
run["gt_s"] = run["slv_dep"]
|
|
return run
|
|
|
|
|
|
def classify(run: dict) -> str:
|
|
if run["arm"] == "vanilla":
|
|
return "vanilla"
|
|
if run["arm"] == "routing":
|
|
return "routing"
|
|
if run["arm"] == "routing2":
|
|
return "routing2"
|
|
# arm == projected -> erasure, split by refresh
|
|
return "online erasure" if run["refr"] > 0 else "static erasure"
|
|
|
|
|
|
# --- plot ------------------------------------------------------------------
|
|
|
|
# routing (route v1, single quarantine) is deprecated -- superseded by routing2
|
|
# (scale-matched quarantine). classify() still tags v1 logs as "routing" so they
|
|
# don't get misread as erasure, but it's left out of ARM_ORDER so it isn't plotted.
|
|
ARM_ORDER = ["vanilla", "static erasure", "online erasure", "routing2"]
|
|
# Distinct colour per series -- the two rows measure different things, so they
|
|
# must not share a palette (hack != teacher-cos). Row 0: red hack vs green
|
|
# solve. Row 1: blue teacher-cos vs amber student-cos.
|
|
RATE_COLORS = {"hack_s": "#c1432b", "gt_s": "#2f7d4f"}
|
|
# Arm colours for the single-panel hack overlay (arms, not series): grey vanilla
|
|
# baseline -> amber static -> blue online, ordered by increasing intervention.
|
|
# TODO(color): make this a quality-ordered red->green ramp instead of fixed
|
|
# per-arm hues -- red = vanilla (worst, most hacking), green = best method
|
|
# (anticipated gradient routing). As arms grow (static/online/grad-routing/
|
|
# confessions), assign colour by method rank along a perceptual RdYlGn ramp so
|
|
# the reader sees "redder = hacks more" at a glance.
|
|
ARM_COLORS = {"vanilla": "#7a7a7a", "static erasure": "#c98a2b",
|
|
"online erasure": "#33508c", "routing": "#2f7d4f",
|
|
"routing2": "#7d2f6f"}
|
|
|
|
|
|
def _onset(steps: np.ndarray, hack: np.ndarray) -> int | None:
|
|
"""First step where RAW hack_s > 0 (the hack-onset point). Computed on the
|
|
unsmoothed series -- EMA would blur the very step we want to mark."""
|
|
nz = np.flatnonzero(hack > 0)
|
|
return int(steps[nz[0]]) if len(nz) else None
|
|
|
|
|
|
def _ema(y: np.ndarray, span: int = 5) -> np.ndarray:
|
|
"""Causal EMA, span=5. Less lag than a trailing SMA(5) since it weights
|
|
recent steps more. NaNs hold the previous smoothed value (don't reset it)."""
|
|
a = 2.0 / (span + 1)
|
|
out = np.empty_like(y)
|
|
prev = np.nan
|
|
for i, v in enumerate(y):
|
|
if np.isnan(v):
|
|
out[i] = prev
|
|
else:
|
|
prev = v if np.isnan(prev) else a * v + (1 - a) * prev
|
|
out[i] = prev
|
|
return out
|
|
|
|
|
|
def _series_panel(ax, runs, cols, colors, ylim, label_series=False):
|
|
"""Overlay per-seed thin EMA lines + bold mean-of-EMA for each series."""
|
|
ends = [] # (endpoint_y, label, color) for direct labels
|
|
for col, label in cols.items():
|
|
color = colors[col]
|
|
stacked = []
|
|
present = [r for r in runs if col in r]
|
|
if not present: # arm lacks this series (e.g. no cos cols for routing2/vanilla)
|
|
continue
|
|
for r in present:
|
|
ys = _ema(r[col])
|
|
ax.plot(r["steps"], ys, color=color, lw=0.7, alpha=0.35, solid_capstyle="round")
|
|
stacked.append(ys)
|
|
# mean over seeds of the smoothed series (runs share the step grid within an arm)
|
|
L = min(len(y) for y in stacked)
|
|
ym = np.nanmean(np.stack([y[:L] for y in stacked]), axis=0)
|
|
xm = runs[0]["steps"][:L]
|
|
ax.plot(xm, ym, color=color, lw=1.8, solid_capstyle="round")
|
|
ends.append((ym[-1], xm[-1], label, color))
|
|
# Direct labels in the leftmost column only -- colour carries the series
|
|
# across the row, so per-panel repeats are redundant ink. Nudge by the
|
|
# ACTUAL endpoint ordering (higher line -> label up, lower -> down): the two
|
|
# cos lines cross, so a fixed up/down stagger would land each label on the
|
|
# wrong line.
|
|
if label_series:
|
|
ends.sort(key=lambda e: e[0]) # lowest endpoint first
|
|
dy = {0: -6, len(ends) - 1: 6} if len(ends) > 1 else {0: 0}
|
|
for rank, (y, x, label, color) in enumerate(ends):
|
|
ax.annotate(label, (x, y), color=color, fontsize=8,
|
|
xytext=(3, dy.get(rank, 0)), textcoords="offset points", va="center")
|
|
if ylim:
|
|
ax.set_ylim(*ylim)
|
|
|
|
|
|
def plot(runs: list[dict], out: Path) -> None:
|
|
by_arm: dict[str, list[dict]] = defaultdict(list)
|
|
for r in runs:
|
|
by_arm[classify(r)].append(r)
|
|
arms = [a for a in ARM_ORDER if a in by_arm]
|
|
if not arms:
|
|
raise SystemExit("no runs classified into arms")
|
|
|
|
fig, axes = plt.subplots(1, len(arms), figsize=(3.0 * len(arms), 2.6),
|
|
sharex=True, sharey=True, squeeze=False)
|
|
for col, arm in enumerate(arms):
|
|
ax = axes[0][col]
|
|
rs = by_arm[arm]
|
|
n_seed = len({r["seed"] for r in rs})
|
|
ax.set_title(f"{arm}\n(n={n_seed} seed{'s' if n_seed > 1 else ''})", fontsize=9)
|
|
_series_panel(ax, rs, RATE_COLS, RATE_COLORS, ylim=(0, 1), label_series=(col == 0))
|
|
ax.set_xlabel("optimizer step")
|
|
onsets = [s for r in rs if (s := _onset(r["steps"], r["hack_s"])) is not None]
|
|
if onsets:
|
|
s0 = float(np.mean(onsets))
|
|
ax.axvline(s0, color="0.55", lw=0.8, ls=(0, (4, 3)), zorder=0)
|
|
ax.annotate("first hack", (s0, 1.0), color="0.4", fontsize=7,
|
|
xytext=(2, -2), textcoords="offset points", va="top")
|
|
|
|
axes[0][0].set_ylabel("deployed rate")
|
|
# range-frame: drop top/right spines, keep ink on data
|
|
for ax in axes.flat:
|
|
ax.spines["top"].set_visible(False)
|
|
ax.spines["right"].set_visible(False)
|
|
ax.tick_params(labelsize=8)
|
|
|
|
fig.suptitle("Training dynamics: deployed hack vs solve by arm "
|
|
"(deploy-eval n=64 T=0.7; EMA-5; dashed = mean hack onset)", fontsize=10)
|
|
fig.tight_layout(rect=(0, 0, 1, 0.96))
|
|
out.parent.mkdir(parents=True, exist_ok=True)
|
|
fig.savefig(out, dpi=150, bbox_inches="tight")
|
|
logger.info(f"wrote {out} ({len(runs)} runs, arms={arms})")
|
|
|
|
|
|
def _overlay_panel(ax, by_arm, arms, key, *, label, with_onset):
|
|
"""Overlay one metric (key) per arm on ax: faint per-seed EMA lines + bold
|
|
EMA mean, optional mean-onset dot. Direct labels (only on the unlabeled-x panel)
|
|
are de-collided in y so overlapping arms don't stack their text (collision test)."""
|
|
ends = [] # (y_endpoint, x_endpoint, arm, color) for direct labels
|
|
for arm in arms:
|
|
rs = [r for r in by_arm[arm] if key in r]
|
|
if not rs:
|
|
continue
|
|
color = ARM_COLORS[arm]
|
|
stacked = []
|
|
for r in rs:
|
|
ys = _ema(r[key])
|
|
ax.plot(r["steps"], ys, color=color, lw=0.6, alpha=0.25, solid_capstyle="round")
|
|
stacked.append(ys)
|
|
L = min(len(y) for y in stacked)
|
|
ym = np.nanmean(np.stack([y[:L] for y in stacked]), axis=0)
|
|
xm = rs[0]["steps"][:L]
|
|
ax.plot(xm, ym, color=color, lw=2.0, solid_capstyle="round")
|
|
if with_onset:
|
|
onsets = [s for r in rs if (s := _onset(r["steps"], r["hack_s"])) is not None]
|
|
if onsets:
|
|
s0 = float(np.mean(onsets))
|
|
ax.plot(s0, np.interp(s0, xm, ym), marker="o", ms=4, color=color, zorder=3)
|
|
ends.append((float(ym[-1]), float(xm[-1]), arm, color))
|
|
ax.set_ylim(0, 1)
|
|
ax.set_ylabel(label)
|
|
ax.spines[["top", "right"]].set_visible(False)
|
|
ax.tick_params(labelsize=8)
|
|
# direct-label only one panel (caller passes with_onset=False for it) -- the
|
|
# other shares colours, so labelling both is redundant ink (eraser test).
|
|
if with_onset:
|
|
return
|
|
ends.sort(key=lambda e: e[0]) # bottom-to-top by endpoint
|
|
gap = 0.052 # min y-separation in data units
|
|
placed = []
|
|
for y, x, arm, color in ends:
|
|
y_lab = y if not placed else max(y, placed[-1] + gap)
|
|
placed.append(y_lab)
|
|
arrow = dict(arrowstyle="-", color=color, lw=0.5, shrinkA=0, shrinkB=0)
|
|
ax.annotate(arm, xy=(x, y), xytext=(x + 1.0, y_lab), textcoords="data",
|
|
color=color, fontsize=8, va="center",
|
|
arrowprops=arrow if abs(y_lab - y) > 1e-3 else None)
|
|
|
|
|
|
def plot_hack_overlay(runs: list[dict], out: Path) -> None:
|
|
"""Two stacked panels sharing x: student hack rate (top) and solve rate (bottom)
|
|
per arm. Faint per-seed EMA lines + bold EMA-5 mean; onset dot on the hack panel;
|
|
arms direct-labelled once on the solve panel (shared colours, no legend)."""
|
|
by_arm: dict[str, list[dict]] = defaultdict(list)
|
|
for r in runs:
|
|
by_arm[classify(r)].append(r)
|
|
arms = [a for a in ARM_ORDER if a in by_arm]
|
|
|
|
fig, (ax_h, ax_s) = plt.subplots(2, 1, figsize=(5.2, 5.2), sharex=True)
|
|
_overlay_panel(ax_h, by_arm, arms, "hack_s", label="hack rate", with_onset=True)
|
|
_overlay_panel(ax_s, by_arm, arms, "gt_s", label="solve rate", with_onset=False)
|
|
ax_s.set_xlabel("optimizer step")
|
|
ax_h.set_title("Hack vs solve rate by arm (EMA-5; dot = mean hack onset)", fontsize=10)
|
|
fig.tight_layout()
|
|
out.parent.mkdir(parents=True, exist_ok=True)
|
|
fig.savefig(out, dpi=150, bbox_inches="tight")
|
|
logger.info(f"wrote {out}")
|
|
|
|
|
|
# --- cli -------------------------------------------------------------------
|
|
|
|
def _gather(paths: list[str]) -> list[Path]:
|
|
out: list[Path] = []
|
|
for p in paths:
|
|
pp = Path(p)
|
|
if pp.is_dir():
|
|
out += sorted(pp.glob("*.log"))
|
|
elif any(c in p for c in "*?["):
|
|
out += sorted(Path().glob(p))
|
|
else:
|
|
out.append(pp)
|
|
return out
|
|
|
|
|
|
def _latest_per_arm(files: list[Path], min_steps: int) -> list[Path]:
|
|
"""One log per arm: the most recent (by filename timestamp) with >= min_steps
|
|
rows. Lets `just dyn` auto-pick the freshest full-length run for each arm
|
|
instead of hand-globbing. Newest filename wins -- timestamp-prefixed names
|
|
sort lexicographically, no mtime races."""
|
|
by_arm: dict[str, tuple[Path, dict]] = {}
|
|
for f in sorted(files): # ascending ts; later overwrites -> keeps newest
|
|
r = parse_log(f)
|
|
if r is None or len(r["steps"]) < min_steps:
|
|
continue
|
|
by_arm[classify(r)] = (f, r)
|
|
return [f for f, _ in by_arm.values()]
|
|
|
|
|
|
def main() -> None:
|
|
ap = argparse.ArgumentParser(description=__doc__)
|
|
ap.add_argument("logs", nargs="+", help="log files, globs, or dirs")
|
|
ap.add_argument("--out", type=Path, default=Path("out/figs/dynamics.png"))
|
|
ap.add_argument("--latest-per-arm", action="store_true",
|
|
help="keep only the newest log per arm (with >= --min-steps rows)")
|
|
ap.add_argument("--min-steps", type=int, default=0,
|
|
help="drop runs shorter than this many logged steps")
|
|
args = ap.parse_args()
|
|
files = _gather(args.logs)
|
|
if args.latest_per_arm:
|
|
files = _latest_per_arm(files, args.min_steps)
|
|
runs = [r for f in files if (r := parse_log(f)) and len(r["steps"]) >= args.min_steps]
|
|
if not runs:
|
|
raise SystemExit(f"no parseable runs in {len(files)} files")
|
|
for r in runs:
|
|
logger.info(f"{classify(r):16s} seed={r['seed']} steps={len(r['steps'])} {r['vhack']}")
|
|
args.out.parent.mkdir(parents=True, exist_ok=True)
|
|
plot(runs, args.out)
|
|
# second figure: single-panel arm-vs-arm overlay of the headline metric
|
|
overlay = args.out.with_name(args.out.stem + "_hack_overlay.png")
|
|
plot_hack_overlay(runs, overlay)
|
|
for p in (args.out, overlay):
|
|
logger.info(f"docs/figs latest -> {link_latest(p)}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|