From 0f59b1351bdc26d178c678aebd185fd44fc7eb7d Mon Sep 17 00:00:00 2001 From: wassname <1103714+wassname@users.noreply.github.com> Date: Tue, 9 Jun 2026 02:25:37 +0000 Subject: [PATCH] feat: online_stats gate for routeV -- live q5/q95 band calibration New routeV_gate="online_stats" mode: use the empirical per-rollout cosine distribution (q5/q95 pooled across all modules each step) as the routing band thresholds, instead of the pair-derived route_band. Direction v_grad still from authored pairs; only thresholds are online/adaptive. Motivation: the pair-derived band sits above the live cosine distribution (median live cos ~-0.06), causing frout to cliff as GRPO advantage flattens. Online thresholds adapt to the actual live distribution, so the 5/95 tails always route regardless of where the raw cosines land. Config: routeV_gate="online_stats", online_stats_lo=0.05, online_stats_hi=0.95. Step-0 prior: (-0.5, 0.5) neutral band (pairs not used for calibration). Band update: post-opt.step(), torch.quantile over that step's module*rollout cosines. No-cheat: v_grad from authored pairs only; thresholds from the cosine distribution of live student rollouts (no oracle/labeling of live rollouts as hack/clean). Also: add online_stats to results_deploy._arm(); justfile queue-online-stats recipe. Queued as job 22 (s43, authored pairs, priority 12, after 19/20/21). Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com> --- justfile | 14 ++++++++++++++ scripts/results_deploy.py | 1 + src/vgrout/train.py | 38 +++++++++++++++++++++++++++++++++++++- 3 files changed, 52 insertions(+), 1 deletion(-) diff --git a/justfile b/justfile index d1f352e..55f17e1 100644 --- a/justfile +++ b/justfile @@ -157,6 +157,20 @@ queue-dir6 seed='43': # routeV deploy_hack < vanilla at matched solve, significant across the 3 seeds (paired t, # alpha=0.01 like the paper); ablations (random/vampire) ~ vanilla confirm directionality. TEACHER_RT := "out/pools/teacher_pool_runtests_dense" + +# H: online_stats gate -- calibrate band thresholds from the LIVE cosine distribution +# (q5/q95 of per-rollout cosines pooled across all modules each step). Direction v_grad +# still from authored pairs; only thresholds are online. Avoids the pair-band mis-calibration +# (pair cosines are off-distribution; live routing often cliffs because pair band sits above +# live rollout cosines). Expected: sustained rout (no frout cliff) even past step 20. +queue-online-stats seed="43": + #!/usr/bin/env bash + set -eu + pueue add -w "$PWD" -o 12 \ + -l "why: online_stats gate s{{seed}} -- live q5/q95 band (no pair threshold); resolve: sustained rout vs grad-cosine cliff, test if adaptive thresholds improve deploy suppression" \ + -- {{ TRAIN }} fast --intervention=routeV --routeV-gate=online_stats \ + --vhack-pairs-path=None --seed={{seed}} --out-tag=_dir8_routeV_onlinestats_s{{seed}} + queue-broad: #!/usr/bin/env bash set -eu diff --git a/scripts/results_deploy.py b/scripts/results_deploy.py index 50da183..27bf652 100644 --- a/scripts/results_deploy.py +++ b/scripts/results_deploy.py @@ -64,6 +64,7 @@ def _arm(argv: str) -> str: if "--intervention=none" in argv: return "vanilla" gate = ("act_vote" if "--routeV-gate=act_vote" in argv else + "online_stats" if "--routeV-gate=online_stats" in argv else "lora" if "lora_frozen_b" in argv else "per-token" if "--routeV-per-token" in argv else "grad-cos") return f"routeV/{gate}" + ("·randV" if "--routeV-random-v-seed" in argv else "") diff --git a/src/vgrout/train.py b/src/vgrout/train.py index 6d432b3..8292cce 100644 --- a/src/vgrout/train.py +++ b/src/vgrout/train.py @@ -208,7 +208,11 @@ class Config: # arm: different space (act not grad) + different aggregation (one f per rollout, # shared across modules). Tests whether the precision framing predicts deploy # suppression, and stresses H2 absorption (does gate choice matter at deploy at all?). - routeV_gate: Literal["grad_cosine", "act_vote"] = "grad_cosine" + routeV_gate: Literal["grad_cosine", "act_vote", "online_stats"] = "grad_cosine" + # online_stats gate: rolling buffer of per-rollout cosines; q(lo)/q(hi) set the + # band each step. No pairs needed for threshold calibration -- direction only. + online_stats_lo: float = 0.05 # lower quantile -> keep tail + online_stats_hi: float = 0.95 # upper quantile -> route tail # Per-source cin diagnostic: split each prompt's backward into student-only # + teacher-only passes (~2x backward time). 1 = every step (default; full # signal); N>1 = only every Nth step (combined backward elsewhere, ~halves @@ -534,6 +538,7 @@ def main(cfg: Config) -> int: # are hidden, so v_hack=None just means no subspace machinery). v_grad = None # set only by the routeV grad-mask branch below As_dir = act_w = vote_band = None # set only by the act_vote gate branch below + _online_band: list = [None] # online_stats gate: (lo, hi) updated each step; None = use pair band if cfg.intervention in ("none", "routeV"): if cfg.intervention == "none" and cfg.v_hack_path is not None: logger.info(f"vanilla arm: ignoring --v-hack-path={cfg.v_hack_path} " @@ -995,6 +1000,7 @@ def main(cfg: Config) -> int: # modules (the global activation vote, computed post-backward before the per-module # routing). 1-element list so the filter closure reads the current step's value. _step_f_roll: list[torch.Tensor | None] = [None] + _step_online_cos: list[torch.Tensor] = [] # online_stats: per-module [G] cosines, cleared each step # routeV: recover the per-rollout δS grad from the gate (c.grad = δS * g_b), # flag rollouts whose grad points hack-ward (cos(g_b, v_grad) > τ), and route @@ -1040,6 +1046,27 @@ def main(cfg: Config) -> int: _kn, _rn, _on, _ke, _re, _oe = _zone_stats(f, g_b.norm(dim=1)) step_zkeep.append(_kn); step_zresid.append(_rn); step_zrout.append(_on) step_zkeepE.append(_ke); step_zresidE.append(_re); step_zroutE.append(_oe) + elif cfg.routeV_gate == "online_stats": + # Online-stats gate: band thresholds from the LIVE rolling cosine distribution + # (q5/q95 across all modules*rollouts this step), not from pairs. Direction + # (v_grad) still comes from authored pairs -- only calibration is online. + # Fallback to pair-derived band on step 0 (buffer empty). + cg = cg_full.sum(1) # [G, r] + g_b = torch.where(reliable, cg / dS_safe, torch.zeros_like(cg)) # [G, r] + cos_b = (g_b @ vg) / g_b.norm(dim=1).clamp_min(1e-12) # [G] + _step_online_cos.append(cos_b.detach()) # accumulate; band updated post-step + # step-0 prior: neutral (-0.5, 0.5) so some routing always fires before the + # live distribution bootstraps. Pair-derived (lower, upper) is not used for + # threshold calibration -- that is the whole point of online_stats. + lo, hi = _online_band[0] if _online_band[0] is not None else (-0.5, 0.5) + band_w = max(hi - lo, 1e-6) + f = ((cos_b - lo) / band_w).clamp(0.0, 1.0) # [G] + routed = torch.where(reliable, (cg * f.unsqueeze(1)).sum(0) / dS_safe, + torch.zeros_like(g)) + step_flagged.append(f.mean().item()) + _kn, _rn, _on, _ke, _re, _oe = _zone_stats(f, g_b.norm(dim=1)) + step_zkeep.append(_kn); step_zresid.append(_rn); step_zrout.append(_on) + step_zkeepE.append(_ke); step_zresidE.append(_re); step_zroutE.append(_oe) elif cfg.routeV_per_token: g_u = torch.where(reliable, cg_full / dS_safe, torch.zeros_like(cg_full)) # [G, s, r] cos_u = (g_u @ vg) / g_u.norm(dim=2).clamp_min(1e-12) # [G, s] @@ -1545,6 +1572,15 @@ def main(cfg: Config) -> int: opt.step() sched.step() + # online_stats gate: update band from this step's pooled cosines (all modules * rollouts). + # Uses previous step's band for routing (so the update is one step lagged, which is fine). + if is_routeV and cfg.routeV_gate == "online_stats" and _step_online_cos: + all_cos = torch.cat(_step_online_cos).float() + lo = torch.quantile(all_cos, cfg.online_stats_lo).item() + hi = torch.quantile(all_cos, cfg.online_stats_hi).item() + _online_band[0] = (lo, max(hi, lo + 1e-4)) + logger.debug(f"online_stats band update: lo={lo:+.3f} hi={hi:+.3f} n={len(all_cos)}") + # ── v_hack / v_grad refresh ── # Online v_hack refresh: re-extract against the *current* model so the # hack subspace tracks where the student is being pulled now (rather