feat(spaghetti): per-trajectory KL/p95 normalization + shared y-axis

- Add --normalize-kl / --calib-tokens (default on, 20 tokens) to
  spaghetti_kl_alive.py: each trajectory is divided by its own
  p95(KL[:calib_tokens]). Calibration target collapses to y=1.
- Share y-axis across all alpha panels (global p99 ymax) for
  direct cross-alpha comparison.
- Add OLMo-2 1B w=4096 figure to figs/ and a README section
  documenting the result, including the unexplained 'random
  dead traces at low alpha' observation.
This commit is contained in:
wassname
2026-05-08 08:17:21 +08:00
parent 7d3fd37743
commit dbff43cffb
3 changed files with 105 additions and 16 deletions
+38
View File
@@ -98,6 +98,44 @@ Out of scope:
- A norm-matching baseline. Iso-KL says nothing about whether two methods of equal KL move the same distance in activation space. - A norm-matching baseline. Iso-KL says nothing about whether two methods of equal KL move the same distance in activation space.
- A threshold sweep. The pmass < 0.95 cutoff is a single point; neighbours might disagree. - A threshold sweep. The pmass < 0.95 cutoff is a single point; neighbours might disagree.
## Per-trajectory KL normalization (OLMo-2 1B, w=4096)
The figures above use raw per-token KL in nats. That works when every panel shares an alpha-scale, but cross-alpha comparison gets harder once trajectories vary in their own baseline KL noise (some prompts are just spikier than others, even at alpha = 1).
So we re-plot KL divided by each trajectory's own p95 over the first 20 tokens. The dashed reference at y = 1 is then the calibration scale by construction. Values above 1 are excursions above that trajectory's own early-rollout budget. The y-axis is shared across panels so the alpha sweep reads as one picture.
![KL/p95 trajectories, OLMo-2 1B, w=4096, n=24](figs/spaghetti_olmo2_1b_w4096_norm.png)
*Per-token `KL(steered || base) / p95(KL[:20])` for OLMo-2-0425-1B (mean_diff, w=4096, n=24 held-out prompts). One panel per alpha. Blue = alive at that token (`pmass_eval >= 0.75` at the nearest fork s <= t); orange = dead (irreversible). Black tick = EOS (right-censored). Dashed = calibration scale = 1. Black line = median across alive+dead trajectories. Threshold here is 0.75, looser than the 0.95 used for Qwen, because OLMo's early pmass is noisier.*
How to read it, in order:
1. *alpha = 0.* No steering, KL = 0, denom = 0, all trajectories skipped. Empty panel is correct.
2. *alpha = 0.25 to 0.75.* Median sits well below 1. Most trajectories never approach the calibration scale even past 4000 tokens. The 1-nat budget set on a w = 4096 calibration window has long-horizon headroom for OLMo-2 1B.
3. *alpha = 1.0.* Median hovers a bit under 1 with one outlier traveller above 2. Iso-KL is doing what it claims at the trajectory-median level on a different prompt set than calibration saw.
4. *alpha = 1.5 to 2.0.* Median crosses 1 and stays there. Most trajectories die. KL is now sustainably over budget.
5. *alpha = 4.0.* Every trajectory dies. Median rides above 1 throughout. Same instant-collapse story as the Qwen panel.
What surprised me, and what doesn't yet have an explanation:
- *Random dead traces at low alpha.* At alpha = 0.25 to 0.75, KL stays comfortably under 1 for almost every token, and yet 7 to 12 of 24 trajectories register a death event at some fork. They are not dying because the budget was exceeded; they are dying for some other reason that p95 KL does not see. Possibilities: a single high-KL spike at one token that p95 averages away; a low-KL but format-fragile region where small logit shifts knock pmass below 0.75; calibration noise on the pmass side (threshold 0.75 is loose, and OLMo's base pmass at some forks is already close to it). I don't know which yet.
- *alpha = 1.0 has fewer deaths than alpha = 0.75 (11 vs 12).* Within noise at n = 24, but the monotonicity of "more steering = more death" is not crisp at this scale.
- *Censoring drops to ~0 above alpha = 0.75.* Steering past the calibrated dose seems to suppress EOS, so trajectories run to the window length even when off-format. That makes died counts more comparable across high-alpha panels but means low-alpha "censored" is partly the model's own brevity, not the steering's.
So: per-trajectory normalization makes the cross-alpha comparison cleaner, and confirms iso-KL is roughly delivering its claimed dose on held-out prompts. It also makes the "KL is necessary but not sufficient" point sharper: the low-alpha panels show death events at well under the calibration scale.
Reproduce:
```bash
uv run --extra all python scripts/spaghetti_kl_alive.py \
--runs-root outputs/olmo1b_w4096_dense/_OLMo-2-0425-1B_mean_diff_s0_w4096_dense_n24_single \
--out outputs/olmo1b_w4096_dense/OLMo-2-0425-1B_mean_diff_s0_w4096_dense_n24/figs_thr075_quantile_shaded/spaghetti_norm \
--window 4096 --threshold 0.75 --model-contains OLMo-2-0425-1B \
--roll 301 --line-lw 0.4
```
Toggle off normalization with `--no-normalize-kl`; change the calibration window with `--calib-tokens N`.
## Honest caveats ## Honest caveats
- *Iso-KL is one fairness criterion, not the criterion.* Matching p95 per-token KL means matching distributional disagreement under greedy decoding in the calibration window. It does not match intervention norm, behavioural effect size, or human-perceived quality. - *Iso-KL is one fairness criterion, not the criterion.* Matching p95 per-token KL means matching distributional disagreement under greedy decoding in the calibration window. It does not match intervention norm, behavioural effect size, or human-perceived quality.
Binary file not shown.

After

Width:  |  Height:  |  Size: 401 KiB

+67 -16
View File
@@ -37,13 +37,17 @@ from matplotlib.colors import to_rgba
try: try:
import seaborn as sns import seaborn as sns
sns.set_theme(context="notebook", style="whitegrid", palette="deep", font_scale=0.9) sns.set_theme(context="notebook", style="white", palette="deep", font_scale=0.9)
plt.rcParams.update({ plt.rcParams.update({
"axes.titlesize": 10, "axes.labelsize": 9, "axes.titlesize": 10, "axes.labelsize": 9,
"axes.spines.top": False, "axes.spines.right": False, "axes.spines.top": False, "axes.spines.right": False,
"axes.grid": False,
"figure.facecolor": "#faf7f2",
"axes.facecolor": "#faf7f2",
"savefig.facecolor": "#faf7f2",
}) })
except Exception: except Exception:
plt.style.use("ggplot") plt.style.use("default")
@dataclass @dataclass
@@ -56,7 +60,11 @@ class Args:
metric: str = "pmass_eval" # 'pmass_eval' is paired with KL prompts metric: str = "pmass_eval" # 'pmass_eval' is paired with KL prompts
model_contains: str = "Qwen3.5-0.8B" model_contains: str = "Qwen3.5-0.8B"
kl_log: bool = False kl_log: bool = False
roll: int = 11 # smooth KL a bit so the spaghetti is readable roll: int = 65 # smooth KL a lot so the spaghetti is readable
line_lw: float = 0.5 # per-trajectory linewidth -- thick enough to see
line_alpha: float | None = None # if None, scale by cohort size: clamp(8/n, .2, .7)
normalize_kl: bool = True # divide each KL trajectory by p95(KL[:calib_tokens])
calib_tokens: int = 20 # window over which to compute the per-traj p95 denom
def load_cell(d: Path, alpha: str, T: int): def load_cell(d: Path, alpha: str, T: int):
@@ -136,16 +144,28 @@ def main(a: Args):
n_panels = len(a.alphas) n_panels = len(a.alphas)
fig, axes = plt.subplots(1, n_panels, figsize=(3.0 * n_panels, 3.4), fig, axes = plt.subplots(1, n_panels, figsize=(3.0 * n_panels, 3.4),
sharey=False, squeeze=False) sharey=False, squeeze=False)
color_alive = to_rgba("#1a9850", 0.65) # green, translucent so overlap stays readable # alpha-per-line is computed per panel from the cohort size so n=32 is not a paintbrush
color_dead = to_rgba("#d7191c", 0.55) # stronger red so all-dead panels are visible # paper-friendly Wong-ish pair: muted teal-blue (alive) vs warm orange (dead).
# colourblind-safe, prints well, doesn't scream "stop sign".
GREEN = "#3a7ca5" # steel teal -- "alive / in budget"
RED = "#d97706" # warm amber -- "dead / out of budget"
summary_rows = [] summary_rows = []
panel_state: list[dict] = [] # collected per-panel info; ylim applied after loop
all_visible_kls: list[float] = []
for j, alpha in enumerate(a.alphas): for j, alpha in enumerate(a.alphas):
ax = axes[0, j] ax = axes[0, j]
all_trajs = [] all_trajs = []
for d in cells: for d in cells:
all_trajs.extend(load_cell(d, alpha, a.window)) all_trajs.extend(load_cell(d, alpha, a.window))
n = len(all_trajs) n = len(all_trajs)
# compute per-line alpha for this panel (more lines -> more transparent)
if a.line_alpha is not None:
la = float(a.line_alpha)
else:
la = float(np.clip(8.0 / max(n, 1), 0.20, 0.70))
color_alive = to_rgba(GREEN, la)
color_dead = to_rgba(RED, la)
n_died = 0 n_died = 0
n_censored = 0 n_censored = 0
median_rows = [] median_rows = []
@@ -160,6 +180,15 @@ def main(a: Args):
T = min(len(kl), gl) T = min(len(kl), gl)
max_gen_len = max(max_gen_len, T) max_gen_len = max(max_gen_len, T)
kl = _rolling_mean(kl[:T], a.roll) kl = _rolling_mean(kl[:T], a.roll)
if a.normalize_kl:
head = kl[: min(a.calib_tokens, len(kl))]
head = head[np.isfinite(head)]
denom = float(np.nanpercentile(head, 95)) if head.size else float("nan")
if np.isfinite(denom) and denom > 1e-8:
kl = kl / denom
else:
# no usable scale (e.g. alpha=0 -> KL==0); skip this traj
continue
visible_kls.extend([float(x) for x in kl if np.isfinite(x)]) visible_kls.extend([float(x) for x in kl if np.isfinite(x)])
median_rows.append(np.pad(kl, (0, max(0, a.window - T)), constant_values=np.nan)[: a.window]) median_rows.append(np.pad(kl, (0, max(0, a.window - T)), constant_values=np.nan)[: a.window])
alive = alive_mask_for_t(pmv, fork, T, a.threshold, gl) alive = alive_mask_for_t(pmv, fork, T, a.threshold, gl)
@@ -181,9 +210,9 @@ def main(a: Args):
alpha=0.6, zorder=3) alpha=0.6, zorder=3)
if dead_segments: if dead_segments:
ax.add_collection(LineCollection(dead_segments, colors=dead_colors, linewidths=0.9, zorder=1)) ax.add_collection(LineCollection(dead_segments, colors=dead_colors, linewidths=a.line_lw, zorder=1, capstyle="round", joinstyle="round", antialiaseds=True))
if alive_segments: if alive_segments:
ax.add_collection(LineCollection(alive_segments, colors=alive_colors, linewidths=0.9, zorder=2)) ax.add_collection(LineCollection(alive_segments, colors=alive_colors, linewidths=a.line_lw, zorder=2, capstyle="round", joinstyle="round", antialiaseds=True))
if median_rows: if median_rows:
mat = np.asarray(median_rows) mat = np.asarray(median_rows)
finite_cols = np.where(np.isfinite(mat).any(axis=0))[0] finite_cols = np.where(np.isfinite(mat).any(axis=0))[0]
@@ -195,30 +224,52 @@ def main(a: Args):
ax.plot(med_x, med, color="black", lw=1.3, alpha=0.9, zorder=4) ax.plot(med_x, med, color="black", lw=1.3, alpha=0.9, zorder=4)
y_hi = _kl_ymax(visible_kls) y_hi = _kl_ymax(visible_kls)
if y_hi >= 1.0: all_visible_kls.extend(visible_kls)
ax.axhline(1.0, color="black", lw=0.7, ls=":", label="KL=1 calib target") # calib target: at alpha=1, KL p95 = 1 nat by construction; at general alpha,
# expected KL ~ alpha^2 (small-step quadratic). Show as horizontal reference.
# If normalize_kl: each traj is divided by its own p95(KL[:calib_tokens]),
# so the target collapses to y=1 by construction (independent of alpha).
if a.normalize_kl:
target = 1.0
target_label = rf"calib p95 (tokens<{a.calib_tokens}) = 1"
else: else:
ax.text(0.98, 0.92, "KL=1 off-scale", transform=ax.transAxes, try:
ha="right", va="top", fontsize=7, color="0.25") target = float(alpha) ** 2
except ValueError:
target = float("nan")
target_label = rf"calib target $\alpha^2$={target:.2g}"
# axhline added unconditionally; off-scale handled later when global ylim is known
ax.axhline(target, color="black", lw=0.9, ls="--", label=target_label) if np.isfinite(target) and target > 0 else None
ax.set_title(rf"$\alpha={alpha}$ n={n} died={n_died} cens={n_censored}") ax.set_title(rf"$\alpha={alpha}$ n={n} died={n_died} cens={n_censored}")
ax.set_xlabel("token t") ax.set_xlabel("token t")
if j == 0: if j == 0:
ax.set_ylabel("per-token KL") ax.set_ylabel("KL / p95(KL[:%d])" % a.calib_tokens if a.normalize_kl else "per-token KL")
if a.kl_log: if a.kl_log:
ax.set_yscale("symlog", linthresh=0.1) ax.set_yscale("symlog", linthresh=0.1)
ax.set_xlim(-1, _panel_xmax(max_gen_len, a.window)) ax.set_xlim(-1, _panel_xmax(max_gen_len, a.window))
ax.set_ylim(-y_hi * 0.05, y_hi) panel_state.append({"ax": ax, "target": target, "target_label": target_label})
# data-driven y-lim sanity
# rely on auto
summary_rows.append({"alpha": alpha, "n": n, "n_died": n_died, "n_censored": n_censored}) summary_rows.append({"alpha": alpha, "n": n, "n_died": n_died, "n_censored": n_censored})
# shared y-axis across all panels for direct cross-alpha comparison
y_hi_global = _kl_ymax(all_visible_kls)
for ps in panel_state:
ax = ps["ax"]
ax.set_ylim(-y_hi_global * 0.05, y_hi_global)
target = ps["target"]
if np.isfinite(target) and target > 0 and target > y_hi_global:
ax.text(0.98, 0.92, f"{ps['target_label']} off-scale",
transform=ax.transAxes, ha="right", va="top", fontsize=7, color="0.25")
# legend on first panel only # legend on first panel only
from matplotlib.lines import Line2D from matplotlib.lines import Line2D
handles = [ handles = [
Line2D([0],[0], color=color_alive, lw=2, label=f"alive (pmass >= {a.threshold})"), Line2D([0],[0], color=color_alive, lw=2, label=f"alive (pmass >= {a.threshold})"),
Line2D([0],[0], color=color_dead, lw=2, label=f"dead (pmass < {a.threshold})"), Line2D([0],[0], color=color_dead, lw=2, label=f"dead (pmass < {a.threshold})"),
Line2D([0],[0], color="black", marker="|", lw=0, label="EOS (right-censored)"), Line2D([0],[0], color="black", marker="|", lw=0, label="EOS (right-censored)"),
Line2D([0],[0], color="black", lw=0.7, ls=":", label="KL=1 target if in view"), Line2D([0],[0], color="black", lw=0.9, ls="--",
label=(rf"calib p95 (tokens<{a.calib_tokens}) = 1" if a.normalize_kl
else r"calib target $\alpha^2$")),
] ]
axes[0, 0].legend(handles=handles, loc="upper right", fontsize=7, frameon=True) axes[0, 0].legend(handles=handles, loc="upper right", fontsize=7, frameon=True)