fix(#41): routeA gate hardening from fresh-eyes review

- assert finite gate scores (one NaN act would poison the rolling buffer ->
  NaN thresholds -> silent all-keep routing for up to route_buffer rollouts)
- d = d*m: degenerate Otsu (t_lo==t_hi) could emit (m=0,d=1), a mask state
  that trains nothing; rout now requires the quarantine on
- solve-mix discrimination collects only when solve teachers exist (a
  hack-only [False]*n list tripped the end-of-run line with a nan solve side)
- warn when authored-pair AUROC < 0.9 (extraction sanity; tiny-random smoke
  sits ~0.78 so warning, not assert)

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-11 12:48:11 +00:00
parent adca442253
commit f646e57028
+12 -1
View File
@@ -262,6 +262,11 @@ def main(cfg: Config) -> int:
f"routeA v_act: {v_act.shape[0]} modules x r={v_act.shape[1]} " f"routeA v_act: {v_act.shape[0]} modules x r={v_act.shape[1]} "
f"(tstat={cfg.vact_tstat}); authored-pair dot gap={(sh.mean() - sc.mean()).item():+.3e}, " f"(tstat={cfg.vact_tstat}); authored-pair dot gap={(sh.mean() - sc.mean()).item():+.3e}, "
f"pair AUROC={pair_auroc:.2f}. SHOULD: pair AUROC ~1.0 ELSE extraction broken.") f"pair AUROC={pair_auroc:.2f}. SHOULD: pair AUROC ~1.0 ELSE extraction broken.")
if pair_auroc < 0.9:
# ~1.0 expected (pairs scored against the v built FROM them); the tiny-random
# smoke model sits ~0.78, so warn rather than assert.
logger.warning(f"routeA v_act: pair AUROC={pair_auroc:.2f} < 0.9 -- on a real "
f"model this means extraction is broken or the pairset is wrong")
if cfg.routeA_random_v_seed is not None: if cfg.routeA_random_v_seed is not None:
v_act = haar_unit_rows(tuple(v_act.shape), cfg.routeA_random_v_seed) v_act = haar_unit_rows(tuple(v_act.shape), cfg.routeA_random_v_seed)
logger.info(f"routeA: OVERRODE v_act with Haar-random unit rows " logger.info(f"routeA: OVERRODE v_act with Haar-random unit rows "
@@ -525,6 +530,9 @@ def main(cfg: Config) -> int:
absorb (1,0, both blocks train); z >= t_hi rout (1,1, deployed detached). absorb (1,0, both blocks train); z >= t_hi rout (1,1, deployed detached).
Warmup: pinned absorb until the buffer holds route_warmup scores -- too few Warmup: pinned absorb until the buffer holds route_warmup scores -- too few
points to place thresholds, and absorb keeps both blocks learning.""" points to place thresholds, and absorb keeps both blocks learning."""
# A single NaN act would poison the whole buffer -> NaN thresholds -> every
# comparison False -> silent all-keep routing for up to route_buffer rollouts.
assert torch.isfinite(dots).all(), f"non-finite gate scores: {dots}"
if len(act_buf) < cfg.route_warmup: if len(act_buf) < cfg.route_warmup:
G_n = dots.shape[0] G_n = dots.shape[0]
return (torch.ones(G_n, device=device), torch.zeros(G_n, device=device), return (torch.ones(G_n, device=device), torch.zeros(G_n, device=device),
@@ -535,6 +543,7 @@ def main(cfg: Config) -> int:
z = (dots - mu) / sd z = (dots - mu) / sd
m = (z > t_lo).float().to(device) # absorb + rout -> quarantine trains m = (z > t_lo).float().to(device) # absorb + rout -> quarantine trains
d = (z >= t_hi).float().to(device) # top zone -> rout -> deployed detached d = (z >= t_hi).float().to(device) # top zone -> rout -> deployed detached
d = d * m # degenerate Otsu (t_lo==t_hi) could give (0,1) = trains nothing; rout requires m=1
logger.debug(f"routeA gate: buf={len(act_buf)} mu={mu:+.3e} sd={sd:.3e} " logger.debug(f"routeA gate: buf={len(act_buf)} mu={mu:+.3e} sd={sd:.3e} "
f"t_lo={t_lo:+.2f} t_hi={t_hi:+.2f} (z) | batch z " f"t_lo={t_lo:+.2f} t_hi={t_hi:+.2f} (z) | batch z "
f"min={z.min().item():+.2f} max={z.max().item():+.2f}") f"min={z.min().item():+.2f} max={z.max().item():+.2f}")
@@ -852,7 +861,9 @@ def main(cfg: Config) -> int:
# hack teachers (d->1) and KEEPS the solve teachers (d->0); equal shares mean # hack teachers (d->1) and KEEPS the solve teachers (d->0); equal shares mean
# the gate is non-directional (the shrinkage null). Teacher SOURCE is our # the gate is non-directional (the shrinkage null). Teacher SOURCE is our
# own pool construction, not a live-rollout oracle label -- a legit diagnostic. # own pool construction, not a live-rollout oracle label -- a legit diagnostic.
if teacher_is_solve: # any() not truthiness: a hack-only teacher list ([False]*n) must not trip
# the end-of-run discrimination line with a nan solve side.
if any(teacher_is_solve):
is_solve_t = torch.tensor(teacher_is_solve, device=d_vec.device, dtype=torch.bool) is_solve_t = torch.tensor(teacher_is_solve, device=d_vec.device, dtype=torch.bool)
d_teach = d_vec[-len(teacher_is_solve):] d_teach = d_vec[-len(teacher_is_solve):]
if (~is_solve_t).any(): if (~is_solve_t).any():