route2 refresh basis-overlap log + soft ppl-drop warning

- route2 v_act/v_grad refresh now logs basis_overlap_with_prev (mean |cos| of
  old vs new mask direction) -- matches the clean-repo guard; a bare refresh bool
  carried no info, overlap shows if the mask chases a drifting target.
- divergence tripwire gets a soft logger.warning at 3-nat lp_t drop before the
  5-nat hard abort (early 'coherence slipping, lr too high?' heads-up).
- threshold note: healthy lp_t runs -0.5..-2.5, collapse ~-11, so an absolute
  <-1 warning would false-fire; relative-drop-from-best is the right test.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-01 00:39:43 +00:00
parent 11bcdd2fe6
commit 8ef78f6d14
+20 -3
View File
@@ -1091,6 +1091,7 @@ def main(cfg: Config) -> int:
# has an intrinsic lp_t ~ -11.9 (uniform logp) but it stays flat, so it never DROPS.
# Abort if lp_t falls this far below its best for 2 steps running (advantage dead).
DIVERGENCE_DROP = 5.0 # nats below best (e^5 ~ 150x worse ppl); never in healthy runs
WARN_DROP = 3.0 # softer: log a warning before the hard abort
dumped_hack_classes: set[str] = set() # first full example of each hack class -> verbose log
teacher_dumped = False
# Per-mode learning tracker (the substrate UAT: did the student learn EACH hack,
@@ -1625,8 +1626,16 @@ def main(cfg: Config) -> int:
if cfg.route2_mask == "act":
from .extract_vhack_grad import extract_v_act
_v = extract_v_act(model, tok, wrappers, MASK_PAIRS, n_heldout=2, device=device)
# Mean |cos(old, new)| over modules = how much the mask direction
# moved. Near 1.0 => stable hack subspace; low => v_act is chasing
# a drifting target (the staleness this refresh is meant to fix).
_ov = []
for name, info in wrappers.items():
info["layer"]._antipasto_v_act.data.copy_(_v[name].to(device))
old = info["layer"]._antipasto_v_act
new = _v[name].to(device, dtype=old.dtype)
_ov.append((old @ new).abs() / (old.norm().clamp_min(1e-9) * new.norm().clamp_min(1e-9)))
old.data.copy_(new)
_act_overlap = float(torch.stack(_ov).mean())
else:
from .extract_vhack_grad import extract_v_hack
_, _, raw_grads, _ = extract_v_hack(
@@ -1648,8 +1657,10 @@ def main(cfg: Config) -> int:
# refreshed. NOTE: this fires AFTER opt.step(), so if the model is
# already diverging the re-extracted direction is extracted on a broken
# model -- watch lp_t / ppl_t around the refresh step.
_ov_str = f", basis_overlap_with_prev={_act_overlap:.3f}" if cfg.route2_mask == "act" else ""
logger.info(f"route2 {cfg.route2_mask}-mask refreshed@step{step} "
f"({len(wrappers)} modules, quarantine ablated during extract)")
f"({len(wrappers)} modules, quarantine ablated during extract{_ov_str}) "
f"SHOULD: overlap ~1 => stable hack subspace; low => v_act chasing a drifting target")
if v_hack is not None and do_refresh:
from .extract_vhack_grad import extract_v_hack
if cfg.vhack_pairs_path is not None:
@@ -1872,7 +1883,13 @@ def main(cfg: Config) -> int:
ppl_t = math.exp(-lp_t_mean) if math.isfinite(lp_t_mean) else float("inf")
if math.isfinite(lp_t_mean):
lp_t_best = max(lp_t_best, lp_t_mean)
diverged = math.isfinite(lp_t_mean) and lp_t_mean < lp_t_best - DIVERGENCE_DROP
drop = lp_t_best - lp_t_mean if math.isfinite(lp_t_mean) else 0.0
# Soft warning at a smaller drop than the hard abort -- an early "ppl is
# climbing, watch for divergence (lr too high?)" before things are lost.
if WARN_DROP <= drop < DIVERGENCE_DROP:
logger.warning(f"step {step}: lp_t={lp_t_mean:.1f} is {drop:.1f} nats below best "
f"{lp_t_best:.1f} (ppl_t={ppl_t:.0e}) -- coherence slipping, lr too high?")
diverged = math.isfinite(lp_t_mean) and drop > DIVERGENCE_DROP
diverged_steps = diverged_steps + 1 if diverged else 0
if diverged_steps >= 2:
logger.error(