survival: legend outside, contrast palette, jitter for overlapping S=1 lines

This commit is contained in:
wassname
2026-05-06 05:53:35 +08:00
parent 9d830cb3f8
commit 8dd427fc2e
6 changed files with 134 additions and 28 deletions
+30 -20
View File
@@ -1,31 +1,23 @@
# iso-kl-figure
Calibrate one steering knob so different methods deliver the same per-token KL budget. Then test whether the model still answers a forced-choice question while it's being steered.
How strongly should we steer a model? How do we compare steering methods when one might be strong and one weak? These are calibration questions.
Headline result on Qwen3.5-0.8B (mean_diff, w=512, n=8 held-out prompts):
Treat steering as an intervention: we want maximum behavior change with minimum side effects, like incoherence, format collapse, or random off-target damage.
- alpha in {0, 0.25, 0.5, 0.75, 1.0, 1.5}: 0 of 8 trajectories die at any fork.
- alpha = 2.0: 8 of 8 die. Median death time t = 64.
- alpha = 4.0: 8 of 8 die at t = 0.
A useful analogy is a car on the road. A small nudge to the steering wheel gets corrected by the driver. A larger nudge causes a lane change. A very large nudge causes a crash the driver cannot recover from. Some crashes happen immediately. Some take a few seconds to develop, after the driver tries and fails to correct.
A phase transition at roughly 2x the iso-KL coefficient. Figures and table below.
We measure the distribution shift caused by steering, especially the worst 5% of per-token KL that would cause a "crash", and find the largest scalar coefficient C that keeps that 5% below a safe threshold (default: 1 nat). At `alpha = 1` the steering is delivering exactly that calibrated dose. At `alpha = 2` it is doing twice the dose. *Iso-KL* means: comparing two methods at the same alpha = 1 means they are spending the same per-token KL budget, so any behavioural difference is not just one of them being louder.
## Glossary
Then comes the survival question. If the calibration only looked at the first w tokens of the rollout, do crashes still happen later? Most trajectories either stabilize or go off track within the first 50-100 tokens, but we want to see the long tail too. So at every fork t we ask: can the model still produce one of two valid answer tokens (`true` or `false`) when forced into a JSON-schema prefill? If yes, *alive* at t. If the probability mass on those tokens drops below 0.95, *dead*, and dead stays dead. Survival S(t) is the fraction still alive at token t, with right-censoring for rollouts that hit EOS before t.
We deliberately reuse a few terms from survival analysis and steering. Definitions below are the operational ones used in this repo, not the textbook ones.
Headline result on Qwen3.5-0.8B (`mean_diff`, w=512, n=8 held-out prompts):
- `alpha` in {0, 0.25, 0.5, 0.75, 1.0, 1.5}: 0 of 8 trajectories die at any fork.
- `alpha = 2.0`: 8 of 8 die. Median death time t = 64.
- `alpha = 4.0`: 8 of 8 die at t = 0.
A phase transition at roughly 2x the iso-KL coefficient. The calibrated dose has long-horizon headroom; doubling the dose is a slow crash; quadrupling it is an instant one. Figures and table below.
- *steering coefficient (alpha)*: scalar multiplier on a steering vector added to residual-stream activations at one layer. `alpha = 1.0` is the iso-KL-calibrated coefficient. `alpha = 2.0` is twice that. `alpha = 0` is the unsteered base model.
- *iso-KL calibration*: bisection on alpha such that the 95th-percentile per-token `KL(steered || base)` over a w-token calibration window equals 1 nat. The output is `c_star`, the coefficient at `alpha = 1`. This makes different steering methods comparable: any two methods at `alpha = 1` are spending the same per-token KL budget.
- *p95 KL*: 95th percentile of per-token `KL(steered || base)` over the calibration window. Robust to single-token spikes.
- *window (w)*: length in tokens of the calibration rollout. `w = 512` for the figures here.
- *pmass*: probability mass the model puts on the set `{true, false}` after we splice in a JSON schema prefill (`'\nI should answer now.</think>{"value": '`) at fork token t. A forced-choice probe: can the model still produce one of the two valid answer tokens, or has steering wrecked the format?
- *pmass_eval*: same probe, on a held-out prompt set, evaluated at every fork t.
- *fork point t*: the token in the rollout where we splice the schema and read pmass.
- *death*: irreversible event. A trajectory is *alive* at fork t if `pmass_eval(t) >= 0.95`. The first time it drops below, it is dead and stays dead. The 0.95 threshold is where the JSON schema stops being a stable attractor; the model is producing other tokens.
- *right-censoring*: the rollout hit EOS at length L < t, so we never see fork t. The trajectory drops out of the at-risk denominator from L onward. Incomplete information, not death.
- *survival S(t)*: fraction of trajectories still alive at fork t, with right-censored trajectories removed from the at-risk set. Kaplan-Meier estimator.
- *mean_diff*: steering vector built as the difference of mean activations on a contrastive prompt pair.
- *spaghetti plot*: every individual KL trajectory drawn as one thin line, alive segments green, dead segments red, black line is the median.
## Spaghetti: per-token KL trajectories, coloured by survival
@@ -112,3 +104,21 @@ Out of scope:
- *Pmass is a proxy.* It scores whether the model still produces one of two specific tokens after we splice in a schema. A model that has gone off-format but is otherwise coherent gets scored as dead.
- *n = 8.* Held-out prompts, not held-out seeds. The phase-transition shape is consistent at this scale; the exact half-life at alpha = 2 is not.
- *One model so far.* Gemma 4B, Gemma 12B (4-bit), OLMo-2 1B, and OLMo-3 7B at w = 4096 are queued. The phase-transition story may or may not survive scaling.
## Glossary
Operational definitions used in this repo, not textbook ones. Skim and come back as needed.
- *steering coefficient (alpha)*: scalar multiplier on the steering vector added to residual-stream activations at one layer. `alpha = 1.0` is the iso-KL-calibrated coefficient C. `alpha = 0` is the unsteered base model.
- *iso-KL calibration*: bisection on alpha such that the 95th-percentile per-token `KL(steered || base)` over a w-token calibration window equals 1 nat. Output is `c_star`, the coefficient at `alpha = 1`.
- *p95 KL*: 95th percentile of per-token `KL(steered || base)` over the calibration window. The "worst 5%" measure used by calibration.
- *window (w)*: length in tokens of the calibration rollout. `w = 512` for the figures here.
- *pmass*: probability the model puts on the set `{true, false}` after we splice in a JSON schema prefill (`'\nI should answer now.</think>{"value": '`) at fork token t.
- *pmass_eval*: pmass on a held-out prompt set, at every fork t.
- *fork point t*: the token in the rollout where we splice the schema and read pmass.
- *death*: irreversible. Alive at t means `pmass_eval(t) >= 0.95`. First time it drops below, dead, and stays dead. Below 0.95 the JSON schema stops being a stable attractor.
- *right-censoring*: the rollout hit EOS at length L < t, so we never see fork t. Drops out of the at-risk denominator from L onward; not death.
- *survival S(t)*: Kaplan-Meier estimator of the fraction alive at fork t.
- *mean_diff*: steering vector built as the difference of mean activations on a contrastive prompt pair.
- *spaghetti plot*: every individual KL trajectory drawn as one thin line, alive segments green, dead red, black line is the median.