mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 16:30:30 +08:00
feat: online_stats gate for routeV -- live q5/q95 band calibration
New routeV_gate="online_stats" mode: use the empirical per-rollout cosine distribution (q5/q95 pooled across all modules each step) as the routing band thresholds, instead of the pair-derived route_band. Direction v_grad still from authored pairs; only thresholds are online/adaptive. Motivation: the pair-derived band sits above the live cosine distribution (median live cos ~-0.06), causing frout to cliff as GRPO advantage flattens. Online thresholds adapt to the actual live distribution, so the 5/95 tails always route regardless of where the raw cosines land. Config: routeV_gate="online_stats", online_stats_lo=0.05, online_stats_hi=0.95. Step-0 prior: (-0.5, 0.5) neutral band (pairs not used for calibration). Band update: post-opt.step(), torch.quantile over that step's module*rollout cosines. No-cheat: v_grad from authored pairs only; thresholds from the cosine distribution of live student rollouts (no oracle/labeling of live rollouts as hack/clean). Also: add online_stats to results_deploy._arm(); justfile queue-online-stats recipe. Queued as job 22 (s43, authored pairs, priority 12, after 19/20/21). Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -157,6 +157,20 @@ queue-dir6 seed='43':
|
||||
# routeV deploy_hack < vanilla at matched solve, significant across the 3 seeds (paired t,
|
||||
# alpha=0.01 like the paper); ablations (random/vampire) ~ vanilla confirm directionality.
|
||||
TEACHER_RT := "out/pools/teacher_pool_runtests_dense"
|
||||
|
||||
# H: online_stats gate -- calibrate band thresholds from the LIVE cosine distribution
|
||||
# (q5/q95 of per-rollout cosines pooled across all modules each step). Direction v_grad
|
||||
# still from authored pairs; only thresholds are online. Avoids the pair-band mis-calibration
|
||||
# (pair cosines are off-distribution; live routing often cliffs because pair band sits above
|
||||
# live rollout cosines). Expected: sustained rout (no frout cliff) even past step 20.
|
||||
queue-online-stats seed="43":
|
||||
#!/usr/bin/env bash
|
||||
set -eu
|
||||
pueue add -w "$PWD" -o 12 \
|
||||
-l "why: online_stats gate s{{seed}} -- live q5/q95 band (no pair threshold); resolve: sustained rout vs grad-cosine cliff, test if adaptive thresholds improve deploy suppression" \
|
||||
-- {{ TRAIN }} fast --intervention=routeV --routeV-gate=online_stats \
|
||||
--vhack-pairs-path=None --seed={{seed}} --out-tag=_dir8_routeV_onlinestats_s{{seed}}
|
||||
|
||||
queue-broad:
|
||||
#!/usr/bin/env bash
|
||||
set -eu
|
||||
|
||||
@@ -64,6 +64,7 @@ def _arm(argv: str) -> str:
|
||||
if "--intervention=none" in argv:
|
||||
return "vanilla"
|
||||
gate = ("act_vote" if "--routeV-gate=act_vote" in argv else
|
||||
"online_stats" if "--routeV-gate=online_stats" in argv else
|
||||
"lora" if "lora_frozen_b" in argv else
|
||||
"per-token" if "--routeV-per-token" in argv else "grad-cos")
|
||||
return f"routeV/{gate}" + ("·randV" if "--routeV-random-v-seed" in argv else "")
|
||||
|
||||
+37
-1
@@ -208,7 +208,11 @@ class Config:
|
||||
# arm: different space (act not grad) + different aggregation (one f per rollout,
|
||||
# shared across modules). Tests whether the precision framing predicts deploy
|
||||
# suppression, and stresses H2 absorption (does gate choice matter at deploy at all?).
|
||||
routeV_gate: Literal["grad_cosine", "act_vote"] = "grad_cosine"
|
||||
routeV_gate: Literal["grad_cosine", "act_vote", "online_stats"] = "grad_cosine"
|
||||
# online_stats gate: rolling buffer of per-rollout cosines; q(lo)/q(hi) set the
|
||||
# band each step. No pairs needed for threshold calibration -- direction only.
|
||||
online_stats_lo: float = 0.05 # lower quantile -> keep tail
|
||||
online_stats_hi: float = 0.95 # upper quantile -> route tail
|
||||
# Per-source cin diagnostic: split each prompt's backward into student-only
|
||||
# + teacher-only passes (~2x backward time). 1 = every step (default; full
|
||||
# signal); N>1 = only every Nth step (combined backward elsewhere, ~halves
|
||||
@@ -534,6 +538,7 @@ def main(cfg: Config) -> int:
|
||||
# are hidden, so v_hack=None just means no subspace machinery).
|
||||
v_grad = None # set only by the routeV grad-mask branch below
|
||||
As_dir = act_w = vote_band = None # set only by the act_vote gate branch below
|
||||
_online_band: list = [None] # online_stats gate: (lo, hi) updated each step; None = use pair band
|
||||
if cfg.intervention in ("none", "routeV"):
|
||||
if cfg.intervention == "none" and cfg.v_hack_path is not None:
|
||||
logger.info(f"vanilla arm: ignoring --v-hack-path={cfg.v_hack_path} "
|
||||
@@ -995,6 +1000,7 @@ def main(cfg: Config) -> int:
|
||||
# modules (the global activation vote, computed post-backward before the per-module
|
||||
# routing). 1-element list so the filter closure reads the current step's value.
|
||||
_step_f_roll: list[torch.Tensor | None] = [None]
|
||||
_step_online_cos: list[torch.Tensor] = [] # online_stats: per-module [G] cosines, cleared each step
|
||||
|
||||
# routeV: recover the per-rollout δS grad from the gate (c.grad = δS * g_b),
|
||||
# flag rollouts whose grad points hack-ward (cos(g_b, v_grad) > τ), and route
|
||||
@@ -1040,6 +1046,27 @@ def main(cfg: Config) -> int:
|
||||
_kn, _rn, _on, _ke, _re, _oe = _zone_stats(f, g_b.norm(dim=1))
|
||||
step_zkeep.append(_kn); step_zresid.append(_rn); step_zrout.append(_on)
|
||||
step_zkeepE.append(_ke); step_zresidE.append(_re); step_zroutE.append(_oe)
|
||||
elif cfg.routeV_gate == "online_stats":
|
||||
# Online-stats gate: band thresholds from the LIVE rolling cosine distribution
|
||||
# (q5/q95 across all modules*rollouts this step), not from pairs. Direction
|
||||
# (v_grad) still comes from authored pairs -- only calibration is online.
|
||||
# Fallback to pair-derived band on step 0 (buffer empty).
|
||||
cg = cg_full.sum(1) # [G, r]
|
||||
g_b = torch.where(reliable, cg / dS_safe, torch.zeros_like(cg)) # [G, r]
|
||||
cos_b = (g_b @ vg) / g_b.norm(dim=1).clamp_min(1e-12) # [G]
|
||||
_step_online_cos.append(cos_b.detach()) # accumulate; band updated post-step
|
||||
# step-0 prior: neutral (-0.5, 0.5) so some routing always fires before the
|
||||
# live distribution bootstraps. Pair-derived (lower, upper) is not used for
|
||||
# threshold calibration -- that is the whole point of online_stats.
|
||||
lo, hi = _online_band[0] if _online_band[0] is not None else (-0.5, 0.5)
|
||||
band_w = max(hi - lo, 1e-6)
|
||||
f = ((cos_b - lo) / band_w).clamp(0.0, 1.0) # [G]
|
||||
routed = torch.where(reliable, (cg * f.unsqueeze(1)).sum(0) / dS_safe,
|
||||
torch.zeros_like(g))
|
||||
step_flagged.append(f.mean().item())
|
||||
_kn, _rn, _on, _ke, _re, _oe = _zone_stats(f, g_b.norm(dim=1))
|
||||
step_zkeep.append(_kn); step_zresid.append(_rn); step_zrout.append(_on)
|
||||
step_zkeepE.append(_ke); step_zresidE.append(_re); step_zroutE.append(_oe)
|
||||
elif cfg.routeV_per_token:
|
||||
g_u = torch.where(reliable, cg_full / dS_safe, torch.zeros_like(cg_full)) # [G, s, r]
|
||||
cos_u = (g_u @ vg) / g_u.norm(dim=2).clamp_min(1e-12) # [G, s]
|
||||
@@ -1545,6 +1572,15 @@ def main(cfg: Config) -> int:
|
||||
opt.step()
|
||||
sched.step()
|
||||
|
||||
# online_stats gate: update band from this step's pooled cosines (all modules * rollouts).
|
||||
# Uses previous step's band for routing (so the update is one step lagged, which is fine).
|
||||
if is_routeV and cfg.routeV_gate == "online_stats" and _step_online_cos:
|
||||
all_cos = torch.cat(_step_online_cos).float()
|
||||
lo = torch.quantile(all_cos, cfg.online_stats_lo).item()
|
||||
hi = torch.quantile(all_cos, cfg.online_stats_hi).item()
|
||||
_online_band[0] = (lo, max(hi, lo + 1e-4))
|
||||
logger.debug(f"online_stats band update: lo={lo:+.3f} hi={hi:+.3f} n={len(all_cos)}")
|
||||
|
||||
# ── v_hack / v_grad refresh ──
|
||||
# Online v_hack refresh: re-extract against the *current* model so the
|
||||
# hack subspace tracks where the student is being pulled now (rather
|
||||
|
||||
Reference in New Issue
Block a user