feat(spaghetti): per-trajectory KL/p95 normalization + shared y-axis

- Add --normalize-kl / --calib-tokens (default on, 20 tokens) to
  spaghetti_kl_alive.py: each trajectory is divided by its own
  p95(KL[:calib_tokens]). Calibration target collapses to y=1.
- Share y-axis across all alpha panels (global p99 ymax) for
  direct cross-alpha comparison.
- Add OLMo-2 1B w=4096 figure to figs/ and a README section
  documenting the result, including the unexplained 'random
  dead traces at low alpha' observation.
This commit is contained in:
wassname
2026-05-08 08:17:21 +08:00
parent 7d3fd37743
commit dbff43cffb
3 changed files with 105 additions and 16 deletions
+38
View File
@@ -98,6 +98,44 @@ Out of scope:
- A norm-matching baseline. Iso-KL says nothing about whether two methods of equal KL move the same distance in activation space.
- A threshold sweep. The pmass < 0.95 cutoff is a single point; neighbours might disagree.
## Per-trajectory KL normalization (OLMo-2 1B, w=4096)
The figures above use raw per-token KL in nats. That works when every panel shares an alpha-scale, but cross-alpha comparison gets harder once trajectories vary in their own baseline KL noise (some prompts are just spikier than others, even at alpha = 1).
So we re-plot KL divided by each trajectory's own p95 over the first 20 tokens. The dashed reference at y = 1 is then the calibration scale by construction. Values above 1 are excursions above that trajectory's own early-rollout budget. The y-axis is shared across panels so the alpha sweep reads as one picture.
![KL/p95 trajectories, OLMo-2 1B, w=4096, n=24](figs/spaghetti_olmo2_1b_w4096_norm.png)
*Per-token `KL(steered || base) / p95(KL[:20])` for OLMo-2-0425-1B (mean_diff, w=4096, n=24 held-out prompts). One panel per alpha. Blue = alive at that token (`pmass_eval >= 0.75` at the nearest fork s <= t); orange = dead (irreversible). Black tick = EOS (right-censored). Dashed = calibration scale = 1. Black line = median across alive+dead trajectories. Threshold here is 0.75, looser than the 0.95 used for Qwen, because OLMo's early pmass is noisier.*
How to read it, in order:
1. *alpha = 0.* No steering, KL = 0, denom = 0, all trajectories skipped. Empty panel is correct.
2. *alpha = 0.25 to 0.75.* Median sits well below 1. Most trajectories never approach the calibration scale even past 4000 tokens. The 1-nat budget set on a w = 4096 calibration window has long-horizon headroom for OLMo-2 1B.
3. *alpha = 1.0.* Median hovers a bit under 1 with one outlier traveller above 2. Iso-KL is doing what it claims at the trajectory-median level on a different prompt set than calibration saw.
4. *alpha = 1.5 to 2.0.* Median crosses 1 and stays there. Most trajectories die. KL is now sustainably over budget.
5. *alpha = 4.0.* Every trajectory dies. Median rides above 1 throughout. Same instant-collapse story as the Qwen panel.
What surprised me, and what doesn't yet have an explanation:
- *Random dead traces at low alpha.* At alpha = 0.25 to 0.75, KL stays comfortably under 1 for almost every token, and yet 7 to 12 of 24 trajectories register a death event at some fork. They are not dying because the budget was exceeded; they are dying for some other reason that p95 KL does not see. Possibilities: a single high-KL spike at one token that p95 averages away; a low-KL but format-fragile region where small logit shifts knock pmass below 0.75; calibration noise on the pmass side (threshold 0.75 is loose, and OLMo's base pmass at some forks is already close to it). I don't know which yet.
- *alpha = 1.0 has fewer deaths than alpha = 0.75 (11 vs 12).* Within noise at n = 24, but the monotonicity of "more steering = more death" is not crisp at this scale.
- *Censoring drops to ~0 above alpha = 0.75.* Steering past the calibrated dose seems to suppress EOS, so trajectories run to the window length even when off-format. That makes died counts more comparable across high-alpha panels but means low-alpha "censored" is partly the model's own brevity, not the steering's.
So: per-trajectory normalization makes the cross-alpha comparison cleaner, and confirms iso-KL is roughly delivering its claimed dose on held-out prompts. It also makes the "KL is necessary but not sufficient" point sharper: the low-alpha panels show death events at well under the calibration scale.
Reproduce:
```bash
uv run --extra all python scripts/spaghetti_kl_alive.py \
--runs-root outputs/olmo1b_w4096_dense/_OLMo-2-0425-1B_mean_diff_s0_w4096_dense_n24_single \
--out outputs/olmo1b_w4096_dense/OLMo-2-0425-1B_mean_diff_s0_w4096_dense_n24/figs_thr075_quantile_shaded/spaghetti_norm \
--window 4096 --threshold 0.75 --model-contains OLMo-2-0425-1B \
--roll 301 --line-lw 0.4
```
Toggle off normalization with `--no-normalize-kl`; change the calibration window with `--calib-tokens N`.
## Honest caveats
- *Iso-KL is one fairness criterion, not the criterion.* Matching p95 per-token KL means matching distributional disagreement under greedy decoding in the calibration window. It does not match intervention norm, behavioural effect size, or human-perceived quality.