This commit is contained in:
wassname
2026-05-08 11:25:10 +08:00
parent 2a69829612
commit 77b296cc75
10 changed files with 4176 additions and 46 deletions
+40 -19
View File
@@ -65,14 +65,19 @@ except Exception:
class Args:
runs_root: str = "outputs"
out: str = "figs"
window: int = 50 # only this window enters the figure (need long-enough traj for rolling-16)
roll: int = 16 # smoothing window for KL trajectory
window: int = 512 # only this window enters the figure
roll: int = 65 # smoothing window for KL trajectory
alphas: tuple[str, ...] = ("0.25", "0.5", "0.75", "1.0", "1.5", "2.0", "4.0")
kl_ymax: float = 6.0
model_contains: str = ""
kl_only: bool = False
spaghetti: bool = False # plot individual trajectories instead of p10..p90 band
color_by_pmass: bool = False # color KL spaghetti lines by paired pmass (requires pmass_eval)
line_alpha: float | None = None # per-line alpha override; None = auto clip(2.5/n,.08,.35)
line_lw: float = 0.18 # per-trajectory linewidth; full opacity needs very thin lines
median_lw: float = 0.75 # median linewidth
quantile_lines: bool = False # clean summary: p10/p50/p90 lines, no fill/spaghetti
mark_t: int = -1 # optional vertical token marker; -1 disables
def _rolling_mean(x: np.ndarray, w: int) -> np.ndarray:
@@ -138,20 +143,20 @@ def _draw_kl_panel(ax, K: np.ndarray, a: Args, P: np.ndarray | None = None) -> N
pts = np.column_stack([xs, traj])
segs = np.stack([pts[:-1], pts[1:]], axis=1)
lc = LineCollection(segs, cmap=cmap, norm=plt.Normalize(0, 1),
linewidths=0.7, alpha=0.55)
linewidths=a.line_lw, alpha=float(a.line_alpha if a.line_alpha is not None else np.clip(2.5/max(K.shape[0],1), 0.08, 0.35)))
lc.set_array(pmass_row[:-1])
ax.add_collection(lc)
ax.set_xlim(xs[0], xs[-1])
med = np.nanmedian(Kp, axis=0)
ax.plot(xs, med, color="k", lw=1.6)
ax.plot(xs, med, color="k", lw=a.median_lw)
else:
crossed = (K > 1.0).any(axis=1)
for traj in Kp[~crossed]:
ax.plot(xs, traj, color="0.55", lw=0.6, alpha=0.6)
ax.plot(xs, traj, color="0.55", lw=a.line_lw, alpha=0.5)
for traj in Kp[crossed]:
ax.plot(xs, traj, color="C3", lw=0.6, alpha=0.6)
ax.plot(xs, traj, color="C3", lw=a.line_lw, alpha=0.5)
med = np.nanmedian(Kp, axis=0)
ax.plot(xs, med, color="k", lw=1.6)
ax.plot(xs, med, color="k", lw=a.median_lw)
frac = float(crossed.mean())
ax.text(0.97, 0.97, f"{frac:.0%} cross KL=1",
transform=ax.transAxes, ha="right", va="top",
@@ -163,21 +168,32 @@ def _draw_kl_panel(ax, K: np.ndarray, a: Args, P: np.ndarray | None = None) -> N
p50s = _rolling_mean(p50, a.roll)
p10s = _rolling_mean(p10, a.roll)
p90s = _rolling_mean(p90, a.roll)
ax.fill_between(xs, p10s, p90s, alpha=0.25, color="C0", lw=0)
ax.plot(xs, p50s, color="C0", lw=1.6)
if a.quantile_lines:
# standard quantile fan: outer band light, inner band dark, p50 line
p25s = _rolling_mean(np.nanpercentile(K, 25, axis=0), a.roll)
p75s = _rolling_mean(np.nanpercentile(K, 75, axis=0), a.roll)
ax.fill_between(xs, p10s, p90s, alpha=0.15, color="C0", lw=0, label="p10..p90")
ax.fill_between(xs, p25s, p75s, alpha=0.32, color="C0", lw=0, label="p25..p75")
ax.plot(xs, p50s, color="C0", lw=1.5, label="p50")
else:
ax.fill_between(xs, p10s, p90s, alpha=0.25, color="C0", lw=0)
ax.plot(xs, p50s, color="C0", lw=1.6)
def make_kl_figure(cells: list[dict], a: Args, out_path: Path) -> None:
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, len(a.alphas), figsize=(4.0 * len(a.alphas), 3.2),
fig, axes = plt.subplots(1, len(a.alphas), figsize=(3.7 * len(a.alphas), 3.4),
sharex=True, sharey=True, squeeze=False, constrained_layout=True)
label = a.model_contains or "all models"
mode = "individual trajectories (red=ever crossed KL=1)" if a.spaghetti \
else f"p50 + p10..p90 band, smoothed rolling-{a.roll}"
n_max = max((_pool_kl(cells, alpha, T=a.window).shape[0] for alpha in a.alphas), default=0)
mode = f"individual trajectories, roll={a.roll}, color=pmass" if (a.spaghetti and a.color_by_pmass) \
else "individual trajectories (red=ever crossed KL=1)" if a.spaghetti \
else f"shaded quantiles p10/p25/p50/p75/p90, roll={a.roll}" if a.quantile_lines \
else f"p50 + p10..p90 band, smoothed rolling-{a.roll}"
fig.suptitle(
f"KL trajectory on N=8 held-out long-form prompts ({label})\n"
f"{mode}. Solid line: KL=1 nat. Dotted v-line: t=20.",
f"KL trajectory on N={n_max} held-out long-form prompts ({label})\n"
f"{mode}. Solid horizontal: KL=1 nat.",
fontsize=10,
)
@@ -202,27 +218,32 @@ def make_kl_figure(cells: list[dict], a: Args, out_path: Path) -> None:
Pe = _pool_pmass_eval(cells, alpha, _first_fork(cells), T=a.window) if a.color_by_pmass else None
_draw_kl_panel(ax, K, a, P=Pe)
if y_max >= 1.0:
ax.axhline(1.0, color="k", lw=1.0)
ax.axhline(1.0, color="k", lw=0.7)
else:
ax.text(0.98, 0.92, "KL=1 off-scale", transform=ax.transAxes,
ha="right", va="top", fontsize=8, color="0.25")
ax.axvline(20, color="k", ls=":", lw=0.8)
if a.mark_t >= 0:
ax.axvline(a.mark_t, color="k", ls=":", lw=0.6)
ax.set_title(rf"$\alpha = {alpha}$ (n={K.shape[0]} traj)")
ax.set_xlim(0, x_max)
ax.set_ylim(0, y_max)
ax.set_xlabel("token")
if j == 0:
ax.set_ylabel("KL(steered || base) [nats]")
ax.set_ylabel("KL")
if a.color_by_pmass:
from matplotlib.cm import ScalarMappable
from matplotlib.colors import LinearSegmentedColormap, Normalize
cmap = LinearSegmentedColormap.from_list("alive", ["#c0392b", "#f1c40f", "#27ae60"])
sm = ScalarMappable(norm=Normalize(0, 1), cmap=cmap); sm.set_array([])
cbar = fig.colorbar(sm, ax=axes[0, :].tolist(), location="right", shrink=0.75, pad=0.01, fraction=0.015)
cbar.set_label("pmass (0=dead, 1=alive)")
cbar = fig.colorbar(sm, ax=axes[0, :].tolist(), location="right", shrink=0.72, pad=0.015, fraction=0.012)
cbar.set_label("pmass", labelpad=2)
fig.savefig(out_path, dpi=160, bbox_inches="tight")
if a.quantile_lines:
q_path = out_path.with_name(out_path.stem + "_quantile_lines" + out_path.suffix)
fig.savefig(q_path, dpi=160, bbox_inches="tight")
logger.info(f"KL quantile-line figure -> {q_path}")
logger.info(f"KL-only figure -> {out_path}")
+11 -4
View File
@@ -16,6 +16,10 @@ class Args:
threshold: float = 0.95
out_name: str = "figs_auto"
model_contains: str = ""
line_alpha: float | None = None # forwarded to spaghetti+aggregate
roll: int = 65
line_lw: float = 0.5
quantile_lines: bool = False
def _ensure_single_run_root(run_dir: Path) -> Path:
@@ -72,18 +76,21 @@ def main(a: Args) -> None:
"--window", str(__import__("json").loads(calib.read_text())["window"]),
"--threshold", str(a.threshold),
"--model-contains", model_filter,
], repo_root)
"--roll", str(a.roll),
"--line-lw", str(a.line_lw),
] + (["--line-alpha", str(a.line_alpha)] if a.line_alpha is not None else []), repo_root)
_run([
sys.executable,
"scripts/aggregate.py",
"--runs-root", str(single_root),
"--out", str(out_dir / "aggregate"),
"--window", str(__import__("json").loads(calib.read_text())["window"]),
"--spaghetti",
"--color-by-pmass",
"--kl-only",
"--model-contains", model_filter,
], repo_root)
"--roll", str(a.roll),
"--line-lw", str(a.line_lw),
] + (["--spaghetti", "--color-by-pmass"] if not a.quantile_lines else ["--quantile-lines"]) \
+ (["--line-alpha", str(a.line_alpha)] if a.line_alpha is not None else []), repo_root)
pngs = [
out_dir / "survival" / "survival_pmass_eval.png",
+93 -12
View File
@@ -57,12 +57,13 @@ from pathlib import Path
import torch
import tyro
from loguru import logger
from tqdm.auto import tqdm
from iso_kl_figure import (
SteeringConfig, MeanDiffC, PCAC, DirectionalAblationC,
train, calibrate_iso_kl, measure_kl,
)
from iso_kl_figure.branch_pmass import branch_pmass, collect_choice_token_ids
from iso_kl_figure.branch_pmass import branch_pmass, build_chat_interrupt_suffix, collect_choice_token_ids
from iso_kl_figure.calibrate import _eos_token_ids
from iso_kl_figure.target import _get_blocks
@@ -103,6 +104,14 @@ EVAL_PROMPTS = [
"Was Shakespeare a contemporary of Queen Elizabeth I of England?",
"Is the Pacific Ocean larger in surface area than the Atlantic Ocean?",
"Does a transformer architecture rely on self-attention rather than recurrence for sequence modeling?",
"Is honey produced by bees primarily from nectar collected from flowers?",
"Did Ada Lovelace write what is often considered the first published computer algorithm?",
"Is the boiling point of pure water at sea level approximately 100 degrees Celsius?",
"Does cross-entropy loss reduce to negative log likelihood for a single correct class label?",
"Was the printing press invented by Johannes Gutenberg in the 15th century?",
"Is the speed of light in a vacuum the same for all inertial observers?",
"Do antibiotics kill viruses in addition to bacteria?",
"Is the chemical symbol for gold 'Au' on the periodic table?",
]
# pmass diagnostic prompts share the same schema as EVAL above.
@@ -144,7 +153,7 @@ class Args:
model: str
method: str
seed: int = 0
window: int = 50
window: int = 512
run_id: str = ""
layer_frac: float = 0.6
target_kl: float = 1.0
@@ -159,6 +168,13 @@ class Args:
skip_pmass_qa: bool = False # skip yes/no pmass loop, only do paired pmass_eval
render_figs: bool = False # render single-run survival + spaghetti + KL pngs
render_threshold: float = 0.95
chat_interrupt: bool = False # use chat-template-correct interrupt turn for pmass scoring
# (in-distribution; recommended over raw splice).
interrupt_user_text: str = (
"Given the conversation so far, what is your final answer?"
" Respond with JSON: {\"Answer\": true} or {\"Answer\": false}."
)
interrupt_prefill: str = '{"Answer": '
def _set_seed(s: int):
@@ -169,6 +185,28 @@ def _set_seed(s: int):
torch.cuda.manual_seed_all(s)
def _log_fork_summary(stage: str, alpha: float, gen, gen_text: str,
fork_points: list[int], pm: dict) -> None:
"""One INFO line + one debug table per (stage, alpha) instead of N forks * 1 line.
SHOULD: at alpha~1, pmass ~> 0.5 across forks (model still respects schema).
ELSE: prefill broken or schema drift -- inspect debug_first in pmass.json.
"""
from tabulate import tabulate
pmv = pm["pmass"]
ptv = pm["p_true"]
pm_min = min(pmv) if pmv else float("nan")
pm_med = sorted(pmv)[len(pmv) // 2] if pmv else float("nan")
logger.info(
f" alpha={alpha} {stage}[0] gen_len={gen.shape[0]} pmass(med/min)="
f"{pm_med:.2f}/{pm_min:.2f} | text[:80]={gen_text[:80]!r}"
)
rows = [(t, f"{p:.3f}", f"{q:.3f}", repr(s))
for t, p, q, s in zip(fork_points, pmv, ptv, pm["argmax_str"])]
table = tabulate(rows, headers=["t", "pmass", "p_true", "argmax"], tablefmt="plain")
logger.debug(f"\nfork table [{stage} alpha={alpha}]\n{table}")
def _render_chat(tok, msgs: list[dict[str, str]], add_generation_prompt: bool) -> str:
if tok.chat_template is None:
rendered = []
@@ -279,6 +317,15 @@ def main(a: Args):
gen_lens_eval: dict[str, list] = {} # T per (alpha, eval-prompt)
debug_first: dict[str, dict] = {} # first prompt per alpha: gen_text + top-5 at each fork
# Append-only JSONL of every generation: full text + paired pmass.
# Crash-safe: each line flushed before next prompt. Overwritten only at run start.
gens_path = out_dir / "gens.jsonl"
if gens_path.exists():
gens_path.unlink()
def _append_gen(rec: dict) -> None:
with open(gens_path, "a") as f:
f.write(json.dumps(rec) + "\n")
def write_eval_outputs() -> None:
(out_dir / "trajectory.json").write_text(json.dumps({
"fork_points_full": list(range(a.window)),
@@ -317,7 +364,18 @@ def main(a: Args):
tok(s, return_tensors="pt", add_special_tokens=False).input_ids[0]
for s in eval_prompt_strs
]
for alpha in a.alphas:
interrupt_suffix_ids = None
interrupt_suffix_decoded = ""
if a.chat_interrupt:
interrupt_suffix_ids = build_chat_interrupt_suffix(
tok, a.interrupt_user_text, a.interrupt_prefill,
)
# SHOULD: decoded suffix ends with prefill literal. ELSE template quirk.
interrupt_suffix_decoded = tok.decode(interrupt_suffix_ids)
logger.info(f"chat_interrupt suffix ({len(interrupt_suffix_ids)} tok): {interrupt_suffix_decoded!r}")
if not interrupt_suffix_decoded.endswith(a.interrupt_prefill):
logger.warning(f"interrupt suffix does not end with prefill={a.interrupt_prefill!r}")
for alpha in tqdm(a.alphas, desc=f"alphas[{a.run_id}]", mininterval=60):
v.cfg.coeff = alpha * c_star
logger.info(f"=== eval alpha={alpha} c={v.cfg.coeff:+.4f} ===")
m = measure_kl(v, model, tok, eval_ids_list, T=a.window, device=a.device)
@@ -351,6 +409,7 @@ def main(a: Args):
v, model, tok, ids, gen, fork_points,
PREFILL_STR, a_ids, b_ids,
rollout_cache=gen_out.past_key_values,
interrupt_suffix_ids=interrupt_suffix_ids,
device=a.device,
)
pm_for_alpha.append(pm["pmass"])
@@ -363,17 +422,28 @@ def main(a: Args):
else "other"
for s in pm["argmax_str"]
])
# debug dump for first QA prompt only
# save full gen text + paired pmass for every QA prompt
gen_text = tok.decode(gen, skip_special_tokens=False)
_append_gen({
"alpha": float(alpha), "kind": "qa", "prompt_idx": q_idx,
"score_mode": "chat_interrupt" if a.chat_interrupt else "raw_splice",
"interrupt_user_text": a.interrupt_user_text if a.chat_interrupt else "",
"interrupt_prefill": a.interrupt_prefill if a.chat_interrupt else "",
"interrupt_suffix_ids": list(interrupt_suffix_ids) if interrupt_suffix_ids is not None else [],
"interrupt_suffix_decoded": interrupt_suffix_decoded,
"prompt": prompt_str, "gen_text": gen_text, "gen_len": int(gen.shape[0]),
"fork_points": list(fork_points),
"pmass": pm["pmass"], "p_true": pm["p_true"],
"argmax_str": pm["argmax_str"], "was_thinking": pm["was_thinking"],
})
# debug dump for first QA prompt only (kept for backward compat)
if q_idx == 0:
gen_text = tok.decode(gen, skip_special_tokens=False)
debug_first.setdefault(str(alpha), {})["qa"] = {
"prompt": prompt_str, "gen_text": gen_text, "gen_len": int(gen.shape[0]),
"pmass_per_fork": pm["pmass"], "p_true_per_fork": pm["p_true"],
"argmax_per_fork": pm["argmax_str"],
}
logger.info(f" [debug] alpha={alpha} qa[0] gen_len={gen.shape[0]} text[:120]={gen_text[:120]!r}")
for t, pmv, ptv, am in zip(fork_points, pm["pmass"], pm["p_true"], pm["argmax_str"]):
logger.info(f" t={t:>3} pmass={pmv:.3f} p_true={ptv:.3f} argmax={am!r}")
_log_fork_summary("qa", alpha, gen, gen_text, fork_points, pm)
pmass_all[str(alpha)] = pm_for_alpha
p_true_all[str(alpha)] = pt_for_alpha
argmax_all[str(alpha)] = ax_for_alpha
@@ -407,19 +477,30 @@ def main(a: Args):
v, model, tok, ids, gen, fork_points,
PREFILL_STR, a_ids, b_ids,
rollout_cache=gen_out.past_key_values,
interrupt_suffix_ids=interrupt_suffix_ids,
device=a.device,
)
pm_eval_for_alpha.append(pm["pmass"])
gen_text = tok.decode(gen, skip_special_tokens=False)
_append_gen({
"alpha": float(alpha), "kind": "eval", "prompt_idx": p_idx,
"score_mode": "chat_interrupt" if a.chat_interrupt else "raw_splice",
"interrupt_user_text": a.interrupt_user_text if a.chat_interrupt else "",
"interrupt_prefill": a.interrupt_prefill if a.chat_interrupt else "",
"interrupt_suffix_ids": list(interrupt_suffix_ids) if interrupt_suffix_ids is not None else [],
"interrupt_suffix_decoded": interrupt_suffix_decoded,
"prompt": prompt_str, "gen_text": gen_text, "gen_len": int(gen.shape[0]),
"fork_points": list(fork_points),
"pmass": pm["pmass"], "p_true": pm["p_true"],
"argmax_str": pm["argmax_str"], "was_thinking": pm["was_thinking"],
})
if p_idx == 0:
gen_text = tok.decode(gen, skip_special_tokens=False)
debug_first.setdefault(str(alpha), {})["eval"] = {
"prompt": prompt_str, "gen_text": gen_text, "gen_len": int(gen.shape[0]),
"pmass_per_fork": pm["pmass"], "p_true_per_fork": pm["p_true"],
"argmax_per_fork": pm["argmax_str"],
}
logger.info(f" [debug] alpha={alpha} eval[0] gen_len={gen.shape[0]} text[:120]={gen_text[:120]!r}")
for t, pmv, ptv, am in zip(fork_points, pm["pmass"], pm["p_true"], pm["argmax_str"]):
logger.info(f" t={t:>3} pmass={pmv:.3f} p_true={ptv:.3f} argmax={am!r}")
_log_fork_summary("eval", alpha, gen, gen_text, fork_points, pm)
pmass_eval_all[str(alpha)] = pm_eval_for_alpha
gen_lens_eval[str(alpha)] = gen_lens_eval_alpha
# SHOULD: at alpha=1, mean(pmass at t=0) > 0.5 (model still respects schema).