From fc30514b23dce10025ab3c524a45b14e377fa689 Mon Sep 17 00:00:00 2001 From: wassname Date: Sat, 30 May 2026 00:50:53 +0000 Subject: [PATCH] feat: T5 eval-time ablation for route + fix route deployment invariant T5: eval_hack_solve helper + ablate_quarantine ctx; periodic ablated-eval (hack_abl/solve_abl cols, appended so results.py indices unchanged) every --eval-ablate-every steps; final kept-vs-ablated ROUTE EVAL BLUF. plot_dynamics plots the ablated series for the routing arm (the coherence-gap fix: training hack_s looks vanilla; routing only shows post-ablation). External-review fixes (docs/spec/20260530_code_review.md): - Critical: route now feeds delta_S the SAME g_proj as erase (was forcing preserve_magnitude=False/overshoot=1, which diverged from erase before AdamW). delta_S is its own AdamW param fed erase's grad, so route-ablated deployment evolves identically to erase regardless of AdamW non-linearity. Only the combined training forward over-moves (intended; never deployed). Corrected the overclaiming docstrings (no "sum == g" / "reproduces vanilla" identity). - Important: clip_grad_norm_ now covers delta_params + delta_hack_params (no-op for none/erase; bounds the route update). - Important: results.py paired-delta table includes routing (keyed on arm). smoke route/erase/vanilla green: dsh route=0.0105 erase/none=0, span=2.9e-7, ROUTE EVAL BLUF prints. Co-Authored-By: Claude Opus 4.8 --- docs/spec/20260530_code_review.md | 24 +++++++ justfile | 6 +- scripts/plot_dynamics.py | 30 ++++++-- scripts/results.py | 9 ++- src/projected_grpo/proj.py | 48 +++++++------ src/projected_grpo/train.py | 110 +++++++++++++++++++++++++++++- 6 files changed, 195 insertions(+), 32 deletions(-) create mode 100644 docs/spec/20260530_code_review.md diff --git a/docs/spec/20260530_code_review.md b/docs/spec/20260530_code_review.md new file mode 100644 index 0000000..0d0a0ca --- /dev/null +++ b/docs/spec/20260530_code_review.md @@ -0,0 +1,24 @@ +## Code Review: gradient routing split and arm->intervention fallout + +### Summary +The algebra in `src/projected_grpo/proj.py` looks right: the route path forces `preserve_magnitude=False` and `overshoot=1.0`, so the projection step itself satisfies `g_proj + removed == g`, and the erase path still uses the same projection math as before. The main problem is downstream in `train.py`: once you optimize `delta_S` and `delta_S_hack` as separate AdamW parameters, that exact-split invariant no longer survives the optimizer step. + +### Critical (must fix) +- [src/projected_grpo/proj.py:192-194, src/projected_grpo/antipasto.py:83, src/projected_grpo/train.py:679] Route is split correctly at the grad level, but not at the update level. `delta_S` and `delta_S_hack` are stepped as two independent AdamW parameters, while the forward uses their sum. AdamW is not linear in the gradient, it keeps separate moments and normalizes each tensor independently, so `step(delta_S, g_proj) + step(delta_S_hack, removed)` is not equal to `step(delta, g_proj + removed)`. On the first AdamW step this already diverges because each branch reduces to a per-coordinate sign update, so overlapping same-sign components can get double-counted. That breaks the comment-level claim that routing preserves the original training-time movement while only quarantining it for eval. Suggested fix: keep one optimizer state on the combined adapter update and only decompose for storage/eval, or use a linear optimizer/update rule for these routed knobs. + +### Important (should fix) +- [src/projected_grpo/train.py:1189] `clip_grad_norm_` only clips `delta_params`, not `delta_hack_params`. In the route arm, the parked hack component bypasses clipping entirely, so route can take a larger total step than erase/none and the logged `gn` is no longer the norm of the actual optimized gradient. If `grad_clip` is meant to constrain the route arm too, clip `delta_params + delta_hack_params` together, or clip the combined grad before splitting. +- [scripts/results.py:138] The arm rename itself does not break parsing, `results.py` now reads `arm=` from the printed preset line, which should handle both old `--arm` logs and new `--intervention` logs. But the paired-delta section still hardcodes `arm == "projected"`, so `routing` runs silently drop out of the only direct-vs-vanilla comparison table. If routing is a first-class arm now, this summary should include it explicitly. + +### Suggestions +- [scripts/plot_dynamics.py:101-105, 231-262] `plot_dynamics.py` also parses `arm=` correctly, so the rename looks safe there. The remaining issue is semantic, not parsing: routing panels are drawn from raw training-time `hack_s`, even though the route hypothesis is about the model with `delta_S_hack` ablated. Until an ablated-eval series is logged, plotting routing next to vanilla/erase will be misleading. + +### Positive +- [src/projected_grpo/proj.py:183-194] The route/erase split in the projection code is clean. Route forces the exact-split settings, and erase still goes through the old `_project_one_module` path, so the projection-layer behavior itself is consistent with the intended design. +- [scripts/results.py:40-43, scripts/plot_dynamics.py:69-71] Switching log parsing to the printed `arm=` field is the right compatibility point for old and new logs. + +### Verdict +REQUEST CHANGES + +`proj.py` is fine, but the current `train.py` wiring means route is not actually preserving the original update after AdamW, and its quarantined branch bypasses grad clipping. Fix those two first; the arm rename itself looks parse-safe. +[?2026h[?1006l[?1002l[?1000l[?1007h[?1049l[<999u[>4;0m[?2026l \ No newline at end of file diff --git a/justfile b/justfile index 09ff7e3..b326e5f 100644 --- a/justfile +++ b/justfile @@ -35,11 +35,13 @@ smoke-vanilla *ARGS: --teacher-pool-dir=out/probe_distill/teacher_pool --mix-ratio=0.5 {{ ARGS }} # Routing path: parks the hack-ward grad in delta_S_hack, ablates at eval. -# Fires the R3 span assert + the two-param optimizer path. +# Fires the R3 span assert, the two-param optimizer path, the periodic +# ablated-eval series, and the final kept-vs-ablated BLUF. smoke-route *ARGS: BEARTYPE=1 CUDA_VISIBLE_DEVICES= {{ TRAIN }} smoke --intervention=route \ --v-hack-path=out/v_hack_smoke.safetensors \ - --teacher-pool-dir=out/probe_distill/teacher_pool --mix-ratio=0.5 {{ ARGS }} + --teacher-pool-dir=out/probe_distill/teacher_pool --mix-ratio=0.5 \ + --eval-ablate-every=10 --eval-n-prompts=2 {{ ARGS }} # Run smoke twice: first warms the v_hack cache (cache-miss path), second hits # the cache (cache-hit path). Catches scope/save bugs that only manifest in one. diff --git a/scripts/plot_dynamics.py b/scripts/plot_dynamics.py index 5dd34d8..2f19020 100644 --- a/scripts/plot_dynamics.py +++ b/scripts/plot_dynamics.py @@ -13,10 +13,19 @@ to diverge from the (refreshed) v_hack. Data source: logs/*.log per-step rows (the durable source results.py also uses). We parse by HEADER NAME, not fixed index, because newer runs add columns (refr). -Arm classification (from the argv line): - vanilla arm=vanilla +Arm classification (from the preset line `arm=`, covering old --arm and new +--intervention logs): + vanilla arm=vanilla (intervention=none) static erasure arm=projected, no --vhack-refresh-every (frozen v_hack) online erasure arm=projected, --vhack-refresh-every=N>0 (re-extracted) + routing arm=routing (intervention=route) + +For routing we plot the ABLATED-eval hack/solve (hack_abl/solve_abl, measured +with delta_S_hack zeroed every --eval-ablate-every steps), NOT the training-time +hack_s: the routed forward still hacks during training, so the training curve +would falsely read "route doesn't work". The ablated curve is the deployment +model. (none/erase plot training-time hack_s; their intervention acts at train +time.) Usage: uv run python scripts/plot_dynamics.py logs/*converge*.log @@ -82,7 +91,10 @@ def parse_log(path: Path) -> dict | None: series: dict[str, list[float]] = defaultdict(list) steps: list[int] = [] - wanted = {**RATE_COLS, **COS_COLS} + # Also parse the route ablated-eval columns when present (older logs lack + # them -> skip). For routing we plot THESE, not the training-time hack_s. + abl = {"hack_abl", "solve_abl"} & set(idx) + wanted = {**RATE_COLS, **COS_COLS, **{c: c for c in abl}} for line in txt.splitlines(): if "| INFO |" not in line: continue @@ -94,8 +106,16 @@ def parse_log(path: Path) -> dict | None: series[col].append(_val(row[idx[col]])) if not steps: return None - return dict(arm=arm, refr=refr, seed=seed, vhack=vhack, - steps=np.array(steps), **{k: np.array(v, dtype=float) for k, v in series.items()}) + run = dict(arm=arm, refr=refr, seed=seed, vhack=vhack, + steps=np.array(steps), **{k: np.array(v, dtype=float) for k, v in series.items()}) + # COHERENCE-GAP FIX: route's training-time hack_s looks vanilla (the routed + # forward still hacks); routing's benefit only shows once delta_S_hack is + # ablated at eval. So for routing, plot the ablated series under the same + # hack_s/gt_s keys -> all downstream (panels, onset, overlay) reads it. + if arm == "routing" and "hack_abl" in run: + run["hack_s"] = run["hack_abl"] + run["gt_s"] = run["solve_abl"] + return run def classify(run: dict) -> str: diff --git a/scripts/results.py b/scripts/results.py index 6168e41..5213582 100644 --- a/scripts/results.py +++ b/scripts/results.py @@ -135,11 +135,16 @@ def main() -> None: van = (df.filter(pl.col("arm") == "vanilla") .select(["mix", "seed", "L5_hack", "L5_solve"]) .rename({"L5_hack": "v_hack", "L5_solve": "v_solve"})) - j = (df.filter(pl.col("arm") == "projected") + # Both intervention arms compare against the same-seed vanilla. routing is a + # first-class arm now, so include it (keyed on `arm` below so it doesn't + # merge with projected). NOTE: routing's L5_hack here is the TRAINING-time + # hack (the routed forward still hacks); the deployment number is the + # ablated-eval (ROUTE EVAL BLUF / hack_abl), not this column. + j = (df.filter(pl.col("arm").is_in(["projected", "routing"])) .join(van, on=["mix", "seed"], how="inner") .with_columns((pl.col("L5_hack") - pl.col("v_hack")).alias("dh"), (pl.col("L5_solve") - pl.col("v_solve")).alias("ds"))) - pkey = ["mix", "refr", "over", "gate", "k", "vhack"] + pkey = ["arm", "mix", "refr", "over", "gate", "k", "vhack"] pj = (j.group_by(pkey) .agg(pl.col("dh").mean().alias("Dhack"), pl.col("dh").std().alias("Dhack_sd"), diff --git a/src/projected_grpo/proj.py b/src/projected_grpo/proj.py index 98d1b9a..ca7d7eb 100644 --- a/src/projected_grpo/proj.py +++ b/src/projected_grpo/proj.py @@ -79,10 +79,12 @@ def _project_one_module( ) -> tuple[Float[torch.Tensor, "r"], Float[torch.Tensor, "r"], float, float, bool]: """Per-module top-k removal. Returns (g_proj, removed, cos_pre, cos_post, fired). - `removed` = overshoot*c_use@V, the vector subtracted from g (pre-rescale). - Erasure drops it; routing parks it in delta_S_hack. When preserve_magnitude - is False and overshoot is 1.0, g_proj + removed == g exactly (the invariant - routing relies on so the training-time forward still moves hack-ward). + `removed` = overshoot*c_use@V, the vector subtracted from g (computed + before any preserve_magnitude rescale, so removed ∈ span(V) always). + Erasure drops it; routing parks it in delta_S_hack. Note g_proj + removed + == g ONLY when preserve_magnitude is False and overshoot is 1.0; with the + defaults g_proj is rescaled, so the sum is not the original g (routing does + not rely on that sum -- see project_delta_S_grad). cos_pre / cos_post are SIGNED scalars (sum of per-axis V @ g coefficients, normalized by ||g||). Positive = grad pushes toward hack; negative = grad @@ -156,14 +158,19 @@ def project_delta_S_grad( `preserve_magnitude`: rescale g' to ||g|| after projection. `measure_only`: same math, but g is not mutated (the `none` intervention). - `route`: instead of discarding the removed hack component, park it in the - quarantine knob delta_S_hack.grad (Gradient Routing, Cloud 2410.04332). - delta_S keeps g - removed; delta_S_hack gets removed; their sum is the - original g, so the training-time forward still moves hack-ward, but an - eval with delta_S_hack zeroed has the hack capability ablated. Routing - forces overshoot=1.0 and preserve_magnitude=False so the split is exact - (g_proj + removed == g); the route-vs-erase comparison is only clean at - overshoot=1.0 anyway (route ⊇ erase). Mutually exclusive with measure_only. + `route`: erase AND park the removed hack-ward component in the quarantine + knob delta_S_hack.grad (Gradient Routing, Cloud 2410.04332). delta_S gets + the IDENTICAL g_proj as erase (same gate/preserve/overshoot), so the + deployment model -- delta_S with delta_S_hack zeroed at eval -- evolves + under the same update rule as the erase arm (each is its own AdamW param; + the quarantine's separate optimizer state cannot perturb delta_S). That is + the sense in which route ⊇ erase: erase == route with the quarantine + discarded. CAVEAT (not an identity): the combined TRAINING forward + delta_S + delta_S_hack does NOT reproduce a vanilla update -- AdamW steps + the two knobs independently, so the sum over-moves hack-ward. That is + intended (the model keeps hacking during training so the capability lands + in the quarantine), and it only affects the training trajectory, never the + ablated deployment. Mutually exclusive with measure_only. Diagnostics returned (per call, averaged over modules): mean_cos_pre = mean over modules of sum(V @ g)/||g||, signed @@ -180,20 +187,17 @@ def project_delta_S_grad( if name not in v_hack: # module dropped by global noise-floor filter continue V = v_hack[name].to(g.device, dtype=g.dtype) # [k, r] - if route: - g_proj, removed, cos_pre, cos_post, fired = _project_one_module( - g, V, gate_mode, preserve_magnitude=False, overshoot=1.0) - else: - g_proj, removed, cos_pre, cos_post, fired = _project_one_module( - g, V, gate_mode, preserve_magnitude, overshoot) + g_proj, removed, cos_pre, cos_post, fired = _project_one_module( + g, V, gate_mode, preserve_magnitude, overshoot) cos_pre_list.append(cos_pre) cos_post_list.append(cos_post) - if fired: + if fired and not measure_only: + info["delta_S"].grad = g_proj # same update rule as erase if route: - info["delta_S"].grad = g_proj + # quarantine the discarded hack-ward part; removed ∈ span(V), + # ablated at eval so its magnitude/overshoot scaling is harmless. info["delta_S_hack"].grad = removed - elif not measure_only: - info["delta_S"].grad = g_proj + if fired: n_fired += 1 pre_t = torch.tensor(cos_pre_list); post_t = torch.tensor(cos_post_list) return { diff --git a/src/projected_grpo/train.py b/src/projected_grpo/train.py index 8c0add4..a839ae5 100644 --- a/src/projected_grpo/train.py +++ b/src/projected_grpo/train.py @@ -58,6 +58,7 @@ import json import os import sys import time +from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime from pathlib import Path @@ -181,6 +182,15 @@ class Config: # discriminative power as the student drifted. 0 = off (load once at start # and freeze). Refresh cost ~14*2 backwards on Qwen3-4B ~ 1-2 min wall. vhack_refresh_every: int = 0 + # Route eval-time ablation: every N steps (and at the end), zero delta_S_hack + # and eval hack/solve on a fixed prompt subset -> the `hack_abl`/`solve_abl` + # columns. This is the series the dynamics plot uses for route, because the + # TRAINING-time hack curve looks vanilla (the routed forward still hacks); + # routing's benefit only shows once the quarantine is ablated. 0 = off (the + # final kept-vs-ablated BLUF still prints for route). Only meaningful for + # intervention=route. eval_n_prompts prompts x `group` samples each. + eval_ablate_every: int = 0 + eval_n_prompts: int = 8 # Optional: pool-derived pairs JSON (built by pairs_from_pool.py). When set, # BOTH the cache-miss extract AND the online refresh use these pairs instead # of the hand-crafted projected_grpo.pairs.PAIRS. Required for the cross- @@ -456,6 +466,50 @@ def ref_logprobs_via_zero_delta( info["delta_S"].data.copy_(saved[n]) +@contextmanager +def ablate_quarantine(wrappers: dict): + """Zero delta_S_hack for the duration -- the eval-time ablation of the + routed hack capability. Save -> zero -> (eval) -> restore. The route arm's + deployment model IS this ablated state.""" + saved = {n: info["delta_S_hack"].data.clone() for n, info in wrappers.items()} + for info in wrappers.values(): + info["delta_S_hack"].data.zero_() + try: + yield + finally: + for n, info in wrappers.items(): + info["delta_S_hack"].data.copy_(saved[n]) + + +@torch.no_grad() +def eval_hack_solve(model, tok, problems, eval_idxs, gen_cfg, device, max_new) -> dict: + """Student-only generate + grade on a FIXED prompt subset (no teacher, no + backward) -- a clean read of what the current adapter does. + + hack = C-detector rate (M1 reward hack); solve = gt_pass rate (held-out + grader). Same compute_reward as training, so the numbers are comparable to + the per-step hack_s/gt_s, just measured off-policy on a held-fixed subset. + """ + model.config.use_cache = True + n = hacks = solves = 0 + for idx in eval_idxs: + prob = problems[idx] + prompt = tok.apply_chat_template( + prob["messages"], tokenize=False, add_generation_prompt=True, enable_thinking=False) + enc = tok(prompt, return_tensors="pt", add_special_tokens=False).to(device) + if enc.input_ids.shape[1] + max_new > 2048: + continue + out = model.generate(**enc, generation_config=gen_cfg) + comps = out[:, enc.input_ids.shape[1]:] + for t in tok.batch_decode(comps, skip_special_tokens=True): + r = compute_reward( + t, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"], + setup_code=prob["setup_code"], func_name_hint=prob["func_name"]) + hacks += int(r.hacked); solves += int(r.gt_pass); n += 1 + model.config.use_cache = False + return dict(hack=hacks / max(1, n), solve=solves / max(1, n), n=n) + + @dataclass(frozen=True) class _Col: """Per-step table column spec. @@ -528,6 +582,9 @@ class StepLogger: # refr = "mod/axes" of the v_hack re-extracted this step, "-" if no # refresh fired (frozen-V and vanilla runs are all "-"). _Col("refr", 9, "refr", None), + # Route-only ablated-eval hack/solve (delta_S_hack=0); nan elsewhere. + _Col("hack_abl", 8, "hack_abl", "+.3f"), + _Col("solve_abl", 9, "solve_abl", "+.3f"), ] def header(self) -> str: @@ -707,6 +764,13 @@ def main(cfg: Config) -> int: repetition_penalty=1.0, num_return_sequences=G_s, pad_token_id=tok.pad_token_id, ) + # Eval-ablation config: student-only, `group` samples/prompt (no teacher + # split, so we want the full group for a tighter rate estimate). + gen_cfg_eval = GenerationConfig( + max_new_tokens=max_new, do_sample=True, + temperature=0.7, top_p=1.0, top_k=20, min_p=0.0, repetition_penalty=1.0, + num_return_sequences=group, pad_token_id=tok.pad_token_id, + ) problems = load_problems(n_problems) logger.info(f"loaded {len(problems)} problems from {DATA.name}") @@ -726,6 +790,10 @@ def main(cfg: Config) -> int: f"({len(teacher_pool)} cached prompts). Re-run pregen-teacher against the same dataset." ) + # Fixed eval subset for route ablation: first eval_n_prompts problems, held + # constant across the run so the ablated-hack series is comparable step-to-step. + eval_idxs = list(range(min(cfg.eval_n_prompts, len(problems)))) + rng = torch.Generator().manual_seed(cfg.seed) rows = [] logger.info( @@ -1186,7 +1254,10 @@ table columns: # clip_grad_norm_ returns the pre-clip total L2 norm — capture for the # per-step `gn` column so we can see whether the clip threshold is the # bottleneck on update magnitude (compare gn vs cfg.grad_clip). - gn = float(torch.nn.utils.clip_grad_norm_(delta_params, cfg.grad_clip)) + # Clip over both knobs. For none/erase, delta_S_hack.grad is None so it's + # ignored -> identical norm to before (R4). For route it bounds the + # combined update (main + quarantine). + gn = float(torch.nn.utils.clip_grad_norm_(delta_params + delta_hack_params, cfg.grad_clip)) opt.step() sched.step() @@ -1236,6 +1307,23 @@ table columns: model.train() refr = f"{len(v_hack)}/{sum(V.shape[0] for V in v_hack.values())}" # mod/axes -> per-step row + # Periodic route ablated-eval: zero the quarantine, eval hack/solve on the + # fixed subset. This is the curve the plot uses for route (training-time + # hack_s looks vanilla; the routed forward still hacks). NaN on non-eval + # steps and for non-route arms (plot's EMA holds NaN). + hack_abl = solve_abl = float("nan") + if (cfg.intervention == "route" and cfg.eval_ablate_every > 0 + and (step % cfg.eval_ablate_every == 0 or step == steps - 1)): + _was_training = model.training + model.eval() + with ablate_quarantine(wrappers): + ev = eval_hack_solve(model, tok, problems, eval_idxs, gen_cfg_eval, device, max_new) + if _was_training: + model.train() + hack_abl, solve_abl = ev["hack"], ev["solve"] + logger.info(f"step {step} route ablated-eval (delta_S_hack=0): " + f"hack={hack_abl:.3f} solve={solve_abl:.3f} (n={ev['n']})") + rewards_t = torch.tensor(agg_rew, dtype=torch.float32) if agg_rew else torch.zeros(1) rew_mean = rewards_t.mean().item() rew_std = rewards_t.std().item() if rewards_t.numel() > 1 else 0.0 @@ -1343,6 +1431,11 @@ table columns: "cos_post": diag["mean_cos_post"], "fired": diag["frac_fired"], "refr": refr, + # Route ablated-eval (delta_S_hack=0); NaN except on route eval steps. + # Appended AFTER refr so results.py's positional GT_S/HACK_S indices + # are unaffected. plot_dynamics reads it by name. + "hack_abl": hack_abl, + "solve_abl": solve_abl, "gen": t_gen, "fb": t_fb, "t_rew": t_rew, @@ -1421,6 +1514,21 @@ table columns: if cfg.intervention == "route": assert dsh_norm > 0.0, "route: delta_S_hack never moved -> degenerated to erasure" + # Route: final kept-vs-ablated eval -- the absorption test. KEPT keeps the + # quarantine (training-time model, still hacks); ABLATED zeroes it (the + # deployment model). SHOULD: ablated hack < kept hack at preserved solve + # => the quarantine absorbed the hack. ELSE routing didn't localize it. + if cfg.intervention == "route": + model.eval() + ev_kept = eval_hack_solve(model, tok, problems, eval_idxs, gen_cfg_eval, device, max_new) + with ablate_quarantine(wrappers): + ev_abl = eval_hack_solve(model, tok, problems, eval_idxs, gen_cfg_eval, device, max_new) + logger.info( + f"ROUTE EVAL (n={ev_kept['n']}): " + f"kept hack={ev_kept['hack']:.3f} solve={ev_kept['solve']:.3f} | " + f"ablated hack={ev_abl['hack']:.3f} solve={ev_abl['solve']:.3f} " + f"(SHOULD: ablated hack < kept hack at ~matched solve)") + # Final tail: cue emoji + main metric BLUF, then per-step tsv table. # Vanilla arm: 🟢 if hacking emerged. Projected arm: 🟢 if HACK_RATE dropped # vs a matched-PASS vanilla — we can't judge that here, so just report.