mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 21:07:17 +08:00
55937a86fb
git mv src/projected_grpo -> src/vgrout and find-replace the module name in
all imports (.py), `-m projected_grpo.*` invocations (justfile), and the
[project] name (pyproject; setuptools auto-discovers via where=["src"]).
Left RESEARCH_JOURNAL.md untouched: its commands/paths are dated lab notes
tied to past commits, so rewriting them would falsify provenance. Repo dir,
git remote, and absolute paths unchanged.
Verified: `import vgrout` and `python -m vgrout.train --help` load the full
graph; verify_rewards.py + verify_gate_anchor.py (both import vgrout) pass.
Full `just smoke` is blocked upstream by missing gitignored data artifacts
(out/pools/{substrate,teacher_pool}, out/vhack/*smoke*), unrelated to the rename.
276 lines
13 KiB
Python
276 lines
13 KiB
Python
"""Multi-loophole substrate per-mode dynamics (#137/#148): how much of each loophole
|
|
does each intervention let the student learn, and how fast?
|
|
|
|
The substrate run interleaves all K modes in ONE log via the hk_<mode> columns
|
|
(cumulative student hacks / rollouts-of-that-mode-seen). We parse those, take the
|
|
per-step *instantaneous* rate (cumulative diffs), and EMA-smooth it -- the
|
|
instantaneous rate is what shows a method SUPPRESSING a mode over time, which the
|
|
monotone cumulative curve hides. Pass --cumulative for the raw running rate.
|
|
|
|
Two core layouts (both emitted by default):
|
|
by-method : one panel per intervention (vanilla / erase / route); one coloured
|
|
line per hack type. Reads "how many of K classes does THIS method let through".
|
|
by-hack : one panel per hack type; one line per method (mean over seeds, thin
|
|
per-seed). Reads "for THIS loophole, which method suppresses it best".
|
|
|
|
Route caveat (load-bearing): hk_<mode> is the TRAINING-time rate; the routed forward
|
|
still hacks during training, the deployed model (quarantine knob deleted) is the real
|
|
number. The log has aggregate hack_deploy but NOT per-mode deploy, so route's per-mode
|
|
curve is drawn DASHED and overstates route. TODO: log per-mode deploy in train.py to
|
|
make route's per-mode honest; until then read route's real number off plot_dynamics.
|
|
|
|
This is the single plotting ENTRYPOINT (`just plot`): it emits the per-mode cut
|
|
(by-method, by-hack) AND delegates the aggregate "total hacks per arm" + cos-alignment
|
|
figures to plot_dynamics.plot/plot_hack_overlay (reuse, not reimplement). plot_dynamics
|
|
owns route's deploy-curve substitution and the cos rows; this script owns parse_hk.
|
|
|
|
Usage:
|
|
uv run python scripts/plot_substrate.py logs/*_sub4_*.log # both layouts -> out/figs/
|
|
uv run python scripts/plot_substrate.py A.log B.log --out-stem out/figs/sub4
|
|
uv run python scripts/plot_substrate.py <run>.log --cumulative --ema-span 6
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import re
|
|
from pathlib import Path
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from loguru import logger
|
|
|
|
from vgrout.figs import save_fig
|
|
|
|
# hk_ column header -> (display mode, colour). Order = panel/legend order.
|
|
# Colourblind-safe-ish qualitative set; one hue per loophole, reused across panels.
|
|
HK = {
|
|
"hk_rt": ("run_tests", "#c1432b"),
|
|
"hk_fm": ("file_marker", "#7b3294"),
|
|
"hk_so": ("stdout_marker", "#b8860b"),
|
|
"hk_se": ("sentinel", "#2f7d4f"),
|
|
"hk_eq": ("eq_override", "#33508c"),
|
|
}
|
|
# method -> (display label, colour, dashed?). dashed = per-mode curve is train-time
|
|
# only (route: the routed forward still hacks; deploy is lower and not logged per-mode).
|
|
METHODS = {
|
|
"vanilla": ("vanilla", "#444444", False),
|
|
"erase": ("erase", "#c1432b", False),
|
|
"route": ("route (train-time)", "#33508c", True),
|
|
}
|
|
_HDR_TOK = re.compile(r"[A-Za-z_]+") # "hack_s?" -> "hack_s"
|
|
|
|
|
|
def classify(txt: str) -> str:
|
|
"""vanilla / erase / route from the preset `arm=` line (covers --intervention logs).
|
|
Unknown arms (e.g. route2's routing2_act) fall through to their raw name -- the
|
|
plotters filter to known METHODS, so an unmapped arm is silently dropped from the
|
|
train-dynamics panels rather than crashing the whole `just plot`."""
|
|
preset = next((l for l in txt.splitlines() if "preset=" in l and "arm=" in l), "")
|
|
arm = (re.search(r"\barm=(\w+)", preset) or [None, "vanilla"])[1]
|
|
return {"vanilla": "vanilla", "projected": "erase", "routing": "route"}.get(arm, arm)
|
|
|
|
|
|
def parse_hk(path: Path) -> dict | None:
|
|
"""{method, seed, steps, <mode>: (n[], d[])} from a substrate run log, or None
|
|
if the log isn't a multi-loophole run (no hk_rt header). Returning None rather
|
|
than raising lets `just plot` glob a broad set of logs (old single-mode/aborted
|
|
runs mixed in) without crashing; main() logs which paths were skipped."""
|
|
txt = path.read_text(errors="replace")
|
|
hdr = next((l for l in txt.splitlines() if "ref_eq" in l and "hk_rt" in l), None)
|
|
if hdr is None:
|
|
return None
|
|
names = [_HDR_TOK.match(t).group(0) for t in hdr.split("| INFO |", 1)[1].split()]
|
|
idx = {n: i for i, n in enumerate(names)}
|
|
present = [k for k in HK if k in idx] # 4-mode substrate dropped hk_eq; plot only what's logged
|
|
steps, counts = [], {k: [] for k in present}
|
|
for line in txt.splitlines():
|
|
if "| INFO |" not in line:
|
|
continue
|
|
row = line.split("| INFO |", 1)[1].split()
|
|
if not row or not row[0].isdigit() or len(row) < len(names):
|
|
continue
|
|
# hk_<mode> is now the per-batch hack COUNT (current step, not cumulative,
|
|
# no /denominator) -- the streaming log dropped the rollout denominator, so
|
|
# a per-mode RATE is no longer recoverable. We plot the count directly. Old
|
|
# logs print "n/d" (cumulative): incompatible units, skip the whole log.
|
|
if "/" in row[idx[present[0]]]:
|
|
return None
|
|
steps.append(int(row[idx["step"]]))
|
|
for k in present:
|
|
counts[k].append(int(row[idx[k]]))
|
|
if not steps:
|
|
return None # header present but no parseable per-step rows (e.g. diverged/aborted)
|
|
m = re.search(r"seed(\d+)", path.name) or re.search(r"_s(\d+)", path.name)
|
|
return dict(
|
|
method=classify(txt),
|
|
seed=m.group(1) if m else "?",
|
|
steps=np.array(steps),
|
|
**{k: np.array(v) for k, v in counts.items()},
|
|
)
|
|
|
|
|
|
def ema(y: np.ndarray, span: int) -> np.ndarray:
|
|
"""EMA that carries the last value across NaN gaps (steps where a mode saw 0 rollouts)."""
|
|
a = 2.0 / (span + 1.0)
|
|
out = np.full(len(y), np.nan)
|
|
m = None
|
|
for i, v in enumerate(y):
|
|
if np.isnan(v):
|
|
out[i] = m if m is not None else np.nan
|
|
continue
|
|
m = v if m is None else a * v + (1 - a) * m
|
|
out[i] = m
|
|
return out
|
|
|
|
|
|
def rate(count: np.ndarray, *, cumulative: bool, span: int) -> np.ndarray:
|
|
"""Per-step hacks of one mode. count is the per-batch hack count (instantaneous).
|
|
cumulative=running total (cumsum); else EMA-smoothed per-batch count."""
|
|
if cumulative:
|
|
return np.cumsum(count)
|
|
return ema(count.astype(float), span)
|
|
|
|
|
|
def _despine(ax):
|
|
ax.spines[["top", "right"]].set_visible(False)
|
|
ax.grid(axis="y", lw=0.4, alpha=0.35)
|
|
|
|
|
|
def _onset(x, y) -> int | None:
|
|
nz = np.where(np.nan_to_num(y) > 0)[0]
|
|
return int(x[nz[0]]) if len(nz) else None
|
|
|
|
|
|
def plot_by_method(runs, ylabel, cumulative, span, out: Path):
|
|
"""One panel per method; one line per hack type. Multi-seed -> mean bold + per-seed thin."""
|
|
methods = [m for m in METHODS if any(r["method"] == m for r in runs)]
|
|
modes = [k for k in HK if all(k in r for r in runs)]
|
|
fig, axes = plt.subplots(1, len(methods), figsize=(3.5 * len(methods), 3.6),
|
|
sharey=True, sharex=True, squeeze=False)
|
|
axes = axes[0]
|
|
for ax, method in zip(axes, methods):
|
|
grp = [r for r in runs if r["method"] == method]
|
|
L = min(len(r["steps"]) for r in grp)
|
|
x = grp[0]["steps"][:L]
|
|
n_learned = 0
|
|
for k in modes:
|
|
mode, color = HK[k]
|
|
stk = np.stack([rate(r[k], cumulative=cumulative, span=span)[:L] for r in grp])
|
|
ymean = np.nanmean(stk, axis=0)
|
|
for ys in stk if len(grp) > 1 else []:
|
|
ax.plot(x, ys, color=color, lw=0.6, alpha=0.30)
|
|
ax.plot(x, ymean, color=color, lw=1.8, solid_capstyle="round")
|
|
on = _onset(x, ymean)
|
|
n_learned += on is not None
|
|
ax.annotate(f"{mode} {np.nan_to_num(ymean[-1]):.1f}", (x[-1], np.nan_to_num(ymean[-1])),
|
|
color=color, fontsize=7, va="center", xytext=(5, 0), textcoords="offset points")
|
|
label, _, dashed = METHODS[method]
|
|
ax.set_title(f"{label} ({n_learned}/{len(modes)} learned)", fontsize=9)
|
|
ax.set_xlabel("GRPO step")
|
|
ax.set_xlim(0, x[-1] * 1.30)
|
|
_despine(ax)
|
|
if dashed:
|
|
ax.text(0.03, 0.97, "train-time\n(deploy lower)", transform=ax.transAxes,
|
|
fontsize=6.5, va="top", color="#888")
|
|
axes[0].set_ylabel(ylabel)
|
|
axes[0].set_ylim(-0.02, None)
|
|
fig.tight_layout()
|
|
save_fig(fig, out)
|
|
logger.info(f"wrote {out} (by-method, {len(methods)} methods)")
|
|
|
|
|
|
def plot_by_hack(runs, ylabel, cumulative, span, out: Path):
|
|
"""One panel per hack type; one line per method (mean over seeds, thin per-seed)."""
|
|
methods = [m for m in METHODS if any(r["method"] == m for r in runs)]
|
|
modes = [k for k in HK if all(k in r for r in runs)]
|
|
fig, axes = plt.subplots(1, len(modes), figsize=(3.2 * len(modes), 3.6),
|
|
sharey=True, sharex=True, squeeze=False)
|
|
axes = axes[0]
|
|
for ax, k in zip(axes, modes):
|
|
mode, _ = HK[k]
|
|
for method in methods:
|
|
grp = [r for r in runs if r["method"] == method]
|
|
L = min(len(r["steps"]) for r in grp)
|
|
x = grp[0]["steps"][:L]
|
|
stk = np.stack([rate(r[k], cumulative=cumulative, span=span)[:L] for r in grp])
|
|
ymean = np.nanmean(stk, axis=0)
|
|
label, color, dashed = METHODS[method]
|
|
for ys in stk if len(grp) > 1 else []:
|
|
ax.plot(x, ys, color=color, lw=0.6, alpha=0.25, ls="--" if dashed else "-")
|
|
ax.plot(x, ymean, color=color, lw=1.8, ls="--" if dashed else "-", solid_capstyle="round")
|
|
ax.annotate(label, (x[-1], np.nan_to_num(ymean[-1])), color=color, fontsize=7,
|
|
va="center", xytext=(5, 0), textcoords="offset points")
|
|
ax.set_title(mode, fontsize=9)
|
|
ax.set_xlabel("GRPO step")
|
|
ax.set_xlim(0, x[-1] * 1.45)
|
|
_despine(ax)
|
|
axes[0].set_ylabel(ylabel)
|
|
axes[0].set_ylim(-0.02, None)
|
|
fig.tight_layout()
|
|
save_fig(fig, out)
|
|
logger.info(f"wrote {out} (by-hack, {len(modes)} modes)")
|
|
|
|
|
|
def main() -> None:
|
|
"""Single plotting entrypoint (`just plot`). Emits FOUR figures from one set
|
|
of logs, reusing two parsers/owners:
|
|
|
|
<stem>_by_method.png per-mode, panel per method (this script's parse_hk)
|
|
<stem>_by_hack.png per-mode, panel per hack (this script's parse_hk)
|
|
<stem>_aggregate.png aggregate small-multiples (plot_dynamics.plot)
|
|
<stem>_aggregate_hack_overlay.png arm-vs-arm hack overlay (plot_dynamics)
|
|
|
|
The aggregate pair is the "total hacks per arm" core plot -- delegated to
|
|
plot_dynamics (which owns the deploy-curve substitution for routing and the
|
|
cos-alignment rows), NOT reimplemented here. --no-aggregate skips it (e.g. on
|
|
logs without the cos_pre/deploy columns).
|
|
"""
|
|
ap = argparse.ArgumentParser(description=__doc__)
|
|
ap.add_argument("logs", nargs="+", type=Path)
|
|
ap.add_argument("--out-stem", type=Path, default=Path("out/figs/substrate"),
|
|
help="writes <stem>_by_method.png, _by_hack.png, _aggregate*.png")
|
|
ap.add_argument("--cumulative", action="store_true", help="running N/M instead of EMA instantaneous")
|
|
ap.add_argument("--ema-span", type=int, default=6)
|
|
ap.add_argument("--no-aggregate", action="store_true",
|
|
help="skip the plot_dynamics aggregate + overlay figures")
|
|
args = ap.parse_args()
|
|
stem = args.out_stem
|
|
|
|
# 1-2. per-mode small multiples (this script owns these). Skip (don't crash on)
|
|
# logs that aren't multi-loophole substrate runs -- the glob may catch old
|
|
# single-mode/aborted runs; log which were dropped so the skip isn't silent.
|
|
parsed = {p: parse_hk(p) for p in args.logs}
|
|
skipped = [p for p, r in parsed.items() if r is None]
|
|
if skipped:
|
|
logger.warning(f"skipped {len(skipped)} non-substrate log(s): "
|
|
+ ", ".join(p.name for p in skipped))
|
|
runs = [r for r in parsed.values() if r is not None]
|
|
if not runs:
|
|
raise SystemExit("no substrate runs in the glob (need hk_rt columns)")
|
|
logger.info(f"parsed {len(runs)} runs: " + ", ".join(f"{r['method']}/s{r['seed']}" for r in runs))
|
|
ylabel = "cumulative hacks" if args.cumulative else f"hacks/batch (EMA span {args.ema_span})"
|
|
plot_by_method(runs, ylabel, args.cumulative, args.ema_span, stem.with_name(stem.name + "_by_method.png"))
|
|
plot_by_hack(runs, ylabel, args.cumulative, args.ema_span, stem.with_name(stem.name + "_by_hack.png"))
|
|
|
|
# 3-4. aggregate "total hacks per arm" + hack overlay (reuse plot_dynamics,
|
|
# which owns route's deploy-curve substitution + the cos-alignment rows).
|
|
# Non-fatal: the two per-mode figures above are the substrate deliverable;
|
|
# plot_dynamics assumes the older erase/route column set (cin_t etc.) and
|
|
# KeyErrors on a route2 log, so a delegation failure must not sink `just plot`.
|
|
if not args.no_aggregate:
|
|
try:
|
|
import plot_dynamics as pd
|
|
agg_runs = [r for p in args.logs if (r := pd.parse_log(p))]
|
|
if agg_runs:
|
|
agg = stem.with_name(stem.name + "_aggregate.png")
|
|
pd.plot(agg_runs, agg)
|
|
pd.plot_hack_overlay(agg_runs, agg.with_name(agg.stem + "_hack_overlay.png"))
|
|
else:
|
|
logger.warning("no runs had aggregate columns (cos_pre/hack_s) -- skipped aggregate figs")
|
|
except Exception as e:
|
|
logger.warning(f"aggregate delegation (plot_dynamics) failed, per-mode figs still written: {e!r}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|