mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 16:45:42 +08:00
fix(#41): routeA gate hardening from fresh-eyes review
- assert finite gate scores (one NaN act would poison the rolling buffer -> NaN thresholds -> silent all-keep routing for up to route_buffer rollouts) - d = d*m: degenerate Otsu (t_lo==t_hi) could emit (m=0,d=1), a mask state that trains nothing; rout now requires the quarantine on - solve-mix discrimination collects only when solve teachers exist (a hack-only [False]*n list tripped the end-of-run line with a nan solve side) - warn when authored-pair AUROC < 0.9 (extraction sanity; tiny-random smoke sits ~0.78 so warning, not assert) Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
+12
-1
@@ -262,6 +262,11 @@ def main(cfg: Config) -> int:
|
|||||||
f"routeA v_act: {v_act.shape[0]} modules x r={v_act.shape[1]} "
|
f"routeA v_act: {v_act.shape[0]} modules x r={v_act.shape[1]} "
|
||||||
f"(tstat={cfg.vact_tstat}); authored-pair dot gap={(sh.mean() - sc.mean()).item():+.3e}, "
|
f"(tstat={cfg.vact_tstat}); authored-pair dot gap={(sh.mean() - sc.mean()).item():+.3e}, "
|
||||||
f"pair AUROC={pair_auroc:.2f}. SHOULD: pair AUROC ~1.0 ELSE extraction broken.")
|
f"pair AUROC={pair_auroc:.2f}. SHOULD: pair AUROC ~1.0 ELSE extraction broken.")
|
||||||
|
if pair_auroc < 0.9:
|
||||||
|
# ~1.0 expected (pairs scored against the v built FROM them); the tiny-random
|
||||||
|
# smoke model sits ~0.78, so warn rather than assert.
|
||||||
|
logger.warning(f"routeA v_act: pair AUROC={pair_auroc:.2f} < 0.9 -- on a real "
|
||||||
|
f"model this means extraction is broken or the pairset is wrong")
|
||||||
if cfg.routeA_random_v_seed is not None:
|
if cfg.routeA_random_v_seed is not None:
|
||||||
v_act = haar_unit_rows(tuple(v_act.shape), cfg.routeA_random_v_seed)
|
v_act = haar_unit_rows(tuple(v_act.shape), cfg.routeA_random_v_seed)
|
||||||
logger.info(f"routeA: OVERRODE v_act with Haar-random unit rows "
|
logger.info(f"routeA: OVERRODE v_act with Haar-random unit rows "
|
||||||
@@ -525,6 +530,9 @@ def main(cfg: Config) -> int:
|
|||||||
absorb (1,0, both blocks train); z >= t_hi rout (1,1, deployed detached).
|
absorb (1,0, both blocks train); z >= t_hi rout (1,1, deployed detached).
|
||||||
Warmup: pinned absorb until the buffer holds route_warmup scores -- too few
|
Warmup: pinned absorb until the buffer holds route_warmup scores -- too few
|
||||||
points to place thresholds, and absorb keeps both blocks learning."""
|
points to place thresholds, and absorb keeps both blocks learning."""
|
||||||
|
# A single NaN act would poison the whole buffer -> NaN thresholds -> every
|
||||||
|
# comparison False -> silent all-keep routing for up to route_buffer rollouts.
|
||||||
|
assert torch.isfinite(dots).all(), f"non-finite gate scores: {dots}"
|
||||||
if len(act_buf) < cfg.route_warmup:
|
if len(act_buf) < cfg.route_warmup:
|
||||||
G_n = dots.shape[0]
|
G_n = dots.shape[0]
|
||||||
return (torch.ones(G_n, device=device), torch.zeros(G_n, device=device),
|
return (torch.ones(G_n, device=device), torch.zeros(G_n, device=device),
|
||||||
@@ -535,6 +543,7 @@ def main(cfg: Config) -> int:
|
|||||||
z = (dots - mu) / sd
|
z = (dots - mu) / sd
|
||||||
m = (z > t_lo).float().to(device) # absorb + rout -> quarantine trains
|
m = (z > t_lo).float().to(device) # absorb + rout -> quarantine trains
|
||||||
d = (z >= t_hi).float().to(device) # top zone -> rout -> deployed detached
|
d = (z >= t_hi).float().to(device) # top zone -> rout -> deployed detached
|
||||||
|
d = d * m # degenerate Otsu (t_lo==t_hi) could give (0,1) = trains nothing; rout requires m=1
|
||||||
logger.debug(f"routeA gate: buf={len(act_buf)} mu={mu:+.3e} sd={sd:.3e} "
|
logger.debug(f"routeA gate: buf={len(act_buf)} mu={mu:+.3e} sd={sd:.3e} "
|
||||||
f"t_lo={t_lo:+.2f} t_hi={t_hi:+.2f} (z) | batch z "
|
f"t_lo={t_lo:+.2f} t_hi={t_hi:+.2f} (z) | batch z "
|
||||||
f"min={z.min().item():+.2f} max={z.max().item():+.2f}")
|
f"min={z.min().item():+.2f} max={z.max().item():+.2f}")
|
||||||
@@ -852,7 +861,9 @@ def main(cfg: Config) -> int:
|
|||||||
# hack teachers (d->1) and KEEPS the solve teachers (d->0); equal shares mean
|
# hack teachers (d->1) and KEEPS the solve teachers (d->0); equal shares mean
|
||||||
# the gate is non-directional (the shrinkage null). Teacher SOURCE is our
|
# the gate is non-directional (the shrinkage null). Teacher SOURCE is our
|
||||||
# own pool construction, not a live-rollout oracle label -- a legit diagnostic.
|
# own pool construction, not a live-rollout oracle label -- a legit diagnostic.
|
||||||
if teacher_is_solve:
|
# any() not truthiness: a hack-only teacher list ([False]*n) must not trip
|
||||||
|
# the end-of-run discrimination line with a nan solve side.
|
||||||
|
if any(teacher_is_solve):
|
||||||
is_solve_t = torch.tensor(teacher_is_solve, device=d_vec.device, dtype=torch.bool)
|
is_solve_t = torch.tensor(teacher_is_solve, device=d_vec.device, dtype=torch.bool)
|
||||||
d_teach = d_vec[-len(teacher_is_solve):]
|
d_teach = d_vec[-len(teacher_is_solve):]
|
||||||
if (~is_solve_t).any():
|
if (~is_solve_t).any():
|
||||||
|
|||||||
Reference in New Issue
Block a user