mirror of
https://github.com/wassname/isokl_steering_calibration.git
synced 2026-06-27 16:45:53 +08:00
532 lines
26 KiB
Python
532 lines
26 KiB
Python
"""End-to-end runner for one (model, method, seed, window) cell.
|
|
|
|
Artefacts (per cell, under outputs/<run_id>/):
|
|
calib.json c_star + run metadata (model, method, seed, layer, window).
|
|
history.json bisection history (without per-token KL arrays).
|
|
trajectory.json per-token KL on EVAL_PROMPTS for each alpha:
|
|
per_t_p95_kl[alpha]: list[T]
|
|
per_prompt_per_t_kl[alpha]: list[N_prompts][T]
|
|
pmass.json forked-answer probability mass at fork_points:
|
|
pmass[alpha]: yes/no reasoning prompts (legacy probe)
|
|
pmass_eval[alpha]: SAME prompts as trajectory.json (paired w KL)
|
|
gen_lens_qa[alpha], gen_lens_eval[alpha]: T per rollout
|
|
(use for right-censoring -- NaN at t > T means rollout
|
|
EOS'd before that fork, NOT a measurement failure).
|
|
debug_first[alpha]: gen_text + per-fork pmass for the
|
|
FIRST qa & eval prompt (sanity-check; see survival.py).
|
|
pmass_eval is the metric to use when you want a per-trajectory
|
|
coherence signal aligned with the KL trajectory.
|
|
results.csv one row per alpha with kl_p95/mean/max.
|
|
|
|
Flow:
|
|
1. Load model + tokenizer; set seed.
|
|
2. Build pos/neg pair (cheap content-vs-refusal); train Vector v.
|
|
3. Calibrate iso-KL: bisect coeff so per-token kl_p95 hits target_kl (default 1)
|
|
over T=window tokens on CALIB_PROMPTS.
|
|
4. For each alpha, set coeff=alpha*c_star and:
|
|
a) measure_kl on EVAL_PROMPTS -> per_prompt_per_t_kl trajectories.
|
|
b) (--compute-pmass) yes/no questions: rollout, then at each fork point
|
|
prefill `\n{"choice": ` and read mass on {true/True/ true/.../1/ 1}
|
|
and {false/.../0}. Skip with --skip-pmass-qa.
|
|
c) (--compute-pmass) repeat (b) on EVAL_PROMPTS -> pmass_eval, paired
|
|
with the KL trajectory on the same rollouts. This is the honest
|
|
coherence signal; KL is largely a calibration tautology at alpha=1.
|
|
|
|
Fork-point modes:
|
|
--fork-step N linear: {0, N, 2N, ..., window}
|
|
--fork-log log-spaced: {0, 1, 2, 4, 8, ..., window} (with --n-log-forks)
|
|
|
|
KV-cache:
|
|
Rollouts use generate(..., return_dict_in_generate=True). branch_pmass
|
|
uses full-recompute by default (cache reuse adds ~10% bf16 noise; verified
|
|
by scripts/audit_pmass.py). Pass use_rollout_cache=True only if you accept
|
|
the noise.
|
|
|
|
Usage (tyro):
|
|
python scripts/run_cell.py --model Qwen/Qwen2.5-0.5B-Instruct --method mean_diff \
|
|
--seed 0 --window 512 --compute-pmass --skip-pmass-qa --fork-log
|
|
"""
|
|
from __future__ import annotations
|
|
import json
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import tyro
|
|
from loguru import logger
|
|
from tqdm.auto import tqdm
|
|
|
|
from iso_kl_figure import (
|
|
SteeringConfig, MeanDiffC, PCAC, DirectionalAblationC,
|
|
train, calibrate_iso_kl, measure_kl,
|
|
)
|
|
from iso_kl_figure.branch_pmass import branch_pmass, build_chat_interrupt_suffix, collect_choice_token_ids
|
|
from iso_kl_figure.calibrate import _eos_token_ids
|
|
from iso_kl_figure.target import _get_blocks
|
|
|
|
|
|
CALIB_PROMPTS = [
|
|
"The eiffel tower is in Paris.",
|
|
"Tell me a greentext story about a small village during the Martian carrot bubble.",
|
|
"Walk me through computing the integral of x^2 from 0 to 1.",
|
|
"Briefly describe how plate tectonics shape mountain ranges.",
|
|
]
|
|
|
|
# EVAL prompts are yes/no questions paired with schema -- same contract as
|
|
# tinymfv/guided.py so branch_pmass at any fork can score format-compliance.
|
|
# The model is steered into <think>\n at the prompt boundary; at fork t we
|
|
# inject "\nI should answer now.</think>\n{\"choice\": " and read p(true|false).
|
|
EVAL_PROMPTS = [
|
|
"Is the sky blue primarily because of Rayleigh scattering of sunlight?",
|
|
"Does Hamlet die at the end of Shakespeare's play 'Hamlet'?",
|
|
"Is excessive money supply growth one cause of inflation in modern economies?",
|
|
"Do octopuses have three hearts and blue copper-based blood?",
|
|
"Does photosynthesis use light energy to convert CO2 and water into glucose?",
|
|
"Was the Eiffel Tower originally built as the entrance arch for the 1889 World's Fair?",
|
|
"Can plate tectonic collisions form mountain ranges over geological time?",
|
|
"Is a circular import one possible cause of a Python script crashing on import?",
|
|
"Does water boil at a lower temperature at higher altitudes due to reduced atmospheric pressure?",
|
|
"Is the mitochondrion commonly described as the powerhouse of the eukaryotic cell?",
|
|
"Did the Apollo 11 mission land humans on the Moon in July 1969?",
|
|
"Is Mount Everest the tallest mountain on Earth measured from base to peak?",
|
|
"Can a closure in JavaScript capture variables from its enclosing lexical scope?",
|
|
"Does the human body have exactly 206 bones in adulthood?",
|
|
"Is sound generally faster in solids than in air at the same temperature?",
|
|
"Was the Great Wall of China built primarily as a single continuous wall in one dynasty?",
|
|
"Does relativity predict that time dilates for an observer moving at relativistic speeds?",
|
|
"Is DNA replication semi-conservative, with each daughter molecule keeping one parent strand?",
|
|
"Did the Roman Empire formally split into Western and Eastern halves in the 4th century CE?",
|
|
"Is gradient descent guaranteed to find the global minimum of an arbitrary non-convex loss?",
|
|
"Does the moon always show the same face to Earth because of tidal locking?",
|
|
"Was Shakespeare a contemporary of Queen Elizabeth I of England?",
|
|
"Is the Pacific Ocean larger in surface area than the Atlantic Ocean?",
|
|
"Does a transformer architecture rely on self-attention rather than recurrence for sequence modeling?",
|
|
"Is honey produced by bees primarily from nectar collected from flowers?",
|
|
"Did Ada Lovelace write what is often considered the first published computer algorithm?",
|
|
"Is the boiling point of pure water at sea level approximately 100 degrees Celsius?",
|
|
"Does cross-entropy loss reduce to negative log likelihood for a single correct class label?",
|
|
"Was the printing press invented by Johannes Gutenberg in the 15th century?",
|
|
"Is the speed of light in a vacuum the same for all inertial observers?",
|
|
"Do antibiotics kill viruses in addition to bacteria?",
|
|
"Is the chemical symbol for gold 'Au' on the periodic table?",
|
|
]
|
|
|
|
# pmass diagnostic prompts share the same schema as EVAL above.
|
|
_SCHEMA = (
|
|
'Reason for at least 80 words before the final answer. Then answer only with '
|
|
'{"choice": true} or {"choice": false}. If you use numeric JSON, use 1 for true and 0 for false.'
|
|
)
|
|
_QUESTIONS = [
|
|
"Is the Eiffel Tower located in Paris, France?",
|
|
"Do dolphins breathe air through lungs rather than gills?",
|
|
"Was Albert Einstein born in Germany?",
|
|
"Can humans naturally see ultraviolet light?",
|
|
"Did the Roman Empire fall in 476 CE?",
|
|
"Are tomatoes botanically classified as fruits?",
|
|
"Is the speed of light approximately 300,000 km/s?",
|
|
"Do octopuses have three hearts?",
|
|
]
|
|
PREFILL_STR = '\n{"choice": '
|
|
|
|
# Pos/neg pair: a generic content-vs-refusal direction. Cheap, not the main
|
|
# point -- the figure tests calibration *behaviour*, not direction quality.
|
|
POS_NEG = [
|
|
("Sure, here is the answer:", "I cannot help with that."),
|
|
("Here are some details:", "Sorry, I can't assist with that."),
|
|
("Of course, let me explain.", "I won't be able to help."),
|
|
("Yes, that makes sense.", "No, I have to decline."),
|
|
]
|
|
|
|
|
|
METHOD_MAP = {
|
|
"mean_diff": MeanDiffC,
|
|
"pca": PCAC,
|
|
"directional_ablation": DirectionalAblationC,
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class Args:
|
|
model: str
|
|
method: str
|
|
seed: int = 0
|
|
window: int = 512
|
|
run_id: str = ""
|
|
layer_frac: float = 0.6
|
|
target_kl: float = 1.0
|
|
out_root: str = "outputs"
|
|
device: str = "cuda"
|
|
dtype: str = "bfloat16"
|
|
alphas: tuple[float, ...] = (0.25, 0.5, 0.75, 1.0, 1.5, 2.0, 4.0)
|
|
fork_step: int = 5
|
|
fork_log: bool = False # log-spaced fork points {0,1,2,4,8,...,window}
|
|
n_log_forks: int = 14 # number of log-spaced forks (incl. 0)
|
|
compute_pmass: bool = False
|
|
skip_pmass_qa: bool = False # skip yes/no pmass loop, only do paired pmass_eval
|
|
render_figs: bool = False # render single-run survival + spaghetti + KL pngs
|
|
render_threshold: float = 0.95
|
|
chat_interrupt: bool = False # use chat-template-correct interrupt turn for pmass scoring
|
|
# (in-distribution; recommended over raw splice).
|
|
interrupt_user_text: str = (
|
|
"Given the conversation so far, what is your final answer?"
|
|
" Respond with JSON: {\"Answer\": true} or {\"Answer\": false}."
|
|
)
|
|
interrupt_prefill: str = '{"Answer": '
|
|
|
|
|
|
def _set_seed(s: int):
|
|
import random
|
|
import numpy as np
|
|
random.seed(s); np.random.seed(s); torch.manual_seed(s)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(s)
|
|
|
|
|
|
def _log_fork_summary(stage: str, alpha: float, gen, gen_text: str,
|
|
fork_points: list[int], pm: dict) -> None:
|
|
"""One INFO line + one debug table per (stage, alpha) instead of N forks * 1 line.
|
|
|
|
SHOULD: at alpha~1, pmass ~> 0.5 across forks (model still respects schema).
|
|
ELSE: prefill broken or schema drift -- inspect debug_first in pmass.json.
|
|
"""
|
|
from tabulate import tabulate
|
|
pmv = pm["pmass"]
|
|
ptv = pm["p_true"]
|
|
pm_min = min(pmv) if pmv else float("nan")
|
|
pm_med = sorted(pmv)[len(pmv) // 2] if pmv else float("nan")
|
|
logger.info(
|
|
f" alpha={alpha} {stage}[0] gen_len={gen.shape[0]} pmass(med/min)="
|
|
f"{pm_med:.2f}/{pm_min:.2f} | text[:80]={gen_text[:80]!r}"
|
|
)
|
|
rows = [(t, f"{p:.3f}", f"{q:.3f}", repr(s))
|
|
for t, p, q, s in zip(fork_points, pmv, ptv, pm["argmax_str"])]
|
|
table = tabulate(rows, headers=["t", "pmass", "p_true", "argmax"], tablefmt="plain")
|
|
logger.debug(f"\nfork table [{stage} alpha={alpha}]\n{table}")
|
|
|
|
|
|
def _render_chat(tok, msgs: list[dict[str, str]], add_generation_prompt: bool) -> str:
|
|
if tok.chat_template is None:
|
|
rendered = []
|
|
for msg in msgs:
|
|
role = msg["role"]
|
|
content = msg["content"]
|
|
if role == "user":
|
|
rendered.append(f"User: {content}\n")
|
|
elif role == "assistant":
|
|
rendered.append(f"Assistant: {content}")
|
|
else:
|
|
raise ValueError(f"plain prompt renderer does not support role={role!r}")
|
|
if add_generation_prompt:
|
|
rendered.append("Assistant: ")
|
|
return "".join(rendered)
|
|
return tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=add_generation_prompt)
|
|
|
|
|
|
def _build_guided_prompt(tok, user_text: str, schema_hint: str = _SCHEMA) -> str:
|
|
"""Match tinymfv/guided.py: chat template + '<think>\\n' suffix so the model
|
|
is in thinking mode at every fork point. branch_pmass detects this and
|
|
splices '\\nI should answer now.</think>{prefill}' before scoring."""
|
|
full_user = f"{user_text}\n\n{schema_hint}" if schema_hint else user_text
|
|
msgs = [{"role": "user", "content": full_user}]
|
|
p = _render_chat(tok, msgs, add_generation_prompt=True)
|
|
return p + "<think>\n"
|
|
|
|
|
|
def main(a: Args):
|
|
if not a.run_id:
|
|
a.run_id = f"{a.model.split('/')[-1]}_{a.method}_s{a.seed}_w{a.window}"
|
|
out_dir = Path(a.out_root) / a.run_id
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
logger.add(out_dir / "run.log", level="INFO")
|
|
|
|
_set_seed(a.seed)
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
dtype = getattr(torch, a.dtype)
|
|
tok = AutoTokenizer.from_pretrained(a.model)
|
|
if tok.pad_token_id is None:
|
|
tok.pad_token_id = tok.eos_token_id
|
|
model_kwargs = {"torch_dtype": dtype}
|
|
if "bnb-4bit" in a.model.lower():
|
|
model_kwargs["device_map"] = {"": a.device}
|
|
model = AutoModelForCausalLM.from_pretrained(a.model, **model_kwargs)
|
|
else:
|
|
model = AutoModelForCausalLM.from_pretrained(a.model, **model_kwargs).to(a.device)
|
|
model.eval()
|
|
|
|
n_layers = len(_get_blocks(model))
|
|
layer = int(a.layer_frac * n_layers)
|
|
logger.info(f"model={a.model} n_layers={n_layers} target_layer={layer}")
|
|
|
|
cfg_cls = METHOD_MAP[a.method]
|
|
cfg = cfg_cls(coeff=1.0, layers=(layer,))
|
|
|
|
pos = [_render_chat(tok, [{"role": "user", "content": u},
|
|
{"role": "assistant", "content": p}], add_generation_prompt=False)
|
|
for u, (p, _) in zip(CALIB_PROMPTS, POS_NEG)]
|
|
neg = [_render_chat(tok, [{"role": "user", "content": u},
|
|
{"role": "assistant", "content": n}], add_generation_prompt=False)
|
|
for u, (_, n) in zip(CALIB_PROMPTS, POS_NEG)]
|
|
v = train(model, tok, pos, neg, cfg, batch_size=4, max_length=128)
|
|
|
|
logger.info("=== calibrate ===")
|
|
c_star, history = calibrate_iso_kl(
|
|
v, model, tok, CALIB_PROMPTS,
|
|
target_kl=a.target_kl, target_stat="kl_p95",
|
|
T=a.window, device=a.device,
|
|
)
|
|
v.cfg.coeff = c_star
|
|
logger.info(f"c_star = {c_star:+.4f}")
|
|
# Strip per_prompt_per_t from history to keep file size small.
|
|
hist_slim = [{k: v_ for k, v_ in h.items() if k != "per_prompt_per_t"}
|
|
for h in history]
|
|
(out_dir / "history.json").write_text(json.dumps(hist_slim, indent=2))
|
|
(out_dir / "calib.json").write_text(json.dumps({
|
|
"c_star": c_star, "target_kl": a.target_kl, "window": a.window,
|
|
"method": a.method, "model": a.model, "seed": a.seed, "layer": layer,
|
|
}, indent=2))
|
|
|
|
# Sanity-check pmass on the BASE model (no steering): should be ~1.0,
|
|
# otherwise the prefill/schema isn't priming the right tokens.
|
|
a_ids, b_ids = collect_choice_token_ids(tok)
|
|
if a.compute_pmass:
|
|
logger.info(f"choice ids: a(true)={a_ids} b(false)={b_ids}")
|
|
|
|
# -- trajectory + pmass at each alpha on held-out prompts
|
|
rows = []
|
|
if a.fork_log:
|
|
# log-spaced including 0 and window: {0, 1, 2, 4, 8, ..., window}
|
|
import numpy as _np
|
|
raw = _np.unique(_np.round(_np.geomspace(1, a.window, a.n_log_forks - 1)).astype(int))
|
|
fork_points = [0] + [int(x) for x in raw]
|
|
fork_points = sorted(set(fp for fp in fork_points if fp <= a.window))
|
|
else:
|
|
fork_points = list(range(0, a.window + 1, a.fork_step))
|
|
logger.info(f"fork_points (n={len(fork_points)}): {fork_points}")
|
|
trajectory: dict[str, list] = {}
|
|
per_prompt_traj: dict[str, list] = {}
|
|
pmass_all: dict[str, list] = {}
|
|
pmass_eval_all: dict[str, list] = {} # pmass paired with EVAL_PROMPTS (same prompts as KL)
|
|
p_true_all: dict[str, list] = {}
|
|
argmax_all: dict[str, list] = {}
|
|
thinking_all: dict[str, list] = {}
|
|
answer_label_all: dict[str, list] = {}
|
|
gen_lens_qa: dict[str, list] = {} # T per (alpha, qa-prompt) -- right-censoring info
|
|
gen_lens_eval: dict[str, list] = {} # T per (alpha, eval-prompt)
|
|
debug_first: dict[str, dict] = {} # first prompt per alpha: gen_text + top-5 at each fork
|
|
|
|
# Append-only JSONL of every generation: full text + paired pmass.
|
|
# Crash-safe: each line flushed before next prompt. Overwritten only at run start.
|
|
gens_path = out_dir / "gens.jsonl"
|
|
if gens_path.exists():
|
|
gens_path.unlink()
|
|
def _append_gen(rec: dict) -> None:
|
|
with open(gens_path, "a") as f:
|
|
f.write(json.dumps(rec) + "\n")
|
|
|
|
def write_eval_outputs() -> None:
|
|
(out_dir / "trajectory.json").write_text(json.dumps({
|
|
"fork_points_full": list(range(a.window)),
|
|
"per_t_p95_kl": trajectory,
|
|
"per_prompt_per_t_kl": per_prompt_traj,
|
|
}, indent=2))
|
|
(out_dir / "pmass.json").write_text(json.dumps({
|
|
"fork_points": fork_points,
|
|
"pmass": pmass_all,
|
|
"pmass_eval": pmass_eval_all,
|
|
"p_true": p_true_all,
|
|
"argmax_str": argmax_all,
|
|
"was_thinking": thinking_all,
|
|
"answer_label": answer_label_all,
|
|
"gen_lens_qa": gen_lens_qa,
|
|
"gen_lens_eval": gen_lens_eval,
|
|
"debug_first": debug_first,
|
|
"prefill": PREFILL_STR,
|
|
"schema": _SCHEMA,
|
|
"questions": _QUESTIONS,
|
|
"computed": a.compute_pmass,
|
|
"completed_alphas": list(trajectory.keys()),
|
|
}, indent=2))
|
|
import csv
|
|
with open(out_dir / "results.csv", "w", newline="") as f:
|
|
w = csv.DictWriter(f, fieldnames=["alpha", "coeff", "kl_p95", "kl_mean", "kl_max"])
|
|
w.writeheader()
|
|
for r in rows:
|
|
w.writerow(r)
|
|
|
|
# Pre-build guided EVAL ids ONCE so measure_kl and pmass_eval roll out the
|
|
# same prompt -- otherwise spaghetti coloring (KL trajectory colored by
|
|
# pmass) is meaningless because the rollouts diverge.
|
|
eval_prompt_strs = [_build_guided_prompt(tok, p, _SCHEMA) for p in EVAL_PROMPTS]
|
|
eval_ids_list = [
|
|
tok(s, return_tensors="pt", add_special_tokens=False).input_ids[0]
|
|
for s in eval_prompt_strs
|
|
]
|
|
interrupt_suffix_ids = None
|
|
interrupt_suffix_decoded = ""
|
|
if a.chat_interrupt:
|
|
interrupt_suffix_ids = build_chat_interrupt_suffix(
|
|
tok, a.interrupt_user_text, a.interrupt_prefill,
|
|
)
|
|
# SHOULD: decoded suffix ends with prefill literal. ELSE template quirk.
|
|
interrupt_suffix_decoded = tok.decode(interrupt_suffix_ids)
|
|
logger.info(f"chat_interrupt suffix ({len(interrupt_suffix_ids)} tok): {interrupt_suffix_decoded!r}")
|
|
if not interrupt_suffix_decoded.endswith(a.interrupt_prefill):
|
|
logger.warning(f"interrupt suffix does not end with prefill={a.interrupt_prefill!r}")
|
|
for alpha in tqdm(a.alphas, desc=f"alphas[{a.run_id}]", mininterval=60):
|
|
v.cfg.coeff = alpha * c_star
|
|
logger.info(f"=== eval alpha={alpha} c={v.cfg.coeff:+.4f} ===")
|
|
m = measure_kl(v, model, tok, eval_ids_list, T=a.window, device=a.device)
|
|
trajectory[str(alpha)] = m["per_t_p95"]
|
|
per_prompt_traj[str(alpha)] = m["per_prompt_per_t"]
|
|
rows.append({"alpha": alpha, "coeff": v.cfg.coeff, "kl_p95": m["kl_p95"],
|
|
"kl_mean": m["kl_mean"], "kl_max": m["kl_max"]})
|
|
|
|
# pmass per held-out prompt
|
|
pm_for_alpha, pt_for_alpha, ax_for_alpha, wt_for_alpha, ans_for_alpha = [], [], [], [], []
|
|
gen_lens_qa_alpha = []
|
|
gen_lens_eval_alpha = []
|
|
if a.compute_pmass and not a.skip_pmass_qa:
|
|
for q_idx, q in enumerate(_QUESTIONS):
|
|
prompt_str = _build_guided_prompt(tok, q, _SCHEMA)
|
|
ids = tok(prompt_str, return_tensors="pt", add_special_tokens=False).input_ids[0]
|
|
pad = tok.pad_token_id
|
|
with v(model):
|
|
gen_out = model.generate(
|
|
ids.unsqueeze(0).to(a.device),
|
|
max_new_tokens=a.window,
|
|
pad_token_id=pad, eos_token_id=_eos_token_ids(tok),
|
|
do_sample=False,
|
|
use_cache=True,
|
|
return_dict_in_generate=True,
|
|
)
|
|
gen = gen_out.sequences[0, ids.shape[0]:]
|
|
gen_lens_qa_alpha.append(int(gen.shape[0]))
|
|
# KV cache from steered rollout: ~100x speedup at each fork.
|
|
pm = branch_pmass(
|
|
v, model, tok, ids, gen, fork_points,
|
|
PREFILL_STR, a_ids, b_ids,
|
|
rollout_cache=gen_out.past_key_values,
|
|
interrupt_suffix_ids=interrupt_suffix_ids,
|
|
device=a.device,
|
|
)
|
|
pm_for_alpha.append(pm["pmass"])
|
|
pt_for_alpha.append(pm["p_true"])
|
|
ax_for_alpha.append(pm["argmax_str"])
|
|
wt_for_alpha.append(pm["was_thinking"])
|
|
ans_for_alpha.append([
|
|
"true" if s in {"true", "True", " true", " True", "1", " 1"}
|
|
else "false" if s in {"false", "False", " false", " False", "0", " 0"}
|
|
else "other"
|
|
for s in pm["argmax_str"]
|
|
])
|
|
# save full gen text + paired pmass for every QA prompt
|
|
gen_text = tok.decode(gen, skip_special_tokens=False)
|
|
_append_gen({
|
|
"alpha": float(alpha), "kind": "qa", "prompt_idx": q_idx,
|
|
"score_mode": "chat_interrupt" if a.chat_interrupt else "raw_splice",
|
|
"interrupt_user_text": a.interrupt_user_text if a.chat_interrupt else "",
|
|
"interrupt_prefill": a.interrupt_prefill if a.chat_interrupt else "",
|
|
"interrupt_suffix_ids": list(interrupt_suffix_ids) if interrupt_suffix_ids is not None else [],
|
|
"interrupt_suffix_decoded": interrupt_suffix_decoded,
|
|
"prompt": prompt_str, "gen_text": gen_text, "gen_len": int(gen.shape[0]),
|
|
"fork_points": list(fork_points),
|
|
"pmass": pm["pmass"], "p_true": pm["p_true"],
|
|
"argmax_str": pm["argmax_str"], "was_thinking": pm["was_thinking"],
|
|
})
|
|
# debug dump for first QA prompt only (kept for backward compat)
|
|
if q_idx == 0:
|
|
debug_first.setdefault(str(alpha), {})["qa"] = {
|
|
"prompt": prompt_str, "gen_text": gen_text, "gen_len": int(gen.shape[0]),
|
|
"pmass_per_fork": pm["pmass"], "p_true_per_fork": pm["p_true"],
|
|
"argmax_per_fork": pm["argmax_str"],
|
|
}
|
|
_log_fork_summary("qa", alpha, gen, gen_text, fork_points, pm)
|
|
pmass_all[str(alpha)] = pm_for_alpha
|
|
p_true_all[str(alpha)] = pt_for_alpha
|
|
argmax_all[str(alpha)] = ax_for_alpha
|
|
thinking_all[str(alpha)] = wt_for_alpha
|
|
answer_label_all[str(alpha)] = ans_for_alpha
|
|
gen_lens_qa[str(alpha)] = gen_lens_qa_alpha
|
|
|
|
# Paired pmass on EVAL_PROMPTS (same long-form prompts as KL) so we can
|
|
# color KL trajectories by pmass at each fork. Long-form prompts won't
|
|
# naturally produce schema tokens, but the prefill forces the question
|
|
# "if you committed now, can you still produce a valid choice?" -- which
|
|
# is exactly the coherence signal we want.
|
|
pm_eval_for_alpha = []
|
|
if a.compute_pmass:
|
|
for p_idx, p in enumerate(EVAL_PROMPTS):
|
|
prompt_str = eval_prompt_strs[p_idx]
|
|
ids = eval_ids_list[p_idx]
|
|
pad = tok.pad_token_id
|
|
with v(model):
|
|
gen_out = model.generate(
|
|
ids.unsqueeze(0).to(a.device),
|
|
max_new_tokens=a.window,
|
|
pad_token_id=pad, eos_token_id=_eos_token_ids(tok),
|
|
do_sample=False,
|
|
use_cache=True,
|
|
return_dict_in_generate=True,
|
|
)
|
|
gen = gen_out.sequences[0, ids.shape[0]:]
|
|
gen_lens_eval_alpha.append(int(gen.shape[0]))
|
|
pm = branch_pmass(
|
|
v, model, tok, ids, gen, fork_points,
|
|
PREFILL_STR, a_ids, b_ids,
|
|
rollout_cache=gen_out.past_key_values,
|
|
interrupt_suffix_ids=interrupt_suffix_ids,
|
|
device=a.device,
|
|
)
|
|
pm_eval_for_alpha.append(pm["pmass"])
|
|
gen_text = tok.decode(gen, skip_special_tokens=False)
|
|
_append_gen({
|
|
"alpha": float(alpha), "kind": "eval", "prompt_idx": p_idx,
|
|
"score_mode": "chat_interrupt" if a.chat_interrupt else "raw_splice",
|
|
"interrupt_user_text": a.interrupt_user_text if a.chat_interrupt else "",
|
|
"interrupt_prefill": a.interrupt_prefill if a.chat_interrupt else "",
|
|
"interrupt_suffix_ids": list(interrupt_suffix_ids) if interrupt_suffix_ids is not None else [],
|
|
"interrupt_suffix_decoded": interrupt_suffix_decoded,
|
|
"prompt": prompt_str, "gen_text": gen_text, "gen_len": int(gen.shape[0]),
|
|
"fork_points": list(fork_points),
|
|
"pmass": pm["pmass"], "p_true": pm["p_true"],
|
|
"argmax_str": pm["argmax_str"], "was_thinking": pm["was_thinking"],
|
|
})
|
|
if p_idx == 0:
|
|
debug_first.setdefault(str(alpha), {})["eval"] = {
|
|
"prompt": prompt_str, "gen_text": gen_text, "gen_len": int(gen.shape[0]),
|
|
"pmass_per_fork": pm["pmass"], "p_true_per_fork": pm["p_true"],
|
|
"argmax_per_fork": pm["argmax_str"],
|
|
}
|
|
_log_fork_summary("eval", alpha, gen, gen_text, fork_points, pm)
|
|
pmass_eval_all[str(alpha)] = pm_eval_for_alpha
|
|
gen_lens_eval[str(alpha)] = gen_lens_eval_alpha
|
|
# SHOULD: at alpha=1, mean(pmass at t=0) > 0.5 (model still respects schema).
|
|
# ELSE: prefill string broken or chat template off.
|
|
if a.compute_pmass:
|
|
import numpy as _np
|
|
pmass_rows = pm_eval_for_alpha if a.skip_pmass_qa else pm_for_alpha
|
|
t0 = _np.array([row[0] for row in pmass_rows])
|
|
logger.info(f" alpha={alpha} pmass@t=0: mean={t0.mean():.3f} min={t0.min():.3f} max={t0.max():.3f}")
|
|
write_eval_outputs()
|
|
|
|
write_eval_outputs()
|
|
if a.render_figs:
|
|
repo_root = Path(__file__).resolve().parents[1]
|
|
cmd = [
|
|
sys.executable,
|
|
"scripts/render_run_figs.py",
|
|
"--run-dir", str(out_dir),
|
|
"--threshold", str(a.render_threshold),
|
|
"--model-contains", a.model.split("/")[-1],
|
|
]
|
|
logger.info("rendering single-run figures")
|
|
subprocess.run(cmd, check=True, cwd=repo_root)
|
|
logger.info(f"DONE -> {out_dir}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main(tyro.cli(Args))
|