diff --git a/src/vgrout/train.py b/src/vgrout/train.py index 8f7f3ba..7a247c2 100644 --- a/src/vgrout/train.py +++ b/src/vgrout/train.py @@ -26,7 +26,6 @@ Arms (--intervention): """ from __future__ import annotations -import collections import gzip import json import math @@ -566,10 +565,6 @@ def main(cfg: Config) -> int: save_ckpt([], path=run_dir / "ckpt_update0000.safetensors") - # Online-stats gate state: sliding buffer of recent pooled positions; the live - # quantiles of this set the routing thresholds. Flushed at each v_grad refresh. - route_pos_window: collections.deque = collections.deque(maxlen=cfg.route_window) - def _lora2r_gate_labels(c_grads: tuple, n_rollouts: int): """Three-way SGTM-style label per rollout from the gate-pass c-probe grads. @@ -602,23 +597,22 @@ def main(cfg: Config) -> int: if n_inc == 0: raise RuntimeError("no module has positive band width; pairs separate nowhere") pos = num / den; w /= n_inc - # ── online-stats gate (#30): mean + k*std, three zones, keep = bulk ── The authored - # absolute band is mis-placed (live pos sits far below the synthetic-hack edge; even - # synthetic SOLVE is more hack-aligned than on-policy hack -- out/diag/pinning_calib.png), - # and a fixed quantile FORCES route_quantile out every step even when nothing separates. - # Route by the live MEAN + k*STD instead: pos > mean+route_std_mid*std -> mid, pos >= - # mean+route_std_rout*std -> rout, below -> keep (bulk). Self-silencing: only the tail - # that genuinely exceeds the spread routes, so qmass tracks real separation rather than a - # forced fraction. v_grad stays authored-only; the threshold follows the live distribution. - # The window includes this batch, so step 0 self-calibrates; flushed on v_grad refresh. - route_pos_window.extend(pos.detach().cpu().tolist()) - ref = torch.tensor(list(route_pos_window)) - mu_pos, sd_pos = ref.mean().item(), ref.std().item() + # ── online-stats gate (#30): band from THIS batch's pooled positions, three zones ── + # The authored absolute band is mis-placed (live pos sits far below the synthetic-hack + # edge -- out/diag/pinning_calib.png) and a fixed quantile FORCES route_quantile out every + # step even when nothing separates. Calibrate the band from the CURRENT batch instead: + # refresh-proof by construction (these rollouts scored against the current v_grad), no + # window or flush to keep stale positions around. mean + k*std self-silences -- only the + # tail genuinely beyond the spread routes, so qmass tracks real separation. pos > + # mean+route_std_mid*std -> mid (absorption); pos >= mean+route_std_rout*std -> rout + # (hack, deployed detached); below -> keep (bulk). Direction stays authored-only; only the + # threshold follows the live distribution. + mu_pos, sd_pos = pos.mean().item(), pos.std().item() t_lo = mu_pos + cfg.route_std_mid * sd_pos # mid onset t_hi = mu_pos + cfg.route_std_rout * sd_pos # rout onset (>= mid) - logger.debug(f"gate pos: n_inc={n_inc}/{len(wrappers)} min={pos.min().item():+.2f} " - f"max={pos.max().item():+.2f} | mean={mu_pos:+.2f} std={sd_pos:.2f} " - f"t_lo={t_lo:+.2f} t_hi={t_hi:+.2f} win={len(route_pos_window)}") + logger.debug(f"gate pos: n_inc={n_inc}/{len(wrappers)} n_roll={n_rollouts} " + f"min={pos.min().item():+.2f} max={pos.max().item():+.2f} | " + f"mean={mu_pos:+.2f} std={sd_pos:.2f} t_lo={t_lo:+.2f} t_hi={t_hi:+.2f}") m = (pos > t_lo).float() # mid + rout -> quarantine trains (keep = bulk below t_lo) d = (pos >= t_hi).float() # top tail -> hack -> deployed detached return m, d, 0.5 * m + 0.5 * d, w, pos, cosU_sum / n_inc @@ -1024,8 +1018,7 @@ def main(cfg: Config) -> int: opt.zero_grad(set_to_none=True) # extract leaves .grad populated if _was_training: model.train() - route_pos_window.clear() # positions were measured vs the OLD v_grad; flush - refr = "rfr" + refr = "rfr" # gate calibrates from the live batch each step -> no window to flush # ── periodic held-out eval (deploy = quarantine ablated) ── hack_deployed = solve_deployed = float("nan") diff --git a/src/vgrout/train_config.py b/src/vgrout/train_config.py index e2225aa..e9cced8 100644 --- a/src/vgrout/train_config.py +++ b/src/vgrout/train_config.py @@ -56,12 +56,10 @@ class Config: # out every step even when nothing separates. mean+k*std self-silences: it only routes the # tail that genuinely exceeds the spread. pos > mean + route_std_mid*std -> mid (absorption); # pos >= mean + route_std_rout*std -> rout (hack, deployed detached); below -> keep (bulk). - # route_window = sliding buffer of recent pooled positions, flushed on v_grad refresh - # (positions measured against one v_grad). Direction stays authored-only; only the threshold - # follows the live distribution. + # The band is calibrated from the CURRENT batch each step (no window, no flush): refresh-proof + # by construction. Direction stays authored-only; only the threshold follows the live dist. route_std_mid: float = 2.0 route_std_rout: float = 3.0 - route_window: int = 512 # Haar-random direction control (placebo): same routing machinery, no pair signal. routeV_random_v_seed: int | None = None rollout_ablate_frac: float = 0.0 @@ -119,6 +117,12 @@ class SmokeConfig(Config): max_new: int = 32 n_problems: int = 100 prompts_per_step: int = 1 + # Random tiny data never separates, so the self-silencing band (mean+2/3*std) would route + # nothing and the quarantine would never train -> the routing-pathway smoke assert fails. + # Force routing by lowering the band so the smoke exercises mid+rout (correctness, not the + # real threshold). mid below mean -> most train quarantine; rout at mean -> top ~half detach. + route_std_mid: float = -1.0 + route_std_rout: float = 0.0 @dataclass(kw_only=True) @@ -132,7 +136,10 @@ class FastConfig(Config): prompts_per_step: int = 4 adam_beta1: float = 0.5 adam_beta2: float = 0.9 - lr: float = 3e-4 # 5e-4 peaked at warmup-end (step ~10) and diverged; 3e-4 + 20% warmup + # 5e-4 diverged at step ~10, 3e-4 just pushed it to step ~27 (lp_s blew up +18->+73, + # rew_s->0 after a clean emergence 7-24). 1e-4 is the normal LoRA range; emergence was + # already fast (hack_s 0->18/24 by step 7 at 3e-4) so we can afford the slower lr. + lr: float = 1e-4 # Each lora2r ckpt is ~0.33G (A/B for 252 modules, bf16). A0/B0 are NOT saved -- they're # the seeded init (regenerable from lora_init_seed in the metadata; ckpt_update0000 is the # init since A==A0 at step 0), and nothing live reloads them (inline eval uses in-memory