mirror of
https://github.com/wassname/isokl_steering_calibration.git
synced 2026-06-27 17:01:10 +08:00
feat(spaghetti): per-trajectory KL/p95 normalization + shared y-axis
- Add --normalize-kl / --calib-tokens (default on, 20 tokens) to spaghetti_kl_alive.py: each trajectory is divided by its own p95(KL[:calib_tokens]). Calibration target collapses to y=1. - Share y-axis across all alpha panels (global p99 ymax) for direct cross-alpha comparison. - Add OLMo-2 1B w=4096 figure to figs/ and a README section documenting the result, including the unexplained 'random dead traces at low alpha' observation.
This commit is contained in:
@@ -98,6 +98,44 @@ Out of scope:
|
||||
- A norm-matching baseline. Iso-KL says nothing about whether two methods of equal KL move the same distance in activation space.
|
||||
- A threshold sweep. The pmass < 0.95 cutoff is a single point; neighbours might disagree.
|
||||
|
||||
## Per-trajectory KL normalization (OLMo-2 1B, w=4096)
|
||||
|
||||
The figures above use raw per-token KL in nats. That works when every panel shares an alpha-scale, but cross-alpha comparison gets harder once trajectories vary in their own baseline KL noise (some prompts are just spikier than others, even at alpha = 1).
|
||||
|
||||
So we re-plot KL divided by each trajectory's own p95 over the first 20 tokens. The dashed reference at y = 1 is then the calibration scale by construction. Values above 1 are excursions above that trajectory's own early-rollout budget. The y-axis is shared across panels so the alpha sweep reads as one picture.
|
||||
|
||||

|
||||
|
||||
*Per-token `KL(steered || base) / p95(KL[:20])` for OLMo-2-0425-1B (mean_diff, w=4096, n=24 held-out prompts). One panel per alpha. Blue = alive at that token (`pmass_eval >= 0.75` at the nearest fork s <= t); orange = dead (irreversible). Black tick = EOS (right-censored). Dashed = calibration scale = 1. Black line = median across alive+dead trajectories. Threshold here is 0.75, looser than the 0.95 used for Qwen, because OLMo's early pmass is noisier.*
|
||||
|
||||
How to read it, in order:
|
||||
|
||||
1. *alpha = 0.* No steering, KL = 0, denom = 0, all trajectories skipped. Empty panel is correct.
|
||||
2. *alpha = 0.25 to 0.75.* Median sits well below 1. Most trajectories never approach the calibration scale even past 4000 tokens. The 1-nat budget set on a w = 4096 calibration window has long-horizon headroom for OLMo-2 1B.
|
||||
3. *alpha = 1.0.* Median hovers a bit under 1 with one outlier traveller above 2. Iso-KL is doing what it claims at the trajectory-median level on a different prompt set than calibration saw.
|
||||
4. *alpha = 1.5 to 2.0.* Median crosses 1 and stays there. Most trajectories die. KL is now sustainably over budget.
|
||||
5. *alpha = 4.0.* Every trajectory dies. Median rides above 1 throughout. Same instant-collapse story as the Qwen panel.
|
||||
|
||||
What surprised me, and what doesn't yet have an explanation:
|
||||
|
||||
- *Random dead traces at low alpha.* At alpha = 0.25 to 0.75, KL stays comfortably under 1 for almost every token, and yet 7 to 12 of 24 trajectories register a death event at some fork. They are not dying because the budget was exceeded; they are dying for some other reason that p95 KL does not see. Possibilities: a single high-KL spike at one token that p95 averages away; a low-KL but format-fragile region where small logit shifts knock pmass below 0.75; calibration noise on the pmass side (threshold 0.75 is loose, and OLMo's base pmass at some forks is already close to it). I don't know which yet.
|
||||
- *alpha = 1.0 has fewer deaths than alpha = 0.75 (11 vs 12).* Within noise at n = 24, but the monotonicity of "more steering = more death" is not crisp at this scale.
|
||||
- *Censoring drops to ~0 above alpha = 0.75.* Steering past the calibrated dose seems to suppress EOS, so trajectories run to the window length even when off-format. That makes died counts more comparable across high-alpha panels but means low-alpha "censored" is partly the model's own brevity, not the steering's.
|
||||
|
||||
So: per-trajectory normalization makes the cross-alpha comparison cleaner, and confirms iso-KL is roughly delivering its claimed dose on held-out prompts. It also makes the "KL is necessary but not sufficient" point sharper: the low-alpha panels show death events at well under the calibration scale.
|
||||
|
||||
Reproduce:
|
||||
|
||||
```bash
|
||||
uv run --extra all python scripts/spaghetti_kl_alive.py \
|
||||
--runs-root outputs/olmo1b_w4096_dense/_OLMo-2-0425-1B_mean_diff_s0_w4096_dense_n24_single \
|
||||
--out outputs/olmo1b_w4096_dense/OLMo-2-0425-1B_mean_diff_s0_w4096_dense_n24/figs_thr075_quantile_shaded/spaghetti_norm \
|
||||
--window 4096 --threshold 0.75 --model-contains OLMo-2-0425-1B \
|
||||
--roll 301 --line-lw 0.4
|
||||
```
|
||||
|
||||
Toggle off normalization with `--no-normalize-kl`; change the calibration window with `--calib-tokens N`.
|
||||
|
||||
## Honest caveats
|
||||
|
||||
- *Iso-KL is one fairness criterion, not the criterion.* Matching p95 per-token KL means matching distributional disagreement under greedy decoding in the calibration window. It does not match intervention norm, behavioural effect size, or human-perceived quality.
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 401 KiB |
@@ -37,13 +37,17 @@ from matplotlib.colors import to_rgba
|
||||
|
||||
try:
|
||||
import seaborn as sns
|
||||
sns.set_theme(context="notebook", style="whitegrid", palette="deep", font_scale=0.9)
|
||||
sns.set_theme(context="notebook", style="white", palette="deep", font_scale=0.9)
|
||||
plt.rcParams.update({
|
||||
"axes.titlesize": 10, "axes.labelsize": 9,
|
||||
"axes.spines.top": False, "axes.spines.right": False,
|
||||
"axes.grid": False,
|
||||
"figure.facecolor": "#faf7f2",
|
||||
"axes.facecolor": "#faf7f2",
|
||||
"savefig.facecolor": "#faf7f2",
|
||||
})
|
||||
except Exception:
|
||||
plt.style.use("ggplot")
|
||||
plt.style.use("default")
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -56,7 +60,11 @@ class Args:
|
||||
metric: str = "pmass_eval" # 'pmass_eval' is paired with KL prompts
|
||||
model_contains: str = "Qwen3.5-0.8B"
|
||||
kl_log: bool = False
|
||||
roll: int = 11 # smooth KL a bit so the spaghetti is readable
|
||||
roll: int = 65 # smooth KL a lot so the spaghetti is readable
|
||||
line_lw: float = 0.5 # per-trajectory linewidth -- thick enough to see
|
||||
line_alpha: float | None = None # if None, scale by cohort size: clamp(8/n, .2, .7)
|
||||
normalize_kl: bool = True # divide each KL trajectory by p95(KL[:calib_tokens])
|
||||
calib_tokens: int = 20 # window over which to compute the per-traj p95 denom
|
||||
|
||||
|
||||
def load_cell(d: Path, alpha: str, T: int):
|
||||
@@ -136,16 +144,28 @@ def main(a: Args):
|
||||
n_panels = len(a.alphas)
|
||||
fig, axes = plt.subplots(1, n_panels, figsize=(3.0 * n_panels, 3.4),
|
||||
sharey=False, squeeze=False)
|
||||
color_alive = to_rgba("#1a9850", 0.65) # green, translucent so overlap stays readable
|
||||
color_dead = to_rgba("#d7191c", 0.55) # stronger red so all-dead panels are visible
|
||||
# alpha-per-line is computed per panel from the cohort size so n=32 is not a paintbrush
|
||||
# paper-friendly Wong-ish pair: muted teal-blue (alive) vs warm orange (dead).
|
||||
# colourblind-safe, prints well, doesn't scream "stop sign".
|
||||
GREEN = "#3a7ca5" # steel teal -- "alive / in budget"
|
||||
RED = "#d97706" # warm amber -- "dead / out of budget"
|
||||
|
||||
summary_rows = []
|
||||
panel_state: list[dict] = [] # collected per-panel info; ylim applied after loop
|
||||
all_visible_kls: list[float] = []
|
||||
for j, alpha in enumerate(a.alphas):
|
||||
ax = axes[0, j]
|
||||
all_trajs = []
|
||||
for d in cells:
|
||||
all_trajs.extend(load_cell(d, alpha, a.window))
|
||||
n = len(all_trajs)
|
||||
# compute per-line alpha for this panel (more lines -> more transparent)
|
||||
if a.line_alpha is not None:
|
||||
la = float(a.line_alpha)
|
||||
else:
|
||||
la = float(np.clip(8.0 / max(n, 1), 0.20, 0.70))
|
||||
color_alive = to_rgba(GREEN, la)
|
||||
color_dead = to_rgba(RED, la)
|
||||
n_died = 0
|
||||
n_censored = 0
|
||||
median_rows = []
|
||||
@@ -160,6 +180,15 @@ def main(a: Args):
|
||||
T = min(len(kl), gl)
|
||||
max_gen_len = max(max_gen_len, T)
|
||||
kl = _rolling_mean(kl[:T], a.roll)
|
||||
if a.normalize_kl:
|
||||
head = kl[: min(a.calib_tokens, len(kl))]
|
||||
head = head[np.isfinite(head)]
|
||||
denom = float(np.nanpercentile(head, 95)) if head.size else float("nan")
|
||||
if np.isfinite(denom) and denom > 1e-8:
|
||||
kl = kl / denom
|
||||
else:
|
||||
# no usable scale (e.g. alpha=0 -> KL==0); skip this traj
|
||||
continue
|
||||
visible_kls.extend([float(x) for x in kl if np.isfinite(x)])
|
||||
median_rows.append(np.pad(kl, (0, max(0, a.window - T)), constant_values=np.nan)[: a.window])
|
||||
alive = alive_mask_for_t(pmv, fork, T, a.threshold, gl)
|
||||
@@ -181,9 +210,9 @@ def main(a: Args):
|
||||
alpha=0.6, zorder=3)
|
||||
|
||||
if dead_segments:
|
||||
ax.add_collection(LineCollection(dead_segments, colors=dead_colors, linewidths=0.9, zorder=1))
|
||||
ax.add_collection(LineCollection(dead_segments, colors=dead_colors, linewidths=a.line_lw, zorder=1, capstyle="round", joinstyle="round", antialiaseds=True))
|
||||
if alive_segments:
|
||||
ax.add_collection(LineCollection(alive_segments, colors=alive_colors, linewidths=0.9, zorder=2))
|
||||
ax.add_collection(LineCollection(alive_segments, colors=alive_colors, linewidths=a.line_lw, zorder=2, capstyle="round", joinstyle="round", antialiaseds=True))
|
||||
if median_rows:
|
||||
mat = np.asarray(median_rows)
|
||||
finite_cols = np.where(np.isfinite(mat).any(axis=0))[0]
|
||||
@@ -195,30 +224,52 @@ def main(a: Args):
|
||||
ax.plot(med_x, med, color="black", lw=1.3, alpha=0.9, zorder=4)
|
||||
|
||||
y_hi = _kl_ymax(visible_kls)
|
||||
if y_hi >= 1.0:
|
||||
ax.axhline(1.0, color="black", lw=0.7, ls=":", label="KL=1 calib target")
|
||||
all_visible_kls.extend(visible_kls)
|
||||
# calib target: at alpha=1, KL p95 = 1 nat by construction; at general alpha,
|
||||
# expected KL ~ alpha^2 (small-step quadratic). Show as horizontal reference.
|
||||
# If normalize_kl: each traj is divided by its own p95(KL[:calib_tokens]),
|
||||
# so the target collapses to y=1 by construction (independent of alpha).
|
||||
if a.normalize_kl:
|
||||
target = 1.0
|
||||
target_label = rf"calib p95 (tokens<{a.calib_tokens}) = 1"
|
||||
else:
|
||||
ax.text(0.98, 0.92, "KL=1 off-scale", transform=ax.transAxes,
|
||||
ha="right", va="top", fontsize=7, color="0.25")
|
||||
try:
|
||||
target = float(alpha) ** 2
|
||||
except ValueError:
|
||||
target = float("nan")
|
||||
target_label = rf"calib target $\alpha^2$={target:.2g}"
|
||||
# axhline added unconditionally; off-scale handled later when global ylim is known
|
||||
ax.axhline(target, color="black", lw=0.9, ls="--", label=target_label) if np.isfinite(target) and target > 0 else None
|
||||
ax.set_title(rf"$\alpha={alpha}$ n={n} died={n_died} cens={n_censored}")
|
||||
ax.set_xlabel("token t")
|
||||
if j == 0:
|
||||
ax.set_ylabel("per-token KL")
|
||||
ax.set_ylabel("KL / p95(KL[:%d])" % a.calib_tokens if a.normalize_kl else "per-token KL")
|
||||
if a.kl_log:
|
||||
ax.set_yscale("symlog", linthresh=0.1)
|
||||
ax.set_xlim(-1, _panel_xmax(max_gen_len, a.window))
|
||||
ax.set_ylim(-y_hi * 0.05, y_hi)
|
||||
# data-driven y-lim sanity
|
||||
# rely on auto
|
||||
panel_state.append({"ax": ax, "target": target, "target_label": target_label})
|
||||
summary_rows.append({"alpha": alpha, "n": n, "n_died": n_died, "n_censored": n_censored})
|
||||
|
||||
# shared y-axis across all panels for direct cross-alpha comparison
|
||||
y_hi_global = _kl_ymax(all_visible_kls)
|
||||
for ps in panel_state:
|
||||
ax = ps["ax"]
|
||||
ax.set_ylim(-y_hi_global * 0.05, y_hi_global)
|
||||
target = ps["target"]
|
||||
if np.isfinite(target) and target > 0 and target > y_hi_global:
|
||||
ax.text(0.98, 0.92, f"{ps['target_label']} off-scale",
|
||||
transform=ax.transAxes, ha="right", va="top", fontsize=7, color="0.25")
|
||||
|
||||
|
||||
# legend on first panel only
|
||||
from matplotlib.lines import Line2D
|
||||
handles = [
|
||||
Line2D([0],[0], color=color_alive, lw=2, label=f"alive (pmass >= {a.threshold})"),
|
||||
Line2D([0],[0], color=color_dead, lw=2, label=f"dead (pmass < {a.threshold})"),
|
||||
Line2D([0],[0], color="black", marker="|", lw=0, label="EOS (right-censored)"),
|
||||
Line2D([0],[0], color="black", lw=0.7, ls=":", label="KL=1 target if in view"),
|
||||
Line2D([0],[0], color="black", lw=0.9, ls="--",
|
||||
label=(rf"calib p95 (tokens<{a.calib_tokens}) = 1" if a.normalize_kl
|
||||
else r"calib target $\alpha^2$")),
|
||||
]
|
||||
axes[0, 0].legend(handles=handles, loc="upper right", fontsize=7, frameon=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user