mirror of
https://github.com/wassname/isokl_steering_calibration.git
synced 2026-06-27 18:05:12 +08:00
417 lines
18 KiB
Python
417 lines
18 KiB
Python
"""Aggregate per-cell outputs into Figure 1 + headline table.
|
|
|
|
Layout (default --no-kl-only): 2 rows x N alpha cols.
|
|
Top row: KL(steered || base) per token on EVAL_PROMPTS.
|
|
Band mode (default): p50 line + p10..p90 band, rolling-{roll} smooth.
|
|
Spaghetti mode: individual trajectories.
|
|
default coloring: red if traj ever crosses KL=1, grey otherwise.
|
|
--color-by-pmass: segments coloured by paired pmass_eval(t) using a
|
|
red->yellow->green colormap (dead->alive). Requires
|
|
pmass.json with non-empty pmass_eval (run_cell.py
|
|
with --compute-pmass).
|
|
Bottom row: forked-answer pmass at fork_points.
|
|
Uses pmass[alpha] (legacy yes/no reasoning prompts). Different prompt set
|
|
than the KL panel above, so the row-to-row link is across cells, not
|
|
paired per-trajectory. For paired analysis see scripts/survival.py with
|
|
--metric pmass on pmass_eval.
|
|
|
|
Reading caveats:
|
|
- alpha=1 sits near KL=1 on CALIB_PROMPTS by construction. KL on EVAL_PROMPTS
|
|
only tests generalisation, not budget choice -- not an independent test.
|
|
- The honest coherence test is pmass / pmass_eval (real forced-choice mass
|
|
drop), see survival.py for KM-style curves.
|
|
|
|
Table: per (model, method, window) c_star mean +/- std across seeds.
|
|
|
|
Usage:
|
|
# smoothed band:
|
|
python scripts/aggregate.py --runs_root outputs --out figs/
|
|
# raw spaghetti:
|
|
python scripts/aggregate.py --runs_root outputs/qwen05_w512 \
|
|
--out figs/qwen05_pretty_raw --spaghetti --roll 1 \
|
|
--alphas 0.5 1.0 2.0 4.0
|
|
# KL spaghetti coloured by paired pmass_eval:
|
|
python scripts/aggregate.py --runs_root outputs/qwen05_w512 \
|
|
--out figs/qwen05_color --spaghetti --color-by-pmass
|
|
"""
|
|
from __future__ import annotations
|
|
import json
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import polars as pl
|
|
import tyro
|
|
from loguru import logger
|
|
from tabulate import tabulate
|
|
import matplotlib.pyplot as plt
|
|
|
|
try:
|
|
import seaborn as sns
|
|
sns.set_theme(context="notebook", style="whitegrid", palette="deep", font_scale=0.95)
|
|
plt.rcParams.update({
|
|
"axes.titlesize": 11,
|
|
"axes.labelsize": 10,
|
|
"axes.spines.top": False,
|
|
"axes.spines.right": False,
|
|
"figure.titlesize": 11,
|
|
"figure.dpi": 110,
|
|
})
|
|
except Exception:
|
|
plt.style.use("ggplot")
|
|
|
|
|
|
@dataclass
|
|
class Args:
|
|
runs_root: str = "outputs"
|
|
out: str = "figs"
|
|
window: int = 512 # only this window enters the figure
|
|
roll: int = 65 # smoothing window for KL trajectory
|
|
alphas: tuple[str, ...] = ("0.25", "0.5", "0.75", "1.0", "1.5", "2.0", "4.0")
|
|
kl_ymax: float = 6.0
|
|
model_contains: str = ""
|
|
kl_only: bool = False
|
|
spaghetti: bool = False # plot individual trajectories instead of p10..p90 band
|
|
color_by_pmass: bool = False # color KL spaghetti lines by paired pmass (requires pmass_eval)
|
|
line_alpha: float | None = None # per-line alpha override; None = auto clip(2.5/n,.08,.35)
|
|
line_lw: float = 0.18 # per-trajectory linewidth; full opacity needs very thin lines
|
|
median_lw: float = 0.75 # median linewidth
|
|
quantile_lines: bool = False # clean summary: p10/p50/p90 lines, no fill/spaghetti
|
|
mark_t: int = -1 # optional vertical token marker; -1 disables
|
|
|
|
|
|
def _rolling_mean(x: np.ndarray, w: int) -> np.ndarray:
|
|
"""Rolling mean along axis 0 with edge padding so output len == input len."""
|
|
if len(x) < w or w <= 1:
|
|
return x
|
|
pad = w // 2
|
|
xp = np.pad(x, pad, mode="edge")
|
|
kernel = np.ones(w) / w
|
|
return np.convolve(xp, kernel, mode="valid")[: len(x)]
|
|
|
|
|
|
def load_cells(root: Path, window: int, model_contains: str = "") -> list[dict]:
|
|
cells = []
|
|
for d in sorted(root.iterdir()):
|
|
if not d.is_dir() or d.name.startswith("_"):
|
|
continue
|
|
calib = d / "calib.json"; traj_p = d / "trajectory.json"; pm_p = d / "pmass.json"
|
|
if not (calib.exists() and traj_p.exists() and pm_p.exists()):
|
|
continue
|
|
meta = json.loads(calib.read_text())
|
|
if meta.get("window") != window:
|
|
continue
|
|
if model_contains and model_contains not in meta.get("model", ""):
|
|
continue
|
|
traj = json.loads(traj_p.read_text())
|
|
pm = json.loads(pm_p.read_text())
|
|
# Skip stale outputs that lack per-prompt KL (pre-redesign).
|
|
if "per_prompt_per_t_kl" not in traj:
|
|
logger.warning(f"skipping stale {d.name} (no per_prompt_per_t_kl)")
|
|
continue
|
|
cells.append({"id": d.name, **meta, "traj": traj, "pmass": pm})
|
|
return cells
|
|
|
|
|
|
def _draw_kl_panel(ax, K: np.ndarray, a: Args, P: np.ndarray | None = None) -> None:
|
|
"""Draw KL trajectories on ax. spaghetti mode: thin per-trajectory lines.
|
|
color_by_pmass: each segment colored by pmass(t) using a green->red colormap
|
|
(1.0 = lively green, 0.0 = dead red). Otherwise colored by whether traj
|
|
ever crossed KL=1. Smoothing per-line via _rolling_mean(roll). Black median.
|
|
"""
|
|
if not K.size:
|
|
return
|
|
finite_cols = np.where(np.isfinite(K).any(axis=0))[0]
|
|
if finite_cols.size == 0:
|
|
return
|
|
K = K[:, : finite_cols[-1] + 1]
|
|
if P is not None and P.size:
|
|
P = P[:, : K.shape[1]]
|
|
xs = np.arange(K.shape[1])
|
|
if a.spaghetti:
|
|
# optional per-line smoothing
|
|
Kp = np.array([_rolling_mean(row, a.roll) for row in K]) if a.roll > 1 else K
|
|
if a.color_by_pmass and P is not None and P.size and P.shape == K.shape:
|
|
# LineCollection per trajectory, color = pmass(t).
|
|
# Draw greenest (alive) first, reddest (dead) last so red sits on top
|
|
# of the green pile -- otherwise the eye loses the dead trajectories.
|
|
from matplotlib.collections import LineCollection
|
|
from matplotlib.colors import LinearSegmentedColormap
|
|
cmap = LinearSegmentedColormap.from_list("alive", ["#c0392b", "#f1c40f", "#27ae60"]) # dead->alive
|
|
order = np.argsort(-np.nanmean(P, axis=1)) # high pmass first, low pmass last
|
|
for traj, pmass_row in zip(Kp[order], P[order]):
|
|
pts = np.column_stack([xs, traj])
|
|
segs = np.stack([pts[:-1], pts[1:]], axis=1)
|
|
lc = LineCollection(segs, cmap=cmap, norm=plt.Normalize(0, 1),
|
|
linewidths=a.line_lw, alpha=float(a.line_alpha if a.line_alpha is not None else np.clip(2.5/max(K.shape[0],1), 0.08, 0.35)))
|
|
lc.set_array(pmass_row[:-1])
|
|
ax.add_collection(lc)
|
|
ax.set_xlim(xs[0], xs[-1])
|
|
med = np.nanmedian(Kp, axis=0)
|
|
ax.plot(xs, med, color="k", lw=a.median_lw)
|
|
else:
|
|
crossed = (K > 1.0).any(axis=1)
|
|
for traj in Kp[~crossed]:
|
|
ax.plot(xs, traj, color="0.55", lw=a.line_lw, alpha=0.5)
|
|
for traj in Kp[crossed]:
|
|
ax.plot(xs, traj, color="C3", lw=a.line_lw, alpha=0.5)
|
|
med = np.nanmedian(Kp, axis=0)
|
|
ax.plot(xs, med, color="k", lw=a.median_lw)
|
|
frac = float(crossed.mean())
|
|
ax.text(0.97, 0.97, f"{frac:.0%} cross KL=1",
|
|
transform=ax.transAxes, ha="right", va="top",
|
|
fontsize=8, color="C3" if frac > 0.5 else "0.3")
|
|
else:
|
|
p50 = np.nanpercentile(K, 50, axis=0)
|
|
p10 = np.nanpercentile(K, 10, axis=0)
|
|
p90 = np.nanpercentile(K, 90, axis=0)
|
|
p50s = _rolling_mean(p50, a.roll)
|
|
p10s = _rolling_mean(p10, a.roll)
|
|
p90s = _rolling_mean(p90, a.roll)
|
|
if a.quantile_lines:
|
|
# standard quantile fan: outer band light, inner band dark, p50 line
|
|
p25s = _rolling_mean(np.nanpercentile(K, 25, axis=0), a.roll)
|
|
p75s = _rolling_mean(np.nanpercentile(K, 75, axis=0), a.roll)
|
|
ax.fill_between(xs, p10s, p90s, alpha=0.15, color="C0", lw=0, label="p10..p90")
|
|
ax.fill_between(xs, p25s, p75s, alpha=0.32, color="C0", lw=0, label="p25..p75")
|
|
ax.plot(xs, p50s, color="C0", lw=1.5, label="p50")
|
|
else:
|
|
ax.fill_between(xs, p10s, p90s, alpha=0.25, color="C0", lw=0)
|
|
ax.plot(xs, p50s, color="C0", lw=1.6)
|
|
|
|
|
|
def make_kl_figure(cells: list[dict], a: Args, out_path: Path) -> None:
|
|
import matplotlib.pyplot as plt
|
|
|
|
fig, axes = plt.subplots(1, len(a.alphas), figsize=(3.7 * len(a.alphas), 3.4),
|
|
sharex=True, sharey=True, squeeze=False, constrained_layout=True)
|
|
label = a.model_contains or "all models"
|
|
n_max = max((_pool_kl(cells, alpha, T=a.window).shape[0] for alpha in a.alphas), default=0)
|
|
mode = f"individual trajectories, roll={a.roll}, color=pmass" if (a.spaghetti and a.color_by_pmass) \
|
|
else "individual trajectories (red=ever crossed KL=1)" if a.spaghetti \
|
|
else f"shaded quantiles p10/p25/p50/p75/p90, roll={a.roll}" if a.quantile_lines \
|
|
else f"p50 + p10..p90 band, smoothed rolling-{a.roll}"
|
|
fig.suptitle(
|
|
f"KL trajectory on N={n_max} held-out long-form prompts ({label})\n"
|
|
f"{mode}. Solid horizontal: KL=1 nat.",
|
|
fontsize=10,
|
|
)
|
|
|
|
x_stop = 1
|
|
y_data: list[float] = []
|
|
for alpha in a.alphas:
|
|
K = _pool_kl(cells, alpha, T=a.window)
|
|
finite_cols = np.where(np.isfinite(K).any(axis=0))[0]
|
|
if finite_cols.size:
|
|
x_stop = max(x_stop, int(finite_cols[-1] + 1))
|
|
vals = K[:, : finite_cols[-1] + 1]
|
|
y_data.extend([float(x) for x in vals.ravel() if np.isfinite(x) and x >= 0.0])
|
|
x_max = float(min(a.window, max(20, int(x_stop * 1.05))))
|
|
if y_data:
|
|
y_max = float(min(a.kl_ymax, max(1e-3, np.nanpercentile(np.asarray(y_data), 99) * 1.4)))
|
|
else:
|
|
y_max = 1.1
|
|
|
|
for j, alpha in enumerate(a.alphas):
|
|
ax = axes[0, j]
|
|
K = _pool_kl(cells, alpha, T=a.window)
|
|
Pe = _pool_pmass_eval(cells, alpha, _first_fork(cells), T=a.window) if a.color_by_pmass else None
|
|
_draw_kl_panel(ax, K, a, P=Pe)
|
|
if y_max >= 1.0:
|
|
ax.axhline(1.0, color="k", lw=0.7)
|
|
else:
|
|
ax.text(0.98, 0.92, "KL=1 off-scale", transform=ax.transAxes,
|
|
ha="right", va="top", fontsize=8, color="0.25")
|
|
if a.mark_t >= 0:
|
|
ax.axvline(a.mark_t, color="k", ls=":", lw=0.6)
|
|
ax.set_title(rf"$\alpha = {alpha}$ (n={K.shape[0]} traj)")
|
|
ax.set_xlim(0, x_max)
|
|
ax.set_ylim(0, y_max)
|
|
ax.set_xlabel("token")
|
|
if j == 0:
|
|
ax.set_ylabel("KL")
|
|
|
|
if a.color_by_pmass:
|
|
from matplotlib.cm import ScalarMappable
|
|
from matplotlib.colors import LinearSegmentedColormap, Normalize
|
|
cmap = LinearSegmentedColormap.from_list("alive", ["#c0392b", "#f1c40f", "#27ae60"])
|
|
sm = ScalarMappable(norm=Normalize(0, 1), cmap=cmap); sm.set_array([])
|
|
cbar = fig.colorbar(sm, ax=axes[0, :].tolist(), location="right", shrink=0.72, pad=0.015, fraction=0.012)
|
|
cbar.set_label("pmass", labelpad=2)
|
|
|
|
fig.savefig(out_path, dpi=160, bbox_inches="tight")
|
|
if a.quantile_lines:
|
|
q_path = out_path.with_name(out_path.stem + "_quantile_lines" + out_path.suffix)
|
|
fig.savefig(q_path, dpi=160, bbox_inches="tight")
|
|
logger.info(f"KL quantile-line figure -> {q_path}")
|
|
logger.info(f"KL-only figure -> {out_path}")
|
|
|
|
|
|
def make_table(root: Path) -> pl.DataFrame:
|
|
rows = []
|
|
for d in sorted(root.iterdir()):
|
|
if not d.is_dir() or d.name.startswith("_"):
|
|
continue
|
|
calib = d / "calib.json"
|
|
if not calib.exists():
|
|
continue
|
|
rows.append(json.loads(calib.read_text()) | {"run": d.name})
|
|
if not rows:
|
|
return pl.DataFrame()
|
|
df = pl.DataFrame(rows)
|
|
df = df.with_columns(pl.col("model").str.split("/").list.last().alias("model_short"))
|
|
g = (df.group_by(["model_short", "method", "window"])
|
|
.agg(pl.col("c_star").mean().alias("c_mean"),
|
|
pl.col("c_star").std().alias("c_std"),
|
|
pl.len().alias("n_seeds"))
|
|
.sort(["model_short", "method", "window"]))
|
|
g = g.with_columns(
|
|
(pl.col("c_std") / pl.col("c_mean").abs()).alias("c_cv"),
|
|
)
|
|
return g
|
|
|
|
|
|
def _pool_kl(cells: list[dict], alpha: str, T: int) -> np.ndarray:
|
|
"""Stack per-prompt KL trajectories from all cells -> (N, T) ndarray."""
|
|
rows = []
|
|
for c in cells:
|
|
per_prompt = c["traj"]["per_prompt_per_t_kl"].get(alpha, [])
|
|
for r in per_prompt:
|
|
arr = np.full(T, np.nan)
|
|
arr[: len(r)] = r[: T]
|
|
rows.append(arr)
|
|
return np.array(rows) if rows else np.zeros((0, T))
|
|
|
|
|
|
def _first_fork(cells: list[dict]) -> list[int]:
|
|
for c in cells:
|
|
if c["pmass"].get("computed", True):
|
|
return c["pmass"]["fork_points"]
|
|
return []
|
|
|
|
|
|
def _pool_pmass_eval(cells: list[dict], alpha: str, fork_points: list[int], T: int) -> np.ndarray:
|
|
"""Pool pmass on EVAL_PROMPTS (paired with KL) and interpolate from fork
|
|
points to per-token, returning (N, T). NaN for cells without pmass_eval.
|
|
"""
|
|
rows = []
|
|
for c in cells:
|
|
if not c["pmass"].get("computed", True):
|
|
continue
|
|
per_prompt = c["pmass"].get("pmass_eval", {}).get(alpha, [])
|
|
for r in per_prompt:
|
|
xs = np.array(fork_points[: len(r)])
|
|
ys = np.array(r, dtype=float)
|
|
t = np.arange(T)
|
|
interp = np.interp(t, xs, ys, left=ys[0], right=ys[-1])
|
|
rows.append(interp)
|
|
return np.array(rows) if rows else np.zeros((0, T))
|
|
|
|
|
|
def _pool_pmass(cells: list[dict], alpha: str) -> tuple[np.ndarray, list[int]]:
|
|
rows = []
|
|
fork = None
|
|
for c in cells:
|
|
if not c["pmass"].get("computed", True):
|
|
continue
|
|
f = c["pmass"]["fork_points"]
|
|
if fork is None: fork = f
|
|
per_prompt = c["pmass"]["pmass"].get(alpha, [])
|
|
for r in per_prompt:
|
|
rows.append(r)
|
|
if not rows:
|
|
return np.zeros((0, len(fork or []))), fork or []
|
|
L = max(len(r) for r in rows)
|
|
arr = np.full((len(rows), L), np.nan)
|
|
for i, r in enumerate(rows):
|
|
arr[i, : len(r)] = r
|
|
return arr, fork
|
|
|
|
|
|
def make_figure(cells: list[dict], a: Args, out_path: Path) -> None:
|
|
import matplotlib.pyplot as plt
|
|
|
|
fig, axes = plt.subplots(2, len(a.alphas), figsize=(4.0 * len(a.alphas), 5.5),
|
|
sharex="row", sharey="row", squeeze=False)
|
|
n_cells = len(cells)
|
|
fig.suptitle(
|
|
f"Figure 1: iso-KL calibration on Qwen2.5-0.5B-Instruct\n"
|
|
f"Top: KL(steered || base) per token on long-form held-out prompts.\n"
|
|
f"Bottom: forked-answer pmass on yes/no reasoning prompts (high=alive, low=collapsed).\n"
|
|
f"{n_cells} cells (3 methods x 1 seed) x N=8 prompts. "
|
|
f"Solid h-line: KL=1 nat. Dotted v-line: t=20.",
|
|
fontsize=10,
|
|
)
|
|
|
|
for j, alpha in enumerate(a.alphas):
|
|
# ---- top: KL ----
|
|
ax = axes[0, j]
|
|
K = _pool_kl(cells, alpha, T=a.window)
|
|
Pe = _pool_pmass_eval(cells, alpha, _first_fork(cells), T=a.window) if a.color_by_pmass else None
|
|
_draw_kl_panel(ax, K, a, P=Pe)
|
|
ax.axhline(1.0, color="k", lw=1.0)
|
|
ax.axvline(20, color="k", ls=":", lw=0.8)
|
|
ax.set_title(rf"$\alpha = {alpha}$ (n={K.shape[0]} traj)")
|
|
ax.set_ylim(0, a.kl_ymax)
|
|
if j == 0:
|
|
ax.set_ylabel("KL(steered || base) [nats]")
|
|
ax.set_xlabel("token")
|
|
|
|
# ---- bottom: pmass ----
|
|
ax2 = axes[1, j]
|
|
P, fork = _pool_pmass(cells, alpha)
|
|
if P.size:
|
|
xs = np.array(fork[: P.shape[1]])
|
|
if a.spaghetti:
|
|
for row in P:
|
|
ax2.plot(xs, row, color="C1", lw=0.6, alpha=0.5)
|
|
ax2.plot(xs, np.nanmedian(P, axis=0), color="k", lw=1.6, marker="o", ms=3)
|
|
else:
|
|
p50 = np.nanpercentile(P, 50, axis=0)
|
|
p10 = np.nanpercentile(P, 10, axis=0)
|
|
p90 = np.nanpercentile(P, 90, axis=0)
|
|
ax2.fill_between(xs, p10, p90, alpha=0.25, color="C1", lw=0)
|
|
ax2.plot(xs, p50, color="C1", lw=1.6, marker="o", ms=3)
|
|
ax2.axvline(20, color="k", ls=":", lw=0.8)
|
|
ax2.set_ylim(-0.02, 1.05)
|
|
if j == 0:
|
|
ax2.set_ylabel('pmass = p(true/1) + p(false/0)\nat fork t')
|
|
ax2.set_xlabel("fork token")
|
|
|
|
if a.color_by_pmass:
|
|
from matplotlib.cm import ScalarMappable
|
|
from matplotlib.colors import LinearSegmentedColormap, Normalize
|
|
cmap = LinearSegmentedColormap.from_list("alive", ["#c0392b", "#f1c40f", "#27ae60"])
|
|
sm = ScalarMappable(norm=Normalize(0, 1), cmap=cmap); sm.set_array([])
|
|
cbar = fig.colorbar(sm, ax=axes[0, :].tolist(), location="right", shrink=0.8, pad=0.02)
|
|
cbar.set_label("pmass (0=dead, 1=alive)")
|
|
|
|
fig.tight_layout(rect=(0, 0, 1, 0.88))
|
|
fig.savefig(out_path, dpi=140, bbox_inches="tight")
|
|
logger.info(f"figure -> {out_path}")
|
|
|
|
|
|
def main(a: Args):
|
|
out = Path(a.out); out.mkdir(parents=True, exist_ok=True)
|
|
root = Path(a.runs_root)
|
|
cells = load_cells(root, window=a.window, model_contains=a.model_contains)
|
|
if not cells:
|
|
raise SystemExit(f"no cells with window={a.window} under {root}")
|
|
logger.info(f"loaded {len(cells)} cells (window={a.window})")
|
|
|
|
df = make_table(root)
|
|
if not df.is_empty():
|
|
df.write_csv(out / "table.csv")
|
|
md = tabulate(df.rows(), headers=df.columns, tablefmt="pipe", floatfmt=".3f")
|
|
(out / "table.md").write_text(md)
|
|
logger.info(f"table -> {out/'table.md'}\n{md}")
|
|
|
|
if a.kl_only:
|
|
make_kl_figure(cells, a, out / "figure1_kl_only.png")
|
|
else:
|
|
make_figure(cells, a, out / "figure1.png")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main(tyro.cli(Args))
|