mirror of
https://github.com/wassname/isokl_steering_calibration.git
synced 2026-06-27 15:15:52 +08:00
write up
This commit is contained in:
@@ -1,4 +1,6 @@
|
|||||||
# iso-kl-figure (short version)
|
# calibrating steering overl ong trajectories by normalising KL outliers
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
## The problem
|
## The problem
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
Binary file not shown.
|
After Width: | Height: | Size: 190 KiB |
@@ -11,7 +11,7 @@ test:
|
|||||||
uv run --extra all pytest -q
|
uv run --extra all pytest -q
|
||||||
|
|
||||||
# Run one (model, method, seed, window) cell end-to-end (calibrate + trajectory + pmass).
|
# Run one (model, method, seed, window) cell end-to-end (calibrate + trajectory + pmass).
|
||||||
cell model="Qwen/Qwen3.5-0.8B" method="mean_diff" seed="0" window="50":
|
cell model="Qwen/Qwen3.5-0.8B" method="mean_diff" seed="0" window="512":
|
||||||
uv run --extra all python scripts/run_cell.py \
|
uv run --extra all python scripts/run_cell.py \
|
||||||
--model {{model}} --method {{method}} --seed {{seed}} --window {{window}}
|
--model {{model}} --method {{method}} --seed {{seed}} --window {{window}}
|
||||||
|
|
||||||
|
|||||||
+40
-19
@@ -65,14 +65,19 @@ except Exception:
|
|||||||
class Args:
|
class Args:
|
||||||
runs_root: str = "outputs"
|
runs_root: str = "outputs"
|
||||||
out: str = "figs"
|
out: str = "figs"
|
||||||
window: int = 50 # only this window enters the figure (need long-enough traj for rolling-16)
|
window: int = 512 # only this window enters the figure
|
||||||
roll: int = 16 # smoothing window for KL trajectory
|
roll: int = 65 # smoothing window for KL trajectory
|
||||||
alphas: tuple[str, ...] = ("0.25", "0.5", "0.75", "1.0", "1.5", "2.0", "4.0")
|
alphas: tuple[str, ...] = ("0.25", "0.5", "0.75", "1.0", "1.5", "2.0", "4.0")
|
||||||
kl_ymax: float = 6.0
|
kl_ymax: float = 6.0
|
||||||
model_contains: str = ""
|
model_contains: str = ""
|
||||||
kl_only: bool = False
|
kl_only: bool = False
|
||||||
spaghetti: bool = False # plot individual trajectories instead of p10..p90 band
|
spaghetti: bool = False # plot individual trajectories instead of p10..p90 band
|
||||||
color_by_pmass: bool = False # color KL spaghetti lines by paired pmass (requires pmass_eval)
|
color_by_pmass: bool = False # color KL spaghetti lines by paired pmass (requires pmass_eval)
|
||||||
|
line_alpha: float | None = None # per-line alpha override; None = auto clip(2.5/n,.08,.35)
|
||||||
|
line_lw: float = 0.18 # per-trajectory linewidth; full opacity needs very thin lines
|
||||||
|
median_lw: float = 0.75 # median linewidth
|
||||||
|
quantile_lines: bool = False # clean summary: p10/p50/p90 lines, no fill/spaghetti
|
||||||
|
mark_t: int = -1 # optional vertical token marker; -1 disables
|
||||||
|
|
||||||
|
|
||||||
def _rolling_mean(x: np.ndarray, w: int) -> np.ndarray:
|
def _rolling_mean(x: np.ndarray, w: int) -> np.ndarray:
|
||||||
@@ -138,20 +143,20 @@ def _draw_kl_panel(ax, K: np.ndarray, a: Args, P: np.ndarray | None = None) -> N
|
|||||||
pts = np.column_stack([xs, traj])
|
pts = np.column_stack([xs, traj])
|
||||||
segs = np.stack([pts[:-1], pts[1:]], axis=1)
|
segs = np.stack([pts[:-1], pts[1:]], axis=1)
|
||||||
lc = LineCollection(segs, cmap=cmap, norm=plt.Normalize(0, 1),
|
lc = LineCollection(segs, cmap=cmap, norm=plt.Normalize(0, 1),
|
||||||
linewidths=0.7, alpha=0.55)
|
linewidths=a.line_lw, alpha=float(a.line_alpha if a.line_alpha is not None else np.clip(2.5/max(K.shape[0],1), 0.08, 0.35)))
|
||||||
lc.set_array(pmass_row[:-1])
|
lc.set_array(pmass_row[:-1])
|
||||||
ax.add_collection(lc)
|
ax.add_collection(lc)
|
||||||
ax.set_xlim(xs[0], xs[-1])
|
ax.set_xlim(xs[0], xs[-1])
|
||||||
med = np.nanmedian(Kp, axis=0)
|
med = np.nanmedian(Kp, axis=0)
|
||||||
ax.plot(xs, med, color="k", lw=1.6)
|
ax.plot(xs, med, color="k", lw=a.median_lw)
|
||||||
else:
|
else:
|
||||||
crossed = (K > 1.0).any(axis=1)
|
crossed = (K > 1.0).any(axis=1)
|
||||||
for traj in Kp[~crossed]:
|
for traj in Kp[~crossed]:
|
||||||
ax.plot(xs, traj, color="0.55", lw=0.6, alpha=0.6)
|
ax.plot(xs, traj, color="0.55", lw=a.line_lw, alpha=0.5)
|
||||||
for traj in Kp[crossed]:
|
for traj in Kp[crossed]:
|
||||||
ax.plot(xs, traj, color="C3", lw=0.6, alpha=0.6)
|
ax.plot(xs, traj, color="C3", lw=a.line_lw, alpha=0.5)
|
||||||
med = np.nanmedian(Kp, axis=0)
|
med = np.nanmedian(Kp, axis=0)
|
||||||
ax.plot(xs, med, color="k", lw=1.6)
|
ax.plot(xs, med, color="k", lw=a.median_lw)
|
||||||
frac = float(crossed.mean())
|
frac = float(crossed.mean())
|
||||||
ax.text(0.97, 0.97, f"{frac:.0%} cross KL=1",
|
ax.text(0.97, 0.97, f"{frac:.0%} cross KL=1",
|
||||||
transform=ax.transAxes, ha="right", va="top",
|
transform=ax.transAxes, ha="right", va="top",
|
||||||
@@ -163,21 +168,32 @@ def _draw_kl_panel(ax, K: np.ndarray, a: Args, P: np.ndarray | None = None) -> N
|
|||||||
p50s = _rolling_mean(p50, a.roll)
|
p50s = _rolling_mean(p50, a.roll)
|
||||||
p10s = _rolling_mean(p10, a.roll)
|
p10s = _rolling_mean(p10, a.roll)
|
||||||
p90s = _rolling_mean(p90, a.roll)
|
p90s = _rolling_mean(p90, a.roll)
|
||||||
ax.fill_between(xs, p10s, p90s, alpha=0.25, color="C0", lw=0)
|
if a.quantile_lines:
|
||||||
ax.plot(xs, p50s, color="C0", lw=1.6)
|
# standard quantile fan: outer band light, inner band dark, p50 line
|
||||||
|
p25s = _rolling_mean(np.nanpercentile(K, 25, axis=0), a.roll)
|
||||||
|
p75s = _rolling_mean(np.nanpercentile(K, 75, axis=0), a.roll)
|
||||||
|
ax.fill_between(xs, p10s, p90s, alpha=0.15, color="C0", lw=0, label="p10..p90")
|
||||||
|
ax.fill_between(xs, p25s, p75s, alpha=0.32, color="C0", lw=0, label="p25..p75")
|
||||||
|
ax.plot(xs, p50s, color="C0", lw=1.5, label="p50")
|
||||||
|
else:
|
||||||
|
ax.fill_between(xs, p10s, p90s, alpha=0.25, color="C0", lw=0)
|
||||||
|
ax.plot(xs, p50s, color="C0", lw=1.6)
|
||||||
|
|
||||||
|
|
||||||
def make_kl_figure(cells: list[dict], a: Args, out_path: Path) -> None:
|
def make_kl_figure(cells: list[dict], a: Args, out_path: Path) -> None:
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
fig, axes = plt.subplots(1, len(a.alphas), figsize=(4.0 * len(a.alphas), 3.2),
|
fig, axes = plt.subplots(1, len(a.alphas), figsize=(3.7 * len(a.alphas), 3.4),
|
||||||
sharex=True, sharey=True, squeeze=False, constrained_layout=True)
|
sharex=True, sharey=True, squeeze=False, constrained_layout=True)
|
||||||
label = a.model_contains or "all models"
|
label = a.model_contains or "all models"
|
||||||
mode = "individual trajectories (red=ever crossed KL=1)" if a.spaghetti \
|
n_max = max((_pool_kl(cells, alpha, T=a.window).shape[0] for alpha in a.alphas), default=0)
|
||||||
else f"p50 + p10..p90 band, smoothed rolling-{a.roll}"
|
mode = f"individual trajectories, roll={a.roll}, color=pmass" if (a.spaghetti and a.color_by_pmass) \
|
||||||
|
else "individual trajectories (red=ever crossed KL=1)" if a.spaghetti \
|
||||||
|
else f"shaded quantiles p10/p25/p50/p75/p90, roll={a.roll}" if a.quantile_lines \
|
||||||
|
else f"p50 + p10..p90 band, smoothed rolling-{a.roll}"
|
||||||
fig.suptitle(
|
fig.suptitle(
|
||||||
f"KL trajectory on N=8 held-out long-form prompts ({label})\n"
|
f"KL trajectory on N={n_max} held-out long-form prompts ({label})\n"
|
||||||
f"{mode}. Solid line: KL=1 nat. Dotted v-line: t=20.",
|
f"{mode}. Solid horizontal: KL=1 nat.",
|
||||||
fontsize=10,
|
fontsize=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -202,27 +218,32 @@ def make_kl_figure(cells: list[dict], a: Args, out_path: Path) -> None:
|
|||||||
Pe = _pool_pmass_eval(cells, alpha, _first_fork(cells), T=a.window) if a.color_by_pmass else None
|
Pe = _pool_pmass_eval(cells, alpha, _first_fork(cells), T=a.window) if a.color_by_pmass else None
|
||||||
_draw_kl_panel(ax, K, a, P=Pe)
|
_draw_kl_panel(ax, K, a, P=Pe)
|
||||||
if y_max >= 1.0:
|
if y_max >= 1.0:
|
||||||
ax.axhline(1.0, color="k", lw=1.0)
|
ax.axhline(1.0, color="k", lw=0.7)
|
||||||
else:
|
else:
|
||||||
ax.text(0.98, 0.92, "KL=1 off-scale", transform=ax.transAxes,
|
ax.text(0.98, 0.92, "KL=1 off-scale", transform=ax.transAxes,
|
||||||
ha="right", va="top", fontsize=8, color="0.25")
|
ha="right", va="top", fontsize=8, color="0.25")
|
||||||
ax.axvline(20, color="k", ls=":", lw=0.8)
|
if a.mark_t >= 0:
|
||||||
|
ax.axvline(a.mark_t, color="k", ls=":", lw=0.6)
|
||||||
ax.set_title(rf"$\alpha = {alpha}$ (n={K.shape[0]} traj)")
|
ax.set_title(rf"$\alpha = {alpha}$ (n={K.shape[0]} traj)")
|
||||||
ax.set_xlim(0, x_max)
|
ax.set_xlim(0, x_max)
|
||||||
ax.set_ylim(0, y_max)
|
ax.set_ylim(0, y_max)
|
||||||
ax.set_xlabel("token")
|
ax.set_xlabel("token")
|
||||||
if j == 0:
|
if j == 0:
|
||||||
ax.set_ylabel("KL(steered || base) [nats]")
|
ax.set_ylabel("KL")
|
||||||
|
|
||||||
if a.color_by_pmass:
|
if a.color_by_pmass:
|
||||||
from matplotlib.cm import ScalarMappable
|
from matplotlib.cm import ScalarMappable
|
||||||
from matplotlib.colors import LinearSegmentedColormap, Normalize
|
from matplotlib.colors import LinearSegmentedColormap, Normalize
|
||||||
cmap = LinearSegmentedColormap.from_list("alive", ["#c0392b", "#f1c40f", "#27ae60"])
|
cmap = LinearSegmentedColormap.from_list("alive", ["#c0392b", "#f1c40f", "#27ae60"])
|
||||||
sm = ScalarMappable(norm=Normalize(0, 1), cmap=cmap); sm.set_array([])
|
sm = ScalarMappable(norm=Normalize(0, 1), cmap=cmap); sm.set_array([])
|
||||||
cbar = fig.colorbar(sm, ax=axes[0, :].tolist(), location="right", shrink=0.75, pad=0.01, fraction=0.015)
|
cbar = fig.colorbar(sm, ax=axes[0, :].tolist(), location="right", shrink=0.72, pad=0.015, fraction=0.012)
|
||||||
cbar.set_label("pmass (0=dead, 1=alive)")
|
cbar.set_label("pmass", labelpad=2)
|
||||||
|
|
||||||
fig.savefig(out_path, dpi=160, bbox_inches="tight")
|
fig.savefig(out_path, dpi=160, bbox_inches="tight")
|
||||||
|
if a.quantile_lines:
|
||||||
|
q_path = out_path.with_name(out_path.stem + "_quantile_lines" + out_path.suffix)
|
||||||
|
fig.savefig(q_path, dpi=160, bbox_inches="tight")
|
||||||
|
logger.info(f"KL quantile-line figure -> {q_path}")
|
||||||
logger.info(f"KL-only figure -> {out_path}")
|
logger.info(f"KL-only figure -> {out_path}")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,10 @@ class Args:
|
|||||||
threshold: float = 0.95
|
threshold: float = 0.95
|
||||||
out_name: str = "figs_auto"
|
out_name: str = "figs_auto"
|
||||||
model_contains: str = ""
|
model_contains: str = ""
|
||||||
|
line_alpha: float | None = None # forwarded to spaghetti+aggregate
|
||||||
|
roll: int = 65
|
||||||
|
line_lw: float = 0.5
|
||||||
|
quantile_lines: bool = False
|
||||||
|
|
||||||
|
|
||||||
def _ensure_single_run_root(run_dir: Path) -> Path:
|
def _ensure_single_run_root(run_dir: Path) -> Path:
|
||||||
@@ -72,18 +76,21 @@ def main(a: Args) -> None:
|
|||||||
"--window", str(__import__("json").loads(calib.read_text())["window"]),
|
"--window", str(__import__("json").loads(calib.read_text())["window"]),
|
||||||
"--threshold", str(a.threshold),
|
"--threshold", str(a.threshold),
|
||||||
"--model-contains", model_filter,
|
"--model-contains", model_filter,
|
||||||
], repo_root)
|
"--roll", str(a.roll),
|
||||||
|
"--line-lw", str(a.line_lw),
|
||||||
|
] + (["--line-alpha", str(a.line_alpha)] if a.line_alpha is not None else []), repo_root)
|
||||||
_run([
|
_run([
|
||||||
sys.executable,
|
sys.executable,
|
||||||
"scripts/aggregate.py",
|
"scripts/aggregate.py",
|
||||||
"--runs-root", str(single_root),
|
"--runs-root", str(single_root),
|
||||||
"--out", str(out_dir / "aggregate"),
|
"--out", str(out_dir / "aggregate"),
|
||||||
"--window", str(__import__("json").loads(calib.read_text())["window"]),
|
"--window", str(__import__("json").loads(calib.read_text())["window"]),
|
||||||
"--spaghetti",
|
|
||||||
"--color-by-pmass",
|
|
||||||
"--kl-only",
|
"--kl-only",
|
||||||
"--model-contains", model_filter,
|
"--model-contains", model_filter,
|
||||||
], repo_root)
|
"--roll", str(a.roll),
|
||||||
|
"--line-lw", str(a.line_lw),
|
||||||
|
] + (["--spaghetti", "--color-by-pmass"] if not a.quantile_lines else ["--quantile-lines"]) \
|
||||||
|
+ (["--line-alpha", str(a.line_alpha)] if a.line_alpha is not None else []), repo_root)
|
||||||
|
|
||||||
pngs = [
|
pngs = [
|
||||||
out_dir / "survival" / "survival_pmass_eval.png",
|
out_dir / "survival" / "survival_pmass_eval.png",
|
||||||
|
|||||||
+93
-12
@@ -57,12 +57,13 @@ from pathlib import Path
|
|||||||
import torch
|
import torch
|
||||||
import tyro
|
import tyro
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
from iso_kl_figure import (
|
from iso_kl_figure import (
|
||||||
SteeringConfig, MeanDiffC, PCAC, DirectionalAblationC,
|
SteeringConfig, MeanDiffC, PCAC, DirectionalAblationC,
|
||||||
train, calibrate_iso_kl, measure_kl,
|
train, calibrate_iso_kl, measure_kl,
|
||||||
)
|
)
|
||||||
from iso_kl_figure.branch_pmass import branch_pmass, collect_choice_token_ids
|
from iso_kl_figure.branch_pmass import branch_pmass, build_chat_interrupt_suffix, collect_choice_token_ids
|
||||||
from iso_kl_figure.calibrate import _eos_token_ids
|
from iso_kl_figure.calibrate import _eos_token_ids
|
||||||
from iso_kl_figure.target import _get_blocks
|
from iso_kl_figure.target import _get_blocks
|
||||||
|
|
||||||
@@ -103,6 +104,14 @@ EVAL_PROMPTS = [
|
|||||||
"Was Shakespeare a contemporary of Queen Elizabeth I of England?",
|
"Was Shakespeare a contemporary of Queen Elizabeth I of England?",
|
||||||
"Is the Pacific Ocean larger in surface area than the Atlantic Ocean?",
|
"Is the Pacific Ocean larger in surface area than the Atlantic Ocean?",
|
||||||
"Does a transformer architecture rely on self-attention rather than recurrence for sequence modeling?",
|
"Does a transformer architecture rely on self-attention rather than recurrence for sequence modeling?",
|
||||||
|
"Is honey produced by bees primarily from nectar collected from flowers?",
|
||||||
|
"Did Ada Lovelace write what is often considered the first published computer algorithm?",
|
||||||
|
"Is the boiling point of pure water at sea level approximately 100 degrees Celsius?",
|
||||||
|
"Does cross-entropy loss reduce to negative log likelihood for a single correct class label?",
|
||||||
|
"Was the printing press invented by Johannes Gutenberg in the 15th century?",
|
||||||
|
"Is the speed of light in a vacuum the same for all inertial observers?",
|
||||||
|
"Do antibiotics kill viruses in addition to bacteria?",
|
||||||
|
"Is the chemical symbol for gold 'Au' on the periodic table?",
|
||||||
]
|
]
|
||||||
|
|
||||||
# pmass diagnostic prompts share the same schema as EVAL above.
|
# pmass diagnostic prompts share the same schema as EVAL above.
|
||||||
@@ -144,7 +153,7 @@ class Args:
|
|||||||
model: str
|
model: str
|
||||||
method: str
|
method: str
|
||||||
seed: int = 0
|
seed: int = 0
|
||||||
window: int = 50
|
window: int = 512
|
||||||
run_id: str = ""
|
run_id: str = ""
|
||||||
layer_frac: float = 0.6
|
layer_frac: float = 0.6
|
||||||
target_kl: float = 1.0
|
target_kl: float = 1.0
|
||||||
@@ -159,6 +168,13 @@ class Args:
|
|||||||
skip_pmass_qa: bool = False # skip yes/no pmass loop, only do paired pmass_eval
|
skip_pmass_qa: bool = False # skip yes/no pmass loop, only do paired pmass_eval
|
||||||
render_figs: bool = False # render single-run survival + spaghetti + KL pngs
|
render_figs: bool = False # render single-run survival + spaghetti + KL pngs
|
||||||
render_threshold: float = 0.95
|
render_threshold: float = 0.95
|
||||||
|
chat_interrupt: bool = False # use chat-template-correct interrupt turn for pmass scoring
|
||||||
|
# (in-distribution; recommended over raw splice).
|
||||||
|
interrupt_user_text: str = (
|
||||||
|
"Given the conversation so far, what is your final answer?"
|
||||||
|
" Respond with JSON: {\"Answer\": true} or {\"Answer\": false}."
|
||||||
|
)
|
||||||
|
interrupt_prefill: str = '{"Answer": '
|
||||||
|
|
||||||
|
|
||||||
def _set_seed(s: int):
|
def _set_seed(s: int):
|
||||||
@@ -169,6 +185,28 @@ def _set_seed(s: int):
|
|||||||
torch.cuda.manual_seed_all(s)
|
torch.cuda.manual_seed_all(s)
|
||||||
|
|
||||||
|
|
||||||
|
def _log_fork_summary(stage: str, alpha: float, gen, gen_text: str,
|
||||||
|
fork_points: list[int], pm: dict) -> None:
|
||||||
|
"""One INFO line + one debug table per (stage, alpha) instead of N forks * 1 line.
|
||||||
|
|
||||||
|
SHOULD: at alpha~1, pmass ~> 0.5 across forks (model still respects schema).
|
||||||
|
ELSE: prefill broken or schema drift -- inspect debug_first in pmass.json.
|
||||||
|
"""
|
||||||
|
from tabulate import tabulate
|
||||||
|
pmv = pm["pmass"]
|
||||||
|
ptv = pm["p_true"]
|
||||||
|
pm_min = min(pmv) if pmv else float("nan")
|
||||||
|
pm_med = sorted(pmv)[len(pmv) // 2] if pmv else float("nan")
|
||||||
|
logger.info(
|
||||||
|
f" alpha={alpha} {stage}[0] gen_len={gen.shape[0]} pmass(med/min)="
|
||||||
|
f"{pm_med:.2f}/{pm_min:.2f} | text[:80]={gen_text[:80]!r}"
|
||||||
|
)
|
||||||
|
rows = [(t, f"{p:.3f}", f"{q:.3f}", repr(s))
|
||||||
|
for t, p, q, s in zip(fork_points, pmv, ptv, pm["argmax_str"])]
|
||||||
|
table = tabulate(rows, headers=["t", "pmass", "p_true", "argmax"], tablefmt="plain")
|
||||||
|
logger.debug(f"\nfork table [{stage} alpha={alpha}]\n{table}")
|
||||||
|
|
||||||
|
|
||||||
def _render_chat(tok, msgs: list[dict[str, str]], add_generation_prompt: bool) -> str:
|
def _render_chat(tok, msgs: list[dict[str, str]], add_generation_prompt: bool) -> str:
|
||||||
if tok.chat_template is None:
|
if tok.chat_template is None:
|
||||||
rendered = []
|
rendered = []
|
||||||
@@ -279,6 +317,15 @@ def main(a: Args):
|
|||||||
gen_lens_eval: dict[str, list] = {} # T per (alpha, eval-prompt)
|
gen_lens_eval: dict[str, list] = {} # T per (alpha, eval-prompt)
|
||||||
debug_first: dict[str, dict] = {} # first prompt per alpha: gen_text + top-5 at each fork
|
debug_first: dict[str, dict] = {} # first prompt per alpha: gen_text + top-5 at each fork
|
||||||
|
|
||||||
|
# Append-only JSONL of every generation: full text + paired pmass.
|
||||||
|
# Crash-safe: each line flushed before next prompt. Overwritten only at run start.
|
||||||
|
gens_path = out_dir / "gens.jsonl"
|
||||||
|
if gens_path.exists():
|
||||||
|
gens_path.unlink()
|
||||||
|
def _append_gen(rec: dict) -> None:
|
||||||
|
with open(gens_path, "a") as f:
|
||||||
|
f.write(json.dumps(rec) + "\n")
|
||||||
|
|
||||||
def write_eval_outputs() -> None:
|
def write_eval_outputs() -> None:
|
||||||
(out_dir / "trajectory.json").write_text(json.dumps({
|
(out_dir / "trajectory.json").write_text(json.dumps({
|
||||||
"fork_points_full": list(range(a.window)),
|
"fork_points_full": list(range(a.window)),
|
||||||
@@ -317,7 +364,18 @@ def main(a: Args):
|
|||||||
tok(s, return_tensors="pt", add_special_tokens=False).input_ids[0]
|
tok(s, return_tensors="pt", add_special_tokens=False).input_ids[0]
|
||||||
for s in eval_prompt_strs
|
for s in eval_prompt_strs
|
||||||
]
|
]
|
||||||
for alpha in a.alphas:
|
interrupt_suffix_ids = None
|
||||||
|
interrupt_suffix_decoded = ""
|
||||||
|
if a.chat_interrupt:
|
||||||
|
interrupt_suffix_ids = build_chat_interrupt_suffix(
|
||||||
|
tok, a.interrupt_user_text, a.interrupt_prefill,
|
||||||
|
)
|
||||||
|
# SHOULD: decoded suffix ends with prefill literal. ELSE template quirk.
|
||||||
|
interrupt_suffix_decoded = tok.decode(interrupt_suffix_ids)
|
||||||
|
logger.info(f"chat_interrupt suffix ({len(interrupt_suffix_ids)} tok): {interrupt_suffix_decoded!r}")
|
||||||
|
if not interrupt_suffix_decoded.endswith(a.interrupt_prefill):
|
||||||
|
logger.warning(f"interrupt suffix does not end with prefill={a.interrupt_prefill!r}")
|
||||||
|
for alpha in tqdm(a.alphas, desc=f"alphas[{a.run_id}]", mininterval=60):
|
||||||
v.cfg.coeff = alpha * c_star
|
v.cfg.coeff = alpha * c_star
|
||||||
logger.info(f"=== eval alpha={alpha} c={v.cfg.coeff:+.4f} ===")
|
logger.info(f"=== eval alpha={alpha} c={v.cfg.coeff:+.4f} ===")
|
||||||
m = measure_kl(v, model, tok, eval_ids_list, T=a.window, device=a.device)
|
m = measure_kl(v, model, tok, eval_ids_list, T=a.window, device=a.device)
|
||||||
@@ -351,6 +409,7 @@ def main(a: Args):
|
|||||||
v, model, tok, ids, gen, fork_points,
|
v, model, tok, ids, gen, fork_points,
|
||||||
PREFILL_STR, a_ids, b_ids,
|
PREFILL_STR, a_ids, b_ids,
|
||||||
rollout_cache=gen_out.past_key_values,
|
rollout_cache=gen_out.past_key_values,
|
||||||
|
interrupt_suffix_ids=interrupt_suffix_ids,
|
||||||
device=a.device,
|
device=a.device,
|
||||||
)
|
)
|
||||||
pm_for_alpha.append(pm["pmass"])
|
pm_for_alpha.append(pm["pmass"])
|
||||||
@@ -363,17 +422,28 @@ def main(a: Args):
|
|||||||
else "other"
|
else "other"
|
||||||
for s in pm["argmax_str"]
|
for s in pm["argmax_str"]
|
||||||
])
|
])
|
||||||
# debug dump for first QA prompt only
|
# save full gen text + paired pmass for every QA prompt
|
||||||
|
gen_text = tok.decode(gen, skip_special_tokens=False)
|
||||||
|
_append_gen({
|
||||||
|
"alpha": float(alpha), "kind": "qa", "prompt_idx": q_idx,
|
||||||
|
"score_mode": "chat_interrupt" if a.chat_interrupt else "raw_splice",
|
||||||
|
"interrupt_user_text": a.interrupt_user_text if a.chat_interrupt else "",
|
||||||
|
"interrupt_prefill": a.interrupt_prefill if a.chat_interrupt else "",
|
||||||
|
"interrupt_suffix_ids": list(interrupt_suffix_ids) if interrupt_suffix_ids is not None else [],
|
||||||
|
"interrupt_suffix_decoded": interrupt_suffix_decoded,
|
||||||
|
"prompt": prompt_str, "gen_text": gen_text, "gen_len": int(gen.shape[0]),
|
||||||
|
"fork_points": list(fork_points),
|
||||||
|
"pmass": pm["pmass"], "p_true": pm["p_true"],
|
||||||
|
"argmax_str": pm["argmax_str"], "was_thinking": pm["was_thinking"],
|
||||||
|
})
|
||||||
|
# debug dump for first QA prompt only (kept for backward compat)
|
||||||
if q_idx == 0:
|
if q_idx == 0:
|
||||||
gen_text = tok.decode(gen, skip_special_tokens=False)
|
|
||||||
debug_first.setdefault(str(alpha), {})["qa"] = {
|
debug_first.setdefault(str(alpha), {})["qa"] = {
|
||||||
"prompt": prompt_str, "gen_text": gen_text, "gen_len": int(gen.shape[0]),
|
"prompt": prompt_str, "gen_text": gen_text, "gen_len": int(gen.shape[0]),
|
||||||
"pmass_per_fork": pm["pmass"], "p_true_per_fork": pm["p_true"],
|
"pmass_per_fork": pm["pmass"], "p_true_per_fork": pm["p_true"],
|
||||||
"argmax_per_fork": pm["argmax_str"],
|
"argmax_per_fork": pm["argmax_str"],
|
||||||
}
|
}
|
||||||
logger.info(f" [debug] alpha={alpha} qa[0] gen_len={gen.shape[0]} text[:120]={gen_text[:120]!r}")
|
_log_fork_summary("qa", alpha, gen, gen_text, fork_points, pm)
|
||||||
for t, pmv, ptv, am in zip(fork_points, pm["pmass"], pm["p_true"], pm["argmax_str"]):
|
|
||||||
logger.info(f" t={t:>3} pmass={pmv:.3f} p_true={ptv:.3f} argmax={am!r}")
|
|
||||||
pmass_all[str(alpha)] = pm_for_alpha
|
pmass_all[str(alpha)] = pm_for_alpha
|
||||||
p_true_all[str(alpha)] = pt_for_alpha
|
p_true_all[str(alpha)] = pt_for_alpha
|
||||||
argmax_all[str(alpha)] = ax_for_alpha
|
argmax_all[str(alpha)] = ax_for_alpha
|
||||||
@@ -407,19 +477,30 @@ def main(a: Args):
|
|||||||
v, model, tok, ids, gen, fork_points,
|
v, model, tok, ids, gen, fork_points,
|
||||||
PREFILL_STR, a_ids, b_ids,
|
PREFILL_STR, a_ids, b_ids,
|
||||||
rollout_cache=gen_out.past_key_values,
|
rollout_cache=gen_out.past_key_values,
|
||||||
|
interrupt_suffix_ids=interrupt_suffix_ids,
|
||||||
device=a.device,
|
device=a.device,
|
||||||
)
|
)
|
||||||
pm_eval_for_alpha.append(pm["pmass"])
|
pm_eval_for_alpha.append(pm["pmass"])
|
||||||
|
gen_text = tok.decode(gen, skip_special_tokens=False)
|
||||||
|
_append_gen({
|
||||||
|
"alpha": float(alpha), "kind": "eval", "prompt_idx": p_idx,
|
||||||
|
"score_mode": "chat_interrupt" if a.chat_interrupt else "raw_splice",
|
||||||
|
"interrupt_user_text": a.interrupt_user_text if a.chat_interrupt else "",
|
||||||
|
"interrupt_prefill": a.interrupt_prefill if a.chat_interrupt else "",
|
||||||
|
"interrupt_suffix_ids": list(interrupt_suffix_ids) if interrupt_suffix_ids is not None else [],
|
||||||
|
"interrupt_suffix_decoded": interrupt_suffix_decoded,
|
||||||
|
"prompt": prompt_str, "gen_text": gen_text, "gen_len": int(gen.shape[0]),
|
||||||
|
"fork_points": list(fork_points),
|
||||||
|
"pmass": pm["pmass"], "p_true": pm["p_true"],
|
||||||
|
"argmax_str": pm["argmax_str"], "was_thinking": pm["was_thinking"],
|
||||||
|
})
|
||||||
if p_idx == 0:
|
if p_idx == 0:
|
||||||
gen_text = tok.decode(gen, skip_special_tokens=False)
|
|
||||||
debug_first.setdefault(str(alpha), {})["eval"] = {
|
debug_first.setdefault(str(alpha), {})["eval"] = {
|
||||||
"prompt": prompt_str, "gen_text": gen_text, "gen_len": int(gen.shape[0]),
|
"prompt": prompt_str, "gen_text": gen_text, "gen_len": int(gen.shape[0]),
|
||||||
"pmass_per_fork": pm["pmass"], "p_true_per_fork": pm["p_true"],
|
"pmass_per_fork": pm["pmass"], "p_true_per_fork": pm["p_true"],
|
||||||
"argmax_per_fork": pm["argmax_str"],
|
"argmax_per_fork": pm["argmax_str"],
|
||||||
}
|
}
|
||||||
logger.info(f" [debug] alpha={alpha} eval[0] gen_len={gen.shape[0]} text[:120]={gen_text[:120]!r}")
|
_log_fork_summary("eval", alpha, gen, gen_text, fork_points, pm)
|
||||||
for t, pmv, ptv, am in zip(fork_points, pm["pmass"], pm["p_true"], pm["argmax_str"]):
|
|
||||||
logger.info(f" t={t:>3} pmass={pmv:.3f} p_true={ptv:.3f} argmax={am!r}")
|
|
||||||
pmass_eval_all[str(alpha)] = pm_eval_for_alpha
|
pmass_eval_all[str(alpha)] = pm_eval_for_alpha
|
||||||
gen_lens_eval[str(alpha)] = gen_lens_eval_alpha
|
gen_lens_eval[str(alpha)] = gen_lens_eval_alpha
|
||||||
# SHOULD: at alpha=1, mean(pmass at t=0) > 0.5 (model still respects schema).
|
# SHOULD: at alpha=1, mean(pmass at t=0) > 0.5 (model still respects schema).
|
||||||
|
|||||||
@@ -55,6 +55,43 @@ def collect_choice_token_ids(
|
|||||||
return _ids(a_words), _ids(b_words)
|
return _ids(a_words), _ids(b_words)
|
||||||
|
|
||||||
|
|
||||||
|
def build_chat_interrupt_suffix(
|
||||||
|
tok,
|
||||||
|
interrupt_user_text: str,
|
||||||
|
assistant_prefill: str,
|
||||||
|
) -> list[int]:
|
||||||
|
"""Construct token ids for: <eot> <user>{question}<eot> <assistant>{prefill}.
|
||||||
|
|
||||||
|
Used to interrupt an in-progress assistant rollout with a follow-up question
|
||||||
|
in a chat-template-correct way (vs. raw token splice). Detects the suffix
|
||||||
|
by string-diffing the rendered template before vs. after appending the
|
||||||
|
interrupt turn -- no template-specific special-token guessing.
|
||||||
|
|
||||||
|
SHOULD: tok.decode(suffix_ids) ends with assistant_prefill. ELSE template
|
||||||
|
quirk: inspect tok.chat_template.
|
||||||
|
"""
|
||||||
|
base_msgs = [
|
||||||
|
{"role": "user", "content": "_"},
|
||||||
|
{"role": "assistant", "content": "_ROLLED_"},
|
||||||
|
]
|
||||||
|
ext_msgs = base_msgs + [
|
||||||
|
{"role": "user", "content": interrupt_user_text},
|
||||||
|
{"role": "assistant", "content": assistant_prefill},
|
||||||
|
]
|
||||||
|
base_str = tok.apply_chat_template(base_msgs, tokenize=False)
|
||||||
|
ext_str = tok.apply_chat_template(ext_msgs, tokenize=False, continue_final_message=True)
|
||||||
|
if not ext_str.startswith(base_str):
|
||||||
|
# template inserts trailing whitespace in `base_str` — strip and retry
|
||||||
|
base_str = base_str.rstrip()
|
||||||
|
if not ext_str.startswith(base_str):
|
||||||
|
raise ValueError(
|
||||||
|
"chat template not prefix-stable: cannot derive interrupt suffix. "
|
||||||
|
"Set use_chat_interrupt=False or extend this helper."
|
||||||
|
)
|
||||||
|
suffix_str = ext_str[len(base_str):]
|
||||||
|
return tok.encode(suffix_str, add_special_tokens=False)
|
||||||
|
|
||||||
|
|
||||||
def _is_thinking(seq_ids: Tensor, think_id: int, unthink_id: int) -> bool:
|
def _is_thinking(seq_ids: Tensor, think_id: int, unthink_id: int) -> bool:
|
||||||
"""True iff the last `<think>` in seq_ids is after the last `</think>`."""
|
"""True iff the last `<think>` in seq_ids is after the last `</think>`."""
|
||||||
if think_id is None or unthink_id is None:
|
if think_id is None or unthink_id is None:
|
||||||
@@ -89,6 +126,9 @@ def branch_pmass(
|
|||||||
handle_thinking: bool = True,
|
handle_thinking: bool = True,
|
||||||
end_think_str: str = "</think>",
|
end_think_str: str = "</think>",
|
||||||
force_close_str: str = "\nI should answer now.", # see tinymfv/guided.py
|
force_close_str: str = "\nI should answer now.", # see tinymfv/guided.py
|
||||||
|
interrupt_suffix_ids: Sequence[int] | None = None, # if set, overrides prefill_str + thinking
|
||||||
|
# close logic and uses chat-template-correct interrupt
|
||||||
|
# turn (built via build_chat_interrupt_suffix).
|
||||||
device: str | torch.device = "cuda",
|
device: str | torch.device = "cuda",
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Returns dict with parallel lists indexed by fork point:
|
"""Returns dict with parallel lists indexed by fork point:
|
||||||
@@ -113,6 +153,13 @@ def branch_pmass(
|
|||||||
tok.encode(force_close_str + end_think_str, add_special_tokens=False),
|
tok.encode(force_close_str + end_think_str, add_special_tokens=False),
|
||||||
device=device, dtype=torch.long,
|
device=device, dtype=torch.long,
|
||||||
) if handle_thinking else None
|
) if handle_thinking else None
|
||||||
|
# Chat-template interrupt mode: replaces both close_t and pre_t with one
|
||||||
|
# template-correct turn boundary (<eot><user>...<eot><assistant>{prefill}).
|
||||||
|
# Disables thinking-detection (irrelevant: we are starting a new turn).
|
||||||
|
interrupt_t = (
|
||||||
|
torch.tensor(list(interrupt_suffix_ids), device=device, dtype=torch.long)
|
||||||
|
if interrupt_suffix_ids is not None else None
|
||||||
|
)
|
||||||
a_t = torch.tensor(list(a_ids), dtype=torch.long, device=device)
|
a_t = torch.tensor(list(a_ids), dtype=torch.long, device=device)
|
||||||
b_t = torch.tensor(list(b_ids), dtype=torch.long, device=device)
|
b_t = torch.tensor(list(b_ids), dtype=torch.long, device=device)
|
||||||
all_t = torch.cat([a_t, b_t])
|
all_t = torch.cat([a_t, b_t])
|
||||||
@@ -130,9 +177,13 @@ def branch_pmass(
|
|||||||
continue
|
continue
|
||||||
prefix = rolled[:t]
|
prefix = rolled[:t]
|
||||||
seq_so_far = torch.cat([pids, prefix])
|
seq_so_far = torch.cat([pids, prefix])
|
||||||
thinking = (handle_thinking and think_id is not None and unthink_id is not None
|
if interrupt_t is not None:
|
||||||
and _is_thinking(seq_so_far, think_id, unthink_id))
|
new_tail = interrupt_t
|
||||||
new_tail = torch.cat([close_t, pre_t]) if (thinking and close_t is not None) else pre_t
|
thinking = False # not meaningful in chat-interrupt mode
|
||||||
|
else:
|
||||||
|
thinking = (handle_thinking and think_id is not None and unthink_id is not None
|
||||||
|
and _is_thinking(seq_so_far, think_id, unthink_id))
|
||||||
|
new_tail = torch.cat([close_t, pre_t]) if (thinking and close_t is not None) else pre_t
|
||||||
|
|
||||||
if rollout_cache is not None and use_rollout_cache:
|
if rollout_cache is not None and use_rollout_cache:
|
||||||
cache = copy.deepcopy(rollout_cache)
|
cache = copy.deepcopy(rollout_cache)
|
||||||
|
|||||||
@@ -158,12 +158,14 @@ def measure_kl(
|
|||||||
base_full = torch.cat([pids.to(device), base_gen])
|
base_full = torch.cat([pids.to(device), base_gen])
|
||||||
decoded_base = tok.decode(base_full, skip_special_tokens=False)
|
decoded_base = tok.decode(base_full, skip_special_tokens=False)
|
||||||
decoded_steer = tok.decode(full_ids, skip_special_tokens=False)
|
decoded_steer = tok.decode(full_ids, skip_special_tokens=False)
|
||||||
|
# SHOULD: BASE and STEER both coherent; STEER differs from BASE but does not collapse.
|
||||||
|
# Truncate to keep iso-KL bracket logs scannable; full text in trajectory.json/debug_first.
|
||||||
|
head = 400
|
||||||
logger.info(
|
logger.info(
|
||||||
f"EXPECT: same prompt under c=0 vs c={v.cfg.coeff:+.4f}; both coherent; "
|
f"SHOULD: c=0 vs c={v.cfg.coeff:+.4f} both coherent, steer differs but does not collapse.\n"
|
||||||
"steered should differ from base but not collapse.\n"
|
f"=== CALIBRATE demo (T={T}, first {head} chars each) ===\n"
|
||||||
f"\n=== CALIBRATE demo trace (T={T}) ===\n"
|
f"-- BASE : {decoded_base[:head]!r}\n"
|
||||||
f"--- BASE (c=0) ---\n{decoded_base}\n"
|
f"-- STEER : {decoded_steer[:head]!r}\n"
|
||||||
f"\n--- STEER (c={v.cfg.coeff:+.4f}) ---\n{decoded_steer}\n"
|
|
||||||
f"=== /CALIBRATE ==="
|
f"=== /CALIBRATE ==="
|
||||||
)
|
)
|
||||||
kls = _kl_generated_incremental(v, model, pids, gen, device)
|
kls = _kl_generated_incremental(v, model, pids, gen, device)
|
||||||
|
|||||||
+19
-1
@@ -15,7 +15,11 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|||||||
from iso_kl_figure import (
|
from iso_kl_figure import (
|
||||||
SteeringConfig, MeanDiffC, train, measure_kl, attach, detach,
|
SteeringConfig, MeanDiffC, train, measure_kl, attach, detach,
|
||||||
)
|
)
|
||||||
from iso_kl_figure.branch_pmass import branch_pmass, collect_choice_token_ids
|
from iso_kl_figure.branch_pmass import (
|
||||||
|
branch_pmass,
|
||||||
|
build_chat_interrupt_suffix,
|
||||||
|
collect_choice_token_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM"
|
MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM"
|
||||||
@@ -61,8 +65,22 @@ def test_pipeline_smoke():
|
|||||||
v.cfg.coeff = 5.0
|
v.cfg.coeff = 5.0
|
||||||
p_steer = branch_pmass(v, model, tok, pids, rolled, fork, prefill,
|
p_steer = branch_pmass(v, model, tok, pids, rolled, fork, prefill,
|
||||||
a_ids, b_ids, device="cpu")
|
a_ids, b_ids, device="cpu")
|
||||||
|
tok.chat_template = (
|
||||||
|
"{% for message in messages %}"
|
||||||
|
"<|{{ message['role'] }}|>\n{{ message['content'] }}\n"
|
||||||
|
"{% endfor %}"
|
||||||
|
"{% if add_generation_prompt %}<|assistant|>\n{% endif %}"
|
||||||
|
)
|
||||||
|
interrupt_ids = build_chat_interrupt_suffix(tok, "Final answer?", '{"Answer": ')
|
||||||
|
assert tok.decode(interrupt_ids).endswith('{"Answer": ')
|
||||||
|
p_interrupt = branch_pmass(
|
||||||
|
v, model, tok, pids, rolled, fork, prefill,
|
||||||
|
a_ids, b_ids, interrupt_suffix_ids=interrupt_ids, device="cpu",
|
||||||
|
)
|
||||||
for x in p_zero["pmass"] + p_steer["pmass"]:
|
for x in p_zero["pmass"] + p_steer["pmass"]:
|
||||||
assert 0.0 <= x <= 1.0, f"pmass out of [0,1]: {x}"
|
assert 0.0 <= x <= 1.0, f"pmass out of [0,1]: {x}"
|
||||||
|
for x in p_interrupt["pmass"]:
|
||||||
|
assert 0.0 <= x <= 1.0, f"chat-interrupt pmass out of [0,1]: {x}"
|
||||||
diff = max(abs(a - b) for a, b in zip(p_zero["pmass"], p_steer["pmass"]))
|
diff = max(abs(a - b) for a, b in zip(p_zero["pmass"], p_steer["pmass"]))
|
||||||
assert diff > 1e-8, "pmass invariant to coeff -- hook is dead"
|
assert diff > 1e-8, "pmass invariant to coeff -- hook is dead"
|
||||||
# p_true should be in [0,1] or NaN
|
# p_true should be in [0,1] or NaN
|
||||||
|
|||||||
Reference in New Issue
Block a user