This commit is contained in:
wassname
2026-05-08 11:25:10 +08:00
parent 2a69829612
commit 77b296cc75
10 changed files with 4176 additions and 46 deletions
+3 -1
View File
@@ -1,4 +1,6 @@
# iso-kl-figure (short version) # calibrating steering overl ong trajectories by normalising KL outliers
![alt text](figs/zoom_in.png)
## The problem ## The problem
File diff suppressed because it is too large Load Diff
BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 190 KiB

+1 -1
View File
@@ -11,7 +11,7 @@ test:
uv run --extra all pytest -q uv run --extra all pytest -q
# Run one (model, method, seed, window) cell end-to-end (calibrate + trajectory + pmass). # Run one (model, method, seed, window) cell end-to-end (calibrate + trajectory + pmass).
cell model="Qwen/Qwen3.5-0.8B" method="mean_diff" seed="0" window="50": cell model="Qwen/Qwen3.5-0.8B" method="mean_diff" seed="0" window="512":
uv run --extra all python scripts/run_cell.py \ uv run --extra all python scripts/run_cell.py \
--model {{model}} --method {{method}} --seed {{seed}} --window {{window}} --model {{model}} --method {{method}} --seed {{seed}} --window {{window}}
+40 -19
View File
@@ -65,14 +65,19 @@ except Exception:
class Args: class Args:
runs_root: str = "outputs" runs_root: str = "outputs"
out: str = "figs" out: str = "figs"
window: int = 50 # only this window enters the figure (need long-enough traj for rolling-16) window: int = 512 # only this window enters the figure
roll: int = 16 # smoothing window for KL trajectory roll: int = 65 # smoothing window for KL trajectory
alphas: tuple[str, ...] = ("0.25", "0.5", "0.75", "1.0", "1.5", "2.0", "4.0") alphas: tuple[str, ...] = ("0.25", "0.5", "0.75", "1.0", "1.5", "2.0", "4.0")
kl_ymax: float = 6.0 kl_ymax: float = 6.0
model_contains: str = "" model_contains: str = ""
kl_only: bool = False kl_only: bool = False
spaghetti: bool = False # plot individual trajectories instead of p10..p90 band spaghetti: bool = False # plot individual trajectories instead of p10..p90 band
color_by_pmass: bool = False # color KL spaghetti lines by paired pmass (requires pmass_eval) color_by_pmass: bool = False # color KL spaghetti lines by paired pmass (requires pmass_eval)
line_alpha: float | None = None # per-line alpha override; None = auto clip(2.5/n,.08,.35)
line_lw: float = 0.18 # per-trajectory linewidth; full opacity needs very thin lines
median_lw: float = 0.75 # median linewidth
quantile_lines: bool = False # clean summary: p10/p50/p90 lines, no fill/spaghetti
mark_t: int = -1 # optional vertical token marker; -1 disables
def _rolling_mean(x: np.ndarray, w: int) -> np.ndarray: def _rolling_mean(x: np.ndarray, w: int) -> np.ndarray:
@@ -138,20 +143,20 @@ def _draw_kl_panel(ax, K: np.ndarray, a: Args, P: np.ndarray | None = None) -> N
pts = np.column_stack([xs, traj]) pts = np.column_stack([xs, traj])
segs = np.stack([pts[:-1], pts[1:]], axis=1) segs = np.stack([pts[:-1], pts[1:]], axis=1)
lc = LineCollection(segs, cmap=cmap, norm=plt.Normalize(0, 1), lc = LineCollection(segs, cmap=cmap, norm=plt.Normalize(0, 1),
linewidths=0.7, alpha=0.55) linewidths=a.line_lw, alpha=float(a.line_alpha if a.line_alpha is not None else np.clip(2.5/max(K.shape[0],1), 0.08, 0.35)))
lc.set_array(pmass_row[:-1]) lc.set_array(pmass_row[:-1])
ax.add_collection(lc) ax.add_collection(lc)
ax.set_xlim(xs[0], xs[-1]) ax.set_xlim(xs[0], xs[-1])
med = np.nanmedian(Kp, axis=0) med = np.nanmedian(Kp, axis=0)
ax.plot(xs, med, color="k", lw=1.6) ax.plot(xs, med, color="k", lw=a.median_lw)
else: else:
crossed = (K > 1.0).any(axis=1) crossed = (K > 1.0).any(axis=1)
for traj in Kp[~crossed]: for traj in Kp[~crossed]:
ax.plot(xs, traj, color="0.55", lw=0.6, alpha=0.6) ax.plot(xs, traj, color="0.55", lw=a.line_lw, alpha=0.5)
for traj in Kp[crossed]: for traj in Kp[crossed]:
ax.plot(xs, traj, color="C3", lw=0.6, alpha=0.6) ax.plot(xs, traj, color="C3", lw=a.line_lw, alpha=0.5)
med = np.nanmedian(Kp, axis=0) med = np.nanmedian(Kp, axis=0)
ax.plot(xs, med, color="k", lw=1.6) ax.plot(xs, med, color="k", lw=a.median_lw)
frac = float(crossed.mean()) frac = float(crossed.mean())
ax.text(0.97, 0.97, f"{frac:.0%} cross KL=1", ax.text(0.97, 0.97, f"{frac:.0%} cross KL=1",
transform=ax.transAxes, ha="right", va="top", transform=ax.transAxes, ha="right", va="top",
@@ -163,21 +168,32 @@ def _draw_kl_panel(ax, K: np.ndarray, a: Args, P: np.ndarray | None = None) -> N
p50s = _rolling_mean(p50, a.roll) p50s = _rolling_mean(p50, a.roll)
p10s = _rolling_mean(p10, a.roll) p10s = _rolling_mean(p10, a.roll)
p90s = _rolling_mean(p90, a.roll) p90s = _rolling_mean(p90, a.roll)
ax.fill_between(xs, p10s, p90s, alpha=0.25, color="C0", lw=0) if a.quantile_lines:
ax.plot(xs, p50s, color="C0", lw=1.6) # standard quantile fan: outer band light, inner band dark, p50 line
p25s = _rolling_mean(np.nanpercentile(K, 25, axis=0), a.roll)
p75s = _rolling_mean(np.nanpercentile(K, 75, axis=0), a.roll)
ax.fill_between(xs, p10s, p90s, alpha=0.15, color="C0", lw=0, label="p10..p90")
ax.fill_between(xs, p25s, p75s, alpha=0.32, color="C0", lw=0, label="p25..p75")
ax.plot(xs, p50s, color="C0", lw=1.5, label="p50")
else:
ax.fill_between(xs, p10s, p90s, alpha=0.25, color="C0", lw=0)
ax.plot(xs, p50s, color="C0", lw=1.6)
def make_kl_figure(cells: list[dict], a: Args, out_path: Path) -> None: def make_kl_figure(cells: list[dict], a: Args, out_path: Path) -> None:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, len(a.alphas), figsize=(4.0 * len(a.alphas), 3.2), fig, axes = plt.subplots(1, len(a.alphas), figsize=(3.7 * len(a.alphas), 3.4),
sharex=True, sharey=True, squeeze=False, constrained_layout=True) sharex=True, sharey=True, squeeze=False, constrained_layout=True)
label = a.model_contains or "all models" label = a.model_contains or "all models"
mode = "individual trajectories (red=ever crossed KL=1)" if a.spaghetti \ n_max = max((_pool_kl(cells, alpha, T=a.window).shape[0] for alpha in a.alphas), default=0)
else f"p50 + p10..p90 band, smoothed rolling-{a.roll}" mode = f"individual trajectories, roll={a.roll}, color=pmass" if (a.spaghetti and a.color_by_pmass) \
else "individual trajectories (red=ever crossed KL=1)" if a.spaghetti \
else f"shaded quantiles p10/p25/p50/p75/p90, roll={a.roll}" if a.quantile_lines \
else f"p50 + p10..p90 band, smoothed rolling-{a.roll}"
fig.suptitle( fig.suptitle(
f"KL trajectory on N=8 held-out long-form prompts ({label})\n" f"KL trajectory on N={n_max} held-out long-form prompts ({label})\n"
f"{mode}. Solid line: KL=1 nat. Dotted v-line: t=20.", f"{mode}. Solid horizontal: KL=1 nat.",
fontsize=10, fontsize=10,
) )
@@ -202,27 +218,32 @@ def make_kl_figure(cells: list[dict], a: Args, out_path: Path) -> None:
Pe = _pool_pmass_eval(cells, alpha, _first_fork(cells), T=a.window) if a.color_by_pmass else None Pe = _pool_pmass_eval(cells, alpha, _first_fork(cells), T=a.window) if a.color_by_pmass else None
_draw_kl_panel(ax, K, a, P=Pe) _draw_kl_panel(ax, K, a, P=Pe)
if y_max >= 1.0: if y_max >= 1.0:
ax.axhline(1.0, color="k", lw=1.0) ax.axhline(1.0, color="k", lw=0.7)
else: else:
ax.text(0.98, 0.92, "KL=1 off-scale", transform=ax.transAxes, ax.text(0.98, 0.92, "KL=1 off-scale", transform=ax.transAxes,
ha="right", va="top", fontsize=8, color="0.25") ha="right", va="top", fontsize=8, color="0.25")
ax.axvline(20, color="k", ls=":", lw=0.8) if a.mark_t >= 0:
ax.axvline(a.mark_t, color="k", ls=":", lw=0.6)
ax.set_title(rf"$\alpha = {alpha}$ (n={K.shape[0]} traj)") ax.set_title(rf"$\alpha = {alpha}$ (n={K.shape[0]} traj)")
ax.set_xlim(0, x_max) ax.set_xlim(0, x_max)
ax.set_ylim(0, y_max) ax.set_ylim(0, y_max)
ax.set_xlabel("token") ax.set_xlabel("token")
if j == 0: if j == 0:
ax.set_ylabel("KL(steered || base) [nats]") ax.set_ylabel("KL")
if a.color_by_pmass: if a.color_by_pmass:
from matplotlib.cm import ScalarMappable from matplotlib.cm import ScalarMappable
from matplotlib.colors import LinearSegmentedColormap, Normalize from matplotlib.colors import LinearSegmentedColormap, Normalize
cmap = LinearSegmentedColormap.from_list("alive", ["#c0392b", "#f1c40f", "#27ae60"]) cmap = LinearSegmentedColormap.from_list("alive", ["#c0392b", "#f1c40f", "#27ae60"])
sm = ScalarMappable(norm=Normalize(0, 1), cmap=cmap); sm.set_array([]) sm = ScalarMappable(norm=Normalize(0, 1), cmap=cmap); sm.set_array([])
cbar = fig.colorbar(sm, ax=axes[0, :].tolist(), location="right", shrink=0.75, pad=0.01, fraction=0.015) cbar = fig.colorbar(sm, ax=axes[0, :].tolist(), location="right", shrink=0.72, pad=0.015, fraction=0.012)
cbar.set_label("pmass (0=dead, 1=alive)") cbar.set_label("pmass", labelpad=2)
fig.savefig(out_path, dpi=160, bbox_inches="tight") fig.savefig(out_path, dpi=160, bbox_inches="tight")
if a.quantile_lines:
q_path = out_path.with_name(out_path.stem + "_quantile_lines" + out_path.suffix)
fig.savefig(q_path, dpi=160, bbox_inches="tight")
logger.info(f"KL quantile-line figure -> {q_path}")
logger.info(f"KL-only figure -> {out_path}") logger.info(f"KL-only figure -> {out_path}")
+11 -4
View File
@@ -16,6 +16,10 @@ class Args:
threshold: float = 0.95 threshold: float = 0.95
out_name: str = "figs_auto" out_name: str = "figs_auto"
model_contains: str = "" model_contains: str = ""
line_alpha: float | None = None # forwarded to spaghetti+aggregate
roll: int = 65
line_lw: float = 0.5
quantile_lines: bool = False
def _ensure_single_run_root(run_dir: Path) -> Path: def _ensure_single_run_root(run_dir: Path) -> Path:
@@ -72,18 +76,21 @@ def main(a: Args) -> None:
"--window", str(__import__("json").loads(calib.read_text())["window"]), "--window", str(__import__("json").loads(calib.read_text())["window"]),
"--threshold", str(a.threshold), "--threshold", str(a.threshold),
"--model-contains", model_filter, "--model-contains", model_filter,
], repo_root) "--roll", str(a.roll),
"--line-lw", str(a.line_lw),
] + (["--line-alpha", str(a.line_alpha)] if a.line_alpha is not None else []), repo_root)
_run([ _run([
sys.executable, sys.executable,
"scripts/aggregate.py", "scripts/aggregate.py",
"--runs-root", str(single_root), "--runs-root", str(single_root),
"--out", str(out_dir / "aggregate"), "--out", str(out_dir / "aggregate"),
"--window", str(__import__("json").loads(calib.read_text())["window"]), "--window", str(__import__("json").loads(calib.read_text())["window"]),
"--spaghetti",
"--color-by-pmass",
"--kl-only", "--kl-only",
"--model-contains", model_filter, "--model-contains", model_filter,
], repo_root) "--roll", str(a.roll),
"--line-lw", str(a.line_lw),
] + (["--spaghetti", "--color-by-pmass"] if not a.quantile_lines else ["--quantile-lines"]) \
+ (["--line-alpha", str(a.line_alpha)] if a.line_alpha is not None else []), repo_root)
pngs = [ pngs = [
out_dir / "survival" / "survival_pmass_eval.png", out_dir / "survival" / "survival_pmass_eval.png",
+93 -12
View File
@@ -57,12 +57,13 @@ from pathlib import Path
import torch import torch
import tyro import tyro
from loguru import logger from loguru import logger
from tqdm.auto import tqdm
from iso_kl_figure import ( from iso_kl_figure import (
SteeringConfig, MeanDiffC, PCAC, DirectionalAblationC, SteeringConfig, MeanDiffC, PCAC, DirectionalAblationC,
train, calibrate_iso_kl, measure_kl, train, calibrate_iso_kl, measure_kl,
) )
from iso_kl_figure.branch_pmass import branch_pmass, collect_choice_token_ids from iso_kl_figure.branch_pmass import branch_pmass, build_chat_interrupt_suffix, collect_choice_token_ids
from iso_kl_figure.calibrate import _eos_token_ids from iso_kl_figure.calibrate import _eos_token_ids
from iso_kl_figure.target import _get_blocks from iso_kl_figure.target import _get_blocks
@@ -103,6 +104,14 @@ EVAL_PROMPTS = [
"Was Shakespeare a contemporary of Queen Elizabeth I of England?", "Was Shakespeare a contemporary of Queen Elizabeth I of England?",
"Is the Pacific Ocean larger in surface area than the Atlantic Ocean?", "Is the Pacific Ocean larger in surface area than the Atlantic Ocean?",
"Does a transformer architecture rely on self-attention rather than recurrence for sequence modeling?", "Does a transformer architecture rely on self-attention rather than recurrence for sequence modeling?",
"Is honey produced by bees primarily from nectar collected from flowers?",
"Did Ada Lovelace write what is often considered the first published computer algorithm?",
"Is the boiling point of pure water at sea level approximately 100 degrees Celsius?",
"Does cross-entropy loss reduce to negative log likelihood for a single correct class label?",
"Was the printing press invented by Johannes Gutenberg in the 15th century?",
"Is the speed of light in a vacuum the same for all inertial observers?",
"Do antibiotics kill viruses in addition to bacteria?",
"Is the chemical symbol for gold 'Au' on the periodic table?",
] ]
# pmass diagnostic prompts share the same schema as EVAL above. # pmass diagnostic prompts share the same schema as EVAL above.
@@ -144,7 +153,7 @@ class Args:
model: str model: str
method: str method: str
seed: int = 0 seed: int = 0
window: int = 50 window: int = 512
run_id: str = "" run_id: str = ""
layer_frac: float = 0.6 layer_frac: float = 0.6
target_kl: float = 1.0 target_kl: float = 1.0
@@ -159,6 +168,13 @@ class Args:
skip_pmass_qa: bool = False # skip yes/no pmass loop, only do paired pmass_eval skip_pmass_qa: bool = False # skip yes/no pmass loop, only do paired pmass_eval
render_figs: bool = False # render single-run survival + spaghetti + KL pngs render_figs: bool = False # render single-run survival + spaghetti + KL pngs
render_threshold: float = 0.95 render_threshold: float = 0.95
chat_interrupt: bool = False # use chat-template-correct interrupt turn for pmass scoring
# (in-distribution; recommended over raw splice).
interrupt_user_text: str = (
"Given the conversation so far, what is your final answer?"
" Respond with JSON: {\"Answer\": true} or {\"Answer\": false}."
)
interrupt_prefill: str = '{"Answer": '
def _set_seed(s: int): def _set_seed(s: int):
@@ -169,6 +185,28 @@ def _set_seed(s: int):
torch.cuda.manual_seed_all(s) torch.cuda.manual_seed_all(s)
def _log_fork_summary(stage: str, alpha: float, gen, gen_text: str,
fork_points: list[int], pm: dict) -> None:
"""One INFO line + one debug table per (stage, alpha) instead of N forks * 1 line.
SHOULD: at alpha~1, pmass ~> 0.5 across forks (model still respects schema).
ELSE: prefill broken or schema drift -- inspect debug_first in pmass.json.
"""
from tabulate import tabulate
pmv = pm["pmass"]
ptv = pm["p_true"]
pm_min = min(pmv) if pmv else float("nan")
pm_med = sorted(pmv)[len(pmv) // 2] if pmv else float("nan")
logger.info(
f" alpha={alpha} {stage}[0] gen_len={gen.shape[0]} pmass(med/min)="
f"{pm_med:.2f}/{pm_min:.2f} | text[:80]={gen_text[:80]!r}"
)
rows = [(t, f"{p:.3f}", f"{q:.3f}", repr(s))
for t, p, q, s in zip(fork_points, pmv, ptv, pm["argmax_str"])]
table = tabulate(rows, headers=["t", "pmass", "p_true", "argmax"], tablefmt="plain")
logger.debug(f"\nfork table [{stage} alpha={alpha}]\n{table}")
def _render_chat(tok, msgs: list[dict[str, str]], add_generation_prompt: bool) -> str: def _render_chat(tok, msgs: list[dict[str, str]], add_generation_prompt: bool) -> str:
if tok.chat_template is None: if tok.chat_template is None:
rendered = [] rendered = []
@@ -279,6 +317,15 @@ def main(a: Args):
gen_lens_eval: dict[str, list] = {} # T per (alpha, eval-prompt) gen_lens_eval: dict[str, list] = {} # T per (alpha, eval-prompt)
debug_first: dict[str, dict] = {} # first prompt per alpha: gen_text + top-5 at each fork debug_first: dict[str, dict] = {} # first prompt per alpha: gen_text + top-5 at each fork
# Append-only JSONL of every generation: full text + paired pmass.
# Crash-safe: each line flushed before next prompt. Overwritten only at run start.
gens_path = out_dir / "gens.jsonl"
if gens_path.exists():
gens_path.unlink()
def _append_gen(rec: dict) -> None:
with open(gens_path, "a") as f:
f.write(json.dumps(rec) + "\n")
def write_eval_outputs() -> None: def write_eval_outputs() -> None:
(out_dir / "trajectory.json").write_text(json.dumps({ (out_dir / "trajectory.json").write_text(json.dumps({
"fork_points_full": list(range(a.window)), "fork_points_full": list(range(a.window)),
@@ -317,7 +364,18 @@ def main(a: Args):
tok(s, return_tensors="pt", add_special_tokens=False).input_ids[0] tok(s, return_tensors="pt", add_special_tokens=False).input_ids[0]
for s in eval_prompt_strs for s in eval_prompt_strs
] ]
for alpha in a.alphas: interrupt_suffix_ids = None
interrupt_suffix_decoded = ""
if a.chat_interrupt:
interrupt_suffix_ids = build_chat_interrupt_suffix(
tok, a.interrupt_user_text, a.interrupt_prefill,
)
# SHOULD: decoded suffix ends with prefill literal. ELSE template quirk.
interrupt_suffix_decoded = tok.decode(interrupt_suffix_ids)
logger.info(f"chat_interrupt suffix ({len(interrupt_suffix_ids)} tok): {interrupt_suffix_decoded!r}")
if not interrupt_suffix_decoded.endswith(a.interrupt_prefill):
logger.warning(f"interrupt suffix does not end with prefill={a.interrupt_prefill!r}")
for alpha in tqdm(a.alphas, desc=f"alphas[{a.run_id}]", mininterval=60):
v.cfg.coeff = alpha * c_star v.cfg.coeff = alpha * c_star
logger.info(f"=== eval alpha={alpha} c={v.cfg.coeff:+.4f} ===") logger.info(f"=== eval alpha={alpha} c={v.cfg.coeff:+.4f} ===")
m = measure_kl(v, model, tok, eval_ids_list, T=a.window, device=a.device) m = measure_kl(v, model, tok, eval_ids_list, T=a.window, device=a.device)
@@ -351,6 +409,7 @@ def main(a: Args):
v, model, tok, ids, gen, fork_points, v, model, tok, ids, gen, fork_points,
PREFILL_STR, a_ids, b_ids, PREFILL_STR, a_ids, b_ids,
rollout_cache=gen_out.past_key_values, rollout_cache=gen_out.past_key_values,
interrupt_suffix_ids=interrupt_suffix_ids,
device=a.device, device=a.device,
) )
pm_for_alpha.append(pm["pmass"]) pm_for_alpha.append(pm["pmass"])
@@ -363,17 +422,28 @@ def main(a: Args):
else "other" else "other"
for s in pm["argmax_str"] for s in pm["argmax_str"]
]) ])
# debug dump for first QA prompt only # save full gen text + paired pmass for every QA prompt
gen_text = tok.decode(gen, skip_special_tokens=False)
_append_gen({
"alpha": float(alpha), "kind": "qa", "prompt_idx": q_idx,
"score_mode": "chat_interrupt" if a.chat_interrupt else "raw_splice",
"interrupt_user_text": a.interrupt_user_text if a.chat_interrupt else "",
"interrupt_prefill": a.interrupt_prefill if a.chat_interrupt else "",
"interrupt_suffix_ids": list(interrupt_suffix_ids) if interrupt_suffix_ids is not None else [],
"interrupt_suffix_decoded": interrupt_suffix_decoded,
"prompt": prompt_str, "gen_text": gen_text, "gen_len": int(gen.shape[0]),
"fork_points": list(fork_points),
"pmass": pm["pmass"], "p_true": pm["p_true"],
"argmax_str": pm["argmax_str"], "was_thinking": pm["was_thinking"],
})
# debug dump for first QA prompt only (kept for backward compat)
if q_idx == 0: if q_idx == 0:
gen_text = tok.decode(gen, skip_special_tokens=False)
debug_first.setdefault(str(alpha), {})["qa"] = { debug_first.setdefault(str(alpha), {})["qa"] = {
"prompt": prompt_str, "gen_text": gen_text, "gen_len": int(gen.shape[0]), "prompt": prompt_str, "gen_text": gen_text, "gen_len": int(gen.shape[0]),
"pmass_per_fork": pm["pmass"], "p_true_per_fork": pm["p_true"], "pmass_per_fork": pm["pmass"], "p_true_per_fork": pm["p_true"],
"argmax_per_fork": pm["argmax_str"], "argmax_per_fork": pm["argmax_str"],
} }
logger.info(f" [debug] alpha={alpha} qa[0] gen_len={gen.shape[0]} text[:120]={gen_text[:120]!r}") _log_fork_summary("qa", alpha, gen, gen_text, fork_points, pm)
for t, pmv, ptv, am in zip(fork_points, pm["pmass"], pm["p_true"], pm["argmax_str"]):
logger.info(f" t={t:>3} pmass={pmv:.3f} p_true={ptv:.3f} argmax={am!r}")
pmass_all[str(alpha)] = pm_for_alpha pmass_all[str(alpha)] = pm_for_alpha
p_true_all[str(alpha)] = pt_for_alpha p_true_all[str(alpha)] = pt_for_alpha
argmax_all[str(alpha)] = ax_for_alpha argmax_all[str(alpha)] = ax_for_alpha
@@ -407,19 +477,30 @@ def main(a: Args):
v, model, tok, ids, gen, fork_points, v, model, tok, ids, gen, fork_points,
PREFILL_STR, a_ids, b_ids, PREFILL_STR, a_ids, b_ids,
rollout_cache=gen_out.past_key_values, rollout_cache=gen_out.past_key_values,
interrupt_suffix_ids=interrupt_suffix_ids,
device=a.device, device=a.device,
) )
pm_eval_for_alpha.append(pm["pmass"]) pm_eval_for_alpha.append(pm["pmass"])
gen_text = tok.decode(gen, skip_special_tokens=False)
_append_gen({
"alpha": float(alpha), "kind": "eval", "prompt_idx": p_idx,
"score_mode": "chat_interrupt" if a.chat_interrupt else "raw_splice",
"interrupt_user_text": a.interrupt_user_text if a.chat_interrupt else "",
"interrupt_prefill": a.interrupt_prefill if a.chat_interrupt else "",
"interrupt_suffix_ids": list(interrupt_suffix_ids) if interrupt_suffix_ids is not None else [],
"interrupt_suffix_decoded": interrupt_suffix_decoded,
"prompt": prompt_str, "gen_text": gen_text, "gen_len": int(gen.shape[0]),
"fork_points": list(fork_points),
"pmass": pm["pmass"], "p_true": pm["p_true"],
"argmax_str": pm["argmax_str"], "was_thinking": pm["was_thinking"],
})
if p_idx == 0: if p_idx == 0:
gen_text = tok.decode(gen, skip_special_tokens=False)
debug_first.setdefault(str(alpha), {})["eval"] = { debug_first.setdefault(str(alpha), {})["eval"] = {
"prompt": prompt_str, "gen_text": gen_text, "gen_len": int(gen.shape[0]), "prompt": prompt_str, "gen_text": gen_text, "gen_len": int(gen.shape[0]),
"pmass_per_fork": pm["pmass"], "p_true_per_fork": pm["p_true"], "pmass_per_fork": pm["pmass"], "p_true_per_fork": pm["p_true"],
"argmax_per_fork": pm["argmax_str"], "argmax_per_fork": pm["argmax_str"],
} }
logger.info(f" [debug] alpha={alpha} eval[0] gen_len={gen.shape[0]} text[:120]={gen_text[:120]!r}") _log_fork_summary("eval", alpha, gen, gen_text, fork_points, pm)
for t, pmv, ptv, am in zip(fork_points, pm["pmass"], pm["p_true"], pm["argmax_str"]):
logger.info(f" t={t:>3} pmass={pmv:.3f} p_true={ptv:.3f} argmax={am!r}")
pmass_eval_all[str(alpha)] = pm_eval_for_alpha pmass_eval_all[str(alpha)] = pm_eval_for_alpha
gen_lens_eval[str(alpha)] = gen_lens_eval_alpha gen_lens_eval[str(alpha)] = gen_lens_eval_alpha
# SHOULD: at alpha=1, mean(pmass at t=0) > 0.5 (model still respects schema). # SHOULD: at alpha=1, mean(pmass at t=0) > 0.5 (model still respects schema).
+54 -3
View File
@@ -55,6 +55,43 @@ def collect_choice_token_ids(
return _ids(a_words), _ids(b_words) return _ids(a_words), _ids(b_words)
def build_chat_interrupt_suffix(
tok,
interrupt_user_text: str,
assistant_prefill: str,
) -> list[int]:
"""Construct token ids for: <eot> <user>{question}<eot> <assistant>{prefill}.
Used to interrupt an in-progress assistant rollout with a follow-up question
in a chat-template-correct way (vs. raw token splice). Detects the suffix
by string-diffing the rendered template before vs. after appending the
interrupt turn -- no template-specific special-token guessing.
SHOULD: tok.decode(suffix_ids) ends with assistant_prefill. ELSE template
quirk: inspect tok.chat_template.
"""
base_msgs = [
{"role": "user", "content": "_"},
{"role": "assistant", "content": "_ROLLED_"},
]
ext_msgs = base_msgs + [
{"role": "user", "content": interrupt_user_text},
{"role": "assistant", "content": assistant_prefill},
]
base_str = tok.apply_chat_template(base_msgs, tokenize=False)
ext_str = tok.apply_chat_template(ext_msgs, tokenize=False, continue_final_message=True)
if not ext_str.startswith(base_str):
# template inserts trailing whitespace in `base_str` — strip and retry
base_str = base_str.rstrip()
if not ext_str.startswith(base_str):
raise ValueError(
"chat template not prefix-stable: cannot derive interrupt suffix. "
"Set use_chat_interrupt=False or extend this helper."
)
suffix_str = ext_str[len(base_str):]
return tok.encode(suffix_str, add_special_tokens=False)
def _is_thinking(seq_ids: Tensor, think_id: int, unthink_id: int) -> bool: def _is_thinking(seq_ids: Tensor, think_id: int, unthink_id: int) -> bool:
"""True iff the last `<think>` in seq_ids is after the last `</think>`.""" """True iff the last `<think>` in seq_ids is after the last `</think>`."""
if think_id is None or unthink_id is None: if think_id is None or unthink_id is None:
@@ -89,6 +126,9 @@ def branch_pmass(
handle_thinking: bool = True, handle_thinking: bool = True,
end_think_str: str = "</think>", end_think_str: str = "</think>",
force_close_str: str = "\nI should answer now.", # see tinymfv/guided.py force_close_str: str = "\nI should answer now.", # see tinymfv/guided.py
interrupt_suffix_ids: Sequence[int] | None = None, # if set, overrides prefill_str + thinking
# close logic and uses chat-template-correct interrupt
# turn (built via build_chat_interrupt_suffix).
device: str | torch.device = "cuda", device: str | torch.device = "cuda",
) -> dict: ) -> dict:
"""Returns dict with parallel lists indexed by fork point: """Returns dict with parallel lists indexed by fork point:
@@ -113,6 +153,13 @@ def branch_pmass(
tok.encode(force_close_str + end_think_str, add_special_tokens=False), tok.encode(force_close_str + end_think_str, add_special_tokens=False),
device=device, dtype=torch.long, device=device, dtype=torch.long,
) if handle_thinking else None ) if handle_thinking else None
# Chat-template interrupt mode: replaces both close_t and pre_t with one
# template-correct turn boundary (<eot><user>...<eot><assistant>{prefill}).
# Disables thinking-detection (irrelevant: we are starting a new turn).
interrupt_t = (
torch.tensor(list(interrupt_suffix_ids), device=device, dtype=torch.long)
if interrupt_suffix_ids is not None else None
)
a_t = torch.tensor(list(a_ids), dtype=torch.long, device=device) a_t = torch.tensor(list(a_ids), dtype=torch.long, device=device)
b_t = torch.tensor(list(b_ids), dtype=torch.long, device=device) b_t = torch.tensor(list(b_ids), dtype=torch.long, device=device)
all_t = torch.cat([a_t, b_t]) all_t = torch.cat([a_t, b_t])
@@ -130,9 +177,13 @@ def branch_pmass(
continue continue
prefix = rolled[:t] prefix = rolled[:t]
seq_so_far = torch.cat([pids, prefix]) seq_so_far = torch.cat([pids, prefix])
thinking = (handle_thinking and think_id is not None and unthink_id is not None if interrupt_t is not None:
and _is_thinking(seq_so_far, think_id, unthink_id)) new_tail = interrupt_t
new_tail = torch.cat([close_t, pre_t]) if (thinking and close_t is not None) else pre_t thinking = False # not meaningful in chat-interrupt mode
else:
thinking = (handle_thinking and think_id is not None and unthink_id is not None
and _is_thinking(seq_so_far, think_id, unthink_id))
new_tail = torch.cat([close_t, pre_t]) if (thinking and close_t is not None) else pre_t
if rollout_cache is not None and use_rollout_cache: if rollout_cache is not None and use_rollout_cache:
cache = copy.deepcopy(rollout_cache) cache = copy.deepcopy(rollout_cache)
+7 -5
View File
@@ -158,12 +158,14 @@ def measure_kl(
base_full = torch.cat([pids.to(device), base_gen]) base_full = torch.cat([pids.to(device), base_gen])
decoded_base = tok.decode(base_full, skip_special_tokens=False) decoded_base = tok.decode(base_full, skip_special_tokens=False)
decoded_steer = tok.decode(full_ids, skip_special_tokens=False) decoded_steer = tok.decode(full_ids, skip_special_tokens=False)
# SHOULD: BASE and STEER both coherent; STEER differs from BASE but does not collapse.
# Truncate to keep iso-KL bracket logs scannable; full text in trajectory.json/debug_first.
head = 400
logger.info( logger.info(
f"EXPECT: same prompt under c=0 vs c={v.cfg.coeff:+.4f}; both coherent; " f"SHOULD: c=0 vs c={v.cfg.coeff:+.4f} both coherent, steer differs but does not collapse.\n"
"steered should differ from base but not collapse.\n" f"=== CALIBRATE demo (T={T}, first {head} chars each) ===\n"
f"\n=== CALIBRATE demo trace (T={T}) ===\n" f"-- BASE : {decoded_base[:head]!r}\n"
f"--- BASE (c=0) ---\n{decoded_base}\n" f"-- STEER : {decoded_steer[:head]!r}\n"
f"\n--- STEER (c={v.cfg.coeff:+.4f}) ---\n{decoded_steer}\n"
f"=== /CALIBRATE ===" f"=== /CALIBRATE ==="
) )
kls = _kl_generated_incremental(v, model, pids, gen, device) kls = _kl_generated_incremental(v, model, pids, gen, device)
+19 -1
View File
@@ -15,7 +15,11 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
from iso_kl_figure import ( from iso_kl_figure import (
SteeringConfig, MeanDiffC, train, measure_kl, attach, detach, SteeringConfig, MeanDiffC, train, measure_kl, attach, detach,
) )
from iso_kl_figure.branch_pmass import branch_pmass, collect_choice_token_ids from iso_kl_figure.branch_pmass import (
branch_pmass,
build_chat_interrupt_suffix,
collect_choice_token_ids,
)
MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM" MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM"
@@ -61,8 +65,22 @@ def test_pipeline_smoke():
v.cfg.coeff = 5.0 v.cfg.coeff = 5.0
p_steer = branch_pmass(v, model, tok, pids, rolled, fork, prefill, p_steer = branch_pmass(v, model, tok, pids, rolled, fork, prefill,
a_ids, b_ids, device="cpu") a_ids, b_ids, device="cpu")
tok.chat_template = (
"{% for message in messages %}"
"<|{{ message['role'] }}|>\n{{ message['content'] }}\n"
"{% endfor %}"
"{% if add_generation_prompt %}<|assistant|>\n{% endif %}"
)
interrupt_ids = build_chat_interrupt_suffix(tok, "Final answer?", '{"Answer": ')
assert tok.decode(interrupt_ids).endswith('{"Answer": ')
p_interrupt = branch_pmass(
v, model, tok, pids, rolled, fork, prefill,
a_ids, b_ids, interrupt_suffix_ids=interrupt_ids, device="cpu",
)
for x in p_zero["pmass"] + p_steer["pmass"]: for x in p_zero["pmass"] + p_steer["pmass"]:
assert 0.0 <= x <= 1.0, f"pmass out of [0,1]: {x}" assert 0.0 <= x <= 1.0, f"pmass out of [0,1]: {x}"
for x in p_interrupt["pmass"]:
assert 0.0 <= x <= 1.0, f"chat-interrupt pmass out of [0,1]: {x}"
diff = max(abs(a - b) for a, b in zip(p_zero["pmass"], p_steer["pmass"])) diff = max(abs(a - b) for a, b in zip(p_zero["pmass"], p_steer["pmass"]))
assert diff > 1e-8, "pmass invariant to coeff -- hook is dead" assert diff > 1e-8, "pmass invariant to coeff -- hook is dead"
# p_true should be in [0,1] or NaN # p_true should be in [0,1] or NaN