wip: dense alpha sweep + auto-render figs (Qwen3.5-0.8B w=512); queued OLMo-2/Gemma 4B/Gemma 12B/OLMo-3 7B at w=4096

This commit is contained in:
wassname
2026-05-06 05:37:33 +08:00
parent 0bd7a11d2d
commit bd34b7580c
144 changed files with 342825 additions and 210 deletions
+36
View File
@@ -0,0 +1,36 @@
"""Debug: verify pmass measurement on un-steered baseline."""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
tok = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2.5-0.5B-Instruct", torch_dtype=torch.float32
).eval()
prefill = '\n{"choice": '
schema = (
'Think briefly, then answer immediately and only with: '
'{"choice": true} or {"choice": false}. Do not output 1 or 0.'
)
q = "Is the Eiffel Tower located in Paris, France?"
msgs = [{"role": "user", "content": schema + " " + q}]
chat = tok.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False)
full = chat + prefill
ids = tok(full, return_tensors="pt").input_ids
print(f"prompt len={ids.shape[1]} tokens")
with torch.no_grad():
logits = model(ids).logits[0, -1]
probs = torch.softmax(logits, dim=-1)
top = probs.topk(10)
print("\nTop-10 next tokens after full prefill (NO steer, NO rollout):")
for p, i in zip(top.values, top.indices):
print(f" id={int(i):6d} p={float(p):.4f} tok={tok.decode([int(i)])!r}")
a_ids = [16, 1866, 2514, 830, 3007] # 1 true True ' true' ' True'
b_ids = [15, 3849, 4049] # 0 false False
sa = float(probs[a_ids].sum())
sb = float(probs[b_ids].sum())
print(f"\nsum a_ids (true/1)={sa:.4f} sum b_ids (false/0)={sb:.4f} pmass={sa+sb:.4f}")
print(f"argmax: id={int(probs.argmax())} tok={tok.decode([int(probs.argmax())])!r}")
+320 -99
View File
@@ -1,148 +1,369 @@
"""Aggregate per-cell outputs into Figure 1 + the headline table.
"""Aggregate per-cell outputs into Figure 1 + headline table.
Figure 1: two stacked subplots.
Top: per-token p95 KL trajectory. x = token offset; y = KL(steer || base).
Colour by method, linestyle by alpha (solid=1, dashed=2), seed bands
as thin lines, faceted by model. Horizontal at target_kl=1.
Bottom: branch-pmass at fork points. x = fork token offset; y = mean pmass
across held-out prompts; bands = +/- 1 std across seeds.
Layout (default --no-kl-only): 2 rows x N alpha cols.
Top row: KL(steered || base) per token on EVAL_PROMPTS.
Band mode (default): p50 line + p10..p90 band, rolling-{roll} smooth.
Spaghetti mode: individual trajectories.
default coloring: red if traj ever crosses KL=1, grey otherwise.
--color-by-pmass: segments coloured by paired pmass_eval(t) using a
red->yellow->green colormap (dead->alive). Requires
pmass.json with non-empty pmass_eval (run_cell.py
with --compute-pmass).
Bottom row: forked-answer pmass at fork_points.
Uses pmass[alpha] (legacy yes/no reasoning prompts). Different prompt set
than the KL panel above, so the row-to-row link is across cells, not
paired per-trajectory. For paired analysis see scripts/survival.py with
--metric pmass on pmass_eval.
Table: one row per (model, method), columns = c_star (mean +/- std across seeds),
KL_p95 @ alpha=1, KL_p95 @ alpha=2, pmass @ alpha=1, pmass @ alpha=2.
Reading caveats:
- alpha=1 sits near KL=1 on CALIB_PROMPTS by construction. KL on EVAL_PROMPTS
only tests generalisation, not budget choice -- not an independent test.
- The honest coherence test is pmass / pmass_eval (real forced-choice mass
drop), see survival.py for KM-style curves.
Table: per (model, method, window) c_star mean +/- std across seeds.
Usage:
# smoothed band:
python scripts/aggregate.py --runs_root outputs --out figs/
# raw spaghetti:
python scripts/aggregate.py --runs_root outputs_qwen05_w512 \
--out figs_qwen05_pretty_raw --spaghetti --roll 1 \
--alphas 0.5 1.0 2.0 4.0
# KL spaghetti coloured by paired pmass_eval:
python scripts/aggregate.py --runs_root outputs_qwen05_w512 \
--out figs_qwen05_color --spaghetti --color-by-pmass
"""
from __future__ import annotations
import json
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import polars as pl
import tyro
from loguru import logger
from tabulate import tabulate
import matplotlib.pyplot as plt
try:
import seaborn as sns
sns.set_theme(context="notebook", style="whitegrid", palette="deep", font_scale=0.95)
plt.rcParams.update({
"axes.titlesize": 11,
"axes.labelsize": 10,
"axes.spines.top": False,
"axes.spines.right": False,
"figure.titlesize": 11,
"figure.dpi": 110,
})
except Exception:
plt.style.use("ggplot")
@dataclass
class Args:
runs_root: str = "outputs"
out: str = "figs"
window: int = 50 # only this window enters the figure (need long-enough traj for rolling-16)
roll: int = 16 # smoothing window for KL trajectory
alphas: tuple[str, ...] = ("0.25", "0.5", "0.75", "1.0", "1.5", "2.0", "4.0")
kl_ymax: float = 6.0
model_contains: str = ""
kl_only: bool = False
spaghetti: bool = False # plot individual trajectories instead of p10..p90 band
color_by_pmass: bool = False # color KL spaghetti lines by paired pmass (requires pmass_eval)
def load_cells(root: Path) -> list[dict]:
def _rolling_mean(x: np.ndarray, w: int) -> np.ndarray:
"""Rolling mean along axis 0 with edge padding so output len == input len."""
if len(x) < w or w <= 1:
return x
pad = w // 2
xp = np.pad(x, pad, mode="edge")
kernel = np.ones(w) / w
return np.convolve(xp, kernel, mode="valid")[: len(x)]
def load_cells(root: Path, window: int, model_contains: str = "") -> list[dict]:
cells = []
for d in sorted(root.iterdir()):
if not d.is_dir():
if not d.is_dir() or d.name.startswith("_"):
continue
calib = d / "calib.json"; traj_p = d / "trajectory.json"; pm_p = d / "pmass.json"
if not (calib.exists() and traj_p.exists() and pm_p.exists()):
continue
meta = json.loads(calib.read_text())
if meta.get("window") != window:
continue
if model_contains and model_contains not in meta.get("model", ""):
continue
traj = json.loads(traj_p.read_text())
pm = json.loads(pm_p.read_text())
# Skip stale outputs that lack per-prompt KL (pre-redesign).
if "per_prompt_per_t_kl" not in traj:
logger.warning(f"skipping stale {d.name} (no per_prompt_per_t_kl)")
continue
cells.append({"id": d.name, **meta, "traj": traj, "pmass": pm})
return cells
def _draw_kl_panel(ax, K: np.ndarray, a: Args, P: np.ndarray | None = None) -> None:
"""Draw KL trajectories on ax. spaghetti mode: thin per-trajectory lines.
color_by_pmass: each segment colored by pmass(t) using a green->red colormap
(1.0 = lively green, 0.0 = dead red). Otherwise colored by whether traj
ever crossed KL=1. Smoothing per-line via _rolling_mean(roll). Black median.
"""
if not K.size:
return
xs = np.arange(K.shape[1])
if a.spaghetti:
# optional per-line smoothing
Kp = np.array([_rolling_mean(row, a.roll) for row in K]) if a.roll > 1 else K
if a.color_by_pmass and P is not None and P.size and P.shape == K.shape:
# LineCollection per trajectory, color = pmass(t).
# Draw greenest (alive) first, reddest (dead) last so red sits on top
# of the green pile -- otherwise the eye loses the dead trajectories.
from matplotlib.collections import LineCollection
from matplotlib.colors import LinearSegmentedColormap
cmap = LinearSegmentedColormap.from_list("alive", ["#c0392b", "#f1c40f", "#27ae60"]) # dead->alive
order = np.argsort(-np.nanmean(P, axis=1)) # high pmass first, low pmass last
for traj, pmass_row in zip(Kp[order], P[order]):
pts = np.column_stack([xs, traj])
segs = np.stack([pts[:-1], pts[1:]], axis=1)
lc = LineCollection(segs, cmap=cmap, norm=plt.Normalize(0, 1),
linewidths=0.7, alpha=0.55)
lc.set_array(pmass_row[:-1])
ax.add_collection(lc)
ax.set_xlim(xs[0], xs[-1])
med = np.nanmedian(Kp, axis=0)
ax.plot(xs, med, color="k", lw=1.6)
else:
crossed = (K > 1.0).any(axis=1)
for traj in Kp[~crossed]:
ax.plot(xs, traj, color="0.55", lw=0.6, alpha=0.6)
for traj in Kp[crossed]:
ax.plot(xs, traj, color="C3", lw=0.6, alpha=0.6)
med = np.nanmedian(Kp, axis=0)
ax.plot(xs, med, color="k", lw=1.6)
frac = float(crossed.mean())
ax.text(0.97, 0.97, f"{frac:.0%} cross KL=1",
transform=ax.transAxes, ha="right", va="top",
fontsize=8, color="C3" if frac > 0.5 else "0.3")
else:
p50 = np.nanpercentile(K, 50, axis=0)
p10 = np.nanpercentile(K, 10, axis=0)
p90 = np.nanpercentile(K, 90, axis=0)
p50s = _rolling_mean(p50, a.roll)
p10s = _rolling_mean(p10, a.roll)
p90s = _rolling_mean(p90, a.roll)
ax.fill_between(xs, p10s, p90s, alpha=0.25, color="C0", lw=0)
ax.plot(xs, p50s, color="C0", lw=1.6)
def make_kl_figure(cells: list[dict], a: Args, out_path: Path) -> None:
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, len(a.alphas), figsize=(4.0 * len(a.alphas), 3.2),
sharex=True, sharey=True, squeeze=False)
label = a.model_contains or "all models"
mode = "individual trajectories (red=ever crossed KL=1)" if a.spaghetti \
else f"p50 + p10..p90 band, smoothed rolling-{a.roll}"
fig.suptitle(
f"KL trajectory on N=8 held-out long-form prompts ({label})\n"
f"{mode}. Solid line: KL=1 nat. Dotted v-line: t=20.",
fontsize=10,
)
for j, alpha in enumerate(a.alphas):
ax = axes[0, j]
K = _pool_kl(cells, alpha, T=a.window)
Pe = _pool_pmass_eval(cells, alpha, _first_fork(cells), T=a.window) if a.color_by_pmass else None
_draw_kl_panel(ax, K, a, P=Pe)
ax.axhline(1.0, color="k", lw=1.0)
ax.axvline(20, color="k", ls=":", lw=0.8)
ax.set_title(rf"$\alpha = {alpha}$ (n={K.shape[0]} traj)")
ax.set_ylim(0, a.kl_ymax)
ax.set_xlabel("token")
if j == 0:
ax.set_ylabel("KL(steered || base) [nats]")
if a.color_by_pmass:
from matplotlib.cm import ScalarMappable
from matplotlib.colors import LinearSegmentedColormap, Normalize
cmap = LinearSegmentedColormap.from_list("alive", ["#c0392b", "#f1c40f", "#27ae60"])
sm = ScalarMappable(norm=Normalize(0, 1), cmap=cmap); sm.set_array([])
cbar = fig.colorbar(sm, ax=axes[0, :].tolist(), location="right", shrink=0.8, pad=0.02)
cbar.set_label("pmass (0=dead, 1=alive)")
fig.tight_layout(rect=(0, 0, 1, 0.86))
fig.savefig(out_path, dpi=160, bbox_inches="tight")
logger.info(f"KL-only figure -> {out_path}")
def make_table(root: Path) -> pl.DataFrame:
rows = []
for d in sorted(root.iterdir()):
if not d.is_dir() or d.name.startswith("_"):
continue
calib = d / "calib.json"
if not calib.exists():
continue
meta = json.loads(calib.read_text())
traj = json.loads((d / "trajectory.json").read_text())
pmass = json.loads((d / "pmass.json").read_text())
cells.append({"id": d.name, **meta, "traj": traj, "pmass": pmass})
return cells
rows.append(json.loads(calib.read_text()) | {"run": d.name})
if not rows:
return pl.DataFrame()
df = pl.DataFrame(rows)
df = df.with_columns(pl.col("model").str.split("/").list.last().alias("model_short"))
g = (df.group_by(["model_short", "method", "window"])
.agg(pl.col("c_star").mean().alias("c_mean"),
pl.col("c_star").std().alias("c_std"),
pl.len().alias("n_seeds"))
.sort(["model_short", "method", "window"]))
g = g.with_columns(
(pl.col("c_std") / pl.col("c_mean").abs()).alias("c_cv"),
)
return g
def make_table(cells: list[dict]) -> pl.DataFrame:
def _pool_kl(cells: list[dict], alpha: str, T: int) -> np.ndarray:
"""Stack per-prompt KL trajectories from all cells -> (N, T) ndarray."""
rows = []
by_mm = defaultdict(list)
for c in cells:
by_mm[(c["model"], c["method"])].append(c)
for (model, method), group in by_mm.items():
c_stars = [g["c_star"] for g in group]
# pmass: mean over fork_points and prompts at each alpha, then across seeds
for alpha in ("1.0", "2.0"):
kls = []
pms = []
for g in group:
kls.append(g["traj"]["per_t_p95_kl"][alpha])
pms.append(g["pmass"]["pmass"][alpha])
kls_flat = [x for arr in kls for x in arr]
pms_flat = [x for prompt in pms for arr in prompt for x in arr]
rows.append({
"model": model.split("/")[-1],
"method": method,
"alpha": float(alpha),
"c_star_mean": sum(c_stars) / len(c_stars),
"n_seeds": len(group),
"kl_p95_mean": sum(kls_flat) / max(len(kls_flat), 1),
"pmass_mean": sum(pms_flat) / max(len(pms_flat), 1),
})
return pl.DataFrame(rows)
per_prompt = c["traj"]["per_prompt_per_t_kl"].get(alpha, [])
for r in per_prompt:
arr = np.full(T, np.nan)
arr[: len(r)] = r[: T]
rows.append(arr)
return np.array(rows) if rows else np.zeros((0, T))
def make_figure(cells: list[dict], out_path: Path) -> None:
def _first_fork(cells: list[dict]) -> list[int]:
for c in cells:
if c["pmass"].get("computed", True):
return c["pmass"]["fork_points"]
return []
def _pool_pmass_eval(cells: list[dict], alpha: str, fork_points: list[int], T: int) -> np.ndarray:
"""Pool pmass on EVAL_PROMPTS (paired with KL) and interpolate from fork
points to per-token, returning (N, T). NaN for cells without pmass_eval.
"""
rows = []
for c in cells:
if not c["pmass"].get("computed", True):
continue
per_prompt = c["pmass"].get("pmass_eval", {}).get(alpha, [])
for r in per_prompt:
xs = np.array(fork_points[: len(r)])
ys = np.array(r, dtype=float)
t = np.arange(T)
interp = np.interp(t, xs, ys, left=ys[0], right=ys[-1])
rows.append(interp)
return np.array(rows) if rows else np.zeros((0, T))
def _pool_pmass(cells: list[dict], alpha: str) -> tuple[np.ndarray, list[int]]:
rows = []
fork = None
for c in cells:
if not c["pmass"].get("computed", True):
continue
f = c["pmass"]["fork_points"]
if fork is None: fork = f
per_prompt = c["pmass"]["pmass"].get(alpha, [])
for r in per_prompt:
rows.append(r)
if not rows:
return np.zeros((0, len(fork or []))), fork or []
L = max(len(r) for r in rows)
arr = np.full((len(rows), L), np.nan)
for i, r in enumerate(rows):
arr[i, : len(r)] = r
return arr, fork
def make_figure(cells: list[dict], a: Args, out_path: Path) -> None:
import matplotlib.pyplot as plt
import numpy as np
models = sorted({c["model"] for c in cells})
methods = sorted({c["method"] for c in cells})
fig, axes = plt.subplots(2, len(models), figsize=(5 * len(models), 7),
sharex="col", squeeze=False)
cmap = plt.get_cmap("tab10")
method_color = {m: cmap(i) for i, m in enumerate(methods)}
for ci, model in enumerate(models):
ax_kl = axes[0, ci]
ax_pm = axes[1, ci]
ax_kl.set_title(model.split("/")[-1])
ax_kl.axhline(1.0, color="black", linestyle=":", linewidth=0.8, alpha=0.5)
ax_kl.set_ylabel("p95 KL(steer || base)")
ax_pm.set_xlabel("token offset")
ax_pm.set_ylabel("branch pmass")
ax_pm.set_ylim(-0.02, 1.02)
ax_kl.set_yscale("log")
fig, axes = plt.subplots(2, len(a.alphas), figsize=(4.0 * len(a.alphas), 5.5),
sharex="row", sharey="row", squeeze=False)
n_cells = len(cells)
fig.suptitle(
f"Figure 1: iso-KL calibration on Qwen2.5-0.5B-Instruct\n"
f"Top: KL(steered || base) per token on long-form held-out prompts.\n"
f"Bottom: forked-answer pmass on yes/no reasoning prompts (high=alive, low=collapsed).\n"
f"{n_cells} cells (3 methods x 1 seed) x N=8 prompts. "
f"Solid h-line: KL=1 nat. Dotted v-line: t=20.",
fontsize=10,
)
for method in methods:
for alpha, ls in [("1.0", "-"), ("2.0", "--")]:
kls = [c["traj"]["per_t_p95_kl"][alpha]
for c in cells if c["model"] == model and c["method"] == method]
if not kls:
continue
arr = np.array(kls)
x = np.arange(arr.shape[1])
ax_kl.plot(x, arr.mean(0), color=method_color[method],
linestyle=ls, linewidth=2,
label=f"{method} a={alpha}")
if arr.shape[0] > 1:
ax_kl.fill_between(x, arr.min(0), arr.max(0),
color=method_color[method], alpha=0.12)
for j, alpha in enumerate(a.alphas):
# ---- top: KL ----
ax = axes[0, j]
K = _pool_kl(cells, alpha, T=a.window)
Pe = _pool_pmass_eval(cells, alpha, _first_fork(cells), T=a.window) if a.color_by_pmass else None
_draw_kl_panel(ax, K, a, P=Pe)
ax.axhline(1.0, color="k", lw=1.0)
ax.axvline(20, color="k", ls=":", lw=0.8)
ax.set_title(rf"$\alpha = {alpha}$ (n={K.shape[0]} traj)")
ax.set_ylim(0, a.kl_ymax)
if j == 0:
ax.set_ylabel("KL(steered || base) [nats]")
ax.set_xlabel("token")
pms = [c["pmass"]["pmass"][alpha]
for c in cells if c["model"] == model and c["method"] == method]
if not pms:
continue
# pms: list of (n_seed) of (n_prompt) of (n_fork)
pms_arr = np.array(pms) # (n_seed, n_prompt, n_fork)
fork = cells[0]["pmass"]["fork_points"]
mean = pms_arr.mean(axis=(0, 1))
std = pms_arr.std(axis=(0, 1))
ax_pm.plot(fork, mean, color=method_color[method],
linestyle=ls, linewidth=2)
ax_pm.fill_between(fork, mean - std, mean + std,
color=method_color[method], alpha=0.12)
if ci == 0:
ax_kl.legend(fontsize=8, loc="upper left")
fig.tight_layout()
fig.savefig(out_path, dpi=150, bbox_inches="tight")
# ---- bottom: pmass ----
ax2 = axes[1, j]
P, fork = _pool_pmass(cells, alpha)
if P.size:
xs = np.array(fork[: P.shape[1]])
if a.spaghetti:
for row in P:
ax2.plot(xs, row, color="C1", lw=0.6, alpha=0.5)
ax2.plot(xs, np.nanmedian(P, axis=0), color="k", lw=1.6, marker="o", ms=3)
else:
p50 = np.nanpercentile(P, 50, axis=0)
p10 = np.nanpercentile(P, 10, axis=0)
p90 = np.nanpercentile(P, 90, axis=0)
ax2.fill_between(xs, p10, p90, alpha=0.25, color="C1", lw=0)
ax2.plot(xs, p50, color="C1", lw=1.6, marker="o", ms=3)
ax2.axvline(20, color="k", ls=":", lw=0.8)
ax2.set_ylim(-0.02, 1.05)
if j == 0:
ax2.set_ylabel('pmass = p(true/1) + p(false/0)\nat fork t')
ax2.set_xlabel("fork token")
if a.color_by_pmass:
from matplotlib.cm import ScalarMappable
from matplotlib.colors import LinearSegmentedColormap, Normalize
cmap = LinearSegmentedColormap.from_list("alive", ["#c0392b", "#f1c40f", "#27ae60"])
sm = ScalarMappable(norm=Normalize(0, 1), cmap=cmap); sm.set_array([])
cbar = fig.colorbar(sm, ax=axes[0, :].tolist(), location="right", shrink=0.8, pad=0.02)
cbar.set_label("pmass (0=dead, 1=alive)")
fig.tight_layout(rect=(0, 0, 1, 0.88))
fig.savefig(out_path, dpi=140, bbox_inches="tight")
logger.info(f"figure -> {out_path}")
def main(a: Args):
out = Path(a.out); out.mkdir(parents=True, exist_ok=True)
cells = load_cells(Path(a.runs_root))
root = Path(a.runs_root)
cells = load_cells(root, window=a.window, model_contains=a.model_contains)
if not cells:
raise SystemExit(f"no cells under {a.runs_root}")
logger.info(f"loaded {len(cells)} cells")
raise SystemExit(f"no cells with window={a.window} under {root}")
logger.info(f"loaded {len(cells)} cells (window={a.window})")
df = make_table(cells)
df.write_csv(out / "table.csv")
md = df.to_pandas().to_markdown(index=False, floatfmt=".3f")
(out / "table.md").write_text(md)
logger.info(f"table -> {out/'table.md'}\n{md}")
df = make_table(root)
if not df.is_empty():
df.write_csv(out / "table.csv")
md = tabulate(df.rows(), headers=df.columns, tablefmt="pipe", floatfmt=".3f")
(out / "table.md").write_text(md)
logger.info(f"table -> {out/'table.md'}\n{md}")
make_figure(cells, out / "figure1.png")
if a.kl_only:
make_kl_figure(cells, a, out / "figure1_kl_only.png")
else:
make_figure(cells, a, out / "figure1.png")
if __name__ == "__main__":
+243
View File
@@ -0,0 +1,243 @@
"""Empirical audit of branch_pmass: load model, generate one rollout per alpha,
print decoded gen text and top-k tokens at the prefill end-point, and dump
everything to JSON for review.
Goal: distinguish between
(a) pmass measurement is wrong (top tokens at prefill end DON'T match the
pmass we read out of the schema-token groups), vs
(b) pmass is right but the model just doesn't put mass on schema tokens
naturally (steering isn't the problem, the prompt+prefill is), vs
(c) pmass is right and steering really does collapse coherence.
For first prompt per alpha:
- decoded full generation text
- per fork point: top-10 tokens with prob, plus pmass(true)+pmass(false), p_true
Usage:
uv run --extra all python scripts/audit_pmass.py \
--model Qwen/Qwen3.5-0.8B --window 64 --out audit_pmass.json
"""
from __future__ import annotations
import json
from dataclasses import dataclass, field
from pathlib import Path
import torch
import tyro
from loguru import logger
# Import the live module so any audit reflects current code
from iso_kl_figure import (
MeanDiffC, PCAC, DirectionalAblationC,
train, calibrate_iso_kl,
)
from iso_kl_figure.branch_pmass import collect_choice_token_ids, branch_pmass
# Re-import constants from run_cell so audit uses the same prompts/schema
import importlib.util, sys
_rc_path = Path(__file__).parent / "run_cell.py"
_spec = importlib.util.spec_from_file_location("_run_cell", _rc_path)
_rc = importlib.util.module_from_spec(_spec)
sys.modules["_run_cell"] = _rc
_spec.loader.exec_module(_rc)
CALIB_PROMPTS = _rc.CALIB_PROMPTS
EVAL_PROMPTS = _rc.EVAL_PROMPTS
_QUESTIONS = _rc._QUESTIONS
_SCHEMA = _rc._SCHEMA
PREFILL_STR = _rc.PREFILL_STR
POS_NEG = _rc.POS_NEG
METHOD_MAP = {"mean_diff": MeanDiffC, "pca": PCAC, "directional_ablation": DirectionalAblationC}
@dataclass
class Args:
model: str = "Qwen/Qwen3.5-0.8B"
method: str = "mean_diff"
seed: int = 0
window: int = 64
layer_frac: float = 0.6
target_kl: float = 1.0
device: str = "cuda"
dtype: str = "bfloat16"
alphas: tuple[float, ...] = (0.0, 0.25, 0.5, 0.75, 1.0, 1.5, 2.0, 4.0)
fork_points: tuple[int, ...] = (0, 8, 16, 32, 64)
out: str = "audit_pmass.json"
top_k: int = 10
use_qa_prompt: bool = True # True: yes/no q+_SCHEMA; False: long-form EVAL_PROMPTS[0]
skip_calib: bool = False # if True, use fixed_coeffs as raw c per alpha (no iso-KL bisection)
fixed_coeffs: tuple[float, ...] = (0.0, 0.25, 0.5, 0.75, 1.0, 1.5, 2.0, 4.0) # used when skip_calib=True; aligned with --alphas
cross_check_branch_pmass: bool = True # also call branch_pmass and compare to local recompute
def _set_seed(s: int):
import random, numpy as np
random.seed(s); np.random.seed(s); torch.manual_seed(s)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(s)
@torch.no_grad()
def topk_at_prefill_end(v, model, tok, prompt_ids, rolled_ids, fork_points,
prefill_str, a_ids, b_ids, k=10, device="cuda"):
"""Mirror branch_pmass logic but ALSO return top-k tokens for inspection."""
import copy
pids = prompt_ids.to(device); rolled = rolled_ids.to(device)
P = pids.shape[0]; T = rolled.shape[0]
pre_t = torch.tensor(tok.encode(prefill_str, add_special_tokens=False),
device=device, dtype=torch.long)
a_t = torch.tensor(list(a_ids), dtype=torch.long, device=device)
b_t = torch.tensor(list(b_ids), dtype=torch.long, device=device)
all_t = torch.cat([a_t, b_t])
out_per_fork = []
for t in fork_points:
if t > T:
out_per_fork.append({"t": int(t), "skipped": True, "reason": f"t>T={T}"})
continue
prefix = rolled[:t]
seq = torch.cat([pids, prefix, pre_t]).unsqueeze(0)
with v(model):
logits = model(seq).logits[0, -1].float()
probs = torch.softmax(logits, dim=-1)
# top-k
tk_p, tk_i = probs.topk(k)
topk = [(tok.decode([int(i)]), float(p)) for p, i in zip(tk_p.tolist(), tk_i.tolist())]
pa = float(probs[a_t].sum()); pb = float(probs[b_t].sum())
pm = pa + pb
pt = pa / pm if pm > 0 else float("nan")
out_per_fork.append({
"t": int(t),
"skipped": False,
"topk": topk,
"p_true_group": pa,
"p_false_group": pb,
"pmass": pm,
"p_true": pt,
"argmax": tok.decode([int(probs.argmax())]),
})
return out_per_fork
def main(a: Args):
_set_seed(a.seed)
from transformers import AutoModelForCausalLM, AutoTokenizer
dtype = getattr(torch, a.dtype)
logger.info(f"loading model={a.model}")
tok = AutoTokenizer.from_pretrained(a.model)
if tok.pad_token_id is None: tok.pad_token_id = tok.eos_token_id
model = AutoModelForCausalLM.from_pretrained(a.model, torch_dtype=dtype).to(a.device)
model.eval()
n_layers = model.config.num_hidden_layers
layer = int(a.layer_frac * n_layers)
cfg_cls = METHOD_MAP[a.method]
cfg = cfg_cls(coeff=1.0, layers=(layer,))
# Train + calibrate
pos = [tok.apply_chat_template([{"role": "user", "content": u},
{"role": "assistant", "content": p}], tokenize=False)
for u, (p, _) in zip(CALIB_PROMPTS, POS_NEG)]
neg = [tok.apply_chat_template([{"role": "user", "content": u},
{"role": "assistant", "content": n}], tokenize=False)
for u, (_, n) in zip(CALIB_PROMPTS, POS_NEG)]
v = train(model, tok, pos, neg, cfg, batch_size=4, max_length=128)
if a.skip_calib:
c_star = 1.0 # fixed_coeffs are absolute
logger.info(f"skip_calib: using fixed_coeffs={a.fixed_coeffs} as raw c")
else:
c_star, _ = calibrate_iso_kl(v, model, tok, CALIB_PROMPTS,
target_kl=a.target_kl, target_stat="kl_p95",
T=a.window, device=a.device)
logger.info(f"c_star={c_star:+.4f}")
a_ids, b_ids = collect_choice_token_ids(tok)
logger.info(f"a_ids (true-group)={a_ids} -> tokens={[tok.decode([i]) for i in a_ids]}")
logger.info(f"b_ids (false-group)={b_ids} -> tokens={[tok.decode([i]) for i in b_ids]}")
# Single prompt
if a.use_qa_prompt:
prompt = f"{_QUESTIONS[0]}\n\n{_SCHEMA}"
else:
prompt = EVAL_PROMPTS[0]
logger.info(f"prompt: {prompt[:120]}...")
ids = tok.apply_chat_template([{"role": "user", "content": prompt}],
add_generation_prompt=True,
return_tensors="pt").input_ids[0]
out = {
"model": a.model, "method": a.method, "seed": a.seed,
"c_star": c_star, "layer": layer,
"prompt": prompt, "use_qa_prompt": a.use_qa_prompt,
"prefill": PREFILL_STR,
"a_ids": a_ids, "b_ids": b_ids,
"a_tokens": [tok.decode([i]) for i in a_ids],
"b_tokens": [tok.decode([i]) for i in b_ids],
"fork_points": list(a.fork_points),
"alphas": {},
}
for i_alpha, alpha in enumerate(a.alphas):
if a.skip_calib:
v.cfg.coeff = a.fixed_coeffs[i_alpha]
else:
v.cfg.coeff = alpha * c_star
logger.info(f"=== alpha={alpha} coeff={v.cfg.coeff:+.4f} ===")
with v(model):
gen_out = model.generate(
ids.unsqueeze(0).to(a.device),
max_new_tokens=a.window,
pad_token_id=tok.pad_token_id,
eos_token_id=tok.eos_token_id,
do_sample=False,
return_dict_in_generate=True,
)
gen = gen_out.sequences[0, ids.shape[0]:]
gen_text = tok.decode(gen, skip_special_tokens=False)
gen_len = int(gen.shape[0])
logger.info(f" gen_len={gen_len} text[:200]={gen_text[:200]!r}")
per_fork = topk_at_prefill_end(
v, model, tok, ids, gen, list(a.fork_points),
PREFILL_STR, a_ids, b_ids, k=a.top_k, device=a.device,
)
# cross-check: compare to production branch_pmass output
bp_compare = None
if a.cross_check_branch_pmass:
bp = branch_pmass(
v, model, tok, ids, gen, list(a.fork_points),
PREFILL_STR, a_ids, b_ids,
rollout_cache=getattr(gen_out, "past_key_values", None),
device=a.device,
)
bp_compare = {
"pmass": bp["pmass"], "p_true": bp["p_true"],
"argmax_str": bp["argmax_str"], "was_thinking": bp["was_thinking"],
}
for row, bpm, bpt, bam in zip(per_fork, bp["pmass"], bp["p_true"], bp["argmax_str"]):
if row.get("skipped"): continue
local_pm = row["pmass"]; local_pt = row["p_true"]; local_am = row["argmax"]
tag = "OK"
if not (abs(local_pm - bpm) < 1e-3 and (local_am == bam)):
tag = "MISMATCH"
logger.info(f" cross-check t={row['t']:>3} {tag}: local pm={local_pm:.4f} argmax={local_am!r} | branch_pmass pm={bpm:.4f} argmax={bam!r}")
# log inline
for row in per_fork:
if row.get("skipped"):
logger.info(f" t={row['t']:>3} SKIPPED ({row['reason']})")
else:
top3 = ", ".join(f"{tok!r}={p:.3f}" for tok, p in row["topk"][:3])
logger.info(f" t={row['t']:>3} pmass={row['pmass']:.3f} "
f"p_true={row['p_true']:.3f} argmax={row['argmax']!r} top3=[{top3}]")
out["alphas"][str(alpha)] = {
"coeff": float(v.cfg.coeff),
"gen_text": gen_text, "gen_len": gen_len, "per_fork": per_fork,
"branch_pmass_compare": bp_compare,
}
Path(a.out).write_text(json.dumps(out, indent=2, default=str))
logger.info(f"DONE -> {a.out}")
if __name__ == "__main__":
main(tyro.cli(Args))
+22 -1
View File
@@ -5,9 +5,11 @@ cd "$(dirname "$0")/.."
ROOT="$PWD"
MODELS=(
"Qwen/Qwen3.5-0.8B"
"Qwen/Qwen2.5-0.5B-Instruct"
"meta-llama/Llama-3.2-1B-Instruct"
"Qwen/Qwen3-4B-Instruct-2507"
"Qwen/Qwen3-4B-Thinking-2507"
)
METHODS=("mean_diff" "directional_ablation" "pca")
SEEDS=(0 1 2)
@@ -16,6 +18,7 @@ WINDOWS=(20 50)
# Priority: small models first so the figure starts populating; 4B last.
prio_for_model() {
case "$1" in
*0.8B*) echo 35 ;;
*0.5B*) echo 30 ;;
*1B*) echo 20 ;;
*4B*) echo 10 ;;
@@ -31,7 +34,25 @@ for model in "${MODELS[@]}"; do
for window in "${WINDOWS[@]}"; do
run_id="$(basename "$model")_${method}_s${seed}_w${window}"
if [ -f "outputs/${run_id}/calib.json" ]; then
echo "skip ${run_id} (already done)"; continue
if env -i HOME="$HOME" PATH="$ROOT/.venv/bin:/usr/bin:/usr/local/bin" "$ROOT/.venv/bin/python" - <<PY
import json
from pathlib import Path
run = Path("outputs/${run_id}")
traj = json.loads((run / "trajectory.json").read_text()) if (run / "trajectory.json").exists() else {}
pm = json.loads((run / "pmass.json").read_text()) if (run / "pmass.json").exists() else {}
ok = (
"per_prompt_per_t_kl" in traj
and "schema" in pm
and "4.0" in traj.get("per_t_p95_kl", {})
and "4.0" in pm.get("pmass", {})
)
raise SystemExit(0 if ok else 1)
PY
then
echo "skip ${run_id} (fresh)"
continue
fi
echo "rerun ${run_id} (stale outputs)"
fi
label="why: stability of iso-KL calib for ${method} on ${model} (seed ${seed}, T=${window}); resolve: include cell in figure1 if it converges, else flag as bracket-pinned"
pueue add -w "$ROOT" -o "$prio" -l "$label" -- \
+98
View File
@@ -0,0 +1,98 @@
from __future__ import annotations
import os
import subprocess
import sys
from dataclasses import dataclass
from pathlib import Path
import tyro
from loguru import logger
@dataclass
class Args:
run_dir: str
threshold: float = 0.95
out_name: str = "figs_auto"
model_contains: str = ""
def _ensure_single_run_root(run_dir: Path) -> Path:
root = run_dir.parent / f"_{run_dir.name}_single"
root.mkdir(parents=True, exist_ok=True)
link = root / run_dir.name
if link.is_symlink() or link.exists():
try:
if link.resolve() == run_dir.resolve():
return root
except Exception:
pass
if link.is_symlink() or link.is_file():
link.unlink()
else:
raise RuntimeError(f"refusing to replace non-file staging path: {link}")
os.symlink(run_dir.resolve(), link)
return root
def _run(cmd: list[str], cwd: Path) -> None:
logger.info("$ " + " ".join(cmd))
subprocess.run(cmd, cwd=cwd, check=True)
def main(a: Args) -> None:
run_dir = Path(a.run_dir).resolve()
repo_root = Path(__file__).resolve().parents[1]
single_root = _ensure_single_run_root(run_dir)
out_dir = run_dir / a.out_name
out_dir.mkdir(parents=True, exist_ok=True)
calib = run_dir / "calib.json"
if not calib.exists():
raise SystemExit(f"missing {calib}")
model_filter = a.model_contains or run_dir.name.split("_", 1)[0]
common = [
sys.executable,
"scripts/survival.py",
"--runs-root", str(single_root),
"--out", str(out_dir / "survival"),
"--window", str(__import__("json").loads(calib.read_text())["window"]),
"--metric", "pmass_eval",
"--thresholds", str(a.threshold),
"--model-contains", model_filter,
]
_run(common, repo_root)
_run([
sys.executable,
"scripts/spaghetti_kl_alive.py",
"--runs-root", str(single_root),
"--out", str(out_dir / "spaghetti"),
"--window", str(__import__("json").loads(calib.read_text())["window"]),
"--threshold", str(a.threshold),
"--model-contains", model_filter,
], repo_root)
_run([
sys.executable,
"scripts/aggregate.py",
"--runs-root", str(single_root),
"--out", str(out_dir / "aggregate"),
"--window", str(__import__("json").loads(calib.read_text())["window"]),
"--spaghetti",
"--color-by-pmass",
"--kl-only",
"--model-contains", model_filter,
], repo_root)
pngs = [
out_dir / "survival" / "survival_pmass_eval.png",
out_dir / "spaghetti" / "kl_alive_spaghetti.png",
out_dir / "aggregate" / "figure1_kl_only.png",
]
for png in pngs:
logger.info(f"PNG -> {png}")
if __name__ == "__main__":
main(tyro.cli(Args))
+260 -48
View File
@@ -1,27 +1,56 @@
"""End-to-end runner for one (model, method, seed, window) cell.
Flow:
1. Load model + tokenizer (HF), set seed.
2. Build pos/neg prompts (cheap pair); train the steering Vector v.
3. Calibrate iso-KL at target_kl=1 over T=window tokens. Save full history
(incl. per-token KL arrays) to outputs/<run_id>/history.json.
4. Re-run measure_kl at coeff=alpha*c_star (alpha in {1, 2}) on a held-out
prompt set so the trajectory plot reflects generalisation, not the
calibration set itself. Save per-token p95 KL to trajectory.json.
5. For each held-out prompt, rollout T_eval tokens under the steered model,
then branch-pmass at fork_points {0, 5, ..., T_eval}. Save to
pmass.json. Use a JSON-format suffix so target tokens are well-defined.
Artefacts (per cell, under outputs/<run_id>/):
calib.json c_star + run metadata (model, method, seed, layer, window).
history.json bisection history (without per-token KL arrays).
trajectory.json per-token KL on EVAL_PROMPTS for each alpha:
per_t_p95_kl[alpha]: list[T]
per_prompt_per_t_kl[alpha]: list[N_prompts][T]
pmass.json forked-answer probability mass at fork_points:
pmass[alpha]: yes/no reasoning prompts (legacy probe)
pmass_eval[alpha]: SAME prompts as trajectory.json (paired w KL)
gen_lens_qa[alpha], gen_lens_eval[alpha]: T per rollout
(use for right-censoring -- NaN at t > T means rollout
EOS'd before that fork, NOT a measurement failure).
debug_first[alpha]: gen_text + per-fork pmass for the
FIRST qa & eval prompt (sanity-check; see survival.py).
pmass_eval is the metric to use when you want a per-trajectory
coherence signal aligned with the KL trajectory.
results.csv one row per alpha with kl_p95/mean/max.
Outputs one CSV row per (alpha, prompt) into outputs/<run_id>/results.csv
plus the artefacts above.
Flow:
1. Load model + tokenizer; set seed.
2. Build pos/neg pair (cheap content-vs-refusal); train Vector v.
3. Calibrate iso-KL: bisect coeff so per-token kl_p95 hits target_kl (default 1)
over T=window tokens on CALIB_PROMPTS.
4. For each alpha, set coeff=alpha*c_star and:
a) measure_kl on EVAL_PROMPTS -> per_prompt_per_t_kl trajectories.
b) (--compute-pmass) yes/no questions: rollout, then at each fork point
prefill `\n{"choice": ` and read mass on {true/True/ true/.../1/ 1}
and {false/.../0}. Skip with --skip-pmass-qa.
c) (--compute-pmass) repeat (b) on EVAL_PROMPTS -> pmass_eval, paired
with the KL trajectory on the same rollouts. This is the honest
coherence signal; KL is largely a calibration tautology at alpha=1.
Fork-point modes:
--fork-step N linear: {0, N, 2N, ..., window}
--fork-log log-spaced: {0, 1, 2, 4, 8, ..., window} (with --n-log-forks)
KV-cache:
Rollouts use generate(..., return_dict_in_generate=True). branch_pmass
uses full-recompute by default (cache reuse adds ~10% bf16 noise; verified
by scripts/audit_pmass.py). Pass use_rollout_cache=True only if you accept
the noise.
Usage (tyro):
python scripts/run_cell.py --model Qwen/Qwen2.5-0.5B-Instruct --method mean_diff \
--seed 0 --window 50 --run_id qwen05b_md_s0_w50
--seed 0 --window 512 --compute-pmass --skip-pmass-qa --fork-log
"""
from __future__ import annotations
import json
import os
import subprocess
import sys
from dataclasses import dataclass, field
from pathlib import Path
@@ -33,7 +62,7 @@ from iso_kl_figure import (
SteeringConfig, MeanDiffC, PCAC, DirectionalAblationC,
train, calibrate_iso_kl, measure_kl,
)
from iso_kl_figure.branch_pmass import branch_pmass
from iso_kl_figure.branch_pmass import branch_pmass, collect_choice_token_ids
CALIB_PROMPTS = [
@@ -43,14 +72,38 @@ CALIB_PROMPTS = [
"Briefly describe how plate tectonics shape mountain ranges.",
]
# Held-out prompts -- different from calibration; tests generalisation.
# EVAL prompts are yes/no questions paired with schema -- same contract as
# tinymfv/guided.py so branch_pmass at any fork can score format-compliance.
# The model is steered into <think>\n at the prompt boundary; at fork t we
# inject "\nI should answer now.</think>\n{\"choice\": " and read p(true|false).
EVAL_PROMPTS = [
"Explain why the sky is blue in two short sentences.",
"Summarise the plot of Hamlet in one paragraph.",
"List three causes of inflation.",
"Write a haiku about an octopus librarian.",
"Is the sky blue primarily because of Rayleigh scattering of sunlight?",
"Does Hamlet die at the end of Shakespeare's play 'Hamlet'?",
"Is excessive money supply growth one cause of inflation in modern economies?",
"Do octopuses have three hearts and blue copper-based blood?",
"Does photosynthesis use light energy to convert CO2 and water into glucose?",
"Was the Eiffel Tower originally built as the entrance arch for the 1889 World's Fair?",
"Can plate tectonic collisions form mountain ranges over geological time?",
"Is a circular import one possible cause of a Python script crashing on import?",
]
# pmass diagnostic prompts share the same schema as EVAL above.
_SCHEMA = (
'Reason for at least 80 words before the final answer. Then answer only with '
'{"choice": true} or {"choice": false}. If you use numeric JSON, use 1 for true and 0 for false.'
)
_QUESTIONS = [
"Is the Eiffel Tower located in Paris, France?",
"Do dolphins breathe air through lungs rather than gills?",
"Was Albert Einstein born in Germany?",
"Can humans naturally see ultraviolet light?",
"Did the Roman Empire fall in 476 CE?",
"Are tomatoes botanically classified as fruits?",
"Is the speed of light approximately 300,000 km/s?",
"Do octopuses have three hearts?",
]
PREFILL_STR = '\n{"choice": '
# Pos/neg pair: a generic content-vs-refusal direction. Cheap, not the main
# point -- the figure tests calibration *behaviour*, not direction quality.
POS_NEG = [
@@ -80,9 +133,14 @@ class Args:
out_root: str = "outputs"
device: str = "cuda"
dtype: str = "bfloat16"
suffix_str: str = ' Final answer in JSON: {"value": '
target_words: list[str] = field(default_factory=lambda: ["true", "false", "yes", "no"])
alphas: tuple[float, ...] = (0.25, 0.5, 0.75, 1.0, 1.5, 2.0, 4.0)
fork_step: int = 5
fork_log: bool = False # log-spaced fork points {0,1,2,4,8,...,window}
n_log_forks: int = 14 # number of log-spaced forks (incl. 0)
compute_pmass: bool = False
skip_pmass_qa: bool = False # skip yes/no pmass loop, only do paired pmass_eval
render_figs: bool = False # render single-run survival + spaghetti + KL pngs
render_threshold: float = 0.95
def _set_seed(s: int):
@@ -93,6 +151,19 @@ def _set_seed(s: int):
torch.cuda.manual_seed_all(s)
def _build_guided_prompt(tok, user_text: str, schema_hint: str = _SCHEMA) -> str:
"""Match tinymfv/guided.py: chat template + '<think>\\n' suffix so the model
is in thinking mode at every fork point. branch_pmass detects this and
splices '\\nI should answer now.</think>{prefill}' before scoring."""
full_user = f"{user_text}\n\n{schema_hint}" if schema_hint else user_text
msgs = [{"role": "user", "content": full_user}]
try:
p = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
except TypeError:
p = tok.apply_chat_template(msgs, tokenize=False)
return p + "<think>\n"
def main(a: Args):
if not a.run_id:
a.run_id = f"{a.model.split('/')[-1]}_{a.method}_s{a.seed}_w{a.window}"
@@ -134,56 +205,186 @@ def main(a: Args):
)
v.cfg.coeff = c_star
logger.info(f"c_star = {c_star:+.4f}")
(out_dir / "history.json").write_text(json.dumps(history, indent=2))
# Strip per_prompt_per_t from history to keep file size small.
hist_slim = [{k: v_ for k, v_ in h.items() if k != "per_prompt_per_t"}
for h in history]
(out_dir / "history.json").write_text(json.dumps(hist_slim, indent=2))
(out_dir / "calib.json").write_text(json.dumps({
"c_star": c_star, "target_kl": a.target_kl, "window": a.window,
"method": a.method, "model": a.model, "seed": a.seed, "layer": layer,
}, indent=2))
# -- trajectory + pmass at alpha in {1, 2} on held-out prompts
# Sanity-check pmass on the BASE model (no steering): should be ~1.0,
# otherwise the prefill/schema isn't priming the right tokens.
a_ids, b_ids = collect_choice_token_ids(tok)
if a.compute_pmass:
logger.info(f"choice ids: a(true)={a_ids} b(false)={b_ids}")
# -- trajectory + pmass at each alpha on held-out prompts
rows = []
fork_points = list(range(0, a.window + 1, a.fork_step))
if a.fork_log:
# log-spaced including 0 and window: {0, 1, 2, 4, 8, ..., window}
import numpy as _np
raw = _np.unique(_np.round(_np.geomspace(1, a.window, a.n_log_forks - 1)).astype(int))
fork_points = [0] + [int(x) for x in raw]
fork_points = sorted(set(fp for fp in fork_points if fp <= a.window))
else:
fork_points = list(range(0, a.window + 1, a.fork_step))
logger.info(f"fork_points (n={len(fork_points)}): {fork_points}")
trajectory: dict[str, list] = {}
per_prompt_traj: dict[str, list] = {}
pmass_all: dict[str, list] = {}
for alpha in (1.0, 2.0):
pmass_eval_all: dict[str, list] = {} # pmass paired with EVAL_PROMPTS (same prompts as KL)
p_true_all: dict[str, list] = {}
argmax_all: dict[str, list] = {}
thinking_all: dict[str, list] = {}
answer_label_all: dict[str, list] = {}
gen_lens_qa: dict[str, list] = {} # T per (alpha, qa-prompt) -- right-censoring info
gen_lens_eval: dict[str, list] = {} # T per (alpha, eval-prompt)
debug_first: dict[str, dict] = {} # first prompt per alpha: gen_text + top-5 at each fork
# Pre-build guided EVAL ids ONCE so measure_kl and pmass_eval roll out the
# same prompt -- otherwise spaghetti coloring (KL trajectory colored by
# pmass) is meaningless because the rollouts diverge.
eval_prompt_strs = [_build_guided_prompt(tok, p, _SCHEMA) for p in EVAL_PROMPTS]
eval_ids_list = [
tok(s, return_tensors="pt", add_special_tokens=False).input_ids[0]
for s in eval_prompt_strs
]
for alpha in a.alphas:
v.cfg.coeff = alpha * c_star
logger.info(f"=== eval alpha={alpha} c={v.cfg.coeff:+.4f} ===")
m = measure_kl(v, model, tok, EVAL_PROMPTS, T=a.window, device=a.device)
m = measure_kl(v, model, tok, eval_ids_list, T=a.window, device=a.device)
trajectory[str(alpha)] = m["per_t_p95"]
per_prompt_traj[str(alpha)] = m["per_prompt_per_t"]
rows.append({"alpha": alpha, "coeff": v.cfg.coeff, "kl_p95": m["kl_p95"],
"kl_mean": m["kl_mean"], "kl_max": m["kl_max"]})
# pmass per held-out prompt
pm_for_alpha = []
for p in EVAL_PROMPTS:
ids = tok.apply_chat_template(
[{"role": "user", "content": p}],
add_generation_prompt=True, return_tensors="pt",
).input_ids[0]
pad = tok.pad_token_id
with v(model):
gen = model.generate(
ids.unsqueeze(0).to(a.device),
max_new_tokens=a.window,
pad_token_id=pad, eos_token_id=tok.eos_token_id,
do_sample=False,
)[0, ids.shape[0]:]
pm = branch_pmass(
v, model, tok, ids, gen, fork_points,
a.suffix_str, a.target_words, device=a.device,
)
pm_for_alpha.append(pm["pmass"])
pm_for_alpha, pt_for_alpha, ax_for_alpha, wt_for_alpha, ans_for_alpha = [], [], [], [], []
gen_lens_qa_alpha = []
gen_lens_eval_alpha = []
if a.compute_pmass and not a.skip_pmass_qa:
for q_idx, q in enumerate(_QUESTIONS):
prompt_str = _build_guided_prompt(tok, q, _SCHEMA)
ids = tok(prompt_str, return_tensors="pt", add_special_tokens=False).input_ids[0]
pad = tok.pad_token_id
with v(model):
gen_out = model.generate(
ids.unsqueeze(0).to(a.device),
max_new_tokens=a.window,
pad_token_id=pad, eos_token_id=tok.eos_token_id,
do_sample=False,
use_cache=True,
return_dict_in_generate=True,
)
gen = gen_out.sequences[0, ids.shape[0]:]
gen_lens_qa_alpha.append(int(gen.shape[0]))
# KV cache from steered rollout: ~100x speedup at each fork.
pm = branch_pmass(
v, model, tok, ids, gen, fork_points,
PREFILL_STR, a_ids, b_ids,
rollout_cache=gen_out.past_key_values,
device=a.device,
)
pm_for_alpha.append(pm["pmass"])
pt_for_alpha.append(pm["p_true"])
ax_for_alpha.append(pm["argmax_str"])
wt_for_alpha.append(pm["was_thinking"])
ans_for_alpha.append([
"true" if s in {"true", "True", " true", " True", "1", " 1"}
else "false" if s in {"false", "False", " false", " False", "0", " 0"}
else "other"
for s in pm["argmax_str"]
])
# debug dump for first QA prompt only
if q_idx == 0:
gen_text = tok.decode(gen, skip_special_tokens=False)
debug_first.setdefault(str(alpha), {})["qa"] = {
"prompt": prompt_str, "gen_text": gen_text, "gen_len": int(gen.shape[0]),
"pmass_per_fork": pm["pmass"], "p_true_per_fork": pm["p_true"],
"argmax_per_fork": pm["argmax_str"],
}
logger.info(f" [debug] alpha={alpha} qa[0] gen_len={gen.shape[0]} text[:120]={gen_text[:120]!r}")
for t, pmv, ptv, am in zip(fork_points, pm["pmass"], pm["p_true"], pm["argmax_str"]):
logger.info(f" t={t:>3} pmass={pmv:.3f} p_true={ptv:.3f} argmax={am!r}")
pmass_all[str(alpha)] = pm_for_alpha
p_true_all[str(alpha)] = pt_for_alpha
argmax_all[str(alpha)] = ax_for_alpha
thinking_all[str(alpha)] = wt_for_alpha
answer_label_all[str(alpha)] = ans_for_alpha
gen_lens_qa[str(alpha)] = gen_lens_qa_alpha
# Paired pmass on EVAL_PROMPTS (same long-form prompts as KL) so we can
# color KL trajectories by pmass at each fork. Long-form prompts won't
# naturally produce schema tokens, but the prefill forces the question
# "if you committed now, can you still produce a valid choice?" -- which
# is exactly the coherence signal we want.
pm_eval_for_alpha = []
if a.compute_pmass:
for p_idx, p in enumerate(EVAL_PROMPTS):
prompt_str = eval_prompt_strs[p_idx]
ids = eval_ids_list[p_idx]
pad = tok.pad_token_id
with v(model):
gen_out = model.generate(
ids.unsqueeze(0).to(a.device),
max_new_tokens=a.window,
pad_token_id=pad, eos_token_id=tok.eos_token_id,
do_sample=False,
use_cache=True,
return_dict_in_generate=True,
)
gen = gen_out.sequences[0, ids.shape[0]:]
gen_lens_eval_alpha.append(int(gen.shape[0]))
pm = branch_pmass(
v, model, tok, ids, gen, fork_points,
PREFILL_STR, a_ids, b_ids,
rollout_cache=gen_out.past_key_values,
device=a.device,
)
pm_eval_for_alpha.append(pm["pmass"])
if p_idx == 0:
gen_text = tok.decode(gen, skip_special_tokens=False)
debug_first.setdefault(str(alpha), {})["eval"] = {
"prompt": prompt_str, "gen_text": gen_text, "gen_len": int(gen.shape[0]),
"pmass_per_fork": pm["pmass"], "p_true_per_fork": pm["p_true"],
"argmax_per_fork": pm["argmax_str"],
}
logger.info(f" [debug] alpha={alpha} eval[0] gen_len={gen.shape[0]} text[:120]={gen_text[:120]!r}")
for t, pmv, ptv, am in zip(fork_points, pm["pmass"], pm["p_true"], pm["argmax_str"]):
logger.info(f" t={t:>3} pmass={pmv:.3f} p_true={ptv:.3f} argmax={am!r}")
pmass_eval_all[str(alpha)] = pm_eval_for_alpha
gen_lens_eval[str(alpha)] = gen_lens_eval_alpha
# SHOULD: at alpha=1, mean(pmass at t=0) > 0.5 (model still respects schema).
# ELSE: prefill string broken or chat template off.
if a.compute_pmass:
try:
import numpy as _np
t0 = _np.array([row[0] for row in pm_for_alpha])
logger.info(f" alpha={alpha} pmass@t=0: mean={t0.mean():.3f} min={t0.min():.3f} max={t0.max():.3f}")
except Exception:
pass
(out_dir / "trajectory.json").write_text(json.dumps({
"fork_points_full": list(range(a.window)),
"per_t_p95_kl": trajectory,
"per_prompt_per_t_kl": per_prompt_traj,
}, indent=2))
(out_dir / "pmass.json").write_text(json.dumps({
"fork_points": fork_points,
"pmass": pmass_all,
"suffix": a.suffix_str,
"target_words": a.target_words,
"pmass_eval": pmass_eval_all,
"p_true": p_true_all,
"argmax_str": argmax_all,
"was_thinking": thinking_all,
"answer_label": answer_label_all,
"gen_lens_qa": gen_lens_qa, # T per (alpha, qa-prompt) for right-censoring
"gen_lens_eval": gen_lens_eval, # T per (alpha, eval-prompt) for right-censoring
"debug_first": debug_first, # first prompt per alpha: gen_text + per-fork pmass/argmax
"prefill": PREFILL_STR,
"schema": _SCHEMA,
"questions": _QUESTIONS,
"computed": a.compute_pmass,
}, indent=2))
import csv
with open(out_dir / "results.csv", "w", newline="") as f:
@@ -191,6 +392,17 @@ def main(a: Args):
w.writeheader()
for r in rows:
w.writerow(r)
if a.render_figs:
repo_root = Path(__file__).resolve().parents[1]
cmd = [
sys.executable,
"scripts/render_run_figs.py",
"--run-dir", str(out_dir),
"--threshold", str(a.render_threshold),
"--model-contains", a.model.split("/")[-1],
]
logger.info("rendering single-run figures")
subprocess.run(cmd, check=True, cwd=repo_root)
logger.info(f"DONE -> {out_dir}")
+215
View File
@@ -0,0 +1,215 @@
"""KL-vs-survival spaghetti: per-prompt KL trajectory colored by alive/dead.
Each rollout is a thin line. KL(t) is plotted; the line is colored by whether
the rollout is currently 'alive' (pmass at the nearest fork s<=t is >= threshold)
or 'dead' (it has dropped below threshold at some s<=t -- death is irreversible).
Right-censoring: if the rollout EOS'd before t (gen_len < t), the line stops at
gen_len (drawn as a small terminal dot) -- it is 'complete' not 'dead'.
Panels: one per alpha. Title: KL ceiling (calibration target) + n trajectories.
Reading guide: at alpha=1, KL is bounded near the calibration target by
construction (~1 nat). Lines staying blue/green => model still coherent under
that KL budget. Red lines => model collapsed even though KL stayed in budget
(steering hit a degenerate region without raising KL much).
Usage:
python scripts/spaghetti_kl_alive.py \
--runs-root outputs_qwen35_w512_v3 \
--out figs_qwen35_w512_kl_alive \
--window 512 \
--alphas 0.0 0.25 0.5 0.75 1.0 1.5 2.0 4.0 \
--threshold 0.8 \
--model-contains Qwen3.5-0.8B
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import tyro
from loguru import logger
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
from matplotlib.colors import to_rgba
try:
import seaborn as sns
sns.set_theme(context="notebook", style="whitegrid", palette="deep", font_scale=0.9)
plt.rcParams.update({
"axes.titlesize": 10, "axes.labelsize": 9,
"axes.spines.top": False, "axes.spines.right": False,
})
except Exception:
plt.style.use("ggplot")
@dataclass
class Args:
runs_root: str = "outputs_qwen35_w512_v3"
out: str = "figs_qwen35_w512_kl_alive"
window: int = 512
alphas: tuple[str, ...] = ("0.0", "0.25", "0.5", "0.75", "1.0", "1.5", "2.0", "4.0")
threshold: float = 0.8 # pmass < threshold = dead
metric: str = "pmass_eval" # 'pmass_eval' is paired with KL prompts
model_contains: str = "Qwen3.5-0.8B"
kl_log: bool = True
roll: int = 11 # smooth KL a bit so the spaghetti is readable
def load_cell(d: Path, alpha: str, T: int):
"""Load one cell: returns list of (kl_traj, pmass_per_fork, fork_points, gen_len)."""
calib = d / "calib.json"
traj_p = d / "trajectory.json"
pm_p = d / "pmass.json"
if not (calib.exists() and traj_p.exists() and pm_p.exists()):
return []
traj = json.loads(traj_p.read_text())
pm = json.loads(pm_p.read_text())
if not pm.get("computed", True):
return []
fork = pm["fork_points"]
kl_per_prompt = traj.get("per_prompt_per_t_kl", {}).get(alpha, [])
pmass_per_prompt = pm.get("pmass_eval", {}).get(alpha, [])
glens = pm.get("gen_lens_eval", {}).get(alpha, [])
if not kl_per_prompt or not pmass_per_prompt:
return []
out = []
for i, (kl, pmv) in enumerate(zip(kl_per_prompt, pmass_per_prompt)):
gl = int(glens[i]) if i < len(glens) else len(kl)
out.append((np.asarray(kl, dtype=float), np.asarray(pmv, dtype=float), fork, gl))
return out
def alive_mask_for_t(pmass_per_fork: np.ndarray, fork: list[int],
T: int, threshold: float, gen_len: int) -> np.ndarray:
"""Return per-token (T,) mask: True=alive, False=dead. Uses nearest-fork-<=t.
Once dead at some fork, dead forever after that fork. Right-censoring at gen_len."""
# walking running-min over forks, then broadcast to per-token via "nearest fork <= t"
rmin = np.minimum.accumulate(np.where(np.isnan(pmass_per_fork), np.inf, pmass_per_fork))
dead_at_fork = rmin < threshold
# for each token t, find largest fork s such that fork[s] <= t
fork_arr = np.array(fork)
alive = np.ones(T, dtype=bool)
for t in range(min(T, gen_len)):
# idx of last fork <= t
idx = int(np.searchsorted(fork_arr, t, side="right") - 1)
if idx >= 0 and dead_at_fork[idx]:
alive[t] = False
# tokens beyond gen_len: censored, mark via separate signal (we'll just truncate display)
return alive
def _rolling_mean(x: np.ndarray, w: int) -> np.ndarray:
if w <= 1 or len(x) < w:
return x
pad = w // 2
xp = np.pad(x, pad, mode="edge")
kernel = np.ones(w, dtype=float) / w
return np.convolve(xp, kernel, mode="valid")[: len(x)]
def main(a: Args):
root = Path(a.runs_root)
out = Path(a.out); out.mkdir(parents=True, exist_ok=True)
cells = [d for d in sorted(root.iterdir()) if d.is_dir() and not d.name.startswith("_")]
cells = [d for d in cells if a.model_contains in d.name]
logger.info(f"found {len(cells)} cells in {root} matching {a.model_contains!r}")
n_panels = len(a.alphas)
fig, axes = plt.subplots(1, n_panels, figsize=(3.0 * n_panels, 3.4),
sharey=True, squeeze=False)
color_alive = to_rgba("#1a9850", 0.65) # green, translucent so overlap stays readable
color_dead = to_rgba("#d7191c", 0.55) # stronger red so all-dead panels are visible
summary_rows = []
for j, alpha in enumerate(a.alphas):
ax = axes[0, j]
all_trajs = []
for d in cells:
all_trajs.extend(load_cell(d, alpha, a.window))
n = len(all_trajs)
n_died = 0
n_censored = 0
median_rows = []
dead_segments = []
alive_segments = []
dead_colors = []
alive_colors = []
# build all dead segments first, then alive segments on top
for kl, pmv, fork, gl in all_trajs:
T = min(len(kl), gl)
kl = _rolling_mean(kl[:T], a.roll)
median_rows.append(np.pad(kl, (0, max(0, a.window - T)), constant_values=np.nan)[: a.window])
alive = alive_mask_for_t(pmv, fork, T, a.threshold, gl)
xs = np.arange(T)
if T < 2:
continue
pts = np.stack([xs, kl], axis=1).reshape(-1, 1, 2)
segs = np.concatenate([pts[:-1], pts[1:]], axis=1)
seg_alive = alive[:-1]
dead_segments.extend([seg for seg, al in zip(segs, seg_alive) if not al])
alive_segments.extend([seg for seg, al in zip(segs, seg_alive) if al])
dead_colors.extend([color_dead] * int((~seg_alive).sum()))
alive_colors.extend([color_alive] * int(seg_alive.sum()))
if not alive.all():
n_died += 1
if gl < a.window:
n_censored += 1
ax.scatter([gl - 1], [kl[-1]], s=6, c="black", marker="|",
alpha=0.6, zorder=3)
if dead_segments:
ax.add_collection(LineCollection(dead_segments, colors=dead_colors, linewidths=0.9, zorder=1))
if alive_segments:
ax.add_collection(LineCollection(alive_segments, colors=alive_colors, linewidths=0.9, zorder=2))
if median_rows:
med = np.nanmedian(np.asarray(median_rows), axis=0)
med_x = np.arange(len(med))
if np.nanmax(np.abs(med)) < 1e-6:
med = med + 1e-4
ax.plot(med_x, med, color="black", lw=1.3, alpha=0.9, zorder=4)
ax.axhline(1.0, color="black", lw=0.7, ls=":", label="KL=1 calib target")
ax.set_title(rf"$\alpha={alpha}$ n={n} died={n_died} cens={n_censored}")
ax.set_xlabel("token t")
if j == 0:
ax.set_ylabel("per-token KL")
if a.kl_log:
ax.set_yscale("symlog", linthresh=0.1)
ax.set_xlim(-5, a.window + 5)
ax.autoscale_view()
# data-driven y-lim sanity
# rely on auto
summary_rows.append({"alpha": alpha, "n": n, "n_died": n_died, "n_censored": n_censored})
# legend on first panel only
from matplotlib.lines import Line2D
handles = [
Line2D([0],[0], color=color_alive, lw=2, label=f"alive (pmass >= {a.threshold})"),
Line2D([0],[0], color=color_dead, lw=2, label=f"dead (pmass < {a.threshold})"),
Line2D([0],[0], color="black", marker="|", lw=0, label="EOS (right-censored)"),
Line2D([0],[0], color="black", lw=0.7, ls=":", label="KL=1 calib target"),
]
axes[0, 0].legend(handles=handles, loc="upper right", fontsize=7, frameon=True)
fig.suptitle(
f"KL trajectories coloured by survival (pmass < {a.threshold} = dead). "
f"Model: {a.model_contains}. window={a.window}.",
fontsize=10,
)
fig.tight_layout(rect=(0, 0, 1, 0.92))
out_p = out / "kl_alive_spaghetti.png"
fig.savefig(out_p, dpi=160, bbox_inches="tight")
logger.info(f"figure -> {out_p}")
from tabulate import tabulate
md = tabulate(summary_rows, headers="keys", tablefmt="pipe")
(out / "kl_alive_summary.md").write_text(md)
logger.info("\n" + md)
if __name__ == "__main__":
main(tyro.cli(Args))
+227
View File
@@ -0,0 +1,227 @@
"""Kaplan-Meier-style survival curves for steered trajectories.
Motivation:
At alpha=1 KL is near the calibration target by construction (we bisected
coeff to make it so), so "alpha=1 stays at KL=1" is circular. The honest
test is whether the model is still COHERENT -- measured here as pmass on
forced-choice answer tokens. Survival curves separate alphas more cleanly
than per-token bands.
Death modes (--metric):
kl proxy. alive(t) := running_max KL(s) over s<=t < threshold (default 1.0).
Largely redundant with calibration at alpha=1; read with care.
pmass real. Right-censored Kaplan-Meier on forced-choice mass.
alive(t) := running_min pmass(s) >= threshold AND rollout reached s.
Death is irreversible (KM convention).
Right-censoring: rollouts that EOS'd before fork t drop OUT of the
denominator at t (they are 'complete', not 'dead'). Uses
gen_lens_qa[alpha] / gen_lens_eval[alpha] from pmass.json.
Reads pmass[alpha] (yes/no reasoning prompts). To use the EVAL_PROMPTS
paired with KL instead, switch the loader to pmass_eval[alpha].
Inputs: outputs/<run>/trajectory.json + pmass.json.
Usage:
python scripts/survival.py --runs_root outputs_qwen05_w512 \
--out figs_qwen05_survival --alphas 0.5 1.0 2.0 4.0 \
--metric pmass --thresholds 0.5 0.8 0.95
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import tyro
from loguru import logger
import matplotlib.pyplot as plt
try:
import seaborn as sns
sns.set_theme(context="notebook", style="whitegrid", palette="deep", font_scale=0.95)
plt.rcParams.update({
"axes.titlesize": 11, "axes.labelsize": 10,
"axes.spines.top": False, "axes.spines.right": False,
})
except Exception:
plt.style.use("ggplot")
@dataclass
class Args:
runs_root: str = "outputs_qwen05_spaghetti"
out: str = "figs_qwen05_survival"
window: int = 50
alphas: tuple[str, ...] = ("0.25", "0.5", "0.75", "1.0", "1.5", "2.0", "4.0")
thresholds: tuple[float, ...] = (1.0,)
model_contains: str = "Qwen2.5-0.5B"
metric: str = "kl" # 'kl', 'pmass', or 'pmass_eval' (paired w/ KL prompts)
def _load_kl(root: Path, alpha: str, T: int, model_contains: str) -> np.ndarray:
rows = []
for d in sorted(root.iterdir()):
if not d.is_dir() or d.name.startswith("_"):
continue
calib = d / "calib.json"; traj_p = d / "trajectory.json"
if not (calib.exists() and traj_p.exists()):
continue
meta = json.loads(calib.read_text())
if meta.get("window") != T or model_contains not in meta.get("model", ""):
continue
traj = json.loads(traj_p.read_text())
per = traj.get("per_prompt_per_t_kl", {}).get(alpha, [])
for r in per:
arr = np.full(T, np.nan)
arr[: len(r)] = r[: T]
rows.append(arr)
return np.array(rows) if rows else np.zeros((0, T))
def _load_pmass(root: Path, alpha: str, T: int, model_contains: str,
key: str = "pmass",
) -> tuple[np.ndarray, list[int], np.ndarray]:
"""Return (N, F) pmass array + fork_points + (N,) gen_lens for right-censoring.
NaN in the pmass array means the rollout EOS'd before that fork ('complete,
not dead' -- right-censored). gen_lens is the rollout length T per row.
key in {'pmass','pmass_eval'} -> uses 'gen_lens_qa' / 'gen_lens_eval'.
"""
rows = []; gen_lens = []; fork = None
glen_key = "gen_lens_qa" if key == "pmass" else "gen_lens_eval"
for d in sorted(root.iterdir()):
if not d.is_dir() or d.name.startswith("_"):
continue
calib = d / "calib.json"; pm_p = d / "pmass.json"
if not (calib.exists() and pm_p.exists()):
continue
meta = json.loads(calib.read_text())
if meta.get("window") != T or model_contains not in meta.get("model", ""):
continue
pm = json.loads(pm_p.read_text())
if not pm.get("computed", True):
continue
if fork is None:
fork = pm["fork_points"]
per = pm.get(key, {}).get(alpha, [])
glens = pm.get(glen_key, {}).get(alpha, [])
for i, r in enumerate(per):
rows.append(r)
# fall back to T (window) if gen_lens not saved (legacy outputs)
gen_lens.append(int(glens[i]) if i < len(glens) else T)
if not rows:
return np.zeros((0, len(fork or []))), fork or [], np.zeros(0, dtype=int)
L = max(len(r) for r in rows)
arr = np.full((len(rows), L), np.nan)
for i, r in enumerate(rows):
arr[i, : len(r)] = r
return arr, fork, np.array(gen_lens, dtype=int)
def survival_kl(K: np.ndarray, threshold: float) -> np.ndarray:
"""S(t) = fraction with running_max(K) < threshold."""
if K.size == 0:
return np.zeros(0)
rmax = np.maximum.accumulate(np.nan_to_num(K, nan=-np.inf), axis=1)
return (rmax < threshold).mean(axis=0)
def survival_pmass(P: np.ndarray, fork: list[int], gen_lens: np.ndarray,
threshold: float) -> np.ndarray:
"""Right-censored Kaplan-Meier on pmass.
A trajectory is 'at risk' at fork t iff its rollout reached t (gen_len >= t).
A trajectory 'dies' at fork t if pmass(t) < threshold AND it has not died yet.
Right-censored trajectories (gen_len < t) drop out of the denominator at t
(they are 'complete', not 'dead' -- per user 2026-05-05).
KM estimate: S(t) = prod_{s<=t} (1 - d_s / n_s) where d_s = deaths at s,
n_s = at-risk just before s. Reduces to "fraction alive" if no censoring.
"""
if P.size == 0:
return np.zeros(0)
N, F = P.shape
fork_arr = np.array(fork[:F])
# at_risk[i,j] = True if rollout i reached fork[j] (gen_lens[i] >= fork[j])
at_risk = gen_lens[:, None] >= fork_arr[None, :]
# 'dead' event: pmass below threshold (treat NaN as not-dead since we right-censor via at_risk)
P_filled = np.nan_to_num(P, nan=np.inf)
died_at = (P_filled < threshold) & at_risk # (N, F)
# Make death irreversible: once dead, stays dead at all later forks (where still at risk)
ever_dead = np.maximum.accumulate(died_at.astype(np.int8), axis=1).astype(bool)
# KM hazard at each fork: d_s / n_s
S = np.ones(F)
s = 1.0
for j in range(F):
n_s = int(at_risk[:, j].sum())
if n_s == 0:
S[j] = s
continue
# new deaths at this fork: ever_dead at j AND not at j-1
if j == 0:
d_s = int(ever_dead[:, j].sum())
else:
d_s = int((ever_dead[:, j] & ~ever_dead[:, j-1] & at_risk[:, j]).sum())
s *= (1.0 - d_s / n_s)
S[j] = s
return S
def main(a: Args):
root = Path(a.runs_root); out = Path(a.out); out.mkdir(parents=True, exist_ok=True)
n_panels = len(a.thresholds)
fig, axes = plt.subplots(1, n_panels, figsize=(4.6 * n_panels, 3.4),
sharey=True, squeeze=False)
cmap = plt.get_cmap("viridis")
colors = {alpha: cmap(i / max(1, len(a.alphas) - 1)) for i, alpha in enumerate(a.alphas)}
rows_summary = []
for j, thr in enumerate(a.thresholds):
ax = axes[0, j]
for alpha in a.alphas:
if a.metric == "kl":
K = _load_kl(root, alpha, a.window, a.model_contains)
if K.size == 0:
logger.warning(f"no data alpha={alpha}"); continue
S = survival_kl(K, thr); xs = np.arange(len(S)); n = K.shape[0]
xlabel = "token t"
elif a.metric in ("pmass", "pmass_eval"):
P, fork, gen_lens = _load_pmass(root, alpha, a.window, a.model_contains, key=a.metric)
if P.size == 0:
logger.warning(f"no {a.metric} alpha={alpha}"); continue
S = survival_pmass(P, fork, gen_lens, thr)
xs = np.array(fork[: len(S)]); n = P.shape[0]
xlabel = "fork token t"
else:
raise SystemExit(f"unknown metric {a.metric!r}; use 'kl', 'pmass', or 'pmass_eval'")
ax.step(xs, S, where="post", color=colors[alpha], lw=2.0,
label=rf"$\alpha={alpha}$ (n={n})")
below_half = np.where(S <= 0.5)[0]
t50 = int(xs[below_half[0]]) if len(below_half) else None
rows_summary.append({"metric": a.metric, "threshold": thr, "alpha": alpha,
"n": int(n),
"S_mid": float(S[len(S)//2]),
"S_end": float(S[-1]),
"t_S<=0.5": t50})
if n_panels > 1:
ax.set_title(f"threshold = {thr:g}")
ax.set_xlabel(xlabel)
ax.set_ylim(-0.02, 1.05)
if j == 0:
ax.set_ylabel("fraction of trajectories alive")
ax.legend(loc="lower left", fontsize=8, frameon=False)
ax.axvline(20, color="k", ls=":", lw=0.7)
fig.suptitle(f"Survival, {a.model_contains}", fontsize=10)
fig.tight_layout(rect=(0, 0, 1, 0.94))
out_p = out / f"survival_{a.metric}.png"
fig.savefig(out_p, dpi=160, bbox_inches="tight")
logger.info(f"survival -> {out_p}")
from tabulate import tabulate
md = tabulate(rows_summary, headers="keys", tablefmt="pipe", floatfmt=".3f")
(out / f"survival_{a.metric}.md").write_text(md)
logger.info("\n" + md)
if __name__ == "__main__":
main(tyro.cli(Args))