diff --git a/README.md b/README.md index 96c2fea..da8f2a5 100644 --- a/README.md +++ b/README.md @@ -98,6 +98,44 @@ Out of scope: - A norm-matching baseline. Iso-KL says nothing about whether two methods of equal KL move the same distance in activation space. - A threshold sweep. The pmass < 0.95 cutoff is a single point; neighbours might disagree. +## Per-trajectory KL normalization (OLMo-2 1B, w=4096) + +The figures above use raw per-token KL in nats. That works when every panel shares an alpha-scale, but cross-alpha comparison gets harder once trajectories vary in their own baseline KL noise (some prompts are just spikier than others, even at alpha = 1). + +So we re-plot KL divided by each trajectory's own p95 over the first 20 tokens. The dashed reference at y = 1 is then the calibration scale by construction. Values above 1 are excursions above that trajectory's own early-rollout budget. The y-axis is shared across panels so the alpha sweep reads as one picture. + +![KL/p95 trajectories, OLMo-2 1B, w=4096, n=24](figs/spaghetti_olmo2_1b_w4096_norm.png) + +*Per-token `KL(steered || base) / p95(KL[:20])` for OLMo-2-0425-1B (mean_diff, w=4096, n=24 held-out prompts). One panel per alpha. Blue = alive at that token (`pmass_eval >= 0.75` at the nearest fork s <= t); orange = dead (irreversible). Black tick = EOS (right-censored). Dashed = calibration scale = 1. Black line = median across alive+dead trajectories. Threshold here is 0.75, looser than the 0.95 used for Qwen, because OLMo's early pmass is noisier.* + +How to read it, in order: + +1. *alpha = 0.* No steering, KL = 0, denom = 0, all trajectories skipped. Empty panel is correct. +2. *alpha = 0.25 to 0.75.* Median sits well below 1. Most trajectories never approach the calibration scale even past 4000 tokens. The 1-nat budget set on a w = 4096 calibration window has long-horizon headroom for OLMo-2 1B. +3. *alpha = 1.0.* Median hovers a bit under 1 with one outlier traveller above 2. Iso-KL is doing what it claims at the trajectory-median level on a different prompt set than calibration saw. +4. *alpha = 1.5 to 2.0.* Median crosses 1 and stays there. Most trajectories die. KL is now sustainably over budget. +5. *alpha = 4.0.* Every trajectory dies. Median rides above 1 throughout. Same instant-collapse story as the Qwen panel. + +What surprised me, and what doesn't yet have an explanation: + +- *Random dead traces at low alpha.* At alpha = 0.25 to 0.75, KL stays comfortably under 1 for almost every token, and yet 7 to 12 of 24 trajectories register a death event at some fork. They are not dying because the budget was exceeded; they are dying for some other reason that p95 KL does not see. Possibilities: a single high-KL spike at one token that p95 averages away; a low-KL but format-fragile region where small logit shifts knock pmass below 0.75; calibration noise on the pmass side (threshold 0.75 is loose, and OLMo's base pmass at some forks is already close to it). I don't know which yet. +- *alpha = 1.0 has fewer deaths than alpha = 0.75 (11 vs 12).* Within noise at n = 24, but the monotonicity of "more steering = more death" is not crisp at this scale. +- *Censoring drops to ~0 above alpha = 0.75.* Steering past the calibrated dose seems to suppress EOS, so trajectories run to the window length even when off-format. That makes died counts more comparable across high-alpha panels but means low-alpha "censored" is partly the model's own brevity, not the steering's. + +So: per-trajectory normalization makes the cross-alpha comparison cleaner, and confirms iso-KL is roughly delivering its claimed dose on held-out prompts. It also makes the "KL is necessary but not sufficient" point sharper: the low-alpha panels show death events at well under the calibration scale. + +Reproduce: + +```bash +uv run --extra all python scripts/spaghetti_kl_alive.py \ + --runs-root outputs/olmo1b_w4096_dense/_OLMo-2-0425-1B_mean_diff_s0_w4096_dense_n24_single \ + --out outputs/olmo1b_w4096_dense/OLMo-2-0425-1B_mean_diff_s0_w4096_dense_n24/figs_thr075_quantile_shaded/spaghetti_norm \ + --window 4096 --threshold 0.75 --model-contains OLMo-2-0425-1B \ + --roll 301 --line-lw 0.4 +``` + +Toggle off normalization with `--no-normalize-kl`; change the calibration window with `--calib-tokens N`. + ## Honest caveats - *Iso-KL is one fairness criterion, not the criterion.* Matching p95 per-token KL means matching distributional disagreement under greedy decoding in the calibration window. It does not match intervention norm, behavioural effect size, or human-perceived quality. diff --git a/figs/spaghetti_olmo2_1b_w4096_norm.png b/figs/spaghetti_olmo2_1b_w4096_norm.png new file mode 100644 index 0000000..3377a38 Binary files /dev/null and b/figs/spaghetti_olmo2_1b_w4096_norm.png differ diff --git a/scripts/spaghetti_kl_alive.py b/scripts/spaghetti_kl_alive.py index 1618532..a077d47 100644 --- a/scripts/spaghetti_kl_alive.py +++ b/scripts/spaghetti_kl_alive.py @@ -37,13 +37,17 @@ from matplotlib.colors import to_rgba try: import seaborn as sns - sns.set_theme(context="notebook", style="whitegrid", palette="deep", font_scale=0.9) + sns.set_theme(context="notebook", style="white", palette="deep", font_scale=0.9) plt.rcParams.update({ "axes.titlesize": 10, "axes.labelsize": 9, "axes.spines.top": False, "axes.spines.right": False, + "axes.grid": False, + "figure.facecolor": "#faf7f2", + "axes.facecolor": "#faf7f2", + "savefig.facecolor": "#faf7f2", }) except Exception: - plt.style.use("ggplot") + plt.style.use("default") @dataclass @@ -56,7 +60,11 @@ class Args: metric: str = "pmass_eval" # 'pmass_eval' is paired with KL prompts model_contains: str = "Qwen3.5-0.8B" kl_log: bool = False - roll: int = 11 # smooth KL a bit so the spaghetti is readable + roll: int = 65 # smooth KL a lot so the spaghetti is readable + line_lw: float = 0.5 # per-trajectory linewidth -- thick enough to see + line_alpha: float | None = None # if None, scale by cohort size: clamp(8/n, .2, .7) + normalize_kl: bool = True # divide each KL trajectory by p95(KL[:calib_tokens]) + calib_tokens: int = 20 # window over which to compute the per-traj p95 denom def load_cell(d: Path, alpha: str, T: int): @@ -136,16 +144,28 @@ def main(a: Args): n_panels = len(a.alphas) fig, axes = plt.subplots(1, n_panels, figsize=(3.0 * n_panels, 3.4), sharey=False, squeeze=False) - color_alive = to_rgba("#1a9850", 0.65) # green, translucent so overlap stays readable - color_dead = to_rgba("#d7191c", 0.55) # stronger red so all-dead panels are visible + # alpha-per-line is computed per panel from the cohort size so n=32 is not a paintbrush + # paper-friendly Wong-ish pair: muted teal-blue (alive) vs warm orange (dead). + # colourblind-safe, prints well, doesn't scream "stop sign". + GREEN = "#3a7ca5" # steel teal -- "alive / in budget" + RED = "#d97706" # warm amber -- "dead / out of budget" summary_rows = [] + panel_state: list[dict] = [] # collected per-panel info; ylim applied after loop + all_visible_kls: list[float] = [] for j, alpha in enumerate(a.alphas): ax = axes[0, j] all_trajs = [] for d in cells: all_trajs.extend(load_cell(d, alpha, a.window)) n = len(all_trajs) + # compute per-line alpha for this panel (more lines -> more transparent) + if a.line_alpha is not None: + la = float(a.line_alpha) + else: + la = float(np.clip(8.0 / max(n, 1), 0.20, 0.70)) + color_alive = to_rgba(GREEN, la) + color_dead = to_rgba(RED, la) n_died = 0 n_censored = 0 median_rows = [] @@ -160,6 +180,15 @@ def main(a: Args): T = min(len(kl), gl) max_gen_len = max(max_gen_len, T) kl = _rolling_mean(kl[:T], a.roll) + if a.normalize_kl: + head = kl[: min(a.calib_tokens, len(kl))] + head = head[np.isfinite(head)] + denom = float(np.nanpercentile(head, 95)) if head.size else float("nan") + if np.isfinite(denom) and denom > 1e-8: + kl = kl / denom + else: + # no usable scale (e.g. alpha=0 -> KL==0); skip this traj + continue visible_kls.extend([float(x) for x in kl if np.isfinite(x)]) median_rows.append(np.pad(kl, (0, max(0, a.window - T)), constant_values=np.nan)[: a.window]) alive = alive_mask_for_t(pmv, fork, T, a.threshold, gl) @@ -181,9 +210,9 @@ def main(a: Args): alpha=0.6, zorder=3) if dead_segments: - ax.add_collection(LineCollection(dead_segments, colors=dead_colors, linewidths=0.9, zorder=1)) + ax.add_collection(LineCollection(dead_segments, colors=dead_colors, linewidths=a.line_lw, zorder=1, capstyle="round", joinstyle="round", antialiaseds=True)) if alive_segments: - ax.add_collection(LineCollection(alive_segments, colors=alive_colors, linewidths=0.9, zorder=2)) + ax.add_collection(LineCollection(alive_segments, colors=alive_colors, linewidths=a.line_lw, zorder=2, capstyle="round", joinstyle="round", antialiaseds=True)) if median_rows: mat = np.asarray(median_rows) finite_cols = np.where(np.isfinite(mat).any(axis=0))[0] @@ -195,30 +224,52 @@ def main(a: Args): ax.plot(med_x, med, color="black", lw=1.3, alpha=0.9, zorder=4) y_hi = _kl_ymax(visible_kls) - if y_hi >= 1.0: - ax.axhline(1.0, color="black", lw=0.7, ls=":", label="KL=1 calib target") + all_visible_kls.extend(visible_kls) + # calib target: at alpha=1, KL p95 = 1 nat by construction; at general alpha, + # expected KL ~ alpha^2 (small-step quadratic). Show as horizontal reference. + # If normalize_kl: each traj is divided by its own p95(KL[:calib_tokens]), + # so the target collapses to y=1 by construction (independent of alpha). + if a.normalize_kl: + target = 1.0 + target_label = rf"calib p95 (tokens<{a.calib_tokens}) = 1" else: - ax.text(0.98, 0.92, "KL=1 off-scale", transform=ax.transAxes, - ha="right", va="top", fontsize=7, color="0.25") + try: + target = float(alpha) ** 2 + except ValueError: + target = float("nan") + target_label = rf"calib target $\alpha^2$={target:.2g}" + # axhline added unconditionally; off-scale handled later when global ylim is known + ax.axhline(target, color="black", lw=0.9, ls="--", label=target_label) if np.isfinite(target) and target > 0 else None ax.set_title(rf"$\alpha={alpha}$ n={n} died={n_died} cens={n_censored}") ax.set_xlabel("token t") if j == 0: - ax.set_ylabel("per-token KL") + ax.set_ylabel("KL / p95(KL[:%d])" % a.calib_tokens if a.normalize_kl else "per-token KL") if a.kl_log: ax.set_yscale("symlog", linthresh=0.1) ax.set_xlim(-1, _panel_xmax(max_gen_len, a.window)) - ax.set_ylim(-y_hi * 0.05, y_hi) - # data-driven y-lim sanity - # rely on auto + panel_state.append({"ax": ax, "target": target, "target_label": target_label}) summary_rows.append({"alpha": alpha, "n": n, "n_died": n_died, "n_censored": n_censored}) + # shared y-axis across all panels for direct cross-alpha comparison + y_hi_global = _kl_ymax(all_visible_kls) + for ps in panel_state: + ax = ps["ax"] + ax.set_ylim(-y_hi_global * 0.05, y_hi_global) + target = ps["target"] + if np.isfinite(target) and target > 0 and target > y_hi_global: + ax.text(0.98, 0.92, f"{ps['target_label']} off-scale", + transform=ax.transAxes, ha="right", va="top", fontsize=7, color="0.25") + + # legend on first panel only from matplotlib.lines import Line2D handles = [ Line2D([0],[0], color=color_alive, lw=2, label=f"alive (pmass >= {a.threshold})"), Line2D([0],[0], color=color_dead, lw=2, label=f"dead (pmass < {a.threshold})"), Line2D([0],[0], color="black", marker="|", lw=0, label="EOS (right-censored)"), - Line2D([0],[0], color="black", lw=0.7, ls=":", label="KL=1 target if in view"), + Line2D([0],[0], color="black", lw=0.9, ls="--", + label=(rf"calib p95 (tokens<{a.calib_tokens}) = 1" if a.normalize_kl + else r"calib target $\alpha^2$")), ] axes[0, 0].legend(handles=handles, loc="upper right", fontsize=7, frameon=True)