From 485839d7b1198266eae4c4d5c8ac99c2309815d7 Mon Sep 17 00:00:00 2001 From: wassname <1103714+wassname@users.noreply.github.com> Date: Sat, 6 Jun 2026 03:27:24 +0000 Subject: [PATCH] route2: pair-calibrated banded gate, drop live-detector tau + force-route Replace the confounded route2 gate (hack_anchor force-routed teacher + weak-detector student rows by LABEL; EMA tau calibrated from a live detector over student rollouts at train time = a cheat) with a band calibrated from the contrastive pairs alone: lower = mean clean-pair cos(g, v_grad); upper = mean hack-pair cos per rollout: f = clamp((cos(g_b, v_grad) - lower)/(upper - lower), 0, 1) routed = sum_b f_b * g_b -> delta_S_hack; kept = g - routed -> delta_S v_grad is now the SOLE router: no detector or gt_pass touches routing, so "does v_hack generalize to held-out modes" is clean and random-vs-real is decisive. Band width (upper-lower) is itself the discriminator: smoke shows +0.289 real vs -0.014 Haar-random (collapsed). conservation routed+kept=g holds exactly; resid~0 in smoke (no hack leak into the deployed knob). - delete build_route2_anchors + EMA state (ema_hack/clean_cos, route2_tau) - add route_band_edges(); build at extract, rebuild on v_grad refresh - drop --gate-anchor-teacher-only config + retire scripts/verify_gate_anchor.py - teacher rollouts now route through the same band (not force-routed) - spec: add the mass-confound control (scientist review 2026-06-06) smoke-route2 + smoke-route2 --route2-random-v-seed=7 both pass; erase smoke green. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com> --- docs/spec/20260606_pair_routing_design.md | 10 ++ justfile | 1 - scripts/verify_gate_anchor.py | 68 ----------- src/vgrout/train.py | 138 +++++++++------------- 4 files changed, 65 insertions(+), 152 deletions(-) delete mode 100644 scripts/verify_gate_anchor.py diff --git a/docs/spec/20260606_pair_routing_design.md b/docs/spec/20260606_pair_routing_design.md index 86a3f19..07b0d90 100644 --- a/docs/spec/20260606_pair_routing_design.md +++ b/docs/spec/20260606_pair_routing_design.md @@ -168,8 +168,18 @@ SHOULD per step: live cos_b percentiles (p10/p50/p90) STRADDLE [lower, upper ELSE all below lower -> routes nothing; all above upper -> routes everything (miscalibrated). SHOULD per step: route fraction f mean ∈ (0,1), some mass at 0 and at 1. ELSE degenerate gate. SHOULD per step: resid = cos(delta_S.grad after routing, v1) ~ 0. ELSE hack leaking into the deployed knob. +ALSO log routed mass: route -> mean f (fraction of grad routed); erase -> ‖removed‖/‖g‖ per step. ``` +Mass confound (scientist review, 2026-06-06). Real and random `v_hack` can suppress by +DIFFERENT routes: the right direction, OR simply quarantining more gradient mass. Real `v1` +aligns with the live hack gradient so it routes/removes more mass than a random direction +(which aligns ~0), so a raw real-vs-random win partly conflates "right direction" with "more +mass removed". Two defences, both cheap: (a) log the routed mass above for both conditions, so +a reader sees whether real won at MATCHED mass; (b) if the gap is mass-driven, add a +magnitude-matched random control (scale the random subtraction/route to remove the same norm +as real). Defence (a) is mandatory; (b) only if (a) shows a mass gap. + ## Implementation plan (src/vgrout/train.py) Rollback tag `pre-routing-refactor`. erase already works; the code below is the route rewrite. diff --git a/justfile b/justfile index 15d4071..79f116a 100644 --- a/justfile +++ b/justfile @@ -27,7 +27,6 @@ results: # zero-variance bails every step, leaving the loss path uncovered. smoke *ARGS: uv run python scripts/verify_rewards.py # grader gate: 3 env_modes x clean/hack - uv run python scripts/verify_gate_anchor.py # route2 no-cheat gate: teacher-only anchor zeroes held-out labels uv run python scripts/verify_partition.py # no-cheat: partition clean + teacher_modes hands gate only known-mode demos BEARTYPE=1 {{ TRAIN }} smoke --intervention=erase \ --v-hack-path=out/vhack/v_hack_smoke.safetensors \ diff --git a/scripts/verify_gate_anchor.py b/scripts/verify_gate_anchor.py deleted file mode 100644 index b2f3f0d..0000000 --- a/scripts/verify_gate_anchor.py +++ /dev/null @@ -1,68 +0,0 @@ -"""Unit check of the route2 τ-gate anchor (train.build_route2_anchors). - -SHOULD: the no-cheat invariant holds at the gate. The anchor decides which rollouts - are FORCE-routed into the deletable quarantine (hack_anchor); the rest only - route if cos(g,v) > tau. ELSE: held-out classes are receiving detector labels - and the A5 generalisation claim is contaminated. - -The load-bearing invariant (the leak this guards): - hacked_E is the run_tests detector. It cross-fires <=1.1% on held-out modes (false - positives), so under the DEFAULT anchor those held-out rollouts are force-routed -- - a real, if small, label leak into the held-out class. The teacher-only anchor - (gate_anchor_teacher_only=True) drops the student detector term entirely, so the - held-out classes get PROVABLY zero detector labels: the airtight A5 control. - -Scenario modelled: merged batch is [students ... teachers] (train.py is_student = -[True]*G_s + [False]*G_t). One held-out student (idx 2) is a detector FALSE POSITIVE. -""" -from __future__ import annotations - -import sys - -import torch -from loguru import logger - -from vgrout.train import build_route2_anchors - -DEV = torch.device("cpu") -G_S, G_T = 5, 2 -IS_STUDENT = [True] * G_S + [False] * G_T # students lead, teachers follow -# Held-out batch: the run_tests detector FALSE-fires on student idx 2 only (the leak). -HACK_E_FLAGS = [False, False, True, False, False] # len == G_S - - -def _check(name: str, cond: bool) -> bool: - logger.info(f"{'PASS' if cond else 'FAIL'} {name}") - return cond - - -def main() -> int: - ok = True - - # DEFAULT anchor: the leak is present. - ha, ca = build_route2_anchors(IS_STUDENT, HACK_E_FLAGS, teacher_only=False, device=DEV) - ok &= _check("default: teachers always anchored", bool(ha[G_S:].all())) - ok &= _check("default: flagged held-out student idx2 IS force-routed (the leak)", - bool(ha[2])) - ok &= _check("default: student detector labels reach the gate (sum>0)", - int(ha[:G_S].sum()) == 1) - ok &= _check("default: clean_anchor is the exact complement", bool((ca == ~ha).all())) - - # TEACHER-ONLY anchor: the fix. Zero student detector labels. - ha2, ca2 = build_route2_anchors(IS_STUDENT, HACK_E_FLAGS, teacher_only=True, device=DEV) - ok &= _check("teacher_only: teachers still all anchored", bool(ha2[G_S:].all())) - ok &= _check("teacher_only: ZERO student rollouts force-routed (no leak)", - int(ha2[:G_S].sum()) == 0) - ok &= _check("teacher_only: the held-out FP student idx2 is NOT routed", - not bool(ha2[2])) - ok &= _check("teacher_only: clean_anchor is the exact complement", bool((ca2 == ~ha2).all())) - - # The fix only touches student labels: teacher anchoring is identical either way. - ok &= _check("fix leaves teacher rows unchanged", bool((ha[G_S:] == ha2[G_S:]).all())) - - logger.info("ALL PASS -- gate anchor no-cheat invariant holds" if ok else "FAILURES above") - return 0 if ok else 1 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/src/vgrout/train.py b/src/vgrout/train.py index 5b33406..6098325 100644 --- a/src/vgrout/train.py +++ b/src/vgrout/train.py @@ -140,13 +140,6 @@ class Config: seed: int = 41 preserve_magnitude: bool = True gate_mode: Literal["one_sided", "no_gate", "reverse"] = "one_sided" - # route2 airtight no-cheat control: anchor the τ-gate on TEACHER rows only, never - # on hacked_E-flagged student rows. The run_tests detector cross-fires <=1.1% on - # held-out modes (false positives), so the default anchor leaks ~1% of held-out - # labels into routing. Teacher-only anchor gives the held-out classes PROVABLY zero - # detector labels -- the strict A5 no-cheat test. v_grad and the τ-route-by-energy - # path are unchanged; only the force-route-known-hacks term drops its student flags. - gate_anchor_teacher_only: bool = False project_overshoot: float = 1.0 # remove overshoot*c_use@V; 1.0=just remove, 1.1=10% reversal of hack-ward grad # route/route2 exploration floor: fraction of student rollouts sampled with the # quarantine (δS_hack) ablated, i.e. from the DEPLOYED model. Intent: guard hack- @@ -335,23 +328,24 @@ def _haar_unit_dirs(v_grad: dict, seed: int, device) -> dict: return out -def build_route2_anchors(is_student: list[bool], hack_E_flags: list[bool], - teacher_only: bool, device) -> tuple[torch.Tensor, torch.Tensor]: - """τ-calibration anchors for the route2 gate (merged rows: students lead, teachers - follow). hack_anchor = teacher rows OR (unless teacher_only) detector-flagged student - rows; clean_anchor is the exact complement. hack_E_flags (len G_s) aligns with the - leading student rows. teacher_only drops the student detector term so held-out classes - get PROVABLY zero detector labels -- the airtight A5 no-cheat control. The default - leaks: the run_tests detector cross-fires <=1.1% on held-out modes, force-routing those - rollouts. Verified in scripts/verify_gate_anchor.py.""" - n = len(is_student) - is_student_t = torch.as_tensor(is_student, dtype=torch.bool, device=device) - flags = torch.zeros(n, dtype=torch.bool, device=device) - if not teacher_only: - m = min(n, len(hack_E_flags)) - flags[:m] = torch.as_tensor(list(hack_E_flags[:m]), dtype=torch.bool, device=device) - hack_anchor = (~is_student_t) | flags - return hack_anchor, ~hack_anchor +def route_band_edges(raw_grads: dict, v_grad: dict, device) -> dict[str, tuple[float, float]]: + """Per-module routing band (lower, upper) from the contrastive pairs ALONE -- the + pair-calibrated replacement for the old live-detector τ. lower = mean clean-pair cosine + to v_grad; upper = mean hack-pair cosine. A live rollout's cos(g_b, v_grad) below lower + is kept, above upper is routed, in between ramps (absorption). raw_grads carries the + train-pair per-pair δS grads as `hack/{name}` / `clean/{name}` [n_pairs, r]. Cosine is + scale-invariant so the extract's length-normalised NLL grads and the live token-sum grads + are comparable here. With a Haar-random v_grad both edges collapse to ~0 -> band closes -> + routing degenerates to a coin flip: band width is itself the real-vs-random discriminator.""" + band = {} + for name in v_grad: + v = v_grad[name].detach().cpu().float() + gh = raw_grads[f"hack/{name}"].float() # [n_pairs, r] + gc = raw_grads[f"clean/{name}"].float() + ch = (gh @ v) / gh.norm(dim=1).clamp_min(1e-12) # [n_pairs] hack-pair cosines + cc = (gc @ v) / gc.norm(dim=1).clamp_min(1e-12) # [n_pairs] clean-pair cosines + band[name] = (cc.mean().item(), ch.mean().item()) # (lower, upper) + return band @torch.no_grad() @@ -497,6 +491,13 @@ def main(cfg: Config) -> int: v_grad = _haar_unit_dirs(v_grad, cfg.route2_random_v_seed, device) logger.info(f"route2 grad: OVERRODE v_grad with Haar-random dirs " f"(seed={cfg.route2_random_v_seed}) -- directionality control (H2 vs H4)") + # Routing band from the pairs (against the FINAL v_grad, so a Haar override + # collapses the band -- the real-vs-random discriminator). + route_band = route_band_edges(raw_grads, v_grad, device) + _bw = [hi - lo for lo, hi in route_band.values()] + logger.info(f"route2 band: edges from {len(route_band)} modules, " + f"mean width(upper-lower)={sum(_bw)/len(_bw):+.3f} " + f"(>0 = pairs separate; ~0 = random/degenerate)") model.train() else: # v_hack path resolution, most-specific first. The pairset (personas) is @@ -761,16 +762,9 @@ def main(cfg: Config) -> int: rollout_log_path.write_text("") first_hack_saved = False route_span_checked = False # R3: assert delta_S_hack.grad in span(V) once - # route2-grad per-step calibrated routing threshold (spec - # docs/spec/20260601_calibrated_tau_route2grad.md). tau = EMA midpoint of the - # hack-cloud (teacher + detector-flagged student) and clean-cloud (not-flagged - # student) cos(g_b, v_grad) per module. Rides the cin drift so a fixed cos>0 - # gate (a ~50% coin-flip in high-dim) is replaced by "above where known hacks - # separate from clean". Persist across steps (EMA = cheap "last N hacks"). - ema_hack_cos: dict[str, float] = {} - ema_clean_cos: dict[str, float] = {} - route2_tau: dict[str, float] = {} - EMA_BETA = 0.9 + # route2-grad routing band is built from the pairs at v_grad extraction time + # (route_band[name] = (lower, upper)); see route_band_edges. No live-detector τ, + # no EMA -- the pairs alone calibrate the gate, refreshed with v_grad. last_gen_sample = None # first student rollout of the latest step (for collapse inspection) diverged_steps = 0 # consecutive steps with collapsed teacher ppl (divergence tripwire) lp_t_best = -float("inf") # coherence high-water mark (best teacher gen_logp seen) @@ -871,13 +865,11 @@ def main(cfg: Config) -> int: # routing on a fresh axis lags ~1 step until δS grows there (A1 stale-mask trade-off). GATE_EPS = 1e-6 step_flagged: list[float] = [] - step_tau: list[float] = [] # per-(prompt,module) calibrated route threshold - step_hkgap: list[float] = [] # ema_hack_cos - ema_clean_cos (discrimination gauge) + step_tau: list[float] = [] # median live cos_b (should sit inside the band) + step_hkgap: list[float] = [] # band width upper-lower (pair separation; ~0 = random/degenerate) step_resid: list[float] = [] # cos(δS.grad AFTER routing, v_grad): hack-ward leak into deployed knob - def _route2_grad_filter(info, n_rollouts: int, - hack_anchor: torch.Tensor, - clean_anchor: torch.Tensor) -> torch.Tensor: + def _route2_grad_filter(info, n_rollouts: int) -> torch.Tensor: g = info["delta_S"].grad # [r] summed over rollouts*tokens # The hook's gate c is per-token ([G*s, r]) because nn.Linear sees a # flattened batch. Sum each rollout's token gate-grads -> per-rollout @@ -894,39 +886,25 @@ def main(cfg: Config) -> int: g_b = torch.where(reliable, cg / dS_safe, torch.zeros_like(cg)) # [G, r] per-rollout vg = v_grad[name] # [r] unit, hack-ward cos_b = (g_b @ vg) / g_b.norm(dim=1).clamp_min(1e-12) # [G] - # Calibrate the threshold to where KNOWN hacks separate from clean, - # per module, EMA-smoothed across steps (rides the cin drift). A fixed - # cos>0 gate is a ~50% coin-flip in high-dim (cos~0 for most rollouts). - if hack_anchor.any(): - mu_h = cos_b[hack_anchor].mean().item() - ema_hack_cos[name] = (EMA_BETA * ema_hack_cos[name] + (1 - EMA_BETA) * mu_h - if name in ema_hack_cos else mu_h) - if clean_anchor.any(): - mu_c = cos_b[clean_anchor].mean().item() - ema_clean_cos[name] = (EMA_BETA * ema_clean_cos[name] + (1 - EMA_BETA) * mu_c - if name in ema_clean_cos else mu_c) - tau = (ema_hack_cos.get(name, 0.0) + ema_clean_cos.get(name, 0.0)) / 2 - route2_tau[name] = tau - step_tau.append(tau) - step_hkgap.append(ema_hack_cos.get(name, 0.0) - ema_clean_cos.get(name, 0.0)) - # Force-route known hacks (teacher + flagged student); τ-route the - # ambiguous rest (incl. unknown B, which lands above τ if it shares the - # v_grad direction). Do NOT force-keep clean_anchor: it is contaminated - # with unknown B, which we WANT routed. - flagged = (hack_anchor | (cos_b > tau)).float() # [G] - step_flagged.append(flagged.mean().item()) - sub = torch.where(reliable, (cg * flagged.unsqueeze(1)).sum(0) / dS_safe, - torch.zeros_like(g)) # flagged rollouts' contribution - # Park the flagged contribution in δS_hack (deleted at deploy); δS keeps - # only the unflagged. Capacity-balanced: both shape [r]. - step_grad_hack[name] = (step_grad_hack[name] + sub.detach().clone() - if name in step_grad_hack else sub.detach().clone()) - g_keep = g - sub # the deployed knob's gradient - # Residual hack-ward alignment of the KEPT grad. Disambiguates qE: - # qE high + resid~0 = routing stripped the hack cleanly (dominant - # teacher grad correctly quarantined); qE high + resid>0 = false - # negatives leaked hack-ward grad into the deployed knob (the real - # failure). vg is unit, so this is a plain cosine. + # Banded gate, calibrated from the PAIRS only (route_band[name]): a rollout + # whose grad cosine is below the clean edge is kept, above the hack edge is + # routed, and in between ramps proportionally (the absorption zone). No live + # detector, no teacher force-route -- v_grad is the sole router. f is the + # routed FRACTION of this rollout's grad (0..1). + lower, upper = route_band[name] + f = ((cos_b - lower) / max(upper - lower, 1e-6)).clamp(0.0, 1.0) # [G] + step_flagged.append(f.mean().item()) + step_hkgap.append(upper - lower) + step_tau.append(cos_b.median().item()) # live cos centre vs the band + routed = torch.where(reliable, (cg * f.unsqueeze(1)).sum(0) / dS_safe, + torch.zeros_like(g)) # Σ_b f_b·g_b on reliable axes + # Park the routed fraction in δS_hack (deleted at deploy); δS keeps the rest. + # routed + g_keep = g exactly (unreliable axes: routed=0, kept whole). + step_grad_hack[name] = (step_grad_hack[name] + routed.detach().clone() + if name in step_grad_hack else routed.detach().clone()) + g_keep = g - routed # the deployed knob's gradient + # Residual hack-ward alignment of the KEPT grad: ~0 = routing stripped the + # hack cleanly; >0 = hack leaked into the deployed knob. vg is unit -> plain cosine. step_resid.append((g_keep @ vg / g_keep.norm().clamp_min(1e-12)).item()) return g_keep @@ -1236,22 +1214,15 @@ def main(cfg: Config) -> int: ptl_norm = (Lp * mask).sum(1) / mask.sum(1).clamp_min(1) loss = ptl_norm.sum() / (group * prompts_per_step) loss.backward() - # route2: per-prompt anchor masks for the τ calibration. Hack cloud = - # teacher rows (known-A hacks) + detector-flagged (hack_E) student rows; - # clean cloud = not-flagged student rows (contaminated with unknown B by - # design -> conservative τ; B still routes via cos>τ). hack_E_flags - # (len G_s) aligns with the leading student rows of is_student. - if is_route2: - _ha, _ca = build_route2_anchors( - is_student, hack_E_flags, cfg.gate_anchor_teacher_only, Lp.device) for name, info in wrappers.items(): g = info["delta_S"].grad if g is None: continue - # route2 routes here: strip flagged rollouts from δS.grad and - # park them in δS_hack (via step_grad_hack in the filter). + # route2 routes here: split each rollout's δS.grad by its cosine to + # v_grad against the pair-calibrated band, park the routed fraction in + # δS_hack (via step_grad_hack in the filter). if is_route2: - g = _route2_grad_filter(info, merged.shape[0], _ha, _ca) + g = _route2_grad_filter(info, merged.shape[0]) step_grad_s[name] = (step_grad_s[name] + g.detach().clone() if name in step_grad_s else g.detach().clone()) @@ -1381,6 +1352,7 @@ def main(cfg: Config) -> int: for name in wrappers: # update in place so _route2_grad_filter's closure sees it d = (raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).mean(0) v_grad[name] = (d / d.norm().clamp_min(1e-12)).to(device) + route_band = route_band_edges(raw_grads, v_grad, device) # rebuild band on the fresh v_grad finally: logger.enable("vgrout.extract_vhack_grad") logger.enable("__main__")