mirror of
https://github.com/wassname/steer-heal-love.git
synced 2026-06-27 16:47:16 +08:00
heal: kl_agg knob (mean|rmse|p95|max) -- outlier-aggregate the per-position KL barrier
mean dilutes the few incoherent positions that carry the collapse: #101's token loops had mean per-position kl_rev ~0.38, under the tau=0.5 hinge, so the barrier never fired (journal h/i). Incoherence is outlier-driven, so rmse/p95/max are sensitive to it (scripts/diag_kl_agg.py synthetic: same loop = rmse 1.5 / p95 3.8 / max 8.1 vs coherent ~0.03; sep ratio grows 21x->58x->77x->85x from mean to max). rmse default for the new arm (smooth dense gradient). eps inside the sqrt: B=0 LoRA init zeros every kl_pos at step 0 and bare sqrt(0) has inf grad -> 0*nan. mean stays the config default = no change to existing runs. Queued as the next loop arm (kl_rev rmse, ref=base, tau=1.0). Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -63,6 +63,11 @@ class RunConfig:
|
||||
# (weights-space shrink, not a loss term), so the two compose: e.g. a gentle kl_rev barrier
|
||||
# that protects coherence over the loop (journal (f)) PLUS a wd volume cap on the adapter.
|
||||
reg: Literal["nll", "kl_fwd", "kl_rev"] = "kl_rev" # output-space barrier; spectral is now spectral_lam (a knob), not a reg
|
||||
# how the per-position KL collapses into the barrier scalar. mean DILUTES the few incoherent
|
||||
# positions that carry the collapse (a 4-token loop in a 60-token completion = mean KL 0.38 < tau=0.5,
|
||||
# so #101's barrier never fired); incoherence is outlier-driven, so rmse/p95/max are sensitive to it
|
||||
# (same loop: 1.5/3.8/8.1 vs coherent ~0.03). rmse = smooth dense gradient (train default), p95/max sparser.
|
||||
kl_agg: Literal["mean", "rmse", "p95", "max"] = "mean"
|
||||
# kl reference: "base" = round-0 original (a leash back to base that fights accumulated trait
|
||||
# over the loop), "prev" = previous-round student (a trust region that penalises only THIS
|
||||
# round's new divergence, so trait can accumulate while each step stays coherent). At round 0
|
||||
|
||||
+17
-2
@@ -25,6 +25,21 @@ def _kl_per_pos(logp_a, logp_b): # KL(a || b) summed over vocab, per position
|
||||
return (logp_a.exp() * (logp_a - logp_b)).sum(-1)
|
||||
|
||||
|
||||
def _agg_kl(kl_pos, how: str):
|
||||
"""Collapse per-position KL into the barrier scalar. mean DILUTES a few incoherent
|
||||
positions: a 4-token loop in a 60-token completion raised mean KL only to 0.38, under
|
||||
tau=0.5, so #101's barrier never fired on the collapse. Incoherence is outlier-driven
|
||||
(a handful of base-improbable spikes), so an outlier-sensitive aggregate catches it where
|
||||
mean cannot (same synthetic loop: rmse 1.5, p95 3.8, max 8.1 vs coherent ~0.03). rmse is
|
||||
smooth with dense gradient (best for training); p95/max are sparser (gradient to ~1 pos)."""
|
||||
if how == "mean": return kl_pos.mean()
|
||||
# +eps inside the sqrt: B=0 LoRA init makes every kl_pos exactly 0 at step 0, and bare
|
||||
# sqrt(0) has an infinite gradient (0/0), which the relu's zero-derivative turns into 0*nan.
|
||||
if how == "rmse": return (kl_pos.pow(2).mean() + 1e-8).sqrt()
|
||||
if how == "p95": return torch.quantile(kl_pos, 0.95)
|
||||
if how == "max": return kl_pos.max()
|
||||
|
||||
|
||||
def _spectral_div(lora, n_iter: int = 3) -> torch.Tensor:
|
||||
"""Mean operator norm σ_max(ΔW) over the adapter's layers, ΔW = (alpha/r)·B@A.
|
||||
|
||||
@@ -160,9 +175,9 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
|
||||
tgt = ids.input_ids[0, 1:]
|
||||
sft = F.nll_loss(logp[mask], tgt[mask])
|
||||
if cfg.reg == "kl_fwd":
|
||||
div = _kl_per_pos(logp0[mask], logp[mask]).mean()
|
||||
div = _agg_kl(_kl_per_pos(logp0[mask], logp[mask]), cfg.kl_agg)
|
||||
elif cfg.reg == "kl_rev":
|
||||
div = _kl_per_pos(logp[mask], logp0[mask]).mean()
|
||||
div = _agg_kl(_kl_per_pos(logp[mask], logp0[mask]), cfg.kl_agg)
|
||||
else:
|
||||
div = torch.zeros((), device=model.device) # nll
|
||||
barrier = lam_eff * torch.relu(div - cfg.tau)
|
||||
|
||||
Reference in New Issue
Block a user