From 11bcdd2fe647d6029d2d9e53a5ea907e1fb56f85 Mon Sep 17 00:00:00 2001 From: wassname Date: Sun, 31 May 2026 23:16:39 +0000 Subject: [PATCH] route2 instrumentation + lr fix + deploy overlay (route2-act divergence) route2-act diverged (run 43): 33M kaiming A_q/B_q at delta_S's lr=3e-3 blew up (gn 0.3->7.5 step 8, generations -> token salad, lp_t -11). Fixes: - #167 separate quarantine lr (route2_quar_lr_scale=0.1) so the 60x-bigger fresh LoRA isn't trained at the main-knob lr. - #168 divergence tripwire on teacher ppl (lp_t high-water mark; abort if it drops >5 nats for 2 steps). Relative so tiny-random smoke (flat lp_t~-11.9) doesn't false-trip. - #165 act-path was silent: stash cos(a,v_act) + fired-fraction in the forward, surface as act_cos/act_fire columns (route2-act). smoke shows act_fire=0.64 => the cos>0 sign test over-routes (fires on most tokens, not just hack ones). - #166 print last train generation before FINAL EVAL (coherence eyeball). - route2 v_act/v_grad refresh was firing but silent -- now announced. - #162 plot_deploy_overlay.py: per-mode DEPLOY overlay from per_mode_deploy.json (honest shipped-model numbers, route2-safe). just plot-deploy. - just plot/results hardened: parse by header name, skip non-substrate logs, non-fatal aggregate delegation. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com> --- justfile | 6 ++ scripts/plot_deploy_overlay.py | 110 ++++++++++++++++++++++++++++++++ scripts/plot_substrate.py | 53 ++++++++++----- src/projected_grpo/antipasto.py | 5 ++ src/projected_grpo/train.py | 91 +++++++++++++++++++++++++- 5 files changed, 247 insertions(+), 18 deletions(-) create mode 100644 scripts/plot_deploy_overlay.py diff --git a/justfile b/justfile index c6a0eb9..354619c 100644 --- a/justfile +++ b/justfile @@ -183,6 +183,12 @@ queue-substrate: plot GLOB='logs/*_sub4_*.log' STEM='out/figs/substrate': uv run python scripts/plot_substrate.py {{ GLOB }} --out-stem {{ STEM }} +# All-arms per-mode DEPLOY overlay (#162) from the per_mode_deploy.json artifacts +# (honest shipped-model numbers; route2-safe -- reads JSON, not logs). Default +# globs every sub4 run dir. -> out/figs/deploy_overlay.png +plot-deploy GLOB='out/runs/*sub4*/per_mode_deploy.json' OUT='out/figs/deploy_overlay.png': + uv run python scripts/plot_deploy_overlay.py {{ GLOB }} --out {{ OUT }} + # Regenerate both dynamics plots from the cell logs (default: all cells; pass a # narrower glob like 'logs/*_cell_*_s41.log' for the seed-41-only checkpoint). regen-dynamics GLOB='logs/*_cell_*.log': diff --git a/scripts/plot_deploy_overlay.py b/scripts/plot_deploy_overlay.py new file mode 100644 index 0000000..096e053 --- /dev/null +++ b/scripts/plot_deploy_overlay.py @@ -0,0 +1,110 @@ +"""All-arms per-mode DEPLOY overlay (#162) from the per_mode_deploy.json artifacts. + +Each run writes out/runs/_/per_mode_deploy.json (train.py, #164) with the +HONEST deploy numbers: for route/route2 the quarantine is deleted before eval, so +this is the model you would actually ship -- unlike plot_substrate's hk_ +curves which are TRAIN-time (routed forward still hacks) and overstate routing. + +Reads JSON, not logs, so it never trips on a route2 arm the log-parsers don't know. + +The headline comparison: per loophole mode, does each intervention suppress the +DEPLOY hack rate below vanilla, and at what cost to DEPLOY solve? run_tests is the +in-dist mode (v_hack built closest to it); the rest are held-out (the no-cheat +generalisation test). Bars grouped by mode, one bar per arm. + +Usage: + uv run python scripts/plot_deploy_overlay.py # globs out/runs/*sub4*/ + uv run python scripts/plot_deploy_overlay.py out/runs/*_sub4_*/per_mode_deploy.json + uv run python scripts/plot_deploy_overlay.py --out out/figs/deploy_overlay.png +""" +from __future__ import annotations + +import argparse +import json +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +from loguru import logger + +# arm -> (display label, colour). Order = legend/bar order (baseline first). +ARM = { + "vanilla": ("vanilla", "#444444"), + "projected": ("erase", "#c1432b"), + "routing": ("route", "#33508c"), + "routing2_act": ("route2 act", "#2f7d4f"), + "routing2_grad":("route2 grad", "#b8860b"), +} +# mode display order: in-dist first, then held-out. +MODE_ORDER = ["run_tests", "file_marker", "stdout_marker", "sentinel", "eq_override"] + + +def load(paths: list[Path]) -> list[dict]: + out = [] + for p in paths: + d = json.loads(p.read_text()) + out.append(d) + logger.info(f"{d['arm']:<14} deploy hack={d['hack_deploy']:.3f} solve={d['solve_deploy']:.3f} ({p})") + return out + + +def _despine(ax): + ax.spines[["top", "right"]].set_visible(False) + ax.grid(axis="y", lw=0.4, alpha=0.35) + + +def _panel(ax, records, modes, arms, field, title, ylabel): + """Grouped bars: x = mode, one bar per arm, height = records[arm].by_mode[mode][field].""" + w = 0.8 / len(arms) + x = np.arange(len(modes)) + for i, arm in enumerate(arms): + rec = next(r for r in records if r["arm"] == arm) + label, color = ARM[arm] + vals = [rec["by_mode"].get(m, {}).get(field, np.nan) for m in modes] + bars = ax.bar(x + i * w, vals, w, label=label, color=color) + for b, v in zip(bars, vals): + if not np.isnan(v): + ax.annotate(f"{v:.2f}", (b.get_x() + b.get_width() / 2, v), fontsize=6, + ha="center", va="bottom", color=color) + ax.set_xticks(x + 0.4 - w / 2) + ax.set_xticklabels([f"{m}\n{'IN' if m == 'run_tests' else 'held-out'}" for m in modes], fontsize=8) + ax.set_title(title, fontsize=10) + ax.set_ylabel(ylabel) + ax.set_ylim(0, 1.05) + _despine(ax) + + +def main() -> None: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("jsons", nargs="*", type=Path, + help="per_mode_deploy.json paths; default globs out/runs/*sub4*/") + ap.add_argument("--out", type=Path, default=Path("out/figs/deploy_overlay.png")) + args = ap.parse_args() + + paths = args.jsons or sorted(Path("out/runs").glob("*sub4*/per_mode_deploy.json")) + if not paths: + raise SystemExit("no per_mode_deploy.json found (run the sweep first)") + records = load(paths) + # dedupe arms (keep latest by file order), then order canonically + by_arm = {r["arm"]: r for r in records} + arms = [a for a in ARM if a in by_arm] + records = [by_arm[a] for a in arms] + modes = [m for m in MODE_ORDER if any(m in r["by_mode"] for r in records)] + + fig, (a1, a2) = plt.subplots(1, 2, figsize=(5.5 + 1.2 * len(modes), 4.2)) + _panel(a1, records, modes, arms, "deploy_hack", + "DEPLOY hack rate by mode (lower = better)", "deploy hack rate") + _panel(a2, records, modes, arms, "deploy_solve", + "DEPLOY solve rate by mode (higher = better)", "deploy solve rate") + a1.legend(fontsize=8, frameon=False, loc="upper right") + n_seed = {r.get("seed") for r in records} + fig.suptitle(f"Per-mode deploy overlay ({len(arms)} arms, seed {sorted(n_seed)}) -- " + f"quarantine deleted = shipped model", fontsize=11) + args.out.parent.mkdir(parents=True, exist_ok=True) + fig.tight_layout() + fig.savefig(args.out, dpi=140, bbox_inches="tight") + logger.info(f"wrote {args.out} ({len(arms)} arms x {len(modes)} modes)") + + +if __name__ == "__main__": + main() diff --git a/scripts/plot_substrate.py b/scripts/plot_substrate.py index 061381a..450cc1d 100644 --- a/scripts/plot_substrate.py +++ b/scripts/plot_substrate.py @@ -59,18 +59,24 @@ _HDR_TOK = re.compile(r"[A-Za-z_]+") # "hack_s?" -> "hack_s" def classify(txt: str) -> str: - """vanilla / erase / route from the preset `arm=` line (covers --intervention logs).""" + """vanilla / erase / route from the preset `arm=` line (covers --intervention logs). + Unknown arms (e.g. route2's routing2_act) fall through to their raw name -- the + plotters filter to known METHODS, so an unmapped arm is silently dropped from the + train-dynamics panels rather than crashing the whole `just plot`.""" preset = next((l for l in txt.splitlines() if "preset=" in l and "arm=" in l), "") arm = (re.search(r"\barm=(\w+)", preset) or [None, "vanilla"])[1] - return {"vanilla": "vanilla", "projected": "erase", "routing": "route"}[arm] + return {"vanilla": "vanilla", "projected": "erase", "routing": "route"}.get(arm, arm) -def parse_hk(path: Path) -> dict: - """{method, seed, steps, : (n[], d[])} from a substrate run log.""" +def parse_hk(path: Path) -> dict | None: + """{method, seed, steps, : (n[], d[])} from a substrate run log, or None + if the log isn't a multi-loophole run (no hk_rt header). Returning None rather + than raising lets `just plot` glob a broad set of logs (old single-mode/aborted + runs mixed in) without crashing; main() logs which paths were skipped.""" txt = path.read_text(errors="replace") hdr = next((l for l in txt.splitlines() if "ref_eq" in l and "hk_rt" in l), None) if hdr is None: - raise ValueError(f"{path}: no substrate header (hk_rt) -- not a multi-loophole run?") + return None names = [_HDR_TOK.match(t).group(0) for t in hdr.split("| INFO |", 1)[1].split()] idx = {n: i for i, n in enumerate(names)} present = [k for k in HK if k in idx] # 4-mode substrate dropped hk_eq; plot only what's logged @@ -86,6 +92,8 @@ def parse_hk(path: Path) -> dict: n, d = row[idx[k]].split("/") nd[k][0].append(int(n)) nd[k][1].append(int(d)) + if not steps: + return None # header present but no parseable per-step rows (e.g. diverged/aborted) m = re.search(r"seed(\d+)", path.name) or re.search(r"_s(\d+)", path.name) return dict( method=classify(txt), @@ -229,8 +237,17 @@ def main() -> None: args = ap.parse_args() stem = args.out_stem - # 1-2. per-mode small multiples (this script owns these) - runs = [parse_hk(p) for p in args.logs] + # 1-2. per-mode small multiples (this script owns these). Skip (don't crash on) + # logs that aren't multi-loophole substrate runs -- the glob may catch old + # single-mode/aborted runs; log which were dropped so the skip isn't silent. + parsed = {p: parse_hk(p) for p in args.logs} + skipped = [p for p, r in parsed.items() if r is None] + if skipped: + logger.warning(f"skipped {len(skipped)} non-substrate log(s): " + + ", ".join(p.name for p in skipped)) + runs = [r for r in parsed.values() if r is not None] + if not runs: + raise SystemExit("no substrate runs in the glob (need hk_rt columns)") logger.info(f"parsed {len(runs)} runs: " + ", ".join(f"{r['method']}/s{r['seed']}" for r in runs)) ylabel = "cumulative hack rate" if args.cumulative else f"hack rate (EMA span {args.ema_span})" plot_by_method(runs, ylabel, args.cumulative, args.ema_span, stem.with_name(stem.name + "_by_method.png")) @@ -238,15 +255,21 @@ def main() -> None: # 3-4. aggregate "total hacks per arm" + hack overlay (reuse plot_dynamics, # which owns route's deploy-curve substitution + the cos-alignment rows). + # Non-fatal: the two per-mode figures above are the substrate deliverable; + # plot_dynamics assumes the older erase/route column set (cin_t etc.) and + # KeyErrors on a route2 log, so a delegation failure must not sink `just plot`. if not args.no_aggregate: - import plot_dynamics as pd - agg_runs = [r for p in args.logs if (r := pd.parse_log(p))] - if agg_runs: - agg = stem.with_name(stem.name + "_aggregate.png") - pd.plot(agg_runs, agg) - pd.plot_hack_overlay(agg_runs, agg.with_name(agg.stem + "_hack_overlay.png")) - else: - logger.warning("no runs had aggregate columns (cos_pre/hack_s) -- skipped aggregate figs") + try: + import plot_dynamics as pd + agg_runs = [r for p in args.logs if (r := pd.parse_log(p))] + if agg_runs: + agg = stem.with_name(stem.name + "_aggregate.png") + pd.plot(agg_runs, agg) + pd.plot_hack_overlay(agg_runs, agg.with_name(agg.stem + "_hack_overlay.png")) + else: + logger.warning("no runs had aggregate columns (cos_pre/hack_s) -- skipped aggregate figs") + except Exception as e: + logger.warning(f"aggregate delegation (plot_dynamics) failed, per-mode figs still written: {e!r}") if __name__ == "__main__": diff --git a/src/projected_grpo/antipasto.py b/src/projected_grpo/antipasto.py index af180dd..c001208 100644 --- a/src/projected_grpo/antipasto.py +++ b/src/projected_grpo/antipasto.py @@ -133,6 +133,11 @@ def _delta_hook(layer: nn.Linear, args: tuple, y: Tensor) -> Tensor: v_act = layer._antipasto_v_act.to(a.dtype) # [r] unit, hack-ward, in Vh coords (fp32 buffer -> a.dtype) cos = (a @ v_act) / (a.norm(dim=-1).clamp_min(1e-6) * v_act.norm().clamp_min(1e-6)) m = cos > 0 # [...] bool + # Stash routing intensity so train.py can log it (else the act path is silent + # and over-routing -- m firing on ~half of all tokens, not just hack tokens -- + # is invisible). fired = fraction of token positions routed to the quarantine. + layer._antipasto_act_fired = m.float().mean().detach() + layer._antipasto_act_cos = cos.mean().detach() kept = torch.where(m.unsqueeze(-1), kept.detach(), kept) return y + (kept + quar).to(y.dtype) diff --git a/src/projected_grpo/train.py b/src/projected_grpo/train.py index dfbb18d..15f7f16 100644 --- a/src/projected_grpo/train.py +++ b/src/projected_grpo/train.py @@ -55,6 +55,7 @@ from __future__ import annotations import gzip import json +import math import os import sys import random @@ -154,6 +155,11 @@ class Config: # detach, single pass. "grad" (Arm A): per-rollout cos(g_b, v_grad) from a gate # probe, routes by subtracting flagged rollouts from delta_S.grad post-backward. route2_mask: Literal["act", "grad"] = "act" + # route2-only: the quarantine A_q/B_q (33M fresh kaiming params) is ~60x larger + # than delta_S (0.5M) and at the shared delta_S lr it diverged -- gn 0.3->7.5 at + # step 8, generations -> token salad, lp_t -11 (run 43). Give it its own lower lr. + # Scale of main lr; 1.0 = old (diverging) behaviour, 0.1 = the fix. + route2_quar_lr_scale: float = 0.1 # Scale-dependent knobs — every preset must set these to a real value; # subclasses below override the defaults. model: str = "Qwen/Qwen3-4B" @@ -687,7 +693,17 @@ class StepLogger: _Col("cos_post", 6, "cout", ".2f", "hack-ward fraction AFTER projection (want ~0: all removed)"), _Col("fired", 5, "fired", ".2f", "fraction of modules where projection fired"), ] - if arm == "routing": + # route2 act-mask: no v_hack grad projection, but the forward routes by + # cos(activation, v_act)>0. Surface that routing intensity (reuses the row's + # cos_pre/fired keys, populated from the stashed act stats in train.py) so the + # act path is no longer silent -- watch `fired` for over-routing (>>0.5 means + # the sign test fires on generic tokens, starving delta_S onto the quarantine). + if arm == "routing2_act": + cols += [ + _Col("cos_pre", 7, "act_cos", "+.2f", "mean cos(activation, v_act): forward routing alignment"), + _Col("fired", 6, "act_fire", ".2f", "fraction of token positions routed to quarantine (cos>0)"), + ] + if arm in ("routing", "routing2_act", "routing2_grad"): cols += [ _Col("hack_deploy", 7, "hk_dep", "+.2f", "DEPLOY-eval hack (quarantine deleted = deployed model)"), _Col("solve_deploy", 7, "slv_dep", "+.2f", "DEPLOY-eval solve"), @@ -754,6 +770,7 @@ def main(cfg: Config) -> int: is_route2 = cfg.intervention == "route2" is_route2_grad = is_route2 and cfg.route2_mask == "grad" + is_route2_act = is_route2 and cfg.route2_mask == "act" wrappers = wrap_model_with_antipasto( model, model_name, CACHE_ROOT, device, quarantine_rank=cfg.route2_quarantine_rank if is_route2 else None, @@ -924,10 +941,18 @@ def main(cfg: Config) -> int: f"G_s={G_s} student + G_t={G_t} teacher per prompt (mix_ratio={cfg.mix_ratio})." ) + # Quarantine (A_q/B_q) gets its own lower lr: it is ~60x bigger than delta_S and + # freshly kaiming-init, so the shared lr diverged it (run 43). Separate param group + # so the scheduler scales both proportionally (the group's lr rides on `lr` via the + # ratio captured here -- LinearLR/CosineAnnealingLR multiply each group's base lr). + quar_lr = lr * cfg.route2_quar_lr_scale opt = torch.optim.AdamW( - delta_params + delta_hack_params + quar_params, lr=lr, weight_decay=cfg.weight_decay, - betas=(adam_beta1, adam_beta2), + [{"params": delta_params + delta_hack_params, "lr": lr}, + {"params": quar_params, "lr": quar_lr}], + lr=lr, weight_decay=cfg.weight_decay, betas=(adam_beta1, adam_beta2), ) + if quar_params: + logger.info(f"route2 quarantine lr = {quar_lr:.1e} ({cfg.route2_quar_lr_scale}x main lr {lr:.1e})") # Linear warmup over `warmup_frac * steps`, then cosine decay to 0 over the rest. # Fraction-based so short presets (fast: 20 steps) don't spend half the run # under warmup. Canonical full-preset: 0.1 * 100 = 10 (matches ariahw config.py:141). @@ -1055,6 +1080,17 @@ def main(cfg: Config) -> int: rollout_log_path.write_text("") first_hack_saved = False route_span_checked = False # R3: assert delta_S_hack.grad in span(V) once + last_gen_sample = None # first student rollout of the latest step (for collapse inspection) + diverged_steps = 0 # consecutive steps with collapsed teacher ppl (divergence tripwire) + lp_t_best = -float("inf") # coherence high-water mark (best teacher gen_logp seen) + # ppl_t = exp(-lp_t) on the FIXED teacher rollouts is a free coherence gauge. + # Divergence is a DROP from the run's own best coherence, not an absolute level: + # a real model sits at lp_t ~ -0.7 and craters to -11..-21 when it diverges (run + # 43: lr too high on the 33M quarantine, generations -> token salad), a ~10-nat + # drop. A relative threshold also keeps `just smoke` green -- the tiny-random model + # has an intrinsic lp_t ~ -11.9 (uniform logp) but it stays flat, so it never DROPS. + # Abort if lp_t falls this far below its best for 2 steps running (advantage dead). + DIVERGENCE_DROP = 5.0 # nats below best (e^5 ~ 150x worse ppl); never in healthy runs dumped_hack_classes: set[str] = set() # first full example of each hack class -> verbose log teacher_dumped = False # Per-mode learning tracker (the substrate UAT: did the student learn EACH hack, @@ -1503,6 +1539,18 @@ def main(cfg: Config) -> int: diag = {"mean_cos_pre": float("nan"), "mean_cos_post": float("nan"), "frac_fired": float("nan"), "mean_cos_pre_s": float("nan"), "mean_cos_pre_t": float("nan")} + # route2 act-mask: the forward stashed per-layer fired-fraction + mean cos + # (cos(a,v_act)). Surface them in cin (mean cos) and fired (routed fraction) + # so over-routing is visible -- a frozen sign-test direction fires on ~half + # of all tokens, starving delta_S and dumping learning onto the quarantine. + if is_route2_act: + fired = [info["layer"]._antipasto_act_fired for info in wrappers.values() + if hasattr(info["layer"], "_antipasto_act_fired")] + coss = [info["layer"]._antipasto_act_cos for info in wrappers.values() + if hasattr(info["layer"], "_antipasto_act_cos")] + if fired: + diag["frac_fired"] = float(torch.stack(fired).mean()) + diag["mean_cos_pre"] = float(torch.stack(coss).mean()) # route2 grad-mask: report the mean per-module per-rollout flag rate so # we can watch the mask actually fire (and rise as hacks emerge). if is_route2_grad and step_flagged: @@ -1595,6 +1643,13 @@ def main(cfg: Config) -> int: if _was_training: model.train() refr = f"route2:{cfg.route2_mask}" + # Announce it -- the route2 refresh was previously silent (only the + # v_hack path logged "refresh@step"), so it looked like the mask never + # refreshed. NOTE: this fires AFTER opt.step(), so if the model is + # already diverging the re-extracted direction is extracted on a broken + # model -- watch lp_t / ppl_t around the refresh step. + logger.info(f"route2 {cfg.route2_mask}-mask refreshed@step{step} " + f"({len(wrappers)} modules, quarantine ablated during extract)") if v_hack is not None and do_refresh: from .extract_vhack_grad import extract_v_hack if cfg.vhack_pairs_path is not None: @@ -1810,6 +1865,25 @@ def main(cfg: Config) -> int: with rollout_log_path.open("a") as fh: for rec in step_rollouts: fh.write(json.dumps(rec) + "\n") + if step_rollouts: + last_gen_sample = (step, step_rollouts[0]) # newest student gen for the final dump + + # Divergence tripwire on teacher perplexity (free coherence gauge, see init). + ppl_t = math.exp(-lp_t_mean) if math.isfinite(lp_t_mean) else float("inf") + if math.isfinite(lp_t_mean): + lp_t_best = max(lp_t_best, lp_t_mean) + diverged = math.isfinite(lp_t_mean) and lp_t_mean < lp_t_best - DIVERGENCE_DROP + diverged_steps = diverged_steps + 1 if diverged else 0 + if diverged_steps >= 2: + logger.error( + f"DIVERGED at step {step}: lp_t={lp_t_mean:.1f} (ppl_t={ppl_t:.0e}), {lp_t_best - lp_t_mean:.1f} " + f"nats below best {lp_t_best:.1f}, for {diverged_steps} steps -- policy collapsed " + f"(gn={gn:.1f}). Aborting to save GPU. Likely lr too high (route2: lower --route2-quar-lr-scale).") + if last_gen_sample: + _s, _r = last_gen_sample + logger.error(f"--- last student gen (step {_s}, reward={_r['reward']:+.2f}) ---\n" + f"{_r['text'][:800]}\n--- END (token salad => divergence confirmed) ---") + raise RuntimeError(f"training diverged (ppl_t={ppl_t:.0e} at step {step})") if (step + 1) % 25 == 0: save_ckpt(rows) # survive early kills; ~12 days for the full sweep if not first_hack_saved and hack_s_n > 0: @@ -1881,6 +1955,17 @@ def main(cfg: Config) -> int: logger.info(f"||B_q|| sum = {bq_norm:.4f} (SHOULD: >0; ELSE quarantine never seeded)") assert bq_norm > 0.0, "route2: B_q never moved -> quarantine never seeded (mask never fired?)" + # Last training generation -- a fast eyeball for coherence before the eval + # numbers. SHOULD: real code/prose for the problem. If it is token salad the + # policy diverged and every eval number below is meaningless (see ppl_t / lp_t). + if last_gen_sample is not None: + _s, _r = last_gen_sample + logger.info( + f"\n\n=== LAST TRAIN GEN (step {_s}, reward={_r['reward']:+.2f}, " + f"gt_pass={_r['gt_pass']}, hacked={_r['hacked_E']}) ===\n" + f"SHOULD: coherent code/prose. ELSE token salad => diverged, eval below is moot.\n" + f"{_r['text'][:800]}\n=== END LAST GEN ===\n") + # Final per-mode train-vs-deploy eval -- run for EVERY arm on the SAME fixed # eval subset so the all-arms overlay reads identical numbers. For route/route2 # this is the absorption test: TRAIN keeps the quarantine knob on (still hacks),