mirror of
https://github.com/wassname/isokl_steering_calibration.git
synced 2026-06-27 17:01:10 +08:00
survival: legend outside, contrast palette, jitter for overlapping S=1 lines
This commit is contained in:
@@ -1,31 +1,23 @@
|
||||
# iso-kl-figure
|
||||
|
||||
Calibrate one steering knob so different methods deliver the same per-token KL budget. Then test whether the model still answers a forced-choice question while it's being steered.
|
||||
How strongly should we steer a model? How do we compare steering methods when one might be strong and one weak? These are calibration questions.
|
||||
|
||||
Headline result on Qwen3.5-0.8B (mean_diff, w=512, n=8 held-out prompts):
|
||||
Treat steering as an intervention: we want maximum behavior change with minimum side effects, like incoherence, format collapse, or random off-target damage.
|
||||
|
||||
- alpha in {0, 0.25, 0.5, 0.75, 1.0, 1.5}: 0 of 8 trajectories die at any fork.
|
||||
- alpha = 2.0: 8 of 8 die. Median death time t = 64.
|
||||
- alpha = 4.0: 8 of 8 die at t = 0.
|
||||
A useful analogy is a car on the road. A small nudge to the steering wheel gets corrected by the driver. A larger nudge causes a lane change. A very large nudge causes a crash the driver cannot recover from. Some crashes happen immediately. Some take a few seconds to develop, after the driver tries and fails to correct.
|
||||
|
||||
A phase transition at roughly 2x the iso-KL coefficient. Figures and table below.
|
||||
We measure the distribution shift caused by steering, especially the worst 5% of per-token KL that would cause a "crash", and find the largest scalar coefficient C that keeps that 5% below a safe threshold (default: 1 nat). At `alpha = 1` the steering is delivering exactly that calibrated dose. At `alpha = 2` it is doing twice the dose. *Iso-KL* means: comparing two methods at the same alpha = 1 means they are spending the same per-token KL budget, so any behavioural difference is not just one of them being louder.
|
||||
|
||||
## Glossary
|
||||
Then comes the survival question. If the calibration only looked at the first w tokens of the rollout, do crashes still happen later? Most trajectories either stabilize or go off track within the first 50-100 tokens, but we want to see the long tail too. So at every fork t we ask: can the model still produce one of two valid answer tokens (`true` or `false`) when forced into a JSON-schema prefill? If yes, *alive* at t. If the probability mass on those tokens drops below 0.95, *dead*, and dead stays dead. Survival S(t) is the fraction still alive at token t, with right-censoring for rollouts that hit EOS before t.
|
||||
|
||||
We deliberately reuse a few terms from survival analysis and steering. Definitions below are the operational ones used in this repo, not the textbook ones.
|
||||
Headline result on Qwen3.5-0.8B (`mean_diff`, w=512, n=8 held-out prompts):
|
||||
|
||||
- `alpha` in {0, 0.25, 0.5, 0.75, 1.0, 1.5}: 0 of 8 trajectories die at any fork.
|
||||
- `alpha = 2.0`: 8 of 8 die. Median death time t = 64.
|
||||
- `alpha = 4.0`: 8 of 8 die at t = 0.
|
||||
|
||||
A phase transition at roughly 2x the iso-KL coefficient. The calibrated dose has long-horizon headroom; doubling the dose is a slow crash; quadrupling it is an instant one. Figures and table below.
|
||||
|
||||
- *steering coefficient (alpha)*: scalar multiplier on a steering vector added to residual-stream activations at one layer. `alpha = 1.0` is the iso-KL-calibrated coefficient. `alpha = 2.0` is twice that. `alpha = 0` is the unsteered base model.
|
||||
- *iso-KL calibration*: bisection on alpha such that the 95th-percentile per-token `KL(steered || base)` over a w-token calibration window equals 1 nat. The output is `c_star`, the coefficient at `alpha = 1`. This makes different steering methods comparable: any two methods at `alpha = 1` are spending the same per-token KL budget.
|
||||
- *p95 KL*: 95th percentile of per-token `KL(steered || base)` over the calibration window. Robust to single-token spikes.
|
||||
- *window (w)*: length in tokens of the calibration rollout. `w = 512` for the figures here.
|
||||
- *pmass*: probability mass the model puts on the set `{true, false}` after we splice in a JSON schema prefill (`'\nI should answer now.</think>{"value": '`) at fork token t. A forced-choice probe: can the model still produce one of the two valid answer tokens, or has steering wrecked the format?
|
||||
- *pmass_eval*: same probe, on a held-out prompt set, evaluated at every fork t.
|
||||
- *fork point t*: the token in the rollout where we splice the schema and read pmass.
|
||||
- *death*: irreversible event. A trajectory is *alive* at fork t if `pmass_eval(t) >= 0.95`. The first time it drops below, it is dead and stays dead. The 0.95 threshold is where the JSON schema stops being a stable attractor; the model is producing other tokens.
|
||||
- *right-censoring*: the rollout hit EOS at length L < t, so we never see fork t. The trajectory drops out of the at-risk denominator from L onward. Incomplete information, not death.
|
||||
- *survival S(t)*: fraction of trajectories still alive at fork t, with right-censored trajectories removed from the at-risk set. Kaplan-Meier estimator.
|
||||
- *mean_diff*: steering vector built as the difference of mean activations on a contrastive prompt pair.
|
||||
- *spaghetti plot*: every individual KL trajectory drawn as one thin line, alive segments green, dead segments red, black line is the median.
|
||||
|
||||
## Spaghetti: per-token KL trajectories, coloured by survival
|
||||
|
||||
@@ -112,3 +104,21 @@ Out of scope:
|
||||
- *Pmass is a proxy.* It scores whether the model still produces one of two specific tokens after we splice in a schema. A model that has gone off-format but is otherwise coherent gets scored as dead.
|
||||
- *n = 8.* Held-out prompts, not held-out seeds. The phase-transition shape is consistent at this scale; the exact half-life at alpha = 2 is not.
|
||||
- *One model so far.* Gemma 4B, Gemma 12B (4-bit), OLMo-2 1B, and OLMo-3 7B at w = 4096 are queued. The phase-transition story may or may not survive scaling.
|
||||
|
||||
|
||||
## Glossary
|
||||
|
||||
Operational definitions used in this repo, not textbook ones. Skim and come back as needed.
|
||||
|
||||
- *steering coefficient (alpha)*: scalar multiplier on the steering vector added to residual-stream activations at one layer. `alpha = 1.0` is the iso-KL-calibrated coefficient C. `alpha = 0` is the unsteered base model.
|
||||
- *iso-KL calibration*: bisection on alpha such that the 95th-percentile per-token `KL(steered || base)` over a w-token calibration window equals 1 nat. Output is `c_star`, the coefficient at `alpha = 1`.
|
||||
- *p95 KL*: 95th percentile of per-token `KL(steered || base)` over the calibration window. The "worst 5%" measure used by calibration.
|
||||
- *window (w)*: length in tokens of the calibration rollout. `w = 512` for the figures here.
|
||||
- *pmass*: probability the model puts on the set `{true, false}` after we splice in a JSON schema prefill (`'\nI should answer now.</think>{"value": '`) at fork token t.
|
||||
- *pmass_eval*: pmass on a held-out prompt set, at every fork t.
|
||||
- *fork point t*: the token in the rollout where we splice the schema and read pmass.
|
||||
- *death*: irreversible. Alive at t means `pmass_eval(t) >= 0.95`. First time it drops below, dead, and stays dead. Below 0.95 the JSON schema stops being a stable attractor.
|
||||
- *right-censoring*: the rollout hit EOS at length L < t, so we never see fork t. Drops out of the at-risk denominator from L onward; not death.
|
||||
- *survival S(t)*: Kaplan-Meier estimator of the fraction alive at fork t.
|
||||
- *mean_diff*: steering vector built as the difference of mean activations on a contrastive prompt pair.
|
||||
- *spaghetti plot*: every individual KL trajectory drawn as one thin line, alive segments green, dead red, black line is the median.
|
||||
@@ -0,0 +1,90 @@
|
||||
{
|
||||
"panel_models": [
|
||||
{
|
||||
"model": "google/gemma-3-27b-it",
|
||||
"scores": {
|
||||
"jargon_defined": {
|
||||
"score": 2,
|
||||
"reason": "The glossary is comprehensive and provides operational definitions for all key terms, including those the researcher is likely familiar with (KL divergence, steering vectors) and the novel ones (pmass, death, right-censoring)."
|
||||
},
|
||||
"figure_interpretability": {
|
||||
"score": 2,
|
||||
"reason": "The figures are well-explained with numbered steps that guide the reader through the key takeaways. The captions clearly state what is being plotted and the axes are understandable. The 'spaghetti plot' and 'survival' plot explanations are particularly strong."
|
||||
},
|
||||
"headline_clarity": {
|
||||
"score": 2,
|
||||
"reason": "The headline result is clearly stated: a phase transition occurs around 2x the iso-KL coefficient. The supporting data (alpha values and death rates) are presented concisely and directly support the claim."
|
||||
},
|
||||
"caveats_visible": {
|
||||
"score": 2,
|
||||
"reason": "The 'Honest caveats' section explicitly outlines the limitations of the study, including the specific fairness criterion used (iso-KL), the proxy nature of pmass, the small sample size (n=8), and the single model tested. This demonstrates a responsible and transparent approach."
|
||||
}
|
||||
},
|
||||
"total": 8,
|
||||
"max": 8,
|
||||
"ratio": 1.0
|
||||
},
|
||||
{
|
||||
"model": "meta-llama/llama-3.3-70b-instruct",
|
||||
"scores": {
|
||||
"jargon_defined": {
|
||||
"score": 2,
|
||||
"reason": "The README provides a clear glossary that defines technical terms such as iso-KL calibration, steering coefficient, and pmass, making it easier for a researcher to understand the concepts."
|
||||
},
|
||||
"figure_interpretability": {
|
||||
"score": 2,
|
||||
"reason": "The README provides detailed explanations of how to read the spaghetti and survival plots, including descriptions of the colors, lines, and axes, making it easy to interpret the figures correctly."
|
||||
},
|
||||
"headline_clarity": {
|
||||
"score": 2,
|
||||
"reason": "The headline result is clearly stated, and the phase transition claim is easy to understand, with a clear description of the findings and their implications."
|
||||
},
|
||||
"caveats_visible": {
|
||||
"score": 2,
|
||||
"reason": "The README includes an 'Honest caveats' section that clearly outlines the limitations and potential biases of the study, including the use of a single fairness criterion, the proxy nature of pmass, and the limited sample size."
|
||||
}
|
||||
},
|
||||
"total": 8,
|
||||
"max": 8,
|
||||
"ratio": 1.0
|
||||
},
|
||||
{
|
||||
"model": "qwen/qwen3-235b-a22b-thinking-2507",
|
||||
"scores": {
|
||||
"jargon_defined": {
|
||||
"score": 2,
|
||||
"reason": "All specialized terms like iso-KL calibration, pmass, and death are operationally defined with clear connections to KL divergence and steering vectors, making them accessible to the target researcher."
|
||||
},
|
||||
"figure_interpretability": {
|
||||
"score": 2,
|
||||
"reason": "Both figures include explicit step-by-step reading guides that explain color coding, axes, and diagnostic patterns, enabling correct interpretation of survival transitions and KL trajectories."
|
||||
},
|
||||
"headline_clarity": {
|
||||
"score": 2,
|
||||
"reason": "The phase-transition claim (death at alpha=2.0 but not below) is stated upfront with precise numerical results and reinforced through figure explanations, leaving no ambiguity about the threshold behavior."
|
||||
},
|
||||
"caveats_visible": {
|
||||
"score": 2,
|
||||
"reason": "Key limitations (n=8, single model, pmass proxy nature, iso-KL's narrow scope) are highlighted in a dedicated 'Honest caveats' section with concrete explanations of their implications."
|
||||
}
|
||||
},
|
||||
"total": 8,
|
||||
"max": 8,
|
||||
"ratio": 1.0
|
||||
}
|
||||
],
|
||||
"summary": {
|
||||
"mean_ratio": 1.0,
|
||||
"verdict": "ready"
|
||||
},
|
||||
"prompt": "Evaluate this README as a technical post for a researcher who knows what KL divergence and steering vectors are but has never seen iso-KL calibration or this 'pmass' survival framing. Can they (a) understand what iso-KL calibration does, (b) read the spaghetti and survival figures correctly, (c) understand what 'death' and 'pmass' mean here, (d) come away with the headline phase-transition claim and its caveats?",
|
||||
"scores": [
|
||||
"jargon_defined",
|
||||
"figure_interpretability",
|
||||
"headline_clarity",
|
||||
"caveats_visible"
|
||||
],
|
||||
"docs": [
|
||||
"README.md"
|
||||
]
|
||||
}
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 46 KiB After Width: | Height: | Size: 52 KiB |
+1
@@ -1,5 +1,6 @@
|
||||
| metric | threshold | alpha | n | S_mid | S_end | t_S<=0.5 |
|
||||
|:-----------|------------:|--------:|----:|--------:|--------:|-----------:|
|
||||
| pmass_eval | 0.950 | 0.000 | 8 | 1.000 | 1.000 | |
|
||||
| pmass_eval | 0.950 | 0.250 | 8 | 1.000 | 1.000 | |
|
||||
| pmass_eval | 0.950 | 0.500 | 8 | 1.000 | 1.000 | |
|
||||
| pmass_eval | 0.950 | 0.750 | 8 | 1.000 | 1.000 | |
|
||||
|
||||
BIN
Binary file not shown.
|
Before Width: | Height: | Size: 46 KiB After Width: | Height: | Size: 52 KiB |
+13
-8
@@ -169,10 +169,11 @@ def survival_pmass(P: np.ndarray, fork: list[int], gen_lens: np.ndarray,
|
||||
def main(a: Args):
|
||||
root = Path(a.runs_root); out = Path(a.out); out.mkdir(parents=True, exist_ok=True)
|
||||
n_panels = len(a.thresholds)
|
||||
fig, axes = plt.subplots(1, n_panels, figsize=(4.6 * n_panels, 3.4),
|
||||
fig, axes = plt.subplots(1, n_panels, figsize=(5.6 * n_panels, 3.6),
|
||||
sharey=True, squeeze=False)
|
||||
cmap = plt.get_cmap("viridis")
|
||||
colors = {alpha: cmap(i / max(1, len(a.alphas) - 1)) for i, alpha in enumerate(a.alphas)}
|
||||
# categorical colors with strong contrast across alpha (avoid viridis dark cluster)
|
||||
palette = ["#000000", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#a65628", "#e41a1c", "#f781bf"]
|
||||
colors = {alpha: palette[i % len(palette)] for i, alpha in enumerate(a.alphas)}
|
||||
|
||||
rows_summary = []
|
||||
for j, thr in enumerate(a.thresholds):
|
||||
@@ -193,7 +194,9 @@ def main(a: Args):
|
||||
xlabel = "fork token t"
|
||||
else:
|
||||
raise SystemExit(f"unknown metric {a.metric!r}; use 'kl', 'pmass', or 'pmass_eval'")
|
||||
ax.step(xs, S, where="post", color=colors[alpha], lw=2.0,
|
||||
# tiny vertical jitter so overlapping S=1.0 lines remain individually visible
|
||||
jitter = 0.004 * (list(a.alphas).index(alpha) - (len(a.alphas) - 1) / 2.0)
|
||||
ax.step(xs, S + jitter, where="post", color=colors[alpha], lw=2.2, alpha=0.9,
|
||||
label=rf"$\alpha={alpha}$ (n={n})")
|
||||
below_half = np.where(S <= 0.5)[0]
|
||||
t50 = int(xs[below_half[0]]) if len(below_half) else None
|
||||
@@ -205,13 +208,15 @@ def main(a: Args):
|
||||
if n_panels > 1:
|
||||
ax.set_title(f"threshold = {thr:g}")
|
||||
ax.set_xlabel(xlabel)
|
||||
ax.set_ylim(-0.02, 1.05)
|
||||
ax.set_ylim(-0.04, 1.08)
|
||||
if j == 0:
|
||||
ax.set_ylabel("fraction of trajectories alive")
|
||||
ax.legend(loc="lower left", fontsize=8, frameon=False)
|
||||
ax.axvline(20, color="k", ls=":", lw=0.7)
|
||||
# legend outside on the right, never covers the axes
|
||||
ax.legend(loc="center left", bbox_to_anchor=(1.02, 0.5),
|
||||
fontsize=9, frameon=False, title=r"$\alpha$ (n)", title_fontsize=9)
|
||||
ax.axvline(20, color="k", ls=":", lw=0.7, alpha=0.5)
|
||||
|
||||
fig.suptitle(f"Survival, {a.model_contains}", fontsize=10)
|
||||
fig.suptitle(f"Survival, {a.model_contains}", fontsize=11)
|
||||
fig.tight_layout(rect=(0, 0, 1, 0.94))
|
||||
out_p = out / f"survival_{a.metric}.png"
|
||||
fig.savefig(out_p, dpi=160, bbox_inches="tight")
|
||||
|
||||
Reference in New Issue
Block a user