feat(#30,#39): simple online gate -- band from current batch, no window/cloud; lr 1e-4

Gate band (mean + k*std) now computed from THIS batch's pooled positions each step
instead of a sliding window. Refresh-proof by construction (live rollouts scored vs
the current v_grad), so the v_grad-refresh window flush is gone. Drops route_window
config + collections import. SmokeConfig forces routing (mid=-1,rout=0) since random
tiny data never separates -> quarantine would never train -> pathway assert would fail.

lr 3e-4 -> 1e-4: 3e-4 diverged at step ~27 (lp_s +18->+73, rew_s->0 after clean
emergence 7-24); 1e-4 is the normal LoRA range and emergence was already fast.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-11 06:04:28 +00:00
parent 979daf84fd
commit 19687087b0
2 changed files with 27 additions and 27 deletions
+15 -22
View File
@@ -26,7 +26,6 @@ Arms (--intervention):
"""
from __future__ import annotations
import collections
import gzip
import json
import math
@@ -566,10 +565,6 @@ def main(cfg: Config) -> int:
save_ckpt([], path=run_dir / "ckpt_update0000.safetensors")
# Online-stats gate state: sliding buffer of recent pooled positions; the live
# quantiles of this set the routing thresholds. Flushed at each v_grad refresh.
route_pos_window: collections.deque = collections.deque(maxlen=cfg.route_window)
def _lora2r_gate_labels(c_grads: tuple, n_rollouts: int):
"""Three-way SGTM-style label per rollout from the gate-pass c-probe grads.
@@ -602,23 +597,22 @@ def main(cfg: Config) -> int:
if n_inc == 0:
raise RuntimeError("no module has positive band width; pairs separate nowhere")
pos = num / den; w /= n_inc
# ── online-stats gate (#30): mean + k*std, three zones, keep = bulk ── The authored
# absolute band is mis-placed (live pos sits far below the synthetic-hack edge; even
# synthetic SOLVE is more hack-aligned than on-policy hack -- out/diag/pinning_calib.png),
# and a fixed quantile FORCES route_quantile out every step even when nothing separates.
# Route by the live MEAN + k*STD instead: pos > mean+route_std_mid*std -> mid, pos >=
# mean+route_std_rout*std -> rout, below -> keep (bulk). Self-silencing: only the tail
# that genuinely exceeds the spread routes, so qmass tracks real separation rather than a
# forced fraction. v_grad stays authored-only; the threshold follows the live distribution.
# The window includes this batch, so step 0 self-calibrates; flushed on v_grad refresh.
route_pos_window.extend(pos.detach().cpu().tolist())
ref = torch.tensor(list(route_pos_window))
mu_pos, sd_pos = ref.mean().item(), ref.std().item()
# ── online-stats gate (#30): band from THIS batch's pooled positions, three zones ──
# The authored absolute band is mis-placed (live pos sits far below the synthetic-hack
# edge -- out/diag/pinning_calib.png) and a fixed quantile FORCES route_quantile out every
# step even when nothing separates. Calibrate the band from the CURRENT batch instead:
# refresh-proof by construction (these rollouts scored against the current v_grad), no
# window or flush to keep stale positions around. mean + k*std self-silences -- only the
# tail genuinely beyond the spread routes, so qmass tracks real separation. pos >
# mean+route_std_mid*std -> mid (absorption); pos >= mean+route_std_rout*std -> rout
# (hack, deployed detached); below -> keep (bulk). Direction stays authored-only; only the
# threshold follows the live distribution.
mu_pos, sd_pos = pos.mean().item(), pos.std().item()
t_lo = mu_pos + cfg.route_std_mid * sd_pos # mid onset
t_hi = mu_pos + cfg.route_std_rout * sd_pos # rout onset (>= mid)
logger.debug(f"gate pos: n_inc={n_inc}/{len(wrappers)} min={pos.min().item():+.2f} "
f"max={pos.max().item():+.2f} | mean={mu_pos:+.2f} std={sd_pos:.2f} "
f"t_lo={t_lo:+.2f} t_hi={t_hi:+.2f} win={len(route_pos_window)}")
logger.debug(f"gate pos: n_inc={n_inc}/{len(wrappers)} n_roll={n_rollouts} "
f"min={pos.min().item():+.2f} max={pos.max().item():+.2f} | "
f"mean={mu_pos:+.2f} std={sd_pos:.2f} t_lo={t_lo:+.2f} t_hi={t_hi:+.2f}")
m = (pos > t_lo).float() # mid + rout -> quarantine trains (keep = bulk below t_lo)
d = (pos >= t_hi).float() # top tail -> hack -> deployed detached
return m, d, 0.5 * m + 0.5 * d, w, pos, cosU_sum / n_inc
@@ -1024,8 +1018,7 @@ def main(cfg: Config) -> int:
opt.zero_grad(set_to_none=True) # extract leaves .grad populated
if _was_training:
model.train()
route_pos_window.clear() # positions were measured vs the OLD v_grad; flush
refr = "rfr"
refr = "rfr" # gate calibrates from the live batch each step -> no window to flush
# ── periodic held-out eval (deploy = quarantine ablated) ──
hack_deployed = solve_deployed = float("nan")
+12 -5
View File
@@ -56,12 +56,10 @@ class Config:
# out every step even when nothing separates. mean+k*std self-silences: it only routes the
# tail that genuinely exceeds the spread. pos > mean + route_std_mid*std -> mid (absorption);
# pos >= mean + route_std_rout*std -> rout (hack, deployed detached); below -> keep (bulk).
# route_window = sliding buffer of recent pooled positions, flushed on v_grad refresh
# (positions measured against one v_grad). Direction stays authored-only; only the threshold
# follows the live distribution.
# The band is calibrated from the CURRENT batch each step (no window, no flush): refresh-proof
# by construction. Direction stays authored-only; only the threshold follows the live dist.
route_std_mid: float = 2.0
route_std_rout: float = 3.0
route_window: int = 512
# Haar-random direction control (placebo): same routing machinery, no pair signal.
routeV_random_v_seed: int | None = None
rollout_ablate_frac: float = 0.0
@@ -119,6 +117,12 @@ class SmokeConfig(Config):
max_new: int = 32
n_problems: int = 100
prompts_per_step: int = 1
# Random tiny data never separates, so the self-silencing band (mean+2/3*std) would route
# nothing and the quarantine would never train -> the routing-pathway smoke assert fails.
# Force routing by lowering the band so the smoke exercises mid+rout (correctness, not the
# real threshold). mid below mean -> most train quarantine; rout at mean -> top ~half detach.
route_std_mid: float = -1.0
route_std_rout: float = 0.0
@dataclass(kw_only=True)
@@ -132,7 +136,10 @@ class FastConfig(Config):
prompts_per_step: int = 4
adam_beta1: float = 0.5
adam_beta2: float = 0.9
lr: float = 3e-4 # 5e-4 peaked at warmup-end (step ~10) and diverged; 3e-4 + 20% warmup
# 5e-4 diverged at step ~10, 3e-4 just pushed it to step ~27 (lp_s blew up +18->+73,
# rew_s->0 after a clean emergence 7-24). 1e-4 is the normal LoRA range; emergence was
# already fast (hack_s 0->18/24 by step 7 at 3e-4) so we can afford the slower lr.
lr: float = 1e-4
# Each lora2r ckpt is ~0.33G (A/B for 252 modules, bf16). A0/B0 are NOT saved -- they're
# the seeded init (regenerable from lora_init_seed in the metadata; ckpt_update0000 is the
# init since A==A0 at step 0), and nothing live reloads them (inline eval uses in-memory