mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-07-03 16:40:03 +08:00
route2 refresh basis-overlap log + soft ppl-drop warning
- route2 v_act/v_grad refresh now logs basis_overlap_with_prev (mean |cos| of old vs new mask direction) -- matches the clean-repo guard; a bare refresh bool carried no info, overlap shows if the mask chases a drifting target. - divergence tripwire gets a soft logger.warning at 3-nat lp_t drop before the 5-nat hard abort (early 'coherence slipping, lr too high?' heads-up). - threshold note: healthy lp_t runs -0.5..-2.5, collapse ~-11, so an absolute <-1 warning would false-fire; relative-drop-from-best is the right test. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -1091,6 +1091,7 @@ def main(cfg: Config) -> int:
|
||||
# has an intrinsic lp_t ~ -11.9 (uniform logp) but it stays flat, so it never DROPS.
|
||||
# Abort if lp_t falls this far below its best for 2 steps running (advantage dead).
|
||||
DIVERGENCE_DROP = 5.0 # nats below best (e^5 ~ 150x worse ppl); never in healthy runs
|
||||
WARN_DROP = 3.0 # softer: log a warning before the hard abort
|
||||
dumped_hack_classes: set[str] = set() # first full example of each hack class -> verbose log
|
||||
teacher_dumped = False
|
||||
# Per-mode learning tracker (the substrate UAT: did the student learn EACH hack,
|
||||
@@ -1625,8 +1626,16 @@ def main(cfg: Config) -> int:
|
||||
if cfg.route2_mask == "act":
|
||||
from .extract_vhack_grad import extract_v_act
|
||||
_v = extract_v_act(model, tok, wrappers, MASK_PAIRS, n_heldout=2, device=device)
|
||||
# Mean |cos(old, new)| over modules = how much the mask direction
|
||||
# moved. Near 1.0 => stable hack subspace; low => v_act is chasing
|
||||
# a drifting target (the staleness this refresh is meant to fix).
|
||||
_ov = []
|
||||
for name, info in wrappers.items():
|
||||
info["layer"]._antipasto_v_act.data.copy_(_v[name].to(device))
|
||||
old = info["layer"]._antipasto_v_act
|
||||
new = _v[name].to(device, dtype=old.dtype)
|
||||
_ov.append((old @ new).abs() / (old.norm().clamp_min(1e-9) * new.norm().clamp_min(1e-9)))
|
||||
old.data.copy_(new)
|
||||
_act_overlap = float(torch.stack(_ov).mean())
|
||||
else:
|
||||
from .extract_vhack_grad import extract_v_hack
|
||||
_, _, raw_grads, _ = extract_v_hack(
|
||||
@@ -1648,8 +1657,10 @@ def main(cfg: Config) -> int:
|
||||
# refreshed. NOTE: this fires AFTER opt.step(), so if the model is
|
||||
# already diverging the re-extracted direction is extracted on a broken
|
||||
# model -- watch lp_t / ppl_t around the refresh step.
|
||||
_ov_str = f", basis_overlap_with_prev={_act_overlap:.3f}" if cfg.route2_mask == "act" else ""
|
||||
logger.info(f"route2 {cfg.route2_mask}-mask refreshed@step{step} "
|
||||
f"({len(wrappers)} modules, quarantine ablated during extract)")
|
||||
f"({len(wrappers)} modules, quarantine ablated during extract{_ov_str}) "
|
||||
f"SHOULD: overlap ~1 => stable hack subspace; low => v_act chasing a drifting target")
|
||||
if v_hack is not None and do_refresh:
|
||||
from .extract_vhack_grad import extract_v_hack
|
||||
if cfg.vhack_pairs_path is not None:
|
||||
@@ -1872,7 +1883,13 @@ def main(cfg: Config) -> int:
|
||||
ppl_t = math.exp(-lp_t_mean) if math.isfinite(lp_t_mean) else float("inf")
|
||||
if math.isfinite(lp_t_mean):
|
||||
lp_t_best = max(lp_t_best, lp_t_mean)
|
||||
diverged = math.isfinite(lp_t_mean) and lp_t_mean < lp_t_best - DIVERGENCE_DROP
|
||||
drop = lp_t_best - lp_t_mean if math.isfinite(lp_t_mean) else 0.0
|
||||
# Soft warning at a smaller drop than the hard abort -- an early "ppl is
|
||||
# climbing, watch for divergence (lr too high?)" before things are lost.
|
||||
if WARN_DROP <= drop < DIVERGENCE_DROP:
|
||||
logger.warning(f"step {step}: lp_t={lp_t_mean:.1f} is {drop:.1f} nats below best "
|
||||
f"{lp_t_best:.1f} (ppl_t={ppl_t:.0e}) -- coherence slipping, lr too high?")
|
||||
diverged = math.isfinite(lp_t_mean) and drop > DIVERGENCE_DROP
|
||||
diverged_steps = diverged_steps + 1 if diverged else 0
|
||||
if diverged_steps >= 2:
|
||||
logger.error(
|
||||
|
||||
Reference in New Issue
Block a user