feat: online_stats gate for routeV -- live q5/q95 band calibration

New routeV_gate="online_stats" mode: use the empirical per-rollout cosine
distribution (q5/q95 pooled across all modules each step) as the routing
band thresholds, instead of the pair-derived route_band. Direction v_grad
still from authored pairs; only thresholds are online/adaptive.

Motivation: the pair-derived band sits above the live cosine distribution
(median live cos ~-0.06), causing frout to cliff as GRPO advantage flattens.
Online thresholds adapt to the actual live distribution, so the 5/95 tails
always route regardless of where the raw cosines land.

Config: routeV_gate="online_stats", online_stats_lo=0.05, online_stats_hi=0.95.
Step-0 prior: (-0.5, 0.5) neutral band (pairs not used for calibration).
Band update: post-opt.step(), torch.quantile over that step's module*rollout cosines.
No-cheat: v_grad from authored pairs only; thresholds from the cosine distribution
of live student rollouts (no oracle/labeling of live rollouts as hack/clean).

Also: add online_stats to results_deploy._arm(); justfile queue-online-stats recipe.
Queued as job 22 (s43, authored pairs, priority 12, after 19/20/21).

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-09 02:25:37 +00:00
parent 0412dc56d1
commit 0f59b1351b
3 changed files with 52 additions and 1 deletions
+14
View File
@@ -157,6 +157,20 @@ queue-dir6 seed='43':
# routeV deploy_hack < vanilla at matched solve, significant across the 3 seeds (paired t,
# alpha=0.01 like the paper); ablations (random/vampire) ~ vanilla confirm directionality.
TEACHER_RT := "out/pools/teacher_pool_runtests_dense"
# H: online_stats gate -- calibrate band thresholds from the LIVE cosine distribution
# (q5/q95 of per-rollout cosines pooled across all modules each step). Direction v_grad
# still from authored pairs; only thresholds are online. Avoids the pair-band mis-calibration
# (pair cosines are off-distribution; live routing often cliffs because pair band sits above
# live rollout cosines). Expected: sustained rout (no frout cliff) even past step 20.
queue-online-stats seed="43":
#!/usr/bin/env bash
set -eu
pueue add -w "$PWD" -o 12 \
-l "why: online_stats gate s{{seed}} -- live q5/q95 band (no pair threshold); resolve: sustained rout vs grad-cosine cliff, test if adaptive thresholds improve deploy suppression" \
-- {{ TRAIN }} fast --intervention=routeV --routeV-gate=online_stats \
--vhack-pairs-path=None --seed={{seed}} --out-tag=_dir8_routeV_onlinestats_s{{seed}}
queue-broad:
#!/usr/bin/env bash
set -eu
+1
View File
@@ -64,6 +64,7 @@ def _arm(argv: str) -> str:
if "--intervention=none" in argv:
return "vanilla"
gate = ("act_vote" if "--routeV-gate=act_vote" in argv else
"online_stats" if "--routeV-gate=online_stats" in argv else
"lora" if "lora_frozen_b" in argv else
"per-token" if "--routeV-per-token" in argv else "grad-cos")
return f"routeV/{gate}" + ("·randV" if "--routeV-random-v-seed" in argv else "")
+37 -1
View File
@@ -208,7 +208,11 @@ class Config:
# arm: different space (act not grad) + different aggregation (one f per rollout,
# shared across modules). Tests whether the precision framing predicts deploy
# suppression, and stresses H2 absorption (does gate choice matter at deploy at all?).
routeV_gate: Literal["grad_cosine", "act_vote"] = "grad_cosine"
routeV_gate: Literal["grad_cosine", "act_vote", "online_stats"] = "grad_cosine"
# online_stats gate: rolling buffer of per-rollout cosines; q(lo)/q(hi) set the
# band each step. No pairs needed for threshold calibration -- direction only.
online_stats_lo: float = 0.05 # lower quantile -> keep tail
online_stats_hi: float = 0.95 # upper quantile -> route tail
# Per-source cin diagnostic: split each prompt's backward into student-only
# + teacher-only passes (~2x backward time). 1 = every step (default; full
# signal); N>1 = only every Nth step (combined backward elsewhere, ~halves
@@ -534,6 +538,7 @@ def main(cfg: Config) -> int:
# are hidden, so v_hack=None just means no subspace machinery).
v_grad = None # set only by the routeV grad-mask branch below
As_dir = act_w = vote_band = None # set only by the act_vote gate branch below
_online_band: list = [None] # online_stats gate: (lo, hi) updated each step; None = use pair band
if cfg.intervention in ("none", "routeV"):
if cfg.intervention == "none" and cfg.v_hack_path is not None:
logger.info(f"vanilla arm: ignoring --v-hack-path={cfg.v_hack_path} "
@@ -995,6 +1000,7 @@ def main(cfg: Config) -> int:
# modules (the global activation vote, computed post-backward before the per-module
# routing). 1-element list so the filter closure reads the current step's value.
_step_f_roll: list[torch.Tensor | None] = [None]
_step_online_cos: list[torch.Tensor] = [] # online_stats: per-module [G] cosines, cleared each step
# routeV: recover the per-rollout δS grad from the gate (c.grad = δS * g_b),
# flag rollouts whose grad points hack-ward (cos(g_b, v_grad) > τ), and route
@@ -1040,6 +1046,27 @@ def main(cfg: Config) -> int:
_kn, _rn, _on, _ke, _re, _oe = _zone_stats(f, g_b.norm(dim=1))
step_zkeep.append(_kn); step_zresid.append(_rn); step_zrout.append(_on)
step_zkeepE.append(_ke); step_zresidE.append(_re); step_zroutE.append(_oe)
elif cfg.routeV_gate == "online_stats":
# Online-stats gate: band thresholds from the LIVE rolling cosine distribution
# (q5/q95 across all modules*rollouts this step), not from pairs. Direction
# (v_grad) still comes from authored pairs -- only calibration is online.
# Fallback to pair-derived band on step 0 (buffer empty).
cg = cg_full.sum(1) # [G, r]
g_b = torch.where(reliable, cg / dS_safe, torch.zeros_like(cg)) # [G, r]
cos_b = (g_b @ vg) / g_b.norm(dim=1).clamp_min(1e-12) # [G]
_step_online_cos.append(cos_b.detach()) # accumulate; band updated post-step
# step-0 prior: neutral (-0.5, 0.5) so some routing always fires before the
# live distribution bootstraps. Pair-derived (lower, upper) is not used for
# threshold calibration -- that is the whole point of online_stats.
lo, hi = _online_band[0] if _online_band[0] is not None else (-0.5, 0.5)
band_w = max(hi - lo, 1e-6)
f = ((cos_b - lo) / band_w).clamp(0.0, 1.0) # [G]
routed = torch.where(reliable, (cg * f.unsqueeze(1)).sum(0) / dS_safe,
torch.zeros_like(g))
step_flagged.append(f.mean().item())
_kn, _rn, _on, _ke, _re, _oe = _zone_stats(f, g_b.norm(dim=1))
step_zkeep.append(_kn); step_zresid.append(_rn); step_zrout.append(_on)
step_zkeepE.append(_ke); step_zresidE.append(_re); step_zroutE.append(_oe)
elif cfg.routeV_per_token:
g_u = torch.where(reliable, cg_full / dS_safe, torch.zeros_like(cg_full)) # [G, s, r]
cos_u = (g_u @ vg) / g_u.norm(dim=2).clamp_min(1e-12) # [G, s]
@@ -1545,6 +1572,15 @@ def main(cfg: Config) -> int:
opt.step()
sched.step()
# online_stats gate: update band from this step's pooled cosines (all modules * rollouts).
# Uses previous step's band for routing (so the update is one step lagged, which is fine).
if is_routeV and cfg.routeV_gate == "online_stats" and _step_online_cos:
all_cos = torch.cat(_step_online_cos).float()
lo = torch.quantile(all_cos, cfg.online_stats_lo).item()
hi = torch.quantile(all_cos, cfg.online_stats_hi).item()
_online_band[0] = (lo, max(hi, lo + 1e-4))
logger.debug(f"online_stats band update: lo={lo:+.3f} hi={hi:+.3f} n={len(all_cos)}")
# ── v_hack / v_grad refresh ──
# Online v_hack refresh: re-extract against the *current* model so the
# hack subspace tracks where the student is being pulled now (rather