mirror of
https://github.com/wassname/isokl_steering_calibration.git
synced 2026-06-27 16:45:53 +08:00
wip: dense alpha sweep + auto-render figs (Qwen3.5-0.8B w=512); queued OLMo-2/Gemma 4B/Gemma 12B/OLMo-3 7B at w=4096
This commit is contained in:
@@ -0,0 +1,36 @@
|
||||
"""Debug: verify pmass measurement on un-steered baseline."""
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
tok = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"Qwen/Qwen2.5-0.5B-Instruct", torch_dtype=torch.float32
|
||||
).eval()
|
||||
|
||||
prefill = '\n{"choice": '
|
||||
schema = (
|
||||
'Think briefly, then answer immediately and only with: '
|
||||
'{"choice": true} or {"choice": false}. Do not output 1 or 0.'
|
||||
)
|
||||
q = "Is the Eiffel Tower located in Paris, France?"
|
||||
|
||||
msgs = [{"role": "user", "content": schema + " " + q}]
|
||||
chat = tok.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False)
|
||||
full = chat + prefill
|
||||
ids = tok(full, return_tensors="pt").input_ids
|
||||
print(f"prompt len={ids.shape[1]} tokens")
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(ids).logits[0, -1]
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
top = probs.topk(10)
|
||||
print("\nTop-10 next tokens after full prefill (NO steer, NO rollout):")
|
||||
for p, i in zip(top.values, top.indices):
|
||||
print(f" id={int(i):6d} p={float(p):.4f} tok={tok.decode([int(i)])!r}")
|
||||
|
||||
a_ids = [16, 1866, 2514, 830, 3007] # 1 true True ' true' ' True'
|
||||
b_ids = [15, 3849, 4049] # 0 false False
|
||||
sa = float(probs[a_ids].sum())
|
||||
sb = float(probs[b_ids].sum())
|
||||
print(f"\nsum a_ids (true/1)={sa:.4f} sum b_ids (false/0)={sb:.4f} pmass={sa+sb:.4f}")
|
||||
print(f"argmax: id={int(probs.argmax())} tok={tok.decode([int(probs.argmax())])!r}")
|
||||
+320
-99
@@ -1,148 +1,369 @@
|
||||
"""Aggregate per-cell outputs into Figure 1 + the headline table.
|
||||
"""Aggregate per-cell outputs into Figure 1 + headline table.
|
||||
|
||||
Figure 1: two stacked subplots.
|
||||
Top: per-token p95 KL trajectory. x = token offset; y = KL(steer || base).
|
||||
Colour by method, linestyle by alpha (solid=1, dashed=2), seed bands
|
||||
as thin lines, faceted by model. Horizontal at target_kl=1.
|
||||
Bottom: branch-pmass at fork points. x = fork token offset; y = mean pmass
|
||||
across held-out prompts; bands = +/- 1 std across seeds.
|
||||
Layout (default --no-kl-only): 2 rows x N alpha cols.
|
||||
Top row: KL(steered || base) per token on EVAL_PROMPTS.
|
||||
Band mode (default): p50 line + p10..p90 band, rolling-{roll} smooth.
|
||||
Spaghetti mode: individual trajectories.
|
||||
default coloring: red if traj ever crosses KL=1, grey otherwise.
|
||||
--color-by-pmass: segments coloured by paired pmass_eval(t) using a
|
||||
red->yellow->green colormap (dead->alive). Requires
|
||||
pmass.json with non-empty pmass_eval (run_cell.py
|
||||
with --compute-pmass).
|
||||
Bottom row: forked-answer pmass at fork_points.
|
||||
Uses pmass[alpha] (legacy yes/no reasoning prompts). Different prompt set
|
||||
than the KL panel above, so the row-to-row link is across cells, not
|
||||
paired per-trajectory. For paired analysis see scripts/survival.py with
|
||||
--metric pmass on pmass_eval.
|
||||
|
||||
Table: one row per (model, method), columns = c_star (mean +/- std across seeds),
|
||||
KL_p95 @ alpha=1, KL_p95 @ alpha=2, pmass @ alpha=1, pmass @ alpha=2.
|
||||
Reading caveats:
|
||||
- alpha=1 sits near KL=1 on CALIB_PROMPTS by construction. KL on EVAL_PROMPTS
|
||||
only tests generalisation, not budget choice -- not an independent test.
|
||||
- The honest coherence test is pmass / pmass_eval (real forced-choice mass
|
||||
drop), see survival.py for KM-style curves.
|
||||
|
||||
Table: per (model, method, window) c_star mean +/- std across seeds.
|
||||
|
||||
Usage:
|
||||
# smoothed band:
|
||||
python scripts/aggregate.py --runs_root outputs --out figs/
|
||||
# raw spaghetti:
|
||||
python scripts/aggregate.py --runs_root outputs_qwen05_w512 \
|
||||
--out figs_qwen05_pretty_raw --spaghetti --roll 1 \
|
||||
--alphas 0.5 1.0 2.0 4.0
|
||||
# KL spaghetti coloured by paired pmass_eval:
|
||||
python scripts/aggregate.py --runs_root outputs_qwen05_w512 \
|
||||
--out figs_qwen05_color --spaghetti --color-by-pmass
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
import tyro
|
||||
from loguru import logger
|
||||
from tabulate import tabulate
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
try:
|
||||
import seaborn as sns
|
||||
sns.set_theme(context="notebook", style="whitegrid", palette="deep", font_scale=0.95)
|
||||
plt.rcParams.update({
|
||||
"axes.titlesize": 11,
|
||||
"axes.labelsize": 10,
|
||||
"axes.spines.top": False,
|
||||
"axes.spines.right": False,
|
||||
"figure.titlesize": 11,
|
||||
"figure.dpi": 110,
|
||||
})
|
||||
except Exception:
|
||||
plt.style.use("ggplot")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Args:
|
||||
runs_root: str = "outputs"
|
||||
out: str = "figs"
|
||||
window: int = 50 # only this window enters the figure (need long-enough traj for rolling-16)
|
||||
roll: int = 16 # smoothing window for KL trajectory
|
||||
alphas: tuple[str, ...] = ("0.25", "0.5", "0.75", "1.0", "1.5", "2.0", "4.0")
|
||||
kl_ymax: float = 6.0
|
||||
model_contains: str = ""
|
||||
kl_only: bool = False
|
||||
spaghetti: bool = False # plot individual trajectories instead of p10..p90 band
|
||||
color_by_pmass: bool = False # color KL spaghetti lines by paired pmass (requires pmass_eval)
|
||||
|
||||
|
||||
def load_cells(root: Path) -> list[dict]:
|
||||
def _rolling_mean(x: np.ndarray, w: int) -> np.ndarray:
|
||||
"""Rolling mean along axis 0 with edge padding so output len == input len."""
|
||||
if len(x) < w or w <= 1:
|
||||
return x
|
||||
pad = w // 2
|
||||
xp = np.pad(x, pad, mode="edge")
|
||||
kernel = np.ones(w) / w
|
||||
return np.convolve(xp, kernel, mode="valid")[: len(x)]
|
||||
|
||||
|
||||
def load_cells(root: Path, window: int, model_contains: str = "") -> list[dict]:
|
||||
cells = []
|
||||
for d in sorted(root.iterdir()):
|
||||
if not d.is_dir():
|
||||
if not d.is_dir() or d.name.startswith("_"):
|
||||
continue
|
||||
calib = d / "calib.json"; traj_p = d / "trajectory.json"; pm_p = d / "pmass.json"
|
||||
if not (calib.exists() and traj_p.exists() and pm_p.exists()):
|
||||
continue
|
||||
meta = json.loads(calib.read_text())
|
||||
if meta.get("window") != window:
|
||||
continue
|
||||
if model_contains and model_contains not in meta.get("model", ""):
|
||||
continue
|
||||
traj = json.loads(traj_p.read_text())
|
||||
pm = json.loads(pm_p.read_text())
|
||||
# Skip stale outputs that lack per-prompt KL (pre-redesign).
|
||||
if "per_prompt_per_t_kl" not in traj:
|
||||
logger.warning(f"skipping stale {d.name} (no per_prompt_per_t_kl)")
|
||||
continue
|
||||
cells.append({"id": d.name, **meta, "traj": traj, "pmass": pm})
|
||||
return cells
|
||||
|
||||
|
||||
def _draw_kl_panel(ax, K: np.ndarray, a: Args, P: np.ndarray | None = None) -> None:
|
||||
"""Draw KL trajectories on ax. spaghetti mode: thin per-trajectory lines.
|
||||
color_by_pmass: each segment colored by pmass(t) using a green->red colormap
|
||||
(1.0 = lively green, 0.0 = dead red). Otherwise colored by whether traj
|
||||
ever crossed KL=1. Smoothing per-line via _rolling_mean(roll). Black median.
|
||||
"""
|
||||
if not K.size:
|
||||
return
|
||||
xs = np.arange(K.shape[1])
|
||||
if a.spaghetti:
|
||||
# optional per-line smoothing
|
||||
Kp = np.array([_rolling_mean(row, a.roll) for row in K]) if a.roll > 1 else K
|
||||
if a.color_by_pmass and P is not None and P.size and P.shape == K.shape:
|
||||
# LineCollection per trajectory, color = pmass(t).
|
||||
# Draw greenest (alive) first, reddest (dead) last so red sits on top
|
||||
# of the green pile -- otherwise the eye loses the dead trajectories.
|
||||
from matplotlib.collections import LineCollection
|
||||
from matplotlib.colors import LinearSegmentedColormap
|
||||
cmap = LinearSegmentedColormap.from_list("alive", ["#c0392b", "#f1c40f", "#27ae60"]) # dead->alive
|
||||
order = np.argsort(-np.nanmean(P, axis=1)) # high pmass first, low pmass last
|
||||
for traj, pmass_row in zip(Kp[order], P[order]):
|
||||
pts = np.column_stack([xs, traj])
|
||||
segs = np.stack([pts[:-1], pts[1:]], axis=1)
|
||||
lc = LineCollection(segs, cmap=cmap, norm=plt.Normalize(0, 1),
|
||||
linewidths=0.7, alpha=0.55)
|
||||
lc.set_array(pmass_row[:-1])
|
||||
ax.add_collection(lc)
|
||||
ax.set_xlim(xs[0], xs[-1])
|
||||
med = np.nanmedian(Kp, axis=0)
|
||||
ax.plot(xs, med, color="k", lw=1.6)
|
||||
else:
|
||||
crossed = (K > 1.0).any(axis=1)
|
||||
for traj in Kp[~crossed]:
|
||||
ax.plot(xs, traj, color="0.55", lw=0.6, alpha=0.6)
|
||||
for traj in Kp[crossed]:
|
||||
ax.plot(xs, traj, color="C3", lw=0.6, alpha=0.6)
|
||||
med = np.nanmedian(Kp, axis=0)
|
||||
ax.plot(xs, med, color="k", lw=1.6)
|
||||
frac = float(crossed.mean())
|
||||
ax.text(0.97, 0.97, f"{frac:.0%} cross KL=1",
|
||||
transform=ax.transAxes, ha="right", va="top",
|
||||
fontsize=8, color="C3" if frac > 0.5 else "0.3")
|
||||
else:
|
||||
p50 = np.nanpercentile(K, 50, axis=0)
|
||||
p10 = np.nanpercentile(K, 10, axis=0)
|
||||
p90 = np.nanpercentile(K, 90, axis=0)
|
||||
p50s = _rolling_mean(p50, a.roll)
|
||||
p10s = _rolling_mean(p10, a.roll)
|
||||
p90s = _rolling_mean(p90, a.roll)
|
||||
ax.fill_between(xs, p10s, p90s, alpha=0.25, color="C0", lw=0)
|
||||
ax.plot(xs, p50s, color="C0", lw=1.6)
|
||||
|
||||
|
||||
def make_kl_figure(cells: list[dict], a: Args, out_path: Path) -> None:
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
fig, axes = plt.subplots(1, len(a.alphas), figsize=(4.0 * len(a.alphas), 3.2),
|
||||
sharex=True, sharey=True, squeeze=False)
|
||||
label = a.model_contains or "all models"
|
||||
mode = "individual trajectories (red=ever crossed KL=1)" if a.spaghetti \
|
||||
else f"p50 + p10..p90 band, smoothed rolling-{a.roll}"
|
||||
fig.suptitle(
|
||||
f"KL trajectory on N=8 held-out long-form prompts ({label})\n"
|
||||
f"{mode}. Solid line: KL=1 nat. Dotted v-line: t=20.",
|
||||
fontsize=10,
|
||||
)
|
||||
|
||||
for j, alpha in enumerate(a.alphas):
|
||||
ax = axes[0, j]
|
||||
K = _pool_kl(cells, alpha, T=a.window)
|
||||
Pe = _pool_pmass_eval(cells, alpha, _first_fork(cells), T=a.window) if a.color_by_pmass else None
|
||||
_draw_kl_panel(ax, K, a, P=Pe)
|
||||
ax.axhline(1.0, color="k", lw=1.0)
|
||||
ax.axvline(20, color="k", ls=":", lw=0.8)
|
||||
ax.set_title(rf"$\alpha = {alpha}$ (n={K.shape[0]} traj)")
|
||||
ax.set_ylim(0, a.kl_ymax)
|
||||
ax.set_xlabel("token")
|
||||
if j == 0:
|
||||
ax.set_ylabel("KL(steered || base) [nats]")
|
||||
|
||||
if a.color_by_pmass:
|
||||
from matplotlib.cm import ScalarMappable
|
||||
from matplotlib.colors import LinearSegmentedColormap, Normalize
|
||||
cmap = LinearSegmentedColormap.from_list("alive", ["#c0392b", "#f1c40f", "#27ae60"])
|
||||
sm = ScalarMappable(norm=Normalize(0, 1), cmap=cmap); sm.set_array([])
|
||||
cbar = fig.colorbar(sm, ax=axes[0, :].tolist(), location="right", shrink=0.8, pad=0.02)
|
||||
cbar.set_label("pmass (0=dead, 1=alive)")
|
||||
|
||||
fig.tight_layout(rect=(0, 0, 1, 0.86))
|
||||
fig.savefig(out_path, dpi=160, bbox_inches="tight")
|
||||
logger.info(f"KL-only figure -> {out_path}")
|
||||
|
||||
|
||||
def make_table(root: Path) -> pl.DataFrame:
|
||||
rows = []
|
||||
for d in sorted(root.iterdir()):
|
||||
if not d.is_dir() or d.name.startswith("_"):
|
||||
continue
|
||||
calib = d / "calib.json"
|
||||
if not calib.exists():
|
||||
continue
|
||||
meta = json.loads(calib.read_text())
|
||||
traj = json.loads((d / "trajectory.json").read_text())
|
||||
pmass = json.loads((d / "pmass.json").read_text())
|
||||
cells.append({"id": d.name, **meta, "traj": traj, "pmass": pmass})
|
||||
return cells
|
||||
rows.append(json.loads(calib.read_text()) | {"run": d.name})
|
||||
if not rows:
|
||||
return pl.DataFrame()
|
||||
df = pl.DataFrame(rows)
|
||||
df = df.with_columns(pl.col("model").str.split("/").list.last().alias("model_short"))
|
||||
g = (df.group_by(["model_short", "method", "window"])
|
||||
.agg(pl.col("c_star").mean().alias("c_mean"),
|
||||
pl.col("c_star").std().alias("c_std"),
|
||||
pl.len().alias("n_seeds"))
|
||||
.sort(["model_short", "method", "window"]))
|
||||
g = g.with_columns(
|
||||
(pl.col("c_std") / pl.col("c_mean").abs()).alias("c_cv"),
|
||||
)
|
||||
return g
|
||||
|
||||
|
||||
def make_table(cells: list[dict]) -> pl.DataFrame:
|
||||
def _pool_kl(cells: list[dict], alpha: str, T: int) -> np.ndarray:
|
||||
"""Stack per-prompt KL trajectories from all cells -> (N, T) ndarray."""
|
||||
rows = []
|
||||
by_mm = defaultdict(list)
|
||||
for c in cells:
|
||||
by_mm[(c["model"], c["method"])].append(c)
|
||||
for (model, method), group in by_mm.items():
|
||||
c_stars = [g["c_star"] for g in group]
|
||||
# pmass: mean over fork_points and prompts at each alpha, then across seeds
|
||||
for alpha in ("1.0", "2.0"):
|
||||
kls = []
|
||||
pms = []
|
||||
for g in group:
|
||||
kls.append(g["traj"]["per_t_p95_kl"][alpha])
|
||||
pms.append(g["pmass"]["pmass"][alpha])
|
||||
kls_flat = [x for arr in kls for x in arr]
|
||||
pms_flat = [x for prompt in pms for arr in prompt for x in arr]
|
||||
rows.append({
|
||||
"model": model.split("/")[-1],
|
||||
"method": method,
|
||||
"alpha": float(alpha),
|
||||
"c_star_mean": sum(c_stars) / len(c_stars),
|
||||
"n_seeds": len(group),
|
||||
"kl_p95_mean": sum(kls_flat) / max(len(kls_flat), 1),
|
||||
"pmass_mean": sum(pms_flat) / max(len(pms_flat), 1),
|
||||
})
|
||||
return pl.DataFrame(rows)
|
||||
per_prompt = c["traj"]["per_prompt_per_t_kl"].get(alpha, [])
|
||||
for r in per_prompt:
|
||||
arr = np.full(T, np.nan)
|
||||
arr[: len(r)] = r[: T]
|
||||
rows.append(arr)
|
||||
return np.array(rows) if rows else np.zeros((0, T))
|
||||
|
||||
|
||||
def make_figure(cells: list[dict], out_path: Path) -> None:
|
||||
def _first_fork(cells: list[dict]) -> list[int]:
|
||||
for c in cells:
|
||||
if c["pmass"].get("computed", True):
|
||||
return c["pmass"]["fork_points"]
|
||||
return []
|
||||
|
||||
|
||||
def _pool_pmass_eval(cells: list[dict], alpha: str, fork_points: list[int], T: int) -> np.ndarray:
|
||||
"""Pool pmass on EVAL_PROMPTS (paired with KL) and interpolate from fork
|
||||
points to per-token, returning (N, T). NaN for cells without pmass_eval.
|
||||
"""
|
||||
rows = []
|
||||
for c in cells:
|
||||
if not c["pmass"].get("computed", True):
|
||||
continue
|
||||
per_prompt = c["pmass"].get("pmass_eval", {}).get(alpha, [])
|
||||
for r in per_prompt:
|
||||
xs = np.array(fork_points[: len(r)])
|
||||
ys = np.array(r, dtype=float)
|
||||
t = np.arange(T)
|
||||
interp = np.interp(t, xs, ys, left=ys[0], right=ys[-1])
|
||||
rows.append(interp)
|
||||
return np.array(rows) if rows else np.zeros((0, T))
|
||||
|
||||
|
||||
def _pool_pmass(cells: list[dict], alpha: str) -> tuple[np.ndarray, list[int]]:
|
||||
rows = []
|
||||
fork = None
|
||||
for c in cells:
|
||||
if not c["pmass"].get("computed", True):
|
||||
continue
|
||||
f = c["pmass"]["fork_points"]
|
||||
if fork is None: fork = f
|
||||
per_prompt = c["pmass"]["pmass"].get(alpha, [])
|
||||
for r in per_prompt:
|
||||
rows.append(r)
|
||||
if not rows:
|
||||
return np.zeros((0, len(fork or []))), fork or []
|
||||
L = max(len(r) for r in rows)
|
||||
arr = np.full((len(rows), L), np.nan)
|
||||
for i, r in enumerate(rows):
|
||||
arr[i, : len(r)] = r
|
||||
return arr, fork
|
||||
|
||||
|
||||
def make_figure(cells: list[dict], a: Args, out_path: Path) -> None:
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
models = sorted({c["model"] for c in cells})
|
||||
methods = sorted({c["method"] for c in cells})
|
||||
fig, axes = plt.subplots(2, len(models), figsize=(5 * len(models), 7),
|
||||
sharex="col", squeeze=False)
|
||||
cmap = plt.get_cmap("tab10")
|
||||
method_color = {m: cmap(i) for i, m in enumerate(methods)}
|
||||
|
||||
for ci, model in enumerate(models):
|
||||
ax_kl = axes[0, ci]
|
||||
ax_pm = axes[1, ci]
|
||||
ax_kl.set_title(model.split("/")[-1])
|
||||
ax_kl.axhline(1.0, color="black", linestyle=":", linewidth=0.8, alpha=0.5)
|
||||
ax_kl.set_ylabel("p95 KL(steer || base)")
|
||||
ax_pm.set_xlabel("token offset")
|
||||
ax_pm.set_ylabel("branch pmass")
|
||||
ax_pm.set_ylim(-0.02, 1.02)
|
||||
ax_kl.set_yscale("log")
|
||||
fig, axes = plt.subplots(2, len(a.alphas), figsize=(4.0 * len(a.alphas), 5.5),
|
||||
sharex="row", sharey="row", squeeze=False)
|
||||
n_cells = len(cells)
|
||||
fig.suptitle(
|
||||
f"Figure 1: iso-KL calibration on Qwen2.5-0.5B-Instruct\n"
|
||||
f"Top: KL(steered || base) per token on long-form held-out prompts.\n"
|
||||
f"Bottom: forked-answer pmass on yes/no reasoning prompts (high=alive, low=collapsed).\n"
|
||||
f"{n_cells} cells (3 methods x 1 seed) x N=8 prompts. "
|
||||
f"Solid h-line: KL=1 nat. Dotted v-line: t=20.",
|
||||
fontsize=10,
|
||||
)
|
||||
|
||||
for method in methods:
|
||||
for alpha, ls in [("1.0", "-"), ("2.0", "--")]:
|
||||
kls = [c["traj"]["per_t_p95_kl"][alpha]
|
||||
for c in cells if c["model"] == model and c["method"] == method]
|
||||
if not kls:
|
||||
continue
|
||||
arr = np.array(kls)
|
||||
x = np.arange(arr.shape[1])
|
||||
ax_kl.plot(x, arr.mean(0), color=method_color[method],
|
||||
linestyle=ls, linewidth=2,
|
||||
label=f"{method} a={alpha}")
|
||||
if arr.shape[0] > 1:
|
||||
ax_kl.fill_between(x, arr.min(0), arr.max(0),
|
||||
color=method_color[method], alpha=0.12)
|
||||
for j, alpha in enumerate(a.alphas):
|
||||
# ---- top: KL ----
|
||||
ax = axes[0, j]
|
||||
K = _pool_kl(cells, alpha, T=a.window)
|
||||
Pe = _pool_pmass_eval(cells, alpha, _first_fork(cells), T=a.window) if a.color_by_pmass else None
|
||||
_draw_kl_panel(ax, K, a, P=Pe)
|
||||
ax.axhline(1.0, color="k", lw=1.0)
|
||||
ax.axvline(20, color="k", ls=":", lw=0.8)
|
||||
ax.set_title(rf"$\alpha = {alpha}$ (n={K.shape[0]} traj)")
|
||||
ax.set_ylim(0, a.kl_ymax)
|
||||
if j == 0:
|
||||
ax.set_ylabel("KL(steered || base) [nats]")
|
||||
ax.set_xlabel("token")
|
||||
|
||||
pms = [c["pmass"]["pmass"][alpha]
|
||||
for c in cells if c["model"] == model and c["method"] == method]
|
||||
if not pms:
|
||||
continue
|
||||
# pms: list of (n_seed) of (n_prompt) of (n_fork)
|
||||
pms_arr = np.array(pms) # (n_seed, n_prompt, n_fork)
|
||||
fork = cells[0]["pmass"]["fork_points"]
|
||||
mean = pms_arr.mean(axis=(0, 1))
|
||||
std = pms_arr.std(axis=(0, 1))
|
||||
ax_pm.plot(fork, mean, color=method_color[method],
|
||||
linestyle=ls, linewidth=2)
|
||||
ax_pm.fill_between(fork, mean - std, mean + std,
|
||||
color=method_color[method], alpha=0.12)
|
||||
if ci == 0:
|
||||
ax_kl.legend(fontsize=8, loc="upper left")
|
||||
fig.tight_layout()
|
||||
fig.savefig(out_path, dpi=150, bbox_inches="tight")
|
||||
# ---- bottom: pmass ----
|
||||
ax2 = axes[1, j]
|
||||
P, fork = _pool_pmass(cells, alpha)
|
||||
if P.size:
|
||||
xs = np.array(fork[: P.shape[1]])
|
||||
if a.spaghetti:
|
||||
for row in P:
|
||||
ax2.plot(xs, row, color="C1", lw=0.6, alpha=0.5)
|
||||
ax2.plot(xs, np.nanmedian(P, axis=0), color="k", lw=1.6, marker="o", ms=3)
|
||||
else:
|
||||
p50 = np.nanpercentile(P, 50, axis=0)
|
||||
p10 = np.nanpercentile(P, 10, axis=0)
|
||||
p90 = np.nanpercentile(P, 90, axis=0)
|
||||
ax2.fill_between(xs, p10, p90, alpha=0.25, color="C1", lw=0)
|
||||
ax2.plot(xs, p50, color="C1", lw=1.6, marker="o", ms=3)
|
||||
ax2.axvline(20, color="k", ls=":", lw=0.8)
|
||||
ax2.set_ylim(-0.02, 1.05)
|
||||
if j == 0:
|
||||
ax2.set_ylabel('pmass = p(true/1) + p(false/0)\nat fork t')
|
||||
ax2.set_xlabel("fork token")
|
||||
|
||||
if a.color_by_pmass:
|
||||
from matplotlib.cm import ScalarMappable
|
||||
from matplotlib.colors import LinearSegmentedColormap, Normalize
|
||||
cmap = LinearSegmentedColormap.from_list("alive", ["#c0392b", "#f1c40f", "#27ae60"])
|
||||
sm = ScalarMappable(norm=Normalize(0, 1), cmap=cmap); sm.set_array([])
|
||||
cbar = fig.colorbar(sm, ax=axes[0, :].tolist(), location="right", shrink=0.8, pad=0.02)
|
||||
cbar.set_label("pmass (0=dead, 1=alive)")
|
||||
|
||||
fig.tight_layout(rect=(0, 0, 1, 0.88))
|
||||
fig.savefig(out_path, dpi=140, bbox_inches="tight")
|
||||
logger.info(f"figure -> {out_path}")
|
||||
|
||||
|
||||
def main(a: Args):
|
||||
out = Path(a.out); out.mkdir(parents=True, exist_ok=True)
|
||||
cells = load_cells(Path(a.runs_root))
|
||||
root = Path(a.runs_root)
|
||||
cells = load_cells(root, window=a.window, model_contains=a.model_contains)
|
||||
if not cells:
|
||||
raise SystemExit(f"no cells under {a.runs_root}")
|
||||
logger.info(f"loaded {len(cells)} cells")
|
||||
raise SystemExit(f"no cells with window={a.window} under {root}")
|
||||
logger.info(f"loaded {len(cells)} cells (window={a.window})")
|
||||
|
||||
df = make_table(cells)
|
||||
df.write_csv(out / "table.csv")
|
||||
md = df.to_pandas().to_markdown(index=False, floatfmt=".3f")
|
||||
(out / "table.md").write_text(md)
|
||||
logger.info(f"table -> {out/'table.md'}\n{md}")
|
||||
df = make_table(root)
|
||||
if not df.is_empty():
|
||||
df.write_csv(out / "table.csv")
|
||||
md = tabulate(df.rows(), headers=df.columns, tablefmt="pipe", floatfmt=".3f")
|
||||
(out / "table.md").write_text(md)
|
||||
logger.info(f"table -> {out/'table.md'}\n{md}")
|
||||
|
||||
make_figure(cells, out / "figure1.png")
|
||||
if a.kl_only:
|
||||
make_kl_figure(cells, a, out / "figure1_kl_only.png")
|
||||
else:
|
||||
make_figure(cells, a, out / "figure1.png")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -0,0 +1,243 @@
|
||||
"""Empirical audit of branch_pmass: load model, generate one rollout per alpha,
|
||||
print decoded gen text and top-k tokens at the prefill end-point, and dump
|
||||
everything to JSON for review.
|
||||
|
||||
Goal: distinguish between
|
||||
(a) pmass measurement is wrong (top tokens at prefill end DON'T match the
|
||||
pmass we read out of the schema-token groups), vs
|
||||
(b) pmass is right but the model just doesn't put mass on schema tokens
|
||||
naturally (steering isn't the problem, the prompt+prefill is), vs
|
||||
(c) pmass is right and steering really does collapse coherence.
|
||||
|
||||
For first prompt per alpha:
|
||||
- decoded full generation text
|
||||
- per fork point: top-10 tokens with prob, plus pmass(true)+pmass(false), p_true
|
||||
|
||||
Usage:
|
||||
uv run --extra all python scripts/audit_pmass.py \
|
||||
--model Qwen/Qwen3.5-0.8B --window 64 --out audit_pmass.json
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import tyro
|
||||
from loguru import logger
|
||||
|
||||
# Import the live module so any audit reflects current code
|
||||
from iso_kl_figure import (
|
||||
MeanDiffC, PCAC, DirectionalAblationC,
|
||||
train, calibrate_iso_kl,
|
||||
)
|
||||
from iso_kl_figure.branch_pmass import collect_choice_token_ids, branch_pmass
|
||||
|
||||
|
||||
# Re-import constants from run_cell so audit uses the same prompts/schema
|
||||
import importlib.util, sys
|
||||
_rc_path = Path(__file__).parent / "run_cell.py"
|
||||
_spec = importlib.util.spec_from_file_location("_run_cell", _rc_path)
|
||||
_rc = importlib.util.module_from_spec(_spec)
|
||||
sys.modules["_run_cell"] = _rc
|
||||
_spec.loader.exec_module(_rc)
|
||||
|
||||
CALIB_PROMPTS = _rc.CALIB_PROMPTS
|
||||
EVAL_PROMPTS = _rc.EVAL_PROMPTS
|
||||
_QUESTIONS = _rc._QUESTIONS
|
||||
_SCHEMA = _rc._SCHEMA
|
||||
PREFILL_STR = _rc.PREFILL_STR
|
||||
POS_NEG = _rc.POS_NEG
|
||||
METHOD_MAP = {"mean_diff": MeanDiffC, "pca": PCAC, "directional_ablation": DirectionalAblationC}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Args:
|
||||
model: str = "Qwen/Qwen3.5-0.8B"
|
||||
method: str = "mean_diff"
|
||||
seed: int = 0
|
||||
window: int = 64
|
||||
layer_frac: float = 0.6
|
||||
target_kl: float = 1.0
|
||||
device: str = "cuda"
|
||||
dtype: str = "bfloat16"
|
||||
alphas: tuple[float, ...] = (0.0, 0.25, 0.5, 0.75, 1.0, 1.5, 2.0, 4.0)
|
||||
fork_points: tuple[int, ...] = (0, 8, 16, 32, 64)
|
||||
out: str = "audit_pmass.json"
|
||||
top_k: int = 10
|
||||
use_qa_prompt: bool = True # True: yes/no q+_SCHEMA; False: long-form EVAL_PROMPTS[0]
|
||||
skip_calib: bool = False # if True, use fixed_coeffs as raw c per alpha (no iso-KL bisection)
|
||||
fixed_coeffs: tuple[float, ...] = (0.0, 0.25, 0.5, 0.75, 1.0, 1.5, 2.0, 4.0) # used when skip_calib=True; aligned with --alphas
|
||||
cross_check_branch_pmass: bool = True # also call branch_pmass and compare to local recompute
|
||||
|
||||
|
||||
def _set_seed(s: int):
|
||||
import random, numpy as np
|
||||
random.seed(s); np.random.seed(s); torch.manual_seed(s)
|
||||
if torch.cuda.is_available(): torch.cuda.manual_seed_all(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def topk_at_prefill_end(v, model, tok, prompt_ids, rolled_ids, fork_points,
|
||||
prefill_str, a_ids, b_ids, k=10, device="cuda"):
|
||||
"""Mirror branch_pmass logic but ALSO return top-k tokens for inspection."""
|
||||
import copy
|
||||
pids = prompt_ids.to(device); rolled = rolled_ids.to(device)
|
||||
P = pids.shape[0]; T = rolled.shape[0]
|
||||
pre_t = torch.tensor(tok.encode(prefill_str, add_special_tokens=False),
|
||||
device=device, dtype=torch.long)
|
||||
a_t = torch.tensor(list(a_ids), dtype=torch.long, device=device)
|
||||
b_t = torch.tensor(list(b_ids), dtype=torch.long, device=device)
|
||||
all_t = torch.cat([a_t, b_t])
|
||||
out_per_fork = []
|
||||
for t in fork_points:
|
||||
if t > T:
|
||||
out_per_fork.append({"t": int(t), "skipped": True, "reason": f"t>T={T}"})
|
||||
continue
|
||||
prefix = rolled[:t]
|
||||
seq = torch.cat([pids, prefix, pre_t]).unsqueeze(0)
|
||||
with v(model):
|
||||
logits = model(seq).logits[0, -1].float()
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
# top-k
|
||||
tk_p, tk_i = probs.topk(k)
|
||||
topk = [(tok.decode([int(i)]), float(p)) for p, i in zip(tk_p.tolist(), tk_i.tolist())]
|
||||
pa = float(probs[a_t].sum()); pb = float(probs[b_t].sum())
|
||||
pm = pa + pb
|
||||
pt = pa / pm if pm > 0 else float("nan")
|
||||
out_per_fork.append({
|
||||
"t": int(t),
|
||||
"skipped": False,
|
||||
"topk": topk,
|
||||
"p_true_group": pa,
|
||||
"p_false_group": pb,
|
||||
"pmass": pm,
|
||||
"p_true": pt,
|
||||
"argmax": tok.decode([int(probs.argmax())]),
|
||||
})
|
||||
return out_per_fork
|
||||
|
||||
|
||||
def main(a: Args):
|
||||
_set_seed(a.seed)
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
dtype = getattr(torch, a.dtype)
|
||||
logger.info(f"loading model={a.model}")
|
||||
tok = AutoTokenizer.from_pretrained(a.model)
|
||||
if tok.pad_token_id is None: tok.pad_token_id = tok.eos_token_id
|
||||
model = AutoModelForCausalLM.from_pretrained(a.model, torch_dtype=dtype).to(a.device)
|
||||
model.eval()
|
||||
|
||||
n_layers = model.config.num_hidden_layers
|
||||
layer = int(a.layer_frac * n_layers)
|
||||
cfg_cls = METHOD_MAP[a.method]
|
||||
cfg = cfg_cls(coeff=1.0, layers=(layer,))
|
||||
|
||||
# Train + calibrate
|
||||
pos = [tok.apply_chat_template([{"role": "user", "content": u},
|
||||
{"role": "assistant", "content": p}], tokenize=False)
|
||||
for u, (p, _) in zip(CALIB_PROMPTS, POS_NEG)]
|
||||
neg = [tok.apply_chat_template([{"role": "user", "content": u},
|
||||
{"role": "assistant", "content": n}], tokenize=False)
|
||||
for u, (_, n) in zip(CALIB_PROMPTS, POS_NEG)]
|
||||
v = train(model, tok, pos, neg, cfg, batch_size=4, max_length=128)
|
||||
if a.skip_calib:
|
||||
c_star = 1.0 # fixed_coeffs are absolute
|
||||
logger.info(f"skip_calib: using fixed_coeffs={a.fixed_coeffs} as raw c")
|
||||
else:
|
||||
c_star, _ = calibrate_iso_kl(v, model, tok, CALIB_PROMPTS,
|
||||
target_kl=a.target_kl, target_stat="kl_p95",
|
||||
T=a.window, device=a.device)
|
||||
logger.info(f"c_star={c_star:+.4f}")
|
||||
|
||||
a_ids, b_ids = collect_choice_token_ids(tok)
|
||||
logger.info(f"a_ids (true-group)={a_ids} -> tokens={[tok.decode([i]) for i in a_ids]}")
|
||||
logger.info(f"b_ids (false-group)={b_ids} -> tokens={[tok.decode([i]) for i in b_ids]}")
|
||||
|
||||
# Single prompt
|
||||
if a.use_qa_prompt:
|
||||
prompt = f"{_QUESTIONS[0]}\n\n{_SCHEMA}"
|
||||
else:
|
||||
prompt = EVAL_PROMPTS[0]
|
||||
logger.info(f"prompt: {prompt[:120]}...")
|
||||
|
||||
ids = tok.apply_chat_template([{"role": "user", "content": prompt}],
|
||||
add_generation_prompt=True,
|
||||
return_tensors="pt").input_ids[0]
|
||||
|
||||
out = {
|
||||
"model": a.model, "method": a.method, "seed": a.seed,
|
||||
"c_star": c_star, "layer": layer,
|
||||
"prompt": prompt, "use_qa_prompt": a.use_qa_prompt,
|
||||
"prefill": PREFILL_STR,
|
||||
"a_ids": a_ids, "b_ids": b_ids,
|
||||
"a_tokens": [tok.decode([i]) for i in a_ids],
|
||||
"b_tokens": [tok.decode([i]) for i in b_ids],
|
||||
"fork_points": list(a.fork_points),
|
||||
"alphas": {},
|
||||
}
|
||||
|
||||
for i_alpha, alpha in enumerate(a.alphas):
|
||||
if a.skip_calib:
|
||||
v.cfg.coeff = a.fixed_coeffs[i_alpha]
|
||||
else:
|
||||
v.cfg.coeff = alpha * c_star
|
||||
logger.info(f"=== alpha={alpha} coeff={v.cfg.coeff:+.4f} ===")
|
||||
with v(model):
|
||||
gen_out = model.generate(
|
||||
ids.unsqueeze(0).to(a.device),
|
||||
max_new_tokens=a.window,
|
||||
pad_token_id=tok.pad_token_id,
|
||||
eos_token_id=tok.eos_token_id,
|
||||
do_sample=False,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
gen = gen_out.sequences[0, ids.shape[0]:]
|
||||
gen_text = tok.decode(gen, skip_special_tokens=False)
|
||||
gen_len = int(gen.shape[0])
|
||||
logger.info(f" gen_len={gen_len} text[:200]={gen_text[:200]!r}")
|
||||
|
||||
per_fork = topk_at_prefill_end(
|
||||
v, model, tok, ids, gen, list(a.fork_points),
|
||||
PREFILL_STR, a_ids, b_ids, k=a.top_k, device=a.device,
|
||||
)
|
||||
# cross-check: compare to production branch_pmass output
|
||||
bp_compare = None
|
||||
if a.cross_check_branch_pmass:
|
||||
bp = branch_pmass(
|
||||
v, model, tok, ids, gen, list(a.fork_points),
|
||||
PREFILL_STR, a_ids, b_ids,
|
||||
rollout_cache=getattr(gen_out, "past_key_values", None),
|
||||
device=a.device,
|
||||
)
|
||||
bp_compare = {
|
||||
"pmass": bp["pmass"], "p_true": bp["p_true"],
|
||||
"argmax_str": bp["argmax_str"], "was_thinking": bp["was_thinking"],
|
||||
}
|
||||
for row, bpm, bpt, bam in zip(per_fork, bp["pmass"], bp["p_true"], bp["argmax_str"]):
|
||||
if row.get("skipped"): continue
|
||||
local_pm = row["pmass"]; local_pt = row["p_true"]; local_am = row["argmax"]
|
||||
tag = "OK"
|
||||
if not (abs(local_pm - bpm) < 1e-3 and (local_am == bam)):
|
||||
tag = "MISMATCH"
|
||||
logger.info(f" cross-check t={row['t']:>3} {tag}: local pm={local_pm:.4f} argmax={local_am!r} | branch_pmass pm={bpm:.4f} argmax={bam!r}")
|
||||
# log inline
|
||||
for row in per_fork:
|
||||
if row.get("skipped"):
|
||||
logger.info(f" t={row['t']:>3} SKIPPED ({row['reason']})")
|
||||
else:
|
||||
top3 = ", ".join(f"{tok!r}={p:.3f}" for tok, p in row["topk"][:3])
|
||||
logger.info(f" t={row['t']:>3} pmass={row['pmass']:.3f} "
|
||||
f"p_true={row['p_true']:.3f} argmax={row['argmax']!r} top3=[{top3}]")
|
||||
out["alphas"][str(alpha)] = {
|
||||
"coeff": float(v.cfg.coeff),
|
||||
"gen_text": gen_text, "gen_len": gen_len, "per_fork": per_fork,
|
||||
"branch_pmass_compare": bp_compare,
|
||||
}
|
||||
|
||||
Path(a.out).write_text(json.dumps(out, indent=2, default=str))
|
||||
logger.info(f"DONE -> {a.out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(tyro.cli(Args))
|
||||
+22
-1
@@ -5,9 +5,11 @@ cd "$(dirname "$0")/.."
|
||||
ROOT="$PWD"
|
||||
|
||||
MODELS=(
|
||||
"Qwen/Qwen3.5-0.8B"
|
||||
"Qwen/Qwen2.5-0.5B-Instruct"
|
||||
"meta-llama/Llama-3.2-1B-Instruct"
|
||||
"Qwen/Qwen3-4B-Instruct-2507"
|
||||
"Qwen/Qwen3-4B-Thinking-2507"
|
||||
)
|
||||
METHODS=("mean_diff" "directional_ablation" "pca")
|
||||
SEEDS=(0 1 2)
|
||||
@@ -16,6 +18,7 @@ WINDOWS=(20 50)
|
||||
# Priority: small models first so the figure starts populating; 4B last.
|
||||
prio_for_model() {
|
||||
case "$1" in
|
||||
*0.8B*) echo 35 ;;
|
||||
*0.5B*) echo 30 ;;
|
||||
*1B*) echo 20 ;;
|
||||
*4B*) echo 10 ;;
|
||||
@@ -31,7 +34,25 @@ for model in "${MODELS[@]}"; do
|
||||
for window in "${WINDOWS[@]}"; do
|
||||
run_id="$(basename "$model")_${method}_s${seed}_w${window}"
|
||||
if [ -f "outputs/${run_id}/calib.json" ]; then
|
||||
echo "skip ${run_id} (already done)"; continue
|
||||
if env -i HOME="$HOME" PATH="$ROOT/.venv/bin:/usr/bin:/usr/local/bin" "$ROOT/.venv/bin/python" - <<PY
|
||||
import json
|
||||
from pathlib import Path
|
||||
run = Path("outputs/${run_id}")
|
||||
traj = json.loads((run / "trajectory.json").read_text()) if (run / "trajectory.json").exists() else {}
|
||||
pm = json.loads((run / "pmass.json").read_text()) if (run / "pmass.json").exists() else {}
|
||||
ok = (
|
||||
"per_prompt_per_t_kl" in traj
|
||||
and "schema" in pm
|
||||
and "4.0" in traj.get("per_t_p95_kl", {})
|
||||
and "4.0" in pm.get("pmass", {})
|
||||
)
|
||||
raise SystemExit(0 if ok else 1)
|
||||
PY
|
||||
then
|
||||
echo "skip ${run_id} (fresh)"
|
||||
continue
|
||||
fi
|
||||
echo "rerun ${run_id} (stale outputs)"
|
||||
fi
|
||||
label="why: stability of iso-KL calib for ${method} on ${model} (seed ${seed}, T=${window}); resolve: include cell in figure1 if it converges, else flag as bracket-pinned"
|
||||
pueue add -w "$ROOT" -o "$prio" -l "$label" -- \
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import tyro
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class Args:
|
||||
run_dir: str
|
||||
threshold: float = 0.95
|
||||
out_name: str = "figs_auto"
|
||||
model_contains: str = ""
|
||||
|
||||
|
||||
def _ensure_single_run_root(run_dir: Path) -> Path:
|
||||
root = run_dir.parent / f"_{run_dir.name}_single"
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
link = root / run_dir.name
|
||||
if link.is_symlink() or link.exists():
|
||||
try:
|
||||
if link.resolve() == run_dir.resolve():
|
||||
return root
|
||||
except Exception:
|
||||
pass
|
||||
if link.is_symlink() or link.is_file():
|
||||
link.unlink()
|
||||
else:
|
||||
raise RuntimeError(f"refusing to replace non-file staging path: {link}")
|
||||
os.symlink(run_dir.resolve(), link)
|
||||
return root
|
||||
|
||||
|
||||
def _run(cmd: list[str], cwd: Path) -> None:
|
||||
logger.info("$ " + " ".join(cmd))
|
||||
subprocess.run(cmd, cwd=cwd, check=True)
|
||||
|
||||
|
||||
def main(a: Args) -> None:
|
||||
run_dir = Path(a.run_dir).resolve()
|
||||
repo_root = Path(__file__).resolve().parents[1]
|
||||
single_root = _ensure_single_run_root(run_dir)
|
||||
out_dir = run_dir / a.out_name
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
calib = run_dir / "calib.json"
|
||||
if not calib.exists():
|
||||
raise SystemExit(f"missing {calib}")
|
||||
|
||||
model_filter = a.model_contains or run_dir.name.split("_", 1)[0]
|
||||
common = [
|
||||
sys.executable,
|
||||
"scripts/survival.py",
|
||||
"--runs-root", str(single_root),
|
||||
"--out", str(out_dir / "survival"),
|
||||
"--window", str(__import__("json").loads(calib.read_text())["window"]),
|
||||
"--metric", "pmass_eval",
|
||||
"--thresholds", str(a.threshold),
|
||||
"--model-contains", model_filter,
|
||||
]
|
||||
_run(common, repo_root)
|
||||
_run([
|
||||
sys.executable,
|
||||
"scripts/spaghetti_kl_alive.py",
|
||||
"--runs-root", str(single_root),
|
||||
"--out", str(out_dir / "spaghetti"),
|
||||
"--window", str(__import__("json").loads(calib.read_text())["window"]),
|
||||
"--threshold", str(a.threshold),
|
||||
"--model-contains", model_filter,
|
||||
], repo_root)
|
||||
_run([
|
||||
sys.executable,
|
||||
"scripts/aggregate.py",
|
||||
"--runs-root", str(single_root),
|
||||
"--out", str(out_dir / "aggregate"),
|
||||
"--window", str(__import__("json").loads(calib.read_text())["window"]),
|
||||
"--spaghetti",
|
||||
"--color-by-pmass",
|
||||
"--kl-only",
|
||||
"--model-contains", model_filter,
|
||||
], repo_root)
|
||||
|
||||
pngs = [
|
||||
out_dir / "survival" / "survival_pmass_eval.png",
|
||||
out_dir / "spaghetti" / "kl_alive_spaghetti.png",
|
||||
out_dir / "aggregate" / "figure1_kl_only.png",
|
||||
]
|
||||
for png in pngs:
|
||||
logger.info(f"PNG -> {png}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(tyro.cli(Args))
|
||||
+260
-48
@@ -1,27 +1,56 @@
|
||||
"""End-to-end runner for one (model, method, seed, window) cell.
|
||||
|
||||
Flow:
|
||||
1. Load model + tokenizer (HF), set seed.
|
||||
2. Build pos/neg prompts (cheap pair); train the steering Vector v.
|
||||
3. Calibrate iso-KL at target_kl=1 over T=window tokens. Save full history
|
||||
(incl. per-token KL arrays) to outputs/<run_id>/history.json.
|
||||
4. Re-run measure_kl at coeff=alpha*c_star (alpha in {1, 2}) on a held-out
|
||||
prompt set so the trajectory plot reflects generalisation, not the
|
||||
calibration set itself. Save per-token p95 KL to trajectory.json.
|
||||
5. For each held-out prompt, rollout T_eval tokens under the steered model,
|
||||
then branch-pmass at fork_points {0, 5, ..., T_eval}. Save to
|
||||
pmass.json. Use a JSON-format suffix so target tokens are well-defined.
|
||||
Artefacts (per cell, under outputs/<run_id>/):
|
||||
calib.json c_star + run metadata (model, method, seed, layer, window).
|
||||
history.json bisection history (without per-token KL arrays).
|
||||
trajectory.json per-token KL on EVAL_PROMPTS for each alpha:
|
||||
per_t_p95_kl[alpha]: list[T]
|
||||
per_prompt_per_t_kl[alpha]: list[N_prompts][T]
|
||||
pmass.json forked-answer probability mass at fork_points:
|
||||
pmass[alpha]: yes/no reasoning prompts (legacy probe)
|
||||
pmass_eval[alpha]: SAME prompts as trajectory.json (paired w KL)
|
||||
gen_lens_qa[alpha], gen_lens_eval[alpha]: T per rollout
|
||||
(use for right-censoring -- NaN at t > T means rollout
|
||||
EOS'd before that fork, NOT a measurement failure).
|
||||
debug_first[alpha]: gen_text + per-fork pmass for the
|
||||
FIRST qa & eval prompt (sanity-check; see survival.py).
|
||||
pmass_eval is the metric to use when you want a per-trajectory
|
||||
coherence signal aligned with the KL trajectory.
|
||||
results.csv one row per alpha with kl_p95/mean/max.
|
||||
|
||||
Outputs one CSV row per (alpha, prompt) into outputs/<run_id>/results.csv
|
||||
plus the artefacts above.
|
||||
Flow:
|
||||
1. Load model + tokenizer; set seed.
|
||||
2. Build pos/neg pair (cheap content-vs-refusal); train Vector v.
|
||||
3. Calibrate iso-KL: bisect coeff so per-token kl_p95 hits target_kl (default 1)
|
||||
over T=window tokens on CALIB_PROMPTS.
|
||||
4. For each alpha, set coeff=alpha*c_star and:
|
||||
a) measure_kl on EVAL_PROMPTS -> per_prompt_per_t_kl trajectories.
|
||||
b) (--compute-pmass) yes/no questions: rollout, then at each fork point
|
||||
prefill `\n{"choice": ` and read mass on {true/True/ true/.../1/ 1}
|
||||
and {false/.../0}. Skip with --skip-pmass-qa.
|
||||
c) (--compute-pmass) repeat (b) on EVAL_PROMPTS -> pmass_eval, paired
|
||||
with the KL trajectory on the same rollouts. This is the honest
|
||||
coherence signal; KL is largely a calibration tautology at alpha=1.
|
||||
|
||||
Fork-point modes:
|
||||
--fork-step N linear: {0, N, 2N, ..., window}
|
||||
--fork-log log-spaced: {0, 1, 2, 4, 8, ..., window} (with --n-log-forks)
|
||||
|
||||
KV-cache:
|
||||
Rollouts use generate(..., return_dict_in_generate=True). branch_pmass
|
||||
uses full-recompute by default (cache reuse adds ~10% bf16 noise; verified
|
||||
by scripts/audit_pmass.py). Pass use_rollout_cache=True only if you accept
|
||||
the noise.
|
||||
|
||||
Usage (tyro):
|
||||
python scripts/run_cell.py --model Qwen/Qwen2.5-0.5B-Instruct --method mean_diff \
|
||||
--seed 0 --window 50 --run_id qwen05b_md_s0_w50
|
||||
--seed 0 --window 512 --compute-pmass --skip-pmass-qa --fork-log
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
@@ -33,7 +62,7 @@ from iso_kl_figure import (
|
||||
SteeringConfig, MeanDiffC, PCAC, DirectionalAblationC,
|
||||
train, calibrate_iso_kl, measure_kl,
|
||||
)
|
||||
from iso_kl_figure.branch_pmass import branch_pmass
|
||||
from iso_kl_figure.branch_pmass import branch_pmass, collect_choice_token_ids
|
||||
|
||||
|
||||
CALIB_PROMPTS = [
|
||||
@@ -43,14 +72,38 @@ CALIB_PROMPTS = [
|
||||
"Briefly describe how plate tectonics shape mountain ranges.",
|
||||
]
|
||||
|
||||
# Held-out prompts -- different from calibration; tests generalisation.
|
||||
# EVAL prompts are yes/no questions paired with schema -- same contract as
|
||||
# tinymfv/guided.py so branch_pmass at any fork can score format-compliance.
|
||||
# The model is steered into <think>\n at the prompt boundary; at fork t we
|
||||
# inject "\nI should answer now.</think>\n{\"choice\": " and read p(true|false).
|
||||
EVAL_PROMPTS = [
|
||||
"Explain why the sky is blue in two short sentences.",
|
||||
"Summarise the plot of Hamlet in one paragraph.",
|
||||
"List three causes of inflation.",
|
||||
"Write a haiku about an octopus librarian.",
|
||||
"Is the sky blue primarily because of Rayleigh scattering of sunlight?",
|
||||
"Does Hamlet die at the end of Shakespeare's play 'Hamlet'?",
|
||||
"Is excessive money supply growth one cause of inflation in modern economies?",
|
||||
"Do octopuses have three hearts and blue copper-based blood?",
|
||||
"Does photosynthesis use light energy to convert CO2 and water into glucose?",
|
||||
"Was the Eiffel Tower originally built as the entrance arch for the 1889 World's Fair?",
|
||||
"Can plate tectonic collisions form mountain ranges over geological time?",
|
||||
"Is a circular import one possible cause of a Python script crashing on import?",
|
||||
]
|
||||
|
||||
# pmass diagnostic prompts share the same schema as EVAL above.
|
||||
_SCHEMA = (
|
||||
'Reason for at least 80 words before the final answer. Then answer only with '
|
||||
'{"choice": true} or {"choice": false}. If you use numeric JSON, use 1 for true and 0 for false.'
|
||||
)
|
||||
_QUESTIONS = [
|
||||
"Is the Eiffel Tower located in Paris, France?",
|
||||
"Do dolphins breathe air through lungs rather than gills?",
|
||||
"Was Albert Einstein born in Germany?",
|
||||
"Can humans naturally see ultraviolet light?",
|
||||
"Did the Roman Empire fall in 476 CE?",
|
||||
"Are tomatoes botanically classified as fruits?",
|
||||
"Is the speed of light approximately 300,000 km/s?",
|
||||
"Do octopuses have three hearts?",
|
||||
]
|
||||
PREFILL_STR = '\n{"choice": '
|
||||
|
||||
# Pos/neg pair: a generic content-vs-refusal direction. Cheap, not the main
|
||||
# point -- the figure tests calibration *behaviour*, not direction quality.
|
||||
POS_NEG = [
|
||||
@@ -80,9 +133,14 @@ class Args:
|
||||
out_root: str = "outputs"
|
||||
device: str = "cuda"
|
||||
dtype: str = "bfloat16"
|
||||
suffix_str: str = ' Final answer in JSON: {"value": '
|
||||
target_words: list[str] = field(default_factory=lambda: ["true", "false", "yes", "no"])
|
||||
alphas: tuple[float, ...] = (0.25, 0.5, 0.75, 1.0, 1.5, 2.0, 4.0)
|
||||
fork_step: int = 5
|
||||
fork_log: bool = False # log-spaced fork points {0,1,2,4,8,...,window}
|
||||
n_log_forks: int = 14 # number of log-spaced forks (incl. 0)
|
||||
compute_pmass: bool = False
|
||||
skip_pmass_qa: bool = False # skip yes/no pmass loop, only do paired pmass_eval
|
||||
render_figs: bool = False # render single-run survival + spaghetti + KL pngs
|
||||
render_threshold: float = 0.95
|
||||
|
||||
|
||||
def _set_seed(s: int):
|
||||
@@ -93,6 +151,19 @@ def _set_seed(s: int):
|
||||
torch.cuda.manual_seed_all(s)
|
||||
|
||||
|
||||
def _build_guided_prompt(tok, user_text: str, schema_hint: str = _SCHEMA) -> str:
|
||||
"""Match tinymfv/guided.py: chat template + '<think>\\n' suffix so the model
|
||||
is in thinking mode at every fork point. branch_pmass detects this and
|
||||
splices '\\nI should answer now.</think>{prefill}' before scoring."""
|
||||
full_user = f"{user_text}\n\n{schema_hint}" if schema_hint else user_text
|
||||
msgs = [{"role": "user", "content": full_user}]
|
||||
try:
|
||||
p = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
|
||||
except TypeError:
|
||||
p = tok.apply_chat_template(msgs, tokenize=False)
|
||||
return p + "<think>\n"
|
||||
|
||||
|
||||
def main(a: Args):
|
||||
if not a.run_id:
|
||||
a.run_id = f"{a.model.split('/')[-1]}_{a.method}_s{a.seed}_w{a.window}"
|
||||
@@ -134,56 +205,186 @@ def main(a: Args):
|
||||
)
|
||||
v.cfg.coeff = c_star
|
||||
logger.info(f"c_star = {c_star:+.4f}")
|
||||
(out_dir / "history.json").write_text(json.dumps(history, indent=2))
|
||||
# Strip per_prompt_per_t from history to keep file size small.
|
||||
hist_slim = [{k: v_ for k, v_ in h.items() if k != "per_prompt_per_t"}
|
||||
for h in history]
|
||||
(out_dir / "history.json").write_text(json.dumps(hist_slim, indent=2))
|
||||
(out_dir / "calib.json").write_text(json.dumps({
|
||||
"c_star": c_star, "target_kl": a.target_kl, "window": a.window,
|
||||
"method": a.method, "model": a.model, "seed": a.seed, "layer": layer,
|
||||
}, indent=2))
|
||||
|
||||
# -- trajectory + pmass at alpha in {1, 2} on held-out prompts
|
||||
# Sanity-check pmass on the BASE model (no steering): should be ~1.0,
|
||||
# otherwise the prefill/schema isn't priming the right tokens.
|
||||
a_ids, b_ids = collect_choice_token_ids(tok)
|
||||
if a.compute_pmass:
|
||||
logger.info(f"choice ids: a(true)={a_ids} b(false)={b_ids}")
|
||||
|
||||
# -- trajectory + pmass at each alpha on held-out prompts
|
||||
rows = []
|
||||
fork_points = list(range(0, a.window + 1, a.fork_step))
|
||||
if a.fork_log:
|
||||
# log-spaced including 0 and window: {0, 1, 2, 4, 8, ..., window}
|
||||
import numpy as _np
|
||||
raw = _np.unique(_np.round(_np.geomspace(1, a.window, a.n_log_forks - 1)).astype(int))
|
||||
fork_points = [0] + [int(x) for x in raw]
|
||||
fork_points = sorted(set(fp for fp in fork_points if fp <= a.window))
|
||||
else:
|
||||
fork_points = list(range(0, a.window + 1, a.fork_step))
|
||||
logger.info(f"fork_points (n={len(fork_points)}): {fork_points}")
|
||||
trajectory: dict[str, list] = {}
|
||||
per_prompt_traj: dict[str, list] = {}
|
||||
pmass_all: dict[str, list] = {}
|
||||
for alpha in (1.0, 2.0):
|
||||
pmass_eval_all: dict[str, list] = {} # pmass paired with EVAL_PROMPTS (same prompts as KL)
|
||||
p_true_all: dict[str, list] = {}
|
||||
argmax_all: dict[str, list] = {}
|
||||
thinking_all: dict[str, list] = {}
|
||||
answer_label_all: dict[str, list] = {}
|
||||
gen_lens_qa: dict[str, list] = {} # T per (alpha, qa-prompt) -- right-censoring info
|
||||
gen_lens_eval: dict[str, list] = {} # T per (alpha, eval-prompt)
|
||||
debug_first: dict[str, dict] = {} # first prompt per alpha: gen_text + top-5 at each fork
|
||||
# Pre-build guided EVAL ids ONCE so measure_kl and pmass_eval roll out the
|
||||
# same prompt -- otherwise spaghetti coloring (KL trajectory colored by
|
||||
# pmass) is meaningless because the rollouts diverge.
|
||||
eval_prompt_strs = [_build_guided_prompt(tok, p, _SCHEMA) for p in EVAL_PROMPTS]
|
||||
eval_ids_list = [
|
||||
tok(s, return_tensors="pt", add_special_tokens=False).input_ids[0]
|
||||
for s in eval_prompt_strs
|
||||
]
|
||||
for alpha in a.alphas:
|
||||
v.cfg.coeff = alpha * c_star
|
||||
logger.info(f"=== eval alpha={alpha} c={v.cfg.coeff:+.4f} ===")
|
||||
m = measure_kl(v, model, tok, EVAL_PROMPTS, T=a.window, device=a.device)
|
||||
m = measure_kl(v, model, tok, eval_ids_list, T=a.window, device=a.device)
|
||||
trajectory[str(alpha)] = m["per_t_p95"]
|
||||
per_prompt_traj[str(alpha)] = m["per_prompt_per_t"]
|
||||
rows.append({"alpha": alpha, "coeff": v.cfg.coeff, "kl_p95": m["kl_p95"],
|
||||
"kl_mean": m["kl_mean"], "kl_max": m["kl_max"]})
|
||||
|
||||
# pmass per held-out prompt
|
||||
pm_for_alpha = []
|
||||
for p in EVAL_PROMPTS:
|
||||
ids = tok.apply_chat_template(
|
||||
[{"role": "user", "content": p}],
|
||||
add_generation_prompt=True, return_tensors="pt",
|
||||
).input_ids[0]
|
||||
pad = tok.pad_token_id
|
||||
with v(model):
|
||||
gen = model.generate(
|
||||
ids.unsqueeze(0).to(a.device),
|
||||
max_new_tokens=a.window,
|
||||
pad_token_id=pad, eos_token_id=tok.eos_token_id,
|
||||
do_sample=False,
|
||||
)[0, ids.shape[0]:]
|
||||
pm = branch_pmass(
|
||||
v, model, tok, ids, gen, fork_points,
|
||||
a.suffix_str, a.target_words, device=a.device,
|
||||
)
|
||||
pm_for_alpha.append(pm["pmass"])
|
||||
pm_for_alpha, pt_for_alpha, ax_for_alpha, wt_for_alpha, ans_for_alpha = [], [], [], [], []
|
||||
gen_lens_qa_alpha = []
|
||||
gen_lens_eval_alpha = []
|
||||
if a.compute_pmass and not a.skip_pmass_qa:
|
||||
for q_idx, q in enumerate(_QUESTIONS):
|
||||
prompt_str = _build_guided_prompt(tok, q, _SCHEMA)
|
||||
ids = tok(prompt_str, return_tensors="pt", add_special_tokens=False).input_ids[0]
|
||||
pad = tok.pad_token_id
|
||||
with v(model):
|
||||
gen_out = model.generate(
|
||||
ids.unsqueeze(0).to(a.device),
|
||||
max_new_tokens=a.window,
|
||||
pad_token_id=pad, eos_token_id=tok.eos_token_id,
|
||||
do_sample=False,
|
||||
use_cache=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
gen = gen_out.sequences[0, ids.shape[0]:]
|
||||
gen_lens_qa_alpha.append(int(gen.shape[0]))
|
||||
# KV cache from steered rollout: ~100x speedup at each fork.
|
||||
pm = branch_pmass(
|
||||
v, model, tok, ids, gen, fork_points,
|
||||
PREFILL_STR, a_ids, b_ids,
|
||||
rollout_cache=gen_out.past_key_values,
|
||||
device=a.device,
|
||||
)
|
||||
pm_for_alpha.append(pm["pmass"])
|
||||
pt_for_alpha.append(pm["p_true"])
|
||||
ax_for_alpha.append(pm["argmax_str"])
|
||||
wt_for_alpha.append(pm["was_thinking"])
|
||||
ans_for_alpha.append([
|
||||
"true" if s in {"true", "True", " true", " True", "1", " 1"}
|
||||
else "false" if s in {"false", "False", " false", " False", "0", " 0"}
|
||||
else "other"
|
||||
for s in pm["argmax_str"]
|
||||
])
|
||||
# debug dump for first QA prompt only
|
||||
if q_idx == 0:
|
||||
gen_text = tok.decode(gen, skip_special_tokens=False)
|
||||
debug_first.setdefault(str(alpha), {})["qa"] = {
|
||||
"prompt": prompt_str, "gen_text": gen_text, "gen_len": int(gen.shape[0]),
|
||||
"pmass_per_fork": pm["pmass"], "p_true_per_fork": pm["p_true"],
|
||||
"argmax_per_fork": pm["argmax_str"],
|
||||
}
|
||||
logger.info(f" [debug] alpha={alpha} qa[0] gen_len={gen.shape[0]} text[:120]={gen_text[:120]!r}")
|
||||
for t, pmv, ptv, am in zip(fork_points, pm["pmass"], pm["p_true"], pm["argmax_str"]):
|
||||
logger.info(f" t={t:>3} pmass={pmv:.3f} p_true={ptv:.3f} argmax={am!r}")
|
||||
pmass_all[str(alpha)] = pm_for_alpha
|
||||
p_true_all[str(alpha)] = pt_for_alpha
|
||||
argmax_all[str(alpha)] = ax_for_alpha
|
||||
thinking_all[str(alpha)] = wt_for_alpha
|
||||
answer_label_all[str(alpha)] = ans_for_alpha
|
||||
gen_lens_qa[str(alpha)] = gen_lens_qa_alpha
|
||||
|
||||
# Paired pmass on EVAL_PROMPTS (same long-form prompts as KL) so we can
|
||||
# color KL trajectories by pmass at each fork. Long-form prompts won't
|
||||
# naturally produce schema tokens, but the prefill forces the question
|
||||
# "if you committed now, can you still produce a valid choice?" -- which
|
||||
# is exactly the coherence signal we want.
|
||||
pm_eval_for_alpha = []
|
||||
if a.compute_pmass:
|
||||
for p_idx, p in enumerate(EVAL_PROMPTS):
|
||||
prompt_str = eval_prompt_strs[p_idx]
|
||||
ids = eval_ids_list[p_idx]
|
||||
pad = tok.pad_token_id
|
||||
with v(model):
|
||||
gen_out = model.generate(
|
||||
ids.unsqueeze(0).to(a.device),
|
||||
max_new_tokens=a.window,
|
||||
pad_token_id=pad, eos_token_id=tok.eos_token_id,
|
||||
do_sample=False,
|
||||
use_cache=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
gen = gen_out.sequences[0, ids.shape[0]:]
|
||||
gen_lens_eval_alpha.append(int(gen.shape[0]))
|
||||
pm = branch_pmass(
|
||||
v, model, tok, ids, gen, fork_points,
|
||||
PREFILL_STR, a_ids, b_ids,
|
||||
rollout_cache=gen_out.past_key_values,
|
||||
device=a.device,
|
||||
)
|
||||
pm_eval_for_alpha.append(pm["pmass"])
|
||||
if p_idx == 0:
|
||||
gen_text = tok.decode(gen, skip_special_tokens=False)
|
||||
debug_first.setdefault(str(alpha), {})["eval"] = {
|
||||
"prompt": prompt_str, "gen_text": gen_text, "gen_len": int(gen.shape[0]),
|
||||
"pmass_per_fork": pm["pmass"], "p_true_per_fork": pm["p_true"],
|
||||
"argmax_per_fork": pm["argmax_str"],
|
||||
}
|
||||
logger.info(f" [debug] alpha={alpha} eval[0] gen_len={gen.shape[0]} text[:120]={gen_text[:120]!r}")
|
||||
for t, pmv, ptv, am in zip(fork_points, pm["pmass"], pm["p_true"], pm["argmax_str"]):
|
||||
logger.info(f" t={t:>3} pmass={pmv:.3f} p_true={ptv:.3f} argmax={am!r}")
|
||||
pmass_eval_all[str(alpha)] = pm_eval_for_alpha
|
||||
gen_lens_eval[str(alpha)] = gen_lens_eval_alpha
|
||||
# SHOULD: at alpha=1, mean(pmass at t=0) > 0.5 (model still respects schema).
|
||||
# ELSE: prefill string broken or chat template off.
|
||||
if a.compute_pmass:
|
||||
try:
|
||||
import numpy as _np
|
||||
t0 = _np.array([row[0] for row in pm_for_alpha])
|
||||
logger.info(f" alpha={alpha} pmass@t=0: mean={t0.mean():.3f} min={t0.min():.3f} max={t0.max():.3f}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
(out_dir / "trajectory.json").write_text(json.dumps({
|
||||
"fork_points_full": list(range(a.window)),
|
||||
"per_t_p95_kl": trajectory,
|
||||
"per_prompt_per_t_kl": per_prompt_traj,
|
||||
}, indent=2))
|
||||
(out_dir / "pmass.json").write_text(json.dumps({
|
||||
"fork_points": fork_points,
|
||||
"pmass": pmass_all,
|
||||
"suffix": a.suffix_str,
|
||||
"target_words": a.target_words,
|
||||
"pmass_eval": pmass_eval_all,
|
||||
"p_true": p_true_all,
|
||||
"argmax_str": argmax_all,
|
||||
"was_thinking": thinking_all,
|
||||
"answer_label": answer_label_all,
|
||||
"gen_lens_qa": gen_lens_qa, # T per (alpha, qa-prompt) for right-censoring
|
||||
"gen_lens_eval": gen_lens_eval, # T per (alpha, eval-prompt) for right-censoring
|
||||
"debug_first": debug_first, # first prompt per alpha: gen_text + per-fork pmass/argmax
|
||||
"prefill": PREFILL_STR,
|
||||
"schema": _SCHEMA,
|
||||
"questions": _QUESTIONS,
|
||||
"computed": a.compute_pmass,
|
||||
}, indent=2))
|
||||
import csv
|
||||
with open(out_dir / "results.csv", "w", newline="") as f:
|
||||
@@ -191,6 +392,17 @@ def main(a: Args):
|
||||
w.writeheader()
|
||||
for r in rows:
|
||||
w.writerow(r)
|
||||
if a.render_figs:
|
||||
repo_root = Path(__file__).resolve().parents[1]
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"scripts/render_run_figs.py",
|
||||
"--run-dir", str(out_dir),
|
||||
"--threshold", str(a.render_threshold),
|
||||
"--model-contains", a.model.split("/")[-1],
|
||||
]
|
||||
logger.info("rendering single-run figures")
|
||||
subprocess.run(cmd, check=True, cwd=repo_root)
|
||||
logger.info(f"DONE -> {out_dir}")
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,215 @@
|
||||
"""KL-vs-survival spaghetti: per-prompt KL trajectory colored by alive/dead.
|
||||
|
||||
Each rollout is a thin line. KL(t) is plotted; the line is colored by whether
|
||||
the rollout is currently 'alive' (pmass at the nearest fork s<=t is >= threshold)
|
||||
or 'dead' (it has dropped below threshold at some s<=t -- death is irreversible).
|
||||
|
||||
Right-censoring: if the rollout EOS'd before t (gen_len < t), the line stops at
|
||||
gen_len (drawn as a small terminal dot) -- it is 'complete' not 'dead'.
|
||||
|
||||
Panels: one per alpha. Title: KL ceiling (calibration target) + n trajectories.
|
||||
|
||||
Reading guide: at alpha=1, KL is bounded near the calibration target by
|
||||
construction (~1 nat). Lines staying blue/green => model still coherent under
|
||||
that KL budget. Red lines => model collapsed even though KL stayed in budget
|
||||
(steering hit a degenerate region without raising KL much).
|
||||
|
||||
Usage:
|
||||
python scripts/spaghetti_kl_alive.py \
|
||||
--runs-root outputs_qwen35_w512_v3 \
|
||||
--out figs_qwen35_w512_kl_alive \
|
||||
--window 512 \
|
||||
--alphas 0.0 0.25 0.5 0.75 1.0 1.5 2.0 4.0 \
|
||||
--threshold 0.8 \
|
||||
--model-contains Qwen3.5-0.8B
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import tyro
|
||||
from loguru import logger
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.collections import LineCollection
|
||||
from matplotlib.colors import to_rgba
|
||||
|
||||
try:
|
||||
import seaborn as sns
|
||||
sns.set_theme(context="notebook", style="whitegrid", palette="deep", font_scale=0.9)
|
||||
plt.rcParams.update({
|
||||
"axes.titlesize": 10, "axes.labelsize": 9,
|
||||
"axes.spines.top": False, "axes.spines.right": False,
|
||||
})
|
||||
except Exception:
|
||||
plt.style.use("ggplot")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Args:
|
||||
runs_root: str = "outputs_qwen35_w512_v3"
|
||||
out: str = "figs_qwen35_w512_kl_alive"
|
||||
window: int = 512
|
||||
alphas: tuple[str, ...] = ("0.0", "0.25", "0.5", "0.75", "1.0", "1.5", "2.0", "4.0")
|
||||
threshold: float = 0.8 # pmass < threshold = dead
|
||||
metric: str = "pmass_eval" # 'pmass_eval' is paired with KL prompts
|
||||
model_contains: str = "Qwen3.5-0.8B"
|
||||
kl_log: bool = True
|
||||
roll: int = 11 # smooth KL a bit so the spaghetti is readable
|
||||
|
||||
|
||||
def load_cell(d: Path, alpha: str, T: int):
|
||||
"""Load one cell: returns list of (kl_traj, pmass_per_fork, fork_points, gen_len)."""
|
||||
calib = d / "calib.json"
|
||||
traj_p = d / "trajectory.json"
|
||||
pm_p = d / "pmass.json"
|
||||
if not (calib.exists() and traj_p.exists() and pm_p.exists()):
|
||||
return []
|
||||
traj = json.loads(traj_p.read_text())
|
||||
pm = json.loads(pm_p.read_text())
|
||||
if not pm.get("computed", True):
|
||||
return []
|
||||
fork = pm["fork_points"]
|
||||
kl_per_prompt = traj.get("per_prompt_per_t_kl", {}).get(alpha, [])
|
||||
pmass_per_prompt = pm.get("pmass_eval", {}).get(alpha, [])
|
||||
glens = pm.get("gen_lens_eval", {}).get(alpha, [])
|
||||
if not kl_per_prompt or not pmass_per_prompt:
|
||||
return []
|
||||
out = []
|
||||
for i, (kl, pmv) in enumerate(zip(kl_per_prompt, pmass_per_prompt)):
|
||||
gl = int(glens[i]) if i < len(glens) else len(kl)
|
||||
out.append((np.asarray(kl, dtype=float), np.asarray(pmv, dtype=float), fork, gl))
|
||||
return out
|
||||
|
||||
|
||||
def alive_mask_for_t(pmass_per_fork: np.ndarray, fork: list[int],
|
||||
T: int, threshold: float, gen_len: int) -> np.ndarray:
|
||||
"""Return per-token (T,) mask: True=alive, False=dead. Uses nearest-fork-<=t.
|
||||
Once dead at some fork, dead forever after that fork. Right-censoring at gen_len."""
|
||||
# walking running-min over forks, then broadcast to per-token via "nearest fork <= t"
|
||||
rmin = np.minimum.accumulate(np.where(np.isnan(pmass_per_fork), np.inf, pmass_per_fork))
|
||||
dead_at_fork = rmin < threshold
|
||||
# for each token t, find largest fork s such that fork[s] <= t
|
||||
fork_arr = np.array(fork)
|
||||
alive = np.ones(T, dtype=bool)
|
||||
for t in range(min(T, gen_len)):
|
||||
# idx of last fork <= t
|
||||
idx = int(np.searchsorted(fork_arr, t, side="right") - 1)
|
||||
if idx >= 0 and dead_at_fork[idx]:
|
||||
alive[t] = False
|
||||
# tokens beyond gen_len: censored, mark via separate signal (we'll just truncate display)
|
||||
return alive
|
||||
|
||||
|
||||
def _rolling_mean(x: np.ndarray, w: int) -> np.ndarray:
|
||||
if w <= 1 or len(x) < w:
|
||||
return x
|
||||
pad = w // 2
|
||||
xp = np.pad(x, pad, mode="edge")
|
||||
kernel = np.ones(w, dtype=float) / w
|
||||
return np.convolve(xp, kernel, mode="valid")[: len(x)]
|
||||
|
||||
|
||||
def main(a: Args):
|
||||
root = Path(a.runs_root)
|
||||
out = Path(a.out); out.mkdir(parents=True, exist_ok=True)
|
||||
cells = [d for d in sorted(root.iterdir()) if d.is_dir() and not d.name.startswith("_")]
|
||||
cells = [d for d in cells if a.model_contains in d.name]
|
||||
logger.info(f"found {len(cells)} cells in {root} matching {a.model_contains!r}")
|
||||
|
||||
n_panels = len(a.alphas)
|
||||
fig, axes = plt.subplots(1, n_panels, figsize=(3.0 * n_panels, 3.4),
|
||||
sharey=True, squeeze=False)
|
||||
color_alive = to_rgba("#1a9850", 0.65) # green, translucent so overlap stays readable
|
||||
color_dead = to_rgba("#d7191c", 0.55) # stronger red so all-dead panels are visible
|
||||
|
||||
summary_rows = []
|
||||
for j, alpha in enumerate(a.alphas):
|
||||
ax = axes[0, j]
|
||||
all_trajs = []
|
||||
for d in cells:
|
||||
all_trajs.extend(load_cell(d, alpha, a.window))
|
||||
n = len(all_trajs)
|
||||
n_died = 0
|
||||
n_censored = 0
|
||||
median_rows = []
|
||||
dead_segments = []
|
||||
alive_segments = []
|
||||
dead_colors = []
|
||||
alive_colors = []
|
||||
# build all dead segments first, then alive segments on top
|
||||
for kl, pmv, fork, gl in all_trajs:
|
||||
T = min(len(kl), gl)
|
||||
kl = _rolling_mean(kl[:T], a.roll)
|
||||
median_rows.append(np.pad(kl, (0, max(0, a.window - T)), constant_values=np.nan)[: a.window])
|
||||
alive = alive_mask_for_t(pmv, fork, T, a.threshold, gl)
|
||||
xs = np.arange(T)
|
||||
if T < 2:
|
||||
continue
|
||||
pts = np.stack([xs, kl], axis=1).reshape(-1, 1, 2)
|
||||
segs = np.concatenate([pts[:-1], pts[1:]], axis=1)
|
||||
seg_alive = alive[:-1]
|
||||
dead_segments.extend([seg for seg, al in zip(segs, seg_alive) if not al])
|
||||
alive_segments.extend([seg for seg, al in zip(segs, seg_alive) if al])
|
||||
dead_colors.extend([color_dead] * int((~seg_alive).sum()))
|
||||
alive_colors.extend([color_alive] * int(seg_alive.sum()))
|
||||
if not alive.all():
|
||||
n_died += 1
|
||||
if gl < a.window:
|
||||
n_censored += 1
|
||||
ax.scatter([gl - 1], [kl[-1]], s=6, c="black", marker="|",
|
||||
alpha=0.6, zorder=3)
|
||||
|
||||
if dead_segments:
|
||||
ax.add_collection(LineCollection(dead_segments, colors=dead_colors, linewidths=0.9, zorder=1))
|
||||
if alive_segments:
|
||||
ax.add_collection(LineCollection(alive_segments, colors=alive_colors, linewidths=0.9, zorder=2))
|
||||
if median_rows:
|
||||
med = np.nanmedian(np.asarray(median_rows), axis=0)
|
||||
med_x = np.arange(len(med))
|
||||
if np.nanmax(np.abs(med)) < 1e-6:
|
||||
med = med + 1e-4
|
||||
ax.plot(med_x, med, color="black", lw=1.3, alpha=0.9, zorder=4)
|
||||
|
||||
ax.axhline(1.0, color="black", lw=0.7, ls=":", label="KL=1 calib target")
|
||||
ax.set_title(rf"$\alpha={alpha}$ n={n} died={n_died} cens={n_censored}")
|
||||
ax.set_xlabel("token t")
|
||||
if j == 0:
|
||||
ax.set_ylabel("per-token KL")
|
||||
if a.kl_log:
|
||||
ax.set_yscale("symlog", linthresh=0.1)
|
||||
ax.set_xlim(-5, a.window + 5)
|
||||
ax.autoscale_view()
|
||||
# data-driven y-lim sanity
|
||||
# rely on auto
|
||||
summary_rows.append({"alpha": alpha, "n": n, "n_died": n_died, "n_censored": n_censored})
|
||||
|
||||
# legend on first panel only
|
||||
from matplotlib.lines import Line2D
|
||||
handles = [
|
||||
Line2D([0],[0], color=color_alive, lw=2, label=f"alive (pmass >= {a.threshold})"),
|
||||
Line2D([0],[0], color=color_dead, lw=2, label=f"dead (pmass < {a.threshold})"),
|
||||
Line2D([0],[0], color="black", marker="|", lw=0, label="EOS (right-censored)"),
|
||||
Line2D([0],[0], color="black", lw=0.7, ls=":", label="KL=1 calib target"),
|
||||
]
|
||||
axes[0, 0].legend(handles=handles, loc="upper right", fontsize=7, frameon=True)
|
||||
|
||||
fig.suptitle(
|
||||
f"KL trajectories coloured by survival (pmass < {a.threshold} = dead). "
|
||||
f"Model: {a.model_contains}. window={a.window}.",
|
||||
fontsize=10,
|
||||
)
|
||||
fig.tight_layout(rect=(0, 0, 1, 0.92))
|
||||
out_p = out / "kl_alive_spaghetti.png"
|
||||
fig.savefig(out_p, dpi=160, bbox_inches="tight")
|
||||
logger.info(f"figure -> {out_p}")
|
||||
|
||||
from tabulate import tabulate
|
||||
md = tabulate(summary_rows, headers="keys", tablefmt="pipe")
|
||||
(out / "kl_alive_summary.md").write_text(md)
|
||||
logger.info("\n" + md)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(tyro.cli(Args))
|
||||
@@ -0,0 +1,227 @@
|
||||
"""Kaplan-Meier-style survival curves for steered trajectories.
|
||||
|
||||
Motivation:
|
||||
At alpha=1 KL is near the calibration target by construction (we bisected
|
||||
coeff to make it so), so "alpha=1 stays at KL=1" is circular. The honest
|
||||
test is whether the model is still COHERENT -- measured here as pmass on
|
||||
forced-choice answer tokens. Survival curves separate alphas more cleanly
|
||||
than per-token bands.
|
||||
|
||||
Death modes (--metric):
|
||||
kl proxy. alive(t) := running_max KL(s) over s<=t < threshold (default 1.0).
|
||||
Largely redundant with calibration at alpha=1; read with care.
|
||||
pmass real. Right-censored Kaplan-Meier on forced-choice mass.
|
||||
alive(t) := running_min pmass(s) >= threshold AND rollout reached s.
|
||||
Death is irreversible (KM convention).
|
||||
Right-censoring: rollouts that EOS'd before fork t drop OUT of the
|
||||
denominator at t (they are 'complete', not 'dead'). Uses
|
||||
gen_lens_qa[alpha] / gen_lens_eval[alpha] from pmass.json.
|
||||
Reads pmass[alpha] (yes/no reasoning prompts). To use the EVAL_PROMPTS
|
||||
paired with KL instead, switch the loader to pmass_eval[alpha].
|
||||
|
||||
Inputs: outputs/<run>/trajectory.json + pmass.json.
|
||||
|
||||
Usage:
|
||||
python scripts/survival.py --runs_root outputs_qwen05_w512 \
|
||||
--out figs_qwen05_survival --alphas 0.5 1.0 2.0 4.0 \
|
||||
--metric pmass --thresholds 0.5 0.8 0.95
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import tyro
|
||||
from loguru import logger
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
try:
|
||||
import seaborn as sns
|
||||
sns.set_theme(context="notebook", style="whitegrid", palette="deep", font_scale=0.95)
|
||||
plt.rcParams.update({
|
||||
"axes.titlesize": 11, "axes.labelsize": 10,
|
||||
"axes.spines.top": False, "axes.spines.right": False,
|
||||
})
|
||||
except Exception:
|
||||
plt.style.use("ggplot")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Args:
|
||||
runs_root: str = "outputs_qwen05_spaghetti"
|
||||
out: str = "figs_qwen05_survival"
|
||||
window: int = 50
|
||||
alphas: tuple[str, ...] = ("0.25", "0.5", "0.75", "1.0", "1.5", "2.0", "4.0")
|
||||
thresholds: tuple[float, ...] = (1.0,)
|
||||
model_contains: str = "Qwen2.5-0.5B"
|
||||
metric: str = "kl" # 'kl', 'pmass', or 'pmass_eval' (paired w/ KL prompts)
|
||||
|
||||
|
||||
def _load_kl(root: Path, alpha: str, T: int, model_contains: str) -> np.ndarray:
|
||||
rows = []
|
||||
for d in sorted(root.iterdir()):
|
||||
if not d.is_dir() or d.name.startswith("_"):
|
||||
continue
|
||||
calib = d / "calib.json"; traj_p = d / "trajectory.json"
|
||||
if not (calib.exists() and traj_p.exists()):
|
||||
continue
|
||||
meta = json.loads(calib.read_text())
|
||||
if meta.get("window") != T or model_contains not in meta.get("model", ""):
|
||||
continue
|
||||
traj = json.loads(traj_p.read_text())
|
||||
per = traj.get("per_prompt_per_t_kl", {}).get(alpha, [])
|
||||
for r in per:
|
||||
arr = np.full(T, np.nan)
|
||||
arr[: len(r)] = r[: T]
|
||||
rows.append(arr)
|
||||
return np.array(rows) if rows else np.zeros((0, T))
|
||||
|
||||
|
||||
def _load_pmass(root: Path, alpha: str, T: int, model_contains: str,
|
||||
key: str = "pmass",
|
||||
) -> tuple[np.ndarray, list[int], np.ndarray]:
|
||||
"""Return (N, F) pmass array + fork_points + (N,) gen_lens for right-censoring.
|
||||
NaN in the pmass array means the rollout EOS'd before that fork ('complete,
|
||||
not dead' -- right-censored). gen_lens is the rollout length T per row.
|
||||
key in {'pmass','pmass_eval'} -> uses 'gen_lens_qa' / 'gen_lens_eval'.
|
||||
"""
|
||||
rows = []; gen_lens = []; fork = None
|
||||
glen_key = "gen_lens_qa" if key == "pmass" else "gen_lens_eval"
|
||||
for d in sorted(root.iterdir()):
|
||||
if not d.is_dir() or d.name.startswith("_"):
|
||||
continue
|
||||
calib = d / "calib.json"; pm_p = d / "pmass.json"
|
||||
if not (calib.exists() and pm_p.exists()):
|
||||
continue
|
||||
meta = json.loads(calib.read_text())
|
||||
if meta.get("window") != T or model_contains not in meta.get("model", ""):
|
||||
continue
|
||||
pm = json.loads(pm_p.read_text())
|
||||
if not pm.get("computed", True):
|
||||
continue
|
||||
if fork is None:
|
||||
fork = pm["fork_points"]
|
||||
per = pm.get(key, {}).get(alpha, [])
|
||||
glens = pm.get(glen_key, {}).get(alpha, [])
|
||||
for i, r in enumerate(per):
|
||||
rows.append(r)
|
||||
# fall back to T (window) if gen_lens not saved (legacy outputs)
|
||||
gen_lens.append(int(glens[i]) if i < len(glens) else T)
|
||||
if not rows:
|
||||
return np.zeros((0, len(fork or []))), fork or [], np.zeros(0, dtype=int)
|
||||
L = max(len(r) for r in rows)
|
||||
arr = np.full((len(rows), L), np.nan)
|
||||
for i, r in enumerate(rows):
|
||||
arr[i, : len(r)] = r
|
||||
return arr, fork, np.array(gen_lens, dtype=int)
|
||||
|
||||
|
||||
def survival_kl(K: np.ndarray, threshold: float) -> np.ndarray:
|
||||
"""S(t) = fraction with running_max(K) < threshold."""
|
||||
if K.size == 0:
|
||||
return np.zeros(0)
|
||||
rmax = np.maximum.accumulate(np.nan_to_num(K, nan=-np.inf), axis=1)
|
||||
return (rmax < threshold).mean(axis=0)
|
||||
|
||||
|
||||
def survival_pmass(P: np.ndarray, fork: list[int], gen_lens: np.ndarray,
|
||||
threshold: float) -> np.ndarray:
|
||||
"""Right-censored Kaplan-Meier on pmass.
|
||||
|
||||
A trajectory is 'at risk' at fork t iff its rollout reached t (gen_len >= t).
|
||||
A trajectory 'dies' at fork t if pmass(t) < threshold AND it has not died yet.
|
||||
Right-censored trajectories (gen_len < t) drop out of the denominator at t
|
||||
(they are 'complete', not 'dead' -- per user 2026-05-05).
|
||||
|
||||
KM estimate: S(t) = prod_{s<=t} (1 - d_s / n_s) where d_s = deaths at s,
|
||||
n_s = at-risk just before s. Reduces to "fraction alive" if no censoring.
|
||||
"""
|
||||
if P.size == 0:
|
||||
return np.zeros(0)
|
||||
N, F = P.shape
|
||||
fork_arr = np.array(fork[:F])
|
||||
# at_risk[i,j] = True if rollout i reached fork[j] (gen_lens[i] >= fork[j])
|
||||
at_risk = gen_lens[:, None] >= fork_arr[None, :]
|
||||
# 'dead' event: pmass below threshold (treat NaN as not-dead since we right-censor via at_risk)
|
||||
P_filled = np.nan_to_num(P, nan=np.inf)
|
||||
died_at = (P_filled < threshold) & at_risk # (N, F)
|
||||
# Make death irreversible: once dead, stays dead at all later forks (where still at risk)
|
||||
ever_dead = np.maximum.accumulate(died_at.astype(np.int8), axis=1).astype(bool)
|
||||
# KM hazard at each fork: d_s / n_s
|
||||
S = np.ones(F)
|
||||
s = 1.0
|
||||
for j in range(F):
|
||||
n_s = int(at_risk[:, j].sum())
|
||||
if n_s == 0:
|
||||
S[j] = s
|
||||
continue
|
||||
# new deaths at this fork: ever_dead at j AND not at j-1
|
||||
if j == 0:
|
||||
d_s = int(ever_dead[:, j].sum())
|
||||
else:
|
||||
d_s = int((ever_dead[:, j] & ~ever_dead[:, j-1] & at_risk[:, j]).sum())
|
||||
s *= (1.0 - d_s / n_s)
|
||||
S[j] = s
|
||||
return S
|
||||
|
||||
|
||||
def main(a: Args):
|
||||
root = Path(a.runs_root); out = Path(a.out); out.mkdir(parents=True, exist_ok=True)
|
||||
n_panels = len(a.thresholds)
|
||||
fig, axes = plt.subplots(1, n_panels, figsize=(4.6 * n_panels, 3.4),
|
||||
sharey=True, squeeze=False)
|
||||
cmap = plt.get_cmap("viridis")
|
||||
colors = {alpha: cmap(i / max(1, len(a.alphas) - 1)) for i, alpha in enumerate(a.alphas)}
|
||||
|
||||
rows_summary = []
|
||||
for j, thr in enumerate(a.thresholds):
|
||||
ax = axes[0, j]
|
||||
for alpha in a.alphas:
|
||||
if a.metric == "kl":
|
||||
K = _load_kl(root, alpha, a.window, a.model_contains)
|
||||
if K.size == 0:
|
||||
logger.warning(f"no data alpha={alpha}"); continue
|
||||
S = survival_kl(K, thr); xs = np.arange(len(S)); n = K.shape[0]
|
||||
xlabel = "token t"
|
||||
elif a.metric in ("pmass", "pmass_eval"):
|
||||
P, fork, gen_lens = _load_pmass(root, alpha, a.window, a.model_contains, key=a.metric)
|
||||
if P.size == 0:
|
||||
logger.warning(f"no {a.metric} alpha={alpha}"); continue
|
||||
S = survival_pmass(P, fork, gen_lens, thr)
|
||||
xs = np.array(fork[: len(S)]); n = P.shape[0]
|
||||
xlabel = "fork token t"
|
||||
else:
|
||||
raise SystemExit(f"unknown metric {a.metric!r}; use 'kl', 'pmass', or 'pmass_eval'")
|
||||
ax.step(xs, S, where="post", color=colors[alpha], lw=2.0,
|
||||
label=rf"$\alpha={alpha}$ (n={n})")
|
||||
below_half = np.where(S <= 0.5)[0]
|
||||
t50 = int(xs[below_half[0]]) if len(below_half) else None
|
||||
rows_summary.append({"metric": a.metric, "threshold": thr, "alpha": alpha,
|
||||
"n": int(n),
|
||||
"S_mid": float(S[len(S)//2]),
|
||||
"S_end": float(S[-1]),
|
||||
"t_S<=0.5": t50})
|
||||
if n_panels > 1:
|
||||
ax.set_title(f"threshold = {thr:g}")
|
||||
ax.set_xlabel(xlabel)
|
||||
ax.set_ylim(-0.02, 1.05)
|
||||
if j == 0:
|
||||
ax.set_ylabel("fraction of trajectories alive")
|
||||
ax.legend(loc="lower left", fontsize=8, frameon=False)
|
||||
ax.axvline(20, color="k", ls=":", lw=0.7)
|
||||
|
||||
fig.suptitle(f"Survival, {a.model_contains}", fontsize=10)
|
||||
fig.tight_layout(rect=(0, 0, 1, 0.94))
|
||||
out_p = out / f"survival_{a.metric}.png"
|
||||
fig.savefig(out_p, dpi=160, bbox_inches="tight")
|
||||
logger.info(f"survival -> {out_p}")
|
||||
|
||||
from tabulate import tabulate
|
||||
md = tabulate(rows_summary, headers="keys", tablefmt="pipe", floatfmt=".3f")
|
||||
(out / f"survival_{a.metric}.md").write_text(md)
|
||||
logger.info("\n" + md)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(tyro.cli(Args))
|
||||
Reference in New Issue
Block a user