feat: T5 eval-time ablation for route + fix route deployment invariant

T5: eval_hack_solve helper + ablate_quarantine ctx; periodic ablated-eval
(hack_abl/solve_abl cols, appended so results.py indices unchanged) every
--eval-ablate-every steps; final kept-vs-ablated ROUTE EVAL BLUF. plot_dynamics
plots the ablated series for the routing arm (the coherence-gap fix: training
hack_s looks vanilla; routing only shows post-ablation).

External-review fixes (docs/spec/20260530_code_review.md):
- Critical: route now feeds delta_S the SAME g_proj as erase (was forcing
  preserve_magnitude=False/overshoot=1, which diverged from erase before AdamW).
  delta_S is its own AdamW param fed erase's grad, so route-ablated deployment
  evolves identically to erase regardless of AdamW non-linearity. Only the
  combined training forward over-moves (intended; never deployed). Corrected the
  overclaiming docstrings (no "sum == g" / "reproduces vanilla" identity).
- Important: clip_grad_norm_ now covers delta_params + delta_hack_params
  (no-op for none/erase; bounds the route update).
- Important: results.py paired-delta table includes routing (keyed on arm).

smoke route/erase/vanilla green: dsh route=0.0105 erase/none=0, span=2.9e-7,
ROUTE EVAL BLUF prints.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
wassname
2026-05-30 00:50:53 +00:00
parent d6342ab201
commit fc30514b23
6 changed files with 195 additions and 32 deletions
+24
View File
@@ -0,0 +1,24 @@
## Code Review: gradient routing split and arm->intervention fallout
### Summary
The algebra in `src/projected_grpo/proj.py` looks right: the route path forces `preserve_magnitude=False` and `overshoot=1.0`, so the projection step itself satisfies `g_proj + removed == g`, and the erase path still uses the same projection math as before. The main problem is downstream in `train.py`: once you optimize `delta_S` and `delta_S_hack` as separate AdamW parameters, that exact-split invariant no longer survives the optimizer step.
### Critical (must fix)
- [src/projected_grpo/proj.py:192-194, src/projected_grpo/antipasto.py:83, src/projected_grpo/train.py:679] Route is split correctly at the grad level, but not at the update level. `delta_S` and `delta_S_hack` are stepped as two independent AdamW parameters, while the forward uses their sum. AdamW is not linear in the gradient, it keeps separate moments and normalizes each tensor independently, so `step(delta_S, g_proj) + step(delta_S_hack, removed)` is not equal to `step(delta, g_proj + removed)`. On the first AdamW step this already diverges because each branch reduces to a per-coordinate sign update, so overlapping same-sign components can get double-counted. That breaks the comment-level claim that routing preserves the original training-time movement while only quarantining it for eval. Suggested fix: keep one optimizer state on the combined adapter update and only decompose for storage/eval, or use a linear optimizer/update rule for these routed knobs.
### Important (should fix)
- [src/projected_grpo/train.py:1189] `clip_grad_norm_` only clips `delta_params`, not `delta_hack_params`. In the route arm, the parked hack component bypasses clipping entirely, so route can take a larger total step than erase/none and the logged `gn` is no longer the norm of the actual optimized gradient. If `grad_clip` is meant to constrain the route arm too, clip `delta_params + delta_hack_params` together, or clip the combined grad before splitting.
- [scripts/results.py:138] The arm rename itself does not break parsing, `results.py` now reads `arm=` from the printed preset line, which should handle both old `--arm` logs and new `--intervention` logs. But the paired-delta section still hardcodes `arm == "projected"`, so `routing` runs silently drop out of the only direct-vs-vanilla comparison table. If routing is a first-class arm now, this summary should include it explicitly.
### Suggestions
- [scripts/plot_dynamics.py:101-105, 231-262] `plot_dynamics.py` also parses `arm=` correctly, so the rename looks safe there. The remaining issue is semantic, not parsing: routing panels are drawn from raw training-time `hack_s`, even though the route hypothesis is about the model with `delta_S_hack` ablated. Until an ablated-eval series is logged, plotting routing next to vanilla/erase will be misleading.
### Positive
- [src/projected_grpo/proj.py:183-194] The route/erase split in the projection code is clean. Route forces the exact-split settings, and erase still goes through the old `_project_one_module` path, so the projection-layer behavior itself is consistent with the intended design.
- [scripts/results.py:40-43, scripts/plot_dynamics.py:69-71] Switching log parsing to the printed `arm=<name>` field is the right compatibility point for old and new logs.
### Verdict
REQUEST CHANGES
`proj.py` is fine, but the current `train.py` wiring means route is not actually preserving the original update after AdamW, and its quarantined branch bypasses grad clipping. Fix those two first; the arm rename itself looks parse-safe.
[?2026h[?1006l[?1002l[?1000l[?1007h[?1049l[<999u[>4;0m[?2026l
+4 -2
View File
@@ -35,11 +35,13 @@ smoke-vanilla *ARGS:
--teacher-pool-dir=out/probe_distill/teacher_pool --mix-ratio=0.5 {{ ARGS }}
# Routing path: parks the hack-ward grad in delta_S_hack, ablates at eval.
# Fires the R3 span assert + the two-param optimizer path.
# Fires the R3 span assert, the two-param optimizer path, the periodic
# ablated-eval series, and the final kept-vs-ablated BLUF.
smoke-route *ARGS:
BEARTYPE=1 CUDA_VISIBLE_DEVICES= {{ TRAIN }} smoke --intervention=route \
--v-hack-path=out/v_hack_smoke.safetensors \
--teacher-pool-dir=out/probe_distill/teacher_pool --mix-ratio=0.5 {{ ARGS }}
--teacher-pool-dir=out/probe_distill/teacher_pool --mix-ratio=0.5 \
--eval-ablate-every=10 --eval-n-prompts=2 {{ ARGS }}
# Run smoke twice: first warms the v_hack cache (cache-miss path), second hits
# the cache (cache-hit path). Catches scope/save bugs that only manifest in one.
+25 -5
View File
@@ -13,10 +13,19 @@ to diverge from the (refreshed) v_hack.
Data source: logs/*.log per-step rows (the durable source results.py also uses).
We parse by HEADER NAME, not fixed index, because newer runs add columns (refr).
Arm classification (from the argv line):
vanilla arm=vanilla
Arm classification (from the preset line `arm=`, covering old --arm and new
--intervention logs):
vanilla arm=vanilla (intervention=none)
static erasure arm=projected, no --vhack-refresh-every (frozen v_hack)
online erasure arm=projected, --vhack-refresh-every=N>0 (re-extracted)
routing arm=routing (intervention=route)
For routing we plot the ABLATED-eval hack/solve (hack_abl/solve_abl, measured
with delta_S_hack zeroed every --eval-ablate-every steps), NOT the training-time
hack_s: the routed forward still hacks during training, so the training curve
would falsely read "route doesn't work". The ablated curve is the deployment
model. (none/erase plot training-time hack_s; their intervention acts at train
time.)
Usage:
uv run python scripts/plot_dynamics.py logs/*converge*.log
@@ -82,7 +91,10 @@ def parse_log(path: Path) -> dict | None:
series: dict[str, list[float]] = defaultdict(list)
steps: list[int] = []
wanted = {**RATE_COLS, **COS_COLS}
# Also parse the route ablated-eval columns when present (older logs lack
# them -> skip). For routing we plot THESE, not the training-time hack_s.
abl = {"hack_abl", "solve_abl"} & set(idx)
wanted = {**RATE_COLS, **COS_COLS, **{c: c for c in abl}}
for line in txt.splitlines():
if "| INFO |" not in line:
continue
@@ -94,8 +106,16 @@ def parse_log(path: Path) -> dict | None:
series[col].append(_val(row[idx[col]]))
if not steps:
return None
return dict(arm=arm, refr=refr, seed=seed, vhack=vhack,
steps=np.array(steps), **{k: np.array(v, dtype=float) for k, v in series.items()})
run = dict(arm=arm, refr=refr, seed=seed, vhack=vhack,
steps=np.array(steps), **{k: np.array(v, dtype=float) for k, v in series.items()})
# COHERENCE-GAP FIX: route's training-time hack_s looks vanilla (the routed
# forward still hacks); routing's benefit only shows once delta_S_hack is
# ablated at eval. So for routing, plot the ablated series under the same
# hack_s/gt_s keys -> all downstream (panels, onset, overlay) reads it.
if arm == "routing" and "hack_abl" in run:
run["hack_s"] = run["hack_abl"]
run["gt_s"] = run["solve_abl"]
return run
def classify(run: dict) -> str:
+7 -2
View File
@@ -135,11 +135,16 @@ def main() -> None:
van = (df.filter(pl.col("arm") == "vanilla")
.select(["mix", "seed", "L5_hack", "L5_solve"])
.rename({"L5_hack": "v_hack", "L5_solve": "v_solve"}))
j = (df.filter(pl.col("arm") == "projected")
# Both intervention arms compare against the same-seed vanilla. routing is a
# first-class arm now, so include it (keyed on `arm` below so it doesn't
# merge with projected). NOTE: routing's L5_hack here is the TRAINING-time
# hack (the routed forward still hacks); the deployment number is the
# ablated-eval (ROUTE EVAL BLUF / hack_abl), not this column.
j = (df.filter(pl.col("arm").is_in(["projected", "routing"]))
.join(van, on=["mix", "seed"], how="inner")
.with_columns((pl.col("L5_hack") - pl.col("v_hack")).alias("dh"),
(pl.col("L5_solve") - pl.col("v_solve")).alias("ds")))
pkey = ["mix", "refr", "over", "gate", "k", "vhack"]
pkey = ["arm", "mix", "refr", "over", "gate", "k", "vhack"]
pj = (j.group_by(pkey)
.agg(pl.col("dh").mean().alias("Dhack"),
pl.col("dh").std().alias("Dhack_sd"),
+26 -22
View File
@@ -79,10 +79,12 @@ def _project_one_module(
) -> tuple[Float[torch.Tensor, "r"], Float[torch.Tensor, "r"], float, float, bool]:
"""Per-module top-k removal. Returns (g_proj, removed, cos_pre, cos_post, fired).
`removed` = overshoot*c_use@V, the vector subtracted from g (pre-rescale).
Erasure drops it; routing parks it in delta_S_hack. When preserve_magnitude
is False and overshoot is 1.0, g_proj + removed == g exactly (the invariant
routing relies on so the training-time forward still moves hack-ward).
`removed` = overshoot*c_use@V, the vector subtracted from g (computed
before any preserve_magnitude rescale, so removed ∈ span(V) always).
Erasure drops it; routing parks it in delta_S_hack. Note g_proj + removed
== g ONLY when preserve_magnitude is False and overshoot is 1.0; with the
defaults g_proj is rescaled, so the sum is not the original g (routing does
not rely on that sum -- see project_delta_S_grad).
cos_pre / cos_post are SIGNED scalars (sum of per-axis V @ g coefficients,
normalized by ||g||). Positive = grad pushes toward hack; negative = grad
@@ -156,14 +158,19 @@ def project_delta_S_grad(
`preserve_magnitude`: rescale g' to ||g|| after projection.
`measure_only`: same math, but g is not mutated (the `none` intervention).
`route`: instead of discarding the removed hack component, park it in the
quarantine knob delta_S_hack.grad (Gradient Routing, Cloud 2410.04332).
delta_S keeps g - removed; delta_S_hack gets removed; their sum is the
original g, so the training-time forward still moves hack-ward, but an
eval with delta_S_hack zeroed has the hack capability ablated. Routing
forces overshoot=1.0 and preserve_magnitude=False so the split is exact
(g_proj + removed == g); the route-vs-erase comparison is only clean at
overshoot=1.0 anyway (route ⊇ erase). Mutually exclusive with measure_only.
`route`: erase AND park the removed hack-ward component in the quarantine
knob delta_S_hack.grad (Gradient Routing, Cloud 2410.04332). delta_S gets
the IDENTICAL g_proj as erase (same gate/preserve/overshoot), so the
deployment model -- delta_S with delta_S_hack zeroed at eval -- evolves
under the same update rule as the erase arm (each is its own AdamW param;
the quarantine's separate optimizer state cannot perturb delta_S). That is
the sense in which route ⊇ erase: erase == route with the quarantine
discarded. CAVEAT (not an identity): the combined TRAINING forward
delta_S + delta_S_hack does NOT reproduce a vanilla update -- AdamW steps
the two knobs independently, so the sum over-moves hack-ward. That is
intended (the model keeps hacking during training so the capability lands
in the quarantine), and it only affects the training trajectory, never the
ablated deployment. Mutually exclusive with measure_only.
Diagnostics returned (per call, averaged over modules):
mean_cos_pre = mean over modules of sum(V @ g)/||g||, signed
@@ -180,20 +187,17 @@ def project_delta_S_grad(
if name not in v_hack: # module dropped by global noise-floor filter
continue
V = v_hack[name].to(g.device, dtype=g.dtype) # [k, r]
if route:
g_proj, removed, cos_pre, cos_post, fired = _project_one_module(
g, V, gate_mode, preserve_magnitude=False, overshoot=1.0)
else:
g_proj, removed, cos_pre, cos_post, fired = _project_one_module(
g, V, gate_mode, preserve_magnitude, overshoot)
g_proj, removed, cos_pre, cos_post, fired = _project_one_module(
g, V, gate_mode, preserve_magnitude, overshoot)
cos_pre_list.append(cos_pre)
cos_post_list.append(cos_post)
if fired:
if fired and not measure_only:
info["delta_S"].grad = g_proj # same update rule as erase
if route:
info["delta_S"].grad = g_proj
# quarantine the discarded hack-ward part; removed ∈ span(V),
# ablated at eval so its magnitude/overshoot scaling is harmless.
info["delta_S_hack"].grad = removed
elif not measure_only:
info["delta_S"].grad = g_proj
if fired:
n_fired += 1
pre_t = torch.tensor(cos_pre_list); post_t = torch.tensor(cos_post_list)
return {
+109 -1
View File
@@ -58,6 +58,7 @@ import json
import os
import sys
import time
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
@@ -181,6 +182,15 @@ class Config:
# discriminative power as the student drifted. 0 = off (load once at start
# and freeze). Refresh cost ~14*2 backwards on Qwen3-4B ~ 1-2 min wall.
vhack_refresh_every: int = 0
# Route eval-time ablation: every N steps (and at the end), zero delta_S_hack
# and eval hack/solve on a fixed prompt subset -> the `hack_abl`/`solve_abl`
# columns. This is the series the dynamics plot uses for route, because the
# TRAINING-time hack curve looks vanilla (the routed forward still hacks);
# routing's benefit only shows once the quarantine is ablated. 0 = off (the
# final kept-vs-ablated BLUF still prints for route). Only meaningful for
# intervention=route. eval_n_prompts prompts x `group` samples each.
eval_ablate_every: int = 0
eval_n_prompts: int = 8
# Optional: pool-derived pairs JSON (built by pairs_from_pool.py). When set,
# BOTH the cache-miss extract AND the online refresh use these pairs instead
# of the hand-crafted projected_grpo.pairs.PAIRS. Required for the cross-
@@ -456,6 +466,50 @@ def ref_logprobs_via_zero_delta(
info["delta_S"].data.copy_(saved[n])
@contextmanager
def ablate_quarantine(wrappers: dict):
"""Zero delta_S_hack for the duration -- the eval-time ablation of the
routed hack capability. Save -> zero -> (eval) -> restore. The route arm's
deployment model IS this ablated state."""
saved = {n: info["delta_S_hack"].data.clone() for n, info in wrappers.items()}
for info in wrappers.values():
info["delta_S_hack"].data.zero_()
try:
yield
finally:
for n, info in wrappers.items():
info["delta_S_hack"].data.copy_(saved[n])
@torch.no_grad()
def eval_hack_solve(model, tok, problems, eval_idxs, gen_cfg, device, max_new) -> dict:
"""Student-only generate + grade on a FIXED prompt subset (no teacher, no
backward) -- a clean read of what the current adapter does.
hack = C-detector rate (M1 reward hack); solve = gt_pass rate (held-out
grader). Same compute_reward as training, so the numbers are comparable to
the per-step hack_s/gt_s, just measured off-policy on a held-fixed subset.
"""
model.config.use_cache = True
n = hacks = solves = 0
for idx in eval_idxs:
prob = problems[idx]
prompt = tok.apply_chat_template(
prob["messages"], tokenize=False, add_generation_prompt=True, enable_thinking=False)
enc = tok(prompt, return_tensors="pt", add_special_tokens=False).to(device)
if enc.input_ids.shape[1] + max_new > 2048:
continue
out = model.generate(**enc, generation_config=gen_cfg)
comps = out[:, enc.input_ids.shape[1]:]
for t in tok.batch_decode(comps, skip_special_tokens=True):
r = compute_reward(
t, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"],
setup_code=prob["setup_code"], func_name_hint=prob["func_name"])
hacks += int(r.hacked); solves += int(r.gt_pass); n += 1
model.config.use_cache = False
return dict(hack=hacks / max(1, n), solve=solves / max(1, n), n=n)
@dataclass(frozen=True)
class _Col:
"""Per-step table column spec.
@@ -528,6 +582,9 @@ class StepLogger:
# refr = "mod/axes" of the v_hack re-extracted this step, "-" if no
# refresh fired (frozen-V and vanilla runs are all "-").
_Col("refr", 9, "refr", None),
# Route-only ablated-eval hack/solve (delta_S_hack=0); nan elsewhere.
_Col("hack_abl", 8, "hack_abl", "+.3f"),
_Col("solve_abl", 9, "solve_abl", "+.3f"),
]
def header(self) -> str:
@@ -707,6 +764,13 @@ def main(cfg: Config) -> int:
repetition_penalty=1.0,
num_return_sequences=G_s, pad_token_id=tok.pad_token_id,
)
# Eval-ablation config: student-only, `group` samples/prompt (no teacher
# split, so we want the full group for a tighter rate estimate).
gen_cfg_eval = GenerationConfig(
max_new_tokens=max_new, do_sample=True,
temperature=0.7, top_p=1.0, top_k=20, min_p=0.0, repetition_penalty=1.0,
num_return_sequences=group, pad_token_id=tok.pad_token_id,
)
problems = load_problems(n_problems)
logger.info(f"loaded {len(problems)} problems from {DATA.name}")
@@ -726,6 +790,10 @@ def main(cfg: Config) -> int:
f"({len(teacher_pool)} cached prompts). Re-run pregen-teacher against the same dataset."
)
# Fixed eval subset for route ablation: first eval_n_prompts problems, held
# constant across the run so the ablated-hack series is comparable step-to-step.
eval_idxs = list(range(min(cfg.eval_n_prompts, len(problems))))
rng = torch.Generator().manual_seed(cfg.seed)
rows = []
logger.info(
@@ -1186,7 +1254,10 @@ table columns:
# clip_grad_norm_ returns the pre-clip total L2 norm — capture for the
# per-step `gn` column so we can see whether the clip threshold is the
# bottleneck on update magnitude (compare gn vs cfg.grad_clip).
gn = float(torch.nn.utils.clip_grad_norm_(delta_params, cfg.grad_clip))
# Clip over both knobs. For none/erase, delta_S_hack.grad is None so it's
# ignored -> identical norm to before (R4). For route it bounds the
# combined update (main + quarantine).
gn = float(torch.nn.utils.clip_grad_norm_(delta_params + delta_hack_params, cfg.grad_clip))
opt.step()
sched.step()
@@ -1236,6 +1307,23 @@ table columns:
model.train()
refr = f"{len(v_hack)}/{sum(V.shape[0] for V in v_hack.values())}" # mod/axes -> per-step row
# Periodic route ablated-eval: zero the quarantine, eval hack/solve on the
# fixed subset. This is the curve the plot uses for route (training-time
# hack_s looks vanilla; the routed forward still hacks). NaN on non-eval
# steps and for non-route arms (plot's EMA holds NaN).
hack_abl = solve_abl = float("nan")
if (cfg.intervention == "route" and cfg.eval_ablate_every > 0
and (step % cfg.eval_ablate_every == 0 or step == steps - 1)):
_was_training = model.training
model.eval()
with ablate_quarantine(wrappers):
ev = eval_hack_solve(model, tok, problems, eval_idxs, gen_cfg_eval, device, max_new)
if _was_training:
model.train()
hack_abl, solve_abl = ev["hack"], ev["solve"]
logger.info(f"step {step} route ablated-eval (delta_S_hack=0): "
f"hack={hack_abl:.3f} solve={solve_abl:.3f} (n={ev['n']})")
rewards_t = torch.tensor(agg_rew, dtype=torch.float32) if agg_rew else torch.zeros(1)
rew_mean = rewards_t.mean().item()
rew_std = rewards_t.std().item() if rewards_t.numel() > 1 else 0.0
@@ -1343,6 +1431,11 @@ table columns:
"cos_post": diag["mean_cos_post"],
"fired": diag["frac_fired"],
"refr": refr,
# Route ablated-eval (delta_S_hack=0); NaN except on route eval steps.
# Appended AFTER refr so results.py's positional GT_S/HACK_S indices
# are unaffected. plot_dynamics reads it by name.
"hack_abl": hack_abl,
"solve_abl": solve_abl,
"gen": t_gen,
"fb": t_fb,
"t_rew": t_rew,
@@ -1421,6 +1514,21 @@ table columns:
if cfg.intervention == "route":
assert dsh_norm > 0.0, "route: delta_S_hack never moved -> degenerated to erasure"
# Route: final kept-vs-ablated eval -- the absorption test. KEPT keeps the
# quarantine (training-time model, still hacks); ABLATED zeroes it (the
# deployment model). SHOULD: ablated hack < kept hack at preserved solve
# => the quarantine absorbed the hack. ELSE routing didn't localize it.
if cfg.intervention == "route":
model.eval()
ev_kept = eval_hack_solve(model, tok, problems, eval_idxs, gen_cfg_eval, device, max_new)
with ablate_quarantine(wrappers):
ev_abl = eval_hack_solve(model, tok, problems, eval_idxs, gen_cfg_eval, device, max_new)
logger.info(
f"ROUTE EVAL (n={ev_kept['n']}): "
f"kept hack={ev_kept['hack']:.3f} solve={ev_kept['solve']:.3f} | "
f"ablated hack={ev_abl['hack']:.3f} solve={ev_abl['solve']:.3f} "
f"(SHOULD: ablated hack < kept hack at ~matched solve)")
# Final tail: cue emoji + main metric BLUF, then per-step tsv table.
# Vanilla arm: 🟢 if hacking emerged. Projected arm: 🟢 if HACK_RATE dropped
# vs a matched-PASS vanilla — we can't judge that here, so just report.