Add batched data gen and bidir calibration

This commit is contained in:
wassname
2026-05-01 18:58:08 +08:00
parent b2ef8fef7b
commit a0f4e719af
7 changed files with 535 additions and 133 deletions
+7
View File
@@ -49,6 +49,13 @@ def main(cfg: SmokeCfg) -> None:
n_topics=2, # 2×1×2 = 4 pairs
n_personas=1,
n_samples=2,
data_batch_size=2,
data_min_new_tokens=16,
data_max_new_tokens=32,
data_temperature=0.7,
data_top_p=0.8,
data_top_k=20,
data_min_p=0.0,
)
replicate_main(rcfg)
print("[smoke] OK", flush=True)
+5 -1
View File
@@ -31,7 +31,11 @@ smoke-sweep:
# Generate +/- pair data for a behavior. Writes to out/data/{behavior}/.
data:
uv run python -m ws.data --model {{model}} --behavior {{behavior}}
uv run python -m ws.data --model-id {{model}} --behavior {{behavior}}
# One-off greedy persona collapse debugger.
debug-personas:
uv run python -m ws.debug_personas --model {{model}} --behavior {{behavior}} --out {{out}}
# Train a single adapter (positive or negative). Pos/neg controls system prompt at gen time.
train sign="pos":
+200 -25
View File
@@ -18,16 +18,19 @@ Output columns:
from __future__ import annotations
import json
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any
import torch
import tyro
from datasets import Dataset
from loguru import logger
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from ws._tok_extras import chat_template_extras
from ws._log import get_argv, setup_logging
from ws._tok_extras import chat_template_extras, has_thinking_mode
REPO_ROOT = Path(__file__).resolve().parents[2]
DATA_DIR = REPO_ROOT / "data"
@@ -162,8 +165,14 @@ class DataCfg:
n_personas: int = 5
n_samples: int = 10
out: Path = Path("out/data")
max_new_tokens: int = 96
temperature: float = 0.8
batch_size: int = 8
min_new_tokens: int = 1024
max_new_tokens: int = 1280
temperature: float | None = None
top_p: float | None = None
top_k: int | None = None
min_p: float | None = None
presence_penalty: float = 0.0
seed: int = 0
@@ -212,21 +221,127 @@ def _build_specs(topics, n_personas: int, n_samples: int, behavior: str):
return specs
@torch.no_grad()
def _gen(model, tok, sys_prompt: str, user_prompt: str, max_new_tokens: int, temperature: float):
def _render_chat_prompt(tok, sys_prompt: str, user_prompt: str) -> str:
msgs = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": user_prompt}]
text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True,
**chat_template_extras(tok))
inputs = tok(text, return_tensors="pt").to(model.device)
out = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=temperature > 0,
temperature=temperature if temperature > 0 else 1.0,
pad_token_id=tok.pad_token_id or tok.eos_token_id,
return tok.apply_chat_template(
msgs,
tokenize=False,
add_generation_prompt=True,
**chat_template_extras(tok),
)
gen = out[0, inputs["input_ids"].shape[1]:]
return tok.decode(gen, skip_special_tokens=True).strip()
def _sampling_defaults(tok, cfg: DataCfg) -> dict[str, Any]:
thinking = has_thinking_mode(tok)
defaults = {
"temperature": 0.6 if thinking else 0.7,
"top_p": 0.95 if thinking else 0.8,
"top_k": 20,
"min_p": 0.0,
}
params = {
"temperature": defaults["temperature"] if cfg.temperature is None else cfg.temperature,
"top_p": defaults["top_p"] if cfg.top_p is None else cfg.top_p,
"top_k": defaults["top_k"] if cfg.top_k is None else cfg.top_k,
"min_p": defaults["min_p"] if cfg.min_p is None else cfg.min_p,
}
if params["temperature"] <= 0:
logger.warning(
"data generation temperature<=0 enables greedy decoding. "
"Qwen recommends sampling for training data; use this only deliberately."
)
if cfg.presence_penalty:
logger.warning(
"presence_penalty is configured but this HF generate path does not support it directly; ignoring."
)
return params
def _trim_generation_ids(ids: torch.Tensor, eos_id: int | None, pad_id: int | None) -> tuple[torch.Tensor, bool]:
trimmed: list[int] = []
hit_eos = False
for tok_id in ids.tolist():
if pad_id is not None and tok_id == pad_id:
continue
trimmed.append(tok_id)
if eos_id is not None and tok_id == eos_id:
hit_eos = True
break
return torch.tensor(trimmed, dtype=torch.long), hit_eos
def _normalize_text(text: str) -> str:
return " ".join(text.split())
def _log_trace(tok, *, prompt_text: str, gen_ids: torch.Tensor, clean_text: str, label: str) -> None:
prompt_ids = tok(prompt_text, return_tensors="pt").input_ids[0]
first = tok.convert_ids_to_tokens(prompt_ids[: min(8, len(prompt_ids))].tolist())
last = tok.convert_ids_to_tokens(prompt_ids[-min(8, len(prompt_ids)):].tolist())
raw_gen = tok.decode(gen_ids, skip_special_tokens=False)
raw_toks = tok.convert_ids_to_tokens(gen_ids.tolist()) if len(gen_ids) else []
logger.info(f"[{label}] full prompt (special tokens included):\n{prompt_text}")
logger.info(f"[{label}] n_input_tokens={prompt_ids.shape[0]} first8={first} last8={last}")
logger.info(f"[{label}] raw generated continuation: {raw_gen!r}")
logger.info(f"[{label}] generated tokens: {raw_toks}")
logger.info(f"[{label}] cleaned continuation: {clean_text!r}")
@torch.no_grad()
def _generate_batch(
model,
tok,
prompts: list[str],
cfg: DataCfg,
sampling: dict[str, Any],
*,
trace_label: str,
) -> list[dict[str, Any]]:
if cfg.max_new_tokens < cfg.min_new_tokens:
raise ValueError(
f"max_new_tokens={cfg.max_new_tokens} must be >= min_new_tokens={cfg.min_new_tokens}"
)
results: list[dict[str, Any]] = []
old_padding_side = tok.padding_side
tok.padding_side = "left"
try:
for start in tqdm(range(0, len(prompts), cfg.batch_size), desc=f"gen {trace_label}", mininterval=60):
batch_prompts = prompts[start:start + cfg.batch_size]
enc = tok(batch_prompts, return_tensors="pt", padding=True).to(model.device)
out = model.generate(
**enc,
min_new_tokens=cfg.min_new_tokens,
max_new_tokens=cfg.max_new_tokens,
do_sample=sampling["temperature"] > 0,
temperature=sampling["temperature"] if sampling["temperature"] > 0 else 1.0,
top_p=sampling["top_p"],
top_k=sampling["top_k"],
min_p=sampling["min_p"],
pad_token_id=tok.pad_token_id or tok.eos_token_id,
eos_token_id=tok.eos_token_id,
)
gen_block = out[:, enc["input_ids"].shape[1]:].cpu()
for i, prompt_text in enumerate(batch_prompts):
raw_ids, hit_eos = _trim_generation_ids(gen_block[i], tok.eos_token_id, tok.pad_token_id)
clean = tok.decode(raw_ids, skip_special_tokens=True).rstrip()
results.append({
"clean_text": clean,
"raw_text": tok.decode(raw_ids, skip_special_tokens=False),
"gen_ids": raw_ids,
"hit_eos": hit_eos,
"prompt_text": prompt_text,
})
if results:
_log_trace(
tok,
prompt_text=results[0]["prompt_text"],
gen_ids=results[0]["gen_ids"],
clean_text=results[0]["clean_text"],
label=trace_label,
)
finally:
tok.padding_side = old_padding_side
return results
def assert_generated_pairs_diverged(ds: Dataset) -> None:
@@ -272,6 +387,10 @@ def assert_generated_pairs_diverged(ds: Dataset) -> None:
f"unique_pos={len({r['response_pos'].strip() for r in rows})}, "
f"unique_neg={len({r['response_neg'].strip() for r in rows})}"
)
logger.info(
"SHOULD: identical_pos_neg stay low, unique_pos/unique_neg stay high, "
"and empty generations stay at zero. Large identical counts mean persona collapse."
)
# TODO judge filter: paper §3 uses GPT-4.1-mini to drop rows where r_pos doesn't
@@ -311,24 +430,69 @@ def generate_pairs(cfg: DataCfg) -> Path:
)
model.eval()
rows = []
for i, spec in enumerate(tqdm(specs, desc=f"gen {cfg.behavior}", mininterval=60)):
sampling = _sampling_defaults(tok, cfg)
logger.info(f"sampling={sampling} thinking_mode={has_thinking_mode(tok)}")
logger.info(
"SHOULD: training data use sampling, not greedy decoding. "
"Each continuation should be >= min_new_tokens and ideally terminate on EOS before max_new_tokens."
)
rendered = []
for spec in specs:
sys_pos = sys_pos_list[spec["persona_idx"]]
sys_neg = sys_neg_list[spec["persona_idx"]]
r_pos = _gen(model, tok, sys_pos, spec["prompt"], cfg.max_new_tokens, cfg.temperature)
r_neg = _gen(model, tok, sys_neg, spec["prompt"], cfg.max_new_tokens, cfg.temperature)
rows.append({
"prompt": spec["prompt"],
"response_pos": r_pos,
"response_neg": r_neg,
rendered.append({
**spec,
"sys_prompt_pos": sys_pos,
"sys_prompt_neg": sys_neg,
"prompt_text_pos": _render_chat_prompt(tok, sys_pos, spec["prompt"]),
"prompt_text_neg": _render_chat_prompt(tok, sys_neg, spec["prompt"]),
})
pos_results = _generate_batch(
model,
tok,
[r["prompt_text_pos"] for r in rendered],
cfg,
sampling,
trace_label="data stage pair0 pos",
)
neg_results = _generate_batch(
model,
tok,
[r["prompt_text_neg"] for r in rendered],
cfg,
sampling,
trace_label="data stage pair0 neg",
)
rows = []
no_eos_pos = 0
no_eos_neg = 0
for spec, pos, neg in zip(rendered, pos_results, neg_results, strict=True):
no_eos_pos += int(not pos["hit_eos"])
no_eos_neg += int(not neg["hit_eos"])
rows.append({
"prompt": spec["prompt"],
"response_pos": pos["clean_text"],
"response_neg": neg["clean_text"],
"sys_prompt_pos": spec["sys_prompt_pos"],
"sys_prompt_neg": spec["sys_prompt_neg"],
"topic_idx": spec["topic_idx"],
"persona_idx": spec["persona_idx"],
"sample_idx": spec["sample_idx"],
"behavior": cfg.behavior,
})
logger.info(
f"eos coverage: pos={len(rows) - no_eos_pos}/{len(rows)} neg={len(rows) - no_eos_neg}/{len(rows)}"
)
if no_eos_pos or no_eos_neg:
logger.warning(
f"{no_eos_pos + no_eos_neg} generations hit max_new_tokens before EOS. "
"Increase max_new_tokens if this is common."
)
ds = Dataset.from_list(rows)
assert_generated_pairs_diverged(ds)
out_dir = cfg.out / cfg.behavior
@@ -342,3 +506,14 @@ def load_pairs(behavior: str, root: Path = Path("out/data")) -> Dataset:
ds = Dataset.load_from_disk(str(root / behavior))
assert_generated_pairs_diverged(ds)
return ds
def main(cfg: DataCfg) -> None:
setup_logging("data")
logger.info(f"argv: {get_argv()}")
logger.info(f"data cfg: {asdict(cfg)}")
generate_pairs(cfg)
if __name__ == "__main__":
main(tyro.cli(DataCfg))
+139
View File
@@ -0,0 +1,139 @@
"""One-off persona collapse debugger.
For each persona pair, greedy-generate short continuations on a fixed prompt
set and warn if left/right collapse to the same text.
"""
from __future__ import annotations
from dataclasses import asdict, dataclass
from pathlib import Path
import polars as pl
import torch
import tyro
from loguru import logger
from tabulate import tabulate
from transformers import AutoModelForCausalLM, AutoTokenizer
from ws._log import final_summary, get_argv, setup_logging
from ws.data import _normalize_text, _personas, _render_chat_prompt, _topics
@dataclass
class PersonaDebugCfg:
model: str = "Qwen/Qwen3-0.6B"
behavior: str = "honesty"
out: Path = Path("out")
n_prompts: int = 8
max_new_tokens: int = 100
batch_size: int = 8
seed: int = 0
@torch.no_grad()
def _greedy_batch(model, tok, prompts: list[str], batch_size: int, max_new_tokens: int) -> list[str]:
rows: list[str] = []
old_padding_side = tok.padding_side
tok.padding_side = "left"
try:
for start in range(0, len(prompts), batch_size):
batch_prompts = prompts[start:start + batch_size]
enc = tok(batch_prompts, return_tensors="pt", padding=True).to(model.device)
out = model.generate(
**enc,
max_new_tokens=max_new_tokens,
do_sample=False,
temperature=1.0,
pad_token_id=tok.pad_token_id or tok.eos_token_id,
eos_token_id=tok.eos_token_id,
)
gen_block = out[:, enc["input_ids"].shape[1]:].cpu()
for i in range(len(batch_prompts)):
rows.append(tok.decode(gen_block[i], skip_special_tokens=True).rstrip())
finally:
tok.padding_side = old_padding_side
return rows
def main(cfg: PersonaDebugCfg) -> None:
setup_logging("debug_personas")
logger.info(f"argv: {get_argv()}")
logger.info(f"persona debug cfg: {asdict(cfg)}")
tok = AutoTokenizer.from_pretrained(cfg.model)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
cfg.model, torch_dtype=torch.bfloat16, device_map="auto"
)
model.eval()
pos_personas, neg_personas = _personas(cfg.behavior)
topics = _topics(cfg.behavior)[:cfg.n_prompts]
prompts: list[str] = []
for a, b in topics:
prompts.append(f"I think {a}. {b}" if cfg.behavior == "sycophancy" else a)
rows = []
for persona_idx, (sys_pos, sys_neg) in enumerate(zip(pos_personas, neg_personas, strict=True)):
prompt_pos = [_render_chat_prompt(tok, sys_pos, prompt) for prompt in prompts]
prompt_neg = [_render_chat_prompt(tok, sys_neg, prompt) for prompt in prompts]
gens_pos = _greedy_batch(model, tok, prompt_pos, cfg.batch_size, cfg.max_new_tokens)
gens_neg = _greedy_batch(model, tok, prompt_neg, cfg.batch_size, cfg.max_new_tokens)
identical = 0
for prompt, gen_pos, gen_neg in zip(prompts, gens_pos, gens_neg, strict=True):
same = _normalize_text(gen_pos) == _normalize_text(gen_neg)
identical += int(same)
rows.append({
"persona_idx": persona_idx,
"prompt": prompt,
"same": same,
"response_pos": gen_pos,
"response_neg": gen_neg,
})
if identical:
logger.warning(
f"persona_idx={persona_idx} collapsed on {identical}/{len(prompts)} greedy probes; "
"discard this pair from persona debugging."
)
df = pl.DataFrame(rows)
out_dir = cfg.out / cfg.behavior / "persona_debug"
out_dir.mkdir(parents=True, exist_ok=True)
per_prompt_path = out_dir / "per_prompt.csv"
summary_path = out_dir / "summary.csv"
df.write_csv(per_prompt_path)
summary = (
df.group_by("persona_idx")
.agg(
pl.len().alias("n_prompts"),
pl.col("same").sum().alias("n_same"),
)
.with_columns(
(pl.col("n_same") / pl.col("n_prompts")).alias("same_rate"),
(pl.col("n_same") == 0).alias("keep_pair"),
)
.sort("persona_idx")
)
summary.write_csv(summary_path)
print("\npersona_debug")
print("SHOULD: left/right greedy probes differ for each persona pair. same_rate>0 means the persona contrast is weak or ignored.")
print(tabulate(summary.to_pandas(), headers="keys", tablefmt="tsv", floatfmt="+.3f", showindex=False))
cue = "🟢" if bool(summary["keep_pair"].all()) else "🟡"
final_summary(
out=summary_path,
argv=get_argv(),
main_metric=f"keep_pairs={int(summary['keep_pair'].sum())}/{len(summary)}",
cue=cue,
table_rows=summary.select("persona_idx", "n_prompts", "n_same", "same_rate", "keep_pair").rows(),
headers=["persona_idx", "n_prompts", "n_same", "same_rate", "keep_pair"],
floatfmt="",
)
if __name__ == "__main__":
main(tyro.cli(PersonaDebugCfg))
+25 -10
View File
@@ -104,6 +104,12 @@ def _eval_dilemmas_repe(model, tok, dirs, layers, alpha, dl, choice_ids, pmass_t
return rows
def _alpha_triplet(row: dict) -> tuple[float, float]:
alpha_pos = float(row["alpha_pos"]) if "alpha_pos" in row else float(row["calibrated_alpha"])
alpha_neg = float(row["alpha_neg"]) if "alpha_neg" in row else float(row["calibrated_alpha"])
return alpha_pos, alpha_neg
def main(cfg: DilemmasCalibratedCfg) -> None:
setup_logging("dilemmas_calibrated")
out_dir = cfg.out / cfg.behavior / "dilemmas_calibrated"
@@ -112,8 +118,12 @@ def main(cfg: DilemmasCalibratedCfg) -> None:
calib_path = cfg.out / cfg.behavior / "kl_calibration" / "summary.csv"
calib = pl.read_csv(calib_path)
logger.info(f"loaded calibration: {len(calib)} methods from {calib_path}")
logger.info(tabulate(calib.select("method", "calibrated_alpha", "p95_at_calib").to_pandas(),
headers="keys", tablefmt="tsv", floatfmt="+.3f", showindex=False))
cols = ["method", "alpha_neg", "alpha_pos", "p95_at_neg", "p95_at_pos"]
fallback_cols = ["method", "calibrated_alpha"]
logger.info(tabulate(
calib.select([c for c in cols if c in calib.columns] or fallback_cols).to_pandas(),
headers="keys", tablefmt="tsv", floatfmt="+.3f", showindex=False
))
tok = AutoTokenizer.from_pretrained(cfg.model)
if tok.pad_token is None:
@@ -153,12 +163,12 @@ def main(cfg: DilemmasCalibratedCfg) -> None:
# Adapter dW evals at calibrated ±α and 0.
for row in calib.iter_rows(named=True):
method = row["method"]
alpha_c = float(row["calibrated_alpha"])
alpha_pos, alpha_neg = _alpha_triplet(row)
if method.startswith("dW:"):
adapter = method.split(":", 1)[1]
w = load_diff(cfg.out / cfg.behavior / adapter / DIFF_FILENAME)
rows = []
for alpha in (-alpha_c, 0.0, alpha_c):
for alpha in (-alpha_neg, 0.0, alpha_pos):
rows.extend(_eval_dilemmas_dw(model, tok, w, alpha, dl, choice_ids,
cfg.pmass_threshold, method))
logger.info(f" {method} α={alpha:+.3f}: {len(ds_pt)} rows")
@@ -166,7 +176,7 @@ def main(cfg: DilemmasCalibratedCfg) -> None:
elif method == "repe":
dirs = _fit_repe_directions(model, tok, cfg.n_repe_train, cfg.behavior)
rows = []
for alpha in (-alpha_c, 0.0, alpha_c):
for alpha in (-alpha_neg, 0.0, alpha_pos):
rows.extend(_eval_dilemmas_repe(model, tok, dirs, cfg.repe_layers, alpha, dl,
choice_ids, cfg.pmass_threshold))
logger.info(f" repe α={alpha:+.3f}: {len(ds_pt)} rows")
@@ -261,9 +271,11 @@ def main(cfg: DilemmasCalibratedCfg) -> None:
# Get calibrated alpha for this method (1.0 for prompts).
if method.startswith("prompt:"):
alpha_c = 1.0
alpha_pos = 1.0
alpha_neg = 1.0
else:
alpha_c = float(calib.filter(pl.col("method") == method)["calibrated_alpha"][0])
row = next(calib.filter(pl.col("method") == method).iter_rows(named=True))
alpha_pos, alpha_neg = _alpha_triplet(row)
# Mean logratio_honesty per coeff.
zero_lr = float(sub.filter(pl.col("coeff") == 0.0)["logratio_honesty"].mean()) if 0.0 in sub["coeff"].to_list() else float("nan")
@@ -272,7 +284,9 @@ def main(cfg: DilemmasCalibratedCfg) -> None:
si_rows.append({
"method": method,
"alpha": alpha_c,
"alpha": alpha_pos,
"alpha_pos": alpha_pos,
"alpha_neg": alpha_neg,
"sign": sign_chosen,
"SI": m["surgical_informedness"],
"SI_to_do": m.get("SI_to_do", float("nan")),
@@ -304,6 +318,7 @@ def main(cfg: DilemmasCalibratedCfg) -> None:
si_df.write_csv(si_path)
print("\n=== Dilemmas SI at KL-calibrated α (matched p95 token-KL ≈ 0.615 nats) ===")
print("SHOULD: use (-alpha_neg, 0, +alpha_pos) per method. Asymmetry is expected when left/right KL footprints differ.")
print(tabulate(si_df.to_pandas(), headers="keys", tablefmt="tsv",
floatfmt="+.3f", showindex=False))
@@ -313,9 +328,9 @@ def main(cfg: DilemmasCalibratedCfg) -> None:
argv=get_argv(),
main_metric=f"best_method={si_df['method'][0]} SI={float(si_df['SI'][0] or 0):+.3f}",
cue=cue,
table_rows=si_df.select("method", "alpha", "sign", "SI", "si_fwd", "si_rev",
table_rows=si_df.select("method", "alpha_neg", "alpha_pos", "sign", "SI", "si_fwd", "si_rev",
"fix_fwd", "broke_fwd").rows(),
headers=["method", "alpha", "sign", "SI", "si_fwd", "si_rev", "fix", "broke"],
headers=["method", "alpha_neg", "alpha_pos", "sign", "SI", "si_fwd", "si_rev", "fix", "broke"],
floatfmt="",
)
+143 -97
View File
@@ -72,13 +72,13 @@ class KLCalibrateCfg:
include_repe: bool = True
n_calib_prompts: int = 50
n_audit_prompts: int = 100
n_tokens: int = 20
n_tokens: int = 50
target_pct: float = 95.0
# "Side of the road" = 1 nat per-token KL (gist):
# https://gist.github.com/wassname/6c11cf30b43d8c228bc114795f1019c7
# Newton residual is 1 p95(KL); we search a global coefficient C such
# that p95 KL = target_kl at α=1.
target_kl: float = 1.0
target_kl: float = 0.5
target_prompt: str = "engineered_prompt_honest" # logged as a reference, not the target
# Bracket guard (lo, hi) on the global coefficient. KL ~ α²·F near root, so
# below ~0.05 nothing happens; above ~16 we'd be deep in collapse-land.
@@ -87,7 +87,7 @@ class KLCalibrateCfg:
n_root_iters: int = 12 # Illinois inner loop; usually converges in 3-5
convergence_tol: float = 0.05 # |p95 - target| < tol (absolute, in nats)
repe_layers: tuple[int, ...] = field(default_factory=lambda: tuple(range(8, 22)))
n_repe_train: int = 20
n_repe_train: int = 50
seed: int = 0
@@ -212,8 +212,19 @@ def _measure_kl_along_trajectory(
}
def _illinois_calibrate(method: str, target: float, *, model, tok, prompts, cfg,
w=None, repe_dirs=None) -> dict:
def _illinois_calibrate(
method: str,
target: float,
*,
model,
tok,
prompts,
cfg,
alpha_sign: float = 1.0,
sign_label: str = "pos",
w=None,
repe_dirs=None,
) -> dict:
"""Exponential bracket within (bracket_lo, bracket_hi) then log-log Illinois
regula falsi. Mirrors steering-lite's validated `calibrate_iso_kl`.
@@ -227,18 +238,42 @@ def _illinois_calibrate(method: str, target: float, *, model, tok, prompts, cfg,
history: list[dict] = []
iter_idx = [0]
def stat(alpha: float) -> float:
def _result(final: dict, converged: bool) -> dict:
return {
"method": method,
"sign": sign_label,
"alpha_sign": alpha_sign,
"alpha_mag": abs(final["alpha"]),
"calibrated_alpha": final["alpha"],
"p95_at_calib": final["p95"],
"mean_at_calib": final["mean"],
"max_at_calib": final["max"],
"ratio_at_calib": final["ratio"],
"iterations": len(history),
"converged": converged,
"history": history,
}
def stat(alpha_mag: float) -> float:
alpha = alpha_sign * alpha_mag
m = _measure_kl_along_trajectory(
method, alpha, model=model, tok=tok, prompts=prompts,
n_tokens=cfg.n_tokens, w=w, repe_dirs=repe_dirs,
repe_layers=cfg.repe_layers,
log_first_sample=(iter_idx[0] == 0),
sample_label=f"calib iter=0 method={method} α={alpha:+.3f}",
sample_label=f"calib iter=0 method={method} sign={sign_label} α={alpha:+.3f}",
)
ratio = m["p95"] / target if target > 0 else 1.0
history.append({"iter": iter_idx[0], "alpha": alpha, **m, "ratio": ratio})
history.append({
"iter": iter_idx[0],
"sign": sign_label,
"alpha": alpha,
"alpha_mag": alpha_mag,
**m,
"ratio": ratio,
})
logger.info(
f" [{method}] iter={iter_idx[0]} α={alpha:+.4f} p95={m['p95']:.4g} "
f" [{method}:{sign_label}] iter={iter_idx[0]} α={alpha:+.4f} p95={m['p95']:.4g} "
f"mean={m['mean']:.4g} max={m['max']:.4g} ratio={ratio:.3f}"
)
iter_idx[0] += 1
@@ -251,13 +286,7 @@ def _illinois_calibrate(method: str, target: float, *, model, tok, prompts, cfg,
mid = float(np.sqrt(lo * hi))
v_mid = stat(mid)
if abs(v_mid - target) < cfg.convergence_tol:
final = history[-1]
return {
"method": method, "calibrated_alpha": final["alpha"],
"p95_at_calib": final["p95"], "mean_at_calib": final["mean"],
"max_at_calib": final["max"], "ratio_at_calib": final["ratio"],
"iterations": len(history), "converged": True, "history": history,
}
return _result(history[-1], True)
if v_mid < target:
c_lo, v_lo = mid, v_mid
@@ -283,14 +312,7 @@ def _illinois_calibrate(method: str, target: float, *, model, tok, prompts, cfg,
c_hi, v_hi = c, v
if v_lo is None or v_hi is None:
# Couldn't bracket within (lo, hi); report last point seen.
final = history[-1]
return {
"method": method, "calibrated_alpha": final["alpha"],
"p95_at_calib": final["p95"], "mean_at_calib": final["mean"],
"max_at_calib": final["max"], "ratio_at_calib": final["ratio"],
"iterations": len(history), "converged": False, "history": history,
}
return _result(history[-1], False)
# 2. Log-log Illinois regula-falsi inside the bracket.
converged = False
@@ -324,22 +346,8 @@ def _illinois_calibrate(method: str, target: float, *, model, tok, prompts, cfg,
# If we exhausted iters without hitting tol, pick the closest point seen.
if not converged:
best = min(history, key=lambda h: abs(h["p95"] - target))
final = best
else:
final = history[-1]
return {
"method": method,
"calibrated_alpha": final["alpha"],
"p95_at_calib": final["p95"],
"mean_at_calib": final["mean"],
"max_at_calib": final["max"],
"ratio_at_calib": final["ratio"],
"iterations": len(history),
"converged": converged,
"history": history,
}
return _result(min(history, key=lambda h: abs(h["p95"] - target)), False)
return _result(history[-1], True)
def main(cfg: KLCalibrateCfg) -> None:
@@ -401,23 +409,33 @@ def main(cfg: KLCalibrateCfg) -> None:
repe_dirs = _fit_repe_directions(model, tok, cfg.n_repe_train, cfg.behavior)
# 3. Illinois regula-falsi calibrate each adapter and (optionally) RepE.
results = []
results_by_method: dict[str, dict[str, dict]] = {}
for adapter in cfg.adapters:
logger.info(f"\n=== calibrate dW:{adapter} ===")
w = load_diff(cfg.out / cfg.behavior / adapter / DIFF_FILENAME)
r = _illinois_calibrate(
f"dW:{adapter}", target, model=model, tok=tok,
prompts=calib_prompts, cfg=cfg, w=w,
)
results.append(r)
results_by_method[f"dW:{adapter}"] = {
"pos": _illinois_calibrate(
f"dW:{adapter}", target, model=model, tok=tok,
prompts=calib_prompts, cfg=cfg, alpha_sign=1.0, sign_label="pos", w=w,
),
"neg": _illinois_calibrate(
f"dW:{adapter}", target, model=model, tok=tok,
prompts=calib_prompts, cfg=cfg, alpha_sign=-1.0, sign_label="neg", w=w,
),
}
if cfg.include_repe:
logger.info("\n=== calibrate repe ===")
r = _illinois_calibrate(
"repe", target, model=model, tok=tok,
prompts=calib_prompts, cfg=cfg, repe_dirs=repe_dirs,
)
results.append(r)
results_by_method["repe"] = {
"pos": _illinois_calibrate(
"repe", target, model=model, tok=tok,
prompts=calib_prompts, cfg=cfg, alpha_sign=1.0, sign_label="pos", repe_dirs=repe_dirs,
),
"neg": _illinois_calibrate(
"repe", target, model=model, tok=tok,
prompts=calib_prompts, cfg=cfg, alpha_sign=-1.0, sign_label="neg", repe_dirs=repe_dirs,
),
}
# 4. Audit: at calibrated α, recompute on n_audit prompts.
logger.info(f"\n=== AUDIT (n={len(audit_prompts)} prompts) ===")
@@ -441,63 +459,83 @@ def main(cfg: KLCalibrateCfg) -> None:
"calib_audit_ratio": m_audit["p95"] / m_calib["p95"] if m_calib["p95"] > 0 else float("nan"),
})
for r in results:
method = r["method"]
alpha = r["calibrated_alpha"]
logger.info(
"SHOULD: pos and neg p95 each match the target independently. "
"Asymmetric alpha_pos/alpha_neg means the steering direction has asymmetric KL footprint, not failure."
)
for method, signs in results_by_method.items():
if method.startswith("dW:"):
adapter = method.split(":", 1)[1]
w = load_diff(cfg.out / cfg.behavior / adapter / DIFF_FILENAME)
m_audit = _measure_kl_along_trajectory(
method, alpha, model=model, tok=tok, prompts=audit_prompts,
n_tokens=cfg.n_tokens, w=w,
)
elif method == "repe":
m_audit = _measure_kl_along_trajectory(
method, alpha, model=model, tok=tok, prompts=audit_prompts,
n_tokens=cfg.n_tokens, repe_dirs=repe_dirs,
repe_layers=cfg.repe_layers,
)
else:
raise ValueError(method)
logger.info(
f" {method} α={alpha:+.3f} audit p95={m_audit['p95']:.4g} "
f"(calib was {r['p95_at_calib']:.4g}, target {target:.4g})"
)
audit_rows.append({
"method": method,
"alpha": alpha,
"p95_calib": r["p95_at_calib"],
"mean_calib": r["mean_at_calib"],
"p95_audit": m_audit["p95"],
"mean_audit": m_audit["mean"],
"max_audit": m_audit["max"],
"calib_audit_ratio": m_audit["p95"] / r["p95_at_calib"] if r["p95_at_calib"] > 0 else float("nan"),
})
w = None
for sign_label, r in signs.items():
alpha = r["calibrated_alpha"]
if method.startswith("dW:"):
m_audit = _measure_kl_along_trajectory(
method, alpha, model=model, tok=tok, prompts=audit_prompts,
n_tokens=cfg.n_tokens, w=w,
)
elif method == "repe":
m_audit = _measure_kl_along_trajectory(
method, alpha, model=model, tok=tok, prompts=audit_prompts,
n_tokens=cfg.n_tokens, repe_dirs=repe_dirs,
repe_layers=cfg.repe_layers,
)
else:
raise ValueError(method)
logger.info(
f" {method}:{sign_label} α={alpha:+.3f} audit p95={m_audit['p95']:.4g} "
f"(calib was {r['p95_at_calib']:.4g}, target {target:.4g})"
)
audit_rows.append({
"method": method,
"sign": sign_label,
"alpha": alpha,
"alpha_mag": r["alpha_mag"],
"p95_calib": r["p95_at_calib"],
"mean_calib": r["mean_at_calib"],
"p95_audit": m_audit["p95"],
"mean_audit": m_audit["mean"],
"max_audit": m_audit["max"],
"calib_audit_ratio": m_audit["p95"] / r["p95_at_calib"] if r["p95_at_calib"] > 0 else float("nan"),
})
audit_df = pl.DataFrame(audit_rows)
audit_df.write_csv(out_dir / "audit.csv")
summary_rows = []
for r in results:
for method, signs in results_by_method.items():
pos = signs["pos"]
neg = signs["neg"]
summary_rows.append({
"method": r["method"],
"calibrated_alpha": r["calibrated_alpha"],
"p95_at_calib": r["p95_at_calib"],
"mean_at_calib": r["mean_at_calib"],
"max_at_calib": r["max_at_calib"],
"ratio_at_calib": r["ratio_at_calib"],
"iterations": r["iterations"],
"converged": r["converged"],
"method": method,
"alpha_pos": pos["alpha_mag"],
"alpha_neg": neg["alpha_mag"],
"calibrated_alpha": pos["alpha_mag"],
"p95_at_pos": pos["p95_at_calib"],
"p95_at_neg": neg["p95_at_calib"],
"mean_at_pos": pos["mean_at_calib"],
"mean_at_neg": neg["mean_at_calib"],
"max_at_pos": pos["max_at_calib"],
"max_at_neg": neg["max_at_calib"],
"ratio_at_pos": pos["ratio_at_calib"],
"ratio_at_neg": neg["ratio_at_calib"],
"iterations_pos": pos["iterations"],
"iterations_neg": neg["iterations"],
"converged_pos": pos["converged"],
"converged_neg": neg["converged"],
})
summary_df = pl.DataFrame(summary_rows).sort("calibrated_alpha")
summary_df = pl.DataFrame(summary_rows).sort("alpha_pos")
summary_df = summary_df.with_columns(pl.lit(target).alias("target_p95"))
summary_path = out_dir / "summary.csv"
summary_df.write_csv(summary_path)
history_rows = []
for r in results:
for h in r["history"]:
history_rows.append({"method": r["method"], **h})
for method, signs in results_by_method.items():
for sign_label, r in signs.items():
for h in r["history"]:
history_rows.append({"method": method, "sign": sign_label, **h})
pl.DataFrame(history_rows).write_csv(out_dir / "root_history.csv")
pl.DataFrame([{"method": k, **v} for k, v in prompt_refs.items()]).write_csv(out_dir / "prompt_refs.csv")
@@ -510,15 +548,23 @@ def main(cfg: KLCalibrateCfg) -> None:
print(tabulate(audit_df.to_pandas(), headers="keys", tablefmt="tsv",
floatfmt="+.4g", showindex=False))
cue = "🟢" if all(r["converged"] for r in results) else "🟡"
n_converged = sum(
int(r["converged"])
for signs in results_by_method.values()
for r in signs.values()
)
n_total = sum(len(signs) for signs in results_by_method.values())
cue = "🟢" if n_converged == n_total else "🟡"
final_summary(
out=summary_path,
argv=get_argv(),
main_metric=f"target_p95={target:.4g} converged={sum(r['converged'] for r in results)}/{len(results)}",
main_metric=f"target_p95={target:.4g} converged={n_converged}/{n_total}",
cue=cue,
table_rows=summary_df.select("method", "calibrated_alpha", "p95_at_calib",
"ratio_at_calib", "iterations", "converged").rows(),
headers=["method", "alpha", "p95", "ratio", "iters", "ok"],
table_rows=summary_df.select(
"method", "alpha_neg", "alpha_pos", "p95_at_neg", "p95_at_pos",
"iterations_neg", "iterations_pos", "converged_neg", "converged_pos"
).rows(),
headers=["method", "alpha_neg", "alpha_pos", "p95_neg", "p95_pos", "iters_neg", "iters_pos", "ok_neg", "ok_pos"],
floatfmt="",
)
+16
View File
@@ -43,6 +43,14 @@ class Cfg:
out: Path = Path("out")
# Shared data dir — kept separate from out so multiple runs reuse the same pairs.
data_root: Path = Path("out/data")
data_batch_size: int = 8
data_min_new_tokens: int = 1024
data_max_new_tokens: int = 1280
data_temperature: float | None = None
data_top_p: float | None = None
data_top_k: int | None = None
data_min_p: float | None = None
data_presence_penalty: float = 0.0
coeffs: tuple[float, ...] = (-2.0, -1.0, 0.0, 1.0, 2.0)
@@ -66,6 +74,14 @@ def _maybe_data(cfg: Cfg) -> Dataset:
dcfg = DataCfg(
model_id=cfg.model, behavior=cfg.behavior, out=data_root,
n_topics=cfg.n_topics, n_personas=cfg.n_personas, n_samples=cfg.n_samples,
batch_size=cfg.data_batch_size,
min_new_tokens=cfg.data_min_new_tokens,
max_new_tokens=cfg.data_max_new_tokens,
temperature=cfg.data_temperature,
top_p=cfg.data_top_p,
top_k=cfg.data_top_k,
min_p=cfg.data_min_p,
presence_penalty=cfg.data_presence_penalty,
)
generate_pairs(dcfg)
return load_pairs(cfg.behavior, root=data_root)