From 2b1d2b7493970584b780a7963a50adc8e8374673 Mon Sep 17 00:00:00 2001 From: wassname <1103714+wassname@users.noreply.github.com> Date: Sat, 6 Jun 2026 14:05:30 +0800 Subject: [PATCH] heal: kl_agg knob (mean|rmse|p95|max) -- outlier-aggregate the per-position KL barrier mean dilutes the few incoherent positions that carry the collapse: #101's token loops had mean per-position kl_rev ~0.38, under the tau=0.5 hinge, so the barrier never fired (journal h/i). Incoherence is outlier-driven, so rmse/p95/max are sensitive to it (scripts/diag_kl_agg.py synthetic: same loop = rmse 1.5 / p95 3.8 / max 8.1 vs coherent ~0.03; sep ratio grows 21x->58x->77x->85x from mean to max). rmse default for the new arm (smooth dense gradient). eps inside the sqrt: B=0 LoRA init zeros every kl_pos at step 0 and bare sqrt(0) has inf grad -> 0*nan. mean stays the config default = no change to existing runs. Queued as the next loop arm (kl_rev rmse, ref=base, tau=1.0). Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com> --- scripts/diag_kl_agg.py | 54 ++++++++++++++++++++++++++++++++++++++++ src/steer_heal/config.py | 5 ++++ src/steer_heal/heal.py | 19 ++++++++++++-- 3 files changed, 76 insertions(+), 2 deletions(-) create mode 100644 scripts/diag_kl_agg.py diff --git a/scripts/diag_kl_agg.py b/scripts/diag_kl_agg.py new file mode 100644 index 0000000..a207377 --- /dev/null +++ b/scripts/diag_kl_agg.py @@ -0,0 +1,54 @@ +"""Why mean-KL is blind to the coherence collapse, and rmse/p95 are not (journal-supporting). + +No GPU, no model: synthetic next-token distributions (ml-debug Part 3 loss-surface check). +A coherent-trait student shifts a little mass toward a base-PLAUSIBLE token at every position; +an incoherent student is base everywhere except a few positions that spike on a base-IMPROBABLE +token (a token loop). We aggregate the per-position KL the way the heal barrier does and show +that mean dilutes the loop under the hinge threshold while outlier aggregates catch it. +""" +import numpy as np +from tabulate import tabulate + +rng = np.random.default_rng(0) +V, T = 200, 60 # vocab, positions in a completion + + +def softmax(z): + z = z - z.max(-1, keepdims=True) + e = np.exp(z) + return e / e.sum(-1, keepdims=True) + + +base_logits = rng.standard_normal((T, V)) +p_ref = softmax(base_logits) +order = np.argsort(p_ref.mean(0)) +trait_tok = order[len(order) // 2] # mid-prob = base-PLAUSIBLE (where coherent trait lands) +loop_tok = order[3] # near-lowest = base-IMPROBABLE (where a loop lands) + +tl = base_logits.copy(); tl[:, trait_tok] += 1.6 # broad small shift, EVERY position +p_trait = softmax(tl) +ll = base_logits.copy() +for t in (12, 13, 14, 15): # 4 spiked positions out of 60 + ll[t] = -10; ll[t, loop_tok] = 12.0 +p_loop = softmax(ll) + + +def kl_pos(p, q): # per-position KL(student || base), vocab summed (as in heal._kl_per_pos) + return (p * (np.log(np.clip(p, 1e-9, 1)) - np.log(np.clip(q, 1e-9, 1)))).sum(-1) + + +AGGS = {"mean_t": lambda k: k.mean(), + "rmse_t": lambda k: np.sqrt((k ** 2).mean()), + "p95_t": lambda k: np.percentile(k, 95), + "max_t": lambda k: k.max()} +rows = [] +for name, p in [("coherent trait", p_trait), ("incoherent loop", p_loop)]: + k = kl_pos(p, p_ref) + rows.append([name] + [f"{f(k):.3f}" for f in AGGS.values()]) +rows.append(["sep ratio (loop/trait)"] + + [f"{f(kl_pos(p_loop, p_ref)) / f(kl_pos(p_trait, p_ref)):.1f}x" for f in AGGS.values()]) +print(tabulate(rows, headers=["student (60 positions)", *AGGS], tablefmt="github")) +print("\nSHOULD: incoherent-loop mean_t KL ~0.38 sits UNDER a tau=0.5 hinge, so relu(mean-tau)=0 and the") +print("barrier never fires (the #101 collapse). The SAME loop has rmse_t ~1.5 / p95_t ~3.8, well over tau,") +print("so an outlier-aggregated barrier fires on it. If mean_t separated loop from trait as well as rmse_t,") +print("the outlier aggregation would buy nothing -- the point is the sep ratio GROWS from mean to rmse/p95.") diff --git a/src/steer_heal/config.py b/src/steer_heal/config.py index 62f70b0..bf28aef 100644 --- a/src/steer_heal/config.py +++ b/src/steer_heal/config.py @@ -63,6 +63,11 @@ class RunConfig: # (weights-space shrink, not a loss term), so the two compose: e.g. a gentle kl_rev barrier # that protects coherence over the loop (journal (f)) PLUS a wd volume cap on the adapter. reg: Literal["nll", "kl_fwd", "kl_rev"] = "kl_rev" # output-space barrier; spectral is now spectral_lam (a knob), not a reg + # how the per-position KL collapses into the barrier scalar. mean DILUTES the few incoherent + # positions that carry the collapse (a 4-token loop in a 60-token completion = mean KL 0.38 < tau=0.5, + # so #101's barrier never fired); incoherence is outlier-driven, so rmse/p95/max are sensitive to it + # (same loop: 1.5/3.8/8.1 vs coherent ~0.03). rmse = smooth dense gradient (train default), p95/max sparser. + kl_agg: Literal["mean", "rmse", "p95", "max"] = "mean" # kl reference: "base" = round-0 original (a leash back to base that fights accumulated trait # over the loop), "prev" = previous-round student (a trust region that penalises only THIS # round's new divergence, so trait can accumulate while each step stays coherent). At round 0 diff --git a/src/steer_heal/heal.py b/src/steer_heal/heal.py index 9a9c212..c282d75 100644 --- a/src/steer_heal/heal.py +++ b/src/steer_heal/heal.py @@ -25,6 +25,21 @@ def _kl_per_pos(logp_a, logp_b): # KL(a || b) summed over vocab, per position return (logp_a.exp() * (logp_a - logp_b)).sum(-1) +def _agg_kl(kl_pos, how: str): + """Collapse per-position KL into the barrier scalar. mean DILUTES a few incoherent + positions: a 4-token loop in a 60-token completion raised mean KL only to 0.38, under + tau=0.5, so #101's barrier never fired on the collapse. Incoherence is outlier-driven + (a handful of base-improbable spikes), so an outlier-sensitive aggregate catches it where + mean cannot (same synthetic loop: rmse 1.5, p95 3.8, max 8.1 vs coherent ~0.03). rmse is + smooth with dense gradient (best for training); p95/max are sparser (gradient to ~1 pos).""" + if how == "mean": return kl_pos.mean() + # +eps inside the sqrt: B=0 LoRA init makes every kl_pos exactly 0 at step 0, and bare + # sqrt(0) has an infinite gradient (0/0), which the relu's zero-derivative turns into 0*nan. + if how == "rmse": return (kl_pos.pow(2).mean() + 1e-8).sqrt() + if how == "p95": return torch.quantile(kl_pos, 0.95) + if how == "max": return kl_pos.max() + + def _spectral_div(lora, n_iter: int = 3) -> torch.Tensor: """Mean operator norm σ_max(ΔW) over the adapter's layers, ΔW = (alpha/r)·B@A. @@ -160,9 +175,9 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg: tgt = ids.input_ids[0, 1:] sft = F.nll_loss(logp[mask], tgt[mask]) if cfg.reg == "kl_fwd": - div = _kl_per_pos(logp0[mask], logp[mask]).mean() + div = _agg_kl(_kl_per_pos(logp0[mask], logp[mask]), cfg.kl_agg) elif cfg.reg == "kl_rev": - div = _kl_per_pos(logp[mask], logp0[mask]).mean() + div = _agg_kl(_kl_per_pos(logp[mask], logp0[mask]), cfg.kl_agg) else: div = torch.zeros((), device=model.device) # nll barrier = lam_eff * torch.relu(div - cfg.tau)