This commit is contained in:
wassname
2026-05-08 11:25:10 +08:00
parent 2a69829612
commit 77b296cc75
10 changed files with 4176 additions and 46 deletions
+3 -1
View File
@@ -1,4 +1,6 @@
# iso-kl-figure (short version)
# calibrating steering overl ong trajectories by normalising KL outliers
![alt text](figs/zoom_in.png)
## The problem
File diff suppressed because it is too large Load Diff
BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 190 KiB

+1 -1
View File
@@ -11,7 +11,7 @@ test:
uv run --extra all pytest -q
# Run one (model, method, seed, window) cell end-to-end (calibrate + trajectory + pmass).
cell model="Qwen/Qwen3.5-0.8B" method="mean_diff" seed="0" window="50":
cell model="Qwen/Qwen3.5-0.8B" method="mean_diff" seed="0" window="512":
uv run --extra all python scripts/run_cell.py \
--model {{model}} --method {{method}} --seed {{seed}} --window {{window}}
+37 -16
View File
@@ -65,14 +65,19 @@ except Exception:
class Args:
runs_root: str = "outputs"
out: str = "figs"
window: int = 50 # only this window enters the figure (need long-enough traj for rolling-16)
roll: int = 16 # smoothing window for KL trajectory
window: int = 512 # only this window enters the figure
roll: int = 65 # smoothing window for KL trajectory
alphas: tuple[str, ...] = ("0.25", "0.5", "0.75", "1.0", "1.5", "2.0", "4.0")
kl_ymax: float = 6.0
model_contains: str = ""
kl_only: bool = False
spaghetti: bool = False # plot individual trajectories instead of p10..p90 band
color_by_pmass: bool = False # color KL spaghetti lines by paired pmass (requires pmass_eval)
line_alpha: float | None = None # per-line alpha override; None = auto clip(2.5/n,.08,.35)
line_lw: float = 0.18 # per-trajectory linewidth; full opacity needs very thin lines
median_lw: float = 0.75 # median linewidth
quantile_lines: bool = False # clean summary: p10/p50/p90 lines, no fill/spaghetti
mark_t: int = -1 # optional vertical token marker; -1 disables
def _rolling_mean(x: np.ndarray, w: int) -> np.ndarray:
@@ -138,20 +143,20 @@ def _draw_kl_panel(ax, K: np.ndarray, a: Args, P: np.ndarray | None = None) -> N
pts = np.column_stack([xs, traj])
segs = np.stack([pts[:-1], pts[1:]], axis=1)
lc = LineCollection(segs, cmap=cmap, norm=plt.Normalize(0, 1),
linewidths=0.7, alpha=0.55)
linewidths=a.line_lw, alpha=float(a.line_alpha if a.line_alpha is not None else np.clip(2.5/max(K.shape[0],1), 0.08, 0.35)))
lc.set_array(pmass_row[:-1])
ax.add_collection(lc)
ax.set_xlim(xs[0], xs[-1])
med = np.nanmedian(Kp, axis=0)
ax.plot(xs, med, color="k", lw=1.6)
ax.plot(xs, med, color="k", lw=a.median_lw)
else:
crossed = (K > 1.0).any(axis=1)
for traj in Kp[~crossed]:
ax.plot(xs, traj, color="0.55", lw=0.6, alpha=0.6)
ax.plot(xs, traj, color="0.55", lw=a.line_lw, alpha=0.5)
for traj in Kp[crossed]:
ax.plot(xs, traj, color="C3", lw=0.6, alpha=0.6)
ax.plot(xs, traj, color="C3", lw=a.line_lw, alpha=0.5)
med = np.nanmedian(Kp, axis=0)
ax.plot(xs, med, color="k", lw=1.6)
ax.plot(xs, med, color="k", lw=a.median_lw)
frac = float(crossed.mean())
ax.text(0.97, 0.97, f"{frac:.0%} cross KL=1",
transform=ax.transAxes, ha="right", va="top",
@@ -163,6 +168,14 @@ def _draw_kl_panel(ax, K: np.ndarray, a: Args, P: np.ndarray | None = None) -> N
p50s = _rolling_mean(p50, a.roll)
p10s = _rolling_mean(p10, a.roll)
p90s = _rolling_mean(p90, a.roll)
if a.quantile_lines:
# standard quantile fan: outer band light, inner band dark, p50 line
p25s = _rolling_mean(np.nanpercentile(K, 25, axis=0), a.roll)
p75s = _rolling_mean(np.nanpercentile(K, 75, axis=0), a.roll)
ax.fill_between(xs, p10s, p90s, alpha=0.15, color="C0", lw=0, label="p10..p90")
ax.fill_between(xs, p25s, p75s, alpha=0.32, color="C0", lw=0, label="p25..p75")
ax.plot(xs, p50s, color="C0", lw=1.5, label="p50")
else:
ax.fill_between(xs, p10s, p90s, alpha=0.25, color="C0", lw=0)
ax.plot(xs, p50s, color="C0", lw=1.6)
@@ -170,14 +183,17 @@ def _draw_kl_panel(ax, K: np.ndarray, a: Args, P: np.ndarray | None = None) -> N
def make_kl_figure(cells: list[dict], a: Args, out_path: Path) -> None:
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, len(a.alphas), figsize=(4.0 * len(a.alphas), 3.2),
fig, axes = plt.subplots(1, len(a.alphas), figsize=(3.7 * len(a.alphas), 3.4),
sharex=True, sharey=True, squeeze=False, constrained_layout=True)
label = a.model_contains or "all models"
mode = "individual trajectories (red=ever crossed KL=1)" if a.spaghetti \
n_max = max((_pool_kl(cells, alpha, T=a.window).shape[0] for alpha in a.alphas), default=0)
mode = f"individual trajectories, roll={a.roll}, color=pmass" if (a.spaghetti and a.color_by_pmass) \
else "individual trajectories (red=ever crossed KL=1)" if a.spaghetti \
else f"shaded quantiles p10/p25/p50/p75/p90, roll={a.roll}" if a.quantile_lines \
else f"p50 + p10..p90 band, smoothed rolling-{a.roll}"
fig.suptitle(
f"KL trajectory on N=8 held-out long-form prompts ({label})\n"
f"{mode}. Solid line: KL=1 nat. Dotted v-line: t=20.",
f"KL trajectory on N={n_max} held-out long-form prompts ({label})\n"
f"{mode}. Solid horizontal: KL=1 nat.",
fontsize=10,
)
@@ -202,27 +218,32 @@ def make_kl_figure(cells: list[dict], a: Args, out_path: Path) -> None:
Pe = _pool_pmass_eval(cells, alpha, _first_fork(cells), T=a.window) if a.color_by_pmass else None
_draw_kl_panel(ax, K, a, P=Pe)
if y_max >= 1.0:
ax.axhline(1.0, color="k", lw=1.0)
ax.axhline(1.0, color="k", lw=0.7)
else:
ax.text(0.98, 0.92, "KL=1 off-scale", transform=ax.transAxes,
ha="right", va="top", fontsize=8, color="0.25")
ax.axvline(20, color="k", ls=":", lw=0.8)
if a.mark_t >= 0:
ax.axvline(a.mark_t, color="k", ls=":", lw=0.6)
ax.set_title(rf"$\alpha = {alpha}$ (n={K.shape[0]} traj)")
ax.set_xlim(0, x_max)
ax.set_ylim(0, y_max)
ax.set_xlabel("token")
if j == 0:
ax.set_ylabel("KL(steered || base) [nats]")
ax.set_ylabel("KL")
if a.color_by_pmass:
from matplotlib.cm import ScalarMappable
from matplotlib.colors import LinearSegmentedColormap, Normalize
cmap = LinearSegmentedColormap.from_list("alive", ["#c0392b", "#f1c40f", "#27ae60"])
sm = ScalarMappable(norm=Normalize(0, 1), cmap=cmap); sm.set_array([])
cbar = fig.colorbar(sm, ax=axes[0, :].tolist(), location="right", shrink=0.75, pad=0.01, fraction=0.015)
cbar.set_label("pmass (0=dead, 1=alive)")
cbar = fig.colorbar(sm, ax=axes[0, :].tolist(), location="right", shrink=0.72, pad=0.015, fraction=0.012)
cbar.set_label("pmass", labelpad=2)
fig.savefig(out_path, dpi=160, bbox_inches="tight")
if a.quantile_lines:
q_path = out_path.with_name(out_path.stem + "_quantile_lines" + out_path.suffix)
fig.savefig(q_path, dpi=160, bbox_inches="tight")
logger.info(f"KL quantile-line figure -> {q_path}")
logger.info(f"KL-only figure -> {out_path}")
+11 -4
View File
@@ -16,6 +16,10 @@ class Args:
threshold: float = 0.95
out_name: str = "figs_auto"
model_contains: str = ""
line_alpha: float | None = None # forwarded to spaghetti+aggregate
roll: int = 65
line_lw: float = 0.5
quantile_lines: bool = False
def _ensure_single_run_root(run_dir: Path) -> Path:
@@ -72,18 +76,21 @@ def main(a: Args) -> None:
"--window", str(__import__("json").loads(calib.read_text())["window"]),
"--threshold", str(a.threshold),
"--model-contains", model_filter,
], repo_root)
"--roll", str(a.roll),
"--line-lw", str(a.line_lw),
] + (["--line-alpha", str(a.line_alpha)] if a.line_alpha is not None else []), repo_root)
_run([
sys.executable,
"scripts/aggregate.py",
"--runs-root", str(single_root),
"--out", str(out_dir / "aggregate"),
"--window", str(__import__("json").loads(calib.read_text())["window"]),
"--spaghetti",
"--color-by-pmass",
"--kl-only",
"--model-contains", model_filter,
], repo_root)
"--roll", str(a.roll),
"--line-lw", str(a.line_lw),
] + (["--spaghetti", "--color-by-pmass"] if not a.quantile_lines else ["--quantile-lines"]) \
+ (["--line-alpha", str(a.line_alpha)] if a.line_alpha is not None else []), repo_root)
pngs = [
out_dir / "survival" / "survival_pmass_eval.png",
+93 -12
View File
@@ -57,12 +57,13 @@ from pathlib import Path
import torch
import tyro
from loguru import logger
from tqdm.auto import tqdm
from iso_kl_figure import (
SteeringConfig, MeanDiffC, PCAC, DirectionalAblationC,
train, calibrate_iso_kl, measure_kl,
)
from iso_kl_figure.branch_pmass import branch_pmass, collect_choice_token_ids
from iso_kl_figure.branch_pmass import branch_pmass, build_chat_interrupt_suffix, collect_choice_token_ids
from iso_kl_figure.calibrate import _eos_token_ids
from iso_kl_figure.target import _get_blocks
@@ -103,6 +104,14 @@ EVAL_PROMPTS = [
"Was Shakespeare a contemporary of Queen Elizabeth I of England?",
"Is the Pacific Ocean larger in surface area than the Atlantic Ocean?",
"Does a transformer architecture rely on self-attention rather than recurrence for sequence modeling?",
"Is honey produced by bees primarily from nectar collected from flowers?",
"Did Ada Lovelace write what is often considered the first published computer algorithm?",
"Is the boiling point of pure water at sea level approximately 100 degrees Celsius?",
"Does cross-entropy loss reduce to negative log likelihood for a single correct class label?",
"Was the printing press invented by Johannes Gutenberg in the 15th century?",
"Is the speed of light in a vacuum the same for all inertial observers?",
"Do antibiotics kill viruses in addition to bacteria?",
"Is the chemical symbol for gold 'Au' on the periodic table?",
]
# pmass diagnostic prompts share the same schema as EVAL above.
@@ -144,7 +153,7 @@ class Args:
model: str
method: str
seed: int = 0
window: int = 50
window: int = 512
run_id: str = ""
layer_frac: float = 0.6
target_kl: float = 1.0
@@ -159,6 +168,13 @@ class Args:
skip_pmass_qa: bool = False # skip yes/no pmass loop, only do paired pmass_eval
render_figs: bool = False # render single-run survival + spaghetti + KL pngs
render_threshold: float = 0.95
chat_interrupt: bool = False # use chat-template-correct interrupt turn for pmass scoring
# (in-distribution; recommended over raw splice).
interrupt_user_text: str = (
"Given the conversation so far, what is your final answer?"
" Respond with JSON: {\"Answer\": true} or {\"Answer\": false}."
)
interrupt_prefill: str = '{"Answer": '
def _set_seed(s: int):
@@ -169,6 +185,28 @@ def _set_seed(s: int):
torch.cuda.manual_seed_all(s)
def _log_fork_summary(stage: str, alpha: float, gen, gen_text: str,
fork_points: list[int], pm: dict) -> None:
"""One INFO line + one debug table per (stage, alpha) instead of N forks * 1 line.
SHOULD: at alpha~1, pmass ~> 0.5 across forks (model still respects schema).
ELSE: prefill broken or schema drift -- inspect debug_first in pmass.json.
"""
from tabulate import tabulate
pmv = pm["pmass"]
ptv = pm["p_true"]
pm_min = min(pmv) if pmv else float("nan")
pm_med = sorted(pmv)[len(pmv) // 2] if pmv else float("nan")
logger.info(
f" alpha={alpha} {stage}[0] gen_len={gen.shape[0]} pmass(med/min)="
f"{pm_med:.2f}/{pm_min:.2f} | text[:80]={gen_text[:80]!r}"
)
rows = [(t, f"{p:.3f}", f"{q:.3f}", repr(s))
for t, p, q, s in zip(fork_points, pmv, ptv, pm["argmax_str"])]
table = tabulate(rows, headers=["t", "pmass", "p_true", "argmax"], tablefmt="plain")
logger.debug(f"\nfork table [{stage} alpha={alpha}]\n{table}")
def _render_chat(tok, msgs: list[dict[str, str]], add_generation_prompt: bool) -> str:
if tok.chat_template is None:
rendered = []
@@ -279,6 +317,15 @@ def main(a: Args):
gen_lens_eval: dict[str, list] = {} # T per (alpha, eval-prompt)
debug_first: dict[str, dict] = {} # first prompt per alpha: gen_text + top-5 at each fork
# Append-only JSONL of every generation: full text + paired pmass.
# Crash-safe: each line flushed before next prompt. Overwritten only at run start.
gens_path = out_dir / "gens.jsonl"
if gens_path.exists():
gens_path.unlink()
def _append_gen(rec: dict) -> None:
with open(gens_path, "a") as f:
f.write(json.dumps(rec) + "\n")
def write_eval_outputs() -> None:
(out_dir / "trajectory.json").write_text(json.dumps({
"fork_points_full": list(range(a.window)),
@@ -317,7 +364,18 @@ def main(a: Args):
tok(s, return_tensors="pt", add_special_tokens=False).input_ids[0]
for s in eval_prompt_strs
]
for alpha in a.alphas:
interrupt_suffix_ids = None
interrupt_suffix_decoded = ""
if a.chat_interrupt:
interrupt_suffix_ids = build_chat_interrupt_suffix(
tok, a.interrupt_user_text, a.interrupt_prefill,
)
# SHOULD: decoded suffix ends with prefill literal. ELSE template quirk.
interrupt_suffix_decoded = tok.decode(interrupt_suffix_ids)
logger.info(f"chat_interrupt suffix ({len(interrupt_suffix_ids)} tok): {interrupt_suffix_decoded!r}")
if not interrupt_suffix_decoded.endswith(a.interrupt_prefill):
logger.warning(f"interrupt suffix does not end with prefill={a.interrupt_prefill!r}")
for alpha in tqdm(a.alphas, desc=f"alphas[{a.run_id}]", mininterval=60):
v.cfg.coeff = alpha * c_star
logger.info(f"=== eval alpha={alpha} c={v.cfg.coeff:+.4f} ===")
m = measure_kl(v, model, tok, eval_ids_list, T=a.window, device=a.device)
@@ -351,6 +409,7 @@ def main(a: Args):
v, model, tok, ids, gen, fork_points,
PREFILL_STR, a_ids, b_ids,
rollout_cache=gen_out.past_key_values,
interrupt_suffix_ids=interrupt_suffix_ids,
device=a.device,
)
pm_for_alpha.append(pm["pmass"])
@@ -363,17 +422,28 @@ def main(a: Args):
else "other"
for s in pm["argmax_str"]
])
# debug dump for first QA prompt only
if q_idx == 0:
# save full gen text + paired pmass for every QA prompt
gen_text = tok.decode(gen, skip_special_tokens=False)
_append_gen({
"alpha": float(alpha), "kind": "qa", "prompt_idx": q_idx,
"score_mode": "chat_interrupt" if a.chat_interrupt else "raw_splice",
"interrupt_user_text": a.interrupt_user_text if a.chat_interrupt else "",
"interrupt_prefill": a.interrupt_prefill if a.chat_interrupt else "",
"interrupt_suffix_ids": list(interrupt_suffix_ids) if interrupt_suffix_ids is not None else [],
"interrupt_suffix_decoded": interrupt_suffix_decoded,
"prompt": prompt_str, "gen_text": gen_text, "gen_len": int(gen.shape[0]),
"fork_points": list(fork_points),
"pmass": pm["pmass"], "p_true": pm["p_true"],
"argmax_str": pm["argmax_str"], "was_thinking": pm["was_thinking"],
})
# debug dump for first QA prompt only (kept for backward compat)
if q_idx == 0:
debug_first.setdefault(str(alpha), {})["qa"] = {
"prompt": prompt_str, "gen_text": gen_text, "gen_len": int(gen.shape[0]),
"pmass_per_fork": pm["pmass"], "p_true_per_fork": pm["p_true"],
"argmax_per_fork": pm["argmax_str"],
}
logger.info(f" [debug] alpha={alpha} qa[0] gen_len={gen.shape[0]} text[:120]={gen_text[:120]!r}")
for t, pmv, ptv, am in zip(fork_points, pm["pmass"], pm["p_true"], pm["argmax_str"]):
logger.info(f" t={t:>3} pmass={pmv:.3f} p_true={ptv:.3f} argmax={am!r}")
_log_fork_summary("qa", alpha, gen, gen_text, fork_points, pm)
pmass_all[str(alpha)] = pm_for_alpha
p_true_all[str(alpha)] = pt_for_alpha
argmax_all[str(alpha)] = ax_for_alpha
@@ -407,19 +477,30 @@ def main(a: Args):
v, model, tok, ids, gen, fork_points,
PREFILL_STR, a_ids, b_ids,
rollout_cache=gen_out.past_key_values,
interrupt_suffix_ids=interrupt_suffix_ids,
device=a.device,
)
pm_eval_for_alpha.append(pm["pmass"])
if p_idx == 0:
gen_text = tok.decode(gen, skip_special_tokens=False)
_append_gen({
"alpha": float(alpha), "kind": "eval", "prompt_idx": p_idx,
"score_mode": "chat_interrupt" if a.chat_interrupt else "raw_splice",
"interrupt_user_text": a.interrupt_user_text if a.chat_interrupt else "",
"interrupt_prefill": a.interrupt_prefill if a.chat_interrupt else "",
"interrupt_suffix_ids": list(interrupt_suffix_ids) if interrupt_suffix_ids is not None else [],
"interrupt_suffix_decoded": interrupt_suffix_decoded,
"prompt": prompt_str, "gen_text": gen_text, "gen_len": int(gen.shape[0]),
"fork_points": list(fork_points),
"pmass": pm["pmass"], "p_true": pm["p_true"],
"argmax_str": pm["argmax_str"], "was_thinking": pm["was_thinking"],
})
if p_idx == 0:
debug_first.setdefault(str(alpha), {})["eval"] = {
"prompt": prompt_str, "gen_text": gen_text, "gen_len": int(gen.shape[0]),
"pmass_per_fork": pm["pmass"], "p_true_per_fork": pm["p_true"],
"argmax_per_fork": pm["argmax_str"],
}
logger.info(f" [debug] alpha={alpha} eval[0] gen_len={gen.shape[0]} text[:120]={gen_text[:120]!r}")
for t, pmv, ptv, am in zip(fork_points, pm["pmass"], pm["p_true"], pm["argmax_str"]):
logger.info(f" t={t:>3} pmass={pmv:.3f} p_true={ptv:.3f} argmax={am!r}")
_log_fork_summary("eval", alpha, gen, gen_text, fork_points, pm)
pmass_eval_all[str(alpha)] = pm_eval_for_alpha
gen_lens_eval[str(alpha)] = gen_lens_eval_alpha
# SHOULD: at alpha=1, mean(pmass at t=0) > 0.5 (model still respects schema).
+51
View File
@@ -55,6 +55,43 @@ def collect_choice_token_ids(
return _ids(a_words), _ids(b_words)
def build_chat_interrupt_suffix(
tok,
interrupt_user_text: str,
assistant_prefill: str,
) -> list[int]:
"""Construct token ids for: <eot> <user>{question}<eot> <assistant>{prefill}.
Used to interrupt an in-progress assistant rollout with a follow-up question
in a chat-template-correct way (vs. raw token splice). Detects the suffix
by string-diffing the rendered template before vs. after appending the
interrupt turn -- no template-specific special-token guessing.
SHOULD: tok.decode(suffix_ids) ends with assistant_prefill. ELSE template
quirk: inspect tok.chat_template.
"""
base_msgs = [
{"role": "user", "content": "_"},
{"role": "assistant", "content": "_ROLLED_"},
]
ext_msgs = base_msgs + [
{"role": "user", "content": interrupt_user_text},
{"role": "assistant", "content": assistant_prefill},
]
base_str = tok.apply_chat_template(base_msgs, tokenize=False)
ext_str = tok.apply_chat_template(ext_msgs, tokenize=False, continue_final_message=True)
if not ext_str.startswith(base_str):
# template inserts trailing whitespace in `base_str` — strip and retry
base_str = base_str.rstrip()
if not ext_str.startswith(base_str):
raise ValueError(
"chat template not prefix-stable: cannot derive interrupt suffix. "
"Set use_chat_interrupt=False or extend this helper."
)
suffix_str = ext_str[len(base_str):]
return tok.encode(suffix_str, add_special_tokens=False)
def _is_thinking(seq_ids: Tensor, think_id: int, unthink_id: int) -> bool:
"""True iff the last `<think>` in seq_ids is after the last `</think>`."""
if think_id is None or unthink_id is None:
@@ -89,6 +126,9 @@ def branch_pmass(
handle_thinking: bool = True,
end_think_str: str = "</think>",
force_close_str: str = "\nI should answer now.", # see tinymfv/guided.py
interrupt_suffix_ids: Sequence[int] | None = None, # if set, overrides prefill_str + thinking
# close logic and uses chat-template-correct interrupt
# turn (built via build_chat_interrupt_suffix).
device: str | torch.device = "cuda",
) -> dict:
"""Returns dict with parallel lists indexed by fork point:
@@ -113,6 +153,13 @@ def branch_pmass(
tok.encode(force_close_str + end_think_str, add_special_tokens=False),
device=device, dtype=torch.long,
) if handle_thinking else None
# Chat-template interrupt mode: replaces both close_t and pre_t with one
# template-correct turn boundary (<eot><user>...<eot><assistant>{prefill}).
# Disables thinking-detection (irrelevant: we are starting a new turn).
interrupt_t = (
torch.tensor(list(interrupt_suffix_ids), device=device, dtype=torch.long)
if interrupt_suffix_ids is not None else None
)
a_t = torch.tensor(list(a_ids), dtype=torch.long, device=device)
b_t = torch.tensor(list(b_ids), dtype=torch.long, device=device)
all_t = torch.cat([a_t, b_t])
@@ -130,6 +177,10 @@ def branch_pmass(
continue
prefix = rolled[:t]
seq_so_far = torch.cat([pids, prefix])
if interrupt_t is not None:
new_tail = interrupt_t
thinking = False # not meaningful in chat-interrupt mode
else:
thinking = (handle_thinking and think_id is not None and unthink_id is not None
and _is_thinking(seq_so_far, think_id, unthink_id))
new_tail = torch.cat([close_t, pre_t]) if (thinking and close_t is not None) else pre_t
+7 -5
View File
@@ -158,12 +158,14 @@ def measure_kl(
base_full = torch.cat([pids.to(device), base_gen])
decoded_base = tok.decode(base_full, skip_special_tokens=False)
decoded_steer = tok.decode(full_ids, skip_special_tokens=False)
# SHOULD: BASE and STEER both coherent; STEER differs from BASE but does not collapse.
# Truncate to keep iso-KL bracket logs scannable; full text in trajectory.json/debug_first.
head = 400
logger.info(
f"EXPECT: same prompt under c=0 vs c={v.cfg.coeff:+.4f}; both coherent; "
"steered should differ from base but not collapse.\n"
f"\n=== CALIBRATE demo trace (T={T}) ===\n"
f"--- BASE (c=0) ---\n{decoded_base}\n"
f"\n--- STEER (c={v.cfg.coeff:+.4f}) ---\n{decoded_steer}\n"
f"SHOULD: c=0 vs c={v.cfg.coeff:+.4f} both coherent, steer differs but does not collapse.\n"
f"=== CALIBRATE demo (T={T}, first {head} chars each) ===\n"
f"-- BASE : {decoded_base[:head]!r}\n"
f"-- STEER : {decoded_steer[:head]!r}\n"
f"=== /CALIBRATE ==="
)
kls = _kl_generated_incremental(v, model, pids, gen, device)
+19 -1
View File
@@ -15,7 +15,11 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
from iso_kl_figure import (
SteeringConfig, MeanDiffC, train, measure_kl, attach, detach,
)
from iso_kl_figure.branch_pmass import branch_pmass, collect_choice_token_ids
from iso_kl_figure.branch_pmass import (
branch_pmass,
build_chat_interrupt_suffix,
collect_choice_token_ids,
)
MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM"
@@ -61,8 +65,22 @@ def test_pipeline_smoke():
v.cfg.coeff = 5.0
p_steer = branch_pmass(v, model, tok, pids, rolled, fork, prefill,
a_ids, b_ids, device="cpu")
tok.chat_template = (
"{% for message in messages %}"
"<|{{ message['role'] }}|>\n{{ message['content'] }}\n"
"{% endfor %}"
"{% if add_generation_prompt %}<|assistant|>\n{% endif %}"
)
interrupt_ids = build_chat_interrupt_suffix(tok, "Final answer?", '{"Answer": ')
assert tok.decode(interrupt_ids).endswith('{"Answer": ')
p_interrupt = branch_pmass(
v, model, tok, pids, rolled, fork, prefill,
a_ids, b_ids, interrupt_suffix_ids=interrupt_ids, device="cpu",
)
for x in p_zero["pmass"] + p_steer["pmass"]:
assert 0.0 <= x <= 1.0, f"pmass out of [0,1]: {x}"
for x in p_interrupt["pmass"]:
assert 0.0 <= x <= 1.0, f"chat-interrupt pmass out of [0,1]: {x}"
diff = max(abs(a - b) for a, b in zip(p_zero["pmass"], p_steer["pmass"]))
assert diff > 1e-8, "pmass invariant to coeff -- hook is dead"
# p_true should be in [0,1] or NaN