mirror of
https://github.com/wassname/isokl_steering_calibration.git
synced 2026-06-27 15:15:52 +08:00
244 lines
9.8 KiB
Python
244 lines
9.8 KiB
Python
"""Empirical audit of branch_pmass: load model, generate one rollout per alpha,
|
|
print decoded gen text and top-k tokens at the prefill end-point, and dump
|
|
everything to JSON for review.
|
|
|
|
Goal: distinguish between
|
|
(a) pmass measurement is wrong (top tokens at prefill end DON'T match the
|
|
pmass we read out of the schema-token groups), vs
|
|
(b) pmass is right but the model just doesn't put mass on schema tokens
|
|
naturally (steering isn't the problem, the prompt+prefill is), vs
|
|
(c) pmass is right and steering really does collapse coherence.
|
|
|
|
For first prompt per alpha:
|
|
- decoded full generation text
|
|
- per fork point: top-10 tokens with prob, plus pmass(true)+pmass(false), p_true
|
|
|
|
Usage:
|
|
uv run --extra all python scripts/audit_pmass.py \
|
|
--model Qwen/Qwen3.5-0.8B --window 64 --out audit_pmass.json
|
|
"""
|
|
from __future__ import annotations
|
|
import json
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import tyro
|
|
from loguru import logger
|
|
|
|
# Import the live module so any audit reflects current code
|
|
from iso_kl_figure import (
|
|
MeanDiffC, PCAC, DirectionalAblationC,
|
|
train, calibrate_iso_kl,
|
|
)
|
|
from iso_kl_figure.branch_pmass import collect_choice_token_ids, branch_pmass
|
|
|
|
|
|
# Re-import constants from run_cell so audit uses the same prompts/schema
|
|
import importlib.util, sys
|
|
_rc_path = Path(__file__).parent / "run_cell.py"
|
|
_spec = importlib.util.spec_from_file_location("_run_cell", _rc_path)
|
|
_rc = importlib.util.module_from_spec(_spec)
|
|
sys.modules["_run_cell"] = _rc
|
|
_spec.loader.exec_module(_rc)
|
|
|
|
CALIB_PROMPTS = _rc.CALIB_PROMPTS
|
|
EVAL_PROMPTS = _rc.EVAL_PROMPTS
|
|
_QUESTIONS = _rc._QUESTIONS
|
|
_SCHEMA = _rc._SCHEMA
|
|
PREFILL_STR = _rc.PREFILL_STR
|
|
POS_NEG = _rc.POS_NEG
|
|
METHOD_MAP = {"mean_diff": MeanDiffC, "pca": PCAC, "directional_ablation": DirectionalAblationC}
|
|
|
|
|
|
@dataclass
|
|
class Args:
|
|
model: str = "Qwen/Qwen3.5-0.8B"
|
|
method: str = "mean_diff"
|
|
seed: int = 0
|
|
window: int = 64
|
|
layer_frac: float = 0.6
|
|
target_kl: float = 1.0
|
|
device: str = "cuda"
|
|
dtype: str = "bfloat16"
|
|
alphas: tuple[float, ...] = (0.0, 0.25, 0.5, 0.75, 1.0, 1.5, 2.0, 4.0)
|
|
fork_points: tuple[int, ...] = (0, 8, 16, 32, 64)
|
|
out: str = "audit_pmass.json"
|
|
top_k: int = 10
|
|
use_qa_prompt: bool = True # True: yes/no q+_SCHEMA; False: long-form EVAL_PROMPTS[0]
|
|
skip_calib: bool = False # if True, use fixed_coeffs as raw c per alpha (no iso-KL bisection)
|
|
fixed_coeffs: tuple[float, ...] = (0.0, 0.25, 0.5, 0.75, 1.0, 1.5, 2.0, 4.0) # used when skip_calib=True; aligned with --alphas
|
|
cross_check_branch_pmass: bool = True # also call branch_pmass and compare to local recompute
|
|
|
|
|
|
def _set_seed(s: int):
|
|
import random, numpy as np
|
|
random.seed(s); np.random.seed(s); torch.manual_seed(s)
|
|
if torch.cuda.is_available(): torch.cuda.manual_seed_all(s)
|
|
|
|
|
|
@torch.no_grad()
|
|
def topk_at_prefill_end(v, model, tok, prompt_ids, rolled_ids, fork_points,
|
|
prefill_str, a_ids, b_ids, k=10, device="cuda"):
|
|
"""Mirror branch_pmass logic but ALSO return top-k tokens for inspection."""
|
|
import copy
|
|
pids = prompt_ids.to(device); rolled = rolled_ids.to(device)
|
|
P = pids.shape[0]; T = rolled.shape[0]
|
|
pre_t = torch.tensor(tok.encode(prefill_str, add_special_tokens=False),
|
|
device=device, dtype=torch.long)
|
|
a_t = torch.tensor(list(a_ids), dtype=torch.long, device=device)
|
|
b_t = torch.tensor(list(b_ids), dtype=torch.long, device=device)
|
|
all_t = torch.cat([a_t, b_t])
|
|
out_per_fork = []
|
|
for t in fork_points:
|
|
if t > T:
|
|
out_per_fork.append({"t": int(t), "skipped": True, "reason": f"t>T={T}"})
|
|
continue
|
|
prefix = rolled[:t]
|
|
seq = torch.cat([pids, prefix, pre_t]).unsqueeze(0)
|
|
with v(model):
|
|
logits = model(seq).logits[0, -1].float()
|
|
probs = torch.softmax(logits, dim=-1)
|
|
# top-k
|
|
tk_p, tk_i = probs.topk(k)
|
|
topk = [(tok.decode([int(i)]), float(p)) for p, i in zip(tk_p.tolist(), tk_i.tolist())]
|
|
pa = float(probs[a_t].sum()); pb = float(probs[b_t].sum())
|
|
pm = pa + pb
|
|
pt = pa / pm if pm > 0 else float("nan")
|
|
out_per_fork.append({
|
|
"t": int(t),
|
|
"skipped": False,
|
|
"topk": topk,
|
|
"p_true_group": pa,
|
|
"p_false_group": pb,
|
|
"pmass": pm,
|
|
"p_true": pt,
|
|
"argmax": tok.decode([int(probs.argmax())]),
|
|
})
|
|
return out_per_fork
|
|
|
|
|
|
def main(a: Args):
|
|
_set_seed(a.seed)
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
dtype = getattr(torch, a.dtype)
|
|
logger.info(f"loading model={a.model}")
|
|
tok = AutoTokenizer.from_pretrained(a.model)
|
|
if tok.pad_token_id is None: tok.pad_token_id = tok.eos_token_id
|
|
model = AutoModelForCausalLM.from_pretrained(a.model, torch_dtype=dtype).to(a.device)
|
|
model.eval()
|
|
|
|
n_layers = model.config.num_hidden_layers
|
|
layer = int(a.layer_frac * n_layers)
|
|
cfg_cls = METHOD_MAP[a.method]
|
|
cfg = cfg_cls(coeff=1.0, layers=(layer,))
|
|
|
|
# Train + calibrate
|
|
pos = [tok.apply_chat_template([{"role": "user", "content": u},
|
|
{"role": "assistant", "content": p}], tokenize=False)
|
|
for u, (p, _) in zip(CALIB_PROMPTS, POS_NEG)]
|
|
neg = [tok.apply_chat_template([{"role": "user", "content": u},
|
|
{"role": "assistant", "content": n}], tokenize=False)
|
|
for u, (_, n) in zip(CALIB_PROMPTS, POS_NEG)]
|
|
v = train(model, tok, pos, neg, cfg, batch_size=4, max_length=128)
|
|
if a.skip_calib:
|
|
c_star = 1.0 # fixed_coeffs are absolute
|
|
logger.info(f"skip_calib: using fixed_coeffs={a.fixed_coeffs} as raw c")
|
|
else:
|
|
c_star, _ = calibrate_iso_kl(v, model, tok, CALIB_PROMPTS,
|
|
target_kl=a.target_kl, target_stat="kl_p95",
|
|
T=a.window, device=a.device)
|
|
logger.info(f"c_star={c_star:+.4f}")
|
|
|
|
a_ids, b_ids = collect_choice_token_ids(tok)
|
|
logger.info(f"a_ids (true-group)={a_ids} -> tokens={[tok.decode([i]) for i in a_ids]}")
|
|
logger.info(f"b_ids (false-group)={b_ids} -> tokens={[tok.decode([i]) for i in b_ids]}")
|
|
|
|
# Single prompt
|
|
if a.use_qa_prompt:
|
|
prompt = f"{_QUESTIONS[0]}\n\n{_SCHEMA}"
|
|
else:
|
|
prompt = EVAL_PROMPTS[0]
|
|
logger.info(f"prompt: {prompt[:120]}...")
|
|
|
|
ids = tok.apply_chat_template([{"role": "user", "content": prompt}],
|
|
add_generation_prompt=True,
|
|
return_tensors="pt").input_ids[0]
|
|
|
|
out = {
|
|
"model": a.model, "method": a.method, "seed": a.seed,
|
|
"c_star": c_star, "layer": layer,
|
|
"prompt": prompt, "use_qa_prompt": a.use_qa_prompt,
|
|
"prefill": PREFILL_STR,
|
|
"a_ids": a_ids, "b_ids": b_ids,
|
|
"a_tokens": [tok.decode([i]) for i in a_ids],
|
|
"b_tokens": [tok.decode([i]) for i in b_ids],
|
|
"fork_points": list(a.fork_points),
|
|
"alphas": {},
|
|
}
|
|
|
|
for i_alpha, alpha in enumerate(a.alphas):
|
|
if a.skip_calib:
|
|
v.cfg.coeff = a.fixed_coeffs[i_alpha]
|
|
else:
|
|
v.cfg.coeff = alpha * c_star
|
|
logger.info(f"=== alpha={alpha} coeff={v.cfg.coeff:+.4f} ===")
|
|
with v(model):
|
|
gen_out = model.generate(
|
|
ids.unsqueeze(0).to(a.device),
|
|
max_new_tokens=a.window,
|
|
pad_token_id=tok.pad_token_id,
|
|
eos_token_id=tok.eos_token_id,
|
|
do_sample=False,
|
|
return_dict_in_generate=True,
|
|
)
|
|
gen = gen_out.sequences[0, ids.shape[0]:]
|
|
gen_text = tok.decode(gen, skip_special_tokens=False)
|
|
gen_len = int(gen.shape[0])
|
|
logger.info(f" gen_len={gen_len} text[:200]={gen_text[:200]!r}")
|
|
|
|
per_fork = topk_at_prefill_end(
|
|
v, model, tok, ids, gen, list(a.fork_points),
|
|
PREFILL_STR, a_ids, b_ids, k=a.top_k, device=a.device,
|
|
)
|
|
# cross-check: compare to production branch_pmass output
|
|
bp_compare = None
|
|
if a.cross_check_branch_pmass:
|
|
bp = branch_pmass(
|
|
v, model, tok, ids, gen, list(a.fork_points),
|
|
PREFILL_STR, a_ids, b_ids,
|
|
rollout_cache=getattr(gen_out, "past_key_values", None),
|
|
device=a.device,
|
|
)
|
|
bp_compare = {
|
|
"pmass": bp["pmass"], "p_true": bp["p_true"],
|
|
"argmax_str": bp["argmax_str"], "was_thinking": bp["was_thinking"],
|
|
}
|
|
for row, bpm, bpt, bam in zip(per_fork, bp["pmass"], bp["p_true"], bp["argmax_str"]):
|
|
if row.get("skipped"): continue
|
|
local_pm = row["pmass"]; local_pt = row["p_true"]; local_am = row["argmax"]
|
|
tag = "OK"
|
|
if not (abs(local_pm - bpm) < 1e-3 and (local_am == bam)):
|
|
tag = "MISMATCH"
|
|
logger.info(f" cross-check t={row['t']:>3} {tag}: local pm={local_pm:.4f} argmax={local_am!r} | branch_pmass pm={bpm:.4f} argmax={bam!r}")
|
|
# log inline
|
|
for row in per_fork:
|
|
if row.get("skipped"):
|
|
logger.info(f" t={row['t']:>3} SKIPPED ({row['reason']})")
|
|
else:
|
|
top3 = ", ".join(f"{tok!r}={p:.3f}" for tok, p in row["topk"][:3])
|
|
logger.info(f" t={row['t']:>3} pmass={row['pmass']:.3f} "
|
|
f"p_true={row['p_true']:.3f} argmax={row['argmax']!r} top3=[{top3}]")
|
|
out["alphas"][str(alpha)] = {
|
|
"coeff": float(v.cfg.coeff),
|
|
"gen_text": gen_text, "gen_len": gen_len, "per_fork": per_fork,
|
|
"branch_pmass_compare": bp_compare,
|
|
}
|
|
|
|
Path(a.out).write_text(json.dumps(out, indent=2, default=str))
|
|
logger.info(f"DONE -> {a.out}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main(tyro.cli(Args))
|