diff --git a/src/projected_grpo/train.py b/src/projected_grpo/train.py index 15f7f16..52f4c0c 100644 --- a/src/projected_grpo/train.py +++ b/src/projected_grpo/train.py @@ -1091,6 +1091,7 @@ def main(cfg: Config) -> int: # has an intrinsic lp_t ~ -11.9 (uniform logp) but it stays flat, so it never DROPS. # Abort if lp_t falls this far below its best for 2 steps running (advantage dead). DIVERGENCE_DROP = 5.0 # nats below best (e^5 ~ 150x worse ppl); never in healthy runs + WARN_DROP = 3.0 # softer: log a warning before the hard abort dumped_hack_classes: set[str] = set() # first full example of each hack class -> verbose log teacher_dumped = False # Per-mode learning tracker (the substrate UAT: did the student learn EACH hack, @@ -1625,8 +1626,16 @@ def main(cfg: Config) -> int: if cfg.route2_mask == "act": from .extract_vhack_grad import extract_v_act _v = extract_v_act(model, tok, wrappers, MASK_PAIRS, n_heldout=2, device=device) + # Mean |cos(old, new)| over modules = how much the mask direction + # moved. Near 1.0 => stable hack subspace; low => v_act is chasing + # a drifting target (the staleness this refresh is meant to fix). + _ov = [] for name, info in wrappers.items(): - info["layer"]._antipasto_v_act.data.copy_(_v[name].to(device)) + old = info["layer"]._antipasto_v_act + new = _v[name].to(device, dtype=old.dtype) + _ov.append((old @ new).abs() / (old.norm().clamp_min(1e-9) * new.norm().clamp_min(1e-9))) + old.data.copy_(new) + _act_overlap = float(torch.stack(_ov).mean()) else: from .extract_vhack_grad import extract_v_hack _, _, raw_grads, _ = extract_v_hack( @@ -1648,8 +1657,10 @@ def main(cfg: Config) -> int: # refreshed. NOTE: this fires AFTER opt.step(), so if the model is # already diverging the re-extracted direction is extracted on a broken # model -- watch lp_t / ppl_t around the refresh step. + _ov_str = f", basis_overlap_with_prev={_act_overlap:.3f}" if cfg.route2_mask == "act" else "" logger.info(f"route2 {cfg.route2_mask}-mask refreshed@step{step} " - f"({len(wrappers)} modules, quarantine ablated during extract)") + f"({len(wrappers)} modules, quarantine ablated during extract{_ov_str}) " + f"SHOULD: overlap ~1 => stable hack subspace; low => v_act chasing a drifting target") if v_hack is not None and do_refresh: from .extract_vhack_grad import extract_v_hack if cfg.vhack_pairs_path is not None: @@ -1872,7 +1883,13 @@ def main(cfg: Config) -> int: ppl_t = math.exp(-lp_t_mean) if math.isfinite(lp_t_mean) else float("inf") if math.isfinite(lp_t_mean): lp_t_best = max(lp_t_best, lp_t_mean) - diverged = math.isfinite(lp_t_mean) and lp_t_mean < lp_t_best - DIVERGENCE_DROP + drop = lp_t_best - lp_t_mean if math.isfinite(lp_t_mean) else 0.0 + # Soft warning at a smaller drop than the hard abort -- an early "ppl is + # climbing, watch for divergence (lr too high?)" before things are lost. + if WARN_DROP <= drop < DIVERGENCE_DROP: + logger.warning(f"step {step}: lp_t={lp_t_mean:.1f} is {drop:.1f} nats below best " + f"{lp_t_best:.1f} (ppl_t={ppl_t:.0e}) -- coherence slipping, lr too high?") + diverged = math.isfinite(lp_t_mean) and drop > DIVERGENCE_DROP diverged_steps = diverged_steps + 1 if diverged else 0 if diverged_steps >= 2: logger.error(