mirror of
https://github.com/wassname/steer-heal-love.git
synced 2026-06-27 17:02:34 +08:00
heal: kl_agg knob (mean|rmse|p95|max) -- outlier-aggregate the per-position KL barrier
mean dilutes the few incoherent positions that carry the collapse: #101's token loops had mean per-position kl_rev ~0.38, under the tau=0.5 hinge, so the barrier never fired (journal h/i). Incoherence is outlier-driven, so rmse/p95/max are sensitive to it (scripts/diag_kl_agg.py synthetic: same loop = rmse 1.5 / p95 3.8 / max 8.1 vs coherent ~0.03; sep ratio grows 21x->58x->77x->85x from mean to max). rmse default for the new arm (smooth dense gradient). eps inside the sqrt: B=0 LoRA init zeros every kl_pos at step 0 and bare sqrt(0) has inf grad -> 0*nan. mean stays the config default = no change to existing runs. Queued as the next loop arm (kl_rev rmse, ref=base, tau=1.0). Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,54 @@
|
|||||||
|
"""Why mean-KL is blind to the coherence collapse, and rmse/p95 are not (journal-supporting).
|
||||||
|
|
||||||
|
No GPU, no model: synthetic next-token distributions (ml-debug Part 3 loss-surface check).
|
||||||
|
A coherent-trait student shifts a little mass toward a base-PLAUSIBLE token at every position;
|
||||||
|
an incoherent student is base everywhere except a few positions that spike on a base-IMPROBABLE
|
||||||
|
token (a token loop). We aggregate the per-position KL the way the heal barrier does and show
|
||||||
|
that mean dilutes the loop under the hinge threshold while outlier aggregates catch it.
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
from tabulate import tabulate
|
||||||
|
|
||||||
|
rng = np.random.default_rng(0)
|
||||||
|
V, T = 200, 60 # vocab, positions in a completion
|
||||||
|
|
||||||
|
|
||||||
|
def softmax(z):
|
||||||
|
z = z - z.max(-1, keepdims=True)
|
||||||
|
e = np.exp(z)
|
||||||
|
return e / e.sum(-1, keepdims=True)
|
||||||
|
|
||||||
|
|
||||||
|
base_logits = rng.standard_normal((T, V))
|
||||||
|
p_ref = softmax(base_logits)
|
||||||
|
order = np.argsort(p_ref.mean(0))
|
||||||
|
trait_tok = order[len(order) // 2] # mid-prob = base-PLAUSIBLE (where coherent trait lands)
|
||||||
|
loop_tok = order[3] # near-lowest = base-IMPROBABLE (where a loop lands)
|
||||||
|
|
||||||
|
tl = base_logits.copy(); tl[:, trait_tok] += 1.6 # broad small shift, EVERY position
|
||||||
|
p_trait = softmax(tl)
|
||||||
|
ll = base_logits.copy()
|
||||||
|
for t in (12, 13, 14, 15): # 4 spiked positions out of 60
|
||||||
|
ll[t] = -10; ll[t, loop_tok] = 12.0
|
||||||
|
p_loop = softmax(ll)
|
||||||
|
|
||||||
|
|
||||||
|
def kl_pos(p, q): # per-position KL(student || base), vocab summed (as in heal._kl_per_pos)
|
||||||
|
return (p * (np.log(np.clip(p, 1e-9, 1)) - np.log(np.clip(q, 1e-9, 1)))).sum(-1)
|
||||||
|
|
||||||
|
|
||||||
|
AGGS = {"mean_t": lambda k: k.mean(),
|
||||||
|
"rmse_t": lambda k: np.sqrt((k ** 2).mean()),
|
||||||
|
"p95_t": lambda k: np.percentile(k, 95),
|
||||||
|
"max_t": lambda k: k.max()}
|
||||||
|
rows = []
|
||||||
|
for name, p in [("coherent trait", p_trait), ("incoherent loop", p_loop)]:
|
||||||
|
k = kl_pos(p, p_ref)
|
||||||
|
rows.append([name] + [f"{f(k):.3f}" for f in AGGS.values()])
|
||||||
|
rows.append(["sep ratio (loop/trait)"] +
|
||||||
|
[f"{f(kl_pos(p_loop, p_ref)) / f(kl_pos(p_trait, p_ref)):.1f}x" for f in AGGS.values()])
|
||||||
|
print(tabulate(rows, headers=["student (60 positions)", *AGGS], tablefmt="github"))
|
||||||
|
print("\nSHOULD: incoherent-loop mean_t KL ~0.38 sits UNDER a tau=0.5 hinge, so relu(mean-tau)=0 and the")
|
||||||
|
print("barrier never fires (the #101 collapse). The SAME loop has rmse_t ~1.5 / p95_t ~3.8, well over tau,")
|
||||||
|
print("so an outlier-aggregated barrier fires on it. If mean_t separated loop from trait as well as rmse_t,")
|
||||||
|
print("the outlier aggregation would buy nothing -- the point is the sep ratio GROWS from mean to rmse/p95.")
|
||||||
@@ -63,6 +63,11 @@ class RunConfig:
|
|||||||
# (weights-space shrink, not a loss term), so the two compose: e.g. a gentle kl_rev barrier
|
# (weights-space shrink, not a loss term), so the two compose: e.g. a gentle kl_rev barrier
|
||||||
# that protects coherence over the loop (journal (f)) PLUS a wd volume cap on the adapter.
|
# that protects coherence over the loop (journal (f)) PLUS a wd volume cap on the adapter.
|
||||||
reg: Literal["nll", "kl_fwd", "kl_rev"] = "kl_rev" # output-space barrier; spectral is now spectral_lam (a knob), not a reg
|
reg: Literal["nll", "kl_fwd", "kl_rev"] = "kl_rev" # output-space barrier; spectral is now spectral_lam (a knob), not a reg
|
||||||
|
# how the per-position KL collapses into the barrier scalar. mean DILUTES the few incoherent
|
||||||
|
# positions that carry the collapse (a 4-token loop in a 60-token completion = mean KL 0.38 < tau=0.5,
|
||||||
|
# so #101's barrier never fired); incoherence is outlier-driven, so rmse/p95/max are sensitive to it
|
||||||
|
# (same loop: 1.5/3.8/8.1 vs coherent ~0.03). rmse = smooth dense gradient (train default), p95/max sparser.
|
||||||
|
kl_agg: Literal["mean", "rmse", "p95", "max"] = "mean"
|
||||||
# kl reference: "base" = round-0 original (a leash back to base that fights accumulated trait
|
# kl reference: "base" = round-0 original (a leash back to base that fights accumulated trait
|
||||||
# over the loop), "prev" = previous-round student (a trust region that penalises only THIS
|
# over the loop), "prev" = previous-round student (a trust region that penalises only THIS
|
||||||
# round's new divergence, so trait can accumulate while each step stays coherent). At round 0
|
# round's new divergence, so trait can accumulate while each step stays coherent). At round 0
|
||||||
|
|||||||
+17
-2
@@ -25,6 +25,21 @@ def _kl_per_pos(logp_a, logp_b): # KL(a || b) summed over vocab, per position
|
|||||||
return (logp_a.exp() * (logp_a - logp_b)).sum(-1)
|
return (logp_a.exp() * (logp_a - logp_b)).sum(-1)
|
||||||
|
|
||||||
|
|
||||||
|
def _agg_kl(kl_pos, how: str):
|
||||||
|
"""Collapse per-position KL into the barrier scalar. mean DILUTES a few incoherent
|
||||||
|
positions: a 4-token loop in a 60-token completion raised mean KL only to 0.38, under
|
||||||
|
tau=0.5, so #101's barrier never fired on the collapse. Incoherence is outlier-driven
|
||||||
|
(a handful of base-improbable spikes), so an outlier-sensitive aggregate catches it where
|
||||||
|
mean cannot (same synthetic loop: rmse 1.5, p95 3.8, max 8.1 vs coherent ~0.03). rmse is
|
||||||
|
smooth with dense gradient (best for training); p95/max are sparser (gradient to ~1 pos)."""
|
||||||
|
if how == "mean": return kl_pos.mean()
|
||||||
|
# +eps inside the sqrt: B=0 LoRA init makes every kl_pos exactly 0 at step 0, and bare
|
||||||
|
# sqrt(0) has an infinite gradient (0/0), which the relu's zero-derivative turns into 0*nan.
|
||||||
|
if how == "rmse": return (kl_pos.pow(2).mean() + 1e-8).sqrt()
|
||||||
|
if how == "p95": return torch.quantile(kl_pos, 0.95)
|
||||||
|
if how == "max": return kl_pos.max()
|
||||||
|
|
||||||
|
|
||||||
def _spectral_div(lora, n_iter: int = 3) -> torch.Tensor:
|
def _spectral_div(lora, n_iter: int = 3) -> torch.Tensor:
|
||||||
"""Mean operator norm σ_max(ΔW) over the adapter's layers, ΔW = (alpha/r)·B@A.
|
"""Mean operator norm σ_max(ΔW) over the adapter's layers, ΔW = (alpha/r)·B@A.
|
||||||
|
|
||||||
@@ -160,9 +175,9 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
|
|||||||
tgt = ids.input_ids[0, 1:]
|
tgt = ids.input_ids[0, 1:]
|
||||||
sft = F.nll_loss(logp[mask], tgt[mask])
|
sft = F.nll_loss(logp[mask], tgt[mask])
|
||||||
if cfg.reg == "kl_fwd":
|
if cfg.reg == "kl_fwd":
|
||||||
div = _kl_per_pos(logp0[mask], logp[mask]).mean()
|
div = _agg_kl(_kl_per_pos(logp0[mask], logp[mask]), cfg.kl_agg)
|
||||||
elif cfg.reg == "kl_rev":
|
elif cfg.reg == "kl_rev":
|
||||||
div = _kl_per_pos(logp[mask], logp0[mask]).mean()
|
div = _agg_kl(_kl_per_pos(logp[mask], logp0[mask]), cfg.kl_agg)
|
||||||
else:
|
else:
|
||||||
div = torch.zeros((), device=model.device) # nll
|
div = torch.zeros((), device=model.device) # nll
|
||||||
barrier = lam_eff * torch.relu(div - cfg.tau)
|
barrier = lam_eff * torch.relu(div - cfg.tau)
|
||||||
|
|||||||
Reference in New Issue
Block a user