Files
isokl_steering_calibration/scripts/run_cell.py
T

411 lines
19 KiB
Python

"""End-to-end runner for one (model, method, seed, window) cell.
Artefacts (per cell, under outputs/<run_id>/):
calib.json c_star + run metadata (model, method, seed, layer, window).
history.json bisection history (without per-token KL arrays).
trajectory.json per-token KL on EVAL_PROMPTS for each alpha:
per_t_p95_kl[alpha]: list[T]
per_prompt_per_t_kl[alpha]: list[N_prompts][T]
pmass.json forked-answer probability mass at fork_points:
pmass[alpha]: yes/no reasoning prompts (legacy probe)
pmass_eval[alpha]: SAME prompts as trajectory.json (paired w KL)
gen_lens_qa[alpha], gen_lens_eval[alpha]: T per rollout
(use for right-censoring -- NaN at t > T means rollout
EOS'd before that fork, NOT a measurement failure).
debug_first[alpha]: gen_text + per-fork pmass for the
FIRST qa & eval prompt (sanity-check; see survival.py).
pmass_eval is the metric to use when you want a per-trajectory
coherence signal aligned with the KL trajectory.
results.csv one row per alpha with kl_p95/mean/max.
Flow:
1. Load model + tokenizer; set seed.
2. Build pos/neg pair (cheap content-vs-refusal); train Vector v.
3. Calibrate iso-KL: bisect coeff so per-token kl_p95 hits target_kl (default 1)
over T=window tokens on CALIB_PROMPTS.
4. For each alpha, set coeff=alpha*c_star and:
a) measure_kl on EVAL_PROMPTS -> per_prompt_per_t_kl trajectories.
b) (--compute-pmass) yes/no questions: rollout, then at each fork point
prefill `\n{"choice": ` and read mass on {true/True/ true/.../1/ 1}
and {false/.../0}. Skip with --skip-pmass-qa.
c) (--compute-pmass) repeat (b) on EVAL_PROMPTS -> pmass_eval, paired
with the KL trajectory on the same rollouts. This is the honest
coherence signal; KL is largely a calibration tautology at alpha=1.
Fork-point modes:
--fork-step N linear: {0, N, 2N, ..., window}
--fork-log log-spaced: {0, 1, 2, 4, 8, ..., window} (with --n-log-forks)
KV-cache:
Rollouts use generate(..., return_dict_in_generate=True). branch_pmass
uses full-recompute by default (cache reuse adds ~10% bf16 noise; verified
by scripts/audit_pmass.py). Pass use_rollout_cache=True only if you accept
the noise.
Usage (tyro):
python scripts/run_cell.py --model Qwen/Qwen2.5-0.5B-Instruct --method mean_diff \
--seed 0 --window 512 --compute-pmass --skip-pmass-qa --fork-log
"""
from __future__ import annotations
import json
import os
import subprocess
import sys
from dataclasses import dataclass, field
from pathlib import Path
import torch
import tyro
from loguru import logger
from iso_kl_figure import (
SteeringConfig, MeanDiffC, PCAC, DirectionalAblationC,
train, calibrate_iso_kl, measure_kl,
)
from iso_kl_figure.branch_pmass import branch_pmass, collect_choice_token_ids
CALIB_PROMPTS = [
"The eiffel tower is in Paris.",
"Tell me a greentext story about a small village during the Martian carrot bubble.",
"Walk me through computing the integral of x^2 from 0 to 1.",
"Briefly describe how plate tectonics shape mountain ranges.",
]
# EVAL prompts are yes/no questions paired with schema -- same contract as
# tinymfv/guided.py so branch_pmass at any fork can score format-compliance.
# The model is steered into <think>\n at the prompt boundary; at fork t we
# inject "\nI should answer now.</think>\n{\"choice\": " and read p(true|false).
EVAL_PROMPTS = [
"Is the sky blue primarily because of Rayleigh scattering of sunlight?",
"Does Hamlet die at the end of Shakespeare's play 'Hamlet'?",
"Is excessive money supply growth one cause of inflation in modern economies?",
"Do octopuses have three hearts and blue copper-based blood?",
"Does photosynthesis use light energy to convert CO2 and water into glucose?",
"Was the Eiffel Tower originally built as the entrance arch for the 1889 World's Fair?",
"Can plate tectonic collisions form mountain ranges over geological time?",
"Is a circular import one possible cause of a Python script crashing on import?",
]
# pmass diagnostic prompts share the same schema as EVAL above.
_SCHEMA = (
'Reason for at least 80 words before the final answer. Then answer only with '
'{"choice": true} or {"choice": false}. If you use numeric JSON, use 1 for true and 0 for false.'
)
_QUESTIONS = [
"Is the Eiffel Tower located in Paris, France?",
"Do dolphins breathe air through lungs rather than gills?",
"Was Albert Einstein born in Germany?",
"Can humans naturally see ultraviolet light?",
"Did the Roman Empire fall in 476 CE?",
"Are tomatoes botanically classified as fruits?",
"Is the speed of light approximately 300,000 km/s?",
"Do octopuses have three hearts?",
]
PREFILL_STR = '\n{"choice": '
# Pos/neg pair: a generic content-vs-refusal direction. Cheap, not the main
# point -- the figure tests calibration *behaviour*, not direction quality.
POS_NEG = [
("Sure, here is the answer:", "I cannot help with that."),
("Here are some details:", "Sorry, I can't assist with that."),
("Of course, let me explain.", "I won't be able to help."),
("Yes, that makes sense.", "No, I have to decline."),
]
METHOD_MAP = {
"mean_diff": MeanDiffC,
"pca": PCAC,
"directional_ablation": DirectionalAblationC,
}
@dataclass
class Args:
model: str
method: str
seed: int = 0
window: int = 50
run_id: str = ""
layer_frac: float = 0.6
target_kl: float = 1.0
out_root: str = "outputs"
device: str = "cuda"
dtype: str = "bfloat16"
alphas: tuple[float, ...] = (0.25, 0.5, 0.75, 1.0, 1.5, 2.0, 4.0)
fork_step: int = 5
fork_log: bool = False # log-spaced fork points {0,1,2,4,8,...,window}
n_log_forks: int = 14 # number of log-spaced forks (incl. 0)
compute_pmass: bool = False
skip_pmass_qa: bool = False # skip yes/no pmass loop, only do paired pmass_eval
render_figs: bool = False # render single-run survival + spaghetti + KL pngs
render_threshold: float = 0.95
def _set_seed(s: int):
import random
import numpy as np
random.seed(s); np.random.seed(s); torch.manual_seed(s)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(s)
def _build_guided_prompt(tok, user_text: str, schema_hint: str = _SCHEMA) -> str:
"""Match tinymfv/guided.py: chat template + '<think>\\n' suffix so the model
is in thinking mode at every fork point. branch_pmass detects this and
splices '\\nI should answer now.</think>{prefill}' before scoring."""
full_user = f"{user_text}\n\n{schema_hint}" if schema_hint else user_text
msgs = [{"role": "user", "content": full_user}]
try:
p = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
except TypeError:
p = tok.apply_chat_template(msgs, tokenize=False)
return p + "<think>\n"
def main(a: Args):
if not a.run_id:
a.run_id = f"{a.model.split('/')[-1]}_{a.method}_s{a.seed}_w{a.window}"
out_dir = Path(a.out_root) / a.run_id
out_dir.mkdir(parents=True, exist_ok=True)
logger.add(out_dir / "run.log", level="INFO")
_set_seed(a.seed)
from transformers import AutoModelForCausalLM, AutoTokenizer
dtype = getattr(torch, a.dtype)
tok = AutoTokenizer.from_pretrained(a.model)
if tok.pad_token_id is None:
tok.pad_token_id = tok.eos_token_id
model = AutoModelForCausalLM.from_pretrained(a.model, torch_dtype=dtype).to(a.device)
model.eval()
n_layers = model.config.num_hidden_layers
layer = int(a.layer_frac * n_layers)
logger.info(f"model={a.model} n_layers={n_layers} target_layer={layer}")
cfg_cls = METHOD_MAP[a.method]
cfg = cfg_cls(coeff=1.0, layers=(layer,))
pos = [tok.apply_chat_template([{"role": "user", "content": u},
{"role": "assistant", "content": p}],
tokenize=False)
for u, (p, _) in zip(CALIB_PROMPTS, POS_NEG)]
neg = [tok.apply_chat_template([{"role": "user", "content": u},
{"role": "assistant", "content": n}],
tokenize=False)
for u, (_, n) in zip(CALIB_PROMPTS, POS_NEG)]
v = train(model, tok, pos, neg, cfg, batch_size=4, max_length=128)
logger.info("=== calibrate ===")
c_star, history = calibrate_iso_kl(
v, model, tok, CALIB_PROMPTS,
target_kl=a.target_kl, target_stat="kl_p95",
T=a.window, device=a.device,
)
v.cfg.coeff = c_star
logger.info(f"c_star = {c_star:+.4f}")
# Strip per_prompt_per_t from history to keep file size small.
hist_slim = [{k: v_ for k, v_ in h.items() if k != "per_prompt_per_t"}
for h in history]
(out_dir / "history.json").write_text(json.dumps(hist_slim, indent=2))
(out_dir / "calib.json").write_text(json.dumps({
"c_star": c_star, "target_kl": a.target_kl, "window": a.window,
"method": a.method, "model": a.model, "seed": a.seed, "layer": layer,
}, indent=2))
# Sanity-check pmass on the BASE model (no steering): should be ~1.0,
# otherwise the prefill/schema isn't priming the right tokens.
a_ids, b_ids = collect_choice_token_ids(tok)
if a.compute_pmass:
logger.info(f"choice ids: a(true)={a_ids} b(false)={b_ids}")
# -- trajectory + pmass at each alpha on held-out prompts
rows = []
if a.fork_log:
# log-spaced including 0 and window: {0, 1, 2, 4, 8, ..., window}
import numpy as _np
raw = _np.unique(_np.round(_np.geomspace(1, a.window, a.n_log_forks - 1)).astype(int))
fork_points = [0] + [int(x) for x in raw]
fork_points = sorted(set(fp for fp in fork_points if fp <= a.window))
else:
fork_points = list(range(0, a.window + 1, a.fork_step))
logger.info(f"fork_points (n={len(fork_points)}): {fork_points}")
trajectory: dict[str, list] = {}
per_prompt_traj: dict[str, list] = {}
pmass_all: dict[str, list] = {}
pmass_eval_all: dict[str, list] = {} # pmass paired with EVAL_PROMPTS (same prompts as KL)
p_true_all: dict[str, list] = {}
argmax_all: dict[str, list] = {}
thinking_all: dict[str, list] = {}
answer_label_all: dict[str, list] = {}
gen_lens_qa: dict[str, list] = {} # T per (alpha, qa-prompt) -- right-censoring info
gen_lens_eval: dict[str, list] = {} # T per (alpha, eval-prompt)
debug_first: dict[str, dict] = {} # first prompt per alpha: gen_text + top-5 at each fork
# Pre-build guided EVAL ids ONCE so measure_kl and pmass_eval roll out the
# same prompt -- otherwise spaghetti coloring (KL trajectory colored by
# pmass) is meaningless because the rollouts diverge.
eval_prompt_strs = [_build_guided_prompt(tok, p, _SCHEMA) for p in EVAL_PROMPTS]
eval_ids_list = [
tok(s, return_tensors="pt", add_special_tokens=False).input_ids[0]
for s in eval_prompt_strs
]
for alpha in a.alphas:
v.cfg.coeff = alpha * c_star
logger.info(f"=== eval alpha={alpha} c={v.cfg.coeff:+.4f} ===")
m = measure_kl(v, model, tok, eval_ids_list, T=a.window, device=a.device)
trajectory[str(alpha)] = m["per_t_p95"]
per_prompt_traj[str(alpha)] = m["per_prompt_per_t"]
rows.append({"alpha": alpha, "coeff": v.cfg.coeff, "kl_p95": m["kl_p95"],
"kl_mean": m["kl_mean"], "kl_max": m["kl_max"]})
# pmass per held-out prompt
pm_for_alpha, pt_for_alpha, ax_for_alpha, wt_for_alpha, ans_for_alpha = [], [], [], [], []
gen_lens_qa_alpha = []
gen_lens_eval_alpha = []
if a.compute_pmass and not a.skip_pmass_qa:
for q_idx, q in enumerate(_QUESTIONS):
prompt_str = _build_guided_prompt(tok, q, _SCHEMA)
ids = tok(prompt_str, return_tensors="pt", add_special_tokens=False).input_ids[0]
pad = tok.pad_token_id
with v(model):
gen_out = model.generate(
ids.unsqueeze(0).to(a.device),
max_new_tokens=a.window,
pad_token_id=pad, eos_token_id=tok.eos_token_id,
do_sample=False,
use_cache=True,
return_dict_in_generate=True,
)
gen = gen_out.sequences[0, ids.shape[0]:]
gen_lens_qa_alpha.append(int(gen.shape[0]))
# KV cache from steered rollout: ~100x speedup at each fork.
pm = branch_pmass(
v, model, tok, ids, gen, fork_points,
PREFILL_STR, a_ids, b_ids,
rollout_cache=gen_out.past_key_values,
device=a.device,
)
pm_for_alpha.append(pm["pmass"])
pt_for_alpha.append(pm["p_true"])
ax_for_alpha.append(pm["argmax_str"])
wt_for_alpha.append(pm["was_thinking"])
ans_for_alpha.append([
"true" if s in {"true", "True", " true", " True", "1", " 1"}
else "false" if s in {"false", "False", " false", " False", "0", " 0"}
else "other"
for s in pm["argmax_str"]
])
# debug dump for first QA prompt only
if q_idx == 0:
gen_text = tok.decode(gen, skip_special_tokens=False)
debug_first.setdefault(str(alpha), {})["qa"] = {
"prompt": prompt_str, "gen_text": gen_text, "gen_len": int(gen.shape[0]),
"pmass_per_fork": pm["pmass"], "p_true_per_fork": pm["p_true"],
"argmax_per_fork": pm["argmax_str"],
}
logger.info(f" [debug] alpha={alpha} qa[0] gen_len={gen.shape[0]} text[:120]={gen_text[:120]!r}")
for t, pmv, ptv, am in zip(fork_points, pm["pmass"], pm["p_true"], pm["argmax_str"]):
logger.info(f" t={t:>3} pmass={pmv:.3f} p_true={ptv:.3f} argmax={am!r}")
pmass_all[str(alpha)] = pm_for_alpha
p_true_all[str(alpha)] = pt_for_alpha
argmax_all[str(alpha)] = ax_for_alpha
thinking_all[str(alpha)] = wt_for_alpha
answer_label_all[str(alpha)] = ans_for_alpha
gen_lens_qa[str(alpha)] = gen_lens_qa_alpha
# Paired pmass on EVAL_PROMPTS (same long-form prompts as KL) so we can
# color KL trajectories by pmass at each fork. Long-form prompts won't
# naturally produce schema tokens, but the prefill forces the question
# "if you committed now, can you still produce a valid choice?" -- which
# is exactly the coherence signal we want.
pm_eval_for_alpha = []
if a.compute_pmass:
for p_idx, p in enumerate(EVAL_PROMPTS):
prompt_str = eval_prompt_strs[p_idx]
ids = eval_ids_list[p_idx]
pad = tok.pad_token_id
with v(model):
gen_out = model.generate(
ids.unsqueeze(0).to(a.device),
max_new_tokens=a.window,
pad_token_id=pad, eos_token_id=tok.eos_token_id,
do_sample=False,
use_cache=True,
return_dict_in_generate=True,
)
gen = gen_out.sequences[0, ids.shape[0]:]
gen_lens_eval_alpha.append(int(gen.shape[0]))
pm = branch_pmass(
v, model, tok, ids, gen, fork_points,
PREFILL_STR, a_ids, b_ids,
rollout_cache=gen_out.past_key_values,
device=a.device,
)
pm_eval_for_alpha.append(pm["pmass"])
if p_idx == 0:
gen_text = tok.decode(gen, skip_special_tokens=False)
debug_first.setdefault(str(alpha), {})["eval"] = {
"prompt": prompt_str, "gen_text": gen_text, "gen_len": int(gen.shape[0]),
"pmass_per_fork": pm["pmass"], "p_true_per_fork": pm["p_true"],
"argmax_per_fork": pm["argmax_str"],
}
logger.info(f" [debug] alpha={alpha} eval[0] gen_len={gen.shape[0]} text[:120]={gen_text[:120]!r}")
for t, pmv, ptv, am in zip(fork_points, pm["pmass"], pm["p_true"], pm["argmax_str"]):
logger.info(f" t={t:>3} pmass={pmv:.3f} p_true={ptv:.3f} argmax={am!r}")
pmass_eval_all[str(alpha)] = pm_eval_for_alpha
gen_lens_eval[str(alpha)] = gen_lens_eval_alpha
# SHOULD: at alpha=1, mean(pmass at t=0) > 0.5 (model still respects schema).
# ELSE: prefill string broken or chat template off.
if a.compute_pmass:
try:
import numpy as _np
t0 = _np.array([row[0] for row in pm_for_alpha])
logger.info(f" alpha={alpha} pmass@t=0: mean={t0.mean():.3f} min={t0.min():.3f} max={t0.max():.3f}")
except Exception:
pass
(out_dir / "trajectory.json").write_text(json.dumps({
"fork_points_full": list(range(a.window)),
"per_t_p95_kl": trajectory,
"per_prompt_per_t_kl": per_prompt_traj,
}, indent=2))
(out_dir / "pmass.json").write_text(json.dumps({
"fork_points": fork_points,
"pmass": pmass_all,
"pmass_eval": pmass_eval_all,
"p_true": p_true_all,
"argmax_str": argmax_all,
"was_thinking": thinking_all,
"answer_label": answer_label_all,
"gen_lens_qa": gen_lens_qa, # T per (alpha, qa-prompt) for right-censoring
"gen_lens_eval": gen_lens_eval, # T per (alpha, eval-prompt) for right-censoring
"debug_first": debug_first, # first prompt per alpha: gen_text + per-fork pmass/argmax
"prefill": PREFILL_STR,
"schema": _SCHEMA,
"questions": _QUESTIONS,
"computed": a.compute_pmass,
}, indent=2))
import csv
with open(out_dir / "results.csv", "w", newline="") as f:
w = csv.DictWriter(f, fieldnames=["alpha", "coeff", "kl_p95", "kl_mean", "kl_max"])
w.writeheader()
for r in rows:
w.writerow(r)
if a.render_figs:
repo_root = Path(__file__).resolve().parents[1]
cmd = [
sys.executable,
"scripts/render_run_figs.py",
"--run-dir", str(out_dir),
"--threshold", str(a.render_threshold),
"--model-contains", a.model.split("/")[-1],
]
logger.info("rendering single-run figures")
subprocess.run(cmd, check=True, cwd=repo_root)
logger.info(f"DONE -> {out_dir}")
if __name__ == "__main__":
main(tyro.cli(Args))