heal: kl_agg knob (mean|rmse|p95|max) -- outlier-aggregate the per-position KL barrier

mean dilutes the few incoherent positions that carry the collapse: #101's token
loops had mean per-position kl_rev ~0.38, under the tau=0.5 hinge, so the barrier
never fired (journal h/i). Incoherence is outlier-driven, so rmse/p95/max are
sensitive to it (scripts/diag_kl_agg.py synthetic: same loop = rmse 1.5 / p95 3.8
/ max 8.1 vs coherent ~0.03; sep ratio grows 21x->58x->77x->85x from mean to max).
rmse default for the new arm (smooth dense gradient). eps inside the sqrt: B=0
LoRA init zeros every kl_pos at step 0 and bare sqrt(0) has inf grad -> 0*nan.
mean stays the config default = no change to existing runs. Queued as the next
loop arm (kl_rev rmse, ref=base, tau=1.0).

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-06 14:05:30 +08:00
parent 026de8fd74
commit 2b1d2b7493
3 changed files with 76 additions and 2 deletions
+54
View File
@@ -0,0 +1,54 @@
"""Why mean-KL is blind to the coherence collapse, and rmse/p95 are not (journal-supporting).
No GPU, no model: synthetic next-token distributions (ml-debug Part 3 loss-surface check).
A coherent-trait student shifts a little mass toward a base-PLAUSIBLE token at every position;
an incoherent student is base everywhere except a few positions that spike on a base-IMPROBABLE
token (a token loop). We aggregate the per-position KL the way the heal barrier does and show
that mean dilutes the loop under the hinge threshold while outlier aggregates catch it.
"""
import numpy as np
from tabulate import tabulate
rng = np.random.default_rng(0)
V, T = 200, 60 # vocab, positions in a completion
def softmax(z):
z = z - z.max(-1, keepdims=True)
e = np.exp(z)
return e / e.sum(-1, keepdims=True)
base_logits = rng.standard_normal((T, V))
p_ref = softmax(base_logits)
order = np.argsort(p_ref.mean(0))
trait_tok = order[len(order) // 2] # mid-prob = base-PLAUSIBLE (where coherent trait lands)
loop_tok = order[3] # near-lowest = base-IMPROBABLE (where a loop lands)
tl = base_logits.copy(); tl[:, trait_tok] += 1.6 # broad small shift, EVERY position
p_trait = softmax(tl)
ll = base_logits.copy()
for t in (12, 13, 14, 15): # 4 spiked positions out of 60
ll[t] = -10; ll[t, loop_tok] = 12.0
p_loop = softmax(ll)
def kl_pos(p, q): # per-position KL(student || base), vocab summed (as in heal._kl_per_pos)
return (p * (np.log(np.clip(p, 1e-9, 1)) - np.log(np.clip(q, 1e-9, 1)))).sum(-1)
AGGS = {"mean_t": lambda k: k.mean(),
"rmse_t": lambda k: np.sqrt((k ** 2).mean()),
"p95_t": lambda k: np.percentile(k, 95),
"max_t": lambda k: k.max()}
rows = []
for name, p in [("coherent trait", p_trait), ("incoherent loop", p_loop)]:
k = kl_pos(p, p_ref)
rows.append([name] + [f"{f(k):.3f}" for f in AGGS.values()])
rows.append(["sep ratio (loop/trait)"] +
[f"{f(kl_pos(p_loop, p_ref)) / f(kl_pos(p_trait, p_ref)):.1f}x" for f in AGGS.values()])
print(tabulate(rows, headers=["student (60 positions)", *AGGS], tablefmt="github"))
print("\nSHOULD: incoherent-loop mean_t KL ~0.38 sits UNDER a tau=0.5 hinge, so relu(mean-tau)=0 and the")
print("barrier never fires (the #101 collapse). The SAME loop has rmse_t ~1.5 / p95_t ~3.8, well over tau,")
print("so an outlier-aggregated barrier fires on it. If mean_t separated loop from trait as well as rmse_t,")
print("the outlier aggregation would buy nothing -- the point is the sep ratio GROWS from mean to rmse/p95.")
+5
View File
@@ -63,6 +63,11 @@ class RunConfig:
# (weights-space shrink, not a loss term), so the two compose: e.g. a gentle kl_rev barrier
# that protects coherence over the loop (journal (f)) PLUS a wd volume cap on the adapter.
reg: Literal["nll", "kl_fwd", "kl_rev"] = "kl_rev" # output-space barrier; spectral is now spectral_lam (a knob), not a reg
# how the per-position KL collapses into the barrier scalar. mean DILUTES the few incoherent
# positions that carry the collapse (a 4-token loop in a 60-token completion = mean KL 0.38 < tau=0.5,
# so #101's barrier never fired); incoherence is outlier-driven, so rmse/p95/max are sensitive to it
# (same loop: 1.5/3.8/8.1 vs coherent ~0.03). rmse = smooth dense gradient (train default), p95/max sparser.
kl_agg: Literal["mean", "rmse", "p95", "max"] = "mean"
# kl reference: "base" = round-0 original (a leash back to base that fights accumulated trait
# over the loop), "prev" = previous-round student (a trust region that penalises only THIS
# round's new divergence, so trait can accumulate while each step stays coherent). At round 0
+17 -2
View File
@@ -25,6 +25,21 @@ def _kl_per_pos(logp_a, logp_b): # KL(a || b) summed over vocab, per position
return (logp_a.exp() * (logp_a - logp_b)).sum(-1)
def _agg_kl(kl_pos, how: str):
"""Collapse per-position KL into the barrier scalar. mean DILUTES a few incoherent
positions: a 4-token loop in a 60-token completion raised mean KL only to 0.38, under
tau=0.5, so #101's barrier never fired on the collapse. Incoherence is outlier-driven
(a handful of base-improbable spikes), so an outlier-sensitive aggregate catches it where
mean cannot (same synthetic loop: rmse 1.5, p95 3.8, max 8.1 vs coherent ~0.03). rmse is
smooth with dense gradient (best for training); p95/max are sparser (gradient to ~1 pos)."""
if how == "mean": return kl_pos.mean()
# +eps inside the sqrt: B=0 LoRA init makes every kl_pos exactly 0 at step 0, and bare
# sqrt(0) has an infinite gradient (0/0), which the relu's zero-derivative turns into 0*nan.
if how == "rmse": return (kl_pos.pow(2).mean() + 1e-8).sqrt()
if how == "p95": return torch.quantile(kl_pos, 0.95)
if how == "max": return kl_pos.max()
def _spectral_div(lora, n_iter: int = 3) -> torch.Tensor:
"""Mean operator norm σ_max(ΔW) over the adapter's layers, ΔW = (alpha/r)·B@A.
@@ -160,9 +175,9 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
tgt = ids.input_ids[0, 1:]
sft = F.nll_loss(logp[mask], tgt[mask])
if cfg.reg == "kl_fwd":
div = _kl_per_pos(logp0[mask], logp[mask]).mean()
div = _agg_kl(_kl_per_pos(logp0[mask], logp[mask]), cfg.kl_agg)
elif cfg.reg == "kl_rev":
div = _kl_per_pos(logp[mask], logp0[mask]).mean()
div = _agg_kl(_kl_per_pos(logp[mask], logp0[mask]), cfg.kl_agg)
else:
div = torch.zeros((), device=model.device) # nll
barrier = lam_eff * torch.relu(div - cfg.tau)