mirror of
https://github.com/wassname/isokl_steering_calibration.git
synced 2026-06-27 15:15:52 +08:00
write up
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
# iso-kl-figure (short version)
|
||||
# calibrating steering overl ong trajectories by normalising KL outliers
|
||||
|
||||

|
||||
|
||||
## The problem
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Binary file not shown.
|
After Width: | Height: | Size: 190 KiB |
@@ -11,7 +11,7 @@ test:
|
||||
uv run --extra all pytest -q
|
||||
|
||||
# Run one (model, method, seed, window) cell end-to-end (calibrate + trajectory + pmass).
|
||||
cell model="Qwen/Qwen3.5-0.8B" method="mean_diff" seed="0" window="50":
|
||||
cell model="Qwen/Qwen3.5-0.8B" method="mean_diff" seed="0" window="512":
|
||||
uv run --extra all python scripts/run_cell.py \
|
||||
--model {{model}} --method {{method}} --seed {{seed}} --window {{window}}
|
||||
|
||||
|
||||
+37
-16
@@ -65,14 +65,19 @@ except Exception:
|
||||
class Args:
|
||||
runs_root: str = "outputs"
|
||||
out: str = "figs"
|
||||
window: int = 50 # only this window enters the figure (need long-enough traj for rolling-16)
|
||||
roll: int = 16 # smoothing window for KL trajectory
|
||||
window: int = 512 # only this window enters the figure
|
||||
roll: int = 65 # smoothing window for KL trajectory
|
||||
alphas: tuple[str, ...] = ("0.25", "0.5", "0.75", "1.0", "1.5", "2.0", "4.0")
|
||||
kl_ymax: float = 6.0
|
||||
model_contains: str = ""
|
||||
kl_only: bool = False
|
||||
spaghetti: bool = False # plot individual trajectories instead of p10..p90 band
|
||||
color_by_pmass: bool = False # color KL spaghetti lines by paired pmass (requires pmass_eval)
|
||||
line_alpha: float | None = None # per-line alpha override; None = auto clip(2.5/n,.08,.35)
|
||||
line_lw: float = 0.18 # per-trajectory linewidth; full opacity needs very thin lines
|
||||
median_lw: float = 0.75 # median linewidth
|
||||
quantile_lines: bool = False # clean summary: p10/p50/p90 lines, no fill/spaghetti
|
||||
mark_t: int = -1 # optional vertical token marker; -1 disables
|
||||
|
||||
|
||||
def _rolling_mean(x: np.ndarray, w: int) -> np.ndarray:
|
||||
@@ -138,20 +143,20 @@ def _draw_kl_panel(ax, K: np.ndarray, a: Args, P: np.ndarray | None = None) -> N
|
||||
pts = np.column_stack([xs, traj])
|
||||
segs = np.stack([pts[:-1], pts[1:]], axis=1)
|
||||
lc = LineCollection(segs, cmap=cmap, norm=plt.Normalize(0, 1),
|
||||
linewidths=0.7, alpha=0.55)
|
||||
linewidths=a.line_lw, alpha=float(a.line_alpha if a.line_alpha is not None else np.clip(2.5/max(K.shape[0],1), 0.08, 0.35)))
|
||||
lc.set_array(pmass_row[:-1])
|
||||
ax.add_collection(lc)
|
||||
ax.set_xlim(xs[0], xs[-1])
|
||||
med = np.nanmedian(Kp, axis=0)
|
||||
ax.plot(xs, med, color="k", lw=1.6)
|
||||
ax.plot(xs, med, color="k", lw=a.median_lw)
|
||||
else:
|
||||
crossed = (K > 1.0).any(axis=1)
|
||||
for traj in Kp[~crossed]:
|
||||
ax.plot(xs, traj, color="0.55", lw=0.6, alpha=0.6)
|
||||
ax.plot(xs, traj, color="0.55", lw=a.line_lw, alpha=0.5)
|
||||
for traj in Kp[crossed]:
|
||||
ax.plot(xs, traj, color="C3", lw=0.6, alpha=0.6)
|
||||
ax.plot(xs, traj, color="C3", lw=a.line_lw, alpha=0.5)
|
||||
med = np.nanmedian(Kp, axis=0)
|
||||
ax.plot(xs, med, color="k", lw=1.6)
|
||||
ax.plot(xs, med, color="k", lw=a.median_lw)
|
||||
frac = float(crossed.mean())
|
||||
ax.text(0.97, 0.97, f"{frac:.0%} cross KL=1",
|
||||
transform=ax.transAxes, ha="right", va="top",
|
||||
@@ -163,6 +168,14 @@ def _draw_kl_panel(ax, K: np.ndarray, a: Args, P: np.ndarray | None = None) -> N
|
||||
p50s = _rolling_mean(p50, a.roll)
|
||||
p10s = _rolling_mean(p10, a.roll)
|
||||
p90s = _rolling_mean(p90, a.roll)
|
||||
if a.quantile_lines:
|
||||
# standard quantile fan: outer band light, inner band dark, p50 line
|
||||
p25s = _rolling_mean(np.nanpercentile(K, 25, axis=0), a.roll)
|
||||
p75s = _rolling_mean(np.nanpercentile(K, 75, axis=0), a.roll)
|
||||
ax.fill_between(xs, p10s, p90s, alpha=0.15, color="C0", lw=0, label="p10..p90")
|
||||
ax.fill_between(xs, p25s, p75s, alpha=0.32, color="C0", lw=0, label="p25..p75")
|
||||
ax.plot(xs, p50s, color="C0", lw=1.5, label="p50")
|
||||
else:
|
||||
ax.fill_between(xs, p10s, p90s, alpha=0.25, color="C0", lw=0)
|
||||
ax.plot(xs, p50s, color="C0", lw=1.6)
|
||||
|
||||
@@ -170,14 +183,17 @@ def _draw_kl_panel(ax, K: np.ndarray, a: Args, P: np.ndarray | None = None) -> N
|
||||
def make_kl_figure(cells: list[dict], a: Args, out_path: Path) -> None:
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
fig, axes = plt.subplots(1, len(a.alphas), figsize=(4.0 * len(a.alphas), 3.2),
|
||||
fig, axes = plt.subplots(1, len(a.alphas), figsize=(3.7 * len(a.alphas), 3.4),
|
||||
sharex=True, sharey=True, squeeze=False, constrained_layout=True)
|
||||
label = a.model_contains or "all models"
|
||||
mode = "individual trajectories (red=ever crossed KL=1)" if a.spaghetti \
|
||||
n_max = max((_pool_kl(cells, alpha, T=a.window).shape[0] for alpha in a.alphas), default=0)
|
||||
mode = f"individual trajectories, roll={a.roll}, color=pmass" if (a.spaghetti and a.color_by_pmass) \
|
||||
else "individual trajectories (red=ever crossed KL=1)" if a.spaghetti \
|
||||
else f"shaded quantiles p10/p25/p50/p75/p90, roll={a.roll}" if a.quantile_lines \
|
||||
else f"p50 + p10..p90 band, smoothed rolling-{a.roll}"
|
||||
fig.suptitle(
|
||||
f"KL trajectory on N=8 held-out long-form prompts ({label})\n"
|
||||
f"{mode}. Solid line: KL=1 nat. Dotted v-line: t=20.",
|
||||
f"KL trajectory on N={n_max} held-out long-form prompts ({label})\n"
|
||||
f"{mode}. Solid horizontal: KL=1 nat.",
|
||||
fontsize=10,
|
||||
)
|
||||
|
||||
@@ -202,27 +218,32 @@ def make_kl_figure(cells: list[dict], a: Args, out_path: Path) -> None:
|
||||
Pe = _pool_pmass_eval(cells, alpha, _first_fork(cells), T=a.window) if a.color_by_pmass else None
|
||||
_draw_kl_panel(ax, K, a, P=Pe)
|
||||
if y_max >= 1.0:
|
||||
ax.axhline(1.0, color="k", lw=1.0)
|
||||
ax.axhline(1.0, color="k", lw=0.7)
|
||||
else:
|
||||
ax.text(0.98, 0.92, "KL=1 off-scale", transform=ax.transAxes,
|
||||
ha="right", va="top", fontsize=8, color="0.25")
|
||||
ax.axvline(20, color="k", ls=":", lw=0.8)
|
||||
if a.mark_t >= 0:
|
||||
ax.axvline(a.mark_t, color="k", ls=":", lw=0.6)
|
||||
ax.set_title(rf"$\alpha = {alpha}$ (n={K.shape[0]} traj)")
|
||||
ax.set_xlim(0, x_max)
|
||||
ax.set_ylim(0, y_max)
|
||||
ax.set_xlabel("token")
|
||||
if j == 0:
|
||||
ax.set_ylabel("KL(steered || base) [nats]")
|
||||
ax.set_ylabel("KL")
|
||||
|
||||
if a.color_by_pmass:
|
||||
from matplotlib.cm import ScalarMappable
|
||||
from matplotlib.colors import LinearSegmentedColormap, Normalize
|
||||
cmap = LinearSegmentedColormap.from_list("alive", ["#c0392b", "#f1c40f", "#27ae60"])
|
||||
sm = ScalarMappable(norm=Normalize(0, 1), cmap=cmap); sm.set_array([])
|
||||
cbar = fig.colorbar(sm, ax=axes[0, :].tolist(), location="right", shrink=0.75, pad=0.01, fraction=0.015)
|
||||
cbar.set_label("pmass (0=dead, 1=alive)")
|
||||
cbar = fig.colorbar(sm, ax=axes[0, :].tolist(), location="right", shrink=0.72, pad=0.015, fraction=0.012)
|
||||
cbar.set_label("pmass", labelpad=2)
|
||||
|
||||
fig.savefig(out_path, dpi=160, bbox_inches="tight")
|
||||
if a.quantile_lines:
|
||||
q_path = out_path.with_name(out_path.stem + "_quantile_lines" + out_path.suffix)
|
||||
fig.savefig(q_path, dpi=160, bbox_inches="tight")
|
||||
logger.info(f"KL quantile-line figure -> {q_path}")
|
||||
logger.info(f"KL-only figure -> {out_path}")
|
||||
|
||||
|
||||
|
||||
@@ -16,6 +16,10 @@ class Args:
|
||||
threshold: float = 0.95
|
||||
out_name: str = "figs_auto"
|
||||
model_contains: str = ""
|
||||
line_alpha: float | None = None # forwarded to spaghetti+aggregate
|
||||
roll: int = 65
|
||||
line_lw: float = 0.5
|
||||
quantile_lines: bool = False
|
||||
|
||||
|
||||
def _ensure_single_run_root(run_dir: Path) -> Path:
|
||||
@@ -72,18 +76,21 @@ def main(a: Args) -> None:
|
||||
"--window", str(__import__("json").loads(calib.read_text())["window"]),
|
||||
"--threshold", str(a.threshold),
|
||||
"--model-contains", model_filter,
|
||||
], repo_root)
|
||||
"--roll", str(a.roll),
|
||||
"--line-lw", str(a.line_lw),
|
||||
] + (["--line-alpha", str(a.line_alpha)] if a.line_alpha is not None else []), repo_root)
|
||||
_run([
|
||||
sys.executable,
|
||||
"scripts/aggregate.py",
|
||||
"--runs-root", str(single_root),
|
||||
"--out", str(out_dir / "aggregate"),
|
||||
"--window", str(__import__("json").loads(calib.read_text())["window"]),
|
||||
"--spaghetti",
|
||||
"--color-by-pmass",
|
||||
"--kl-only",
|
||||
"--model-contains", model_filter,
|
||||
], repo_root)
|
||||
"--roll", str(a.roll),
|
||||
"--line-lw", str(a.line_lw),
|
||||
] + (["--spaghetti", "--color-by-pmass"] if not a.quantile_lines else ["--quantile-lines"]) \
|
||||
+ (["--line-alpha", str(a.line_alpha)] if a.line_alpha is not None else []), repo_root)
|
||||
|
||||
pngs = [
|
||||
out_dir / "survival" / "survival_pmass_eval.png",
|
||||
|
||||
+93
-12
@@ -57,12 +57,13 @@ from pathlib import Path
|
||||
import torch
|
||||
import tyro
|
||||
from loguru import logger
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from iso_kl_figure import (
|
||||
SteeringConfig, MeanDiffC, PCAC, DirectionalAblationC,
|
||||
train, calibrate_iso_kl, measure_kl,
|
||||
)
|
||||
from iso_kl_figure.branch_pmass import branch_pmass, collect_choice_token_ids
|
||||
from iso_kl_figure.branch_pmass import branch_pmass, build_chat_interrupt_suffix, collect_choice_token_ids
|
||||
from iso_kl_figure.calibrate import _eos_token_ids
|
||||
from iso_kl_figure.target import _get_blocks
|
||||
|
||||
@@ -103,6 +104,14 @@ EVAL_PROMPTS = [
|
||||
"Was Shakespeare a contemporary of Queen Elizabeth I of England?",
|
||||
"Is the Pacific Ocean larger in surface area than the Atlantic Ocean?",
|
||||
"Does a transformer architecture rely on self-attention rather than recurrence for sequence modeling?",
|
||||
"Is honey produced by bees primarily from nectar collected from flowers?",
|
||||
"Did Ada Lovelace write what is often considered the first published computer algorithm?",
|
||||
"Is the boiling point of pure water at sea level approximately 100 degrees Celsius?",
|
||||
"Does cross-entropy loss reduce to negative log likelihood for a single correct class label?",
|
||||
"Was the printing press invented by Johannes Gutenberg in the 15th century?",
|
||||
"Is the speed of light in a vacuum the same for all inertial observers?",
|
||||
"Do antibiotics kill viruses in addition to bacteria?",
|
||||
"Is the chemical symbol for gold 'Au' on the periodic table?",
|
||||
]
|
||||
|
||||
# pmass diagnostic prompts share the same schema as EVAL above.
|
||||
@@ -144,7 +153,7 @@ class Args:
|
||||
model: str
|
||||
method: str
|
||||
seed: int = 0
|
||||
window: int = 50
|
||||
window: int = 512
|
||||
run_id: str = ""
|
||||
layer_frac: float = 0.6
|
||||
target_kl: float = 1.0
|
||||
@@ -159,6 +168,13 @@ class Args:
|
||||
skip_pmass_qa: bool = False # skip yes/no pmass loop, only do paired pmass_eval
|
||||
render_figs: bool = False # render single-run survival + spaghetti + KL pngs
|
||||
render_threshold: float = 0.95
|
||||
chat_interrupt: bool = False # use chat-template-correct interrupt turn for pmass scoring
|
||||
# (in-distribution; recommended over raw splice).
|
||||
interrupt_user_text: str = (
|
||||
"Given the conversation so far, what is your final answer?"
|
||||
" Respond with JSON: {\"Answer\": true} or {\"Answer\": false}."
|
||||
)
|
||||
interrupt_prefill: str = '{"Answer": '
|
||||
|
||||
|
||||
def _set_seed(s: int):
|
||||
@@ -169,6 +185,28 @@ def _set_seed(s: int):
|
||||
torch.cuda.manual_seed_all(s)
|
||||
|
||||
|
||||
def _log_fork_summary(stage: str, alpha: float, gen, gen_text: str,
|
||||
fork_points: list[int], pm: dict) -> None:
|
||||
"""One INFO line + one debug table per (stage, alpha) instead of N forks * 1 line.
|
||||
|
||||
SHOULD: at alpha~1, pmass ~> 0.5 across forks (model still respects schema).
|
||||
ELSE: prefill broken or schema drift -- inspect debug_first in pmass.json.
|
||||
"""
|
||||
from tabulate import tabulate
|
||||
pmv = pm["pmass"]
|
||||
ptv = pm["p_true"]
|
||||
pm_min = min(pmv) if pmv else float("nan")
|
||||
pm_med = sorted(pmv)[len(pmv) // 2] if pmv else float("nan")
|
||||
logger.info(
|
||||
f" alpha={alpha} {stage}[0] gen_len={gen.shape[0]} pmass(med/min)="
|
||||
f"{pm_med:.2f}/{pm_min:.2f} | text[:80]={gen_text[:80]!r}"
|
||||
)
|
||||
rows = [(t, f"{p:.3f}", f"{q:.3f}", repr(s))
|
||||
for t, p, q, s in zip(fork_points, pmv, ptv, pm["argmax_str"])]
|
||||
table = tabulate(rows, headers=["t", "pmass", "p_true", "argmax"], tablefmt="plain")
|
||||
logger.debug(f"\nfork table [{stage} alpha={alpha}]\n{table}")
|
||||
|
||||
|
||||
def _render_chat(tok, msgs: list[dict[str, str]], add_generation_prompt: bool) -> str:
|
||||
if tok.chat_template is None:
|
||||
rendered = []
|
||||
@@ -279,6 +317,15 @@ def main(a: Args):
|
||||
gen_lens_eval: dict[str, list] = {} # T per (alpha, eval-prompt)
|
||||
debug_first: dict[str, dict] = {} # first prompt per alpha: gen_text + top-5 at each fork
|
||||
|
||||
# Append-only JSONL of every generation: full text + paired pmass.
|
||||
# Crash-safe: each line flushed before next prompt. Overwritten only at run start.
|
||||
gens_path = out_dir / "gens.jsonl"
|
||||
if gens_path.exists():
|
||||
gens_path.unlink()
|
||||
def _append_gen(rec: dict) -> None:
|
||||
with open(gens_path, "a") as f:
|
||||
f.write(json.dumps(rec) + "\n")
|
||||
|
||||
def write_eval_outputs() -> None:
|
||||
(out_dir / "trajectory.json").write_text(json.dumps({
|
||||
"fork_points_full": list(range(a.window)),
|
||||
@@ -317,7 +364,18 @@ def main(a: Args):
|
||||
tok(s, return_tensors="pt", add_special_tokens=False).input_ids[0]
|
||||
for s in eval_prompt_strs
|
||||
]
|
||||
for alpha in a.alphas:
|
||||
interrupt_suffix_ids = None
|
||||
interrupt_suffix_decoded = ""
|
||||
if a.chat_interrupt:
|
||||
interrupt_suffix_ids = build_chat_interrupt_suffix(
|
||||
tok, a.interrupt_user_text, a.interrupt_prefill,
|
||||
)
|
||||
# SHOULD: decoded suffix ends with prefill literal. ELSE template quirk.
|
||||
interrupt_suffix_decoded = tok.decode(interrupt_suffix_ids)
|
||||
logger.info(f"chat_interrupt suffix ({len(interrupt_suffix_ids)} tok): {interrupt_suffix_decoded!r}")
|
||||
if not interrupt_suffix_decoded.endswith(a.interrupt_prefill):
|
||||
logger.warning(f"interrupt suffix does not end with prefill={a.interrupt_prefill!r}")
|
||||
for alpha in tqdm(a.alphas, desc=f"alphas[{a.run_id}]", mininterval=60):
|
||||
v.cfg.coeff = alpha * c_star
|
||||
logger.info(f"=== eval alpha={alpha} c={v.cfg.coeff:+.4f} ===")
|
||||
m = measure_kl(v, model, tok, eval_ids_list, T=a.window, device=a.device)
|
||||
@@ -351,6 +409,7 @@ def main(a: Args):
|
||||
v, model, tok, ids, gen, fork_points,
|
||||
PREFILL_STR, a_ids, b_ids,
|
||||
rollout_cache=gen_out.past_key_values,
|
||||
interrupt_suffix_ids=interrupt_suffix_ids,
|
||||
device=a.device,
|
||||
)
|
||||
pm_for_alpha.append(pm["pmass"])
|
||||
@@ -363,17 +422,28 @@ def main(a: Args):
|
||||
else "other"
|
||||
for s in pm["argmax_str"]
|
||||
])
|
||||
# debug dump for first QA prompt only
|
||||
if q_idx == 0:
|
||||
# save full gen text + paired pmass for every QA prompt
|
||||
gen_text = tok.decode(gen, skip_special_tokens=False)
|
||||
_append_gen({
|
||||
"alpha": float(alpha), "kind": "qa", "prompt_idx": q_idx,
|
||||
"score_mode": "chat_interrupt" if a.chat_interrupt else "raw_splice",
|
||||
"interrupt_user_text": a.interrupt_user_text if a.chat_interrupt else "",
|
||||
"interrupt_prefill": a.interrupt_prefill if a.chat_interrupt else "",
|
||||
"interrupt_suffix_ids": list(interrupt_suffix_ids) if interrupt_suffix_ids is not None else [],
|
||||
"interrupt_suffix_decoded": interrupt_suffix_decoded,
|
||||
"prompt": prompt_str, "gen_text": gen_text, "gen_len": int(gen.shape[0]),
|
||||
"fork_points": list(fork_points),
|
||||
"pmass": pm["pmass"], "p_true": pm["p_true"],
|
||||
"argmax_str": pm["argmax_str"], "was_thinking": pm["was_thinking"],
|
||||
})
|
||||
# debug dump for first QA prompt only (kept for backward compat)
|
||||
if q_idx == 0:
|
||||
debug_first.setdefault(str(alpha), {})["qa"] = {
|
||||
"prompt": prompt_str, "gen_text": gen_text, "gen_len": int(gen.shape[0]),
|
||||
"pmass_per_fork": pm["pmass"], "p_true_per_fork": pm["p_true"],
|
||||
"argmax_per_fork": pm["argmax_str"],
|
||||
}
|
||||
logger.info(f" [debug] alpha={alpha} qa[0] gen_len={gen.shape[0]} text[:120]={gen_text[:120]!r}")
|
||||
for t, pmv, ptv, am in zip(fork_points, pm["pmass"], pm["p_true"], pm["argmax_str"]):
|
||||
logger.info(f" t={t:>3} pmass={pmv:.3f} p_true={ptv:.3f} argmax={am!r}")
|
||||
_log_fork_summary("qa", alpha, gen, gen_text, fork_points, pm)
|
||||
pmass_all[str(alpha)] = pm_for_alpha
|
||||
p_true_all[str(alpha)] = pt_for_alpha
|
||||
argmax_all[str(alpha)] = ax_for_alpha
|
||||
@@ -407,19 +477,30 @@ def main(a: Args):
|
||||
v, model, tok, ids, gen, fork_points,
|
||||
PREFILL_STR, a_ids, b_ids,
|
||||
rollout_cache=gen_out.past_key_values,
|
||||
interrupt_suffix_ids=interrupt_suffix_ids,
|
||||
device=a.device,
|
||||
)
|
||||
pm_eval_for_alpha.append(pm["pmass"])
|
||||
if p_idx == 0:
|
||||
gen_text = tok.decode(gen, skip_special_tokens=False)
|
||||
_append_gen({
|
||||
"alpha": float(alpha), "kind": "eval", "prompt_idx": p_idx,
|
||||
"score_mode": "chat_interrupt" if a.chat_interrupt else "raw_splice",
|
||||
"interrupt_user_text": a.interrupt_user_text if a.chat_interrupt else "",
|
||||
"interrupt_prefill": a.interrupt_prefill if a.chat_interrupt else "",
|
||||
"interrupt_suffix_ids": list(interrupt_suffix_ids) if interrupt_suffix_ids is not None else [],
|
||||
"interrupt_suffix_decoded": interrupt_suffix_decoded,
|
||||
"prompt": prompt_str, "gen_text": gen_text, "gen_len": int(gen.shape[0]),
|
||||
"fork_points": list(fork_points),
|
||||
"pmass": pm["pmass"], "p_true": pm["p_true"],
|
||||
"argmax_str": pm["argmax_str"], "was_thinking": pm["was_thinking"],
|
||||
})
|
||||
if p_idx == 0:
|
||||
debug_first.setdefault(str(alpha), {})["eval"] = {
|
||||
"prompt": prompt_str, "gen_text": gen_text, "gen_len": int(gen.shape[0]),
|
||||
"pmass_per_fork": pm["pmass"], "p_true_per_fork": pm["p_true"],
|
||||
"argmax_per_fork": pm["argmax_str"],
|
||||
}
|
||||
logger.info(f" [debug] alpha={alpha} eval[0] gen_len={gen.shape[0]} text[:120]={gen_text[:120]!r}")
|
||||
for t, pmv, ptv, am in zip(fork_points, pm["pmass"], pm["p_true"], pm["argmax_str"]):
|
||||
logger.info(f" t={t:>3} pmass={pmv:.3f} p_true={ptv:.3f} argmax={am!r}")
|
||||
_log_fork_summary("eval", alpha, gen, gen_text, fork_points, pm)
|
||||
pmass_eval_all[str(alpha)] = pm_eval_for_alpha
|
||||
gen_lens_eval[str(alpha)] = gen_lens_eval_alpha
|
||||
# SHOULD: at alpha=1, mean(pmass at t=0) > 0.5 (model still respects schema).
|
||||
|
||||
@@ -55,6 +55,43 @@ def collect_choice_token_ids(
|
||||
return _ids(a_words), _ids(b_words)
|
||||
|
||||
|
||||
def build_chat_interrupt_suffix(
|
||||
tok,
|
||||
interrupt_user_text: str,
|
||||
assistant_prefill: str,
|
||||
) -> list[int]:
|
||||
"""Construct token ids for: <eot> <user>{question}<eot> <assistant>{prefill}.
|
||||
|
||||
Used to interrupt an in-progress assistant rollout with a follow-up question
|
||||
in a chat-template-correct way (vs. raw token splice). Detects the suffix
|
||||
by string-diffing the rendered template before vs. after appending the
|
||||
interrupt turn -- no template-specific special-token guessing.
|
||||
|
||||
SHOULD: tok.decode(suffix_ids) ends with assistant_prefill. ELSE template
|
||||
quirk: inspect tok.chat_template.
|
||||
"""
|
||||
base_msgs = [
|
||||
{"role": "user", "content": "_"},
|
||||
{"role": "assistant", "content": "_ROLLED_"},
|
||||
]
|
||||
ext_msgs = base_msgs + [
|
||||
{"role": "user", "content": interrupt_user_text},
|
||||
{"role": "assistant", "content": assistant_prefill},
|
||||
]
|
||||
base_str = tok.apply_chat_template(base_msgs, tokenize=False)
|
||||
ext_str = tok.apply_chat_template(ext_msgs, tokenize=False, continue_final_message=True)
|
||||
if not ext_str.startswith(base_str):
|
||||
# template inserts trailing whitespace in `base_str` — strip and retry
|
||||
base_str = base_str.rstrip()
|
||||
if not ext_str.startswith(base_str):
|
||||
raise ValueError(
|
||||
"chat template not prefix-stable: cannot derive interrupt suffix. "
|
||||
"Set use_chat_interrupt=False or extend this helper."
|
||||
)
|
||||
suffix_str = ext_str[len(base_str):]
|
||||
return tok.encode(suffix_str, add_special_tokens=False)
|
||||
|
||||
|
||||
def _is_thinking(seq_ids: Tensor, think_id: int, unthink_id: int) -> bool:
|
||||
"""True iff the last `<think>` in seq_ids is after the last `</think>`."""
|
||||
if think_id is None or unthink_id is None:
|
||||
@@ -89,6 +126,9 @@ def branch_pmass(
|
||||
handle_thinking: bool = True,
|
||||
end_think_str: str = "</think>",
|
||||
force_close_str: str = "\nI should answer now.", # see tinymfv/guided.py
|
||||
interrupt_suffix_ids: Sequence[int] | None = None, # if set, overrides prefill_str + thinking
|
||||
# close logic and uses chat-template-correct interrupt
|
||||
# turn (built via build_chat_interrupt_suffix).
|
||||
device: str | torch.device = "cuda",
|
||||
) -> dict:
|
||||
"""Returns dict with parallel lists indexed by fork point:
|
||||
@@ -113,6 +153,13 @@ def branch_pmass(
|
||||
tok.encode(force_close_str + end_think_str, add_special_tokens=False),
|
||||
device=device, dtype=torch.long,
|
||||
) if handle_thinking else None
|
||||
# Chat-template interrupt mode: replaces both close_t and pre_t with one
|
||||
# template-correct turn boundary (<eot><user>...<eot><assistant>{prefill}).
|
||||
# Disables thinking-detection (irrelevant: we are starting a new turn).
|
||||
interrupt_t = (
|
||||
torch.tensor(list(interrupt_suffix_ids), device=device, dtype=torch.long)
|
||||
if interrupt_suffix_ids is not None else None
|
||||
)
|
||||
a_t = torch.tensor(list(a_ids), dtype=torch.long, device=device)
|
||||
b_t = torch.tensor(list(b_ids), dtype=torch.long, device=device)
|
||||
all_t = torch.cat([a_t, b_t])
|
||||
@@ -130,6 +177,10 @@ def branch_pmass(
|
||||
continue
|
||||
prefix = rolled[:t]
|
||||
seq_so_far = torch.cat([pids, prefix])
|
||||
if interrupt_t is not None:
|
||||
new_tail = interrupt_t
|
||||
thinking = False # not meaningful in chat-interrupt mode
|
||||
else:
|
||||
thinking = (handle_thinking and think_id is not None and unthink_id is not None
|
||||
and _is_thinking(seq_so_far, think_id, unthink_id))
|
||||
new_tail = torch.cat([close_t, pre_t]) if (thinking and close_t is not None) else pre_t
|
||||
|
||||
@@ -158,12 +158,14 @@ def measure_kl(
|
||||
base_full = torch.cat([pids.to(device), base_gen])
|
||||
decoded_base = tok.decode(base_full, skip_special_tokens=False)
|
||||
decoded_steer = tok.decode(full_ids, skip_special_tokens=False)
|
||||
# SHOULD: BASE and STEER both coherent; STEER differs from BASE but does not collapse.
|
||||
# Truncate to keep iso-KL bracket logs scannable; full text in trajectory.json/debug_first.
|
||||
head = 400
|
||||
logger.info(
|
||||
f"EXPECT: same prompt under c=0 vs c={v.cfg.coeff:+.4f}; both coherent; "
|
||||
"steered should differ from base but not collapse.\n"
|
||||
f"\n=== CALIBRATE demo trace (T={T}) ===\n"
|
||||
f"--- BASE (c=0) ---\n{decoded_base}\n"
|
||||
f"\n--- STEER (c={v.cfg.coeff:+.4f}) ---\n{decoded_steer}\n"
|
||||
f"SHOULD: c=0 vs c={v.cfg.coeff:+.4f} both coherent, steer differs but does not collapse.\n"
|
||||
f"=== CALIBRATE demo (T={T}, first {head} chars each) ===\n"
|
||||
f"-- BASE : {decoded_base[:head]!r}\n"
|
||||
f"-- STEER : {decoded_steer[:head]!r}\n"
|
||||
f"=== /CALIBRATE ==="
|
||||
)
|
||||
kls = _kl_generated_incremental(v, model, pids, gen, device)
|
||||
|
||||
+19
-1
@@ -15,7 +15,11 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from iso_kl_figure import (
|
||||
SteeringConfig, MeanDiffC, train, measure_kl, attach, detach,
|
||||
)
|
||||
from iso_kl_figure.branch_pmass import branch_pmass, collect_choice_token_ids
|
||||
from iso_kl_figure.branch_pmass import (
|
||||
branch_pmass,
|
||||
build_chat_interrupt_suffix,
|
||||
collect_choice_token_ids,
|
||||
)
|
||||
|
||||
|
||||
MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM"
|
||||
@@ -61,8 +65,22 @@ def test_pipeline_smoke():
|
||||
v.cfg.coeff = 5.0
|
||||
p_steer = branch_pmass(v, model, tok, pids, rolled, fork, prefill,
|
||||
a_ids, b_ids, device="cpu")
|
||||
tok.chat_template = (
|
||||
"{% for message in messages %}"
|
||||
"<|{{ message['role'] }}|>\n{{ message['content'] }}\n"
|
||||
"{% endfor %}"
|
||||
"{% if add_generation_prompt %}<|assistant|>\n{% endif %}"
|
||||
)
|
||||
interrupt_ids = build_chat_interrupt_suffix(tok, "Final answer?", '{"Answer": ')
|
||||
assert tok.decode(interrupt_ids).endswith('{"Answer": ')
|
||||
p_interrupt = branch_pmass(
|
||||
v, model, tok, pids, rolled, fork, prefill,
|
||||
a_ids, b_ids, interrupt_suffix_ids=interrupt_ids, device="cpu",
|
||||
)
|
||||
for x in p_zero["pmass"] + p_steer["pmass"]:
|
||||
assert 0.0 <= x <= 1.0, f"pmass out of [0,1]: {x}"
|
||||
for x in p_interrupt["pmass"]:
|
||||
assert 0.0 <= x <= 1.0, f"chat-interrupt pmass out of [0,1]: {x}"
|
||||
diff = max(abs(a - b) for a, b in zip(p_zero["pmass"], p_steer["pmass"]))
|
||||
assert diff > 1e-8, "pmass invariant to coeff -- hook is dead"
|
||||
# p_true should be in [0,1] or NaN
|
||||
|
||||
Reference in New Issue
Block a user