Files
steer-heal-love/docs/reviews/kl_agg_review.md
T
wassname 48814897ef results: rmse outlier-KL barrier holds coherence over the loop; README + log-incoherence plot
Headline (gemma-3-4b-it s42, care-over-authority): aggregating the kl_rev
barrier by rmse over token positions (not the mean) holds coherence flat at
0.997 across all 8 rounds, where the mean aggregate collapses to 0.62 by r7
(token loops). Mean dilutes the few incoherent positions under the tau gate;
rmse is outlier-sensitive and fires on them. Cost is depth (rmse run leashes
to base, trait stays shallow); matched control still running.

- plot.py: coherence panel -> log-incoherence (1-coh, log axis, down=coherent);
  map coherence axis matches; red steer kept on the over-pipeline panels only.
- heal.py: fix kl_agg=p95 crash (torch.quantile rejects bf16 -> .float()).
- run.py: persist per-round adapter gens (adapter_gen) for the outputs table.
- config.py: coh_floor early-stop knob.
- README: results table (mean vs rmse), trajectory figure, outputs-over-loop
  appendix (per-round completions as quotes); spec persona corrected to pos-neg.
- docs/reviews: kl_agg review, pool saturation test, care-lens plan.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
2026-06-07 07:53:00 +08:00

7.4 KiB
Raw Blame History

Overall

The code matches the stated idea: compute per-position KL (summed over vocab), then aggregate over positions with mean / RMSE / p95 / max, and feed that scalar into a ReLU hinge. The forward vs reverse KL dispatch is correct.

The only real correctness issue I see is around the RMSE epsilon in mixed-precision (fp16) training.


1. _agg_kl / RMSE epsilon and dispatch

Verdict: should-fix (if you ever run in fp16), otherwise fine.

  • Gradient at all-zero KL without eps

    • Without eps, rmse = sqrt(mean(kl_pos^2)).
    • At kl_pos == 0 for all positions, mean(kl_pos^2) = 0 and rmse = 0.
    • The theoretical gradient of RMSE at 0 is not well-defined; the autograd implementation uses the chain rule:
      • d rmse / d mean_sq = 0.5 / sqrt(mean_sq)0.5 / 0 at 0 → inf/nan
      • d mean_sq / d kl_pos = 2 * kl_pos / n → 0
      • In code, PyTorch does something like grad_input = grad_output * 0.5 / sqrt(input). With grad_output = 0 from the ReLU gate and sqrt(input)=0, you get 0/0 -> nan.
    • So the comment about the “infinite gradient (0/0)” at zero is essentially correct, and ReLU gating alone does not save you because the backward pass still evaluates 0 / sqrt(0) internally.
  • Does the eps fix it?

    if how == "rmse": return (kl_pos.pow(2).mean() + 1e-8).sqrt()
    
    • In float32 / bfloat16, 1e-8 is representable and > 0, so:

      • At init, rmse = sqrt(1e-8) ≈ 1e-4 < tau, so the barrier is off.
      • sqrt backward sees input ≈ 1e-8, not 0, so no 0/0 and no NaNs.
      • The magnitude of 1e-8 is negligible compared to typical KL scales O(1e-21e1).
    • In float16, 1e-8 underflows to 0 (min subnormal is ≈ 6e-8), so:

      • (kl_pos.pow(2).mean() + 1e-8) at zero KL becomes exactly 0.
      • You are back to sqrt(0) and the 0/0 gradient issue on the first backward.
      • So in fp16 the intended fix is ineffective.

    Recommendation: Either

    • enforce this computation in float32 (kl_pos = kl_pos.float() before the RMSE), or
    • use a larger eps that survives fp16, e.g. 1e-6 or even 1e-5, and optionally cast to float32 for safety.
  • KL forward/reverse dispatch

    if cfg.reg == "kl_fwd":
        div = _agg_kl(_kl_per_pos(logp0[mask], logp[mask]), cfg.kl_agg)  # KL(base || student)
    elif cfg.reg == "kl_rev":
        div = _agg_kl(_kl_per_pos(logp[mask], logp0[mask]), cfg.kl_agg)  # KL(student || base)
    

    This matches your convention (kl_rev = KL(student || base)). No bug here.


2. Gradient sparsity for p95 / max

Verdict: no correctness bug; design trade-off (nit).

  • torch.max and torch.quantile do indeed send nonzero gradients to a small subset of positions:
    • max: exactly the argmax position(s).
    • quantile: positions at/around the quantile threshold.
  • That is expected behavior; mathematically correct for what these operators mean.
  • Given:
    • The SFT loss provides dense gradients over all completion positions.
    • The KL barrier is a regularizer meant to react to outlier tokens.
  • Its reasonable for the barriers gradient to be sparse without being a bug.

RMSE does give denser gradients and is a good default if youre worried about optimization smoothness, but the sparsity of p95/max is not a correctness issue.


3. Tau scale across different aggregators

Verdict: not a bug, but user must retune tau (nit).

  • Mean-KL vs RMSE/p95/max are on different numerical scales even for the same underlying per-position KL distribution.
    • Your synthetic example shows ~2080× larger values for RMSE/p95/max vs mean.
  • Keeping a single tau across aggregators changes the effective trust region tightness.
    • E.g. tau = 0.5 tuned for mean-KL is much tighter if you switch to RMSE.
  • This does not break correctness; it just means tau is not directly comparable across kl_agg choices.
  • For the intended run (rmse, barrier_ref=base):
    • Synthetic coherent trait: RMSE ≈ 0.026
    • Incoherent loop: RMSE ≈ 1.5
    • So tau ≈ 1.0 is plausibly in the right ballpark to separate them.

Practical implication: users should expect to retune tau when changing kl_agg. Thats a configuration / documentation issue, not a code bug.


4. Axis of aggregation (positions vs vocab)

Verdict: implementation matches the stated idea (no issue).

  • _kl_per_pos is:

    # KL(a || b) summed over vocab, per position
    return (logp_a.exp() * (logp_a - logp_b)).sum(-1)  # shape: [positions]
    
  • _agg_kl then reduces over the position axis (tokens in the completion), not over vocab.

  • This is exactly what:

    • Your synthetic experiment does: you computed stats over 60 positions.
    • Your verbal description suggests: “a few base-improbable token positions spike”.

So the implemented interpretation is: “outliers = anomalous positions (timesteps) that have large KL(student||ref)”, which fits both the mechanism (loops at a few timesteps) and your earlier analysis.

An alternative “over vocab first” scheme (e.g. p95 over vocabulary contributions before summing across vocab) is conceptually different and doesnt match your synthetic evidence. I dont see a mathematical reason it would better capture incoherence than the current, standard “sum over vocab, then aggregate over positions” approach.


5. Does RMSE/p95 break the “nats” interpretation?

Verdict: conceptually fine; no math bug.

  • Units:
    • Per-position KL is in nats.
    • mean(kl_pos) is “expected KL per position” (still nats).
    • rmse = sqrt(mean(kl_pos^2)) also has units of nats (√(nats²) → nats), but its no longer an expectation—just an L2 norm of position-wise KLs.
    • p95 and max are quantiles/max of nats, also in nats.
  • The hinge relu(div - tau) and weight lam_eff only require div to be a scalar whose magnitude correlates with “how bad” divergence is; they dont require div to be a true KL.
  • So you lose a clean probabilistic interpretation of tau as “average KL in nats”, but you retain:
    • Monotonicity: higher outlier KLs → larger div.
    • Correct physical units.

This is acceptable; youre using div as a monotone, outlier-sensitive surrogate, not as a literal KL for e.g. information-theoretic bounds.


6. Other potential issues

  1. RMSE eps in fp16
    (already discussed under Q1)

    • Severity: should-fix.
    • Fix: cast to float32 for the RMSE computation and/or use an eps that is representable in fp16.
  2. Empty mask edge case

    • If for some reason mask has zero True entries (no completion tokens), then:
      • logp[mask] is shape [0, vocab].
      • _kl_per_pos(...) ⇒ empty tensor.
      • mean, quantile, or max on an empty tensor will error or give NaN.
    • If your data pipeline guarantees at least one completion token per example, this never happens. If not, this can blow up training.

    Severity: should-fix if empty completions are possible; otherwise irrelevant.

  3. Dtype mismatch in div = torch.zeros((), device=model.device)

    • If the model/logp is in half precision, div defaults to float32. PyTorch will upcast when adding, so this is safe but mildly inconsistent.
    • If you want strict consistency, youd set dtype=logp.dtype, but this is cosmetic.

    Severity: nit.

Beyond the fp16 eps underflow and the possible empty-mask edge case, the math and gradient flow in the change look correct and aligned with your stated goal.