Files
isokl_steering_calibration/scripts/aggregate.py
T

371 lines
15 KiB
Python

"""Aggregate per-cell outputs into Figure 1 + headline table.
Layout (default --no-kl-only): 2 rows x N alpha cols.
Top row: KL(steered || base) per token on EVAL_PROMPTS.
Band mode (default): p50 line + p10..p90 band, rolling-{roll} smooth.
Spaghetti mode: individual trajectories.
default coloring: red if traj ever crosses KL=1, grey otherwise.
--color-by-pmass: segments coloured by paired pmass_eval(t) using a
red->yellow->green colormap (dead->alive). Requires
pmass.json with non-empty pmass_eval (run_cell.py
with --compute-pmass).
Bottom row: forked-answer pmass at fork_points.
Uses pmass[alpha] (legacy yes/no reasoning prompts). Different prompt set
than the KL panel above, so the row-to-row link is across cells, not
paired per-trajectory. For paired analysis see scripts/survival.py with
--metric pmass on pmass_eval.
Reading caveats:
- alpha=1 sits near KL=1 on CALIB_PROMPTS by construction. KL on EVAL_PROMPTS
only tests generalisation, not budget choice -- not an independent test.
- The honest coherence test is pmass / pmass_eval (real forced-choice mass
drop), see survival.py for KM-style curves.
Table: per (model, method, window) c_star mean +/- std across seeds.
Usage:
# smoothed band:
python scripts/aggregate.py --runs_root outputs --out figs/
# raw spaghetti:
python scripts/aggregate.py --runs_root outputs_qwen05_w512 \
--out figs_qwen05_pretty_raw --spaghetti --roll 1 \
--alphas 0.5 1.0 2.0 4.0
# KL spaghetti coloured by paired pmass_eval:
python scripts/aggregate.py --runs_root outputs_qwen05_w512 \
--out figs_qwen05_color --spaghetti --color-by-pmass
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import polars as pl
import tyro
from loguru import logger
from tabulate import tabulate
import matplotlib.pyplot as plt
try:
import seaborn as sns
sns.set_theme(context="notebook", style="whitegrid", palette="deep", font_scale=0.95)
plt.rcParams.update({
"axes.titlesize": 11,
"axes.labelsize": 10,
"axes.spines.top": False,
"axes.spines.right": False,
"figure.titlesize": 11,
"figure.dpi": 110,
})
except Exception:
plt.style.use("ggplot")
@dataclass
class Args:
runs_root: str = "outputs"
out: str = "figs"
window: int = 50 # only this window enters the figure (need long-enough traj for rolling-16)
roll: int = 16 # smoothing window for KL trajectory
alphas: tuple[str, ...] = ("0.25", "0.5", "0.75", "1.0", "1.5", "2.0", "4.0")
kl_ymax: float = 6.0
model_contains: str = ""
kl_only: bool = False
spaghetti: bool = False # plot individual trajectories instead of p10..p90 band
color_by_pmass: bool = False # color KL spaghetti lines by paired pmass (requires pmass_eval)
def _rolling_mean(x: np.ndarray, w: int) -> np.ndarray:
"""Rolling mean along axis 0 with edge padding so output len == input len."""
if len(x) < w or w <= 1:
return x
pad = w // 2
xp = np.pad(x, pad, mode="edge")
kernel = np.ones(w) / w
return np.convolve(xp, kernel, mode="valid")[: len(x)]
def load_cells(root: Path, window: int, model_contains: str = "") -> list[dict]:
cells = []
for d in sorted(root.iterdir()):
if not d.is_dir() or d.name.startswith("_"):
continue
calib = d / "calib.json"; traj_p = d / "trajectory.json"; pm_p = d / "pmass.json"
if not (calib.exists() and traj_p.exists() and pm_p.exists()):
continue
meta = json.loads(calib.read_text())
if meta.get("window") != window:
continue
if model_contains and model_contains not in meta.get("model", ""):
continue
traj = json.loads(traj_p.read_text())
pm = json.loads(pm_p.read_text())
# Skip stale outputs that lack per-prompt KL (pre-redesign).
if "per_prompt_per_t_kl" not in traj:
logger.warning(f"skipping stale {d.name} (no per_prompt_per_t_kl)")
continue
cells.append({"id": d.name, **meta, "traj": traj, "pmass": pm})
return cells
def _draw_kl_panel(ax, K: np.ndarray, a: Args, P: np.ndarray | None = None) -> None:
"""Draw KL trajectories on ax. spaghetti mode: thin per-trajectory lines.
color_by_pmass: each segment colored by pmass(t) using a green->red colormap
(1.0 = lively green, 0.0 = dead red). Otherwise colored by whether traj
ever crossed KL=1. Smoothing per-line via _rolling_mean(roll). Black median.
"""
if not K.size:
return
xs = np.arange(K.shape[1])
if a.spaghetti:
# optional per-line smoothing
Kp = np.array([_rolling_mean(row, a.roll) for row in K]) if a.roll > 1 else K
if a.color_by_pmass and P is not None and P.size and P.shape == K.shape:
# LineCollection per trajectory, color = pmass(t).
# Draw greenest (alive) first, reddest (dead) last so red sits on top
# of the green pile -- otherwise the eye loses the dead trajectories.
from matplotlib.collections import LineCollection
from matplotlib.colors import LinearSegmentedColormap
cmap = LinearSegmentedColormap.from_list("alive", ["#c0392b", "#f1c40f", "#27ae60"]) # dead->alive
order = np.argsort(-np.nanmean(P, axis=1)) # high pmass first, low pmass last
for traj, pmass_row in zip(Kp[order], P[order]):
pts = np.column_stack([xs, traj])
segs = np.stack([pts[:-1], pts[1:]], axis=1)
lc = LineCollection(segs, cmap=cmap, norm=plt.Normalize(0, 1),
linewidths=0.7, alpha=0.55)
lc.set_array(pmass_row[:-1])
ax.add_collection(lc)
ax.set_xlim(xs[0], xs[-1])
med = np.nanmedian(Kp, axis=0)
ax.plot(xs, med, color="k", lw=1.6)
else:
crossed = (K > 1.0).any(axis=1)
for traj in Kp[~crossed]:
ax.plot(xs, traj, color="0.55", lw=0.6, alpha=0.6)
for traj in Kp[crossed]:
ax.plot(xs, traj, color="C3", lw=0.6, alpha=0.6)
med = np.nanmedian(Kp, axis=0)
ax.plot(xs, med, color="k", lw=1.6)
frac = float(crossed.mean())
ax.text(0.97, 0.97, f"{frac:.0%} cross KL=1",
transform=ax.transAxes, ha="right", va="top",
fontsize=8, color="C3" if frac > 0.5 else "0.3")
else:
p50 = np.nanpercentile(K, 50, axis=0)
p10 = np.nanpercentile(K, 10, axis=0)
p90 = np.nanpercentile(K, 90, axis=0)
p50s = _rolling_mean(p50, a.roll)
p10s = _rolling_mean(p10, a.roll)
p90s = _rolling_mean(p90, a.roll)
ax.fill_between(xs, p10s, p90s, alpha=0.25, color="C0", lw=0)
ax.plot(xs, p50s, color="C0", lw=1.6)
def make_kl_figure(cells: list[dict], a: Args, out_path: Path) -> None:
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, len(a.alphas), figsize=(4.0 * len(a.alphas), 3.2),
sharex=True, sharey=True, squeeze=False)
label = a.model_contains or "all models"
mode = "individual trajectories (red=ever crossed KL=1)" if a.spaghetti \
else f"p50 + p10..p90 band, smoothed rolling-{a.roll}"
fig.suptitle(
f"KL trajectory on N=8 held-out long-form prompts ({label})\n"
f"{mode}. Solid line: KL=1 nat. Dotted v-line: t=20.",
fontsize=10,
)
for j, alpha in enumerate(a.alphas):
ax = axes[0, j]
K = _pool_kl(cells, alpha, T=a.window)
Pe = _pool_pmass_eval(cells, alpha, _first_fork(cells), T=a.window) if a.color_by_pmass else None
_draw_kl_panel(ax, K, a, P=Pe)
ax.axhline(1.0, color="k", lw=1.0)
ax.axvline(20, color="k", ls=":", lw=0.8)
ax.set_title(rf"$\alpha = {alpha}$ (n={K.shape[0]} traj)")
ax.set_ylim(0, a.kl_ymax)
ax.set_xlabel("token")
if j == 0:
ax.set_ylabel("KL(steered || base) [nats]")
if a.color_by_pmass:
from matplotlib.cm import ScalarMappable
from matplotlib.colors import LinearSegmentedColormap, Normalize
cmap = LinearSegmentedColormap.from_list("alive", ["#c0392b", "#f1c40f", "#27ae60"])
sm = ScalarMappable(norm=Normalize(0, 1), cmap=cmap); sm.set_array([])
cbar = fig.colorbar(sm, ax=axes[0, :].tolist(), location="right", shrink=0.8, pad=0.02)
cbar.set_label("pmass (0=dead, 1=alive)")
fig.tight_layout(rect=(0, 0, 1, 0.86))
fig.savefig(out_path, dpi=160, bbox_inches="tight")
logger.info(f"KL-only figure -> {out_path}")
def make_table(root: Path) -> pl.DataFrame:
rows = []
for d in sorted(root.iterdir()):
if not d.is_dir() or d.name.startswith("_"):
continue
calib = d / "calib.json"
if not calib.exists():
continue
rows.append(json.loads(calib.read_text()) | {"run": d.name})
if not rows:
return pl.DataFrame()
df = pl.DataFrame(rows)
df = df.with_columns(pl.col("model").str.split("/").list.last().alias("model_short"))
g = (df.group_by(["model_short", "method", "window"])
.agg(pl.col("c_star").mean().alias("c_mean"),
pl.col("c_star").std().alias("c_std"),
pl.len().alias("n_seeds"))
.sort(["model_short", "method", "window"]))
g = g.with_columns(
(pl.col("c_std") / pl.col("c_mean").abs()).alias("c_cv"),
)
return g
def _pool_kl(cells: list[dict], alpha: str, T: int) -> np.ndarray:
"""Stack per-prompt KL trajectories from all cells -> (N, T) ndarray."""
rows = []
for c in cells:
per_prompt = c["traj"]["per_prompt_per_t_kl"].get(alpha, [])
for r in per_prompt:
arr = np.full(T, np.nan)
arr[: len(r)] = r[: T]
rows.append(arr)
return np.array(rows) if rows else np.zeros((0, T))
def _first_fork(cells: list[dict]) -> list[int]:
for c in cells:
if c["pmass"].get("computed", True):
return c["pmass"]["fork_points"]
return []
def _pool_pmass_eval(cells: list[dict], alpha: str, fork_points: list[int], T: int) -> np.ndarray:
"""Pool pmass on EVAL_PROMPTS (paired with KL) and interpolate from fork
points to per-token, returning (N, T). NaN for cells without pmass_eval.
"""
rows = []
for c in cells:
if not c["pmass"].get("computed", True):
continue
per_prompt = c["pmass"].get("pmass_eval", {}).get(alpha, [])
for r in per_prompt:
xs = np.array(fork_points[: len(r)])
ys = np.array(r, dtype=float)
t = np.arange(T)
interp = np.interp(t, xs, ys, left=ys[0], right=ys[-1])
rows.append(interp)
return np.array(rows) if rows else np.zeros((0, T))
def _pool_pmass(cells: list[dict], alpha: str) -> tuple[np.ndarray, list[int]]:
rows = []
fork = None
for c in cells:
if not c["pmass"].get("computed", True):
continue
f = c["pmass"]["fork_points"]
if fork is None: fork = f
per_prompt = c["pmass"]["pmass"].get(alpha, [])
for r in per_prompt:
rows.append(r)
if not rows:
return np.zeros((0, len(fork or []))), fork or []
L = max(len(r) for r in rows)
arr = np.full((len(rows), L), np.nan)
for i, r in enumerate(rows):
arr[i, : len(r)] = r
return arr, fork
def make_figure(cells: list[dict], a: Args, out_path: Path) -> None:
import matplotlib.pyplot as plt
fig, axes = plt.subplots(2, len(a.alphas), figsize=(4.0 * len(a.alphas), 5.5),
sharex="row", sharey="row", squeeze=False)
n_cells = len(cells)
fig.suptitle(
f"Figure 1: iso-KL calibration on Qwen2.5-0.5B-Instruct\n"
f"Top: KL(steered || base) per token on long-form held-out prompts.\n"
f"Bottom: forked-answer pmass on yes/no reasoning prompts (high=alive, low=collapsed).\n"
f"{n_cells} cells (3 methods x 1 seed) x N=8 prompts. "
f"Solid h-line: KL=1 nat. Dotted v-line: t=20.",
fontsize=10,
)
for j, alpha in enumerate(a.alphas):
# ---- top: KL ----
ax = axes[0, j]
K = _pool_kl(cells, alpha, T=a.window)
Pe = _pool_pmass_eval(cells, alpha, _first_fork(cells), T=a.window) if a.color_by_pmass else None
_draw_kl_panel(ax, K, a, P=Pe)
ax.axhline(1.0, color="k", lw=1.0)
ax.axvline(20, color="k", ls=":", lw=0.8)
ax.set_title(rf"$\alpha = {alpha}$ (n={K.shape[0]} traj)")
ax.set_ylim(0, a.kl_ymax)
if j == 0:
ax.set_ylabel("KL(steered || base) [nats]")
ax.set_xlabel("token")
# ---- bottom: pmass ----
ax2 = axes[1, j]
P, fork = _pool_pmass(cells, alpha)
if P.size:
xs = np.array(fork[: P.shape[1]])
if a.spaghetti:
for row in P:
ax2.plot(xs, row, color="C1", lw=0.6, alpha=0.5)
ax2.plot(xs, np.nanmedian(P, axis=0), color="k", lw=1.6, marker="o", ms=3)
else:
p50 = np.nanpercentile(P, 50, axis=0)
p10 = np.nanpercentile(P, 10, axis=0)
p90 = np.nanpercentile(P, 90, axis=0)
ax2.fill_between(xs, p10, p90, alpha=0.25, color="C1", lw=0)
ax2.plot(xs, p50, color="C1", lw=1.6, marker="o", ms=3)
ax2.axvline(20, color="k", ls=":", lw=0.8)
ax2.set_ylim(-0.02, 1.05)
if j == 0:
ax2.set_ylabel('pmass = p(true/1) + p(false/0)\nat fork t')
ax2.set_xlabel("fork token")
if a.color_by_pmass:
from matplotlib.cm import ScalarMappable
from matplotlib.colors import LinearSegmentedColormap, Normalize
cmap = LinearSegmentedColormap.from_list("alive", ["#c0392b", "#f1c40f", "#27ae60"])
sm = ScalarMappable(norm=Normalize(0, 1), cmap=cmap); sm.set_array([])
cbar = fig.colorbar(sm, ax=axes[0, :].tolist(), location="right", shrink=0.8, pad=0.02)
cbar.set_label("pmass (0=dead, 1=alive)")
fig.tight_layout(rect=(0, 0, 1, 0.88))
fig.savefig(out_path, dpi=140, bbox_inches="tight")
logger.info(f"figure -> {out_path}")
def main(a: Args):
out = Path(a.out); out.mkdir(parents=True, exist_ok=True)
root = Path(a.runs_root)
cells = load_cells(root, window=a.window, model_contains=a.model_contains)
if not cells:
raise SystemExit(f"no cells with window={a.window} under {root}")
logger.info(f"loaded {len(cells)} cells (window={a.window})")
df = make_table(root)
if not df.is_empty():
df.write_csv(out / "table.csv")
md = tabulate(df.rows(), headers=df.columns, tablefmt="pipe", floatfmt=".3f")
(out / "table.md").write_text(md)
logger.info(f"table -> {out/'table.md'}\n{md}")
if a.kl_only:
make_kl_figure(cells, a, out / "figure1_kl_only.png")
else:
make_figure(cells, a, out / "figure1.png")
if __name__ == "__main__":
main(tyro.cli(Args))