From f646e570289b042ba8cdd37fcbae3cc0bf32ff05 Mon Sep 17 00:00:00 2001 From: wassname <1103714+wassname@users.noreply.github.com> Date: Thu, 11 Jun 2026 12:48:11 +0000 Subject: [PATCH] fix(#41): routeA gate hardening from fresh-eyes review - assert finite gate scores (one NaN act would poison the rolling buffer -> NaN thresholds -> silent all-keep routing for up to route_buffer rollouts) - d = d*m: degenerate Otsu (t_lo==t_hi) could emit (m=0,d=1), a mask state that trains nothing; rout now requires the quarantine on - solve-mix discrimination collects only when solve teachers exist (a hack-only [False]*n list tripped the end-of-run line with a nan solve side) - warn when authored-pair AUROC < 0.9 (extraction sanity; tiny-random smoke sits ~0.78 so warning, not assert) Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com> --- src/vgrout/train.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/vgrout/train.py b/src/vgrout/train.py index acd400e..04d758b 100644 --- a/src/vgrout/train.py +++ b/src/vgrout/train.py @@ -262,6 +262,11 @@ def main(cfg: Config) -> int: f"routeA v_act: {v_act.shape[0]} modules x r={v_act.shape[1]} " f"(tstat={cfg.vact_tstat}); authored-pair dot gap={(sh.mean() - sc.mean()).item():+.3e}, " f"pair AUROC={pair_auroc:.2f}. SHOULD: pair AUROC ~1.0 ELSE extraction broken.") + if pair_auroc < 0.9: + # ~1.0 expected (pairs scored against the v built FROM them); the tiny-random + # smoke model sits ~0.78, so warn rather than assert. + logger.warning(f"routeA v_act: pair AUROC={pair_auroc:.2f} < 0.9 -- on a real " + f"model this means extraction is broken or the pairset is wrong") if cfg.routeA_random_v_seed is not None: v_act = haar_unit_rows(tuple(v_act.shape), cfg.routeA_random_v_seed) logger.info(f"routeA: OVERRODE v_act with Haar-random unit rows " @@ -525,6 +530,9 @@ def main(cfg: Config) -> int: absorb (1,0, both blocks train); z >= t_hi rout (1,1, deployed detached). Warmup: pinned absorb until the buffer holds route_warmup scores -- too few points to place thresholds, and absorb keeps both blocks learning.""" + # A single NaN act would poison the whole buffer -> NaN thresholds -> every + # comparison False -> silent all-keep routing for up to route_buffer rollouts. + assert torch.isfinite(dots).all(), f"non-finite gate scores: {dots}" if len(act_buf) < cfg.route_warmup: G_n = dots.shape[0] return (torch.ones(G_n, device=device), torch.zeros(G_n, device=device), @@ -535,6 +543,7 @@ def main(cfg: Config) -> int: z = (dots - mu) / sd m = (z > t_lo).float().to(device) # absorb + rout -> quarantine trains d = (z >= t_hi).float().to(device) # top zone -> rout -> deployed detached + d = d * m # degenerate Otsu (t_lo==t_hi) could give (0,1) = trains nothing; rout requires m=1 logger.debug(f"routeA gate: buf={len(act_buf)} mu={mu:+.3e} sd={sd:.3e} " f"t_lo={t_lo:+.2f} t_hi={t_hi:+.2f} (z) | batch z " f"min={z.min().item():+.2f} max={z.max().item():+.2f}") @@ -852,7 +861,9 @@ def main(cfg: Config) -> int: # hack teachers (d->1) and KEEPS the solve teachers (d->0); equal shares mean # the gate is non-directional (the shrinkage null). Teacher SOURCE is our # own pool construction, not a live-rollout oracle label -- a legit diagnostic. - if teacher_is_solve: + # any() not truthiness: a hack-only teacher list ([False]*n) must not trip + # the end-of-run discrimination line with a nan solve side. + if any(teacher_is_solve): is_solve_t = torch.tensor(teacher_is_solve, device=d_vec.device, dtype=torch.bool) d_teach = d_vec[-len(teacher_is_solve):] if (~is_solve_t).any():