iso-kl-figure: scaffold + smoke test passing

This commit is contained in:
copilot
2026-05-05 06:17:25 +08:00
commit 45b7123cf5
23 changed files with 1949 additions and 0 deletions
+12
View File
@@ -0,0 +1,12 @@
__pycache__/
*.pyc
.venv/
uv.lock
*.egg-info/
.pytest_cache/
.ruff_cache/
outputs/*.csv
outputs/*.tsv
outputs/*.png
outputs/*.md
!outputs/.gitkeep
+24
View File
@@ -0,0 +1,24 @@
# AGENTS.md
Inherits conventions from sibling project `steering-lite`. Read [../steering-lite/AGENTS.md](../steering-lite/AGENTS.md) if it exists.
## House rules
- Fail fast. No defensive programming, no fallbacks, no silent dequant.
- Keep this repo small. Anything beyond the headline figure + table belongs in another repo.
- Use `einops` and `jaxtyping` shape annotations at function boundaries only. Tensor dim letters: `b s d` (batch, seq, d_model), `n` (prompts), `t` (token positions), `f` (fork points).
- No backward compat.
- Single functional smoke test = the real pipeline at tiny scale (`tests/test_smoke.py`).
- Methods register via `@register_config` and `@register` decorators; mirror `steering-lite/src/steering_lite/config.py`.
- All experiment scripts write CSV/TSV. Plot/table scripts read CSV/TSV. Never plot from in-memory state.
## Out of scope (deliberately)
- Method zoo beyond mean_diff, directional_ablation, pca.
- LessWrong post / paper draft.
- Citation collection.
- tinymfv or any external eval dependency.
## Verify
`just smoke` -> 3/3 methods pass calibrate -> trajectory -> branch-pmass on tiny-random Llama. Asserts nonzero KL at coeff>0, zero KL at coeff=0, branch-pmass in [0,1].
+36
View File
@@ -0,0 +1,36 @@
# iso-kl-figure
Minimal repo with one job: produce a figure and a table that demonstrate iso-KL calibration is stable across models, seeds, and calibration windows.
## Claim (narrow)
Calibrating a steering coefficient so that p95 per-token KL(steered || base) hits 1 nat in a short calibration window:
- C1: bisection converges for every method tested; held-out p95 KL lands near 1 nat.
- C2 (not too cold): target-axis Delta logit at calibrated alpha is non-zero across methods.
- C3 (not too hot): base-NLL of generated text and branch-pmass of a forced format token stay near base at calibrated alpha across methods.
The 2x check is a sanity probe, not a margin claim. Reported as: at 2x, p95 KL exceeds 1 nat for N of M cells.
Honesty footnote: matched on per-token distributional disagreement under greedy decoding in the calibration window. This is one defensible notion of fairness; not equivalence on intervention norm or behavioral effect size.
## Quick start
```bash
uv sync --extra all
just smoke # tiny-random model, ~1 min CPU
just calibrate # one (model, method, seed, window) cell
just trajectory
just table
just plot
just table-md
```
`just sweep` runs the full grid (3 models x 3 methods x 3 seeds x 2 windows) used by Figure 1.
## What this repo does NOT do
- No paper or LessWrong draft.
- No method zoo beyond mean_diff, directional_ablation, pca.
- No threshold sweep, no calibration-set-size sweep.
- No tinymfv integration. Target-axis is a single contrastive sentiment / refusal pair.
+68
View File
@@ -0,0 +1,68 @@
# iso-kl-figure: spec
## Goal
Produce one figure (Figure 1) and one table (Table 1) that empirically support three claims: iso-KL calibration converges and generalizes (C1), the calibrated coefficient is not too cold (C2), and not too hot (C3). Show stability across 3 models x 3 seeds x 2 calibration windows.
## Scope
In:
- Port `measure_kl`, `calibrate_iso_kl`, minimal Vector/attach/config/target/extract from steering-lite.
- 3 methods: `mean_diff`, `directional_ablation`, `pca`.
- New `branch_pmass` metric: fork-and-teacher-force probability mass on a forced format answer token.
- Scripts producing TSV/CSV; plot and table modules consuming the CSVs.
Out:
- LessWrong post or paper draft.
- Method zoo beyond 3 methods.
- Threshold sweep, calibration-set-size sweep, norm-matching baseline.
- tinymfv integration.
## Requirements
- R1 (C1, calibration converges and generalizes): for every (method, model, seed, window), bisection terminates with calibration p95 within tolerance of 1.0; on a held-out prompt set p95 lands within [0.7, 1.4]. VERIFY: TSV row has converged=true and holdout_p95 in band; sneaky failure (overfits calibration prompts) caught by held-out column.
- R2 (C2, not too cold): target-axis Delta logit at calibrated alpha excludes 0 with 95% CI for each method, on each model. VERIFY: Table 1 row reports CI; sneaky failure (alpha approx 0) caught by alpha column in same row.
- R3 (C3, not too hot, NLL): base-NLL of full 50-token held-out generations stays within 2x of base at calibrated alpha; exceeds 4x of base at 2x calibrated alpha for at least one method per model. VERIFY: Table 1 base_nll_delta column.
- R4 (C3, not too hot, branch-pmass): mean branch-pmass-of-valid-answer at fork points {0, 5, ..., 50} stays within 0.1 of base pmass at calibrated alpha; drops by more than 0.3 at 2x alpha for at least one method per model. VERIFY: Table 1 branch_pmass column and Figure 1 lower subplot.
- R5 (sanity probe at 2x): max p95 KL at 2x alpha exceeds 1 nat in at least 2 of 3 methods on at least 2 of 3 models within 50 tokens. VERIFY: Figure 1 top subplot, alpha=2 panels show lines crossing reference.
- R6 (stability): seed band and window-style overlay in Figure 1 do not change the qualitative C1 conclusion. VERIFY: variance band visually narrow at alpha=1.
## Tasks
- [/] T1 (R*): scaffold repo (pyproject, justfile, README, AGENTS, spec).
- verify: `just --list` lists recipes; `uv sync --extra all` resolves.
- [ ] T2 (R1, R2, R3, R4): port core code from steering-lite (calibrate, vector, attach, config, target, extract, 3 variants).
- verify: imports clean; smoke test runs all 3 methods.
- [ ] T3 (R1, R6): extend calibrate history to save per-token KL arrays (`per_t_p95`, `per_t_max`).
- verify: history dict contains per-token arrays of length T.
- [ ] T4 (R4): implement `branch_pmass` (fork at token t, append fixed format suffix, teacher-force one forward, sum p over `true`/`false` tokens).
- verify: pmass in [0, 1]; pmass at base != pmass at coeff=large (sneaky-fail catch).
- [ ] T5 (R1..R5): implement `run_calibrate.py`, `run_trajectory.py`, `run_table.py`.
- verify: CSVs created with expected columns and at least one row each on smoke.
- [ ] T6 (R*): implement `plot.py`, `table.py`.
- verify: PNG saved; markdown table prints; can be regenerated from CSVs alone.
- [ ] T7 (R*): full sweep on real models.
- verify: numeric asserts in R1..R5 pass.
- [ ] T8 (R*): external review of figure + table.
- verify: review doc saved under docs/spec/.
## Context
Calibration target: p95 per-token KL(steered || base) = 1 nat over T tokens (T in {20, 50}), N=4 calibration prompts under greedy decoding.
Branch-pmass procedure: at fork points t in {0, 5, ..., 50} take steered prefix of length t, append `\nAnswer (true/false): ` then `{"value": ` then teacher-force one forward under steered model, sum probabilities of token variants for `true` and `false`.
Target-axis: a single contrastive pair-set built into the repo (sentiment positive vs negative or refusal yes vs no), 4 prompts each. Target Delta logit = mean over held-out prompts of difference in logit on the target token.
## Log
(append-only; only entries that change a future task)
## TODO
(out-of-scope ideas; not commitments)
## Errors
| Task | Error | Resolution |
|------|-------|------------|
+24
View File
@@ -0,0 +1,24 @@
set shell := ["bash", "-cu"]
default:
@just --list
# Smoke: tiny-random Llama, all 3 methods, asserts nonzero KL + branch-pmass changes with coeff.
smoke:
BEARTYPE=1 uv run --extra all pytest -q tests/test_smoke.py
test:
uv run --extra all pytest -q
# Run one (model, method, seed, window) cell end-to-end (calibrate + trajectory + pmass).
cell model="Qwen/Qwen2.5-0.5B-Instruct" method="mean_diff" seed="0" window="50":
uv run --extra all python scripts/run_cell.py \
--model {{model}} --method {{method}} --seed {{seed}} --window {{window}}
# Sweep model x method x seed x window cells.
sweep:
bash scripts/sweep.sh
# Aggregate all outputs/<run_id>/ into figs/figure1.png + figs/table.md.
aggregate:
uv run --extra all python scripts/aggregate.py --runs-root outputs --out figs
View File
+36
View File
@@ -0,0 +1,36 @@
[project]
name = "iso-kl-figure"
version = "0.0.1"
description = "Minimal repo: produce one figure + one table proving iso-KL calibration is stable across models/seeds/windows."
readme = "README.md"
requires-python = ">=3.11"
dependencies = [
"torch>=2.1",
"numpy>=1.26",
"einops>=0.7",
"jaxtyping>=0.2.34",
"safetensors>=0.5",
"loguru>=0.7",
"tqdm>=4.66",
]
[project.optional-dependencies]
test = ["pytest", "tabulate", "beartype>=0.18"]
hf = ["accelerate>=1.6", "transformers>=4.51"]
plot = ["matplotlib>=3.8", "polars>=1.0"]
all = [
"pytest", "tabulate", "beartype>=0.18",
"accelerate>=1.6", "transformers>=4.51",
"matplotlib>=3.8", "polars>=1.0",
"tyro>=0.9",
]
[build-system]
requires = ["setuptools>=68"]
build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find]
where = ["src"]
[tool.ruff.lint]
ignore = ["F722"] # jaxtyping shape strings
+149
View File
@@ -0,0 +1,149 @@
"""Aggregate per-cell outputs into Figure 1 + the headline table.
Figure 1: two stacked subplots.
Top: per-token p95 KL trajectory. x = token offset; y = KL(steer || base).
Colour by method, linestyle by alpha (solid=1, dashed=2), seed bands
as thin lines, faceted by model. Horizontal at target_kl=1.
Bottom: branch-pmass at fork points. x = fork token offset; y = mean pmass
across held-out prompts; bands = +/- 1 std across seeds.
Table: one row per (model, method), columns = c_star (mean +/- std across seeds),
KL_p95 @ alpha=1, KL_p95 @ alpha=2, pmass @ alpha=1, pmass @ alpha=2.
Usage:
python scripts/aggregate.py --runs_root outputs --out figs/
"""
from __future__ import annotations
import json
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
import polars as pl
import tyro
from loguru import logger
@dataclass
class Args:
runs_root: str = "outputs"
out: str = "figs"
def load_cells(root: Path) -> list[dict]:
cells = []
for d in sorted(root.iterdir()):
if not d.is_dir():
continue
calib = d / "calib.json"
if not calib.exists():
continue
meta = json.loads(calib.read_text())
traj = json.loads((d / "trajectory.json").read_text())
pmass = json.loads((d / "pmass.json").read_text())
cells.append({"id": d.name, **meta, "traj": traj, "pmass": pmass})
return cells
def make_table(cells: list[dict]) -> pl.DataFrame:
rows = []
by_mm = defaultdict(list)
for c in cells:
by_mm[(c["model"], c["method"])].append(c)
for (model, method), group in by_mm.items():
c_stars = [g["c_star"] for g in group]
# pmass: mean over fork_points and prompts at each alpha, then across seeds
for alpha in ("1.0", "2.0"):
kls = []
pms = []
for g in group:
kls.append(g["traj"]["per_t_p95_kl"][alpha])
pms.append(g["pmass"]["pmass"][alpha])
kls_flat = [x for arr in kls for x in arr]
pms_flat = [x for prompt in pms for arr in prompt for x in arr]
rows.append({
"model": model.split("/")[-1],
"method": method,
"alpha": float(alpha),
"c_star_mean": sum(c_stars) / len(c_stars),
"n_seeds": len(group),
"kl_p95_mean": sum(kls_flat) / max(len(kls_flat), 1),
"pmass_mean": sum(pms_flat) / max(len(pms_flat), 1),
})
return pl.DataFrame(rows)
def make_figure(cells: list[dict], out_path: Path) -> None:
import matplotlib.pyplot as plt
import numpy as np
models = sorted({c["model"] for c in cells})
methods = sorted({c["method"] for c in cells})
fig, axes = plt.subplots(2, len(models), figsize=(5 * len(models), 7),
sharex="col", squeeze=False)
cmap = plt.get_cmap("tab10")
method_color = {m: cmap(i) for i, m in enumerate(methods)}
for ci, model in enumerate(models):
ax_kl = axes[0, ci]
ax_pm = axes[1, ci]
ax_kl.set_title(model.split("/")[-1])
ax_kl.axhline(1.0, color="black", linestyle=":", linewidth=0.8, alpha=0.5)
ax_kl.set_ylabel("p95 KL(steer || base)")
ax_pm.set_xlabel("token offset")
ax_pm.set_ylabel("branch pmass")
ax_pm.set_ylim(-0.02, 1.02)
ax_kl.set_yscale("log")
for method in methods:
for alpha, ls in [("1.0", "-"), ("2.0", "--")]:
kls = [c["traj"]["per_t_p95_kl"][alpha]
for c in cells if c["model"] == model and c["method"] == method]
if not kls:
continue
arr = np.array(kls)
x = np.arange(arr.shape[1])
ax_kl.plot(x, arr.mean(0), color=method_color[method],
linestyle=ls, linewidth=2,
label=f"{method} a={alpha}")
if arr.shape[0] > 1:
ax_kl.fill_between(x, arr.min(0), arr.max(0),
color=method_color[method], alpha=0.12)
pms = [c["pmass"]["pmass"][alpha]
for c in cells if c["model"] == model and c["method"] == method]
if not pms:
continue
# pms: list of (n_seed) of (n_prompt) of (n_fork)
pms_arr = np.array(pms) # (n_seed, n_prompt, n_fork)
fork = cells[0]["pmass"]["fork_points"]
mean = pms_arr.mean(axis=(0, 1))
std = pms_arr.std(axis=(0, 1))
ax_pm.plot(fork, mean, color=method_color[method],
linestyle=ls, linewidth=2)
ax_pm.fill_between(fork, mean - std, mean + std,
color=method_color[method], alpha=0.12)
if ci == 0:
ax_kl.legend(fontsize=8, loc="upper left")
fig.tight_layout()
fig.savefig(out_path, dpi=150, bbox_inches="tight")
logger.info(f"figure -> {out_path}")
def main(a: Args):
out = Path(a.out); out.mkdir(parents=True, exist_ok=True)
cells = load_cells(Path(a.runs_root))
if not cells:
raise SystemExit(f"no cells under {a.runs_root}")
logger.info(f"loaded {len(cells)} cells")
df = make_table(cells)
df.write_csv(out / "table.csv")
md = df.to_pandas().to_markdown(index=False, floatfmt=".3f")
(out / "table.md").write_text(md)
logger.info(f"table -> {out/'table.md'}\n{md}")
make_figure(cells, out / "figure1.png")
if __name__ == "__main__":
main(tyro.cli(Args))
+198
View File
@@ -0,0 +1,198 @@
"""End-to-end runner for one (model, method, seed, window) cell.
Flow:
1. Load model + tokenizer (HF), set seed.
2. Build pos/neg prompts (cheap pair); train the steering Vector v.
3. Calibrate iso-KL at target_kl=1 over T=window tokens. Save full history
(incl. per-token KL arrays) to outputs/<run_id>/history.json.
4. Re-run measure_kl at coeff=alpha*c_star (alpha in {1, 2}) on a held-out
prompt set so the trajectory plot reflects generalisation, not the
calibration set itself. Save per-token p95 KL to trajectory.json.
5. For each held-out prompt, rollout T_eval tokens under the steered model,
then branch-pmass at fork_points {0, 5, ..., T_eval}. Save to
pmass.json. Use a JSON-format suffix so target tokens are well-defined.
Outputs one CSV row per (alpha, prompt) into outputs/<run_id>/results.csv
plus the artefacts above.
Usage (tyro):
python scripts/run_cell.py --model Qwen/Qwen2.5-0.5B-Instruct --method mean_diff \
--seed 0 --window 50 --run_id qwen05b_md_s0_w50
"""
from __future__ import annotations
import json
import os
from dataclasses import dataclass, field
from pathlib import Path
import torch
import tyro
from loguru import logger
from iso_kl_figure import (
SteeringConfig, MeanDiffC, PCAC, DirectionalAblationC,
train, calibrate_iso_kl, measure_kl,
)
from iso_kl_figure.branch_pmass import branch_pmass
CALIB_PROMPTS = [
"The eiffel tower is in Paris.",
"Tell me a greentext story about a small village during the Martian carrot bubble.",
"Walk me through computing the integral of x^2 from 0 to 1.",
"Briefly describe how plate tectonics shape mountain ranges.",
]
# Held-out prompts -- different from calibration; tests generalisation.
EVAL_PROMPTS = [
"Explain why the sky is blue in two short sentences.",
"Summarise the plot of Hamlet in one paragraph.",
"List three causes of inflation.",
"Write a haiku about an octopus librarian.",
]
# Pos/neg pair: a generic content-vs-refusal direction. Cheap, not the main
# point -- the figure tests calibration *behaviour*, not direction quality.
POS_NEG = [
("Sure, here is the answer:", "I cannot help with that."),
("Here are some details:", "Sorry, I can't assist with that."),
("Of course, let me explain.", "I won't be able to help."),
("Yes, that makes sense.", "No, I have to decline."),
]
METHOD_MAP = {
"mean_diff": MeanDiffC,
"pca": PCAC,
"directional_ablation": DirectionalAblationC,
}
@dataclass
class Args:
model: str
method: str
seed: int = 0
window: int = 50
run_id: str = ""
layer_frac: float = 0.6
target_kl: float = 1.0
out_root: str = "outputs"
device: str = "cuda"
dtype: str = "bfloat16"
suffix_str: str = ' Final answer in JSON: {"value": '
target_words: list[str] = field(default_factory=lambda: ["true", "false", "yes", "no"])
fork_step: int = 5
def _set_seed(s: int):
import random
import numpy as np
random.seed(s); np.random.seed(s); torch.manual_seed(s)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(s)
def main(a: Args):
if not a.run_id:
a.run_id = f"{a.model.split('/')[-1]}_{a.method}_s{a.seed}_w{a.window}"
out_dir = Path(a.out_root) / a.run_id
out_dir.mkdir(parents=True, exist_ok=True)
logger.add(out_dir / "run.log", level="INFO")
_set_seed(a.seed)
from transformers import AutoModelForCausalLM, AutoTokenizer
dtype = getattr(torch, a.dtype)
tok = AutoTokenizer.from_pretrained(a.model)
if tok.pad_token_id is None:
tok.pad_token_id = tok.eos_token_id
model = AutoModelForCausalLM.from_pretrained(a.model, torch_dtype=dtype).to(a.device)
model.eval()
n_layers = model.config.num_hidden_layers
layer = int(a.layer_frac * n_layers)
logger.info(f"model={a.model} n_layers={n_layers} target_layer={layer}")
cfg_cls = METHOD_MAP[a.method]
cfg = cfg_cls(coeff=1.0, layers=(layer,))
pos = [tok.apply_chat_template([{"role": "user", "content": u},
{"role": "assistant", "content": p}],
tokenize=False)
for u, (p, _) in zip(CALIB_PROMPTS, POS_NEG)]
neg = [tok.apply_chat_template([{"role": "user", "content": u},
{"role": "assistant", "content": n}],
tokenize=False)
for u, (_, n) in zip(CALIB_PROMPTS, POS_NEG)]
v = train(model, tok, pos, neg, cfg, batch_size=4, max_length=128)
logger.info("=== calibrate ===")
c_star, history = calibrate_iso_kl(
v, model, tok, CALIB_PROMPTS,
target_kl=a.target_kl, target_stat="kl_p95",
T=a.window, device=a.device,
)
v.cfg.coeff = c_star
logger.info(f"c_star = {c_star:+.4f}")
(out_dir / "history.json").write_text(json.dumps(history, indent=2))
(out_dir / "calib.json").write_text(json.dumps({
"c_star": c_star, "target_kl": a.target_kl, "window": a.window,
"method": a.method, "model": a.model, "seed": a.seed, "layer": layer,
}, indent=2))
# -- trajectory + pmass at alpha in {1, 2} on held-out prompts
rows = []
fork_points = list(range(0, a.window + 1, a.fork_step))
trajectory: dict[str, list] = {}
pmass_all: dict[str, list] = {}
for alpha in (1.0, 2.0):
v.cfg.coeff = alpha * c_star
logger.info(f"=== eval alpha={alpha} c={v.cfg.coeff:+.4f} ===")
m = measure_kl(v, model, tok, EVAL_PROMPTS, T=a.window, device=a.device)
trajectory[str(alpha)] = m["per_t_p95"]
rows.append({"alpha": alpha, "coeff": v.cfg.coeff, "kl_p95": m["kl_p95"],
"kl_mean": m["kl_mean"], "kl_max": m["kl_max"]})
# pmass per held-out prompt
pm_for_alpha = []
for p in EVAL_PROMPTS:
ids = tok.apply_chat_template(
[{"role": "user", "content": p}],
add_generation_prompt=True, return_tensors="pt",
).input_ids[0]
pad = tok.pad_token_id
with v(model):
gen = model.generate(
ids.unsqueeze(0).to(a.device),
max_new_tokens=a.window,
pad_token_id=pad, eos_token_id=tok.eos_token_id,
do_sample=False,
)[0, ids.shape[0]:]
pm = branch_pmass(
v, model, tok, ids, gen, fork_points,
a.suffix_str, a.target_words, device=a.device,
)
pm_for_alpha.append(pm["pmass"])
pmass_all[str(alpha)] = pm_for_alpha
(out_dir / "trajectory.json").write_text(json.dumps({
"fork_points_full": list(range(a.window)),
"per_t_p95_kl": trajectory,
}, indent=2))
(out_dir / "pmass.json").write_text(json.dumps({
"fork_points": fork_points,
"pmass": pmass_all,
"suffix": a.suffix_str,
"target_words": a.target_words,
}, indent=2))
import csv
with open(out_dir / "results.csv", "w", newline="") as f:
w = csv.DictWriter(f, fieldnames=["alpha", "coeff", "kl_p95", "kl_mean", "kl_max"])
w.writeheader()
for r in rows:
w.writerow(r)
logger.info(f"DONE -> {out_dir}")
if __name__ == "__main__":
main(tyro.cli(Args))
+27
View File
@@ -0,0 +1,27 @@
#!/usr/bin/env bash
# Sweep: model x method x seed x window. Edit the lists to taste.
set -euo pipefail
cd "$(dirname "$0")/.."
MODELS=("Qwen/Qwen2.5-0.5B-Instruct" "Qwen/Qwen2.5-1.5B-Instruct" "meta-llama/Llama-3.2-1B-Instruct")
METHODS=("mean_diff" "directional_ablation" "pca")
SEEDS=(0 1 2)
WINDOWS=(20 50)
for model in "${MODELS[@]}"; do
for method in "${METHODS[@]}"; do
for seed in "${SEEDS[@]}"; do
for window in "${WINDOWS[@]}"; do
run_id="$(basename "$model")_${method}_s${seed}_w${window}"
if [ -f "outputs/${run_id}/calib.json" ]; then
echo "skip ${run_id}"; continue
fi
echo "=== ${run_id} ==="
uv run --extra all python scripts/run_cell.py \
--model "$model" --method "$method" --seed "$seed" --window "$window"
done
done
done
done
uv run --extra all python scripts/aggregate.py --runs-root outputs --out figs
+34
View File
@@ -0,0 +1,34 @@
import os as _os
if _os.environ.get("BEARTYPE"):
from beartype.claw import beartype_this_package as _bt
_bt()
from .config import SteeringConfig, REGISTRY, register
from .extract import record_activations
from .attach import attach, detach, save, load, train
from .calibrate import measure_kl, calibrate_iso_kl
from . import variants # noqa: F401 triggers method + config registration
from .vector import Vector
from .variants.mean_diff import MeanDiffC
from .variants.pca import PCAC
from .variants.directional_ablation import DirectionalAblationC
__all__ = [
"SteeringConfig",
"MeanDiffC",
"PCAC",
"DirectionalAblationC",
"record_activations",
"train",
"attach",
"detach",
"save",
"load",
"measure_kl",
"calibrate_iso_kl",
"REGISTRY",
"register",
"Vector",
]
+243
View File
@@ -0,0 +1,243 @@
"""attach / detach / save / load. The whole runtime.
Variant protocol (uniform across both hook paths):
apply(mod, x, y, state, cfg) -> y_new
`mod` is the hooked module itself (a transformer block or a Linear); `x` is
its input, `y` its output. Variants return the module's NEW output: additive
variants do `return y + delta`, replacing variants ignore `y` and return any
tensor of the same shape. Same contract as lora-lite's `Variant.forward`.
Two hook paths, dispatched on `cfg.target_submodule`:
- `target_submodule is None` (default): hook each transformer block's
forward output. `mod = block`, `x = args[0]` (input residual),
`y = out[0]` (output hidden_states). State keyed by `int` (block index).
- `target_submodule = <regex>`: hook every nn.Linear in each selected block
whose dotted path matches the regex. `mod = linear`, `x` is the Linear's
input, `y` its output. State keyed by `str` (full dotted name like
`"layers.5.mlp.down_proj"`).
"""
from __future__ import annotations
import json
import torch
from torch import nn
from torch.utils.hooks import RemovableHandle
from .config import SteeringConfig, REGISTRY
from .target import find_targets
from .extract import record_activations
_ATTACHED_ATTR = "_steering_lite_attached"
_STATE_PREFIX = "_steering_state_"
_SUB_KEY_PREFIX = "sub::" # safetensors key prefix marking submodule-level state
_SUB_KEY_SEP = "::" # separator between full_name and state_key
def _gather_state(mod) -> dict[str, torch.Tensor]:
return {
k[len(_STATE_PREFIX):]: getattr(mod, k)
for k in dir(mod)
if k.startswith(_STATE_PREFIX) and isinstance(getattr(mod, k, None), torch.Tensor)
}
def _hook(mod, args, out):
"""Forward hook for block-level variants. Block forward returns a tuple
`(hidden_states, ...)`; we replace `[0]` with the variant's output."""
cfg: SteeringConfig = mod._steering_cfg
method = mod._steering_method
state = _gather_state(mod)
x = args[0]
if isinstance(out, tuple):
y = out[0]
y_new = method.apply(mod, x, y, state, cfg)
return (y_new,) + out[1:]
return method.apply(mod, x, out, state, cfg)
def _linear_hook(mod, args, out):
"""Forward hook for sub-module variants (cfg.target_submodule is set).
`out` is the Linear's output tensor (not a tuple)."""
cfg: SteeringConfig = mod._steering_cfg
method = mod._steering_method
state = _gather_state(mod)
return method.apply(mod, args[0], out, state, cfg)
def _install_state(mod: nn.Module, state: dict[str, torch.Tensor], cfg: SteeringConfig) -> None:
for k, v in state.items():
attr = _STATE_PREFIX + k
if hasattr(mod, attr):
raise RuntimeError(f"module already has {attr}; detach first")
mod.register_buffer(attr, v.to(cfg.dtype), persistent=True)
def attach(
model: nn.Module,
cfg: SteeringConfig,
vectors,
) -> list[RemovableHandle]:
"""Install per-target state as buffers and register forward hooks.
`vectors` shape depends on cfg.target_submodule:
- None: dict[int, dict[str, Tensor]] keyed by block layer index.
- regex set: dict[str, dict[str, Tensor]] keyed by full dotted name.
Accepts a `Vector` (auto-unwrapped to its `.state`).
"""
from .vector import Vector
if isinstance(vectors, Vector):
vectors = vectors.state
if cfg.method not in REGISTRY:
raise KeyError(f"unknown method {cfg.method!r}; registered: {list(REGISTRY)}")
method = REGISTRY[cfg.method]
# variant-level default target_submodule (e.g. sspace -> residual writers)
if cfg.target_submodule is None and getattr(method, "default_target_submodule", None):
cfg.target_submodule = method.default_target_submodule
requires_linear = cfg.target_submodule is not None
targets = find_targets(model, cfg)
if not targets:
raise RuntimeError("no target layers matched cfg")
handles: list[RemovableHandle] = []
attached_names: list[str] = []
hooked_modules: list[nn.Module] = []
for full_name, mod, li in targets:
key = full_name if requires_linear else li
if key not in vectors:
raise KeyError(f"vectors missing key {key!r}; have {sorted(vectors)}")
_install_state(mod, vectors[key], cfg)
mod._steering_cfg = cfg
mod._steering_method = method
if requires_linear:
mod._steering_module_name = full_name
hooked_modules.append(mod)
handles.append(mod.register_forward_hook(_linear_hook))
else:
mod._steering_layer_idx = li
handles.append(mod.register_forward_hook(_hook))
attached_names.append(full_name)
setattr(model, _ATTACHED_ATTR, {
"cfg": cfg, "targets": attached_names, "handles": handles,
"hooked_modules": hooked_modules,
})
return handles
def detach(model: nn.Module) -> None:
state = getattr(model, _ATTACHED_ATTR, None)
if state is None:
return
for h in state["handles"]:
h.remove()
for _, mod in model.named_modules():
if not hasattr(mod, "_steering_method"):
continue
for k in [k for k in list(mod._buffers) if k.startswith(_STATE_PREFIX)]:
del mod._buffers[k]
for attr in (
"_steering_cfg", "_steering_method",
"_steering_layer_idx", "_steering_module_name",
):
if hasattr(mod, attr):
delattr(mod, attr)
delattr(model, _ATTACHED_ATTR)
def _log_extract_demo(tok, pos_prompts: list[str], neg_prompts: list[str]) -> None:
"""One trace per stage: decoded full prompt + special tokens, for format debugging."""
from loguru import logger
pos = pos_prompts[0]
neg = neg_prompts[0]
logger.info(
"EXPECT: POS and NEG share user_msg + suffix; differ only in system persona; "
"chat template applied; special tokens (e.g. <|im_start|>) visible.\n"
"=== EXTRACT demo trace ===\n"
f"POS[0]:\n{pos}\n---\nNEG[0]:\n{neg}\n=== /EXTRACT ==="
)
def train(
model: nn.Module,
tok,
pos_prompts: list[str],
neg_prompts: list[str],
cfg: SteeringConfig,
*,
batch_size: int = 8,
max_length: int = 256,
):
"""Extract activations + run method.extract -> Vector. Block-level only.
Stripped from steering-lite: no submodule regex hooks, no attn pooling.
"""
from .vector import Vector
_log_extract_demo(tok, pos_prompts, neg_prompts)
method = REGISTRY[cfg.method]
targets = find_targets(model, cfg)
layers = tuple(li for _, _, li in targets)
pos_acts = record_activations(model, tok, pos_prompts, layers, batch_size=batch_size, max_length=max_length)
neg_acts = record_activations(model, tok, neg_prompts, layers, batch_size=batch_size, max_length=max_length)
state = method.extract(pos_acts, neg_acts, cfg)
return Vector(cfg, state)
def _state_to_safetensors_dict(model: nn.Module) -> dict:
"""Serialise installed state buffers from all hooked modules. Forks on
whether the module is block-level (_steering_layer_idx) or submodule-level
(_steering_module_name); keys distinguish the two so load() can rebuild."""
sd = {}
for _, mod in model.named_modules():
if not hasattr(mod, "_steering_method"):
continue
if hasattr(mod, "_steering_module_name"):
full_name = mod._steering_module_name
for k, v in mod._buffers.items():
if k.startswith(_STATE_PREFIX):
sd[f"{_SUB_KEY_PREFIX}{full_name}{_SUB_KEY_SEP}{k[len(_STATE_PREFIX):]}"] = v.detach().cpu()
elif hasattr(mod, "_steering_layer_idx"):
li = mod._steering_layer_idx
for k, v in mod._buffers.items():
if k.startswith(_STATE_PREFIX):
sd[f"layer{li}.{k[len(_STATE_PREFIX):]}"] = v.detach().cpu()
return sd
def _safetensors_dict_to_state(sd: dict[str, torch.Tensor]) -> dict:
"""Inverse of _state_to_safetensors_dict. Returns dict keyed by int (block-level)
or str (submodule-level), depending on the prefix of each key."""
vectors: dict = {}
for k, v in sd.items():
if k.startswith(_SUB_KEY_PREFIX):
rest = k[len(_SUB_KEY_PREFIX):]
full_name, _, state_key = rest.rpartition(_SUB_KEY_SEP)
vectors.setdefault(full_name, {})[state_key] = v
else:
layer_part, _, sub = k.partition(".")
li = int(layer_part.removeprefix("layer"))
vectors.setdefault(li, {})[sub] = v
return vectors
def save(model: nn.Module, path: str) -> None:
state = getattr(model, _ATTACHED_ATTR, None)
if state is None:
raise RuntimeError("no steering attached; call attach() first")
sd = _state_to_safetensors_dict(model)
metadata = {"cfg": json.dumps(state["cfg"].to_dict())}
from safetensors.torch import save_file
save_file(sd, path, metadata=metadata)
def load(model: nn.Module, path: str) -> list[RemovableHandle]:
from safetensors.torch import load_file, safe_open
with safe_open(path, framework="pt", device="cpu") as f:
metadata = f.metadata()
sd = load_file(path, device="cpu")
cfg = SteeringConfig.from_dict(json.loads(metadata["cfg"]))
vectors = _safetensors_dict_to_state(sd)
return attach(model, cfg, vectors)
+90
View File
@@ -0,0 +1,90 @@
"""Branch-and-teacher-force pmass: coherence metric for steered generation.
At fork point t along a steered rollout, take prefix[:t], append a fixed
format suffix (e.g. `{"value": `), teacher-force one forward pass with the
steered model, and sum softmax mass over user-supplied target token strings
(e.g. `["true", "false"]`). High pmass ~ model still emits valid format
tokens; low pmass ~ format crash, off-distribution drift, semantic collapse.
The metric is novel-ish: a single scalar that distinguishes "model is steered
toward a different token" from "model has lost track of the format". A target
direction can move pmass off 1.0 by reweighting between target tokens but
should not drop pmass to ~0. A miscalibrated coeff drops pmass to noise.
Returns Float[Tensor, "f"] over fork points.
"""
from __future__ import annotations
from typing import Sequence
import torch
from torch import nn, Tensor
from .vector import Vector
def _all_token_ids(tok, words: Sequence[str]) -> list[int]:
"""Collect the leading token id for each word in several capitalisation /
leading-space variants. Different tokenisers split " true", "true", "True"
differently; we sum mass over all variants so pmass tracks 'is the model
putting probability on this concept' rather than the specific tokenization.
"""
ids: set[int] = set()
for w in words:
for variant in (w, " " + w, w.capitalize(), " " + w.capitalize(),
w.upper(), " " + w.upper()):
try:
t = tok.encode(variant, add_special_tokens=False)
except Exception:
continue
if len(t) >= 1:
ids.add(int(t[0]))
return sorted(ids)
@torch.no_grad()
def branch_pmass(
v: Vector,
model: nn.Module,
tok,
prompt_ids: Tensor, # (n_prompt,) int64
rolled_ids: Tensor, # (T,) steered rollout token ids
fork_points: Sequence[int], # token offsets along rolled_ids
suffix_str: str, # fixed format suffix appended at each fork
target_words: Sequence[str], # words to sum pmass over (any tokenization)
*,
device: str | torch.device = "cuda",
) -> dict:
"""Returns {"pmass": [f], "fork_points": [f], "target_ids": [...], "suffix_ids": [...]}
Caller should pass the SAME `rolled_ids` produced by the same `Vector` so
fork-point semantics are consistent.
"""
suffix_ids = tok.encode(suffix_str, add_special_tokens=False)
suffix_t = torch.tensor(suffix_ids, dtype=torch.long, device=device)
target_ids = _all_token_ids(tok, target_words)
if not target_ids:
raise ValueError(f"no target ids found for words={target_words}")
target_idx = torch.tensor(target_ids, dtype=torch.long, device=device)
pmass = []
pids = prompt_ids.to(device)
rolled = rolled_ids.to(device)
T = rolled.shape[0]
for t in fork_points:
if t > T:
pmass.append(float("nan"))
continue
prefix = rolled[:t]
full = torch.cat([pids, prefix, suffix_t]).unsqueeze(0)
with v(model):
logits = model(full).logits[0, -1].float()
probs = torch.softmax(logits, dim=-1)
pmass.append(float(probs[target_idx].sum()))
return {
"pmass": pmass,
"fork_points": list(fork_points),
"target_ids": target_ids,
"suffix_ids": suffix_ids,
}
+247
View File
@@ -0,0 +1,247 @@
"""Iso-KL calibration with per-token KL trajectory persistence.
Forked from steering-lite/calibrate.py. Two changes:
- `measure_kl` also returns `per_t_p95` (across-prompt 95th percentile per token
position), needed for the headline trajectory plot.
- `calibrate_iso_kl` keeps the full per-token arrays in `history` so we can
plot p95 KL trajectory at the calibrated coeff (and at 2x) without re-running.
"""
from __future__ import annotations
import math
from typing import Callable
import torch
from loguru import logger
from torch import Tensor
from torch import nn
from tqdm.auto import tqdm
from .config import SteeringConfig # noqa: F401
from .vector import Vector
_demo_logged = {"flag": False}
DEFAULT_MESSAGES = [
"The eiffel tower is in Paris",
"埃菲尔铁塔🗼位于天都城",
"Tell me a greentext story about a small village during the smaller Martion carrot bubble.",
"Think step by step to calculate the integral of x^2 from 0 to 1 in lean4.",
]
def _tokenize(prompts, tok):
if prompts is None:
prompts = DEFAULT_MESSAGES
if isinstance(prompts[0], str):
return [
tok.apply_chat_template(
[{"role": "user", "content": p}],
add_generation_prompt=True, return_tensors="pt",
).input_ids[0]
for p in prompts
]
return prompts
@torch.no_grad()
def _kl_per_pos(logp_steer: Tensor, logp_base: Tensor) -> Tensor:
p_s = logp_steer.exp()
return (p_s * (logp_steer - logp_base)).sum(dim=-1)
@torch.no_grad()
def _generate(model, prompt_ids, T, tok, do_sample, device):
pad_id = tok.pad_token_id if tok.pad_token_id is not None else tok.eos_token_id
ids = prompt_ids.unsqueeze(0).to(device)
out = model.generate(
ids, max_new_tokens=T, pad_token_id=pad_id, eos_token_id=tok.eos_token_id,
num_return_sequences=1, do_sample=do_sample,
)
return out[0, prompt_ids.shape[0]:]
def _quantile(xs: list[float], q: float) -> float:
if not xs:
return 0.0
return float(torch.tensor(xs).quantile(q))
@torch.no_grad()
def measure_kl(
v: Vector,
model: nn.Module,
tok,
prompts=None,
*,
T: int = 50,
do_sample: bool = False,
device: str | torch.device = "cuda",
) -> dict:
"""Roll out under steering, score under base+steer. Returns scalar stats
plus per-token arrays (mean, p95, max) of length T.
"""
prompts = _tokenize(prompts, tok)
all_kls = []
per_t = [[] for _ in range(T)]
for idx, pids in enumerate(tqdm(prompts, desc="measure_kl", mininterval=60)):
with v(model):
gen = _generate(model, pids, T, tok, do_sample, device)
n_gen = gen.shape[0]
if n_gen == 0:
continue
full_ids = torch.cat([pids.to(device), gen])
if idx == 0 and not _demo_logged["flag"]:
_demo_logged["flag"] = True
base_gen = _generate(model, pids, T, tok, do_sample, device)
base_full = torch.cat([pids.to(device), base_gen])
decoded_base = tok.decode(base_full, skip_special_tokens=False)
decoded_steer = tok.decode(full_ids, skip_special_tokens=False)
logger.info(
f"EXPECT: same prompt under c=0 vs c={v.cfg.coeff:+.4f}; both coherent; "
"steered should differ from base but not collapse.\n"
f"\n=== CALIBRATE demo trace (T={T}) ===\n"
f"--- BASE (c=0) ---\n{decoded_base}\n"
f"\n--- STEER (c={v.cfg.coeff:+.4f}) ---\n{decoded_steer}\n"
f"=== /CALIBRATE ==="
)
full = full_ids.unsqueeze(0)
n_p = pids.shape[0]
logp_base = torch.log_softmax(model(full).logits.float(), dim=-1)[0]
with v(model):
logp_steer = torch.log_softmax(model(full).logits.float(), dim=-1)[0]
slc = slice(n_p - 1, n_p - 1 + n_gen)
kls = _kl_per_pos(logp_steer[slc], logp_base[slc]).cpu()
all_kls.append(kls)
for i in range(n_gen):
per_t[i].append(float(kls[i]))
cat = torch.cat(all_kls)
return {
"kl_mean": float(cat.mean()),
"kl_p50": float(cat.quantile(0.50)),
"kl_p90": float(cat.quantile(0.90)),
"kl_p95": float(cat.quantile(0.95)),
"kl_max": float(cat.max()),
"n_pos": int(cat.numel()),
"per_t_mean": [sum(xs) / len(xs) if xs else 0.0 for xs in per_t],
"per_t_p95": [_quantile(xs, 0.95) for xs in per_t],
"per_t_max": [max(xs) if xs else 0.0 for xs in per_t],
}
def calibrate_iso_kl(
v: Vector,
model: nn.Module,
tok,
prompts=None,
*,
target_kl: float = 1.0,
target_stat: str = "kl_p95",
bracket: tuple[float, float] = (0.01, 16.0),
tol: float = 0.05,
max_iters: int = 12,
T: int = 50,
device: str | torch.device = "cuda",
sign: float = 1.0,
sign_probe: Callable[[Vector], float] | None = None,
sign_probe_c: float = 1.0,
) -> tuple[float, list[dict]]:
"""log-log Illinois bisection on `target_stat`. History keeps per-token
arrays so we can plot the trajectory after."""
_demo_logged["flag"] = False
prompts = _tokenize(prompts, tok)
history: list[dict] = []
if sign_probe is not None:
v.cfg.coeff = +sign_probe_c
score_pos = sign_probe(v)
v.cfg.coeff = -sign_probe_c
score_neg = sign_probe(v)
chosen = +1.0 if score_pos >= score_neg else -1.0
logger.info(
f"sign_probe: +c={sign_probe_c:+.2f} -> {score_pos:+.3f} | "
f"-c={-sign_probe_c:+.2f} -> {score_neg:+.3f} | "
f"chosen sign={chosen:+.0f}"
)
sign = sign * chosen
def eval_at(c: float) -> float:
v.cfg.coeff = sign * c
m = measure_kl(v, model, tok, prompts, T=T, do_sample=False, device=device)
history.append({"coeff": sign * c, "coeff_abs": c, "sign": sign, **m})
logger.info(f" c={sign * c:+.4f} mean={m['kl_mean']:.3f} "
f"p50={m['kl_p50']:.3f} p90={m['kl_p90']:.3f} "
f"p95={m['kl_p95']:.3f} max={m['kl_max']:.3f} n={m['n_pos']}")
return m[target_stat]
lo, hi = bracket
log_target = math.log(target_kl)
mid = (lo * hi) ** 0.5
v_mid = eval_at(mid)
if v_mid < target_kl:
c_lo, v_lo = mid, v_mid
c = mid
c_hi, v_hi = hi, None
while c < hi:
c *= 2.0
val = eval_at(c)
if val >= target_kl:
c_hi, v_hi = c, val
break
c_lo, v_lo = c, val
else:
logger.warning(f"calibrate {v.cfg.method}: KL below target across bracket")
return sign * c, history
else:
c_hi, v_hi = mid, v_mid
c = mid
c_lo, v_lo = lo, None
while c > lo:
c /= 2.0
val = eval_at(c)
if val <= target_kl:
c_lo, v_lo = c, val
break
c_hi, v_hi = c, val
else:
logger.warning(f"calibrate {v.cfg.method}: KL above target across bracket")
return sign * c, history
stale_lo = stale_hi = 0
for _ in tqdm(range(max_iters), desc=f"calib {v.cfg.method}", mininterval=60, leave=False):
if v_lo is not None and v_hi is not None and v_lo > 0 and v_hi > 0:
log_c_lo, log_c_hi = math.log(c_lo), math.log(c_hi)
log_v_lo = math.log(v_lo) - (math.log(2) if stale_lo >= 2 else 0.0)
log_v_hi = math.log(v_hi) - (math.log(2) if stale_hi >= 2 else 0.0)
denom = log_v_hi - log_v_lo
if abs(denom) < 1e-6:
c_new = math.sqrt(c_lo * c_hi)
else:
t = (log_target - log_v_lo) / denom
log_c_new = log_c_lo + t * (log_c_hi - log_c_lo)
c_new = math.exp(log_c_new)
if not (c_lo < c_new < c_hi):
c_new = math.sqrt(c_lo * c_hi)
else:
c_new = math.sqrt(c_lo * c_hi)
v_new = eval_at(c_new)
if abs(v_new - target_kl) < tol:
return sign * c_new, history
if v_new < target_kl:
c_lo, v_lo = c_new, v_new
stale_lo = 0
stale_hi += 1
else:
c_hi, v_hi = c_new, v_new
stale_hi = 0
stale_lo += 1
best = min(history, key=lambda h: abs(h[target_stat] - target_kl))
return best["coeff"], history
+112
View File
@@ -0,0 +1,112 @@
"""SteeringConfig + Method protocol + registries.
Each method ships its own subclass `XC(SteeringConfig)` and `XMethod` class
under `variants/*.py` (e.g. `MeanDiffC` + `MeanDiff`). Two parallel registries
keyed by method name: `_CONFIG_REGISTRY` for `from_dict` deserialisation,
`REGISTRY` for the runtime extract/apply pair.
"""
from dataclasses import dataclass, asdict, field
from typing import Protocol, Any
import torch
from torch import Tensor
@dataclass
class SteeringConfig:
"""Base config. Subclass per method; do not instantiate directly."""
method: str = "?"
# which transformer blocks to hook (indices into model.model.layers)
# None = all layers
layers: tuple[int, ...] | None = None
# which point in the block to add at: "residual" = block output (post mlp+attn),
# "attn_out" = attention output, "mlp_out" = mlp output.
# v1 only implements "residual".
target: str = "residual"
# Optional dotted path of a sub-module within each target block to hook on
# (e.g. "mlp.down_proj"). When None, the block's forward output is hooked
# (default for almost all variants); when set, the sub-module's forward is
# hooked instead. Either way the variant's apply receives the unified
# (block, x, y, state, cfg) signature -- used by weight-SVD methods
# (sspace, sspace_ablate) that need to modify a Linear's output in low-rank
# S-space.
target_submodule: str | None = None
# steering strength at apply-time. Methods interpret it differently:
# additive (mean_diff, pca, sspace, chars, cosine_gated): coeff is α in `h + α v`.
# slerp/angle (spherical, angular_steering): coeff is the slerp t / rotation θ.
# blend (linear_act): coeff is the blend ratio.
# ablation+nudge (directional_ablation): coeff is post-ablation nudge magnitude.
coeff: float = 1.0
dtype: torch.dtype = torch.bfloat16
seed: int = 0
def to_dict(self) -> dict:
d = asdict(self)
d["dtype"] = str(self.dtype).removeprefix("torch.")
return d
@classmethod
def from_dict(cls, d: dict) -> "SteeringConfig":
d = dict(d)
name = d["method"]
sub = _CONFIG_REGISTRY[name]
d["dtype"] = getattr(torch, d["dtype"])
return sub(**d)
_CONFIG_REGISTRY: dict[str, type[SteeringConfig]] = {}
def register_config(cls: type[SteeringConfig]) -> type[SteeringConfig]:
"""Decorator: register `cls` under its `method` default value."""
name = cls.__dataclass_fields__["method"].default
if name == "?":
raise ValueError(f"{cls} must override the default `method` field")
if name in _CONFIG_REGISTRY:
raise ValueError(f"config for method {name!r} already registered")
_CONFIG_REGISTRY[name] = cls
return cls
class Method(Protocol):
"""extract+apply pair. State tensors are registered as buffers on the hooked
module (block or Linear) under `_steering_state_<key>` and rebuilt into a
dict by the hook.
"""
name: str
@staticmethod
def extract(
pos_acts: dict[int, Tensor],
neg_acts: dict[int, Tensor],
cfg: Any,
) -> dict[int, dict[str, Tensor]]:
"""Per-layer state. `pos_acts[l]` is `[n_pos, d_model]`, same for neg."""
...
@staticmethod
def apply(
mod, # the hooked module: a transformer block, or a Linear
x: Tensor, # [b, s, d_in] -- module input
y: Tensor, # [b, s, d_out] -- module output
state: dict[str, Tensor],
cfg: Any,
) -> Tensor:
"""Return the module's NEW output. Additive variants: `return y + delta`.
Replacing variants: ignore `y`, return any tensor of shape `[b, s, d_out]`.
"""
...
REGISTRY: dict[str, type] = {}
def register(cls):
if not getattr(cls, "name", None):
raise ValueError(f"method {cls} missing .name")
REGISTRY[cls.name] = cls
return cls
+58
View File
@@ -0,0 +1,58 @@
"""Record last non-pad-token hidden states at selected layers via forward hooks.
We hook each block's forward output (it returns `(hidden_states, ...)`), gather
the final non-padding token from `attention_mask`, and stack across prompts.
No grad.
"""
from __future__ import annotations
import torch
from torch import nn, Tensor
from jaxtyping import Float
from .target import _get_blocks
@torch.no_grad()
def record_activations(
model: nn.Module,
tok,
prompts: list[str],
layers: tuple[int, ...],
*,
batch_size: int = 8,
max_length: int = 256,
) -> dict[int, Float[Tensor, "n d"]]:
"""Run prompts through model, return last-token hidden state at each layer."""
blocks = _get_blocks(model)
device = next(model.parameters()).device
bucket: dict[int, list[Tensor]] = {l: [] for l in layers}
captured: dict[int, Tensor] = {}
def make_hook(li: int):
def hook(_mod, _args, out):
h = out[0] if isinstance(out, tuple) else out
captured[li] = h
return hook
handles = [blocks[li].register_forward_hook(make_hook(li)) for li in layers]
try:
was_training = model.training
model.eval()
for i in range(0, len(prompts), batch_size):
batch = prompts[i : i + batch_size]
enc = tok(batch, return_tensors="pt", padding=True, truncation=True, max_length=max_length).to(device)
captured.clear()
model(**enc)
mask = enc["attention_mask"]
last_idx = mask.shape[1] - 1 - mask.flip([-1]).argmax(-1)
batch_idx = torch.arange(mask.shape[0], device=last_idx.device)
for li in layers:
bucket[li].append(captured[li][batch_idx, last_idx].detach().to("cpu"))
if was_training:
model.train()
finally:
for h in handles:
h.remove()
return {li: torch.cat(bucket[li], dim=0) for li in layers}
+148
View File
@@ -0,0 +1,148 @@
"""Find transformer blocks (or sub-Linears) to hook.
Default: hook each block's forward output (residual stream after attn+mlp).
When `cfg.target_submodule` is set, it is interpreted as a **regex** matched
against `block.named_modules()` paths under each selected block; matching
`nn.Linear`s become the actual hook targets. This lets a single cfg target
multiple Linears per block (e.g. residual writers `mlp.down_proj` AND
`self_attn.o_proj`, or all 7 Linears in q/k/v/o/gate/up/down).
Works with HF llama-family architectures (llama, qwen, mistral, etc). For other
architectures, set `cfg.layers` to indices into whatever list lives at the path
your model exposes -- override `_get_blocks` if needed.
"""
import re
from torch import nn
def _get_blocks(model: nn.Module) -> nn.ModuleList:
# llama-family: model.model.layers
# gemma3-multimodal: model.language_model.layers (or model.model.language_model.layers)
candidates = []
inner = getattr(model, "model", model)
candidates.append(inner)
lm = getattr(inner, "language_model", None)
if lm is not None:
candidates.append(lm)
candidates.append(getattr(lm, "model", lm))
for c in candidates:
blocks = getattr(c, "layers", None)
if blocks is not None:
return blocks
raise RuntimeError(
f"could not find .layers on {type(model).__name__}; "
f"override _get_blocks for non-llama architectures"
)
def find_targets(model: nn.Module, cfg) -> list[tuple[str, nn.Module, int]]:
"""Return [(full_name, module, layer_idx)] for hook targets selected by cfg.
- `cfg.target_submodule is None`: one entry per selected block (the block itself).
- `cfg.target_submodule = <regex>`: one entry per (block, matching nn.Linear).
Regex is matched with `re.fullmatch` against `name` from `block.named_modules()`,
e.g. `r"mlp\\.down_proj|self_attn\\.o_proj"` matches both residual writers,
`r"self_attn\\.(q|k|v|o)_proj|mlp\\.(gate|up|down)_proj"` matches all 7 Linears.
"""
blocks = _get_blocks(model)
n = len(blocks)
if cfg.layers is None:
idxs = tuple(range(n))
else:
idxs = tuple(cfg.layers)
for i in idxs:
if not (0 <= i < n):
raise ValueError(f"layer {i} out of range [0, {n})")
sub = getattr(cfg, "target_submodule", None)
if sub is None:
return [(f"layers.{i}", blocks[i], i) for i in idxs]
pattern = re.compile(sub)
out = []
for i in idxs:
for name, mod in blocks[i].named_modules():
if name and pattern.fullmatch(name) and isinstance(mod, nn.Linear):
out.append((f"layers.{i}.{name}", mod, i))
if not out:
raise RuntimeError(
f"target_submodule regex {sub!r} matched no nn.Linear "
f"in {len(idxs)} selected blocks"
)
return out
def find_residual_linears(
model: nn.Module,
layer_indices: tuple[int, ...] | None = None,
*,
role: str = "both", # "writer" | "reader" | "both"
fallback_regex: str | None = None,
) -> list[tuple[str, nn.Module, int, str]]:
"""Find Linears connected to the residual stream, per block.
Returns `[(full_name, module, layer_idx, role)]` where role is "writer"
(d_out == d_model, d_in != d_model) or "reader" (d_in == d_model,
d_out != d_model). Square Linears are ambiguous (could be either) and
are excluded by shape detection.
Detection:
1. Primary: weight.shape vs d_model.
2. Fallback: if shape detection finds nothing (non-llama arch, weird
shapes), match `fallback_regex` against `named_modules` paths and
guess role from substring. Default regex covers llama-family names.
Warns when fallback fires.
"""
from loguru import logger
d_model = get_d_model(model)
blocks = _get_blocks(model)
n = len(blocks)
idxs = tuple(layer_indices) if layer_indices is not None else tuple(range(n))
out: list[tuple[str, nn.Module, int, str]] = []
for li in idxs:
for name, mod in blocks[li].named_modules():
if not isinstance(mod, nn.Linear):
continue
d_out, d_in = mod.weight.shape
is_writer = d_out == d_model and d_in != d_model
is_reader = d_in == d_model and d_out != d_model
if is_writer and role in ("writer", "both"):
out.append((f"layers.{li}.{name}", mod, li, "writer"))
elif is_reader and role in ("reader", "both"):
out.append((f"layers.{li}.{name}", mod, li, "reader"))
if out:
return out
regex = fallback_regex or r"mlp\.(down|gate|up)_proj|self_attn\.(q|k|v|o)_proj"
logger.warning(
f"shape-based residual-linear detection found nothing for d_model={d_model} "
f"in {len(idxs)} blocks; falling back to regex {regex!r}"
)
pattern = re.compile(regex)
writer_hints = ("down_proj", "o_proj")
for li in idxs:
for name, mod in blocks[li].named_modules():
if not (name and pattern.fullmatch(name) and isinstance(mod, nn.Linear)):
continue
role_guess = "writer" if any(h in name for h in writer_hints) else "reader"
if role in ("both", role_guess):
out.append((f"layers.{li}.{name}", mod, li, role_guess))
if not out:
logger.warning(
f"regex fallback {regex!r} also matched no Linears in layers {idxs}; "
"super_sspace will have an empty basis"
)
return out
def get_d_model(model: nn.Module) -> int:
cfg = getattr(model, "config", None)
d = getattr(cfg, "hidden_size", None)
if d is None:
# multimodal configs (gemma3): text sub-config
text_cfg = getattr(cfg, "text_config", None)
d = getattr(text_cfg, "hidden_size", None)
if d is None:
raise RuntimeError("model has no .config.hidden_size")
return int(d)
+6
View File
@@ -0,0 +1,6 @@
"""Variant registry. Importing this package triggers @register_config + @register
side effects in the variant modules.
"""
from . import mean_diff # noqa: F401
from . import pca # noqa: F401
from . import directional_ablation # noqa: F401
@@ -0,0 +1,75 @@
"""Mean-diff directional ablation (Arditi-inspired projection-out).
Project the steering direction *out of* the residual stream instead of (or in
addition to) adding to it. Unlike `mean_diff` which translates by $\\alpha v$,
ablation removes the component of $h$ along $\\hat v$:
$$h \\leftarrow h - (h \\cdot \\hat v)\\hat v + \\alpha\\hat v$$
When `coeff=0` this is pure ablation (refusal-direction style); when `coeff!=0`
this is ablation followed by a constant nudge (useful to ablate "old" behavior
and inject "new"). The two terms are mathematically distinct -- ablation is a
*projection* (idempotent), addition is a *translation*.
Norms shrink by $|h \\cdot \\hat v|$ which is informative -- a near-zero shrink
means the direction wasn't present in the first place, so the intervention is
a no-op. Compare to `mean_diff` which always pays a constant $\\alpha\\|\\hat v\\|$
per token regardless of whether the direction is present.
Refs / inspiration:
- Arditi et al. 2024 "Refusal in language models is mediated by a single direction"
https://arxiv.org/abs/2406.11717
- andyrdt/refusal_direction https://github.com/andyrdt/refusal_direction
"""
from dataclasses import dataclass
from einops import einsum
from jaxtyping import Float
from torch import Tensor
from ..config import SteeringConfig, register_config, register
ε = 1e-8
@register_config
@dataclass
class DirectionalAblationC(SteeringConfig):
method: str = "directional_ablation"
coeff: float = 0.0 # post-ablation additive nudge along v_hat; 0.0 = pure ablation
@register
class DirectionalAblation:
name = "directional_ablation"
@staticmethod
def extract(
pos_acts: dict[int, Float[Tensor, "n d"]],
neg_acts: dict[int, Float[Tensor, "m d"]],
cfg: DirectionalAblationC,
) -> dict[int, dict[str, Tensor]]:
out = {}
for li in pos_acts:
v = pos_acts[li].float().mean(0) - neg_acts[li].float().mean(0)
v = v / (v.norm() + ε)
out[li] = {"v": v}
return out
@staticmethod
def apply(
mod,
x: Float[Tensor, "b s d"],
y: Float[Tensor, "b s d"],
state: dict[str, Tensor],
cfg: DirectionalAblationC,
) -> Float[Tensor, "b s d"]:
v = state["v"].to(y) # unit
proj = einsum(y, v, "b s d, d -> b s")
y = y - proj.unsqueeze(-1) * v # ablate
if cfg.coeff != 0.0:
y = y + cfg.coeff * v
return y
+82
View File
@@ -0,0 +1,82 @@
"""Mean-difference steering (CAA / ActAdd).
For each selected layer L, compute the mean difference between positive and
negative last-token hidden states:
$$v_L = \\text{mean}(h^+_L) - \\text{mean}(h^-_L), \\quad \\hat{v}_L = v_L / \\|v_L\\|$$
At runtime, add `coeff * v_hat` to every token's residual at that block:
$$h \\leftarrow h + \\alpha \\cdot \\hat{v}_L$$
This is the same operation as CAA (Panickssery 2023, contrastive MCQ pairs)
and ActAdd (Turner 2023, single prompt-pair); the differences are conventional
not mathematical, so we register one method.
`subtract_corpus_mean=True` toggles Jorgensen 2024 mean-centring: target mean
minus posneg corpus mean. Direction-identical to plain mean_diff under
normalization with equal-size groups; kept as a flag rather than a separate
method.
Refs:
- Panickssery 2023 (CAA) https://arxiv.org/abs/2312.06681
- Turner 2023 (ActAdd) https://arxiv.org/abs/2308.10248
- Jorgensen 2024 (Mean-Centring) https://arxiv.org/abs/2312.03813
- nrimsky/CAA https://github.com/nrimsky/CAA
"""
from dataclasses import dataclass
import torch
from jaxtyping import Float
from torch import Tensor
from ..config import SteeringConfig, register_config, register
ε = 1e-8
@register_config
@dataclass
class MeanDiffC(SteeringConfig):
method: str = "mean_diff"
normalize: bool = True
subtract_corpus_mean: bool = False
@register
class MeanDiff:
name = "mean_diff"
@staticmethod
def extract(
pos_acts: dict[int, Float[Tensor, "n d"]],
neg_acts: dict[int, Float[Tensor, "m d"]],
cfg: MeanDiffC,
) -> dict[int, dict[str, Tensor]]:
out = {}
for li in pos_acts:
p = pos_acts[li].float()
n = neg_acts[li].float()
if cfg.subtract_corpus_mean:
mu = torch.cat([p, n], dim=0).mean(0)
v = p.mean(0) - mu
else:
v = p.mean(0) - n.mean(0)
if cfg.normalize:
v = v / (v.norm() + ε)
out[li] = {"v": v}
return out
@staticmethod
def apply(
mod,
x: Float[Tensor, "b s d"],
y: Float[Tensor, "b s d"],
state: dict[str, Tensor],
cfg: MeanDiffC,
) -> Float[Tensor, "b s d"]:
v = state["v"].to(y)
return y + cfg.coeff * v
+101
View File
@@ -0,0 +1,101 @@
"""PCA steering (RepE/LAT-inspired, vgel pca_diff-like).
For each layer L, compute PCA on the **paired differences** `h^+ - h^-`. Take
the top principal component as the steering direction.
$$D_L = H^+_L - H^-_L \\in \\mathbb{R}^{n\\times d}$$
$$U, S, V^T = \\text{SVD}(D_L - \\bar{D}_L)$$
$$\\text{sign}_L = \\text{sign}\\left(\\sum_i \\mathbb{1}[(D_L)_i \\cdot V_{:,0} > 0] - n/2\\right)$$
$$v_L = V_{:,0} \\cdot \\text{sign}_L$$
Sign-fixed by majority vote of paired-diff projections (repeng/vgel style).
This is a lightweight control-vector baseline, not the full Zou et al. LAT
reader: it omits per-diff normalization, label-based sign selection, and
train-mean recentering for reading scores.
PCA is sign-ambiguous; the vote is more robust than alignment-with-the-mean
when paired diffs are heterogeneous (mean can cancel without the vote
changing). If the vote ties exactly, orient the axis so the largest centered
projection is positive.
At runtime, add `coeff * v_L` to the residual.
Refs:
- Zou et al. 2023 (Representation Engineering) https://arxiv.org/abs/2310.01405
- vgel/repeng: https://github.com/vgel/repeng
"""
from dataclasses import dataclass
import torch
from jaxtyping import Float
from torch import Tensor
from ..config import SteeringConfig, register_config, register
ε = 1e-8
@register_config
@dataclass
class PCAC(SteeringConfig):
method: str = "pca"
n_components: int = 1
normalize: bool = True
@register
class PCA:
name = "pca"
@staticmethod
def extract(
pos_acts: dict[int, Float[Tensor, "n d"]],
neg_acts: dict[int, Float[Tensor, "n d"]],
cfg: PCAC,
) -> dict[int, dict[str, Tensor]]:
out = {}
for li in pos_acts:
if pos_acts[li].shape[0] != neg_acts[li].shape[0]:
raise ValueError(f"layer {li}: pos/neg counts differ")
diffs = (pos_acts[li] - neg_acts[li]).float()
centered = diffs - diffs.mean(0, keepdim=True)
_, _, Vh = torch.linalg.svd(centered, full_matrices=False)
v = Vh[: cfg.n_components]
projs = centered @ v.T
positive_frac = (projs > 0).float().mean(0)
majority_sign = torch.where(positive_frac > 0.5,
torch.ones(v.shape[0]),
-torch.ones(v.shape[0])).to(v)
strongest_idx = projs.abs().argmax(dim=0)
strongest = projs[strongest_idx, torch.arange(v.shape[0], device=projs.device)]
strongest_sign = torch.sign(strongest)
sign = torch.where(positive_frac == 0.5, strongest_sign, majority_sign)
v = v * sign[:, None]
if cfg.n_components == 1:
v = v.squeeze(0)
if cfg.normalize:
v = v / (v.norm() + ε)
out[li] = {"v": v}
else:
if cfg.normalize:
v = v / (v.norm(dim=1, keepdim=True) + ε)
out[li] = {"V": v}
return out
@staticmethod
def apply(
mod,
x: Float[Tensor, "b s d"],
y: Float[Tensor, "b s d"],
state: dict[str, Tensor],
cfg: PCAC,
) -> Float[Tensor, "b s d"]:
if "v" in state:
v = state["v"].to(y)
return y + cfg.coeff * v
# multi-component: sum top-k directions equally
V = state["V"].to(y)
return y + cfg.coeff * V.sum(0)
+114
View File
@@ -0,0 +1,114 @@
"""Vector: extracted steering vector + config, as a single ergonomic object.
Wraps `(cfg, state)` so a user can:
v = Vector.train(model, tok, pos, neg, sl.MeanDiffC(layers=(15,))) \\
.calibrate(model, tok, target_kl=1.0)
with v(model):
out = model.generate(...)
v.save("honesty.safetensors")
v2 = Vector.load("honesty.safetensors")
combined = v + v2 # ensemble (sum buffers, requires same cfg.method)
scaled = v * 0.5 # scale buffers
"""
from __future__ import annotations
from contextlib import contextmanager
from copy import deepcopy
from dataclasses import replace
import torch
from torch import Tensor, nn
from .config import SteeringConfig
class Vector:
def __init__(self, cfg: SteeringConfig, state: dict[int, dict[str, Tensor]]):
self.cfg = cfg
self.state = state
@classmethod
def train(cls, model: nn.Module, tok, pos_prompts: list[str], neg_prompts: list[str],
cfg: SteeringConfig, **kw) -> "Vector":
"""Extract a steering vector from contrastive prompts. Chains with .calibrate()."""
from .attach import train as _train
return _train(model, tok, pos_prompts, neg_prompts, cfg, **kw)
def calibrate(self, model: nn.Module, tok,
prompts: list[str] | list[Tensor] | None = None, *,
target_kl: float = 1.0, **kw) -> "Vector":
"""Set coeff so KL(steered || base) hits target_kl. Mutates and returns self for chaining.
`prompts` defaults to a small generic set; pass list[str] to use your own.
"""
from .calibrate import calibrate_iso_kl
coeff, _ = calibrate_iso_kl(self, model, tok, prompts, target_kl=target_kl, **kw)
self.cfg.coeff = float(coeff)
return self
@contextmanager
def __call__(self, model: nn.Module, *, C: float | None = None):
"""Attach for the duration of the `with` block. `C` overrides cfg.coeff if given."""
from .attach import attach, detach
cfg = self.cfg if C is None else replace(self.cfg, coeff=float(C))
attach(model, cfg, self.state)
try:
yield model
finally:
detach(model)
def __add__(self, other: "Vector") -> "Vector":
if self.cfg.method != other.cfg.method:
raise ValueError(f"cannot add {self.cfg.method!r} + {other.cfg.method!r}")
if sorted(self.state) != sorted(other.state):
raise ValueError(f"layer mismatch: {sorted(self.state)} vs {sorted(other.state)}")
new_state: dict[int, dict[str, Tensor]] = {}
for li in self.state:
a, b = self.state[li], other.state[li]
if sorted(a) != sorted(b):
raise ValueError(f"layer {li}: state keys differ {sorted(a)} vs {sorted(b)}")
new_state[li] = {k: a[k] + b[k] for k in a}
return Vector(deepcopy(self.cfg), new_state)
def __mul__(self, k: float) -> "Vector":
new_state = {
li: {k_: v * float(k) for k_, v in s.items()}
for li, s in self.state.items()
}
return Vector(deepcopy(self.cfg), new_state)
__rmul__ = __mul__
def save(self, path: str) -> None:
from .attach import _STATE_PREFIX, _SUB_KEY_PREFIX, _SUB_KEY_SEP # noqa: F401
import json
from safetensors.torch import save_file
sd: dict[str, Tensor] = {}
sub_mode = self.cfg.target_submodule is not None
for key, s in self.state.items():
for k, t in s.items():
if sub_mode:
sd[f"{_SUB_KEY_PREFIX}{key}{_SUB_KEY_SEP}{k}"] = t.detach().cpu()
else:
sd[f"layer{key}.{k}"] = t.detach().cpu()
metadata = {"cfg": json.dumps(self.cfg.to_dict())}
save_file(sd, path, metadata=metadata)
@classmethod
def load(cls, path: str) -> "Vector":
import json
from safetensors.torch import load_file, safe_open
from .attach import _safetensors_dict_to_state
with safe_open(path, framework="pt", device="cpu") as f:
metadata = f.metadata()
sd = load_file(path, device="cpu")
cfg = SteeringConfig.from_dict(json.loads(metadata["cfg"]))
state = _safetensors_dict_to_state(sd)
return cls(cfg, state)
def __repr__(self) -> str:
layers = sorted(self.state)
return f"Vector(method={self.cfg.method!r}, layers={layers})"
+65
View File
@@ -0,0 +1,65 @@
"""Smoke test: train + calibrate + branch_pmass on a tiny random model.
Pass = runtime sanity. Distinguishing checks:
- measure_kl returns kl > 0 at coeff > 0 (steer DID change distribution).
- measure_kl returns kl ~= 0 at coeff = 0 (silent failure detector: if hooks
leak, base==steer KL would be nonzero).
- branch_pmass is in [0, 1].
- branch_pmass changes between coeff=0 and coeff=large (sneaky-fail catch:
if pmass is just identity-pass-through it would be invariant).
"""
from __future__ import annotations
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from iso_kl_figure import (
SteeringConfig, MeanDiffC, train, measure_kl, attach, detach,
)
from iso_kl_figure.branch_pmass import branch_pmass
MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM"
def test_pipeline_smoke():
tok = AutoTokenizer.from_pretrained(MODEL)
if tok.pad_token_id is None:
tok.pad_token_id = tok.eos_token_id
model = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype=torch.float32)
model.eval()
n_layers = model.config.num_hidden_layers
cfg = MeanDiffC(coeff=1.0, layers=(n_layers // 2,))
pos = ["Sure: <answer>", "Yes: <answer>", "Of course: <answer>", "Here: <answer>"]
neg = ["No way.", "I refuse.", "Cannot help.", "Decline."]
v = train(model, tok, pos, neg, cfg, batch_size=2, max_length=32)
# KL must be ~0 at coeff=0 (no leak), and > 0 at large coeff
v.cfg.coeff = 0.0
m0 = measure_kl(v, model, tok, ["hello world"], T=4, device="cpu")
assert m0["kl_p95"] < 1e-3, f"coeff=0 should give zero KL, got {m0['kl_p95']}"
v.cfg.coeff = 5.0
m1 = measure_kl(v, model, tok, ["hello world"], T=4, device="cpu")
assert m1["kl_p95"] > 0.0, "coeff>0 should give nonzero KL"
# per_t arrays length matches T
assert len(m1["per_t_p95"]) == 4
assert len(m1["per_t_max"]) == 4
# branch_pmass: in [0, 1] and varies with coeff
pids = tok("hello", return_tensors="pt").input_ids[0]
rolled = pids[-2:].clone()
suffix = ' {"value": '
fork = [0, 1, 2]
v.cfg.coeff = 0.0
p_zero = branch_pmass(v, model, tok, pids, rolled, fork, suffix,
["true", "false"], device="cpu")
v.cfg.coeff = 5.0
p_steer = branch_pmass(v, model, tok, pids, rolled, fork, suffix,
["true", "false"], device="cpu")
for x in p_zero["pmass"] + p_steer["pmass"]:
assert 0.0 <= x <= 1.0, f"pmass out of [0,1]: {x}"
diff = max(abs(a - b) for a, b in zip(p_zero["pmass"], p_steer["pmass"]))
assert diff > 1e-6, "pmass invariant to coeff -- hook is dead"