Headline (gemma-3-4b-it s42, care-over-authority): aggregating the kl_rev barrier by rmse over token positions (not the mean) holds coherence flat at 0.997 across all 8 rounds, where the mean aggregate collapses to 0.62 by r7 (token loops). Mean dilutes the few incoherent positions under the tau gate; rmse is outlier-sensitive and fires on them. Cost is depth (rmse run leashes to base, trait stays shallow); matched control still running. - plot.py: coherence panel -> log-incoherence (1-coh, log axis, down=coherent); map coherence axis matches; red steer kept on the over-pipeline panels only. - heal.py: fix kl_agg=p95 crash (torch.quantile rejects bf16 -> .float()). - run.py: persist per-round adapter gens (adapter_gen) for the outputs table. - config.py: coh_floor early-stop knob. - README: results table (mean vs rmse), trajectory figure, outputs-over-loop appendix (per-round completions as quotes); spec persona corrected to pos-neg. - docs/reviews: kl_agg review, pool saturation test, care-lens plan. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
7.4 KiB
Overall
The code matches the stated idea: compute per-position KL (summed over vocab), then aggregate over positions with mean / RMSE / p95 / max, and feed that scalar into a ReLU hinge. The forward vs reverse KL dispatch is correct.
The only real correctness issue I see is around the RMSE epsilon in mixed-precision (fp16) training.
1. _agg_kl / RMSE epsilon and dispatch
Verdict: should-fix (if you ever run in fp16), otherwise fine.
-
Gradient at all-zero KL without eps
- Without eps,
rmse = sqrt(mean(kl_pos^2)). - At
kl_pos == 0for all positions,mean(kl_pos^2) = 0andrmse = 0. - The theoretical gradient of RMSE at 0 is not well-defined; the autograd implementation uses the chain rule:
d rmse / d mean_sq = 0.5 / sqrt(mean_sq)→0.5 / 0at 0 →inf/nand mean_sq / d kl_pos = 2 * kl_pos / n→ 0- In code, PyTorch does something like
grad_input = grad_output * 0.5 / sqrt(input). Withgrad_output = 0from the ReLU gate andsqrt(input)=0, you get0/0 -> nan.
- So the comment about the “infinite gradient (0/0)” at zero is essentially correct, and ReLU gating alone does not save you because the backward pass still evaluates
0 / sqrt(0)internally.
- Without eps,
-
Does the eps fix it?
if how == "rmse": return (kl_pos.pow(2).mean() + 1e-8).sqrt()-
In float32 / bfloat16,
1e-8is representable and > 0, so:- At init,
rmse = sqrt(1e-8) ≈ 1e-4 < tau, so the barrier is off. sqrtbackward sees input ≈ 1e-8, not 0, so no0/0and no NaNs.- The magnitude of
1e-8is negligible compared to typical KL scales O(1e-2–1e1).
- At init,
-
In float16,
1e-8underflows to 0 (min subnormal is ≈ 6e-8), so:(kl_pos.pow(2).mean() + 1e-8)at zero KL becomes exactly 0.- You are back to
sqrt(0)and the0/0gradient issue on the first backward. - So in fp16 the intended fix is ineffective.
Recommendation: Either
- enforce this computation in float32 (
kl_pos = kl_pos.float()before the RMSE), or - use a larger eps that survives fp16, e.g.
1e-6or even1e-5, and optionally cast to float32 for safety.
-
-
KL forward/reverse dispatch
if cfg.reg == "kl_fwd": div = _agg_kl(_kl_per_pos(logp0[mask], logp[mask]), cfg.kl_agg) # KL(base || student) elif cfg.reg == "kl_rev": div = _agg_kl(_kl_per_pos(logp[mask], logp0[mask]), cfg.kl_agg) # KL(student || base)This matches your convention (kl_rev = KL(student || base)). No bug here.
2. Gradient sparsity for p95 / max
Verdict: no correctness bug; design trade-off (nit).
torch.maxandtorch.quantiledo indeed send nonzero gradients to a small subset of positions:max: exactly the argmax position(s).quantile: positions at/around the quantile threshold.
- That is expected behavior; mathematically correct for what these operators mean.
- Given:
- The SFT loss provides dense gradients over all completion positions.
- The KL barrier is a regularizer meant to react to outlier tokens.
- It’s reasonable for the barrier’s gradient to be sparse without being a bug.
RMSE does give denser gradients and is a good default if you’re worried about optimization smoothness, but the sparsity of p95/max is not a correctness issue.
3. Tau scale across different aggregators
Verdict: not a bug, but user must retune tau (nit).
- Mean-KL vs RMSE/p95/max are on different numerical scales even for the same underlying per-position KL distribution.
- Your synthetic example shows ~20–80× larger values for RMSE/p95/max vs mean.
- Keeping a single
tauacross aggregators changes the effective trust region tightness.- E.g.
tau = 0.5tuned for mean-KL is much tighter if you switch to RMSE.
- E.g.
- This does not break correctness; it just means
tauis not directly comparable acrosskl_aggchoices. - For the intended run (rmse, barrier_ref=base):
- Synthetic coherent trait: RMSE ≈ 0.026
- Incoherent loop: RMSE ≈ 1.5
- So
tau ≈ 1.0is plausibly in the right ballpark to separate them.
Practical implication: users should expect to retune tau when changing kl_agg. That’s a configuration / documentation issue, not a code bug.
4. Axis of aggregation (positions vs vocab)
Verdict: implementation matches the stated idea (no issue).
-
_kl_per_posis:# KL(a || b) summed over vocab, per position return (logp_a.exp() * (logp_a - logp_b)).sum(-1) # shape: [positions] -
_agg_klthen reduces over the position axis (tokens in the completion), not over vocab. -
This is exactly what:
- Your synthetic experiment does: you computed stats over 60 positions.
- Your verbal description suggests: “a few base-improbable token positions spike”.
So the implemented interpretation is: “outliers = anomalous positions (timesteps) that have large KL(student||ref)”, which fits both the mechanism (loops at a few timesteps) and your earlier analysis.
An alternative “over vocab first” scheme (e.g. p95 over vocabulary contributions before summing across vocab) is conceptually different and doesn’t match your synthetic evidence. I don’t see a mathematical reason it would better capture incoherence than the current, standard “sum over vocab, then aggregate over positions” approach.
5. Does RMSE/p95 break the “nats” interpretation?
Verdict: conceptually fine; no math bug.
- Units:
- Per-position KL is in nats.
mean(kl_pos)is “expected KL per position” (still nats).rmse = sqrt(mean(kl_pos^2))also has units of nats (√(nats²) → nats), but it’s no longer an expectation—just an L2 norm of position-wise KLs.p95andmaxare quantiles/max of nats, also in nats.
- The hinge
relu(div - tau)and weightlam_effonly requiredivto be a scalar whose magnitude correlates with “how bad” divergence is; they don’t requiredivto be a true KL. - So you lose a clean probabilistic interpretation of
tauas “average KL in nats”, but you retain:- Monotonicity: higher outlier KLs → larger
div. - Correct physical units.
- Monotonicity: higher outlier KLs → larger
This is acceptable; you’re using div as a monotone, outlier-sensitive surrogate, not as a literal KL for e.g. information-theoretic bounds.
6. Other potential issues
-
RMSE eps in fp16
(already discussed under Q1)- Severity: should-fix.
- Fix: cast to float32 for the RMSE computation and/or use an eps that is representable in fp16.
-
Empty
maskedge case- If for some reason
maskhas zero True entries (no completion tokens), then:logp[mask]is shape[0, vocab]._kl_per_pos(...)⇒ empty tensor.mean,quantile, ormaxon an empty tensor will error or give NaN.
- If your data pipeline guarantees at least one completion token per example, this never happens. If not, this can blow up training.
Severity: should-fix if empty completions are possible; otherwise irrelevant.
- If for some reason
-
Dtype mismatch in
div = torch.zeros((), device=model.device)- If the model/logp is in half precision,
divdefaults to float32. PyTorch will upcast when adding, so this is safe but mildly inconsistent. - If you want strict consistency, you’d set
dtype=logp.dtype, but this is cosmetic.
Severity: nit.
- If the model/logp is in half precision,
Beyond the fp16 eps underflow and the possible empty-mask edge case, the math and gradient flow in the change look correct and aligned with your stated goal.