Replace README with short LW-style version; long form moved to README_long.md

This commit is contained in:
wassname
2026-05-08 11:22:25 +08:00
parent dbff43cffb
commit 2a69829612
2 changed files with 286 additions and 128 deletions
+124 -128
View File
@@ -1,162 +1,158 @@
# iso-kl-figure # iso-kl-figure (short version)
How strongly should we steer a model? How do we compare steering methods when one might be strong and one weak? These are calibration questions. ## The problem
Treat steering as an intervention: we want maximum behavior change with minimum side effects, like incoherence, format collapse, or random off-target damage. Activation steering has a knob. You pick a steering direction, multiply
it by a coefficient, and add the result into one residual stream. Small
coefficient: nothing happens. Large coefficient: the model breaks.
Somewhere in between is what you want.
A useful analogy is a car on the road. A small nudge to the steering wheel gets corrected by the driver. A larger nudge causes a lane change. A very large nudge causes a crash the driver cannot recover from. Some crashes happen immediately. Some take a few seconds to develop, after the driver tries and fails to correct. Most papers either pick one coefficient per method, or sweep coefficients
without normalising across methods. So when they say "method A is better
than method B" what you actually learn is "whoever turned A's knob got
more careful tuning than B's."
We measure the distribution shift caused by steering, especially the worst 5% of per-token KL that would cause a "crash", and find the largest scalar coefficient C that keeps that 5% below a safe threshold (default: 1 nat). At `alpha = 1` the steering is delivering exactly that calibrated dose. At `alpha = 2` it is doing twice the dose. *Iso-KL* means: comparing two methods at the same alpha = 1 means they are spending the same per-token KL budget, so any behavioural difference is not just one of them being louder. I want to compare methods more fairly. The natural way is to spend the
same "intervention budget" across methods, then ask what the budget bought
you in behaviour.
Then comes the survival question. If the calibration only looked at the first w tokens of the rollout, do crashes still happen later? Most trajectories either stabilize or go off track within the first 50-100 tokens, but we want to see the long tail too. So at every fork t we ask: can the model still produce one of two valid answer tokens (`true` or `false`) when forced into a JSON-schema prefill? If yes, *alive* at t. If the probability mass on those tokens drops below 0.95, *dead*, and dead stays dead. Survival S(t) is the fraction still alive at token t, with right-censoring for rollouts that hit EOS before t. ## The attempt: iso-KL calibration
Headline result on Qwen3.5-0.8B (`mean_diff`, w=512, n=8 held-out prompts): Pick the budget as a per-token KL divergence between the steered and base
distributions. Concretely:
- `alpha` in {0, 0.25, 0.5, 0.75, 1.0, 1.5}: 0 of 8 trajectories die at any fork. 1. Run the steered model on a calibration prompt for w tokens.
- `alpha = 2.0`: 8 of 8 die. Median death time t = 64. 2. Compute per-token `KL(steered || base)`.
- `alpha = 4.0`: 8 of 8 die at t = 0. 3. Bisect on the coefficient until the 95th percentile of those KLs equals
1 nat. Call the result `c_star`.
4. Define `alpha = 1` to mean "you are spending the calibrated budget."
`alpha = 2` is twice the dose, etc.
A phase transition at roughly 2x the iso-KL coefficient. The calibrated dose has long-horizon headroom; doubling the dose is a slow crash; quadrupling it is an instant one. Figures and table below. So at alpha = 1, two methods are spending the same per-token KL. Any
behavioural difference between them is not just one of them being louder.
Then sweep alpha from 0 to 4 and look at what happens.
## Spaghetti: per-token KL trajectories, coloured by survival ## Alive vs dead
![KL trajectories coloured by survival, Qwen3.5-0.8B](figs/post/spaghetti_qwen35_w512.png) Before showing the figure, one more piece. KL tells you how much the
distribution moved, but not whether the model still works. So I track a
separate signal: can the model still answer a yes/no question?
*Per-token KL(steered || base) for n=8 held-out long-form prompts on Qwen3.5-0.8B (mean_diff, w=512). One panel per alpha. Each thin line is one trajectory; green = alive at that token (pmass_eval >= 0.95), red = dead. Black line is the median. Dotted horizontal: KL = 1 nat (the calibration target).* At every fork token along the rollout I splice in a JSON schema prefill
(`'\nI should answer now.</think>{"value": '`) and check whether the
steered model puts probability mass `>= 0.75` on one of `{true, false}`.
If yes, the model is coherent enough at that point to commit to a
boolean answer: *alive*. If no, it can't even pick between true and
false when handed the schema on a plate: *dead*. Once dead, dead stays
dead. So "dead" is shorthand for "the model is no longer coherent enough
to answer the question, and isn't coming back."
How to read this plot, in order: ## The figure
1. *alpha = 0.* KL is exactly zero everywhere. This is the base model; nothing has been added. A sanity check that the splicing pipeline does not itself perturb the logits. OLMo-2 1B, mean-diff steering, w = 4096 calibration window, n = 24
2. *alpha = 0.25 to 1.0.* KL stays well below the 1-nat line for most tokens. The calibration window only required p95 = 1 nat, so most tokens are far below. All trajectories are green: the model is still producing schema-valid answers throughout. The "warm but not hot" zone. held-out prompts (different prompts than calibration saw). One panel per
3. *alpha = 1.5.* KL still mostly below 1 nat but with a fatter upper envelope. All trajectories still alive end-to-end on the held-out set, even though calibration only guaranteed survival up to roughly `alpha = 1`. Suggests headroom. alpha.
4. *alpha = 2.0.* KL median sits around 1 nat for most of the rollout, with a noisy plateau. The trajectories turn red part-way through. Steering is now strong enough to break format. The phase transition.
5. *alpha = 4.0.* Every trajectory is red from t = 0 and KL is roughly 4-6 nats throughout. The model is no longer answering the question; it is generating something else.
The diagnostic shape is the alpha = 2 panel: KL is only modestly above the calibration target (less than 2x in nats) but format collapse is total. KL is necessary but not sufficient as a steering budget. Going from "calibrated" to "calibrated x 2" is not a smooth degradation; it crosses a cliff. Y-axis is per-token KL divided by each trajectory's *own* p95 over the
first 20 tokens, so the dashed line at y = 1 is the calibration scale by
construction. Anything above 1 means the trajectory is over its own
early-rollout budget. Blue = alive, orange = dead, black ticks = EOS
(rollout finished naturally before the window ended).
## Survival: when do trajectories die? ![](figs/spaghetti_olmo2_1b_w4096_norm.png)
![Kaplan-Meier survival on pmass_eval, Qwen3.5-0.8B](figs/post/survival_qwen35_w512.png) How to read the panels, in order:
*Kaplan-Meier S(t) on the pmass_eval death event (pmass_eval < 0.95). One curve per alpha, n = 8 trajectories each. Dotted vertical: t = 20, the upper end of the calibration window's relevance for the answer span; everything to the right is generalization.* - **alpha = 0.** No steering. KL is exactly zero, denominator is zero, all
trajectories are skipped. Empty panel is the correct outcome.
- **alpha = 0.25 to 0.75.** Median KL sits well below 1 for the whole
rollout. The 1-nat budget set on a w = 4096 calibration window has long-
horizon headroom on a different prompt set.
- **alpha = 1.0.** Median hovers a bit under 1, with one outlier above 2.
Iso-KL is roughly delivering its claimed dose.
- **alpha = 1.5 to 2.0.** Median crosses 1 and stays there. Most
trajectories die. KL is sustainably over budget.
- **alpha = 4.0.** Everyone dies, fast.
| alpha | n | died | censored | S(mid) | S(end) | t at S<=0.5 | ## What works, what doesn't
|------:|--:|-----:|---------:|-------:|-------:|------------:|
| 0.00 | 8 | 0 | 0 | 1.000 | 1.000 | -- |
| 0.25 | 8 | 0 | 0 | 1.000 | 1.000 | -- |
| 0.50 | 8 | 0 | 0 | 1.000 | 1.000 | -- |
| 0.75 | 8 | 0 | 0 | 1.000 | 1.000 | -- |
| 1.00 | 8 | 0 | 0 | 1.000 | 1.000 | -- |
| 1.50 | 8 | 0 | 0 | 1.000 | 1.000 | -- |
| 2.00 | 8 | 8 | 0 | 0.586 | 0.321 | 64 |
| 4.00 | 8 | 8 | 0 | 0.000 | 0.000 | 0 |
How to read this table, in order: It kind of works. Doubling the calibrated dose is a slow crash. Quadrupling
it is instant. The shape of the alpha sweep matches what calibration
predicts.
1. *Censored is always 0* here, so survival numbers are not biased by short rollouts. Every trajectory reaches every fork t in `{0, 5, ..., w}`. It also doesn't work, in interesting ways:
2. *alpha 0 to 1.5*: `died = 0`, `S = 1.0` everywhere. Format holds for the full rollout. The held-out set generalizes past the calibration window (t = 20 dotted line) all the way to t = 512. The 1-nat KL budget set on calibration prompts continues to be a survivable budget on a different prompt set, more than an order of magnitude past where calibration looked.
3. *alpha = 2.0*: `died = 8` of 8 within the rollout; median death at t = 64. All trajectories eventually break, but they survive the first 64 tokens, about 3x the calibration window. Format collapse is gradual at this dose.
4. *alpha = 4.0*: median death at t = 0. The first fork already shows `pmass_eval < 0.95`. Format collapse is immediate.
5. *Read t at S<=0.5 as a half-life.* It scales nonlinearly with alpha: 64 tokens at 2x, 0 tokens at 4x. Doubling the dose past the calibration point does not double the half-life; it eliminates it.
The calibrated alpha (`alpha = 1`) earns a high-confidence survival claim within this experiment: 8 of 8 trajectories survived all 512 forks on a held-out prompt set. Doubling the dose still leaves a usable window. Quadrupling it does not. - **Random dead traces at low alpha.** At alpha = 0.25 to 0.75, KL stays
comfortably under the calibration scale and yet 7 to 12 of 24
trajectories die at some fork. They are not dying because the budget
was exceeded; they are dying for some other reason that p95 KL doesn't
see. Possibilities: a single high-KL spike at one token that p95
averages away; a low-KL but format-fragile region where small logit
shifts knock pmass below threshold; pmass at 0.75 just being noisy.
Probably some of each. So p95 KL is necessary but not sufficient as a
steering budget.
- **It doesn't calibrate cleanly across methods.** It works fine for
`mean_diff`. For more directional methods (e.g. directional ablation,
PCA) the same KL budget at alpha = 1 produces noticeably different
behavioural effects, because mean-mass shifts and directional
projections have different geometry, and one scalar KL doesn't
distinguish them.
The second point is probably unavoidable for any single scalar
calibration target. The intervention is multi-dimensional, the target is
one number.
## So is it useful?
Probably yes, weakly. The calibration this gives you does enable a
fair-er method comparison. I've used it in
[steering-lite](https://github.com/wassname/steering-lite) for a sweep
across four methods, where the iso-KL coefficient is what makes the
comparison not just a lottery on whose default coefficient was tuned
harder.
Better directions if anyone wants to improve this:
- A different lens on steering strength. Taimeskhanov, Vaiter, & Garreau
(2026), [Towards Understanding Steering
Strength](https://arxiv.org/abs/2602.02712), derive how steering
magnitude affects next-token probability and cross-entropy across 11
models, and report non-monotonic effects of strength. They look at the
immediate next-token effect rather than stability along a trajectory,
so it's a complementary view, not a replacement: theirs answers "what
does cranking the knob do at one token", mine asks "does the model
stay coherent for 4k tokens at a calibrated dose". Worth combining.
- Per-token cosine gating. Some steering methods scale the intervention
by the cosine between the residual and the steering direction, a cheap
way to suppress the intervention where it doesn't apply. (I forget who
first did this; pointers welcome.)
- A different calibration target. p95 KL is what I used because it is
cheap and intuitive; a behavioural target (e.g. preserve base
perplexity on a held-out corpus within delta) might be more honest
about what we actually care about.
- Accept that mean-mass shift is a coarse intervention and stop trying to
calibrate it past a certain precision. Use the calibration to put
methods in the same order of magnitude, then compare on downstream
behaviour.
Mostly I'm sharing this because the alternative in much of the literature
is "no calibration." This is cheap, fast, and strictly better than that.
I hope someone improves on it.
## Reproduce ## Reproduce
```bash ```bash
uv sync --extra all uv sync --extra all
just smoke # tiny-random model, ~1 min CPU
just calibrate # one (model, method, seed, window) cell
just trajectory
just plot
```
The headline run for the figures here:
```bash
uv run --extra all python scripts/run_cell.py \ uv run --extra all python scripts/run_cell.py \
--model Qwen/Qwen3.5-0.8B --method mean_diff --seed 0 --window 512 \ --model allenai/OLMo-2-0425-1B --method mean_diff --seed 0 --window 4096 \
--out-root outputs/qwen35_w512_dense \ --out-root outputs/olmo1b_w4096_dense
--run-id Qwen3.5-0.8B_mean_diff_s0_w512_dense \
--compute-pmass --skip-pmass-qa --fork-log \
--alphas 0.0 0.25 0.5 0.75 1.0 1.5 2.0 4.0 \
--render-figs --render-threshold 0.95
```
Outputs land under `outputs/qwen35_w512_dense/<run-id>/figs_auto/{survival,spaghetti,aggregate}/`.
## What this repo is and isn't
In scope:
- One figure family (spaghetti, survival, aggregate) and one calibration script.
- 3 methods: `mean_diff`, `directional_ablation`, `pca`.
- A `branch_pmass` metric: fork-and-teacher-force probability mass on a forced-choice answer token after schema prefill.
Out of scope:
- A method zoo beyond those three.
- A norm-matching baseline. Iso-KL says nothing about whether two methods of equal KL move the same distance in activation space.
- A threshold sweep. The pmass < 0.95 cutoff is a single point; neighbours might disagree.
## Per-trajectory KL normalization (OLMo-2 1B, w=4096)
The figures above use raw per-token KL in nats. That works when every panel shares an alpha-scale, but cross-alpha comparison gets harder once trajectories vary in their own baseline KL noise (some prompts are just spikier than others, even at alpha = 1).
So we re-plot KL divided by each trajectory's own p95 over the first 20 tokens. The dashed reference at y = 1 is then the calibration scale by construction. Values above 1 are excursions above that trajectory's own early-rollout budget. The y-axis is shared across panels so the alpha sweep reads as one picture.
![KL/p95 trajectories, OLMo-2 1B, w=4096, n=24](figs/spaghetti_olmo2_1b_w4096_norm.png)
*Per-token `KL(steered || base) / p95(KL[:20])` for OLMo-2-0425-1B (mean_diff, w=4096, n=24 held-out prompts). One panel per alpha. Blue = alive at that token (`pmass_eval >= 0.75` at the nearest fork s <= t); orange = dead (irreversible). Black tick = EOS (right-censored). Dashed = calibration scale = 1. Black line = median across alive+dead trajectories. Threshold here is 0.75, looser than the 0.95 used for Qwen, because OLMo's early pmass is noisier.*
How to read it, in order:
1. *alpha = 0.* No steering, KL = 0, denom = 0, all trajectories skipped. Empty panel is correct.
2. *alpha = 0.25 to 0.75.* Median sits well below 1. Most trajectories never approach the calibration scale even past 4000 tokens. The 1-nat budget set on a w = 4096 calibration window has long-horizon headroom for OLMo-2 1B.
3. *alpha = 1.0.* Median hovers a bit under 1 with one outlier traveller above 2. Iso-KL is doing what it claims at the trajectory-median level on a different prompt set than calibration saw.
4. *alpha = 1.5 to 2.0.* Median crosses 1 and stays there. Most trajectories die. KL is now sustainably over budget.
5. *alpha = 4.0.* Every trajectory dies. Median rides above 1 throughout. Same instant-collapse story as the Qwen panel.
What surprised me, and what doesn't yet have an explanation:
- *Random dead traces at low alpha.* At alpha = 0.25 to 0.75, KL stays comfortably under 1 for almost every token, and yet 7 to 12 of 24 trajectories register a death event at some fork. They are not dying because the budget was exceeded; they are dying for some other reason that p95 KL does not see. Possibilities: a single high-KL spike at one token that p95 averages away; a low-KL but format-fragile region where small logit shifts knock pmass below 0.75; calibration noise on the pmass side (threshold 0.75 is loose, and OLMo's base pmass at some forks is already close to it). I don't know which yet.
- *alpha = 1.0 has fewer deaths than alpha = 0.75 (11 vs 12).* Within noise at n = 24, but the monotonicity of "more steering = more death" is not crisp at this scale.
- *Censoring drops to ~0 above alpha = 0.75.* Steering past the calibrated dose seems to suppress EOS, so trajectories run to the window length even when off-format. That makes died counts more comparable across high-alpha panels but means low-alpha "censored" is partly the model's own brevity, not the steering's.
So: per-trajectory normalization makes the cross-alpha comparison cleaner, and confirms iso-KL is roughly delivering its claimed dose on held-out prompts. It also makes the "KL is necessary but not sufficient" point sharper: the low-alpha panels show death events at well under the calibration scale.
Reproduce:
```bash
uv run --extra all python scripts/spaghetti_kl_alive.py \ uv run --extra all python scripts/spaghetti_kl_alive.py \
--runs-root outputs/olmo1b_w4096_dense/_OLMo-2-0425-1B_mean_diff_s0_w4096_dense_n24_single \ --runs-root outputs/olmo1b_w4096_dense/_OLMo-2-0425-1B_mean_diff_s0_w4096_dense_n24_single \
--out outputs/olmo1b_w4096_dense/OLMo-2-0425-1B_mean_diff_s0_w4096_dense_n24/figs_thr075_quantile_shaded/spaghetti_norm \ --out figs/spaghetti_olmo2_1b_w4096_norm \
--window 4096 --threshold 0.75 --model-contains OLMo-2-0425-1B \ --window 4096 --threshold 0.75 --model-contains OLMo-2-0425-1B \
--roll 301 --line-lw 0.4 --roll 301 --line-lw 0.4
``` ```
Toggle off normalization with `--no-normalize-kl`; change the calibration window with `--calib-tokens N`. See [README_long.md](README_long.md) for the long version with the original Qwen run,
glossary, and survival analysis.
## Honest caveats
- *Iso-KL is one fairness criterion, not the criterion.* Matching p95 per-token KL means matching distributional disagreement under greedy decoding in the calibration window. It does not match intervention norm, behavioural effect size, or human-perceived quality.
- *Pmass is a proxy.* It scores whether the model still produces one of two specific tokens after we splice in a schema. A model that has gone off-format but is otherwise coherent gets scored as dead.
- *n = 8.* Held-out prompts, not held-out seeds. The phase-transition shape is consistent at this scale; the exact half-life at alpha = 2 is not.
- *One model so far.* Gemma 4B, Gemma 12B (4-bit), OLMo-2 1B, and OLMo-3 7B at w = 4096 are queued. The phase-transition story may or may not survive scaling.
## Glossary
Operational definitions used in this repo, not textbook ones. Skim and come back as needed.
- *steering coefficient (alpha)*: scalar multiplier on the steering vector added to residual-stream activations at one layer. `alpha = 1.0` is the iso-KL-calibrated coefficient C. `alpha = 0` is the unsteered base model.
- *iso-KL calibration*: bisection on alpha such that the 95th-percentile per-token `KL(steered || base)` over a w-token calibration window equals 1 nat. Output is `c_star`, the coefficient at `alpha = 1`.
- *p95 KL*: 95th percentile of per-token `KL(steered || base)` over the calibration window. The "worst 5%" measure used by calibration.
- *window (w)*: length in tokens of the calibration rollout. `w = 512` for the figures here.
- *pmass*: probability the model puts on the set `{true, false}` after we splice in a JSON schema prefill (`'\nI should answer now.</think>{"value": '`) at fork token t.
- *pmass_eval*: pmass on a held-out prompt set, at every fork t.
- *fork point t*: the token in the rollout where we splice the schema and read pmass.
- *death*: irreversible. Alive at t means `pmass_eval(t) >= 0.95`. First time it drops below, dead, and stays dead. Below 0.95 the JSON schema stops being a stable attractor.
- *right-censoring*: the rollout hit EOS at length L < t, so we never see fork t. Drops out of the at-risk denominator from L onward; not death.
- *survival S(t)*: Kaplan-Meier estimator of the fraction alive at fork t.
- *mean_diff*: steering vector built as the difference of mean activations on a contrastive prompt pair.
- *spaghetti plot*: every individual KL trajectory drawn as one thin line, alive segments green, dead red, black line is the median.
+162
View File
@@ -0,0 +1,162 @@
# iso-kl-figure
How strongly should we steer a model? How do we compare steering methods when one might be strong and one weak? These are calibration questions.
Treat steering as an intervention: we want maximum behavior change with minimum side effects, like incoherence, format collapse, or random off-target damage.
A useful analogy is a car on the road. A small nudge to the steering wheel gets corrected by the driver. A larger nudge causes a lane change. A very large nudge causes a crash the driver cannot recover from. Some crashes happen immediately. Some take a few seconds to develop, after the driver tries and fails to correct.
We measure the distribution shift caused by steering, especially the worst 5% of per-token KL that would cause a "crash", and find the largest scalar coefficient C that keeps that 5% below a safe threshold (default: 1 nat). At `alpha = 1` the steering is delivering exactly that calibrated dose. At `alpha = 2` it is doing twice the dose. *Iso-KL* means: comparing two methods at the same alpha = 1 means they are spending the same per-token KL budget, so any behavioural difference is not just one of them being louder.
Then comes the survival question. If the calibration only looked at the first w tokens of the rollout, do crashes still happen later? Most trajectories either stabilize or go off track within the first 50-100 tokens, but we want to see the long tail too. So at every fork t we ask: can the model still produce one of two valid answer tokens (`true` or `false`) when forced into a JSON-schema prefill? If yes, *alive* at t. If the probability mass on those tokens drops below 0.95, *dead*, and dead stays dead. Survival S(t) is the fraction still alive at token t, with right-censoring for rollouts that hit EOS before t.
Headline result on Qwen3.5-0.8B (`mean_diff`, w=512, n=8 held-out prompts):
- `alpha` in {0, 0.25, 0.5, 0.75, 1.0, 1.5}: 0 of 8 trajectories die at any fork.
- `alpha = 2.0`: 8 of 8 die. Median death time t = 64.
- `alpha = 4.0`: 8 of 8 die at t = 0.
A phase transition at roughly 2x the iso-KL coefficient. The calibrated dose has long-horizon headroom; doubling the dose is a slow crash; quadrupling it is an instant one. Figures and table below.
## Spaghetti: per-token KL trajectories, coloured by survival
![KL trajectories coloured by survival, Qwen3.5-0.8B](figs/post/spaghetti_qwen35_w512.png)
*Per-token KL(steered || base) for n=8 held-out long-form prompts on Qwen3.5-0.8B (mean_diff, w=512). One panel per alpha. Each thin line is one trajectory; green = alive at that token (pmass_eval >= 0.95), red = dead. Black line is the median. Dotted horizontal: KL = 1 nat (the calibration target).*
How to read this plot, in order:
1. *alpha = 0.* KL is exactly zero everywhere. This is the base model; nothing has been added. A sanity check that the splicing pipeline does not itself perturb the logits.
2. *alpha = 0.25 to 1.0.* KL stays well below the 1-nat line for most tokens. The calibration window only required p95 = 1 nat, so most tokens are far below. All trajectories are green: the model is still producing schema-valid answers throughout. The "warm but not hot" zone.
3. *alpha = 1.5.* KL still mostly below 1 nat but with a fatter upper envelope. All trajectories still alive end-to-end on the held-out set, even though calibration only guaranteed survival up to roughly `alpha = 1`. Suggests headroom.
4. *alpha = 2.0.* KL median sits around 1 nat for most of the rollout, with a noisy plateau. The trajectories turn red part-way through. Steering is now strong enough to break format. The phase transition.
5. *alpha = 4.0.* Every trajectory is red from t = 0 and KL is roughly 4-6 nats throughout. The model is no longer answering the question; it is generating something else.
The diagnostic shape is the alpha = 2 panel: KL is only modestly above the calibration target (less than 2x in nats) but format collapse is total. KL is necessary but not sufficient as a steering budget. Going from "calibrated" to "calibrated x 2" is not a smooth degradation; it crosses a cliff.
## Survival: when do trajectories die?
![Kaplan-Meier survival on pmass_eval, Qwen3.5-0.8B](figs/post/survival_qwen35_w512.png)
*Kaplan-Meier S(t) on the pmass_eval death event (pmass_eval < 0.95). One curve per alpha, n = 8 trajectories each. Dotted vertical: t = 20, the upper end of the calibration window's relevance for the answer span; everything to the right is generalization.*
| alpha | n | died | censored | S(mid) | S(end) | t at S<=0.5 |
|------:|--:|-----:|---------:|-------:|-------:|------------:|
| 0.00 | 8 | 0 | 0 | 1.000 | 1.000 | -- |
| 0.25 | 8 | 0 | 0 | 1.000 | 1.000 | -- |
| 0.50 | 8 | 0 | 0 | 1.000 | 1.000 | -- |
| 0.75 | 8 | 0 | 0 | 1.000 | 1.000 | -- |
| 1.00 | 8 | 0 | 0 | 1.000 | 1.000 | -- |
| 1.50 | 8 | 0 | 0 | 1.000 | 1.000 | -- |
| 2.00 | 8 | 8 | 0 | 0.586 | 0.321 | 64 |
| 4.00 | 8 | 8 | 0 | 0.000 | 0.000 | 0 |
How to read this table, in order:
1. *Censored is always 0* here, so survival numbers are not biased by short rollouts. Every trajectory reaches every fork t in `{0, 5, ..., w}`.
2. *alpha 0 to 1.5*: `died = 0`, `S = 1.0` everywhere. Format holds for the full rollout. The held-out set generalizes past the calibration window (t = 20 dotted line) all the way to t = 512. The 1-nat KL budget set on calibration prompts continues to be a survivable budget on a different prompt set, more than an order of magnitude past where calibration looked.
3. *alpha = 2.0*: `died = 8` of 8 within the rollout; median death at t = 64. All trajectories eventually break, but they survive the first 64 tokens, about 3x the calibration window. Format collapse is gradual at this dose.
4. *alpha = 4.0*: median death at t = 0. The first fork already shows `pmass_eval < 0.95`. Format collapse is immediate.
5. *Read t at S<=0.5 as a half-life.* It scales nonlinearly with alpha: 64 tokens at 2x, 0 tokens at 4x. Doubling the dose past the calibration point does not double the half-life; it eliminates it.
The calibrated alpha (`alpha = 1`) earns a high-confidence survival claim within this experiment: 8 of 8 trajectories survived all 512 forks on a held-out prompt set. Doubling the dose still leaves a usable window. Quadrupling it does not.
## Reproduce
```bash
uv sync --extra all
just smoke # tiny-random model, ~1 min CPU
just calibrate # one (model, method, seed, window) cell
just trajectory
just plot
```
The headline run for the figures here:
```bash
uv run --extra all python scripts/run_cell.py \
--model Qwen/Qwen3.5-0.8B --method mean_diff --seed 0 --window 512 \
--out-root outputs/qwen35_w512_dense \
--run-id Qwen3.5-0.8B_mean_diff_s0_w512_dense \
--compute-pmass --skip-pmass-qa --fork-log \
--alphas 0.0 0.25 0.5 0.75 1.0 1.5 2.0 4.0 \
--render-figs --render-threshold 0.95
```
Outputs land under `outputs/qwen35_w512_dense/<run-id>/figs_auto/{survival,spaghetti,aggregate}/`.
## What this repo is and isn't
In scope:
- One figure family (spaghetti, survival, aggregate) and one calibration script.
- 3 methods: `mean_diff`, `directional_ablation`, `pca`.
- A `branch_pmass` metric: fork-and-teacher-force probability mass on a forced-choice answer token after schema prefill.
Out of scope:
- A method zoo beyond those three.
- A norm-matching baseline. Iso-KL says nothing about whether two methods of equal KL move the same distance in activation space.
- A threshold sweep. The pmass < 0.95 cutoff is a single point; neighbours might disagree.
## Per-trajectory KL normalization (OLMo-2 1B, w=4096)
The figures above use raw per-token KL in nats. That works when every panel shares an alpha-scale, but cross-alpha comparison gets harder once trajectories vary in their own baseline KL noise (some prompts are just spikier than others, even at alpha = 1).
So we re-plot KL divided by each trajectory's own p95 over the first 20 tokens. The dashed reference at y = 1 is then the calibration scale by construction. Values above 1 are excursions above that trajectory's own early-rollout budget. The y-axis is shared across panels so the alpha sweep reads as one picture.
![KL/p95 trajectories, OLMo-2 1B, w=4096, n=24](figs/spaghetti_olmo2_1b_w4096_norm.png)
*Per-token `KL(steered || base) / p95(KL[:20])` for OLMo-2-0425-1B (mean_diff, w=4096, n=24 held-out prompts). One panel per alpha. Blue = alive at that token (`pmass_eval >= 0.75` at the nearest fork s <= t); orange = dead (irreversible). Black tick = EOS (right-censored). Dashed = calibration scale = 1. Black line = median across alive+dead trajectories. Threshold here is 0.75, looser than the 0.95 used for Qwen, because OLMo's early pmass is noisier.*
How to read it, in order:
1. *alpha = 0.* No steering, KL = 0, denom = 0, all trajectories skipped. Empty panel is correct.
2. *alpha = 0.25 to 0.75.* Median sits well below 1. Most trajectories never approach the calibration scale even past 4000 tokens. The 1-nat budget set on a w = 4096 calibration window has long-horizon headroom for OLMo-2 1B.
3. *alpha = 1.0.* Median hovers a bit under 1 with one outlier traveller above 2. Iso-KL is doing what it claims at the trajectory-median level on a different prompt set than calibration saw.
4. *alpha = 1.5 to 2.0.* Median crosses 1 and stays there. Most trajectories die. KL is now sustainably over budget.
5. *alpha = 4.0.* Every trajectory dies. Median rides above 1 throughout. Same instant-collapse story as the Qwen panel.
What surprised me, and what doesn't yet have an explanation:
- *Random dead traces at low alpha.* At alpha = 0.25 to 0.75, KL stays comfortably under 1 for almost every token, and yet 7 to 12 of 24 trajectories register a death event at some fork. They are not dying because the budget was exceeded; they are dying for some other reason that p95 KL does not see. Possibilities: a single high-KL spike at one token that p95 averages away; a low-KL but format-fragile region where small logit shifts knock pmass below 0.75; calibration noise on the pmass side (threshold 0.75 is loose, and OLMo's base pmass at some forks is already close to it). I don't know which yet.
- *alpha = 1.0 has fewer deaths than alpha = 0.75 (11 vs 12).* Within noise at n = 24, but the monotonicity of "more steering = more death" is not crisp at this scale.
- *Censoring drops to ~0 above alpha = 0.75.* Steering past the calibrated dose seems to suppress EOS, so trajectories run to the window length even when off-format. That makes died counts more comparable across high-alpha panels but means low-alpha "censored" is partly the model's own brevity, not the steering's.
So: per-trajectory normalization makes the cross-alpha comparison cleaner, and confirms iso-KL is roughly delivering its claimed dose on held-out prompts. It also makes the "KL is necessary but not sufficient" point sharper: the low-alpha panels show death events at well under the calibration scale.
Reproduce:
```bash
uv run --extra all python scripts/spaghetti_kl_alive.py \
--runs-root outputs/olmo1b_w4096_dense/_OLMo-2-0425-1B_mean_diff_s0_w4096_dense_n24_single \
--out outputs/olmo1b_w4096_dense/OLMo-2-0425-1B_mean_diff_s0_w4096_dense_n24/figs_thr075_quantile_shaded/spaghetti_norm \
--window 4096 --threshold 0.75 --model-contains OLMo-2-0425-1B \
--roll 301 --line-lw 0.4
```
Toggle off normalization with `--no-normalize-kl`; change the calibration window with `--calib-tokens N`.
## Honest caveats
- *Iso-KL is one fairness criterion, not the criterion.* Matching p95 per-token KL means matching distributional disagreement under greedy decoding in the calibration window. It does not match intervention norm, behavioural effect size, or human-perceived quality.
- *Pmass is a proxy.* It scores whether the model still produces one of two specific tokens after we splice in a schema. A model that has gone off-format but is otherwise coherent gets scored as dead.
- *n = 8.* Held-out prompts, not held-out seeds. The phase-transition shape is consistent at this scale; the exact half-life at alpha = 2 is not.
- *One model so far.* Gemma 4B, Gemma 12B (4-bit), OLMo-2 1B, and OLMo-3 7B at w = 4096 are queued. The phase-transition story may or may not survive scaling.
## Glossary
Operational definitions used in this repo, not textbook ones. Skim and come back as needed.
- *steering coefficient (alpha)*: scalar multiplier on the steering vector added to residual-stream activations at one layer. `alpha = 1.0` is the iso-KL-calibrated coefficient C. `alpha = 0` is the unsteered base model.
- *iso-KL calibration*: bisection on alpha such that the 95th-percentile per-token `KL(steered || base)` over a w-token calibration window equals 1 nat. Output is `c_star`, the coefficient at `alpha = 1`.
- *p95 KL*: 95th percentile of per-token `KL(steered || base)` over the calibration window. The "worst 5%" measure used by calibration.
- *window (w)*: length in tokens of the calibration rollout. `w = 512` for the figures here.
- *pmass*: probability the model puts on the set `{true, false}` after we splice in a JSON schema prefill (`'\nI should answer now.</think>{"value": '`) at fork token t.
- *pmass_eval*: pmass on a held-out prompt set, at every fork t.
- *fork point t*: the token in the rollout where we splice the schema and read pmass.
- *death*: irreversible. Alive at t means `pmass_eval(t) >= 0.95`. First time it drops below, dead, and stays dead. Below 0.95 the JSON schema stops being a stable attractor.
- *right-censoring*: the rollout hit EOS at length L < t, so we never see fork t. Drops out of the at-risk denominator from L onward; not death.
- *survival S(t)*: Kaplan-Meier estimator of the fraction alive at fork t.
- *mean_diff*: steering vector built as the difference of mean activations on a contrastive prompt pair.
- *spaghetti plot*: every individual KL trajectory drawn as one thin line, alive segments green, dead red, black line is the median.