mirror of
https://github.com/wassname/isokl_steering_calibration.git
synced 2026-06-27 18:05:12 +08:00
150 lines
5.6 KiB
Python
150 lines
5.6 KiB
Python
"""Aggregate per-cell outputs into Figure 1 + the headline table.
|
|
|
|
Figure 1: two stacked subplots.
|
|
Top: per-token p95 KL trajectory. x = token offset; y = KL(steer || base).
|
|
Colour by method, linestyle by alpha (solid=1, dashed=2), seed bands
|
|
as thin lines, faceted by model. Horizontal at target_kl=1.
|
|
Bottom: branch-pmass at fork points. x = fork token offset; y = mean pmass
|
|
across held-out prompts; bands = +/- 1 std across seeds.
|
|
|
|
Table: one row per (model, method), columns = c_star (mean +/- std across seeds),
|
|
KL_p95 @ alpha=1, KL_p95 @ alpha=2, pmass @ alpha=1, pmass @ alpha=2.
|
|
|
|
Usage:
|
|
python scripts/aggregate.py --runs_root outputs --out figs/
|
|
"""
|
|
from __future__ import annotations
|
|
import json
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
import polars as pl
|
|
import tyro
|
|
from loguru import logger
|
|
|
|
|
|
@dataclass
|
|
class Args:
|
|
runs_root: str = "outputs"
|
|
out: str = "figs"
|
|
|
|
|
|
def load_cells(root: Path) -> list[dict]:
|
|
cells = []
|
|
for d in sorted(root.iterdir()):
|
|
if not d.is_dir():
|
|
continue
|
|
calib = d / "calib.json"
|
|
if not calib.exists():
|
|
continue
|
|
meta = json.loads(calib.read_text())
|
|
traj = json.loads((d / "trajectory.json").read_text())
|
|
pmass = json.loads((d / "pmass.json").read_text())
|
|
cells.append({"id": d.name, **meta, "traj": traj, "pmass": pmass})
|
|
return cells
|
|
|
|
|
|
def make_table(cells: list[dict]) -> pl.DataFrame:
|
|
rows = []
|
|
by_mm = defaultdict(list)
|
|
for c in cells:
|
|
by_mm[(c["model"], c["method"])].append(c)
|
|
for (model, method), group in by_mm.items():
|
|
c_stars = [g["c_star"] for g in group]
|
|
# pmass: mean over fork_points and prompts at each alpha, then across seeds
|
|
for alpha in ("1.0", "2.0"):
|
|
kls = []
|
|
pms = []
|
|
for g in group:
|
|
kls.append(g["traj"]["per_t_p95_kl"][alpha])
|
|
pms.append(g["pmass"]["pmass"][alpha])
|
|
kls_flat = [x for arr in kls for x in arr]
|
|
pms_flat = [x for prompt in pms for arr in prompt for x in arr]
|
|
rows.append({
|
|
"model": model.split("/")[-1],
|
|
"method": method,
|
|
"alpha": float(alpha),
|
|
"c_star_mean": sum(c_stars) / len(c_stars),
|
|
"n_seeds": len(group),
|
|
"kl_p95_mean": sum(kls_flat) / max(len(kls_flat), 1),
|
|
"pmass_mean": sum(pms_flat) / max(len(pms_flat), 1),
|
|
})
|
|
return pl.DataFrame(rows)
|
|
|
|
|
|
def make_figure(cells: list[dict], out_path: Path) -> None:
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
models = sorted({c["model"] for c in cells})
|
|
methods = sorted({c["method"] for c in cells})
|
|
fig, axes = plt.subplots(2, len(models), figsize=(5 * len(models), 7),
|
|
sharex="col", squeeze=False)
|
|
cmap = plt.get_cmap("tab10")
|
|
method_color = {m: cmap(i) for i, m in enumerate(methods)}
|
|
|
|
for ci, model in enumerate(models):
|
|
ax_kl = axes[0, ci]
|
|
ax_pm = axes[1, ci]
|
|
ax_kl.set_title(model.split("/")[-1])
|
|
ax_kl.axhline(1.0, color="black", linestyle=":", linewidth=0.8, alpha=0.5)
|
|
ax_kl.set_ylabel("p95 KL(steer || base)")
|
|
ax_pm.set_xlabel("token offset")
|
|
ax_pm.set_ylabel("branch pmass")
|
|
ax_pm.set_ylim(-0.02, 1.02)
|
|
ax_kl.set_yscale("log")
|
|
|
|
for method in methods:
|
|
for alpha, ls in [("1.0", "-"), ("2.0", "--")]:
|
|
kls = [c["traj"]["per_t_p95_kl"][alpha]
|
|
for c in cells if c["model"] == model and c["method"] == method]
|
|
if not kls:
|
|
continue
|
|
arr = np.array(kls)
|
|
x = np.arange(arr.shape[1])
|
|
ax_kl.plot(x, arr.mean(0), color=method_color[method],
|
|
linestyle=ls, linewidth=2,
|
|
label=f"{method} a={alpha}")
|
|
if arr.shape[0] > 1:
|
|
ax_kl.fill_between(x, arr.min(0), arr.max(0),
|
|
color=method_color[method], alpha=0.12)
|
|
|
|
pms = [c["pmass"]["pmass"][alpha]
|
|
for c in cells if c["model"] == model and c["method"] == method]
|
|
if not pms:
|
|
continue
|
|
# pms: list of (n_seed) of (n_prompt) of (n_fork)
|
|
pms_arr = np.array(pms) # (n_seed, n_prompt, n_fork)
|
|
fork = cells[0]["pmass"]["fork_points"]
|
|
mean = pms_arr.mean(axis=(0, 1))
|
|
std = pms_arr.std(axis=(0, 1))
|
|
ax_pm.plot(fork, mean, color=method_color[method],
|
|
linestyle=ls, linewidth=2)
|
|
ax_pm.fill_between(fork, mean - std, mean + std,
|
|
color=method_color[method], alpha=0.12)
|
|
if ci == 0:
|
|
ax_kl.legend(fontsize=8, loc="upper left")
|
|
fig.tight_layout()
|
|
fig.savefig(out_path, dpi=150, bbox_inches="tight")
|
|
logger.info(f"figure -> {out_path}")
|
|
|
|
|
|
def main(a: Args):
|
|
out = Path(a.out); out.mkdir(parents=True, exist_ok=True)
|
|
cells = load_cells(Path(a.runs_root))
|
|
if not cells:
|
|
raise SystemExit(f"no cells under {a.runs_root}")
|
|
logger.info(f"loaded {len(cells)} cells")
|
|
|
|
df = make_table(cells)
|
|
df.write_csv(out / "table.csv")
|
|
md = df.to_pandas().to_markdown(index=False, floatfmt=".3f")
|
|
(out / "table.md").write_text(md)
|
|
logger.info(f"table -> {out/'table.md'}\n{md}")
|
|
|
|
make_figure(cells, out / "figure1.png")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main(tyro.cli(Args))
|