mirror of
https://github.com/wassname/isokl_steering_calibration.git
synced 2026-06-27 17:01:10 +08:00
feat(spaghetti): per-trajectory KL/p95 normalization + shared y-axis
- Add --normalize-kl / --calib-tokens (default on, 20 tokens) to spaghetti_kl_alive.py: each trajectory is divided by its own p95(KL[:calib_tokens]). Calibration target collapses to y=1. - Share y-axis across all alpha panels (global p99 ymax) for direct cross-alpha comparison. - Add OLMo-2 1B w=4096 figure to figs/ and a README section documenting the result, including the unexplained 'random dead traces at low alpha' observation.
This commit is contained in:
@@ -98,6 +98,44 @@ Out of scope:
|
||||
- A norm-matching baseline. Iso-KL says nothing about whether two methods of equal KL move the same distance in activation space.
|
||||
- A threshold sweep. The pmass < 0.95 cutoff is a single point; neighbours might disagree.
|
||||
|
||||
## Per-trajectory KL normalization (OLMo-2 1B, w=4096)
|
||||
|
||||
The figures above use raw per-token KL in nats. That works when every panel shares an alpha-scale, but cross-alpha comparison gets harder once trajectories vary in their own baseline KL noise (some prompts are just spikier than others, even at alpha = 1).
|
||||
|
||||
So we re-plot KL divided by each trajectory's own p95 over the first 20 tokens. The dashed reference at y = 1 is then the calibration scale by construction. Values above 1 are excursions above that trajectory's own early-rollout budget. The y-axis is shared across panels so the alpha sweep reads as one picture.
|
||||
|
||||

|
||||
|
||||
*Per-token `KL(steered || base) / p95(KL[:20])` for OLMo-2-0425-1B (mean_diff, w=4096, n=24 held-out prompts). One panel per alpha. Blue = alive at that token (`pmass_eval >= 0.75` at the nearest fork s <= t); orange = dead (irreversible). Black tick = EOS (right-censored). Dashed = calibration scale = 1. Black line = median across alive+dead trajectories. Threshold here is 0.75, looser than the 0.95 used for Qwen, because OLMo's early pmass is noisier.*
|
||||
|
||||
How to read it, in order:
|
||||
|
||||
1. *alpha = 0.* No steering, KL = 0, denom = 0, all trajectories skipped. Empty panel is correct.
|
||||
2. *alpha = 0.25 to 0.75.* Median sits well below 1. Most trajectories never approach the calibration scale even past 4000 tokens. The 1-nat budget set on a w = 4096 calibration window has long-horizon headroom for OLMo-2 1B.
|
||||
3. *alpha = 1.0.* Median hovers a bit under 1 with one outlier traveller above 2. Iso-KL is doing what it claims at the trajectory-median level on a different prompt set than calibration saw.
|
||||
4. *alpha = 1.5 to 2.0.* Median crosses 1 and stays there. Most trajectories die. KL is now sustainably over budget.
|
||||
5. *alpha = 4.0.* Every trajectory dies. Median rides above 1 throughout. Same instant-collapse story as the Qwen panel.
|
||||
|
||||
What surprised me, and what doesn't yet have an explanation:
|
||||
|
||||
- *Random dead traces at low alpha.* At alpha = 0.25 to 0.75, KL stays comfortably under 1 for almost every token, and yet 7 to 12 of 24 trajectories register a death event at some fork. They are not dying because the budget was exceeded; they are dying for some other reason that p95 KL does not see. Possibilities: a single high-KL spike at one token that p95 averages away; a low-KL but format-fragile region where small logit shifts knock pmass below 0.75; calibration noise on the pmass side (threshold 0.75 is loose, and OLMo's base pmass at some forks is already close to it). I don't know which yet.
|
||||
- *alpha = 1.0 has fewer deaths than alpha = 0.75 (11 vs 12).* Within noise at n = 24, but the monotonicity of "more steering = more death" is not crisp at this scale.
|
||||
- *Censoring drops to ~0 above alpha = 0.75.* Steering past the calibrated dose seems to suppress EOS, so trajectories run to the window length even when off-format. That makes died counts more comparable across high-alpha panels but means low-alpha "censored" is partly the model's own brevity, not the steering's.
|
||||
|
||||
So: per-trajectory normalization makes the cross-alpha comparison cleaner, and confirms iso-KL is roughly delivering its claimed dose on held-out prompts. It also makes the "KL is necessary but not sufficient" point sharper: the low-alpha panels show death events at well under the calibration scale.
|
||||
|
||||
Reproduce:
|
||||
|
||||
```bash
|
||||
uv run --extra all python scripts/spaghetti_kl_alive.py \
|
||||
--runs-root outputs/olmo1b_w4096_dense/_OLMo-2-0425-1B_mean_diff_s0_w4096_dense_n24_single \
|
||||
--out outputs/olmo1b_w4096_dense/OLMo-2-0425-1B_mean_diff_s0_w4096_dense_n24/figs_thr075_quantile_shaded/spaghetti_norm \
|
||||
--window 4096 --threshold 0.75 --model-contains OLMo-2-0425-1B \
|
||||
--roll 301 --line-lw 0.4
|
||||
```
|
||||
|
||||
Toggle off normalization with `--no-normalize-kl`; change the calibration window with `--calib-tokens N`.
|
||||
|
||||
## Honest caveats
|
||||
|
||||
- *Iso-KL is one fairness criterion, not the criterion.* Matching p95 per-token KL means matching distributional disagreement under greedy decoding in the calibration window. It does not match intervention norm, behavioural effect size, or human-perceived quality.
|
||||
|
||||
Reference in New Issue
Block a user