heal: kl_agg knob (mean|rmse|p95|max) -- outlier-aggregate the per-position KL barrier

mean dilutes the few incoherent positions that carry the collapse: #101's token
loops had mean per-position kl_rev ~0.38, under the tau=0.5 hinge, so the barrier
never fired (journal h/i). Incoherence is outlier-driven, so rmse/p95/max are
sensitive to it (scripts/diag_kl_agg.py synthetic: same loop = rmse 1.5 / p95 3.8
/ max 8.1 vs coherent ~0.03; sep ratio grows 21x->58x->77x->85x from mean to max).
rmse default for the new arm (smooth dense gradient). eps inside the sqrt: B=0
LoRA init zeros every kl_pos at step 0 and bare sqrt(0) has inf grad -> 0*nan.
mean stays the config default = no change to existing runs. Queued as the next
loop arm (kl_rev rmse, ref=base, tau=1.0).

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-06 14:05:30 +08:00
parent 026de8fd74
commit 2b1d2b7493
3 changed files with 76 additions and 2 deletions
+5
View File
@@ -63,6 +63,11 @@ class RunConfig:
# (weights-space shrink, not a loss term), so the two compose: e.g. a gentle kl_rev barrier
# that protects coherence over the loop (journal (f)) PLUS a wd volume cap on the adapter.
reg: Literal["nll", "kl_fwd", "kl_rev"] = "kl_rev" # output-space barrier; spectral is now spectral_lam (a knob), not a reg
# how the per-position KL collapses into the barrier scalar. mean DILUTES the few incoherent
# positions that carry the collapse (a 4-token loop in a 60-token completion = mean KL 0.38 < tau=0.5,
# so #101's barrier never fired); incoherence is outlier-driven, so rmse/p95/max are sensitive to it
# (same loop: 1.5/3.8/8.1 vs coherent ~0.03). rmse = smooth dense gradient (train default), p95/max sparser.
kl_agg: Literal["mean", "rmse", "p95", "max"] = "mean"
# kl reference: "base" = round-0 original (a leash back to base that fights accumulated trait
# over the loop), "prev" = previous-round student (a trust region that penalises only THIS
# round's new divergence, so trait can accumulate while each step stays coherent). At round 0
+17 -2
View File
@@ -25,6 +25,21 @@ def _kl_per_pos(logp_a, logp_b): # KL(a || b) summed over vocab, per position
return (logp_a.exp() * (logp_a - logp_b)).sum(-1)
def _agg_kl(kl_pos, how: str):
"""Collapse per-position KL into the barrier scalar. mean DILUTES a few incoherent
positions: a 4-token loop in a 60-token completion raised mean KL only to 0.38, under
tau=0.5, so #101's barrier never fired on the collapse. Incoherence is outlier-driven
(a handful of base-improbable spikes), so an outlier-sensitive aggregate catches it where
mean cannot (same synthetic loop: rmse 1.5, p95 3.8, max 8.1 vs coherent ~0.03). rmse is
smooth with dense gradient (best for training); p95/max are sparser (gradient to ~1 pos)."""
if how == "mean": return kl_pos.mean()
# +eps inside the sqrt: B=0 LoRA init makes every kl_pos exactly 0 at step 0, and bare
# sqrt(0) has an infinite gradient (0/0), which the relu's zero-derivative turns into 0*nan.
if how == "rmse": return (kl_pos.pow(2).mean() + 1e-8).sqrt()
if how == "p95": return torch.quantile(kl_pos, 0.95)
if how == "max": return kl_pos.max()
def _spectral_div(lora, n_iter: int = 3) -> torch.Tensor:
"""Mean operator norm σ_max(ΔW) over the adapter's layers, ΔW = (alpha/r)·B@A.
@@ -160,9 +175,9 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
tgt = ids.input_ids[0, 1:]
sft = F.nll_loss(logp[mask], tgt[mask])
if cfg.reg == "kl_fwd":
div = _kl_per_pos(logp0[mask], logp[mask]).mean()
div = _agg_kl(_kl_per_pos(logp0[mask], logp[mask]), cfg.kl_agg)
elif cfg.reg == "kl_rev":
div = _kl_per_pos(logp[mask], logp0[mask]).mean()
div = _agg_kl(_kl_per_pos(logp[mask], logp0[mask]), cfg.kl_agg)
else:
div = torch.zeros((), device=model.device) # nll
barrier = lam_eff * torch.relu(div - cfg.tau)