Files
isokl_steering_calibration/scripts/aggregate.py
T
wassname 77b296cc75 write up
2026-05-08 11:25:10 +08:00

417 lines
18 KiB
Python

"""Aggregate per-cell outputs into Figure 1 + headline table.
Layout (default --no-kl-only): 2 rows x N alpha cols.
Top row: KL(steered || base) per token on EVAL_PROMPTS.
Band mode (default): p50 line + p10..p90 band, rolling-{roll} smooth.
Spaghetti mode: individual trajectories.
default coloring: red if traj ever crosses KL=1, grey otherwise.
--color-by-pmass: segments coloured by paired pmass_eval(t) using a
red->yellow->green colormap (dead->alive). Requires
pmass.json with non-empty pmass_eval (run_cell.py
with --compute-pmass).
Bottom row: forked-answer pmass at fork_points.
Uses pmass[alpha] (legacy yes/no reasoning prompts). Different prompt set
than the KL panel above, so the row-to-row link is across cells, not
paired per-trajectory. For paired analysis see scripts/survival.py with
--metric pmass on pmass_eval.
Reading caveats:
- alpha=1 sits near KL=1 on CALIB_PROMPTS by construction. KL on EVAL_PROMPTS
only tests generalisation, not budget choice -- not an independent test.
- The honest coherence test is pmass / pmass_eval (real forced-choice mass
drop), see survival.py for KM-style curves.
Table: per (model, method, window) c_star mean +/- std across seeds.
Usage:
# smoothed band:
python scripts/aggregate.py --runs_root outputs --out figs/
# raw spaghetti:
python scripts/aggregate.py --runs_root outputs/qwen05_w512 \
--out figs/qwen05_pretty_raw --spaghetti --roll 1 \
--alphas 0.5 1.0 2.0 4.0
# KL spaghetti coloured by paired pmass_eval:
python scripts/aggregate.py --runs_root outputs/qwen05_w512 \
--out figs/qwen05_color --spaghetti --color-by-pmass
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import polars as pl
import tyro
from loguru import logger
from tabulate import tabulate
import matplotlib.pyplot as plt
try:
import seaborn as sns
sns.set_theme(context="notebook", style="whitegrid", palette="deep", font_scale=0.95)
plt.rcParams.update({
"axes.titlesize": 11,
"axes.labelsize": 10,
"axes.spines.top": False,
"axes.spines.right": False,
"figure.titlesize": 11,
"figure.dpi": 110,
})
except Exception:
plt.style.use("ggplot")
@dataclass
class Args:
runs_root: str = "outputs"
out: str = "figs"
window: int = 512 # only this window enters the figure
roll: int = 65 # smoothing window for KL trajectory
alphas: tuple[str, ...] = ("0.25", "0.5", "0.75", "1.0", "1.5", "2.0", "4.0")
kl_ymax: float = 6.0
model_contains: str = ""
kl_only: bool = False
spaghetti: bool = False # plot individual trajectories instead of p10..p90 band
color_by_pmass: bool = False # color KL spaghetti lines by paired pmass (requires pmass_eval)
line_alpha: float | None = None # per-line alpha override; None = auto clip(2.5/n,.08,.35)
line_lw: float = 0.18 # per-trajectory linewidth; full opacity needs very thin lines
median_lw: float = 0.75 # median linewidth
quantile_lines: bool = False # clean summary: p10/p50/p90 lines, no fill/spaghetti
mark_t: int = -1 # optional vertical token marker; -1 disables
def _rolling_mean(x: np.ndarray, w: int) -> np.ndarray:
"""Rolling mean along axis 0 with edge padding so output len == input len."""
if len(x) < w or w <= 1:
return x
pad = w // 2
xp = np.pad(x, pad, mode="edge")
kernel = np.ones(w) / w
return np.convolve(xp, kernel, mode="valid")[: len(x)]
def load_cells(root: Path, window: int, model_contains: str = "") -> list[dict]:
cells = []
for d in sorted(root.iterdir()):
if not d.is_dir() or d.name.startswith("_"):
continue
calib = d / "calib.json"; traj_p = d / "trajectory.json"; pm_p = d / "pmass.json"
if not (calib.exists() and traj_p.exists() and pm_p.exists()):
continue
meta = json.loads(calib.read_text())
if meta.get("window") != window:
continue
if model_contains and model_contains not in meta.get("model", ""):
continue
traj = json.loads(traj_p.read_text())
pm = json.loads(pm_p.read_text())
# Skip stale outputs that lack per-prompt KL (pre-redesign).
if "per_prompt_per_t_kl" not in traj:
logger.warning(f"skipping stale {d.name} (no per_prompt_per_t_kl)")
continue
cells.append({"id": d.name, **meta, "traj": traj, "pmass": pm})
return cells
def _draw_kl_panel(ax, K: np.ndarray, a: Args, P: np.ndarray | None = None) -> None:
"""Draw KL trajectories on ax. spaghetti mode: thin per-trajectory lines.
color_by_pmass: each segment colored by pmass(t) using a green->red colormap
(1.0 = lively green, 0.0 = dead red). Otherwise colored by whether traj
ever crossed KL=1. Smoothing per-line via _rolling_mean(roll). Black median.
"""
if not K.size:
return
finite_cols = np.where(np.isfinite(K).any(axis=0))[0]
if finite_cols.size == 0:
return
K = K[:, : finite_cols[-1] + 1]
if P is not None and P.size:
P = P[:, : K.shape[1]]
xs = np.arange(K.shape[1])
if a.spaghetti:
# optional per-line smoothing
Kp = np.array([_rolling_mean(row, a.roll) for row in K]) if a.roll > 1 else K
if a.color_by_pmass and P is not None and P.size and P.shape == K.shape:
# LineCollection per trajectory, color = pmass(t).
# Draw greenest (alive) first, reddest (dead) last so red sits on top
# of the green pile -- otherwise the eye loses the dead trajectories.
from matplotlib.collections import LineCollection
from matplotlib.colors import LinearSegmentedColormap
cmap = LinearSegmentedColormap.from_list("alive", ["#c0392b", "#f1c40f", "#27ae60"]) # dead->alive
order = np.argsort(-np.nanmean(P, axis=1)) # high pmass first, low pmass last
for traj, pmass_row in zip(Kp[order], P[order]):
pts = np.column_stack([xs, traj])
segs = np.stack([pts[:-1], pts[1:]], axis=1)
lc = LineCollection(segs, cmap=cmap, norm=plt.Normalize(0, 1),
linewidths=a.line_lw, alpha=float(a.line_alpha if a.line_alpha is not None else np.clip(2.5/max(K.shape[0],1), 0.08, 0.35)))
lc.set_array(pmass_row[:-1])
ax.add_collection(lc)
ax.set_xlim(xs[0], xs[-1])
med = np.nanmedian(Kp, axis=0)
ax.plot(xs, med, color="k", lw=a.median_lw)
else:
crossed = (K > 1.0).any(axis=1)
for traj in Kp[~crossed]:
ax.plot(xs, traj, color="0.55", lw=a.line_lw, alpha=0.5)
for traj in Kp[crossed]:
ax.plot(xs, traj, color="C3", lw=a.line_lw, alpha=0.5)
med = np.nanmedian(Kp, axis=0)
ax.plot(xs, med, color="k", lw=a.median_lw)
frac = float(crossed.mean())
ax.text(0.97, 0.97, f"{frac:.0%} cross KL=1",
transform=ax.transAxes, ha="right", va="top",
fontsize=8, color="C3" if frac > 0.5 else "0.3")
else:
p50 = np.nanpercentile(K, 50, axis=0)
p10 = np.nanpercentile(K, 10, axis=0)
p90 = np.nanpercentile(K, 90, axis=0)
p50s = _rolling_mean(p50, a.roll)
p10s = _rolling_mean(p10, a.roll)
p90s = _rolling_mean(p90, a.roll)
if a.quantile_lines:
# standard quantile fan: outer band light, inner band dark, p50 line
p25s = _rolling_mean(np.nanpercentile(K, 25, axis=0), a.roll)
p75s = _rolling_mean(np.nanpercentile(K, 75, axis=0), a.roll)
ax.fill_between(xs, p10s, p90s, alpha=0.15, color="C0", lw=0, label="p10..p90")
ax.fill_between(xs, p25s, p75s, alpha=0.32, color="C0", lw=0, label="p25..p75")
ax.plot(xs, p50s, color="C0", lw=1.5, label="p50")
else:
ax.fill_between(xs, p10s, p90s, alpha=0.25, color="C0", lw=0)
ax.plot(xs, p50s, color="C0", lw=1.6)
def make_kl_figure(cells: list[dict], a: Args, out_path: Path) -> None:
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, len(a.alphas), figsize=(3.7 * len(a.alphas), 3.4),
sharex=True, sharey=True, squeeze=False, constrained_layout=True)
label = a.model_contains or "all models"
n_max = max((_pool_kl(cells, alpha, T=a.window).shape[0] for alpha in a.alphas), default=0)
mode = f"individual trajectories, roll={a.roll}, color=pmass" if (a.spaghetti and a.color_by_pmass) \
else "individual trajectories (red=ever crossed KL=1)" if a.spaghetti \
else f"shaded quantiles p10/p25/p50/p75/p90, roll={a.roll}" if a.quantile_lines \
else f"p50 + p10..p90 band, smoothed rolling-{a.roll}"
fig.suptitle(
f"KL trajectory on N={n_max} held-out long-form prompts ({label})\n"
f"{mode}. Solid horizontal: KL=1 nat.",
fontsize=10,
)
x_stop = 1
y_data: list[float] = []
for alpha in a.alphas:
K = _pool_kl(cells, alpha, T=a.window)
finite_cols = np.where(np.isfinite(K).any(axis=0))[0]
if finite_cols.size:
x_stop = max(x_stop, int(finite_cols[-1] + 1))
vals = K[:, : finite_cols[-1] + 1]
y_data.extend([float(x) for x in vals.ravel() if np.isfinite(x) and x >= 0.0])
x_max = float(min(a.window, max(20, int(x_stop * 1.05))))
if y_data:
y_max = float(min(a.kl_ymax, max(1e-3, np.nanpercentile(np.asarray(y_data), 99) * 1.4)))
else:
y_max = 1.1
for j, alpha in enumerate(a.alphas):
ax = axes[0, j]
K = _pool_kl(cells, alpha, T=a.window)
Pe = _pool_pmass_eval(cells, alpha, _first_fork(cells), T=a.window) if a.color_by_pmass else None
_draw_kl_panel(ax, K, a, P=Pe)
if y_max >= 1.0:
ax.axhline(1.0, color="k", lw=0.7)
else:
ax.text(0.98, 0.92, "KL=1 off-scale", transform=ax.transAxes,
ha="right", va="top", fontsize=8, color="0.25")
if a.mark_t >= 0:
ax.axvline(a.mark_t, color="k", ls=":", lw=0.6)
ax.set_title(rf"$\alpha = {alpha}$ (n={K.shape[0]} traj)")
ax.set_xlim(0, x_max)
ax.set_ylim(0, y_max)
ax.set_xlabel("token")
if j == 0:
ax.set_ylabel("KL")
if a.color_by_pmass:
from matplotlib.cm import ScalarMappable
from matplotlib.colors import LinearSegmentedColormap, Normalize
cmap = LinearSegmentedColormap.from_list("alive", ["#c0392b", "#f1c40f", "#27ae60"])
sm = ScalarMappable(norm=Normalize(0, 1), cmap=cmap); sm.set_array([])
cbar = fig.colorbar(sm, ax=axes[0, :].tolist(), location="right", shrink=0.72, pad=0.015, fraction=0.012)
cbar.set_label("pmass", labelpad=2)
fig.savefig(out_path, dpi=160, bbox_inches="tight")
if a.quantile_lines:
q_path = out_path.with_name(out_path.stem + "_quantile_lines" + out_path.suffix)
fig.savefig(q_path, dpi=160, bbox_inches="tight")
logger.info(f"KL quantile-line figure -> {q_path}")
logger.info(f"KL-only figure -> {out_path}")
def make_table(root: Path) -> pl.DataFrame:
rows = []
for d in sorted(root.iterdir()):
if not d.is_dir() or d.name.startswith("_"):
continue
calib = d / "calib.json"
if not calib.exists():
continue
rows.append(json.loads(calib.read_text()) | {"run": d.name})
if not rows:
return pl.DataFrame()
df = pl.DataFrame(rows)
df = df.with_columns(pl.col("model").str.split("/").list.last().alias("model_short"))
g = (df.group_by(["model_short", "method", "window"])
.agg(pl.col("c_star").mean().alias("c_mean"),
pl.col("c_star").std().alias("c_std"),
pl.len().alias("n_seeds"))
.sort(["model_short", "method", "window"]))
g = g.with_columns(
(pl.col("c_std") / pl.col("c_mean").abs()).alias("c_cv"),
)
return g
def _pool_kl(cells: list[dict], alpha: str, T: int) -> np.ndarray:
"""Stack per-prompt KL trajectories from all cells -> (N, T) ndarray."""
rows = []
for c in cells:
per_prompt = c["traj"]["per_prompt_per_t_kl"].get(alpha, [])
for r in per_prompt:
arr = np.full(T, np.nan)
arr[: len(r)] = r[: T]
rows.append(arr)
return np.array(rows) if rows else np.zeros((0, T))
def _first_fork(cells: list[dict]) -> list[int]:
for c in cells:
if c["pmass"].get("computed", True):
return c["pmass"]["fork_points"]
return []
def _pool_pmass_eval(cells: list[dict], alpha: str, fork_points: list[int], T: int) -> np.ndarray:
"""Pool pmass on EVAL_PROMPTS (paired with KL) and interpolate from fork
points to per-token, returning (N, T). NaN for cells without pmass_eval.
"""
rows = []
for c in cells:
if not c["pmass"].get("computed", True):
continue
per_prompt = c["pmass"].get("pmass_eval", {}).get(alpha, [])
for r in per_prompt:
xs = np.array(fork_points[: len(r)])
ys = np.array(r, dtype=float)
t = np.arange(T)
interp = np.interp(t, xs, ys, left=ys[0], right=ys[-1])
rows.append(interp)
return np.array(rows) if rows else np.zeros((0, T))
def _pool_pmass(cells: list[dict], alpha: str) -> tuple[np.ndarray, list[int]]:
rows = []
fork = None
for c in cells:
if not c["pmass"].get("computed", True):
continue
f = c["pmass"]["fork_points"]
if fork is None: fork = f
per_prompt = c["pmass"]["pmass"].get(alpha, [])
for r in per_prompt:
rows.append(r)
if not rows:
return np.zeros((0, len(fork or []))), fork or []
L = max(len(r) for r in rows)
arr = np.full((len(rows), L), np.nan)
for i, r in enumerate(rows):
arr[i, : len(r)] = r
return arr, fork
def make_figure(cells: list[dict], a: Args, out_path: Path) -> None:
import matplotlib.pyplot as plt
fig, axes = plt.subplots(2, len(a.alphas), figsize=(4.0 * len(a.alphas), 5.5),
sharex="row", sharey="row", squeeze=False)
n_cells = len(cells)
fig.suptitle(
f"Figure 1: iso-KL calibration on Qwen2.5-0.5B-Instruct\n"
f"Top: KL(steered || base) per token on long-form held-out prompts.\n"
f"Bottom: forked-answer pmass on yes/no reasoning prompts (high=alive, low=collapsed).\n"
f"{n_cells} cells (3 methods x 1 seed) x N=8 prompts. "
f"Solid h-line: KL=1 nat. Dotted v-line: t=20.",
fontsize=10,
)
for j, alpha in enumerate(a.alphas):
# ---- top: KL ----
ax = axes[0, j]
K = _pool_kl(cells, alpha, T=a.window)
Pe = _pool_pmass_eval(cells, alpha, _first_fork(cells), T=a.window) if a.color_by_pmass else None
_draw_kl_panel(ax, K, a, P=Pe)
ax.axhline(1.0, color="k", lw=1.0)
ax.axvline(20, color="k", ls=":", lw=0.8)
ax.set_title(rf"$\alpha = {alpha}$ (n={K.shape[0]} traj)")
ax.set_ylim(0, a.kl_ymax)
if j == 0:
ax.set_ylabel("KL(steered || base) [nats]")
ax.set_xlabel("token")
# ---- bottom: pmass ----
ax2 = axes[1, j]
P, fork = _pool_pmass(cells, alpha)
if P.size:
xs = np.array(fork[: P.shape[1]])
if a.spaghetti:
for row in P:
ax2.plot(xs, row, color="C1", lw=0.6, alpha=0.5)
ax2.plot(xs, np.nanmedian(P, axis=0), color="k", lw=1.6, marker="o", ms=3)
else:
p50 = np.nanpercentile(P, 50, axis=0)
p10 = np.nanpercentile(P, 10, axis=0)
p90 = np.nanpercentile(P, 90, axis=0)
ax2.fill_between(xs, p10, p90, alpha=0.25, color="C1", lw=0)
ax2.plot(xs, p50, color="C1", lw=1.6, marker="o", ms=3)
ax2.axvline(20, color="k", ls=":", lw=0.8)
ax2.set_ylim(-0.02, 1.05)
if j == 0:
ax2.set_ylabel('pmass = p(true/1) + p(false/0)\nat fork t')
ax2.set_xlabel("fork token")
if a.color_by_pmass:
from matplotlib.cm import ScalarMappable
from matplotlib.colors import LinearSegmentedColormap, Normalize
cmap = LinearSegmentedColormap.from_list("alive", ["#c0392b", "#f1c40f", "#27ae60"])
sm = ScalarMappable(norm=Normalize(0, 1), cmap=cmap); sm.set_array([])
cbar = fig.colorbar(sm, ax=axes[0, :].tolist(), location="right", shrink=0.8, pad=0.02)
cbar.set_label("pmass (0=dead, 1=alive)")
fig.tight_layout(rect=(0, 0, 1, 0.88))
fig.savefig(out_path, dpi=140, bbox_inches="tight")
logger.info(f"figure -> {out_path}")
def main(a: Args):
out = Path(a.out); out.mkdir(parents=True, exist_ok=True)
root = Path(a.runs_root)
cells = load_cells(root, window=a.window, model_contains=a.model_contains)
if not cells:
raise SystemExit(f"no cells with window={a.window} under {root}")
logger.info(f"loaded {len(cells)} cells (window={a.window})")
df = make_table(root)
if not df.is_empty():
df.write_csv(out / "table.csv")
md = tabulate(df.rows(), headers=df.columns, tablefmt="pipe", floatfmt=".3f")
(out / "table.md").write_text(md)
logger.info(f"table -> {out/'table.md'}\n{md}")
if a.kl_only:
make_kl_figure(cells, a, out / "figure1_kl_only.png")
else:
make_figure(cells, a, out / "figure1.png")
if __name__ == "__main__":
main(tyro.cli(Args))