mirror of
https://github.com/wassname/weight-steering.git
synced 2026-06-27 16:17:59 +08:00
Add batched data gen and bidir calibration
This commit is contained in:
@@ -49,6 +49,13 @@ def main(cfg: SmokeCfg) -> None:
|
||||
n_topics=2, # 2×1×2 = 4 pairs
|
||||
n_personas=1,
|
||||
n_samples=2,
|
||||
data_batch_size=2,
|
||||
data_min_new_tokens=16,
|
||||
data_max_new_tokens=32,
|
||||
data_temperature=0.7,
|
||||
data_top_p=0.8,
|
||||
data_top_k=20,
|
||||
data_min_p=0.0,
|
||||
)
|
||||
replicate_main(rcfg)
|
||||
print("[smoke] OK", flush=True)
|
||||
|
||||
@@ -31,7 +31,11 @@ smoke-sweep:
|
||||
|
||||
# Generate +/- pair data for a behavior. Writes to out/data/{behavior}/.
|
||||
data:
|
||||
uv run python -m ws.data --model {{model}} --behavior {{behavior}}
|
||||
uv run python -m ws.data --model-id {{model}} --behavior {{behavior}}
|
||||
|
||||
# One-off greedy persona collapse debugger.
|
||||
debug-personas:
|
||||
uv run python -m ws.debug_personas --model {{model}} --behavior {{behavior}} --out {{out}}
|
||||
|
||||
# Train a single adapter (positive or negative). Pos/neg controls system prompt at gen time.
|
||||
train sign="pos":
|
||||
|
||||
+200
-25
@@ -18,16 +18,19 @@ Output columns:
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import tyro
|
||||
from datasets import Dataset
|
||||
from loguru import logger
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from ws._tok_extras import chat_template_extras
|
||||
from ws._log import get_argv, setup_logging
|
||||
from ws._tok_extras import chat_template_extras, has_thinking_mode
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
DATA_DIR = REPO_ROOT / "data"
|
||||
@@ -162,8 +165,14 @@ class DataCfg:
|
||||
n_personas: int = 5
|
||||
n_samples: int = 10
|
||||
out: Path = Path("out/data")
|
||||
max_new_tokens: int = 96
|
||||
temperature: float = 0.8
|
||||
batch_size: int = 8
|
||||
min_new_tokens: int = 1024
|
||||
max_new_tokens: int = 1280
|
||||
temperature: float | None = None
|
||||
top_p: float | None = None
|
||||
top_k: int | None = None
|
||||
min_p: float | None = None
|
||||
presence_penalty: float = 0.0
|
||||
seed: int = 0
|
||||
|
||||
|
||||
@@ -212,21 +221,127 @@ def _build_specs(topics, n_personas: int, n_samples: int, behavior: str):
|
||||
return specs
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _gen(model, tok, sys_prompt: str, user_prompt: str, max_new_tokens: int, temperature: float):
|
||||
def _render_chat_prompt(tok, sys_prompt: str, user_prompt: str) -> str:
|
||||
msgs = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": user_prompt}]
|
||||
text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True,
|
||||
**chat_template_extras(tok))
|
||||
inputs = tok(text, return_tensors="pt").to(model.device)
|
||||
out = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=temperature > 0,
|
||||
temperature=temperature if temperature > 0 else 1.0,
|
||||
pad_token_id=tok.pad_token_id or tok.eos_token_id,
|
||||
return tok.apply_chat_template(
|
||||
msgs,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
**chat_template_extras(tok),
|
||||
)
|
||||
gen = out[0, inputs["input_ids"].shape[1]:]
|
||||
return tok.decode(gen, skip_special_tokens=True).strip()
|
||||
|
||||
|
||||
def _sampling_defaults(tok, cfg: DataCfg) -> dict[str, Any]:
|
||||
thinking = has_thinking_mode(tok)
|
||||
defaults = {
|
||||
"temperature": 0.6 if thinking else 0.7,
|
||||
"top_p": 0.95 if thinking else 0.8,
|
||||
"top_k": 20,
|
||||
"min_p": 0.0,
|
||||
}
|
||||
params = {
|
||||
"temperature": defaults["temperature"] if cfg.temperature is None else cfg.temperature,
|
||||
"top_p": defaults["top_p"] if cfg.top_p is None else cfg.top_p,
|
||||
"top_k": defaults["top_k"] if cfg.top_k is None else cfg.top_k,
|
||||
"min_p": defaults["min_p"] if cfg.min_p is None else cfg.min_p,
|
||||
}
|
||||
if params["temperature"] <= 0:
|
||||
logger.warning(
|
||||
"data generation temperature<=0 enables greedy decoding. "
|
||||
"Qwen recommends sampling for training data; use this only deliberately."
|
||||
)
|
||||
if cfg.presence_penalty:
|
||||
logger.warning(
|
||||
"presence_penalty is configured but this HF generate path does not support it directly; ignoring."
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
def _trim_generation_ids(ids: torch.Tensor, eos_id: int | None, pad_id: int | None) -> tuple[torch.Tensor, bool]:
|
||||
trimmed: list[int] = []
|
||||
hit_eos = False
|
||||
for tok_id in ids.tolist():
|
||||
if pad_id is not None and tok_id == pad_id:
|
||||
continue
|
||||
trimmed.append(tok_id)
|
||||
if eos_id is not None and tok_id == eos_id:
|
||||
hit_eos = True
|
||||
break
|
||||
return torch.tensor(trimmed, dtype=torch.long), hit_eos
|
||||
|
||||
|
||||
def _normalize_text(text: str) -> str:
|
||||
return " ".join(text.split())
|
||||
|
||||
|
||||
def _log_trace(tok, *, prompt_text: str, gen_ids: torch.Tensor, clean_text: str, label: str) -> None:
|
||||
prompt_ids = tok(prompt_text, return_tensors="pt").input_ids[0]
|
||||
first = tok.convert_ids_to_tokens(prompt_ids[: min(8, len(prompt_ids))].tolist())
|
||||
last = tok.convert_ids_to_tokens(prompt_ids[-min(8, len(prompt_ids)):].tolist())
|
||||
raw_gen = tok.decode(gen_ids, skip_special_tokens=False)
|
||||
raw_toks = tok.convert_ids_to_tokens(gen_ids.tolist()) if len(gen_ids) else []
|
||||
logger.info(f"[{label}] full prompt (special tokens included):\n{prompt_text}")
|
||||
logger.info(f"[{label}] n_input_tokens={prompt_ids.shape[0]} first8={first} last8={last}")
|
||||
logger.info(f"[{label}] raw generated continuation: {raw_gen!r}")
|
||||
logger.info(f"[{label}] generated tokens: {raw_toks}")
|
||||
logger.info(f"[{label}] cleaned continuation: {clean_text!r}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _generate_batch(
|
||||
model,
|
||||
tok,
|
||||
prompts: list[str],
|
||||
cfg: DataCfg,
|
||||
sampling: dict[str, Any],
|
||||
*,
|
||||
trace_label: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
if cfg.max_new_tokens < cfg.min_new_tokens:
|
||||
raise ValueError(
|
||||
f"max_new_tokens={cfg.max_new_tokens} must be >= min_new_tokens={cfg.min_new_tokens}"
|
||||
)
|
||||
results: list[dict[str, Any]] = []
|
||||
old_padding_side = tok.padding_side
|
||||
tok.padding_side = "left"
|
||||
try:
|
||||
for start in tqdm(range(0, len(prompts), cfg.batch_size), desc=f"gen {trace_label}", mininterval=60):
|
||||
batch_prompts = prompts[start:start + cfg.batch_size]
|
||||
enc = tok(batch_prompts, return_tensors="pt", padding=True).to(model.device)
|
||||
out = model.generate(
|
||||
**enc,
|
||||
min_new_tokens=cfg.min_new_tokens,
|
||||
max_new_tokens=cfg.max_new_tokens,
|
||||
do_sample=sampling["temperature"] > 0,
|
||||
temperature=sampling["temperature"] if sampling["temperature"] > 0 else 1.0,
|
||||
top_p=sampling["top_p"],
|
||||
top_k=sampling["top_k"],
|
||||
min_p=sampling["min_p"],
|
||||
pad_token_id=tok.pad_token_id or tok.eos_token_id,
|
||||
eos_token_id=tok.eos_token_id,
|
||||
)
|
||||
gen_block = out[:, enc["input_ids"].shape[1]:].cpu()
|
||||
for i, prompt_text in enumerate(batch_prompts):
|
||||
raw_ids, hit_eos = _trim_generation_ids(gen_block[i], tok.eos_token_id, tok.pad_token_id)
|
||||
clean = tok.decode(raw_ids, skip_special_tokens=True).rstrip()
|
||||
results.append({
|
||||
"clean_text": clean,
|
||||
"raw_text": tok.decode(raw_ids, skip_special_tokens=False),
|
||||
"gen_ids": raw_ids,
|
||||
"hit_eos": hit_eos,
|
||||
"prompt_text": prompt_text,
|
||||
})
|
||||
if results:
|
||||
_log_trace(
|
||||
tok,
|
||||
prompt_text=results[0]["prompt_text"],
|
||||
gen_ids=results[0]["gen_ids"],
|
||||
clean_text=results[0]["clean_text"],
|
||||
label=trace_label,
|
||||
)
|
||||
finally:
|
||||
tok.padding_side = old_padding_side
|
||||
return results
|
||||
|
||||
|
||||
def assert_generated_pairs_diverged(ds: Dataset) -> None:
|
||||
@@ -272,6 +387,10 @@ def assert_generated_pairs_diverged(ds: Dataset) -> None:
|
||||
f"unique_pos={len({r['response_pos'].strip() for r in rows})}, "
|
||||
f"unique_neg={len({r['response_neg'].strip() for r in rows})}"
|
||||
)
|
||||
logger.info(
|
||||
"SHOULD: identical_pos_neg stay low, unique_pos/unique_neg stay high, "
|
||||
"and empty generations stay at zero. Large identical counts mean persona collapse."
|
||||
)
|
||||
|
||||
|
||||
# TODO judge filter: paper §3 uses GPT-4.1-mini to drop rows where r_pos doesn't
|
||||
@@ -311,24 +430,69 @@ def generate_pairs(cfg: DataCfg) -> Path:
|
||||
)
|
||||
model.eval()
|
||||
|
||||
rows = []
|
||||
for i, spec in enumerate(tqdm(specs, desc=f"gen {cfg.behavior}", mininterval=60)):
|
||||
sampling = _sampling_defaults(tok, cfg)
|
||||
logger.info(f"sampling={sampling} thinking_mode={has_thinking_mode(tok)}")
|
||||
logger.info(
|
||||
"SHOULD: training data use sampling, not greedy decoding. "
|
||||
"Each continuation should be >= min_new_tokens and ideally terminate on EOS before max_new_tokens."
|
||||
)
|
||||
|
||||
rendered = []
|
||||
for spec in specs:
|
||||
sys_pos = sys_pos_list[spec["persona_idx"]]
|
||||
sys_neg = sys_neg_list[spec["persona_idx"]]
|
||||
r_pos = _gen(model, tok, sys_pos, spec["prompt"], cfg.max_new_tokens, cfg.temperature)
|
||||
r_neg = _gen(model, tok, sys_neg, spec["prompt"], cfg.max_new_tokens, cfg.temperature)
|
||||
rows.append({
|
||||
"prompt": spec["prompt"],
|
||||
"response_pos": r_pos,
|
||||
"response_neg": r_neg,
|
||||
rendered.append({
|
||||
**spec,
|
||||
"sys_prompt_pos": sys_pos,
|
||||
"sys_prompt_neg": sys_neg,
|
||||
"prompt_text_pos": _render_chat_prompt(tok, sys_pos, spec["prompt"]),
|
||||
"prompt_text_neg": _render_chat_prompt(tok, sys_neg, spec["prompt"]),
|
||||
})
|
||||
|
||||
pos_results = _generate_batch(
|
||||
model,
|
||||
tok,
|
||||
[r["prompt_text_pos"] for r in rendered],
|
||||
cfg,
|
||||
sampling,
|
||||
trace_label="data stage pair0 pos",
|
||||
)
|
||||
neg_results = _generate_batch(
|
||||
model,
|
||||
tok,
|
||||
[r["prompt_text_neg"] for r in rendered],
|
||||
cfg,
|
||||
sampling,
|
||||
trace_label="data stage pair0 neg",
|
||||
)
|
||||
|
||||
rows = []
|
||||
no_eos_pos = 0
|
||||
no_eos_neg = 0
|
||||
for spec, pos, neg in zip(rendered, pos_results, neg_results, strict=True):
|
||||
no_eos_pos += int(not pos["hit_eos"])
|
||||
no_eos_neg += int(not neg["hit_eos"])
|
||||
rows.append({
|
||||
"prompt": spec["prompt"],
|
||||
"response_pos": pos["clean_text"],
|
||||
"response_neg": neg["clean_text"],
|
||||
"sys_prompt_pos": spec["sys_prompt_pos"],
|
||||
"sys_prompt_neg": spec["sys_prompt_neg"],
|
||||
"topic_idx": spec["topic_idx"],
|
||||
"persona_idx": spec["persona_idx"],
|
||||
"sample_idx": spec["sample_idx"],
|
||||
"behavior": cfg.behavior,
|
||||
})
|
||||
|
||||
logger.info(
|
||||
f"eos coverage: pos={len(rows) - no_eos_pos}/{len(rows)} neg={len(rows) - no_eos_neg}/{len(rows)}"
|
||||
)
|
||||
if no_eos_pos or no_eos_neg:
|
||||
logger.warning(
|
||||
f"{no_eos_pos + no_eos_neg} generations hit max_new_tokens before EOS. "
|
||||
"Increase max_new_tokens if this is common."
|
||||
)
|
||||
|
||||
ds = Dataset.from_list(rows)
|
||||
assert_generated_pairs_diverged(ds)
|
||||
out_dir = cfg.out / cfg.behavior
|
||||
@@ -342,3 +506,14 @@ def load_pairs(behavior: str, root: Path = Path("out/data")) -> Dataset:
|
||||
ds = Dataset.load_from_disk(str(root / behavior))
|
||||
assert_generated_pairs_diverged(ds)
|
||||
return ds
|
||||
|
||||
|
||||
def main(cfg: DataCfg) -> None:
|
||||
setup_logging("data")
|
||||
logger.info(f"argv: {get_argv()}")
|
||||
logger.info(f"data cfg: {asdict(cfg)}")
|
||||
generate_pairs(cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(tyro.cli(DataCfg))
|
||||
|
||||
@@ -0,0 +1,139 @@
|
||||
"""One-off persona collapse debugger.
|
||||
|
||||
For each persona pair, greedy-generate short continuations on a fixed prompt
|
||||
set and warn if left/right collapse to the same text.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import polars as pl
|
||||
import torch
|
||||
import tyro
|
||||
from loguru import logger
|
||||
from tabulate import tabulate
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from ws._log import final_summary, get_argv, setup_logging
|
||||
from ws.data import _normalize_text, _personas, _render_chat_prompt, _topics
|
||||
|
||||
|
||||
@dataclass
|
||||
class PersonaDebugCfg:
|
||||
model: str = "Qwen/Qwen3-0.6B"
|
||||
behavior: str = "honesty"
|
||||
out: Path = Path("out")
|
||||
n_prompts: int = 8
|
||||
max_new_tokens: int = 100
|
||||
batch_size: int = 8
|
||||
seed: int = 0
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _greedy_batch(model, tok, prompts: list[str], batch_size: int, max_new_tokens: int) -> list[str]:
|
||||
rows: list[str] = []
|
||||
old_padding_side = tok.padding_side
|
||||
tok.padding_side = "left"
|
||||
try:
|
||||
for start in range(0, len(prompts), batch_size):
|
||||
batch_prompts = prompts[start:start + batch_size]
|
||||
enc = tok(batch_prompts, return_tensors="pt", padding=True).to(model.device)
|
||||
out = model.generate(
|
||||
**enc,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=False,
|
||||
temperature=1.0,
|
||||
pad_token_id=tok.pad_token_id or tok.eos_token_id,
|
||||
eos_token_id=tok.eos_token_id,
|
||||
)
|
||||
gen_block = out[:, enc["input_ids"].shape[1]:].cpu()
|
||||
for i in range(len(batch_prompts)):
|
||||
rows.append(tok.decode(gen_block[i], skip_special_tokens=True).rstrip())
|
||||
finally:
|
||||
tok.padding_side = old_padding_side
|
||||
return rows
|
||||
|
||||
|
||||
def main(cfg: PersonaDebugCfg) -> None:
|
||||
setup_logging("debug_personas")
|
||||
logger.info(f"argv: {get_argv()}")
|
||||
logger.info(f"persona debug cfg: {asdict(cfg)}")
|
||||
|
||||
tok = AutoTokenizer.from_pretrained(cfg.model)
|
||||
if tok.pad_token is None:
|
||||
tok.pad_token = tok.eos_token
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
cfg.model, torch_dtype=torch.bfloat16, device_map="auto"
|
||||
)
|
||||
model.eval()
|
||||
|
||||
pos_personas, neg_personas = _personas(cfg.behavior)
|
||||
topics = _topics(cfg.behavior)[:cfg.n_prompts]
|
||||
prompts: list[str] = []
|
||||
for a, b in topics:
|
||||
prompts.append(f"I think {a}. {b}" if cfg.behavior == "sycophancy" else a)
|
||||
|
||||
rows = []
|
||||
for persona_idx, (sys_pos, sys_neg) in enumerate(zip(pos_personas, neg_personas, strict=True)):
|
||||
prompt_pos = [_render_chat_prompt(tok, sys_pos, prompt) for prompt in prompts]
|
||||
prompt_neg = [_render_chat_prompt(tok, sys_neg, prompt) for prompt in prompts]
|
||||
gens_pos = _greedy_batch(model, tok, prompt_pos, cfg.batch_size, cfg.max_new_tokens)
|
||||
gens_neg = _greedy_batch(model, tok, prompt_neg, cfg.batch_size, cfg.max_new_tokens)
|
||||
identical = 0
|
||||
for prompt, gen_pos, gen_neg in zip(prompts, gens_pos, gens_neg, strict=True):
|
||||
same = _normalize_text(gen_pos) == _normalize_text(gen_neg)
|
||||
identical += int(same)
|
||||
rows.append({
|
||||
"persona_idx": persona_idx,
|
||||
"prompt": prompt,
|
||||
"same": same,
|
||||
"response_pos": gen_pos,
|
||||
"response_neg": gen_neg,
|
||||
})
|
||||
if identical:
|
||||
logger.warning(
|
||||
f"persona_idx={persona_idx} collapsed on {identical}/{len(prompts)} greedy probes; "
|
||||
"discard this pair from persona debugging."
|
||||
)
|
||||
|
||||
df = pl.DataFrame(rows)
|
||||
out_dir = cfg.out / cfg.behavior / "persona_debug"
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
per_prompt_path = out_dir / "per_prompt.csv"
|
||||
summary_path = out_dir / "summary.csv"
|
||||
df.write_csv(per_prompt_path)
|
||||
|
||||
summary = (
|
||||
df.group_by("persona_idx")
|
||||
.agg(
|
||||
pl.len().alias("n_prompts"),
|
||||
pl.col("same").sum().alias("n_same"),
|
||||
)
|
||||
.with_columns(
|
||||
(pl.col("n_same") / pl.col("n_prompts")).alias("same_rate"),
|
||||
(pl.col("n_same") == 0).alias("keep_pair"),
|
||||
)
|
||||
.sort("persona_idx")
|
||||
)
|
||||
summary.write_csv(summary_path)
|
||||
|
||||
print("\npersona_debug")
|
||||
print("SHOULD: left/right greedy probes differ for each persona pair. same_rate>0 means the persona contrast is weak or ignored.")
|
||||
print(tabulate(summary.to_pandas(), headers="keys", tablefmt="tsv", floatfmt="+.3f", showindex=False))
|
||||
|
||||
cue = "🟢" if bool(summary["keep_pair"].all()) else "🟡"
|
||||
final_summary(
|
||||
out=summary_path,
|
||||
argv=get_argv(),
|
||||
main_metric=f"keep_pairs={int(summary['keep_pair'].sum())}/{len(summary)}",
|
||||
cue=cue,
|
||||
table_rows=summary.select("persona_idx", "n_prompts", "n_same", "same_rate", "keep_pair").rows(),
|
||||
headers=["persona_idx", "n_prompts", "n_same", "same_rate", "keep_pair"],
|
||||
floatfmt="",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(tyro.cli(PersonaDebugCfg))
|
||||
@@ -104,6 +104,12 @@ def _eval_dilemmas_repe(model, tok, dirs, layers, alpha, dl, choice_ids, pmass_t
|
||||
return rows
|
||||
|
||||
|
||||
def _alpha_triplet(row: dict) -> tuple[float, float]:
|
||||
alpha_pos = float(row["alpha_pos"]) if "alpha_pos" in row else float(row["calibrated_alpha"])
|
||||
alpha_neg = float(row["alpha_neg"]) if "alpha_neg" in row else float(row["calibrated_alpha"])
|
||||
return alpha_pos, alpha_neg
|
||||
|
||||
|
||||
def main(cfg: DilemmasCalibratedCfg) -> None:
|
||||
setup_logging("dilemmas_calibrated")
|
||||
out_dir = cfg.out / cfg.behavior / "dilemmas_calibrated"
|
||||
@@ -112,8 +118,12 @@ def main(cfg: DilemmasCalibratedCfg) -> None:
|
||||
calib_path = cfg.out / cfg.behavior / "kl_calibration" / "summary.csv"
|
||||
calib = pl.read_csv(calib_path)
|
||||
logger.info(f"loaded calibration: {len(calib)} methods from {calib_path}")
|
||||
logger.info(tabulate(calib.select("method", "calibrated_alpha", "p95_at_calib").to_pandas(),
|
||||
headers="keys", tablefmt="tsv", floatfmt="+.3f", showindex=False))
|
||||
cols = ["method", "alpha_neg", "alpha_pos", "p95_at_neg", "p95_at_pos"]
|
||||
fallback_cols = ["method", "calibrated_alpha"]
|
||||
logger.info(tabulate(
|
||||
calib.select([c for c in cols if c in calib.columns] or fallback_cols).to_pandas(),
|
||||
headers="keys", tablefmt="tsv", floatfmt="+.3f", showindex=False
|
||||
))
|
||||
|
||||
tok = AutoTokenizer.from_pretrained(cfg.model)
|
||||
if tok.pad_token is None:
|
||||
@@ -153,12 +163,12 @@ def main(cfg: DilemmasCalibratedCfg) -> None:
|
||||
# Adapter dW evals at calibrated ±α and 0.
|
||||
for row in calib.iter_rows(named=True):
|
||||
method = row["method"]
|
||||
alpha_c = float(row["calibrated_alpha"])
|
||||
alpha_pos, alpha_neg = _alpha_triplet(row)
|
||||
if method.startswith("dW:"):
|
||||
adapter = method.split(":", 1)[1]
|
||||
w = load_diff(cfg.out / cfg.behavior / adapter / DIFF_FILENAME)
|
||||
rows = []
|
||||
for alpha in (-alpha_c, 0.0, alpha_c):
|
||||
for alpha in (-alpha_neg, 0.0, alpha_pos):
|
||||
rows.extend(_eval_dilemmas_dw(model, tok, w, alpha, dl, choice_ids,
|
||||
cfg.pmass_threshold, method))
|
||||
logger.info(f" {method} α={alpha:+.3f}: {len(ds_pt)} rows")
|
||||
@@ -166,7 +176,7 @@ def main(cfg: DilemmasCalibratedCfg) -> None:
|
||||
elif method == "repe":
|
||||
dirs = _fit_repe_directions(model, tok, cfg.n_repe_train, cfg.behavior)
|
||||
rows = []
|
||||
for alpha in (-alpha_c, 0.0, alpha_c):
|
||||
for alpha in (-alpha_neg, 0.0, alpha_pos):
|
||||
rows.extend(_eval_dilemmas_repe(model, tok, dirs, cfg.repe_layers, alpha, dl,
|
||||
choice_ids, cfg.pmass_threshold))
|
||||
logger.info(f" repe α={alpha:+.3f}: {len(ds_pt)} rows")
|
||||
@@ -261,9 +271,11 @@ def main(cfg: DilemmasCalibratedCfg) -> None:
|
||||
|
||||
# Get calibrated alpha for this method (1.0 for prompts).
|
||||
if method.startswith("prompt:"):
|
||||
alpha_c = 1.0
|
||||
alpha_pos = 1.0
|
||||
alpha_neg = 1.0
|
||||
else:
|
||||
alpha_c = float(calib.filter(pl.col("method") == method)["calibrated_alpha"][0])
|
||||
row = next(calib.filter(pl.col("method") == method).iter_rows(named=True))
|
||||
alpha_pos, alpha_neg = _alpha_triplet(row)
|
||||
|
||||
# Mean logratio_honesty per coeff.
|
||||
zero_lr = float(sub.filter(pl.col("coeff") == 0.0)["logratio_honesty"].mean()) if 0.0 in sub["coeff"].to_list() else float("nan")
|
||||
@@ -272,7 +284,9 @@ def main(cfg: DilemmasCalibratedCfg) -> None:
|
||||
|
||||
si_rows.append({
|
||||
"method": method,
|
||||
"alpha": alpha_c,
|
||||
"alpha": alpha_pos,
|
||||
"alpha_pos": alpha_pos,
|
||||
"alpha_neg": alpha_neg,
|
||||
"sign": sign_chosen,
|
||||
"SI": m["surgical_informedness"],
|
||||
"SI_to_do": m.get("SI_to_do", float("nan")),
|
||||
@@ -304,6 +318,7 @@ def main(cfg: DilemmasCalibratedCfg) -> None:
|
||||
si_df.write_csv(si_path)
|
||||
|
||||
print("\n=== Dilemmas SI at KL-calibrated α (matched p95 token-KL ≈ 0.615 nats) ===")
|
||||
print("SHOULD: use (-alpha_neg, 0, +alpha_pos) per method. Asymmetry is expected when left/right KL footprints differ.")
|
||||
print(tabulate(si_df.to_pandas(), headers="keys", tablefmt="tsv",
|
||||
floatfmt="+.3f", showindex=False))
|
||||
|
||||
@@ -313,9 +328,9 @@ def main(cfg: DilemmasCalibratedCfg) -> None:
|
||||
argv=get_argv(),
|
||||
main_metric=f"best_method={si_df['method'][0]} SI={float(si_df['SI'][0] or 0):+.3f}",
|
||||
cue=cue,
|
||||
table_rows=si_df.select("method", "alpha", "sign", "SI", "si_fwd", "si_rev",
|
||||
table_rows=si_df.select("method", "alpha_neg", "alpha_pos", "sign", "SI", "si_fwd", "si_rev",
|
||||
"fix_fwd", "broke_fwd").rows(),
|
||||
headers=["method", "alpha", "sign", "SI", "si_fwd", "si_rev", "fix", "broke"],
|
||||
headers=["method", "alpha_neg", "alpha_pos", "sign", "SI", "si_fwd", "si_rev", "fix", "broke"],
|
||||
floatfmt="",
|
||||
)
|
||||
|
||||
|
||||
+143
-97
@@ -72,13 +72,13 @@ class KLCalibrateCfg:
|
||||
include_repe: bool = True
|
||||
n_calib_prompts: int = 50
|
||||
n_audit_prompts: int = 100
|
||||
n_tokens: int = 20
|
||||
n_tokens: int = 50
|
||||
target_pct: float = 95.0
|
||||
# "Side of the road" = 1 nat per-token KL (gist):
|
||||
# https://gist.github.com/wassname/6c11cf30b43d8c228bc114795f1019c7
|
||||
# Newton residual is 1 − p95(KL); we search a global coefficient C such
|
||||
# that p95 KL = target_kl at α=1.
|
||||
target_kl: float = 1.0
|
||||
target_kl: float = 0.5
|
||||
target_prompt: str = "engineered_prompt_honest" # logged as a reference, not the target
|
||||
# Bracket guard (lo, hi) on the global coefficient. KL ~ α²·F near root, so
|
||||
# below ~0.05 nothing happens; above ~16 we'd be deep in collapse-land.
|
||||
@@ -87,7 +87,7 @@ class KLCalibrateCfg:
|
||||
n_root_iters: int = 12 # Illinois inner loop; usually converges in 3-5
|
||||
convergence_tol: float = 0.05 # |p95 - target| < tol (absolute, in nats)
|
||||
repe_layers: tuple[int, ...] = field(default_factory=lambda: tuple(range(8, 22)))
|
||||
n_repe_train: int = 20
|
||||
n_repe_train: int = 50
|
||||
seed: int = 0
|
||||
|
||||
|
||||
@@ -212,8 +212,19 @@ def _measure_kl_along_trajectory(
|
||||
}
|
||||
|
||||
|
||||
def _illinois_calibrate(method: str, target: float, *, model, tok, prompts, cfg,
|
||||
w=None, repe_dirs=None) -> dict:
|
||||
def _illinois_calibrate(
|
||||
method: str,
|
||||
target: float,
|
||||
*,
|
||||
model,
|
||||
tok,
|
||||
prompts,
|
||||
cfg,
|
||||
alpha_sign: float = 1.0,
|
||||
sign_label: str = "pos",
|
||||
w=None,
|
||||
repe_dirs=None,
|
||||
) -> dict:
|
||||
"""Exponential bracket within (bracket_lo, bracket_hi) then log-log Illinois
|
||||
regula falsi. Mirrors steering-lite's validated `calibrate_iso_kl`.
|
||||
|
||||
@@ -227,18 +238,42 @@ def _illinois_calibrate(method: str, target: float, *, model, tok, prompts, cfg,
|
||||
history: list[dict] = []
|
||||
iter_idx = [0]
|
||||
|
||||
def stat(alpha: float) -> float:
|
||||
def _result(final: dict, converged: bool) -> dict:
|
||||
return {
|
||||
"method": method,
|
||||
"sign": sign_label,
|
||||
"alpha_sign": alpha_sign,
|
||||
"alpha_mag": abs(final["alpha"]),
|
||||
"calibrated_alpha": final["alpha"],
|
||||
"p95_at_calib": final["p95"],
|
||||
"mean_at_calib": final["mean"],
|
||||
"max_at_calib": final["max"],
|
||||
"ratio_at_calib": final["ratio"],
|
||||
"iterations": len(history),
|
||||
"converged": converged,
|
||||
"history": history,
|
||||
}
|
||||
|
||||
def stat(alpha_mag: float) -> float:
|
||||
alpha = alpha_sign * alpha_mag
|
||||
m = _measure_kl_along_trajectory(
|
||||
method, alpha, model=model, tok=tok, prompts=prompts,
|
||||
n_tokens=cfg.n_tokens, w=w, repe_dirs=repe_dirs,
|
||||
repe_layers=cfg.repe_layers,
|
||||
log_first_sample=(iter_idx[0] == 0),
|
||||
sample_label=f"calib iter=0 method={method} α={alpha:+.3f}",
|
||||
sample_label=f"calib iter=0 method={method} sign={sign_label} α={alpha:+.3f}",
|
||||
)
|
||||
ratio = m["p95"] / target if target > 0 else 1.0
|
||||
history.append({"iter": iter_idx[0], "alpha": alpha, **m, "ratio": ratio})
|
||||
history.append({
|
||||
"iter": iter_idx[0],
|
||||
"sign": sign_label,
|
||||
"alpha": alpha,
|
||||
"alpha_mag": alpha_mag,
|
||||
**m,
|
||||
"ratio": ratio,
|
||||
})
|
||||
logger.info(
|
||||
f" [{method}] iter={iter_idx[0]} α={alpha:+.4f} p95={m['p95']:.4g} "
|
||||
f" [{method}:{sign_label}] iter={iter_idx[0]} α={alpha:+.4f} p95={m['p95']:.4g} "
|
||||
f"mean={m['mean']:.4g} max={m['max']:.4g} ratio={ratio:.3f}"
|
||||
)
|
||||
iter_idx[0] += 1
|
||||
@@ -251,13 +286,7 @@ def _illinois_calibrate(method: str, target: float, *, model, tok, prompts, cfg,
|
||||
mid = float(np.sqrt(lo * hi))
|
||||
v_mid = stat(mid)
|
||||
if abs(v_mid - target) < cfg.convergence_tol:
|
||||
final = history[-1]
|
||||
return {
|
||||
"method": method, "calibrated_alpha": final["alpha"],
|
||||
"p95_at_calib": final["p95"], "mean_at_calib": final["mean"],
|
||||
"max_at_calib": final["max"], "ratio_at_calib": final["ratio"],
|
||||
"iterations": len(history), "converged": True, "history": history,
|
||||
}
|
||||
return _result(history[-1], True)
|
||||
|
||||
if v_mid < target:
|
||||
c_lo, v_lo = mid, v_mid
|
||||
@@ -283,14 +312,7 @@ def _illinois_calibrate(method: str, target: float, *, model, tok, prompts, cfg,
|
||||
c_hi, v_hi = c, v
|
||||
|
||||
if v_lo is None or v_hi is None:
|
||||
# Couldn't bracket within (lo, hi); report last point seen.
|
||||
final = history[-1]
|
||||
return {
|
||||
"method": method, "calibrated_alpha": final["alpha"],
|
||||
"p95_at_calib": final["p95"], "mean_at_calib": final["mean"],
|
||||
"max_at_calib": final["max"], "ratio_at_calib": final["ratio"],
|
||||
"iterations": len(history), "converged": False, "history": history,
|
||||
}
|
||||
return _result(history[-1], False)
|
||||
|
||||
# 2. Log-log Illinois regula-falsi inside the bracket.
|
||||
converged = False
|
||||
@@ -324,22 +346,8 @@ def _illinois_calibrate(method: str, target: float, *, model, tok, prompts, cfg,
|
||||
|
||||
# If we exhausted iters without hitting tol, pick the closest point seen.
|
||||
if not converged:
|
||||
best = min(history, key=lambda h: abs(h["p95"] - target))
|
||||
final = best
|
||||
else:
|
||||
final = history[-1]
|
||||
|
||||
return {
|
||||
"method": method,
|
||||
"calibrated_alpha": final["alpha"],
|
||||
"p95_at_calib": final["p95"],
|
||||
"mean_at_calib": final["mean"],
|
||||
"max_at_calib": final["max"],
|
||||
"ratio_at_calib": final["ratio"],
|
||||
"iterations": len(history),
|
||||
"converged": converged,
|
||||
"history": history,
|
||||
}
|
||||
return _result(min(history, key=lambda h: abs(h["p95"] - target)), False)
|
||||
return _result(history[-1], True)
|
||||
|
||||
|
||||
def main(cfg: KLCalibrateCfg) -> None:
|
||||
@@ -401,23 +409,33 @@ def main(cfg: KLCalibrateCfg) -> None:
|
||||
repe_dirs = _fit_repe_directions(model, tok, cfg.n_repe_train, cfg.behavior)
|
||||
|
||||
# 3. Illinois regula-falsi calibrate each adapter and (optionally) RepE.
|
||||
results = []
|
||||
results_by_method: dict[str, dict[str, dict]] = {}
|
||||
for adapter in cfg.adapters:
|
||||
logger.info(f"\n=== calibrate dW:{adapter} ===")
|
||||
w = load_diff(cfg.out / cfg.behavior / adapter / DIFF_FILENAME)
|
||||
r = _illinois_calibrate(
|
||||
f"dW:{adapter}", target, model=model, tok=tok,
|
||||
prompts=calib_prompts, cfg=cfg, w=w,
|
||||
)
|
||||
results.append(r)
|
||||
results_by_method[f"dW:{adapter}"] = {
|
||||
"pos": _illinois_calibrate(
|
||||
f"dW:{adapter}", target, model=model, tok=tok,
|
||||
prompts=calib_prompts, cfg=cfg, alpha_sign=1.0, sign_label="pos", w=w,
|
||||
),
|
||||
"neg": _illinois_calibrate(
|
||||
f"dW:{adapter}", target, model=model, tok=tok,
|
||||
prompts=calib_prompts, cfg=cfg, alpha_sign=-1.0, sign_label="neg", w=w,
|
||||
),
|
||||
}
|
||||
|
||||
if cfg.include_repe:
|
||||
logger.info("\n=== calibrate repe ===")
|
||||
r = _illinois_calibrate(
|
||||
"repe", target, model=model, tok=tok,
|
||||
prompts=calib_prompts, cfg=cfg, repe_dirs=repe_dirs,
|
||||
)
|
||||
results.append(r)
|
||||
results_by_method["repe"] = {
|
||||
"pos": _illinois_calibrate(
|
||||
"repe", target, model=model, tok=tok,
|
||||
prompts=calib_prompts, cfg=cfg, alpha_sign=1.0, sign_label="pos", repe_dirs=repe_dirs,
|
||||
),
|
||||
"neg": _illinois_calibrate(
|
||||
"repe", target, model=model, tok=tok,
|
||||
prompts=calib_prompts, cfg=cfg, alpha_sign=-1.0, sign_label="neg", repe_dirs=repe_dirs,
|
||||
),
|
||||
}
|
||||
|
||||
# 4. Audit: at calibrated α, recompute on n_audit prompts.
|
||||
logger.info(f"\n=== AUDIT (n={len(audit_prompts)} prompts) ===")
|
||||
@@ -441,63 +459,83 @@ def main(cfg: KLCalibrateCfg) -> None:
|
||||
"calib_audit_ratio": m_audit["p95"] / m_calib["p95"] if m_calib["p95"] > 0 else float("nan"),
|
||||
})
|
||||
|
||||
for r in results:
|
||||
method = r["method"]
|
||||
alpha = r["calibrated_alpha"]
|
||||
logger.info(
|
||||
"SHOULD: pos and neg p95 each match the target independently. "
|
||||
"Asymmetric alpha_pos/alpha_neg means the steering direction has asymmetric KL footprint, not failure."
|
||||
)
|
||||
for method, signs in results_by_method.items():
|
||||
if method.startswith("dW:"):
|
||||
adapter = method.split(":", 1)[1]
|
||||
w = load_diff(cfg.out / cfg.behavior / adapter / DIFF_FILENAME)
|
||||
m_audit = _measure_kl_along_trajectory(
|
||||
method, alpha, model=model, tok=tok, prompts=audit_prompts,
|
||||
n_tokens=cfg.n_tokens, w=w,
|
||||
)
|
||||
elif method == "repe":
|
||||
m_audit = _measure_kl_along_trajectory(
|
||||
method, alpha, model=model, tok=tok, prompts=audit_prompts,
|
||||
n_tokens=cfg.n_tokens, repe_dirs=repe_dirs,
|
||||
repe_layers=cfg.repe_layers,
|
||||
)
|
||||
else:
|
||||
raise ValueError(method)
|
||||
logger.info(
|
||||
f" {method} α={alpha:+.3f} audit p95={m_audit['p95']:.4g} "
|
||||
f"(calib was {r['p95_at_calib']:.4g}, target {target:.4g})"
|
||||
)
|
||||
audit_rows.append({
|
||||
"method": method,
|
||||
"alpha": alpha,
|
||||
"p95_calib": r["p95_at_calib"],
|
||||
"mean_calib": r["mean_at_calib"],
|
||||
"p95_audit": m_audit["p95"],
|
||||
"mean_audit": m_audit["mean"],
|
||||
"max_audit": m_audit["max"],
|
||||
"calib_audit_ratio": m_audit["p95"] / r["p95_at_calib"] if r["p95_at_calib"] > 0 else float("nan"),
|
||||
})
|
||||
w = None
|
||||
for sign_label, r in signs.items():
|
||||
alpha = r["calibrated_alpha"]
|
||||
if method.startswith("dW:"):
|
||||
m_audit = _measure_kl_along_trajectory(
|
||||
method, alpha, model=model, tok=tok, prompts=audit_prompts,
|
||||
n_tokens=cfg.n_tokens, w=w,
|
||||
)
|
||||
elif method == "repe":
|
||||
m_audit = _measure_kl_along_trajectory(
|
||||
method, alpha, model=model, tok=tok, prompts=audit_prompts,
|
||||
n_tokens=cfg.n_tokens, repe_dirs=repe_dirs,
|
||||
repe_layers=cfg.repe_layers,
|
||||
)
|
||||
else:
|
||||
raise ValueError(method)
|
||||
logger.info(
|
||||
f" {method}:{sign_label} α={alpha:+.3f} audit p95={m_audit['p95']:.4g} "
|
||||
f"(calib was {r['p95_at_calib']:.4g}, target {target:.4g})"
|
||||
)
|
||||
audit_rows.append({
|
||||
"method": method,
|
||||
"sign": sign_label,
|
||||
"alpha": alpha,
|
||||
"alpha_mag": r["alpha_mag"],
|
||||
"p95_calib": r["p95_at_calib"],
|
||||
"mean_calib": r["mean_at_calib"],
|
||||
"p95_audit": m_audit["p95"],
|
||||
"mean_audit": m_audit["mean"],
|
||||
"max_audit": m_audit["max"],
|
||||
"calib_audit_ratio": m_audit["p95"] / r["p95_at_calib"] if r["p95_at_calib"] > 0 else float("nan"),
|
||||
})
|
||||
|
||||
audit_df = pl.DataFrame(audit_rows)
|
||||
audit_df.write_csv(out_dir / "audit.csv")
|
||||
|
||||
summary_rows = []
|
||||
for r in results:
|
||||
for method, signs in results_by_method.items():
|
||||
pos = signs["pos"]
|
||||
neg = signs["neg"]
|
||||
summary_rows.append({
|
||||
"method": r["method"],
|
||||
"calibrated_alpha": r["calibrated_alpha"],
|
||||
"p95_at_calib": r["p95_at_calib"],
|
||||
"mean_at_calib": r["mean_at_calib"],
|
||||
"max_at_calib": r["max_at_calib"],
|
||||
"ratio_at_calib": r["ratio_at_calib"],
|
||||
"iterations": r["iterations"],
|
||||
"converged": r["converged"],
|
||||
"method": method,
|
||||
"alpha_pos": pos["alpha_mag"],
|
||||
"alpha_neg": neg["alpha_mag"],
|
||||
"calibrated_alpha": pos["alpha_mag"],
|
||||
"p95_at_pos": pos["p95_at_calib"],
|
||||
"p95_at_neg": neg["p95_at_calib"],
|
||||
"mean_at_pos": pos["mean_at_calib"],
|
||||
"mean_at_neg": neg["mean_at_calib"],
|
||||
"max_at_pos": pos["max_at_calib"],
|
||||
"max_at_neg": neg["max_at_calib"],
|
||||
"ratio_at_pos": pos["ratio_at_calib"],
|
||||
"ratio_at_neg": neg["ratio_at_calib"],
|
||||
"iterations_pos": pos["iterations"],
|
||||
"iterations_neg": neg["iterations"],
|
||||
"converged_pos": pos["converged"],
|
||||
"converged_neg": neg["converged"],
|
||||
})
|
||||
summary_df = pl.DataFrame(summary_rows).sort("calibrated_alpha")
|
||||
summary_df = pl.DataFrame(summary_rows).sort("alpha_pos")
|
||||
summary_df = summary_df.with_columns(pl.lit(target).alias("target_p95"))
|
||||
summary_path = out_dir / "summary.csv"
|
||||
summary_df.write_csv(summary_path)
|
||||
|
||||
history_rows = []
|
||||
for r in results:
|
||||
for h in r["history"]:
|
||||
history_rows.append({"method": r["method"], **h})
|
||||
for method, signs in results_by_method.items():
|
||||
for sign_label, r in signs.items():
|
||||
for h in r["history"]:
|
||||
history_rows.append({"method": method, "sign": sign_label, **h})
|
||||
pl.DataFrame(history_rows).write_csv(out_dir / "root_history.csv")
|
||||
|
||||
pl.DataFrame([{"method": k, **v} for k, v in prompt_refs.items()]).write_csv(out_dir / "prompt_refs.csv")
|
||||
@@ -510,15 +548,23 @@ def main(cfg: KLCalibrateCfg) -> None:
|
||||
print(tabulate(audit_df.to_pandas(), headers="keys", tablefmt="tsv",
|
||||
floatfmt="+.4g", showindex=False))
|
||||
|
||||
cue = "🟢" if all(r["converged"] for r in results) else "🟡"
|
||||
n_converged = sum(
|
||||
int(r["converged"])
|
||||
for signs in results_by_method.values()
|
||||
for r in signs.values()
|
||||
)
|
||||
n_total = sum(len(signs) for signs in results_by_method.values())
|
||||
cue = "🟢" if n_converged == n_total else "🟡"
|
||||
final_summary(
|
||||
out=summary_path,
|
||||
argv=get_argv(),
|
||||
main_metric=f"target_p95={target:.4g} converged={sum(r['converged'] for r in results)}/{len(results)}",
|
||||
main_metric=f"target_p95={target:.4g} converged={n_converged}/{n_total}",
|
||||
cue=cue,
|
||||
table_rows=summary_df.select("method", "calibrated_alpha", "p95_at_calib",
|
||||
"ratio_at_calib", "iterations", "converged").rows(),
|
||||
headers=["method", "alpha", "p95", "ratio", "iters", "ok"],
|
||||
table_rows=summary_df.select(
|
||||
"method", "alpha_neg", "alpha_pos", "p95_at_neg", "p95_at_pos",
|
||||
"iterations_neg", "iterations_pos", "converged_neg", "converged_pos"
|
||||
).rows(),
|
||||
headers=["method", "alpha_neg", "alpha_pos", "p95_neg", "p95_pos", "iters_neg", "iters_pos", "ok_neg", "ok_pos"],
|
||||
floatfmt="",
|
||||
)
|
||||
|
||||
|
||||
@@ -43,6 +43,14 @@ class Cfg:
|
||||
out: Path = Path("out")
|
||||
# Shared data dir — kept separate from out so multiple runs reuse the same pairs.
|
||||
data_root: Path = Path("out/data")
|
||||
data_batch_size: int = 8
|
||||
data_min_new_tokens: int = 1024
|
||||
data_max_new_tokens: int = 1280
|
||||
data_temperature: float | None = None
|
||||
data_top_p: float | None = None
|
||||
data_top_k: int | None = None
|
||||
data_min_p: float | None = None
|
||||
data_presence_penalty: float = 0.0
|
||||
coeffs: tuple[float, ...] = (-2.0, -1.0, 0.0, 1.0, 2.0)
|
||||
|
||||
|
||||
@@ -66,6 +74,14 @@ def _maybe_data(cfg: Cfg) -> Dataset:
|
||||
dcfg = DataCfg(
|
||||
model_id=cfg.model, behavior=cfg.behavior, out=data_root,
|
||||
n_topics=cfg.n_topics, n_personas=cfg.n_personas, n_samples=cfg.n_samples,
|
||||
batch_size=cfg.data_batch_size,
|
||||
min_new_tokens=cfg.data_min_new_tokens,
|
||||
max_new_tokens=cfg.data_max_new_tokens,
|
||||
temperature=cfg.data_temperature,
|
||||
top_p=cfg.data_top_p,
|
||||
top_k=cfg.data_top_k,
|
||||
min_p=cfg.data_min_p,
|
||||
presence_penalty=cfg.data_presence_penalty,
|
||||
)
|
||||
generate_pairs(dcfg)
|
||||
return load_pairs(cfg.behavior, root=data_root)
|
||||
|
||||
Reference in New Issue
Block a user