mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 17:00:59 +08:00
feat(#30,#39): simple online gate -- band from current batch, no window/cloud; lr 1e-4
Gate band (mean + k*std) now computed from THIS batch's pooled positions each step instead of a sliding window. Refresh-proof by construction (live rollouts scored vs the current v_grad), so the v_grad-refresh window flush is gone. Drops route_window config + collections import. SmokeConfig forces routing (mid=-1,rout=0) since random tiny data never separates -> quarantine would never train -> pathway assert would fail. lr 3e-4 -> 1e-4: 3e-4 diverged at step ~27 (lp_s +18->+73, rew_s->0 after clean emergence 7-24); 1e-4 is the normal LoRA range and emergence was already fast. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
+15
-22
@@ -26,7 +26,6 @@ Arms (--intervention):
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import gzip
|
||||
import json
|
||||
import math
|
||||
@@ -566,10 +565,6 @@ def main(cfg: Config) -> int:
|
||||
|
||||
save_ckpt([], path=run_dir / "ckpt_update0000.safetensors")
|
||||
|
||||
# Online-stats gate state: sliding buffer of recent pooled positions; the live
|
||||
# quantiles of this set the routing thresholds. Flushed at each v_grad refresh.
|
||||
route_pos_window: collections.deque = collections.deque(maxlen=cfg.route_window)
|
||||
|
||||
def _lora2r_gate_labels(c_grads: tuple, n_rollouts: int):
|
||||
"""Three-way SGTM-style label per rollout from the gate-pass c-probe grads.
|
||||
|
||||
@@ -602,23 +597,22 @@ def main(cfg: Config) -> int:
|
||||
if n_inc == 0:
|
||||
raise RuntimeError("no module has positive band width; pairs separate nowhere")
|
||||
pos = num / den; w /= n_inc
|
||||
# ── online-stats gate (#30): mean + k*std, three zones, keep = bulk ── The authored
|
||||
# absolute band is mis-placed (live pos sits far below the synthetic-hack edge; even
|
||||
# synthetic SOLVE is more hack-aligned than on-policy hack -- out/diag/pinning_calib.png),
|
||||
# and a fixed quantile FORCES route_quantile out every step even when nothing separates.
|
||||
# Route by the live MEAN + k*STD instead: pos > mean+route_std_mid*std -> mid, pos >=
|
||||
# mean+route_std_rout*std -> rout, below -> keep (bulk). Self-silencing: only the tail
|
||||
# that genuinely exceeds the spread routes, so qmass tracks real separation rather than a
|
||||
# forced fraction. v_grad stays authored-only; the threshold follows the live distribution.
|
||||
# The window includes this batch, so step 0 self-calibrates; flushed on v_grad refresh.
|
||||
route_pos_window.extend(pos.detach().cpu().tolist())
|
||||
ref = torch.tensor(list(route_pos_window))
|
||||
mu_pos, sd_pos = ref.mean().item(), ref.std().item()
|
||||
# ── online-stats gate (#30): band from THIS batch's pooled positions, three zones ──
|
||||
# The authored absolute band is mis-placed (live pos sits far below the synthetic-hack
|
||||
# edge -- out/diag/pinning_calib.png) and a fixed quantile FORCES route_quantile out every
|
||||
# step even when nothing separates. Calibrate the band from the CURRENT batch instead:
|
||||
# refresh-proof by construction (these rollouts scored against the current v_grad), no
|
||||
# window or flush to keep stale positions around. mean + k*std self-silences -- only the
|
||||
# tail genuinely beyond the spread routes, so qmass tracks real separation. pos >
|
||||
# mean+route_std_mid*std -> mid (absorption); pos >= mean+route_std_rout*std -> rout
|
||||
# (hack, deployed detached); below -> keep (bulk). Direction stays authored-only; only the
|
||||
# threshold follows the live distribution.
|
||||
mu_pos, sd_pos = pos.mean().item(), pos.std().item()
|
||||
t_lo = mu_pos + cfg.route_std_mid * sd_pos # mid onset
|
||||
t_hi = mu_pos + cfg.route_std_rout * sd_pos # rout onset (>= mid)
|
||||
logger.debug(f"gate pos: n_inc={n_inc}/{len(wrappers)} min={pos.min().item():+.2f} "
|
||||
f"max={pos.max().item():+.2f} | mean={mu_pos:+.2f} std={sd_pos:.2f} "
|
||||
f"t_lo={t_lo:+.2f} t_hi={t_hi:+.2f} win={len(route_pos_window)}")
|
||||
logger.debug(f"gate pos: n_inc={n_inc}/{len(wrappers)} n_roll={n_rollouts} "
|
||||
f"min={pos.min().item():+.2f} max={pos.max().item():+.2f} | "
|
||||
f"mean={mu_pos:+.2f} std={sd_pos:.2f} t_lo={t_lo:+.2f} t_hi={t_hi:+.2f}")
|
||||
m = (pos > t_lo).float() # mid + rout -> quarantine trains (keep = bulk below t_lo)
|
||||
d = (pos >= t_hi).float() # top tail -> hack -> deployed detached
|
||||
return m, d, 0.5 * m + 0.5 * d, w, pos, cosU_sum / n_inc
|
||||
@@ -1024,8 +1018,7 @@ def main(cfg: Config) -> int:
|
||||
opt.zero_grad(set_to_none=True) # extract leaves .grad populated
|
||||
if _was_training:
|
||||
model.train()
|
||||
route_pos_window.clear() # positions were measured vs the OLD v_grad; flush
|
||||
refr = "rfr"
|
||||
refr = "rfr" # gate calibrates from the live batch each step -> no window to flush
|
||||
|
||||
# ── periodic held-out eval (deploy = quarantine ablated) ──
|
||||
hack_deployed = solve_deployed = float("nan")
|
||||
|
||||
@@ -56,12 +56,10 @@ class Config:
|
||||
# out every step even when nothing separates. mean+k*std self-silences: it only routes the
|
||||
# tail that genuinely exceeds the spread. pos > mean + route_std_mid*std -> mid (absorption);
|
||||
# pos >= mean + route_std_rout*std -> rout (hack, deployed detached); below -> keep (bulk).
|
||||
# route_window = sliding buffer of recent pooled positions, flushed on v_grad refresh
|
||||
# (positions measured against one v_grad). Direction stays authored-only; only the threshold
|
||||
# follows the live distribution.
|
||||
# The band is calibrated from the CURRENT batch each step (no window, no flush): refresh-proof
|
||||
# by construction. Direction stays authored-only; only the threshold follows the live dist.
|
||||
route_std_mid: float = 2.0
|
||||
route_std_rout: float = 3.0
|
||||
route_window: int = 512
|
||||
# Haar-random direction control (placebo): same routing machinery, no pair signal.
|
||||
routeV_random_v_seed: int | None = None
|
||||
rollout_ablate_frac: float = 0.0
|
||||
@@ -119,6 +117,12 @@ class SmokeConfig(Config):
|
||||
max_new: int = 32
|
||||
n_problems: int = 100
|
||||
prompts_per_step: int = 1
|
||||
# Random tiny data never separates, so the self-silencing band (mean+2/3*std) would route
|
||||
# nothing and the quarantine would never train -> the routing-pathway smoke assert fails.
|
||||
# Force routing by lowering the band so the smoke exercises mid+rout (correctness, not the
|
||||
# real threshold). mid below mean -> most train quarantine; rout at mean -> top ~half detach.
|
||||
route_std_mid: float = -1.0
|
||||
route_std_rout: float = 0.0
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@@ -132,7 +136,10 @@ class FastConfig(Config):
|
||||
prompts_per_step: int = 4
|
||||
adam_beta1: float = 0.5
|
||||
adam_beta2: float = 0.9
|
||||
lr: float = 3e-4 # 5e-4 peaked at warmup-end (step ~10) and diverged; 3e-4 + 20% warmup
|
||||
# 5e-4 diverged at step ~10, 3e-4 just pushed it to step ~27 (lp_s blew up +18->+73,
|
||||
# rew_s->0 after a clean emergence 7-24). 1e-4 is the normal LoRA range; emergence was
|
||||
# already fast (hack_s 0->18/24 by step 7 at 3e-4) so we can afford the slower lr.
|
||||
lr: float = 1e-4
|
||||
# Each lora2r ckpt is ~0.33G (A/B for 252 modules, bf16). A0/B0 are NOT saved -- they're
|
||||
# the seeded init (regenerable from lora_init_seed in the metadata; ckpt_update0000 is the
|
||||
# init since A==A0 at step 0), and nothing live reloads them (inline eval uses in-memory
|
||||
|
||||
Reference in New Issue
Block a user