diff --git a/evals/smoke.py b/evals/smoke.py index d8a981a..18d26af 100644 --- a/evals/smoke.py +++ b/evals/smoke.py @@ -49,6 +49,13 @@ def main(cfg: SmokeCfg) -> None: n_topics=2, # 2×1×2 = 4 pairs n_personas=1, n_samples=2, + data_batch_size=2, + data_min_new_tokens=16, + data_max_new_tokens=32, + data_temperature=0.7, + data_top_p=0.8, + data_top_k=20, + data_min_p=0.0, ) replicate_main(rcfg) print("[smoke] OK", flush=True) diff --git a/justfile b/justfile index c70d8a8..2574a09 100644 --- a/justfile +++ b/justfile @@ -31,7 +31,11 @@ smoke-sweep: # Generate +/- pair data for a behavior. Writes to out/data/{behavior}/. data: - uv run python -m ws.data --model {{model}} --behavior {{behavior}} + uv run python -m ws.data --model-id {{model}} --behavior {{behavior}} + +# One-off greedy persona collapse debugger. +debug-personas: + uv run python -m ws.debug_personas --model {{model}} --behavior {{behavior}} --out {{out}} # Train a single adapter (positive or negative). Pos/neg controls system prompt at gen time. train sign="pos": diff --git a/src/ws/data.py b/src/ws/data.py index 88c66a1..ba20054 100644 --- a/src/ws/data.py +++ b/src/ws/data.py @@ -18,16 +18,19 @@ Output columns: from __future__ import annotations import json -from dataclasses import dataclass +from dataclasses import asdict, dataclass from pathlib import Path +from typing import Any import torch +import tyro from datasets import Dataset from loguru import logger from tqdm.auto import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer -from ws._tok_extras import chat_template_extras +from ws._log import get_argv, setup_logging +from ws._tok_extras import chat_template_extras, has_thinking_mode REPO_ROOT = Path(__file__).resolve().parents[2] DATA_DIR = REPO_ROOT / "data" @@ -162,8 +165,14 @@ class DataCfg: n_personas: int = 5 n_samples: int = 10 out: Path = Path("out/data") - max_new_tokens: int = 96 - temperature: float = 0.8 + batch_size: int = 8 + min_new_tokens: int = 1024 + max_new_tokens: int = 1280 + temperature: float | None = None + top_p: float | None = None + top_k: int | None = None + min_p: float | None = None + presence_penalty: float = 0.0 seed: int = 0 @@ -212,21 +221,127 @@ def _build_specs(topics, n_personas: int, n_samples: int, behavior: str): return specs -@torch.no_grad() -def _gen(model, tok, sys_prompt: str, user_prompt: str, max_new_tokens: int, temperature: float): +def _render_chat_prompt(tok, sys_prompt: str, user_prompt: str) -> str: msgs = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": user_prompt}] - text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True, - **chat_template_extras(tok)) - inputs = tok(text, return_tensors="pt").to(model.device) - out = model.generate( - **inputs, - max_new_tokens=max_new_tokens, - do_sample=temperature > 0, - temperature=temperature if temperature > 0 else 1.0, - pad_token_id=tok.pad_token_id or tok.eos_token_id, + return tok.apply_chat_template( + msgs, + tokenize=False, + add_generation_prompt=True, + **chat_template_extras(tok), ) - gen = out[0, inputs["input_ids"].shape[1]:] - return tok.decode(gen, skip_special_tokens=True).strip() + + +def _sampling_defaults(tok, cfg: DataCfg) -> dict[str, Any]: + thinking = has_thinking_mode(tok) + defaults = { + "temperature": 0.6 if thinking else 0.7, + "top_p": 0.95 if thinking else 0.8, + "top_k": 20, + "min_p": 0.0, + } + params = { + "temperature": defaults["temperature"] if cfg.temperature is None else cfg.temperature, + "top_p": defaults["top_p"] if cfg.top_p is None else cfg.top_p, + "top_k": defaults["top_k"] if cfg.top_k is None else cfg.top_k, + "min_p": defaults["min_p"] if cfg.min_p is None else cfg.min_p, + } + if params["temperature"] <= 0: + logger.warning( + "data generation temperature<=0 enables greedy decoding. " + "Qwen recommends sampling for training data; use this only deliberately." + ) + if cfg.presence_penalty: + logger.warning( + "presence_penalty is configured but this HF generate path does not support it directly; ignoring." + ) + return params + + +def _trim_generation_ids(ids: torch.Tensor, eos_id: int | None, pad_id: int | None) -> tuple[torch.Tensor, bool]: + trimmed: list[int] = [] + hit_eos = False + for tok_id in ids.tolist(): + if pad_id is not None and tok_id == pad_id: + continue + trimmed.append(tok_id) + if eos_id is not None and tok_id == eos_id: + hit_eos = True + break + return torch.tensor(trimmed, dtype=torch.long), hit_eos + + +def _normalize_text(text: str) -> str: + return " ".join(text.split()) + + +def _log_trace(tok, *, prompt_text: str, gen_ids: torch.Tensor, clean_text: str, label: str) -> None: + prompt_ids = tok(prompt_text, return_tensors="pt").input_ids[0] + first = tok.convert_ids_to_tokens(prompt_ids[: min(8, len(prompt_ids))].tolist()) + last = tok.convert_ids_to_tokens(prompt_ids[-min(8, len(prompt_ids)):].tolist()) + raw_gen = tok.decode(gen_ids, skip_special_tokens=False) + raw_toks = tok.convert_ids_to_tokens(gen_ids.tolist()) if len(gen_ids) else [] + logger.info(f"[{label}] full prompt (special tokens included):\n{prompt_text}") + logger.info(f"[{label}] n_input_tokens={prompt_ids.shape[0]} first8={first} last8={last}") + logger.info(f"[{label}] raw generated continuation: {raw_gen!r}") + logger.info(f"[{label}] generated tokens: {raw_toks}") + logger.info(f"[{label}] cleaned continuation: {clean_text!r}") + + +@torch.no_grad() +def _generate_batch( + model, + tok, + prompts: list[str], + cfg: DataCfg, + sampling: dict[str, Any], + *, + trace_label: str, +) -> list[dict[str, Any]]: + if cfg.max_new_tokens < cfg.min_new_tokens: + raise ValueError( + f"max_new_tokens={cfg.max_new_tokens} must be >= min_new_tokens={cfg.min_new_tokens}" + ) + results: list[dict[str, Any]] = [] + old_padding_side = tok.padding_side + tok.padding_side = "left" + try: + for start in tqdm(range(0, len(prompts), cfg.batch_size), desc=f"gen {trace_label}", mininterval=60): + batch_prompts = prompts[start:start + cfg.batch_size] + enc = tok(batch_prompts, return_tensors="pt", padding=True).to(model.device) + out = model.generate( + **enc, + min_new_tokens=cfg.min_new_tokens, + max_new_tokens=cfg.max_new_tokens, + do_sample=sampling["temperature"] > 0, + temperature=sampling["temperature"] if sampling["temperature"] > 0 else 1.0, + top_p=sampling["top_p"], + top_k=sampling["top_k"], + min_p=sampling["min_p"], + pad_token_id=tok.pad_token_id or tok.eos_token_id, + eos_token_id=tok.eos_token_id, + ) + gen_block = out[:, enc["input_ids"].shape[1]:].cpu() + for i, prompt_text in enumerate(batch_prompts): + raw_ids, hit_eos = _trim_generation_ids(gen_block[i], tok.eos_token_id, tok.pad_token_id) + clean = tok.decode(raw_ids, skip_special_tokens=True).rstrip() + results.append({ + "clean_text": clean, + "raw_text": tok.decode(raw_ids, skip_special_tokens=False), + "gen_ids": raw_ids, + "hit_eos": hit_eos, + "prompt_text": prompt_text, + }) + if results: + _log_trace( + tok, + prompt_text=results[0]["prompt_text"], + gen_ids=results[0]["gen_ids"], + clean_text=results[0]["clean_text"], + label=trace_label, + ) + finally: + tok.padding_side = old_padding_side + return results def assert_generated_pairs_diverged(ds: Dataset) -> None: @@ -272,6 +387,10 @@ def assert_generated_pairs_diverged(ds: Dataset) -> None: f"unique_pos={len({r['response_pos'].strip() for r in rows})}, " f"unique_neg={len({r['response_neg'].strip() for r in rows})}" ) + logger.info( + "SHOULD: identical_pos_neg stay low, unique_pos/unique_neg stay high, " + "and empty generations stay at zero. Large identical counts mean persona collapse." + ) # TODO judge filter: paper §3 uses GPT-4.1-mini to drop rows where r_pos doesn't @@ -311,24 +430,69 @@ def generate_pairs(cfg: DataCfg) -> Path: ) model.eval() - rows = [] - for i, spec in enumerate(tqdm(specs, desc=f"gen {cfg.behavior}", mininterval=60)): + sampling = _sampling_defaults(tok, cfg) + logger.info(f"sampling={sampling} thinking_mode={has_thinking_mode(tok)}") + logger.info( + "SHOULD: training data use sampling, not greedy decoding. " + "Each continuation should be >= min_new_tokens and ideally terminate on EOS before max_new_tokens." + ) + + rendered = [] + for spec in specs: sys_pos = sys_pos_list[spec["persona_idx"]] sys_neg = sys_neg_list[spec["persona_idx"]] - r_pos = _gen(model, tok, sys_pos, spec["prompt"], cfg.max_new_tokens, cfg.temperature) - r_neg = _gen(model, tok, sys_neg, spec["prompt"], cfg.max_new_tokens, cfg.temperature) - rows.append({ - "prompt": spec["prompt"], - "response_pos": r_pos, - "response_neg": r_neg, + rendered.append({ + **spec, "sys_prompt_pos": sys_pos, "sys_prompt_neg": sys_neg, + "prompt_text_pos": _render_chat_prompt(tok, sys_pos, spec["prompt"]), + "prompt_text_neg": _render_chat_prompt(tok, sys_neg, spec["prompt"]), + }) + + pos_results = _generate_batch( + model, + tok, + [r["prompt_text_pos"] for r in rendered], + cfg, + sampling, + trace_label="data stage pair0 pos", + ) + neg_results = _generate_batch( + model, + tok, + [r["prompt_text_neg"] for r in rendered], + cfg, + sampling, + trace_label="data stage pair0 neg", + ) + + rows = [] + no_eos_pos = 0 + no_eos_neg = 0 + for spec, pos, neg in zip(rendered, pos_results, neg_results, strict=True): + no_eos_pos += int(not pos["hit_eos"]) + no_eos_neg += int(not neg["hit_eos"]) + rows.append({ + "prompt": spec["prompt"], + "response_pos": pos["clean_text"], + "response_neg": neg["clean_text"], + "sys_prompt_pos": spec["sys_prompt_pos"], + "sys_prompt_neg": spec["sys_prompt_neg"], "topic_idx": spec["topic_idx"], "persona_idx": spec["persona_idx"], "sample_idx": spec["sample_idx"], "behavior": cfg.behavior, }) + logger.info( + f"eos coverage: pos={len(rows) - no_eos_pos}/{len(rows)} neg={len(rows) - no_eos_neg}/{len(rows)}" + ) + if no_eos_pos or no_eos_neg: + logger.warning( + f"{no_eos_pos + no_eos_neg} generations hit max_new_tokens before EOS. " + "Increase max_new_tokens if this is common." + ) + ds = Dataset.from_list(rows) assert_generated_pairs_diverged(ds) out_dir = cfg.out / cfg.behavior @@ -342,3 +506,14 @@ def load_pairs(behavior: str, root: Path = Path("out/data")) -> Dataset: ds = Dataset.load_from_disk(str(root / behavior)) assert_generated_pairs_diverged(ds) return ds + + +def main(cfg: DataCfg) -> None: + setup_logging("data") + logger.info(f"argv: {get_argv()}") + logger.info(f"data cfg: {asdict(cfg)}") + generate_pairs(cfg) + + +if __name__ == "__main__": + main(tyro.cli(DataCfg)) diff --git a/src/ws/debug_personas.py b/src/ws/debug_personas.py new file mode 100644 index 0000000..a55dd1b --- /dev/null +++ b/src/ws/debug_personas.py @@ -0,0 +1,139 @@ +"""One-off persona collapse debugger. + +For each persona pair, greedy-generate short continuations on a fixed prompt +set and warn if left/right collapse to the same text. +""" + +from __future__ import annotations + +from dataclasses import asdict, dataclass +from pathlib import Path + +import polars as pl +import torch +import tyro +from loguru import logger +from tabulate import tabulate +from transformers import AutoModelForCausalLM, AutoTokenizer + +from ws._log import final_summary, get_argv, setup_logging +from ws.data import _normalize_text, _personas, _render_chat_prompt, _topics + + +@dataclass +class PersonaDebugCfg: + model: str = "Qwen/Qwen3-0.6B" + behavior: str = "honesty" + out: Path = Path("out") + n_prompts: int = 8 + max_new_tokens: int = 100 + batch_size: int = 8 + seed: int = 0 + + +@torch.no_grad() +def _greedy_batch(model, tok, prompts: list[str], batch_size: int, max_new_tokens: int) -> list[str]: + rows: list[str] = [] + old_padding_side = tok.padding_side + tok.padding_side = "left" + try: + for start in range(0, len(prompts), batch_size): + batch_prompts = prompts[start:start + batch_size] + enc = tok(batch_prompts, return_tensors="pt", padding=True).to(model.device) + out = model.generate( + **enc, + max_new_tokens=max_new_tokens, + do_sample=False, + temperature=1.0, + pad_token_id=tok.pad_token_id or tok.eos_token_id, + eos_token_id=tok.eos_token_id, + ) + gen_block = out[:, enc["input_ids"].shape[1]:].cpu() + for i in range(len(batch_prompts)): + rows.append(tok.decode(gen_block[i], skip_special_tokens=True).rstrip()) + finally: + tok.padding_side = old_padding_side + return rows + + +def main(cfg: PersonaDebugCfg) -> None: + setup_logging("debug_personas") + logger.info(f"argv: {get_argv()}") + logger.info(f"persona debug cfg: {asdict(cfg)}") + + tok = AutoTokenizer.from_pretrained(cfg.model) + if tok.pad_token is None: + tok.pad_token = tok.eos_token + model = AutoModelForCausalLM.from_pretrained( + cfg.model, torch_dtype=torch.bfloat16, device_map="auto" + ) + model.eval() + + pos_personas, neg_personas = _personas(cfg.behavior) + topics = _topics(cfg.behavior)[:cfg.n_prompts] + prompts: list[str] = [] + for a, b in topics: + prompts.append(f"I think {a}. {b}" if cfg.behavior == "sycophancy" else a) + + rows = [] + for persona_idx, (sys_pos, sys_neg) in enumerate(zip(pos_personas, neg_personas, strict=True)): + prompt_pos = [_render_chat_prompt(tok, sys_pos, prompt) for prompt in prompts] + prompt_neg = [_render_chat_prompt(tok, sys_neg, prompt) for prompt in prompts] + gens_pos = _greedy_batch(model, tok, prompt_pos, cfg.batch_size, cfg.max_new_tokens) + gens_neg = _greedy_batch(model, tok, prompt_neg, cfg.batch_size, cfg.max_new_tokens) + identical = 0 + for prompt, gen_pos, gen_neg in zip(prompts, gens_pos, gens_neg, strict=True): + same = _normalize_text(gen_pos) == _normalize_text(gen_neg) + identical += int(same) + rows.append({ + "persona_idx": persona_idx, + "prompt": prompt, + "same": same, + "response_pos": gen_pos, + "response_neg": gen_neg, + }) + if identical: + logger.warning( + f"persona_idx={persona_idx} collapsed on {identical}/{len(prompts)} greedy probes; " + "discard this pair from persona debugging." + ) + + df = pl.DataFrame(rows) + out_dir = cfg.out / cfg.behavior / "persona_debug" + out_dir.mkdir(parents=True, exist_ok=True) + per_prompt_path = out_dir / "per_prompt.csv" + summary_path = out_dir / "summary.csv" + df.write_csv(per_prompt_path) + + summary = ( + df.group_by("persona_idx") + .agg( + pl.len().alias("n_prompts"), + pl.col("same").sum().alias("n_same"), + ) + .with_columns( + (pl.col("n_same") / pl.col("n_prompts")).alias("same_rate"), + (pl.col("n_same") == 0).alias("keep_pair"), + ) + .sort("persona_idx") + ) + summary.write_csv(summary_path) + + print("\npersona_debug") + print("SHOULD: left/right greedy probes differ for each persona pair. same_rate>0 means the persona contrast is weak or ignored.") + print(tabulate(summary.to_pandas(), headers="keys", tablefmt="tsv", floatfmt="+.3f", showindex=False)) + + cue = "🟢" if bool(summary["keep_pair"].all()) else "🟡" + final_summary( + out=summary_path, + argv=get_argv(), + main_metric=f"keep_pairs={int(summary['keep_pair'].sum())}/{len(summary)}", + cue=cue, + table_rows=summary.select("persona_idx", "n_prompts", "n_same", "same_rate", "keep_pair").rows(), + headers=["persona_idx", "n_prompts", "n_same", "same_rate", "keep_pair"], + floatfmt="", + ) + + +if __name__ == "__main__": + main(tyro.cli(PersonaDebugCfg)) diff --git a/src/ws/eval/dilemmas_calibrated.py b/src/ws/eval/dilemmas_calibrated.py index 70ee656..712a6ca 100644 --- a/src/ws/eval/dilemmas_calibrated.py +++ b/src/ws/eval/dilemmas_calibrated.py @@ -104,6 +104,12 @@ def _eval_dilemmas_repe(model, tok, dirs, layers, alpha, dl, choice_ids, pmass_t return rows +def _alpha_triplet(row: dict) -> tuple[float, float]: + alpha_pos = float(row["alpha_pos"]) if "alpha_pos" in row else float(row["calibrated_alpha"]) + alpha_neg = float(row["alpha_neg"]) if "alpha_neg" in row else float(row["calibrated_alpha"]) + return alpha_pos, alpha_neg + + def main(cfg: DilemmasCalibratedCfg) -> None: setup_logging("dilemmas_calibrated") out_dir = cfg.out / cfg.behavior / "dilemmas_calibrated" @@ -112,8 +118,12 @@ def main(cfg: DilemmasCalibratedCfg) -> None: calib_path = cfg.out / cfg.behavior / "kl_calibration" / "summary.csv" calib = pl.read_csv(calib_path) logger.info(f"loaded calibration: {len(calib)} methods from {calib_path}") - logger.info(tabulate(calib.select("method", "calibrated_alpha", "p95_at_calib").to_pandas(), - headers="keys", tablefmt="tsv", floatfmt="+.3f", showindex=False)) + cols = ["method", "alpha_neg", "alpha_pos", "p95_at_neg", "p95_at_pos"] + fallback_cols = ["method", "calibrated_alpha"] + logger.info(tabulate( + calib.select([c for c in cols if c in calib.columns] or fallback_cols).to_pandas(), + headers="keys", tablefmt="tsv", floatfmt="+.3f", showindex=False + )) tok = AutoTokenizer.from_pretrained(cfg.model) if tok.pad_token is None: @@ -153,12 +163,12 @@ def main(cfg: DilemmasCalibratedCfg) -> None: # Adapter dW evals at calibrated ±α and 0. for row in calib.iter_rows(named=True): method = row["method"] - alpha_c = float(row["calibrated_alpha"]) + alpha_pos, alpha_neg = _alpha_triplet(row) if method.startswith("dW:"): adapter = method.split(":", 1)[1] w = load_diff(cfg.out / cfg.behavior / adapter / DIFF_FILENAME) rows = [] - for alpha in (-alpha_c, 0.0, alpha_c): + for alpha in (-alpha_neg, 0.0, alpha_pos): rows.extend(_eval_dilemmas_dw(model, tok, w, alpha, dl, choice_ids, cfg.pmass_threshold, method)) logger.info(f" {method} α={alpha:+.3f}: {len(ds_pt)} rows") @@ -166,7 +176,7 @@ def main(cfg: DilemmasCalibratedCfg) -> None: elif method == "repe": dirs = _fit_repe_directions(model, tok, cfg.n_repe_train, cfg.behavior) rows = [] - for alpha in (-alpha_c, 0.0, alpha_c): + for alpha in (-alpha_neg, 0.0, alpha_pos): rows.extend(_eval_dilemmas_repe(model, tok, dirs, cfg.repe_layers, alpha, dl, choice_ids, cfg.pmass_threshold)) logger.info(f" repe α={alpha:+.3f}: {len(ds_pt)} rows") @@ -261,9 +271,11 @@ def main(cfg: DilemmasCalibratedCfg) -> None: # Get calibrated alpha for this method (1.0 for prompts). if method.startswith("prompt:"): - alpha_c = 1.0 + alpha_pos = 1.0 + alpha_neg = 1.0 else: - alpha_c = float(calib.filter(pl.col("method") == method)["calibrated_alpha"][0]) + row = next(calib.filter(pl.col("method") == method).iter_rows(named=True)) + alpha_pos, alpha_neg = _alpha_triplet(row) # Mean logratio_honesty per coeff. zero_lr = float(sub.filter(pl.col("coeff") == 0.0)["logratio_honesty"].mean()) if 0.0 in sub["coeff"].to_list() else float("nan") @@ -272,7 +284,9 @@ def main(cfg: DilemmasCalibratedCfg) -> None: si_rows.append({ "method": method, - "alpha": alpha_c, + "alpha": alpha_pos, + "alpha_pos": alpha_pos, + "alpha_neg": alpha_neg, "sign": sign_chosen, "SI": m["surgical_informedness"], "SI_to_do": m.get("SI_to_do", float("nan")), @@ -304,6 +318,7 @@ def main(cfg: DilemmasCalibratedCfg) -> None: si_df.write_csv(si_path) print("\n=== Dilemmas SI at KL-calibrated α (matched p95 token-KL ≈ 0.615 nats) ===") + print("SHOULD: use (-alpha_neg, 0, +alpha_pos) per method. Asymmetry is expected when left/right KL footprints differ.") print(tabulate(si_df.to_pandas(), headers="keys", tablefmt="tsv", floatfmt="+.3f", showindex=False)) @@ -313,9 +328,9 @@ def main(cfg: DilemmasCalibratedCfg) -> None: argv=get_argv(), main_metric=f"best_method={si_df['method'][0]} SI={float(si_df['SI'][0] or 0):+.3f}", cue=cue, - table_rows=si_df.select("method", "alpha", "sign", "SI", "si_fwd", "si_rev", + table_rows=si_df.select("method", "alpha_neg", "alpha_pos", "sign", "SI", "si_fwd", "si_rev", "fix_fwd", "broke_fwd").rows(), - headers=["method", "alpha", "sign", "SI", "si_fwd", "si_rev", "fix", "broke"], + headers=["method", "alpha_neg", "alpha_pos", "sign", "SI", "si_fwd", "si_rev", "fix", "broke"], floatfmt="", ) diff --git a/src/ws/eval/kl_calibrate.py b/src/ws/eval/kl_calibrate.py index 86e83d2..c3c8e62 100644 --- a/src/ws/eval/kl_calibrate.py +++ b/src/ws/eval/kl_calibrate.py @@ -72,13 +72,13 @@ class KLCalibrateCfg: include_repe: bool = True n_calib_prompts: int = 50 n_audit_prompts: int = 100 - n_tokens: int = 20 + n_tokens: int = 50 target_pct: float = 95.0 # "Side of the road" = 1 nat per-token KL (gist): # https://gist.github.com/wassname/6c11cf30b43d8c228bc114795f1019c7 # Newton residual is 1 − p95(KL); we search a global coefficient C such # that p95 KL = target_kl at α=1. - target_kl: float = 1.0 + target_kl: float = 0.5 target_prompt: str = "engineered_prompt_honest" # logged as a reference, not the target # Bracket guard (lo, hi) on the global coefficient. KL ~ α²·F near root, so # below ~0.05 nothing happens; above ~16 we'd be deep in collapse-land. @@ -87,7 +87,7 @@ class KLCalibrateCfg: n_root_iters: int = 12 # Illinois inner loop; usually converges in 3-5 convergence_tol: float = 0.05 # |p95 - target| < tol (absolute, in nats) repe_layers: tuple[int, ...] = field(default_factory=lambda: tuple(range(8, 22))) - n_repe_train: int = 20 + n_repe_train: int = 50 seed: int = 0 @@ -212,8 +212,19 @@ def _measure_kl_along_trajectory( } -def _illinois_calibrate(method: str, target: float, *, model, tok, prompts, cfg, - w=None, repe_dirs=None) -> dict: +def _illinois_calibrate( + method: str, + target: float, + *, + model, + tok, + prompts, + cfg, + alpha_sign: float = 1.0, + sign_label: str = "pos", + w=None, + repe_dirs=None, +) -> dict: """Exponential bracket within (bracket_lo, bracket_hi) then log-log Illinois regula falsi. Mirrors steering-lite's validated `calibrate_iso_kl`. @@ -227,18 +238,42 @@ def _illinois_calibrate(method: str, target: float, *, model, tok, prompts, cfg, history: list[dict] = [] iter_idx = [0] - def stat(alpha: float) -> float: + def _result(final: dict, converged: bool) -> dict: + return { + "method": method, + "sign": sign_label, + "alpha_sign": alpha_sign, + "alpha_mag": abs(final["alpha"]), + "calibrated_alpha": final["alpha"], + "p95_at_calib": final["p95"], + "mean_at_calib": final["mean"], + "max_at_calib": final["max"], + "ratio_at_calib": final["ratio"], + "iterations": len(history), + "converged": converged, + "history": history, + } + + def stat(alpha_mag: float) -> float: + alpha = alpha_sign * alpha_mag m = _measure_kl_along_trajectory( method, alpha, model=model, tok=tok, prompts=prompts, n_tokens=cfg.n_tokens, w=w, repe_dirs=repe_dirs, repe_layers=cfg.repe_layers, log_first_sample=(iter_idx[0] == 0), - sample_label=f"calib iter=0 method={method} α={alpha:+.3f}", + sample_label=f"calib iter=0 method={method} sign={sign_label} α={alpha:+.3f}", ) ratio = m["p95"] / target if target > 0 else 1.0 - history.append({"iter": iter_idx[0], "alpha": alpha, **m, "ratio": ratio}) + history.append({ + "iter": iter_idx[0], + "sign": sign_label, + "alpha": alpha, + "alpha_mag": alpha_mag, + **m, + "ratio": ratio, + }) logger.info( - f" [{method}] iter={iter_idx[0]} α={alpha:+.4f} p95={m['p95']:.4g} " + f" [{method}:{sign_label}] iter={iter_idx[0]} α={alpha:+.4f} p95={m['p95']:.4g} " f"mean={m['mean']:.4g} max={m['max']:.4g} ratio={ratio:.3f}" ) iter_idx[0] += 1 @@ -251,13 +286,7 @@ def _illinois_calibrate(method: str, target: float, *, model, tok, prompts, cfg, mid = float(np.sqrt(lo * hi)) v_mid = stat(mid) if abs(v_mid - target) < cfg.convergence_tol: - final = history[-1] - return { - "method": method, "calibrated_alpha": final["alpha"], - "p95_at_calib": final["p95"], "mean_at_calib": final["mean"], - "max_at_calib": final["max"], "ratio_at_calib": final["ratio"], - "iterations": len(history), "converged": True, "history": history, - } + return _result(history[-1], True) if v_mid < target: c_lo, v_lo = mid, v_mid @@ -283,14 +312,7 @@ def _illinois_calibrate(method: str, target: float, *, model, tok, prompts, cfg, c_hi, v_hi = c, v if v_lo is None or v_hi is None: - # Couldn't bracket within (lo, hi); report last point seen. - final = history[-1] - return { - "method": method, "calibrated_alpha": final["alpha"], - "p95_at_calib": final["p95"], "mean_at_calib": final["mean"], - "max_at_calib": final["max"], "ratio_at_calib": final["ratio"], - "iterations": len(history), "converged": False, "history": history, - } + return _result(history[-1], False) # 2. Log-log Illinois regula-falsi inside the bracket. converged = False @@ -324,22 +346,8 @@ def _illinois_calibrate(method: str, target: float, *, model, tok, prompts, cfg, # If we exhausted iters without hitting tol, pick the closest point seen. if not converged: - best = min(history, key=lambda h: abs(h["p95"] - target)) - final = best - else: - final = history[-1] - - return { - "method": method, - "calibrated_alpha": final["alpha"], - "p95_at_calib": final["p95"], - "mean_at_calib": final["mean"], - "max_at_calib": final["max"], - "ratio_at_calib": final["ratio"], - "iterations": len(history), - "converged": converged, - "history": history, - } + return _result(min(history, key=lambda h: abs(h["p95"] - target)), False) + return _result(history[-1], True) def main(cfg: KLCalibrateCfg) -> None: @@ -401,23 +409,33 @@ def main(cfg: KLCalibrateCfg) -> None: repe_dirs = _fit_repe_directions(model, tok, cfg.n_repe_train, cfg.behavior) # 3. Illinois regula-falsi calibrate each adapter and (optionally) RepE. - results = [] + results_by_method: dict[str, dict[str, dict]] = {} for adapter in cfg.adapters: logger.info(f"\n=== calibrate dW:{adapter} ===") w = load_diff(cfg.out / cfg.behavior / adapter / DIFF_FILENAME) - r = _illinois_calibrate( - f"dW:{adapter}", target, model=model, tok=tok, - prompts=calib_prompts, cfg=cfg, w=w, - ) - results.append(r) + results_by_method[f"dW:{adapter}"] = { + "pos": _illinois_calibrate( + f"dW:{adapter}", target, model=model, tok=tok, + prompts=calib_prompts, cfg=cfg, alpha_sign=1.0, sign_label="pos", w=w, + ), + "neg": _illinois_calibrate( + f"dW:{adapter}", target, model=model, tok=tok, + prompts=calib_prompts, cfg=cfg, alpha_sign=-1.0, sign_label="neg", w=w, + ), + } if cfg.include_repe: logger.info("\n=== calibrate repe ===") - r = _illinois_calibrate( - "repe", target, model=model, tok=tok, - prompts=calib_prompts, cfg=cfg, repe_dirs=repe_dirs, - ) - results.append(r) + results_by_method["repe"] = { + "pos": _illinois_calibrate( + "repe", target, model=model, tok=tok, + prompts=calib_prompts, cfg=cfg, alpha_sign=1.0, sign_label="pos", repe_dirs=repe_dirs, + ), + "neg": _illinois_calibrate( + "repe", target, model=model, tok=tok, + prompts=calib_prompts, cfg=cfg, alpha_sign=-1.0, sign_label="neg", repe_dirs=repe_dirs, + ), + } # 4. Audit: at calibrated α, recompute on n_audit prompts. logger.info(f"\n=== AUDIT (n={len(audit_prompts)} prompts) ===") @@ -441,63 +459,83 @@ def main(cfg: KLCalibrateCfg) -> None: "calib_audit_ratio": m_audit["p95"] / m_calib["p95"] if m_calib["p95"] > 0 else float("nan"), }) - for r in results: - method = r["method"] - alpha = r["calibrated_alpha"] + logger.info( + "SHOULD: pos and neg p95 each match the target independently. " + "Asymmetric alpha_pos/alpha_neg means the steering direction has asymmetric KL footprint, not failure." + ) + for method, signs in results_by_method.items(): if method.startswith("dW:"): adapter = method.split(":", 1)[1] w = load_diff(cfg.out / cfg.behavior / adapter / DIFF_FILENAME) - m_audit = _measure_kl_along_trajectory( - method, alpha, model=model, tok=tok, prompts=audit_prompts, - n_tokens=cfg.n_tokens, w=w, - ) - elif method == "repe": - m_audit = _measure_kl_along_trajectory( - method, alpha, model=model, tok=tok, prompts=audit_prompts, - n_tokens=cfg.n_tokens, repe_dirs=repe_dirs, - repe_layers=cfg.repe_layers, - ) else: - raise ValueError(method) - logger.info( - f" {method} α={alpha:+.3f} audit p95={m_audit['p95']:.4g} " - f"(calib was {r['p95_at_calib']:.4g}, target {target:.4g})" - ) - audit_rows.append({ - "method": method, - "alpha": alpha, - "p95_calib": r["p95_at_calib"], - "mean_calib": r["mean_at_calib"], - "p95_audit": m_audit["p95"], - "mean_audit": m_audit["mean"], - "max_audit": m_audit["max"], - "calib_audit_ratio": m_audit["p95"] / r["p95_at_calib"] if r["p95_at_calib"] > 0 else float("nan"), - }) + w = None + for sign_label, r in signs.items(): + alpha = r["calibrated_alpha"] + if method.startswith("dW:"): + m_audit = _measure_kl_along_trajectory( + method, alpha, model=model, tok=tok, prompts=audit_prompts, + n_tokens=cfg.n_tokens, w=w, + ) + elif method == "repe": + m_audit = _measure_kl_along_trajectory( + method, alpha, model=model, tok=tok, prompts=audit_prompts, + n_tokens=cfg.n_tokens, repe_dirs=repe_dirs, + repe_layers=cfg.repe_layers, + ) + else: + raise ValueError(method) + logger.info( + f" {method}:{sign_label} α={alpha:+.3f} audit p95={m_audit['p95']:.4g} " + f"(calib was {r['p95_at_calib']:.4g}, target {target:.4g})" + ) + audit_rows.append({ + "method": method, + "sign": sign_label, + "alpha": alpha, + "alpha_mag": r["alpha_mag"], + "p95_calib": r["p95_at_calib"], + "mean_calib": r["mean_at_calib"], + "p95_audit": m_audit["p95"], + "mean_audit": m_audit["mean"], + "max_audit": m_audit["max"], + "calib_audit_ratio": m_audit["p95"] / r["p95_at_calib"] if r["p95_at_calib"] > 0 else float("nan"), + }) audit_df = pl.DataFrame(audit_rows) audit_df.write_csv(out_dir / "audit.csv") summary_rows = [] - for r in results: + for method, signs in results_by_method.items(): + pos = signs["pos"] + neg = signs["neg"] summary_rows.append({ - "method": r["method"], - "calibrated_alpha": r["calibrated_alpha"], - "p95_at_calib": r["p95_at_calib"], - "mean_at_calib": r["mean_at_calib"], - "max_at_calib": r["max_at_calib"], - "ratio_at_calib": r["ratio_at_calib"], - "iterations": r["iterations"], - "converged": r["converged"], + "method": method, + "alpha_pos": pos["alpha_mag"], + "alpha_neg": neg["alpha_mag"], + "calibrated_alpha": pos["alpha_mag"], + "p95_at_pos": pos["p95_at_calib"], + "p95_at_neg": neg["p95_at_calib"], + "mean_at_pos": pos["mean_at_calib"], + "mean_at_neg": neg["mean_at_calib"], + "max_at_pos": pos["max_at_calib"], + "max_at_neg": neg["max_at_calib"], + "ratio_at_pos": pos["ratio_at_calib"], + "ratio_at_neg": neg["ratio_at_calib"], + "iterations_pos": pos["iterations"], + "iterations_neg": neg["iterations"], + "converged_pos": pos["converged"], + "converged_neg": neg["converged"], }) - summary_df = pl.DataFrame(summary_rows).sort("calibrated_alpha") + summary_df = pl.DataFrame(summary_rows).sort("alpha_pos") summary_df = summary_df.with_columns(pl.lit(target).alias("target_p95")) summary_path = out_dir / "summary.csv" summary_df.write_csv(summary_path) history_rows = [] - for r in results: - for h in r["history"]: - history_rows.append({"method": r["method"], **h}) + for method, signs in results_by_method.items(): + for sign_label, r in signs.items(): + for h in r["history"]: + history_rows.append({"method": method, "sign": sign_label, **h}) pl.DataFrame(history_rows).write_csv(out_dir / "root_history.csv") pl.DataFrame([{"method": k, **v} for k, v in prompt_refs.items()]).write_csv(out_dir / "prompt_refs.csv") @@ -510,15 +548,23 @@ def main(cfg: KLCalibrateCfg) -> None: print(tabulate(audit_df.to_pandas(), headers="keys", tablefmt="tsv", floatfmt="+.4g", showindex=False)) - cue = "🟢" if all(r["converged"] for r in results) else "🟡" + n_converged = sum( + int(r["converged"]) + for signs in results_by_method.values() + for r in signs.values() + ) + n_total = sum(len(signs) for signs in results_by_method.values()) + cue = "🟢" if n_converged == n_total else "🟡" final_summary( out=summary_path, argv=get_argv(), - main_metric=f"target_p95={target:.4g} converged={sum(r['converged'] for r in results)}/{len(results)}", + main_metric=f"target_p95={target:.4g} converged={n_converged}/{n_total}", cue=cue, - table_rows=summary_df.select("method", "calibrated_alpha", "p95_at_calib", - "ratio_at_calib", "iterations", "converged").rows(), - headers=["method", "alpha", "p95", "ratio", "iters", "ok"], + table_rows=summary_df.select( + "method", "alpha_neg", "alpha_pos", "p95_at_neg", "p95_at_pos", + "iterations_neg", "iterations_pos", "converged_neg", "converged_pos" + ).rows(), + headers=["method", "alpha_neg", "alpha_pos", "p95_neg", "p95_pos", "iters_neg", "iters_pos", "ok_neg", "ok_pos"], floatfmt="", ) diff --git a/src/ws/replicate.py b/src/ws/replicate.py index 119e14d..7c675a4 100644 --- a/src/ws/replicate.py +++ b/src/ws/replicate.py @@ -43,6 +43,14 @@ class Cfg: out: Path = Path("out") # Shared data dir — kept separate from out so multiple runs reuse the same pairs. data_root: Path = Path("out/data") + data_batch_size: int = 8 + data_min_new_tokens: int = 1024 + data_max_new_tokens: int = 1280 + data_temperature: float | None = None + data_top_p: float | None = None + data_top_k: int | None = None + data_min_p: float | None = None + data_presence_penalty: float = 0.0 coeffs: tuple[float, ...] = (-2.0, -1.0, 0.0, 1.0, 2.0) @@ -66,6 +74,14 @@ def _maybe_data(cfg: Cfg) -> Dataset: dcfg = DataCfg( model_id=cfg.model, behavior=cfg.behavior, out=data_root, n_topics=cfg.n_topics, n_personas=cfg.n_personas, n_samples=cfg.n_samples, + batch_size=cfg.data_batch_size, + min_new_tokens=cfg.data_min_new_tokens, + max_new_tokens=cfg.data_max_new_tokens, + temperature=cfg.data_temperature, + top_p=cfg.data_top_p, + top_k=cfg.data_top_k, + min_p=cfg.data_min_p, + presence_penalty=cfg.data_presence_penalty, ) generate_pairs(dcfg) return load_pairs(cfg.behavior, root=data_root)