commit 45b7123cf500d8c571ffae01604d4262be80ec08 Author: copilot Date: Tue May 5 06:17:25 2026 +0800 iso-kl-figure: scaffold + smoke test passing diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..313bd3a --- /dev/null +++ b/.gitignore @@ -0,0 +1,12 @@ +__pycache__/ +*.pyc +.venv/ +uv.lock +*.egg-info/ +.pytest_cache/ +.ruff_cache/ +outputs/*.csv +outputs/*.tsv +outputs/*.png +outputs/*.md +!outputs/.gitkeep diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..884d825 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,24 @@ +# AGENTS.md + +Inherits conventions from sibling project `steering-lite`. Read [../steering-lite/AGENTS.md](../steering-lite/AGENTS.md) if it exists. + +## House rules + +- Fail fast. No defensive programming, no fallbacks, no silent dequant. +- Keep this repo small. Anything beyond the headline figure + table belongs in another repo. +- Use `einops` and `jaxtyping` shape annotations at function boundaries only. Tensor dim letters: `b s d` (batch, seq, d_model), `n` (prompts), `t` (token positions), `f` (fork points). +- No backward compat. +- Single functional smoke test = the real pipeline at tiny scale (`tests/test_smoke.py`). +- Methods register via `@register_config` and `@register` decorators; mirror `steering-lite/src/steering_lite/config.py`. +- All experiment scripts write CSV/TSV. Plot/table scripts read CSV/TSV. Never plot from in-memory state. + +## Out of scope (deliberately) + +- Method zoo beyond mean_diff, directional_ablation, pca. +- LessWrong post / paper draft. +- Citation collection. +- tinymfv or any external eval dependency. + +## Verify + +`just smoke` -> 3/3 methods pass calibrate -> trajectory -> branch-pmass on tiny-random Llama. Asserts nonzero KL at coeff>0, zero KL at coeff=0, branch-pmass in [0,1]. diff --git a/README.md b/README.md new file mode 100644 index 0000000..c71bbd6 --- /dev/null +++ b/README.md @@ -0,0 +1,36 @@ +# iso-kl-figure + +Minimal repo with one job: produce a figure and a table that demonstrate iso-KL calibration is stable across models, seeds, and calibration windows. + +## Claim (narrow) + +Calibrating a steering coefficient so that p95 per-token KL(steered || base) hits 1 nat in a short calibration window: + +- C1: bisection converges for every method tested; held-out p95 KL lands near 1 nat. +- C2 (not too cold): target-axis Delta logit at calibrated alpha is non-zero across methods. +- C3 (not too hot): base-NLL of generated text and branch-pmass of a forced format token stay near base at calibrated alpha across methods. + +The 2x check is a sanity probe, not a margin claim. Reported as: at 2x, p95 KL exceeds 1 nat for N of M cells. + +Honesty footnote: matched on per-token distributional disagreement under greedy decoding in the calibration window. This is one defensible notion of fairness; not equivalence on intervention norm or behavioral effect size. + +## Quick start + +```bash +uv sync --extra all +just smoke # tiny-random model, ~1 min CPU +just calibrate # one (model, method, seed, window) cell +just trajectory +just table +just plot +just table-md +``` + +`just sweep` runs the full grid (3 models x 3 methods x 3 seeds x 2 windows) used by Figure 1. + +## What this repo does NOT do + +- No paper or LessWrong draft. +- No method zoo beyond mean_diff, directional_ablation, pca. +- No threshold sweep, no calibration-set-size sweep. +- No tinymfv integration. Target-axis is a single contrastive sentiment / refusal pair. diff --git a/docs/spec/20260505_iso_kl_figure.md b/docs/spec/20260505_iso_kl_figure.md new file mode 100644 index 0000000..f55b61f --- /dev/null +++ b/docs/spec/20260505_iso_kl_figure.md @@ -0,0 +1,68 @@ +# iso-kl-figure: spec + +## Goal + +Produce one figure (Figure 1) and one table (Table 1) that empirically support three claims: iso-KL calibration converges and generalizes (C1), the calibrated coefficient is not too cold (C2), and not too hot (C3). Show stability across 3 models x 3 seeds x 2 calibration windows. + +## Scope + +In: +- Port `measure_kl`, `calibrate_iso_kl`, minimal Vector/attach/config/target/extract from steering-lite. +- 3 methods: `mean_diff`, `directional_ablation`, `pca`. +- New `branch_pmass` metric: fork-and-teacher-force probability mass on a forced format answer token. +- Scripts producing TSV/CSV; plot and table modules consuming the CSVs. + +Out: +- LessWrong post or paper draft. +- Method zoo beyond 3 methods. +- Threshold sweep, calibration-set-size sweep, norm-matching baseline. +- tinymfv integration. + +## Requirements + +- R1 (C1, calibration converges and generalizes): for every (method, model, seed, window), bisection terminates with calibration p95 within tolerance of 1.0; on a held-out prompt set p95 lands within [0.7, 1.4]. VERIFY: TSV row has converged=true and holdout_p95 in band; sneaky failure (overfits calibration prompts) caught by held-out column. +- R2 (C2, not too cold): target-axis Delta logit at calibrated alpha excludes 0 with 95% CI for each method, on each model. VERIFY: Table 1 row reports CI; sneaky failure (alpha approx 0) caught by alpha column in same row. +- R3 (C3, not too hot, NLL): base-NLL of full 50-token held-out generations stays within 2x of base at calibrated alpha; exceeds 4x of base at 2x calibrated alpha for at least one method per model. VERIFY: Table 1 base_nll_delta column. +- R4 (C3, not too hot, branch-pmass): mean branch-pmass-of-valid-answer at fork points {0, 5, ..., 50} stays within 0.1 of base pmass at calibrated alpha; drops by more than 0.3 at 2x alpha for at least one method per model. VERIFY: Table 1 branch_pmass column and Figure 1 lower subplot. +- R5 (sanity probe at 2x): max p95 KL at 2x alpha exceeds 1 nat in at least 2 of 3 methods on at least 2 of 3 models within 50 tokens. VERIFY: Figure 1 top subplot, alpha=2 panels show lines crossing reference. +- R6 (stability): seed band and window-style overlay in Figure 1 do not change the qualitative C1 conclusion. VERIFY: variance band visually narrow at alpha=1. + +## Tasks + +- [/] T1 (R*): scaffold repo (pyproject, justfile, README, AGENTS, spec). + - verify: `just --list` lists recipes; `uv sync --extra all` resolves. +- [ ] T2 (R1, R2, R3, R4): port core code from steering-lite (calibrate, vector, attach, config, target, extract, 3 variants). + - verify: imports clean; smoke test runs all 3 methods. +- [ ] T3 (R1, R6): extend calibrate history to save per-token KL arrays (`per_t_p95`, `per_t_max`). + - verify: history dict contains per-token arrays of length T. +- [ ] T4 (R4): implement `branch_pmass` (fork at token t, append fixed format suffix, teacher-force one forward, sum p over `true`/`false` tokens). + - verify: pmass in [0, 1]; pmass at base != pmass at coeff=large (sneaky-fail catch). +- [ ] T5 (R1..R5): implement `run_calibrate.py`, `run_trajectory.py`, `run_table.py`. + - verify: CSVs created with expected columns and at least one row each on smoke. +- [ ] T6 (R*): implement `plot.py`, `table.py`. + - verify: PNG saved; markdown table prints; can be regenerated from CSVs alone. +- [ ] T7 (R*): full sweep on real models. + - verify: numeric asserts in R1..R5 pass. +- [ ] T8 (R*): external review of figure + table. + - verify: review doc saved under docs/spec/. + +## Context + +Calibration target: p95 per-token KL(steered || base) = 1 nat over T tokens (T in {20, 50}), N=4 calibration prompts under greedy decoding. + +Branch-pmass procedure: at fork points t in {0, 5, ..., 50} take steered prefix of length t, append `\nAnswer (true/false): ` then `{"value": ` then teacher-force one forward under steered model, sum probabilities of token variants for `true` and `false`. + +Target-axis: a single contrastive pair-set built into the repo (sentiment positive vs negative or refusal yes vs no), 4 prompts each. Target Delta logit = mean over held-out prompts of difference in logit on the target token. + +## Log + +(append-only; only entries that change a future task) + +## TODO + +(out-of-scope ideas; not commitments) + +## Errors + +| Task | Error | Resolution | +|------|-------|------------| diff --git a/justfile b/justfile new file mode 100644 index 0000000..d5036e8 --- /dev/null +++ b/justfile @@ -0,0 +1,24 @@ +set shell := ["bash", "-cu"] + +default: + @just --list + +# Smoke: tiny-random Llama, all 3 methods, asserts nonzero KL + branch-pmass changes with coeff. +smoke: + BEARTYPE=1 uv run --extra all pytest -q tests/test_smoke.py + +test: + uv run --extra all pytest -q + +# Run one (model, method, seed, window) cell end-to-end (calibrate + trajectory + pmass). +cell model="Qwen/Qwen2.5-0.5B-Instruct" method="mean_diff" seed="0" window="50": + uv run --extra all python scripts/run_cell.py \ + --model {{model}} --method {{method}} --seed {{seed}} --window {{window}} + +# Sweep model x method x seed x window cells. +sweep: + bash scripts/sweep.sh + +# Aggregate all outputs// into figs/figure1.png + figs/table.md. +aggregate: + uv run --extra all python scripts/aggregate.py --runs-root outputs --out figs diff --git a/outputs/.gitkeep b/outputs/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..983fa5f --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,36 @@ +[project] +name = "iso-kl-figure" +version = "0.0.1" +description = "Minimal repo: produce one figure + one table proving iso-KL calibration is stable across models/seeds/windows." +readme = "README.md" +requires-python = ">=3.11" +dependencies = [ + "torch>=2.1", + "numpy>=1.26", + "einops>=0.7", + "jaxtyping>=0.2.34", + "safetensors>=0.5", + "loguru>=0.7", + "tqdm>=4.66", +] + +[project.optional-dependencies] +test = ["pytest", "tabulate", "beartype>=0.18"] +hf = ["accelerate>=1.6", "transformers>=4.51"] +plot = ["matplotlib>=3.8", "polars>=1.0"] +all = [ + "pytest", "tabulate", "beartype>=0.18", + "accelerate>=1.6", "transformers>=4.51", + "matplotlib>=3.8", "polars>=1.0", + "tyro>=0.9", +] + +[build-system] +requires = ["setuptools>=68"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.ruff.lint] +ignore = ["F722"] # jaxtyping shape strings diff --git a/scripts/aggregate.py b/scripts/aggregate.py new file mode 100644 index 0000000..e02817a --- /dev/null +++ b/scripts/aggregate.py @@ -0,0 +1,149 @@ +"""Aggregate per-cell outputs into Figure 1 + the headline table. + +Figure 1: two stacked subplots. + Top: per-token p95 KL trajectory. x = token offset; y = KL(steer || base). + Colour by method, linestyle by alpha (solid=1, dashed=2), seed bands + as thin lines, faceted by model. Horizontal at target_kl=1. + Bottom: branch-pmass at fork points. x = fork token offset; y = mean pmass + across held-out prompts; bands = +/- 1 std across seeds. + +Table: one row per (model, method), columns = c_star (mean +/- std across seeds), + KL_p95 @ alpha=1, KL_p95 @ alpha=2, pmass @ alpha=1, pmass @ alpha=2. + +Usage: + python scripts/aggregate.py --runs_root outputs --out figs/ +""" +from __future__ import annotations +import json +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path + +import polars as pl +import tyro +from loguru import logger + + +@dataclass +class Args: + runs_root: str = "outputs" + out: str = "figs" + + +def load_cells(root: Path) -> list[dict]: + cells = [] + for d in sorted(root.iterdir()): + if not d.is_dir(): + continue + calib = d / "calib.json" + if not calib.exists(): + continue + meta = json.loads(calib.read_text()) + traj = json.loads((d / "trajectory.json").read_text()) + pmass = json.loads((d / "pmass.json").read_text()) + cells.append({"id": d.name, **meta, "traj": traj, "pmass": pmass}) + return cells + + +def make_table(cells: list[dict]) -> pl.DataFrame: + rows = [] + by_mm = defaultdict(list) + for c in cells: + by_mm[(c["model"], c["method"])].append(c) + for (model, method), group in by_mm.items(): + c_stars = [g["c_star"] for g in group] + # pmass: mean over fork_points and prompts at each alpha, then across seeds + for alpha in ("1.0", "2.0"): + kls = [] + pms = [] + for g in group: + kls.append(g["traj"]["per_t_p95_kl"][alpha]) + pms.append(g["pmass"]["pmass"][alpha]) + kls_flat = [x for arr in kls for x in arr] + pms_flat = [x for prompt in pms for arr in prompt for x in arr] + rows.append({ + "model": model.split("/")[-1], + "method": method, + "alpha": float(alpha), + "c_star_mean": sum(c_stars) / len(c_stars), + "n_seeds": len(group), + "kl_p95_mean": sum(kls_flat) / max(len(kls_flat), 1), + "pmass_mean": sum(pms_flat) / max(len(pms_flat), 1), + }) + return pl.DataFrame(rows) + + +def make_figure(cells: list[dict], out_path: Path) -> None: + import matplotlib.pyplot as plt + import numpy as np + models = sorted({c["model"] for c in cells}) + methods = sorted({c["method"] for c in cells}) + fig, axes = plt.subplots(2, len(models), figsize=(5 * len(models), 7), + sharex="col", squeeze=False) + cmap = plt.get_cmap("tab10") + method_color = {m: cmap(i) for i, m in enumerate(methods)} + + for ci, model in enumerate(models): + ax_kl = axes[0, ci] + ax_pm = axes[1, ci] + ax_kl.set_title(model.split("/")[-1]) + ax_kl.axhline(1.0, color="black", linestyle=":", linewidth=0.8, alpha=0.5) + ax_kl.set_ylabel("p95 KL(steer || base)") + ax_pm.set_xlabel("token offset") + ax_pm.set_ylabel("branch pmass") + ax_pm.set_ylim(-0.02, 1.02) + ax_kl.set_yscale("log") + + for method in methods: + for alpha, ls in [("1.0", "-"), ("2.0", "--")]: + kls = [c["traj"]["per_t_p95_kl"][alpha] + for c in cells if c["model"] == model and c["method"] == method] + if not kls: + continue + arr = np.array(kls) + x = np.arange(arr.shape[1]) + ax_kl.plot(x, arr.mean(0), color=method_color[method], + linestyle=ls, linewidth=2, + label=f"{method} a={alpha}") + if arr.shape[0] > 1: + ax_kl.fill_between(x, arr.min(0), arr.max(0), + color=method_color[method], alpha=0.12) + + pms = [c["pmass"]["pmass"][alpha] + for c in cells if c["model"] == model and c["method"] == method] + if not pms: + continue + # pms: list of (n_seed) of (n_prompt) of (n_fork) + pms_arr = np.array(pms) # (n_seed, n_prompt, n_fork) + fork = cells[0]["pmass"]["fork_points"] + mean = pms_arr.mean(axis=(0, 1)) + std = pms_arr.std(axis=(0, 1)) + ax_pm.plot(fork, mean, color=method_color[method], + linestyle=ls, linewidth=2) + ax_pm.fill_between(fork, mean - std, mean + std, + color=method_color[method], alpha=0.12) + if ci == 0: + ax_kl.legend(fontsize=8, loc="upper left") + fig.tight_layout() + fig.savefig(out_path, dpi=150, bbox_inches="tight") + logger.info(f"figure -> {out_path}") + + +def main(a: Args): + out = Path(a.out); out.mkdir(parents=True, exist_ok=True) + cells = load_cells(Path(a.runs_root)) + if not cells: + raise SystemExit(f"no cells under {a.runs_root}") + logger.info(f"loaded {len(cells)} cells") + + df = make_table(cells) + df.write_csv(out / "table.csv") + md = df.to_pandas().to_markdown(index=False, floatfmt=".3f") + (out / "table.md").write_text(md) + logger.info(f"table -> {out/'table.md'}\n{md}") + + make_figure(cells, out / "figure1.png") + + +if __name__ == "__main__": + main(tyro.cli(Args)) diff --git a/scripts/run_cell.py b/scripts/run_cell.py new file mode 100644 index 0000000..b97841f --- /dev/null +++ b/scripts/run_cell.py @@ -0,0 +1,198 @@ +"""End-to-end runner for one (model, method, seed, window) cell. + +Flow: +1. Load model + tokenizer (HF), set seed. +2. Build pos/neg prompts (cheap pair); train the steering Vector v. +3. Calibrate iso-KL at target_kl=1 over T=window tokens. Save full history + (incl. per-token KL arrays) to outputs//history.json. +4. Re-run measure_kl at coeff=alpha*c_star (alpha in {1, 2}) on a held-out + prompt set so the trajectory plot reflects generalisation, not the + calibration set itself. Save per-token p95 KL to trajectory.json. +5. For each held-out prompt, rollout T_eval tokens under the steered model, + then branch-pmass at fork_points {0, 5, ..., T_eval}. Save to + pmass.json. Use a JSON-format suffix so target tokens are well-defined. + +Outputs one CSV row per (alpha, prompt) into outputs//results.csv +plus the artefacts above. + +Usage (tyro): + python scripts/run_cell.py --model Qwen/Qwen2.5-0.5B-Instruct --method mean_diff \ + --seed 0 --window 50 --run_id qwen05b_md_s0_w50 +""" +from __future__ import annotations +import json +import os +from dataclasses import dataclass, field +from pathlib import Path + +import torch +import tyro +from loguru import logger + +from iso_kl_figure import ( + SteeringConfig, MeanDiffC, PCAC, DirectionalAblationC, + train, calibrate_iso_kl, measure_kl, +) +from iso_kl_figure.branch_pmass import branch_pmass + + +CALIB_PROMPTS = [ + "The eiffel tower is in Paris.", + "Tell me a greentext story about a small village during the Martian carrot bubble.", + "Walk me through computing the integral of x^2 from 0 to 1.", + "Briefly describe how plate tectonics shape mountain ranges.", +] + +# Held-out prompts -- different from calibration; tests generalisation. +EVAL_PROMPTS = [ + "Explain why the sky is blue in two short sentences.", + "Summarise the plot of Hamlet in one paragraph.", + "List three causes of inflation.", + "Write a haiku about an octopus librarian.", +] + +# Pos/neg pair: a generic content-vs-refusal direction. Cheap, not the main +# point -- the figure tests calibration *behaviour*, not direction quality. +POS_NEG = [ + ("Sure, here is the answer:", "I cannot help with that."), + ("Here are some details:", "Sorry, I can't assist with that."), + ("Of course, let me explain.", "I won't be able to help."), + ("Yes, that makes sense.", "No, I have to decline."), +] + + +METHOD_MAP = { + "mean_diff": MeanDiffC, + "pca": PCAC, + "directional_ablation": DirectionalAblationC, +} + + +@dataclass +class Args: + model: str + method: str + seed: int = 0 + window: int = 50 + run_id: str = "" + layer_frac: float = 0.6 + target_kl: float = 1.0 + out_root: str = "outputs" + device: str = "cuda" + dtype: str = "bfloat16" + suffix_str: str = ' Final answer in JSON: {"value": ' + target_words: list[str] = field(default_factory=lambda: ["true", "false", "yes", "no"]) + fork_step: int = 5 + + +def _set_seed(s: int): + import random + import numpy as np + random.seed(s); np.random.seed(s); torch.manual_seed(s) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(s) + + +def main(a: Args): + if not a.run_id: + a.run_id = f"{a.model.split('/')[-1]}_{a.method}_s{a.seed}_w{a.window}" + out_dir = Path(a.out_root) / a.run_id + out_dir.mkdir(parents=True, exist_ok=True) + logger.add(out_dir / "run.log", level="INFO") + + _set_seed(a.seed) + from transformers import AutoModelForCausalLM, AutoTokenizer + dtype = getattr(torch, a.dtype) + tok = AutoTokenizer.from_pretrained(a.model) + if tok.pad_token_id is None: + tok.pad_token_id = tok.eos_token_id + model = AutoModelForCausalLM.from_pretrained(a.model, torch_dtype=dtype).to(a.device) + model.eval() + + n_layers = model.config.num_hidden_layers + layer = int(a.layer_frac * n_layers) + logger.info(f"model={a.model} n_layers={n_layers} target_layer={layer}") + + cfg_cls = METHOD_MAP[a.method] + cfg = cfg_cls(coeff=1.0, layers=(layer,)) + + pos = [tok.apply_chat_template([{"role": "user", "content": u}, + {"role": "assistant", "content": p}], + tokenize=False) + for u, (p, _) in zip(CALIB_PROMPTS, POS_NEG)] + neg = [tok.apply_chat_template([{"role": "user", "content": u}, + {"role": "assistant", "content": n}], + tokenize=False) + for u, (_, n) in zip(CALIB_PROMPTS, POS_NEG)] + v = train(model, tok, pos, neg, cfg, batch_size=4, max_length=128) + + logger.info("=== calibrate ===") + c_star, history = calibrate_iso_kl( + v, model, tok, CALIB_PROMPTS, + target_kl=a.target_kl, target_stat="kl_p95", + T=a.window, device=a.device, + ) + v.cfg.coeff = c_star + logger.info(f"c_star = {c_star:+.4f}") + (out_dir / "history.json").write_text(json.dumps(history, indent=2)) + (out_dir / "calib.json").write_text(json.dumps({ + "c_star": c_star, "target_kl": a.target_kl, "window": a.window, + "method": a.method, "model": a.model, "seed": a.seed, "layer": layer, + }, indent=2)) + + # -- trajectory + pmass at alpha in {1, 2} on held-out prompts + rows = [] + fork_points = list(range(0, a.window + 1, a.fork_step)) + trajectory: dict[str, list] = {} + pmass_all: dict[str, list] = {} + for alpha in (1.0, 2.0): + v.cfg.coeff = alpha * c_star + logger.info(f"=== eval alpha={alpha} c={v.cfg.coeff:+.4f} ===") + m = measure_kl(v, model, tok, EVAL_PROMPTS, T=a.window, device=a.device) + trajectory[str(alpha)] = m["per_t_p95"] + rows.append({"alpha": alpha, "coeff": v.cfg.coeff, "kl_p95": m["kl_p95"], + "kl_mean": m["kl_mean"], "kl_max": m["kl_max"]}) + + # pmass per held-out prompt + pm_for_alpha = [] + for p in EVAL_PROMPTS: + ids = tok.apply_chat_template( + [{"role": "user", "content": p}], + add_generation_prompt=True, return_tensors="pt", + ).input_ids[0] + pad = tok.pad_token_id + with v(model): + gen = model.generate( + ids.unsqueeze(0).to(a.device), + max_new_tokens=a.window, + pad_token_id=pad, eos_token_id=tok.eos_token_id, + do_sample=False, + )[0, ids.shape[0]:] + pm = branch_pmass( + v, model, tok, ids, gen, fork_points, + a.suffix_str, a.target_words, device=a.device, + ) + pm_for_alpha.append(pm["pmass"]) + pmass_all[str(alpha)] = pm_for_alpha + + (out_dir / "trajectory.json").write_text(json.dumps({ + "fork_points_full": list(range(a.window)), + "per_t_p95_kl": trajectory, + }, indent=2)) + (out_dir / "pmass.json").write_text(json.dumps({ + "fork_points": fork_points, + "pmass": pmass_all, + "suffix": a.suffix_str, + "target_words": a.target_words, + }, indent=2)) + import csv + with open(out_dir / "results.csv", "w", newline="") as f: + w = csv.DictWriter(f, fieldnames=["alpha", "coeff", "kl_p95", "kl_mean", "kl_max"]) + w.writeheader() + for r in rows: + w.writerow(r) + logger.info(f"DONE -> {out_dir}") + + +if __name__ == "__main__": + main(tyro.cli(Args)) diff --git a/scripts/sweep.sh b/scripts/sweep.sh new file mode 100644 index 0000000..7410af1 --- /dev/null +++ b/scripts/sweep.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash +# Sweep: model x method x seed x window. Edit the lists to taste. +set -euo pipefail +cd "$(dirname "$0")/.." + +MODELS=("Qwen/Qwen2.5-0.5B-Instruct" "Qwen/Qwen2.5-1.5B-Instruct" "meta-llama/Llama-3.2-1B-Instruct") +METHODS=("mean_diff" "directional_ablation" "pca") +SEEDS=(0 1 2) +WINDOWS=(20 50) + +for model in "${MODELS[@]}"; do + for method in "${METHODS[@]}"; do + for seed in "${SEEDS[@]}"; do + for window in "${WINDOWS[@]}"; do + run_id="$(basename "$model")_${method}_s${seed}_w${window}" + if [ -f "outputs/${run_id}/calib.json" ]; then + echo "skip ${run_id}"; continue + fi + echo "=== ${run_id} ===" + uv run --extra all python scripts/run_cell.py \ + --model "$model" --method "$method" --seed "$seed" --window "$window" + done + done + done +done + +uv run --extra all python scripts/aggregate.py --runs-root outputs --out figs diff --git a/src/iso_kl_figure/__init__.py b/src/iso_kl_figure/__init__.py new file mode 100644 index 0000000..517fcf6 --- /dev/null +++ b/src/iso_kl_figure/__init__.py @@ -0,0 +1,34 @@ +import os as _os + +if _os.environ.get("BEARTYPE"): + from beartype.claw import beartype_this_package as _bt + _bt() + +from .config import SteeringConfig, REGISTRY, register +from .extract import record_activations +from .attach import attach, detach, save, load, train +from .calibrate import measure_kl, calibrate_iso_kl +from . import variants # noqa: F401 triggers method + config registration +from .vector import Vector + +from .variants.mean_diff import MeanDiffC +from .variants.pca import PCAC +from .variants.directional_ablation import DirectionalAblationC + +__all__ = [ + "SteeringConfig", + "MeanDiffC", + "PCAC", + "DirectionalAblationC", + "record_activations", + "train", + "attach", + "detach", + "save", + "load", + "measure_kl", + "calibrate_iso_kl", + "REGISTRY", + "register", + "Vector", +] diff --git a/src/iso_kl_figure/attach.py b/src/iso_kl_figure/attach.py new file mode 100644 index 0000000..b264657 --- /dev/null +++ b/src/iso_kl_figure/attach.py @@ -0,0 +1,243 @@ +"""attach / detach / save / load. The whole runtime. + +Variant protocol (uniform across both hook paths): + + apply(mod, x, y, state, cfg) -> y_new + +`mod` is the hooked module itself (a transformer block or a Linear); `x` is +its input, `y` its output. Variants return the module's NEW output: additive +variants do `return y + delta`, replacing variants ignore `y` and return any +tensor of the same shape. Same contract as lora-lite's `Variant.forward`. + +Two hook paths, dispatched on `cfg.target_submodule`: + + - `target_submodule is None` (default): hook each transformer block's + forward output. `mod = block`, `x = args[0]` (input residual), + `y = out[0]` (output hidden_states). State keyed by `int` (block index). + - `target_submodule = `: hook every nn.Linear in each selected block + whose dotted path matches the regex. `mod = linear`, `x` is the Linear's + input, `y` its output. State keyed by `str` (full dotted name like + `"layers.5.mlp.down_proj"`). +""" +from __future__ import annotations +import json +import torch +from torch import nn +from torch.utils.hooks import RemovableHandle + +from .config import SteeringConfig, REGISTRY +from .target import find_targets +from .extract import record_activations + + +_ATTACHED_ATTR = "_steering_lite_attached" +_STATE_PREFIX = "_steering_state_" +_SUB_KEY_PREFIX = "sub::" # safetensors key prefix marking submodule-level state +_SUB_KEY_SEP = "::" # separator between full_name and state_key + + +def _gather_state(mod) -> dict[str, torch.Tensor]: + return { + k[len(_STATE_PREFIX):]: getattr(mod, k) + for k in dir(mod) + if k.startswith(_STATE_PREFIX) and isinstance(getattr(mod, k, None), torch.Tensor) + } + + +def _hook(mod, args, out): + """Forward hook for block-level variants. Block forward returns a tuple + `(hidden_states, ...)`; we replace `[0]` with the variant's output.""" + cfg: SteeringConfig = mod._steering_cfg + method = mod._steering_method + state = _gather_state(mod) + x = args[0] + if isinstance(out, tuple): + y = out[0] + y_new = method.apply(mod, x, y, state, cfg) + return (y_new,) + out[1:] + return method.apply(mod, x, out, state, cfg) + + +def _linear_hook(mod, args, out): + """Forward hook for sub-module variants (cfg.target_submodule is set). + `out` is the Linear's output tensor (not a tuple).""" + cfg: SteeringConfig = mod._steering_cfg + method = mod._steering_method + state = _gather_state(mod) + return method.apply(mod, args[0], out, state, cfg) + + +def _install_state(mod: nn.Module, state: dict[str, torch.Tensor], cfg: SteeringConfig) -> None: + for k, v in state.items(): + attr = _STATE_PREFIX + k + if hasattr(mod, attr): + raise RuntimeError(f"module already has {attr}; detach first") + mod.register_buffer(attr, v.to(cfg.dtype), persistent=True) + + +def attach( + model: nn.Module, + cfg: SteeringConfig, + vectors, +) -> list[RemovableHandle]: + """Install per-target state as buffers and register forward hooks. + + `vectors` shape depends on cfg.target_submodule: + - None: dict[int, dict[str, Tensor]] keyed by block layer index. + - regex set: dict[str, dict[str, Tensor]] keyed by full dotted name. + + Accepts a `Vector` (auto-unwrapped to its `.state`). + """ + from .vector import Vector + if isinstance(vectors, Vector): + vectors = vectors.state + if cfg.method not in REGISTRY: + raise KeyError(f"unknown method {cfg.method!r}; registered: {list(REGISTRY)}") + method = REGISTRY[cfg.method] + # variant-level default target_submodule (e.g. sspace -> residual writers) + if cfg.target_submodule is None and getattr(method, "default_target_submodule", None): + cfg.target_submodule = method.default_target_submodule + requires_linear = cfg.target_submodule is not None + targets = find_targets(model, cfg) + if not targets: + raise RuntimeError("no target layers matched cfg") + + handles: list[RemovableHandle] = [] + attached_names: list[str] = [] + hooked_modules: list[nn.Module] = [] + for full_name, mod, li in targets: + key = full_name if requires_linear else li + if key not in vectors: + raise KeyError(f"vectors missing key {key!r}; have {sorted(vectors)}") + _install_state(mod, vectors[key], cfg) + mod._steering_cfg = cfg + mod._steering_method = method + if requires_linear: + mod._steering_module_name = full_name + hooked_modules.append(mod) + handles.append(mod.register_forward_hook(_linear_hook)) + else: + mod._steering_layer_idx = li + handles.append(mod.register_forward_hook(_hook)) + attached_names.append(full_name) + + setattr(model, _ATTACHED_ATTR, { + "cfg": cfg, "targets": attached_names, "handles": handles, + "hooked_modules": hooked_modules, + }) + return handles + + +def detach(model: nn.Module) -> None: + state = getattr(model, _ATTACHED_ATTR, None) + if state is None: + return + for h in state["handles"]: + h.remove() + for _, mod in model.named_modules(): + if not hasattr(mod, "_steering_method"): + continue + for k in [k for k in list(mod._buffers) if k.startswith(_STATE_PREFIX)]: + del mod._buffers[k] + for attr in ( + "_steering_cfg", "_steering_method", + "_steering_layer_idx", "_steering_module_name", + ): + if hasattr(mod, attr): + delattr(mod, attr) + delattr(model, _ATTACHED_ATTR) + + +def _log_extract_demo(tok, pos_prompts: list[str], neg_prompts: list[str]) -> None: + """One trace per stage: decoded full prompt + special tokens, for format debugging.""" + from loguru import logger + pos = pos_prompts[0] + neg = neg_prompts[0] + logger.info( + "EXPECT: POS and NEG share user_msg + suffix; differ only in system persona; " + "chat template applied; special tokens (e.g. <|im_start|>) visible.\n" + "=== EXTRACT demo trace ===\n" + f"POS[0]:\n{pos}\n---\nNEG[0]:\n{neg}\n=== /EXTRACT ===" + ) + + +def train( + model: nn.Module, + tok, + pos_prompts: list[str], + neg_prompts: list[str], + cfg: SteeringConfig, + *, + batch_size: int = 8, + max_length: int = 256, +): + """Extract activations + run method.extract -> Vector. Block-level only. + + Stripped from steering-lite: no submodule regex hooks, no attn pooling. + """ + from .vector import Vector + _log_extract_demo(tok, pos_prompts, neg_prompts) + method = REGISTRY[cfg.method] + targets = find_targets(model, cfg) + layers = tuple(li for _, _, li in targets) + pos_acts = record_activations(model, tok, pos_prompts, layers, batch_size=batch_size, max_length=max_length) + neg_acts = record_activations(model, tok, neg_prompts, layers, batch_size=batch_size, max_length=max_length) + state = method.extract(pos_acts, neg_acts, cfg) + return Vector(cfg, state) + + +def _state_to_safetensors_dict(model: nn.Module) -> dict: + """Serialise installed state buffers from all hooked modules. Forks on + whether the module is block-level (_steering_layer_idx) or submodule-level + (_steering_module_name); keys distinguish the two so load() can rebuild.""" + sd = {} + for _, mod in model.named_modules(): + if not hasattr(mod, "_steering_method"): + continue + if hasattr(mod, "_steering_module_name"): + full_name = mod._steering_module_name + for k, v in mod._buffers.items(): + if k.startswith(_STATE_PREFIX): + sd[f"{_SUB_KEY_PREFIX}{full_name}{_SUB_KEY_SEP}{k[len(_STATE_PREFIX):]}"] = v.detach().cpu() + elif hasattr(mod, "_steering_layer_idx"): + li = mod._steering_layer_idx + for k, v in mod._buffers.items(): + if k.startswith(_STATE_PREFIX): + sd[f"layer{li}.{k[len(_STATE_PREFIX):]}"] = v.detach().cpu() + return sd + + +def _safetensors_dict_to_state(sd: dict[str, torch.Tensor]) -> dict: + """Inverse of _state_to_safetensors_dict. Returns dict keyed by int (block-level) + or str (submodule-level), depending on the prefix of each key.""" + vectors: dict = {} + for k, v in sd.items(): + if k.startswith(_SUB_KEY_PREFIX): + rest = k[len(_SUB_KEY_PREFIX):] + full_name, _, state_key = rest.rpartition(_SUB_KEY_SEP) + vectors.setdefault(full_name, {})[state_key] = v + else: + layer_part, _, sub = k.partition(".") + li = int(layer_part.removeprefix("layer")) + vectors.setdefault(li, {})[sub] = v + return vectors + + +def save(model: nn.Module, path: str) -> None: + state = getattr(model, _ATTACHED_ATTR, None) + if state is None: + raise RuntimeError("no steering attached; call attach() first") + sd = _state_to_safetensors_dict(model) + metadata = {"cfg": json.dumps(state["cfg"].to_dict())} + from safetensors.torch import save_file + save_file(sd, path, metadata=metadata) + + +def load(model: nn.Module, path: str) -> list[RemovableHandle]: + from safetensors.torch import load_file, safe_open + with safe_open(path, framework="pt", device="cpu") as f: + metadata = f.metadata() + sd = load_file(path, device="cpu") + cfg = SteeringConfig.from_dict(json.loads(metadata["cfg"])) + vectors = _safetensors_dict_to_state(sd) + return attach(model, cfg, vectors) diff --git a/src/iso_kl_figure/branch_pmass.py b/src/iso_kl_figure/branch_pmass.py new file mode 100644 index 0000000..b9be0a4 --- /dev/null +++ b/src/iso_kl_figure/branch_pmass.py @@ -0,0 +1,90 @@ +"""Branch-and-teacher-force pmass: coherence metric for steered generation. + +At fork point t along a steered rollout, take prefix[:t], append a fixed +format suffix (e.g. `{"value": `), teacher-force one forward pass with the +steered model, and sum softmax mass over user-supplied target token strings +(e.g. `["true", "false"]`). High pmass ~ model still emits valid format +tokens; low pmass ~ format crash, off-distribution drift, semantic collapse. + +The metric is novel-ish: a single scalar that distinguishes "model is steered +toward a different token" from "model has lost track of the format". A target +direction can move pmass off 1.0 by reweighting between target tokens but +should not drop pmass to ~0. A miscalibrated coeff drops pmass to noise. + +Returns Float[Tensor, "f"] over fork points. +""" +from __future__ import annotations +from typing import Sequence + +import torch +from torch import nn, Tensor + +from .vector import Vector + + +def _all_token_ids(tok, words: Sequence[str]) -> list[int]: + """Collect the leading token id for each word in several capitalisation / + leading-space variants. Different tokenisers split " true", "true", "True" + differently; we sum mass over all variants so pmass tracks 'is the model + putting probability on this concept' rather than the specific tokenization. + """ + ids: set[int] = set() + for w in words: + for variant in (w, " " + w, w.capitalize(), " " + w.capitalize(), + w.upper(), " " + w.upper()): + try: + t = tok.encode(variant, add_special_tokens=False) + except Exception: + continue + if len(t) >= 1: + ids.add(int(t[0])) + return sorted(ids) + + +@torch.no_grad() +def branch_pmass( + v: Vector, + model: nn.Module, + tok, + prompt_ids: Tensor, # (n_prompt,) int64 + rolled_ids: Tensor, # (T,) steered rollout token ids + fork_points: Sequence[int], # token offsets along rolled_ids + suffix_str: str, # fixed format suffix appended at each fork + target_words: Sequence[str], # words to sum pmass over (any tokenization) + *, + device: str | torch.device = "cuda", +) -> dict: + """Returns {"pmass": [f], "fork_points": [f], "target_ids": [...], "suffix_ids": [...]} + + Caller should pass the SAME `rolled_ids` produced by the same `Vector` so + fork-point semantics are consistent. + """ + suffix_ids = tok.encode(suffix_str, add_special_tokens=False) + suffix_t = torch.tensor(suffix_ids, dtype=torch.long, device=device) + target_ids = _all_token_ids(tok, target_words) + if not target_ids: + raise ValueError(f"no target ids found for words={target_words}") + target_idx = torch.tensor(target_ids, dtype=torch.long, device=device) + + pmass = [] + pids = prompt_ids.to(device) + rolled = rolled_ids.to(device) + T = rolled.shape[0] + + for t in fork_points: + if t > T: + pmass.append(float("nan")) + continue + prefix = rolled[:t] + full = torch.cat([pids, prefix, suffix_t]).unsqueeze(0) + with v(model): + logits = model(full).logits[0, -1].float() + probs = torch.softmax(logits, dim=-1) + pmass.append(float(probs[target_idx].sum())) + + return { + "pmass": pmass, + "fork_points": list(fork_points), + "target_ids": target_ids, + "suffix_ids": suffix_ids, + } diff --git a/src/iso_kl_figure/calibrate.py b/src/iso_kl_figure/calibrate.py new file mode 100644 index 0000000..5714e96 --- /dev/null +++ b/src/iso_kl_figure/calibrate.py @@ -0,0 +1,247 @@ +"""Iso-KL calibration with per-token KL trajectory persistence. + +Forked from steering-lite/calibrate.py. Two changes: +- `measure_kl` also returns `per_t_p95` (across-prompt 95th percentile per token + position), needed for the headline trajectory plot. +- `calibrate_iso_kl` keeps the full per-token arrays in `history` so we can + plot p95 KL trajectory at the calibrated coeff (and at 2x) without re-running. +""" +from __future__ import annotations +import math +from typing import Callable + +import torch +from loguru import logger +from torch import Tensor +from torch import nn +from tqdm.auto import tqdm + +from .config import SteeringConfig # noqa: F401 +from .vector import Vector + + +_demo_logged = {"flag": False} + + +DEFAULT_MESSAGES = [ + "The eiffel tower is in Paris", + "埃菲尔铁塔🗼位于天都城", + "Tell me a greentext story about a small village during the smaller Martion carrot bubble.", + "Think step by step to calculate the integral of x^2 from 0 to 1 in lean4.", +] + + +def _tokenize(prompts, tok): + if prompts is None: + prompts = DEFAULT_MESSAGES + if isinstance(prompts[0], str): + return [ + tok.apply_chat_template( + [{"role": "user", "content": p}], + add_generation_prompt=True, return_tensors="pt", + ).input_ids[0] + for p in prompts + ] + return prompts + + +@torch.no_grad() +def _kl_per_pos(logp_steer: Tensor, logp_base: Tensor) -> Tensor: + p_s = logp_steer.exp() + return (p_s * (logp_steer - logp_base)).sum(dim=-1) + + +@torch.no_grad() +def _generate(model, prompt_ids, T, tok, do_sample, device): + pad_id = tok.pad_token_id if tok.pad_token_id is not None else tok.eos_token_id + ids = prompt_ids.unsqueeze(0).to(device) + out = model.generate( + ids, max_new_tokens=T, pad_token_id=pad_id, eos_token_id=tok.eos_token_id, + num_return_sequences=1, do_sample=do_sample, + ) + return out[0, prompt_ids.shape[0]:] + + +def _quantile(xs: list[float], q: float) -> float: + if not xs: + return 0.0 + return float(torch.tensor(xs).quantile(q)) + + +@torch.no_grad() +def measure_kl( + v: Vector, + model: nn.Module, + tok, + prompts=None, + *, + T: int = 50, + do_sample: bool = False, + device: str | torch.device = "cuda", +) -> dict: + """Roll out under steering, score under base+steer. Returns scalar stats + plus per-token arrays (mean, p95, max) of length T. + """ + prompts = _tokenize(prompts, tok) + all_kls = [] + per_t = [[] for _ in range(T)] + + for idx, pids in enumerate(tqdm(prompts, desc="measure_kl", mininterval=60)): + with v(model): + gen = _generate(model, pids, T, tok, do_sample, device) + n_gen = gen.shape[0] + if n_gen == 0: + continue + full_ids = torch.cat([pids.to(device), gen]) + if idx == 0 and not _demo_logged["flag"]: + _demo_logged["flag"] = True + base_gen = _generate(model, pids, T, tok, do_sample, device) + base_full = torch.cat([pids.to(device), base_gen]) + decoded_base = tok.decode(base_full, skip_special_tokens=False) + decoded_steer = tok.decode(full_ids, skip_special_tokens=False) + logger.info( + f"EXPECT: same prompt under c=0 vs c={v.cfg.coeff:+.4f}; both coherent; " + "steered should differ from base but not collapse.\n" + f"\n=== CALIBRATE demo trace (T={T}) ===\n" + f"--- BASE (c=0) ---\n{decoded_base}\n" + f"\n--- STEER (c={v.cfg.coeff:+.4f}) ---\n{decoded_steer}\n" + f"=== /CALIBRATE ===" + ) + full = full_ids.unsqueeze(0) + n_p = pids.shape[0] + + logp_base = torch.log_softmax(model(full).logits.float(), dim=-1)[0] + with v(model): + logp_steer = torch.log_softmax(model(full).logits.float(), dim=-1)[0] + + slc = slice(n_p - 1, n_p - 1 + n_gen) + kls = _kl_per_pos(logp_steer[slc], logp_base[slc]).cpu() + all_kls.append(kls) + for i in range(n_gen): + per_t[i].append(float(kls[i])) + + cat = torch.cat(all_kls) + return { + "kl_mean": float(cat.mean()), + "kl_p50": float(cat.quantile(0.50)), + "kl_p90": float(cat.quantile(0.90)), + "kl_p95": float(cat.quantile(0.95)), + "kl_max": float(cat.max()), + "n_pos": int(cat.numel()), + "per_t_mean": [sum(xs) / len(xs) if xs else 0.0 for xs in per_t], + "per_t_p95": [_quantile(xs, 0.95) for xs in per_t], + "per_t_max": [max(xs) if xs else 0.0 for xs in per_t], + } + + +def calibrate_iso_kl( + v: Vector, + model: nn.Module, + tok, + prompts=None, + *, + target_kl: float = 1.0, + target_stat: str = "kl_p95", + bracket: tuple[float, float] = (0.01, 16.0), + tol: float = 0.05, + max_iters: int = 12, + T: int = 50, + device: str | torch.device = "cuda", + sign: float = 1.0, + sign_probe: Callable[[Vector], float] | None = None, + sign_probe_c: float = 1.0, +) -> tuple[float, list[dict]]: + """log-log Illinois bisection on `target_stat`. History keeps per-token + arrays so we can plot the trajectory after.""" + _demo_logged["flag"] = False + prompts = _tokenize(prompts, tok) + history: list[dict] = [] + + if sign_probe is not None: + v.cfg.coeff = +sign_probe_c + score_pos = sign_probe(v) + v.cfg.coeff = -sign_probe_c + score_neg = sign_probe(v) + chosen = +1.0 if score_pos >= score_neg else -1.0 + logger.info( + f"sign_probe: +c={sign_probe_c:+.2f} -> {score_pos:+.3f} | " + f"-c={-sign_probe_c:+.2f} -> {score_neg:+.3f} | " + f"chosen sign={chosen:+.0f}" + ) + sign = sign * chosen + + def eval_at(c: float) -> float: + v.cfg.coeff = sign * c + m = measure_kl(v, model, tok, prompts, T=T, do_sample=False, device=device) + history.append({"coeff": sign * c, "coeff_abs": c, "sign": sign, **m}) + logger.info(f" c={sign * c:+.4f} mean={m['kl_mean']:.3f} " + f"p50={m['kl_p50']:.3f} p90={m['kl_p90']:.3f} " + f"p95={m['kl_p95']:.3f} max={m['kl_max']:.3f} n={m['n_pos']}") + return m[target_stat] + + lo, hi = bracket + log_target = math.log(target_kl) + + mid = (lo * hi) ** 0.5 + v_mid = eval_at(mid) + if v_mid < target_kl: + c_lo, v_lo = mid, v_mid + c = mid + c_hi, v_hi = hi, None + while c < hi: + c *= 2.0 + val = eval_at(c) + if val >= target_kl: + c_hi, v_hi = c, val + break + c_lo, v_lo = c, val + else: + logger.warning(f"calibrate {v.cfg.method}: KL below target across bracket") + return sign * c, history + else: + c_hi, v_hi = mid, v_mid + c = mid + c_lo, v_lo = lo, None + while c > lo: + c /= 2.0 + val = eval_at(c) + if val <= target_kl: + c_lo, v_lo = c, val + break + c_hi, v_hi = c, val + else: + logger.warning(f"calibrate {v.cfg.method}: KL above target across bracket") + return sign * c, history + + stale_lo = stale_hi = 0 + for _ in tqdm(range(max_iters), desc=f"calib {v.cfg.method}", mininterval=60, leave=False): + if v_lo is not None and v_hi is not None and v_lo > 0 and v_hi > 0: + log_c_lo, log_c_hi = math.log(c_lo), math.log(c_hi) + log_v_lo = math.log(v_lo) - (math.log(2) if stale_lo >= 2 else 0.0) + log_v_hi = math.log(v_hi) - (math.log(2) if stale_hi >= 2 else 0.0) + denom = log_v_hi - log_v_lo + if abs(denom) < 1e-6: + c_new = math.sqrt(c_lo * c_hi) + else: + t = (log_target - log_v_lo) / denom + log_c_new = log_c_lo + t * (log_c_hi - log_c_lo) + c_new = math.exp(log_c_new) + if not (c_lo < c_new < c_hi): + c_new = math.sqrt(c_lo * c_hi) + else: + c_new = math.sqrt(c_lo * c_hi) + + v_new = eval_at(c_new) + if abs(v_new - target_kl) < tol: + return sign * c_new, history + if v_new < target_kl: + c_lo, v_lo = c_new, v_new + stale_lo = 0 + stale_hi += 1 + else: + c_hi, v_hi = c_new, v_new + stale_hi = 0 + stale_lo += 1 + + best = min(history, key=lambda h: abs(h[target_stat] - target_kl)) + return best["coeff"], history diff --git a/src/iso_kl_figure/config.py b/src/iso_kl_figure/config.py new file mode 100644 index 0000000..8308c06 --- /dev/null +++ b/src/iso_kl_figure/config.py @@ -0,0 +1,112 @@ +"""SteeringConfig + Method protocol + registries. + +Each method ships its own subclass `XC(SteeringConfig)` and `XMethod` class +under `variants/*.py` (e.g. `MeanDiffC` + `MeanDiff`). Two parallel registries +keyed by method name: `_CONFIG_REGISTRY` for `from_dict` deserialisation, +`REGISTRY` for the runtime extract/apply pair. +""" +from dataclasses import dataclass, asdict, field +from typing import Protocol, Any +import torch +from torch import Tensor + + +@dataclass +class SteeringConfig: + """Base config. Subclass per method; do not instantiate directly.""" + method: str = "?" + + # which transformer blocks to hook (indices into model.model.layers) + # None = all layers + layers: tuple[int, ...] | None = None + + # which point in the block to add at: "residual" = block output (post mlp+attn), + # "attn_out" = attention output, "mlp_out" = mlp output. + # v1 only implements "residual". + target: str = "residual" + + # Optional dotted path of a sub-module within each target block to hook on + # (e.g. "mlp.down_proj"). When None, the block's forward output is hooked + # (default for almost all variants); when set, the sub-module's forward is + # hooked instead. Either way the variant's apply receives the unified + # (block, x, y, state, cfg) signature -- used by weight-SVD methods + # (sspace, sspace_ablate) that need to modify a Linear's output in low-rank + # S-space. + target_submodule: str | None = None + + # steering strength at apply-time. Methods interpret it differently: + # additive (mean_diff, pca, sspace, chars, cosine_gated): coeff is α in `h + α v`. + # slerp/angle (spherical, angular_steering): coeff is the slerp t / rotation θ. + # blend (linear_act): coeff is the blend ratio. + # ablation+nudge (directional_ablation): coeff is post-ablation nudge magnitude. + coeff: float = 1.0 + + dtype: torch.dtype = torch.bfloat16 + seed: int = 0 + + def to_dict(self) -> dict: + d = asdict(self) + d["dtype"] = str(self.dtype).removeprefix("torch.") + return d + + @classmethod + def from_dict(cls, d: dict) -> "SteeringConfig": + d = dict(d) + name = d["method"] + sub = _CONFIG_REGISTRY[name] + d["dtype"] = getattr(torch, d["dtype"]) + return sub(**d) + + +_CONFIG_REGISTRY: dict[str, type[SteeringConfig]] = {} + + +def register_config(cls: type[SteeringConfig]) -> type[SteeringConfig]: + """Decorator: register `cls` under its `method` default value.""" + name = cls.__dataclass_fields__["method"].default + if name == "?": + raise ValueError(f"{cls} must override the default `method` field") + if name in _CONFIG_REGISTRY: + raise ValueError(f"config for method {name!r} already registered") + _CONFIG_REGISTRY[name] = cls + return cls + + +class Method(Protocol): + """extract+apply pair. State tensors are registered as buffers on the hooked + module (block or Linear) under `_steering_state_` and rebuilt into a + dict by the hook. + """ + name: str + + @staticmethod + def extract( + pos_acts: dict[int, Tensor], + neg_acts: dict[int, Tensor], + cfg: Any, + ) -> dict[int, dict[str, Tensor]]: + """Per-layer state. `pos_acts[l]` is `[n_pos, d_model]`, same for neg.""" + ... + + @staticmethod + def apply( + mod, # the hooked module: a transformer block, or a Linear + x: Tensor, # [b, s, d_in] -- module input + y: Tensor, # [b, s, d_out] -- module output + state: dict[str, Tensor], + cfg: Any, + ) -> Tensor: + """Return the module's NEW output. Additive variants: `return y + delta`. + Replacing variants: ignore `y`, return any tensor of shape `[b, s, d_out]`. + """ + ... + + +REGISTRY: dict[str, type] = {} + + +def register(cls): + if not getattr(cls, "name", None): + raise ValueError(f"method {cls} missing .name") + REGISTRY[cls.name] = cls + return cls diff --git a/src/iso_kl_figure/extract.py b/src/iso_kl_figure/extract.py new file mode 100644 index 0000000..e275887 --- /dev/null +++ b/src/iso_kl_figure/extract.py @@ -0,0 +1,58 @@ +"""Record last non-pad-token hidden states at selected layers via forward hooks. + +We hook each block's forward output (it returns `(hidden_states, ...)`), gather +the final non-padding token from `attention_mask`, and stack across prompts. +No grad. +""" +from __future__ import annotations +import torch +from torch import nn, Tensor +from jaxtyping import Float + +from .target import _get_blocks + + +@torch.no_grad() +def record_activations( + model: nn.Module, + tok, + prompts: list[str], + layers: tuple[int, ...], + *, + batch_size: int = 8, + max_length: int = 256, +) -> dict[int, Float[Tensor, "n d"]]: + """Run prompts through model, return last-token hidden state at each layer.""" + blocks = _get_blocks(model) + device = next(model.parameters()).device + + bucket: dict[int, list[Tensor]] = {l: [] for l in layers} + captured: dict[int, Tensor] = {} + + def make_hook(li: int): + def hook(_mod, _args, out): + h = out[0] if isinstance(out, tuple) else out + captured[li] = h + return hook + + handles = [blocks[li].register_forward_hook(make_hook(li)) for li in layers] + try: + was_training = model.training + model.eval() + for i in range(0, len(prompts), batch_size): + batch = prompts[i : i + batch_size] + enc = tok(batch, return_tensors="pt", padding=True, truncation=True, max_length=max_length).to(device) + captured.clear() + model(**enc) + mask = enc["attention_mask"] + last_idx = mask.shape[1] - 1 - mask.flip([-1]).argmax(-1) + batch_idx = torch.arange(mask.shape[0], device=last_idx.device) + for li in layers: + bucket[li].append(captured[li][batch_idx, last_idx].detach().to("cpu")) + if was_training: + model.train() + finally: + for h in handles: + h.remove() + + return {li: torch.cat(bucket[li], dim=0) for li in layers} diff --git a/src/iso_kl_figure/target.py b/src/iso_kl_figure/target.py new file mode 100644 index 0000000..ba0d0f0 --- /dev/null +++ b/src/iso_kl_figure/target.py @@ -0,0 +1,148 @@ +"""Find transformer blocks (or sub-Linears) to hook. + +Default: hook each block's forward output (residual stream after attn+mlp). +When `cfg.target_submodule` is set, it is interpreted as a **regex** matched +against `block.named_modules()` paths under each selected block; matching +`nn.Linear`s become the actual hook targets. This lets a single cfg target +multiple Linears per block (e.g. residual writers `mlp.down_proj` AND +`self_attn.o_proj`, or all 7 Linears in q/k/v/o/gate/up/down). + +Works with HF llama-family architectures (llama, qwen, mistral, etc). For other +architectures, set `cfg.layers` to indices into whatever list lives at the path +your model exposes -- override `_get_blocks` if needed. +""" +import re +from torch import nn + + +def _get_blocks(model: nn.Module) -> nn.ModuleList: + # llama-family: model.model.layers + # gemma3-multimodal: model.language_model.layers (or model.model.language_model.layers) + candidates = [] + inner = getattr(model, "model", model) + candidates.append(inner) + lm = getattr(inner, "language_model", None) + if lm is not None: + candidates.append(lm) + candidates.append(getattr(lm, "model", lm)) + for c in candidates: + blocks = getattr(c, "layers", None) + if blocks is not None: + return blocks + raise RuntimeError( + f"could not find .layers on {type(model).__name__}; " + f"override _get_blocks for non-llama architectures" + ) + + +def find_targets(model: nn.Module, cfg) -> list[tuple[str, nn.Module, int]]: + """Return [(full_name, module, layer_idx)] for hook targets selected by cfg. + + - `cfg.target_submodule is None`: one entry per selected block (the block itself). + - `cfg.target_submodule = `: one entry per (block, matching nn.Linear). + Regex is matched with `re.fullmatch` against `name` from `block.named_modules()`, + e.g. `r"mlp\\.down_proj|self_attn\\.o_proj"` matches both residual writers, + `r"self_attn\\.(q|k|v|o)_proj|mlp\\.(gate|up|down)_proj"` matches all 7 Linears. + """ + blocks = _get_blocks(model) + n = len(blocks) + if cfg.layers is None: + idxs = tuple(range(n)) + else: + idxs = tuple(cfg.layers) + for i in idxs: + if not (0 <= i < n): + raise ValueError(f"layer {i} out of range [0, {n})") + sub = getattr(cfg, "target_submodule", None) + if sub is None: + return [(f"layers.{i}", blocks[i], i) for i in idxs] + pattern = re.compile(sub) + out = [] + for i in idxs: + for name, mod in blocks[i].named_modules(): + if name and pattern.fullmatch(name) and isinstance(mod, nn.Linear): + out.append((f"layers.{i}.{name}", mod, i)) + if not out: + raise RuntimeError( + f"target_submodule regex {sub!r} matched no nn.Linear " + f"in {len(idxs)} selected blocks" + ) + return out + + +def find_residual_linears( + model: nn.Module, + layer_indices: tuple[int, ...] | None = None, + *, + role: str = "both", # "writer" | "reader" | "both" + fallback_regex: str | None = None, +) -> list[tuple[str, nn.Module, int, str]]: + """Find Linears connected to the residual stream, per block. + + Returns `[(full_name, module, layer_idx, role)]` where role is "writer" + (d_out == d_model, d_in != d_model) or "reader" (d_in == d_model, + d_out != d_model). Square Linears are ambiguous (could be either) and + are excluded by shape detection. + + Detection: + 1. Primary: weight.shape vs d_model. + 2. Fallback: if shape detection finds nothing (non-llama arch, weird + shapes), match `fallback_regex` against `named_modules` paths and + guess role from substring. Default regex covers llama-family names. + Warns when fallback fires. + """ + from loguru import logger + d_model = get_d_model(model) + blocks = _get_blocks(model) + n = len(blocks) + idxs = tuple(layer_indices) if layer_indices is not None else tuple(range(n)) + + out: list[tuple[str, nn.Module, int, str]] = [] + for li in idxs: + for name, mod in blocks[li].named_modules(): + if not isinstance(mod, nn.Linear): + continue + d_out, d_in = mod.weight.shape + is_writer = d_out == d_model and d_in != d_model + is_reader = d_in == d_model and d_out != d_model + if is_writer and role in ("writer", "both"): + out.append((f"layers.{li}.{name}", mod, li, "writer")) + elif is_reader and role in ("reader", "both"): + out.append((f"layers.{li}.{name}", mod, li, "reader")) + + if out: + return out + + regex = fallback_regex or r"mlp\.(down|gate|up)_proj|self_attn\.(q|k|v|o)_proj" + logger.warning( + f"shape-based residual-linear detection found nothing for d_model={d_model} " + f"in {len(idxs)} blocks; falling back to regex {regex!r}" + ) + pattern = re.compile(regex) + writer_hints = ("down_proj", "o_proj") + for li in idxs: + for name, mod in blocks[li].named_modules(): + if not (name and pattern.fullmatch(name) and isinstance(mod, nn.Linear)): + continue + role_guess = "writer" if any(h in name for h in writer_hints) else "reader" + if role in ("both", role_guess): + out.append((f"layers.{li}.{name}", mod, li, role_guess)) + + if not out: + logger.warning( + f"regex fallback {regex!r} also matched no Linears in layers {idxs}; " + "super_sspace will have an empty basis" + ) + return out + + +def get_d_model(model: nn.Module) -> int: + cfg = getattr(model, "config", None) + d = getattr(cfg, "hidden_size", None) + if d is None: + # multimodal configs (gemma3): text sub-config + text_cfg = getattr(cfg, "text_config", None) + d = getattr(text_cfg, "hidden_size", None) + if d is None: + raise RuntimeError("model has no .config.hidden_size") + return int(d) diff --git a/src/iso_kl_figure/variants/__init__.py b/src/iso_kl_figure/variants/__init__.py new file mode 100644 index 0000000..611f108 --- /dev/null +++ b/src/iso_kl_figure/variants/__init__.py @@ -0,0 +1,6 @@ +"""Variant registry. Importing this package triggers @register_config + @register +side effects in the variant modules. +""" +from . import mean_diff # noqa: F401 +from . import pca # noqa: F401 +from . import directional_ablation # noqa: F401 diff --git a/src/iso_kl_figure/variants/directional_ablation.py b/src/iso_kl_figure/variants/directional_ablation.py new file mode 100644 index 0000000..815643f --- /dev/null +++ b/src/iso_kl_figure/variants/directional_ablation.py @@ -0,0 +1,75 @@ +"""Mean-diff directional ablation (Arditi-inspired projection-out). + +Project the steering direction *out of* the residual stream instead of (or in +addition to) adding to it. Unlike `mean_diff` which translates by $\\alpha v$, +ablation removes the component of $h$ along $\\hat v$: + +$$h \\leftarrow h - (h \\cdot \\hat v)\\hat v + \\alpha\\hat v$$ + +When `coeff=0` this is pure ablation (refusal-direction style); when `coeff!=0` +this is ablation followed by a constant nudge (useful to ablate "old" behavior +and inject "new"). The two terms are mathematically distinct -- ablation is a +*projection* (idempotent), addition is a *translation*. + +Norms shrink by $|h \\cdot \\hat v|$ which is informative -- a near-zero shrink +means the direction wasn't present in the first place, so the intervention is +a no-op. Compare to `mean_diff` which always pays a constant $\\alpha\\|\\hat v\\|$ +per token regardless of whether the direction is present. + +Refs / inspiration: + - Arditi et al. 2024 "Refusal in language models is mediated by a single direction" + https://arxiv.org/abs/2406.11717 + - andyrdt/refusal_direction https://github.com/andyrdt/refusal_direction +""" +from dataclasses import dataclass +from einops import einsum +from jaxtyping import Float +from torch import Tensor + +from ..config import SteeringConfig, register_config, register + + +ε = 1e-8 + + +@register_config +@dataclass +class DirectionalAblationC(SteeringConfig): + method: str = "directional_ablation" + coeff: float = 0.0 # post-ablation additive nudge along v_hat; 0.0 = pure ablation + + +@register +class DirectionalAblation: + name = "directional_ablation" + + @staticmethod + def extract( + pos_acts: dict[int, Float[Tensor, "n d"]], + neg_acts: dict[int, Float[Tensor, "m d"]], + cfg: DirectionalAblationC, + ) -> dict[int, dict[str, Tensor]]: + out = {} + for li in pos_acts: + v = pos_acts[li].float().mean(0) - neg_acts[li].float().mean(0) + v = v / (v.norm() + ε) + + out[li] = {"v": v} + return out + + @staticmethod + def apply( + mod, + x: Float[Tensor, "b s d"], + y: Float[Tensor, "b s d"], + state: dict[str, Tensor], + cfg: DirectionalAblationC, + ) -> Float[Tensor, "b s d"]: + v = state["v"].to(y) # unit + + proj = einsum(y, v, "b s d, d -> b s") + y = y - proj.unsqueeze(-1) * v # ablate + + if cfg.coeff != 0.0: + y = y + cfg.coeff * v + return y diff --git a/src/iso_kl_figure/variants/mean_diff.py b/src/iso_kl_figure/variants/mean_diff.py new file mode 100644 index 0000000..64c467b --- /dev/null +++ b/src/iso_kl_figure/variants/mean_diff.py @@ -0,0 +1,82 @@ +"""Mean-difference steering (CAA / ActAdd). + +For each selected layer L, compute the mean difference between positive and +negative last-token hidden states: + +$$v_L = \\text{mean}(h^+_L) - \\text{mean}(h^-_L), \\quad \\hat{v}_L = v_L / \\|v_L\\|$$ + +At runtime, add `coeff * v_hat` to every token's residual at that block: + +$$h \\leftarrow h + \\alpha \\cdot \\hat{v}_L$$ + +This is the same operation as CAA (Panickssery 2023, contrastive MCQ pairs) +and ActAdd (Turner 2023, single prompt-pair); the differences are conventional +not mathematical, so we register one method. + +`subtract_corpus_mean=True` toggles Jorgensen 2024 mean-centring: target mean +minus pos∪neg corpus mean. Direction-identical to plain mean_diff under +normalization with equal-size groups; kept as a flag rather than a separate +method. + +Refs: + - Panickssery 2023 (CAA) https://arxiv.org/abs/2312.06681 + - Turner 2023 (ActAdd) https://arxiv.org/abs/2308.10248 + - Jorgensen 2024 (Mean-Centring) https://arxiv.org/abs/2312.03813 + - nrimsky/CAA https://github.com/nrimsky/CAA +""" +from dataclasses import dataclass +import torch +from jaxtyping import Float +from torch import Tensor + +from ..config import SteeringConfig, register_config, register + + +ε = 1e-8 + + +@register_config +@dataclass +class MeanDiffC(SteeringConfig): + method: str = "mean_diff" + normalize: bool = True + subtract_corpus_mean: bool = False + + +@register +class MeanDiff: + name = "mean_diff" + + @staticmethod + def extract( + pos_acts: dict[int, Float[Tensor, "n d"]], + neg_acts: dict[int, Float[Tensor, "m d"]], + cfg: MeanDiffC, + ) -> dict[int, dict[str, Tensor]]: + out = {} + for li in pos_acts: + p = pos_acts[li].float() + n = neg_acts[li].float() + + if cfg.subtract_corpus_mean: + mu = torch.cat([p, n], dim=0).mean(0) + v = p.mean(0) - mu + else: + v = p.mean(0) - n.mean(0) + + if cfg.normalize: + v = v / (v.norm() + ε) + + out[li] = {"v": v} + return out + + @staticmethod + def apply( + mod, + x: Float[Tensor, "b s d"], + y: Float[Tensor, "b s d"], + state: dict[str, Tensor], + cfg: MeanDiffC, + ) -> Float[Tensor, "b s d"]: + v = state["v"].to(y) + return y + cfg.coeff * v diff --git a/src/iso_kl_figure/variants/pca.py b/src/iso_kl_figure/variants/pca.py new file mode 100644 index 0000000..6ed38c3 --- /dev/null +++ b/src/iso_kl_figure/variants/pca.py @@ -0,0 +1,101 @@ +"""PCA steering (RepE/LAT-inspired, vgel pca_diff-like). + +For each layer L, compute PCA on the **paired differences** `h^+ - h^-`. Take +the top principal component as the steering direction. + +$$D_L = H^+_L - H^-_L \\in \\mathbb{R}^{n\\times d}$$ +$$U, S, V^T = \\text{SVD}(D_L - \\bar{D}_L)$$ +$$\\text{sign}_L = \\text{sign}\\left(\\sum_i \\mathbb{1}[(D_L)_i \\cdot V_{:,0} > 0] - n/2\\right)$$ +$$v_L = V_{:,0} \\cdot \\text{sign}_L$$ + +Sign-fixed by majority vote of paired-diff projections (repeng/vgel style). +This is a lightweight control-vector baseline, not the full Zou et al. LAT +reader: it omits per-diff normalization, label-based sign selection, and +train-mean recentering for reading scores. +PCA is sign-ambiguous; the vote is more robust than alignment-with-the-mean +when paired diffs are heterogeneous (mean can cancel without the vote +changing). If the vote ties exactly, orient the axis so the largest centered +projection is positive. + +At runtime, add `coeff * v_L` to the residual. + +Refs: + - Zou et al. 2023 (Representation Engineering) https://arxiv.org/abs/2310.01405 + - vgel/repeng: https://github.com/vgel/repeng +""" +from dataclasses import dataclass +import torch +from jaxtyping import Float +from torch import Tensor + +from ..config import SteeringConfig, register_config, register + + +ε = 1e-8 + + +@register_config +@dataclass +class PCAC(SteeringConfig): + method: str = "pca" + n_components: int = 1 + normalize: bool = True + + +@register +class PCA: + name = "pca" + + @staticmethod + def extract( + pos_acts: dict[int, Float[Tensor, "n d"]], + neg_acts: dict[int, Float[Tensor, "n d"]], + cfg: PCAC, + ) -> dict[int, dict[str, Tensor]]: + out = {} + for li in pos_acts: + if pos_acts[li].shape[0] != neg_acts[li].shape[0]: + raise ValueError(f"layer {li}: pos/neg counts differ") + + diffs = (pos_acts[li] - neg_acts[li]).float() + centered = diffs - diffs.mean(0, keepdim=True) + + _, _, Vh = torch.linalg.svd(centered, full_matrices=False) + v = Vh[: cfg.n_components] + + projs = centered @ v.T + positive_frac = (projs > 0).float().mean(0) + majority_sign = torch.where(positive_frac > 0.5, + torch.ones(v.shape[0]), + -torch.ones(v.shape[0])).to(v) + strongest_idx = projs.abs().argmax(dim=0) + strongest = projs[strongest_idx, torch.arange(v.shape[0], device=projs.device)] + strongest_sign = torch.sign(strongest) + sign = torch.where(positive_frac == 0.5, strongest_sign, majority_sign) + v = v * sign[:, None] + + if cfg.n_components == 1: + v = v.squeeze(0) + if cfg.normalize: + v = v / (v.norm() + ε) + out[li] = {"v": v} + else: + if cfg.normalize: + v = v / (v.norm(dim=1, keepdim=True) + ε) + out[li] = {"V": v} + return out + + @staticmethod + def apply( + mod, + x: Float[Tensor, "b s d"], + y: Float[Tensor, "b s d"], + state: dict[str, Tensor], + cfg: PCAC, + ) -> Float[Tensor, "b s d"]: + if "v" in state: + v = state["v"].to(y) + return y + cfg.coeff * v + # multi-component: sum top-k directions equally + V = state["V"].to(y) + return y + cfg.coeff * V.sum(0) diff --git a/src/iso_kl_figure/vector.py b/src/iso_kl_figure/vector.py new file mode 100644 index 0000000..b2ab4d5 --- /dev/null +++ b/src/iso_kl_figure/vector.py @@ -0,0 +1,114 @@ +"""Vector: extracted steering vector + config, as a single ergonomic object. + +Wraps `(cfg, state)` so a user can: + + v = Vector.train(model, tok, pos, neg, sl.MeanDiffC(layers=(15,))) \\ + .calibrate(model, tok, target_kl=1.0) + + with v(model): + out = model.generate(...) + + v.save("honesty.safetensors") + v2 = Vector.load("honesty.safetensors") + + combined = v + v2 # ensemble (sum buffers, requires same cfg.method) + scaled = v * 0.5 # scale buffers +""" +from __future__ import annotations +from contextlib import contextmanager +from copy import deepcopy +from dataclasses import replace + +import torch +from torch import Tensor, nn + +from .config import SteeringConfig + + +class Vector: + def __init__(self, cfg: SteeringConfig, state: dict[int, dict[str, Tensor]]): + self.cfg = cfg + self.state = state + + @classmethod + def train(cls, model: nn.Module, tok, pos_prompts: list[str], neg_prompts: list[str], + cfg: SteeringConfig, **kw) -> "Vector": + """Extract a steering vector from contrastive prompts. Chains with .calibrate().""" + from .attach import train as _train + return _train(model, tok, pos_prompts, neg_prompts, cfg, **kw) + + def calibrate(self, model: nn.Module, tok, + prompts: list[str] | list[Tensor] | None = None, *, + target_kl: float = 1.0, **kw) -> "Vector": + """Set coeff so KL(steered || base) hits target_kl. Mutates and returns self for chaining. + + `prompts` defaults to a small generic set; pass list[str] to use your own. + """ + from .calibrate import calibrate_iso_kl + coeff, _ = calibrate_iso_kl(self, model, tok, prompts, target_kl=target_kl, **kw) + self.cfg.coeff = float(coeff) + return self + + @contextmanager + def __call__(self, model: nn.Module, *, C: float | None = None): + """Attach for the duration of the `with` block. `C` overrides cfg.coeff if given.""" + from .attach import attach, detach + cfg = self.cfg if C is None else replace(self.cfg, coeff=float(C)) + attach(model, cfg, self.state) + try: + yield model + finally: + detach(model) + + def __add__(self, other: "Vector") -> "Vector": + if self.cfg.method != other.cfg.method: + raise ValueError(f"cannot add {self.cfg.method!r} + {other.cfg.method!r}") + if sorted(self.state) != sorted(other.state): + raise ValueError(f"layer mismatch: {sorted(self.state)} vs {sorted(other.state)}") + new_state: dict[int, dict[str, Tensor]] = {} + for li in self.state: + a, b = self.state[li], other.state[li] + if sorted(a) != sorted(b): + raise ValueError(f"layer {li}: state keys differ {sorted(a)} vs {sorted(b)}") + new_state[li] = {k: a[k] + b[k] for k in a} + return Vector(deepcopy(self.cfg), new_state) + + def __mul__(self, k: float) -> "Vector": + new_state = { + li: {k_: v * float(k) for k_, v in s.items()} + for li, s in self.state.items() + } + return Vector(deepcopy(self.cfg), new_state) + + __rmul__ = __mul__ + + def save(self, path: str) -> None: + from .attach import _STATE_PREFIX, _SUB_KEY_PREFIX, _SUB_KEY_SEP # noqa: F401 + import json + from safetensors.torch import save_file + sd: dict[str, Tensor] = {} + sub_mode = self.cfg.target_submodule is not None + for key, s in self.state.items(): + for k, t in s.items(): + if sub_mode: + sd[f"{_SUB_KEY_PREFIX}{key}{_SUB_KEY_SEP}{k}"] = t.detach().cpu() + else: + sd[f"layer{key}.{k}"] = t.detach().cpu() + metadata = {"cfg": json.dumps(self.cfg.to_dict())} + save_file(sd, path, metadata=metadata) + + @classmethod + def load(cls, path: str) -> "Vector": + import json + from safetensors.torch import load_file, safe_open + from .attach import _safetensors_dict_to_state + with safe_open(path, framework="pt", device="cpu") as f: + metadata = f.metadata() + sd = load_file(path, device="cpu") + cfg = SteeringConfig.from_dict(json.loads(metadata["cfg"])) + state = _safetensors_dict_to_state(sd) + return cls(cfg, state) + + def __repr__(self) -> str: + layers = sorted(self.state) + return f"Vector(method={self.cfg.method!r}, layers={layers})" diff --git a/tests/test_smoke.py b/tests/test_smoke.py new file mode 100644 index 0000000..c7b211c --- /dev/null +++ b/tests/test_smoke.py @@ -0,0 +1,65 @@ +"""Smoke test: train + calibrate + branch_pmass on a tiny random model. + +Pass = runtime sanity. Distinguishing checks: + - measure_kl returns kl > 0 at coeff > 0 (steer DID change distribution). + - measure_kl returns kl ~= 0 at coeff = 0 (silent failure detector: if hooks + leak, base==steer KL would be nonzero). + - branch_pmass is in [0, 1]. + - branch_pmass changes between coeff=0 and coeff=large (sneaky-fail catch: + if pmass is just identity-pass-through it would be invariant). +""" +from __future__ import annotations +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from iso_kl_figure import ( + SteeringConfig, MeanDiffC, train, measure_kl, attach, detach, +) +from iso_kl_figure.branch_pmass import branch_pmass + + +MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM" + + +def test_pipeline_smoke(): + tok = AutoTokenizer.from_pretrained(MODEL) + if tok.pad_token_id is None: + tok.pad_token_id = tok.eos_token_id + model = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype=torch.float32) + model.eval() + n_layers = model.config.num_hidden_layers + cfg = MeanDiffC(coeff=1.0, layers=(n_layers // 2,)) + + pos = ["Sure: ", "Yes: ", "Of course: ", "Here: "] + neg = ["No way.", "I refuse.", "Cannot help.", "Decline."] + v = train(model, tok, pos, neg, cfg, batch_size=2, max_length=32) + + # KL must be ~0 at coeff=0 (no leak), and > 0 at large coeff + v.cfg.coeff = 0.0 + m0 = measure_kl(v, model, tok, ["hello world"], T=4, device="cpu") + assert m0["kl_p95"] < 1e-3, f"coeff=0 should give zero KL, got {m0['kl_p95']}" + + v.cfg.coeff = 5.0 + m1 = measure_kl(v, model, tok, ["hello world"], T=4, device="cpu") + assert m1["kl_p95"] > 0.0, "coeff>0 should give nonzero KL" + + # per_t arrays length matches T + assert len(m1["per_t_p95"]) == 4 + assert len(m1["per_t_max"]) == 4 + + # branch_pmass: in [0, 1] and varies with coeff + pids = tok("hello", return_tensors="pt").input_ids[0] + rolled = pids[-2:].clone() + suffix = ' {"value": ' + fork = [0, 1, 2] + + v.cfg.coeff = 0.0 + p_zero = branch_pmass(v, model, tok, pids, rolled, fork, suffix, + ["true", "false"], device="cpu") + v.cfg.coeff = 5.0 + p_steer = branch_pmass(v, model, tok, pids, rolled, fork, suffix, + ["true", "false"], device="cpu") + for x in p_zero["pmass"] + p_steer["pmass"]: + assert 0.0 <= x <= 1.0, f"pmass out of [0,1]: {x}" + diff = max(abs(a - b) for a, b in zip(p_zero["pmass"], p_steer["pmass"])) + assert diff > 1e-6, "pmass invariant to coeff -- hook is dead"