mirror of
https://github.com/wassname/isokl_steering_calibration.git
synced 2026-06-27 15:45:51 +08:00
iso-kl-figure: scaffold + smoke test passing
This commit is contained in:
+12
@@ -0,0 +1,12 @@
|
|||||||
|
__pycache__/
|
||||||
|
*.pyc
|
||||||
|
.venv/
|
||||||
|
uv.lock
|
||||||
|
*.egg-info/
|
||||||
|
.pytest_cache/
|
||||||
|
.ruff_cache/
|
||||||
|
outputs/*.csv
|
||||||
|
outputs/*.tsv
|
||||||
|
outputs/*.png
|
||||||
|
outputs/*.md
|
||||||
|
!outputs/.gitkeep
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
# AGENTS.md
|
||||||
|
|
||||||
|
Inherits conventions from sibling project `steering-lite`. Read [../steering-lite/AGENTS.md](../steering-lite/AGENTS.md) if it exists.
|
||||||
|
|
||||||
|
## House rules
|
||||||
|
|
||||||
|
- Fail fast. No defensive programming, no fallbacks, no silent dequant.
|
||||||
|
- Keep this repo small. Anything beyond the headline figure + table belongs in another repo.
|
||||||
|
- Use `einops` and `jaxtyping` shape annotations at function boundaries only. Tensor dim letters: `b s d` (batch, seq, d_model), `n` (prompts), `t` (token positions), `f` (fork points).
|
||||||
|
- No backward compat.
|
||||||
|
- Single functional smoke test = the real pipeline at tiny scale (`tests/test_smoke.py`).
|
||||||
|
- Methods register via `@register_config` and `@register` decorators; mirror `steering-lite/src/steering_lite/config.py`.
|
||||||
|
- All experiment scripts write CSV/TSV. Plot/table scripts read CSV/TSV. Never plot from in-memory state.
|
||||||
|
|
||||||
|
## Out of scope (deliberately)
|
||||||
|
|
||||||
|
- Method zoo beyond mean_diff, directional_ablation, pca.
|
||||||
|
- LessWrong post / paper draft.
|
||||||
|
- Citation collection.
|
||||||
|
- tinymfv or any external eval dependency.
|
||||||
|
|
||||||
|
## Verify
|
||||||
|
|
||||||
|
`just smoke` -> 3/3 methods pass calibrate -> trajectory -> branch-pmass on tiny-random Llama. Asserts nonzero KL at coeff>0, zero KL at coeff=0, branch-pmass in [0,1].
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
# iso-kl-figure
|
||||||
|
|
||||||
|
Minimal repo with one job: produce a figure and a table that demonstrate iso-KL calibration is stable across models, seeds, and calibration windows.
|
||||||
|
|
||||||
|
## Claim (narrow)
|
||||||
|
|
||||||
|
Calibrating a steering coefficient so that p95 per-token KL(steered || base) hits 1 nat in a short calibration window:
|
||||||
|
|
||||||
|
- C1: bisection converges for every method tested; held-out p95 KL lands near 1 nat.
|
||||||
|
- C2 (not too cold): target-axis Delta logit at calibrated alpha is non-zero across methods.
|
||||||
|
- C3 (not too hot): base-NLL of generated text and branch-pmass of a forced format token stay near base at calibrated alpha across methods.
|
||||||
|
|
||||||
|
The 2x check is a sanity probe, not a margin claim. Reported as: at 2x, p95 KL exceeds 1 nat for N of M cells.
|
||||||
|
|
||||||
|
Honesty footnote: matched on per-token distributional disagreement under greedy decoding in the calibration window. This is one defensible notion of fairness; not equivalence on intervention norm or behavioral effect size.
|
||||||
|
|
||||||
|
## Quick start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv sync --extra all
|
||||||
|
just smoke # tiny-random model, ~1 min CPU
|
||||||
|
just calibrate # one (model, method, seed, window) cell
|
||||||
|
just trajectory
|
||||||
|
just table
|
||||||
|
just plot
|
||||||
|
just table-md
|
||||||
|
```
|
||||||
|
|
||||||
|
`just sweep` runs the full grid (3 models x 3 methods x 3 seeds x 2 windows) used by Figure 1.
|
||||||
|
|
||||||
|
## What this repo does NOT do
|
||||||
|
|
||||||
|
- No paper or LessWrong draft.
|
||||||
|
- No method zoo beyond mean_diff, directional_ablation, pca.
|
||||||
|
- No threshold sweep, no calibration-set-size sweep.
|
||||||
|
- No tinymfv integration. Target-axis is a single contrastive sentiment / refusal pair.
|
||||||
@@ -0,0 +1,68 @@
|
|||||||
|
# iso-kl-figure: spec
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
|
||||||
|
Produce one figure (Figure 1) and one table (Table 1) that empirically support three claims: iso-KL calibration converges and generalizes (C1), the calibrated coefficient is not too cold (C2), and not too hot (C3). Show stability across 3 models x 3 seeds x 2 calibration windows.
|
||||||
|
|
||||||
|
## Scope
|
||||||
|
|
||||||
|
In:
|
||||||
|
- Port `measure_kl`, `calibrate_iso_kl`, minimal Vector/attach/config/target/extract from steering-lite.
|
||||||
|
- 3 methods: `mean_diff`, `directional_ablation`, `pca`.
|
||||||
|
- New `branch_pmass` metric: fork-and-teacher-force probability mass on a forced format answer token.
|
||||||
|
- Scripts producing TSV/CSV; plot and table modules consuming the CSVs.
|
||||||
|
|
||||||
|
Out:
|
||||||
|
- LessWrong post or paper draft.
|
||||||
|
- Method zoo beyond 3 methods.
|
||||||
|
- Threshold sweep, calibration-set-size sweep, norm-matching baseline.
|
||||||
|
- tinymfv integration.
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
- R1 (C1, calibration converges and generalizes): for every (method, model, seed, window), bisection terminates with calibration p95 within tolerance of 1.0; on a held-out prompt set p95 lands within [0.7, 1.4]. VERIFY: TSV row has converged=true and holdout_p95 in band; sneaky failure (overfits calibration prompts) caught by held-out column.
|
||||||
|
- R2 (C2, not too cold): target-axis Delta logit at calibrated alpha excludes 0 with 95% CI for each method, on each model. VERIFY: Table 1 row reports CI; sneaky failure (alpha approx 0) caught by alpha column in same row.
|
||||||
|
- R3 (C3, not too hot, NLL): base-NLL of full 50-token held-out generations stays within 2x of base at calibrated alpha; exceeds 4x of base at 2x calibrated alpha for at least one method per model. VERIFY: Table 1 base_nll_delta column.
|
||||||
|
- R4 (C3, not too hot, branch-pmass): mean branch-pmass-of-valid-answer at fork points {0, 5, ..., 50} stays within 0.1 of base pmass at calibrated alpha; drops by more than 0.3 at 2x alpha for at least one method per model. VERIFY: Table 1 branch_pmass column and Figure 1 lower subplot.
|
||||||
|
- R5 (sanity probe at 2x): max p95 KL at 2x alpha exceeds 1 nat in at least 2 of 3 methods on at least 2 of 3 models within 50 tokens. VERIFY: Figure 1 top subplot, alpha=2 panels show lines crossing reference.
|
||||||
|
- R6 (stability): seed band and window-style overlay in Figure 1 do not change the qualitative C1 conclusion. VERIFY: variance band visually narrow at alpha=1.
|
||||||
|
|
||||||
|
## Tasks
|
||||||
|
|
||||||
|
- [/] T1 (R*): scaffold repo (pyproject, justfile, README, AGENTS, spec).
|
||||||
|
- verify: `just --list` lists recipes; `uv sync --extra all` resolves.
|
||||||
|
- [ ] T2 (R1, R2, R3, R4): port core code from steering-lite (calibrate, vector, attach, config, target, extract, 3 variants).
|
||||||
|
- verify: imports clean; smoke test runs all 3 methods.
|
||||||
|
- [ ] T3 (R1, R6): extend calibrate history to save per-token KL arrays (`per_t_p95`, `per_t_max`).
|
||||||
|
- verify: history dict contains per-token arrays of length T.
|
||||||
|
- [ ] T4 (R4): implement `branch_pmass` (fork at token t, append fixed format suffix, teacher-force one forward, sum p over `true`/`false` tokens).
|
||||||
|
- verify: pmass in [0, 1]; pmass at base != pmass at coeff=large (sneaky-fail catch).
|
||||||
|
- [ ] T5 (R1..R5): implement `run_calibrate.py`, `run_trajectory.py`, `run_table.py`.
|
||||||
|
- verify: CSVs created with expected columns and at least one row each on smoke.
|
||||||
|
- [ ] T6 (R*): implement `plot.py`, `table.py`.
|
||||||
|
- verify: PNG saved; markdown table prints; can be regenerated from CSVs alone.
|
||||||
|
- [ ] T7 (R*): full sweep on real models.
|
||||||
|
- verify: numeric asserts in R1..R5 pass.
|
||||||
|
- [ ] T8 (R*): external review of figure + table.
|
||||||
|
- verify: review doc saved under docs/spec/.
|
||||||
|
|
||||||
|
## Context
|
||||||
|
|
||||||
|
Calibration target: p95 per-token KL(steered || base) = 1 nat over T tokens (T in {20, 50}), N=4 calibration prompts under greedy decoding.
|
||||||
|
|
||||||
|
Branch-pmass procedure: at fork points t in {0, 5, ..., 50} take steered prefix of length t, append `\nAnswer (true/false): ` then `{"value": ` then teacher-force one forward under steered model, sum probabilities of token variants for `true` and `false`.
|
||||||
|
|
||||||
|
Target-axis: a single contrastive pair-set built into the repo (sentiment positive vs negative or refusal yes vs no), 4 prompts each. Target Delta logit = mean over held-out prompts of difference in logit on the target token.
|
||||||
|
|
||||||
|
## Log
|
||||||
|
|
||||||
|
(append-only; only entries that change a future task)
|
||||||
|
|
||||||
|
## TODO
|
||||||
|
|
||||||
|
(out-of-scope ideas; not commitments)
|
||||||
|
|
||||||
|
## Errors
|
||||||
|
|
||||||
|
| Task | Error | Resolution |
|
||||||
|
|------|-------|------------|
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
set shell := ["bash", "-cu"]
|
||||||
|
|
||||||
|
default:
|
||||||
|
@just --list
|
||||||
|
|
||||||
|
# Smoke: tiny-random Llama, all 3 methods, asserts nonzero KL + branch-pmass changes with coeff.
|
||||||
|
smoke:
|
||||||
|
BEARTYPE=1 uv run --extra all pytest -q tests/test_smoke.py
|
||||||
|
|
||||||
|
test:
|
||||||
|
uv run --extra all pytest -q
|
||||||
|
|
||||||
|
# Run one (model, method, seed, window) cell end-to-end (calibrate + trajectory + pmass).
|
||||||
|
cell model="Qwen/Qwen2.5-0.5B-Instruct" method="mean_diff" seed="0" window="50":
|
||||||
|
uv run --extra all python scripts/run_cell.py \
|
||||||
|
--model {{model}} --method {{method}} --seed {{seed}} --window {{window}}
|
||||||
|
|
||||||
|
# Sweep model x method x seed x window cells.
|
||||||
|
sweep:
|
||||||
|
bash scripts/sweep.sh
|
||||||
|
|
||||||
|
# Aggregate all outputs/<run_id>/ into figs/figure1.png + figs/table.md.
|
||||||
|
aggregate:
|
||||||
|
uv run --extra all python scripts/aggregate.py --runs-root outputs --out figs
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
[project]
|
||||||
|
name = "iso-kl-figure"
|
||||||
|
version = "0.0.1"
|
||||||
|
description = "Minimal repo: produce one figure + one table proving iso-KL calibration is stable across models/seeds/windows."
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
dependencies = [
|
||||||
|
"torch>=2.1",
|
||||||
|
"numpy>=1.26",
|
||||||
|
"einops>=0.7",
|
||||||
|
"jaxtyping>=0.2.34",
|
||||||
|
"safetensors>=0.5",
|
||||||
|
"loguru>=0.7",
|
||||||
|
"tqdm>=4.66",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
test = ["pytest", "tabulate", "beartype>=0.18"]
|
||||||
|
hf = ["accelerate>=1.6", "transformers>=4.51"]
|
||||||
|
plot = ["matplotlib>=3.8", "polars>=1.0"]
|
||||||
|
all = [
|
||||||
|
"pytest", "tabulate", "beartype>=0.18",
|
||||||
|
"accelerate>=1.6", "transformers>=4.51",
|
||||||
|
"matplotlib>=3.8", "polars>=1.0",
|
||||||
|
"tyro>=0.9",
|
||||||
|
]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=68"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[tool.setuptools.packages.find]
|
||||||
|
where = ["src"]
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
ignore = ["F722"] # jaxtyping shape strings
|
||||||
@@ -0,0 +1,149 @@
|
|||||||
|
"""Aggregate per-cell outputs into Figure 1 + the headline table.
|
||||||
|
|
||||||
|
Figure 1: two stacked subplots.
|
||||||
|
Top: per-token p95 KL trajectory. x = token offset; y = KL(steer || base).
|
||||||
|
Colour by method, linestyle by alpha (solid=1, dashed=2), seed bands
|
||||||
|
as thin lines, faceted by model. Horizontal at target_kl=1.
|
||||||
|
Bottom: branch-pmass at fork points. x = fork token offset; y = mean pmass
|
||||||
|
across held-out prompts; bands = +/- 1 std across seeds.
|
||||||
|
|
||||||
|
Table: one row per (model, method), columns = c_star (mean +/- std across seeds),
|
||||||
|
KL_p95 @ alpha=1, KL_p95 @ alpha=2, pmass @ alpha=1, pmass @ alpha=2.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python scripts/aggregate.py --runs_root outputs --out figs/
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
import json
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import polars as pl
|
||||||
|
import tyro
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Args:
|
||||||
|
runs_root: str = "outputs"
|
||||||
|
out: str = "figs"
|
||||||
|
|
||||||
|
|
||||||
|
def load_cells(root: Path) -> list[dict]:
|
||||||
|
cells = []
|
||||||
|
for d in sorted(root.iterdir()):
|
||||||
|
if not d.is_dir():
|
||||||
|
continue
|
||||||
|
calib = d / "calib.json"
|
||||||
|
if not calib.exists():
|
||||||
|
continue
|
||||||
|
meta = json.loads(calib.read_text())
|
||||||
|
traj = json.loads((d / "trajectory.json").read_text())
|
||||||
|
pmass = json.loads((d / "pmass.json").read_text())
|
||||||
|
cells.append({"id": d.name, **meta, "traj": traj, "pmass": pmass})
|
||||||
|
return cells
|
||||||
|
|
||||||
|
|
||||||
|
def make_table(cells: list[dict]) -> pl.DataFrame:
|
||||||
|
rows = []
|
||||||
|
by_mm = defaultdict(list)
|
||||||
|
for c in cells:
|
||||||
|
by_mm[(c["model"], c["method"])].append(c)
|
||||||
|
for (model, method), group in by_mm.items():
|
||||||
|
c_stars = [g["c_star"] for g in group]
|
||||||
|
# pmass: mean over fork_points and prompts at each alpha, then across seeds
|
||||||
|
for alpha in ("1.0", "2.0"):
|
||||||
|
kls = []
|
||||||
|
pms = []
|
||||||
|
for g in group:
|
||||||
|
kls.append(g["traj"]["per_t_p95_kl"][alpha])
|
||||||
|
pms.append(g["pmass"]["pmass"][alpha])
|
||||||
|
kls_flat = [x for arr in kls for x in arr]
|
||||||
|
pms_flat = [x for prompt in pms for arr in prompt for x in arr]
|
||||||
|
rows.append({
|
||||||
|
"model": model.split("/")[-1],
|
||||||
|
"method": method,
|
||||||
|
"alpha": float(alpha),
|
||||||
|
"c_star_mean": sum(c_stars) / len(c_stars),
|
||||||
|
"n_seeds": len(group),
|
||||||
|
"kl_p95_mean": sum(kls_flat) / max(len(kls_flat), 1),
|
||||||
|
"pmass_mean": sum(pms_flat) / max(len(pms_flat), 1),
|
||||||
|
})
|
||||||
|
return pl.DataFrame(rows)
|
||||||
|
|
||||||
|
|
||||||
|
def make_figure(cells: list[dict], out_path: Path) -> None:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
models = sorted({c["model"] for c in cells})
|
||||||
|
methods = sorted({c["method"] for c in cells})
|
||||||
|
fig, axes = plt.subplots(2, len(models), figsize=(5 * len(models), 7),
|
||||||
|
sharex="col", squeeze=False)
|
||||||
|
cmap = plt.get_cmap("tab10")
|
||||||
|
method_color = {m: cmap(i) for i, m in enumerate(methods)}
|
||||||
|
|
||||||
|
for ci, model in enumerate(models):
|
||||||
|
ax_kl = axes[0, ci]
|
||||||
|
ax_pm = axes[1, ci]
|
||||||
|
ax_kl.set_title(model.split("/")[-1])
|
||||||
|
ax_kl.axhline(1.0, color="black", linestyle=":", linewidth=0.8, alpha=0.5)
|
||||||
|
ax_kl.set_ylabel("p95 KL(steer || base)")
|
||||||
|
ax_pm.set_xlabel("token offset")
|
||||||
|
ax_pm.set_ylabel("branch pmass")
|
||||||
|
ax_pm.set_ylim(-0.02, 1.02)
|
||||||
|
ax_kl.set_yscale("log")
|
||||||
|
|
||||||
|
for method in methods:
|
||||||
|
for alpha, ls in [("1.0", "-"), ("2.0", "--")]:
|
||||||
|
kls = [c["traj"]["per_t_p95_kl"][alpha]
|
||||||
|
for c in cells if c["model"] == model and c["method"] == method]
|
||||||
|
if not kls:
|
||||||
|
continue
|
||||||
|
arr = np.array(kls)
|
||||||
|
x = np.arange(arr.shape[1])
|
||||||
|
ax_kl.plot(x, arr.mean(0), color=method_color[method],
|
||||||
|
linestyle=ls, linewidth=2,
|
||||||
|
label=f"{method} a={alpha}")
|
||||||
|
if arr.shape[0] > 1:
|
||||||
|
ax_kl.fill_between(x, arr.min(0), arr.max(0),
|
||||||
|
color=method_color[method], alpha=0.12)
|
||||||
|
|
||||||
|
pms = [c["pmass"]["pmass"][alpha]
|
||||||
|
for c in cells if c["model"] == model and c["method"] == method]
|
||||||
|
if not pms:
|
||||||
|
continue
|
||||||
|
# pms: list of (n_seed) of (n_prompt) of (n_fork)
|
||||||
|
pms_arr = np.array(pms) # (n_seed, n_prompt, n_fork)
|
||||||
|
fork = cells[0]["pmass"]["fork_points"]
|
||||||
|
mean = pms_arr.mean(axis=(0, 1))
|
||||||
|
std = pms_arr.std(axis=(0, 1))
|
||||||
|
ax_pm.plot(fork, mean, color=method_color[method],
|
||||||
|
linestyle=ls, linewidth=2)
|
||||||
|
ax_pm.fill_between(fork, mean - std, mean + std,
|
||||||
|
color=method_color[method], alpha=0.12)
|
||||||
|
if ci == 0:
|
||||||
|
ax_kl.legend(fontsize=8, loc="upper left")
|
||||||
|
fig.tight_layout()
|
||||||
|
fig.savefig(out_path, dpi=150, bbox_inches="tight")
|
||||||
|
logger.info(f"figure -> {out_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def main(a: Args):
|
||||||
|
out = Path(a.out); out.mkdir(parents=True, exist_ok=True)
|
||||||
|
cells = load_cells(Path(a.runs_root))
|
||||||
|
if not cells:
|
||||||
|
raise SystemExit(f"no cells under {a.runs_root}")
|
||||||
|
logger.info(f"loaded {len(cells)} cells")
|
||||||
|
|
||||||
|
df = make_table(cells)
|
||||||
|
df.write_csv(out / "table.csv")
|
||||||
|
md = df.to_pandas().to_markdown(index=False, floatfmt=".3f")
|
||||||
|
(out / "table.md").write_text(md)
|
||||||
|
logger.info(f"table -> {out/'table.md'}\n{md}")
|
||||||
|
|
||||||
|
make_figure(cells, out / "figure1.png")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main(tyro.cli(Args))
|
||||||
@@ -0,0 +1,198 @@
|
|||||||
|
"""End-to-end runner for one (model, method, seed, window) cell.
|
||||||
|
|
||||||
|
Flow:
|
||||||
|
1. Load model + tokenizer (HF), set seed.
|
||||||
|
2. Build pos/neg prompts (cheap pair); train the steering Vector v.
|
||||||
|
3. Calibrate iso-KL at target_kl=1 over T=window tokens. Save full history
|
||||||
|
(incl. per-token KL arrays) to outputs/<run_id>/history.json.
|
||||||
|
4. Re-run measure_kl at coeff=alpha*c_star (alpha in {1, 2}) on a held-out
|
||||||
|
prompt set so the trajectory plot reflects generalisation, not the
|
||||||
|
calibration set itself. Save per-token p95 KL to trajectory.json.
|
||||||
|
5. For each held-out prompt, rollout T_eval tokens under the steered model,
|
||||||
|
then branch-pmass at fork_points {0, 5, ..., T_eval}. Save to
|
||||||
|
pmass.json. Use a JSON-format suffix so target tokens are well-defined.
|
||||||
|
|
||||||
|
Outputs one CSV row per (alpha, prompt) into outputs/<run_id>/results.csv
|
||||||
|
plus the artefacts above.
|
||||||
|
|
||||||
|
Usage (tyro):
|
||||||
|
python scripts/run_cell.py --model Qwen/Qwen2.5-0.5B-Instruct --method mean_diff \
|
||||||
|
--seed 0 --window 50 --run_id qwen05b_md_s0_w50
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import tyro
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from iso_kl_figure import (
|
||||||
|
SteeringConfig, MeanDiffC, PCAC, DirectionalAblationC,
|
||||||
|
train, calibrate_iso_kl, measure_kl,
|
||||||
|
)
|
||||||
|
from iso_kl_figure.branch_pmass import branch_pmass
|
||||||
|
|
||||||
|
|
||||||
|
CALIB_PROMPTS = [
|
||||||
|
"The eiffel tower is in Paris.",
|
||||||
|
"Tell me a greentext story about a small village during the Martian carrot bubble.",
|
||||||
|
"Walk me through computing the integral of x^2 from 0 to 1.",
|
||||||
|
"Briefly describe how plate tectonics shape mountain ranges.",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Held-out prompts -- different from calibration; tests generalisation.
|
||||||
|
EVAL_PROMPTS = [
|
||||||
|
"Explain why the sky is blue in two short sentences.",
|
||||||
|
"Summarise the plot of Hamlet in one paragraph.",
|
||||||
|
"List three causes of inflation.",
|
||||||
|
"Write a haiku about an octopus librarian.",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Pos/neg pair: a generic content-vs-refusal direction. Cheap, not the main
|
||||||
|
# point -- the figure tests calibration *behaviour*, not direction quality.
|
||||||
|
POS_NEG = [
|
||||||
|
("Sure, here is the answer:", "I cannot help with that."),
|
||||||
|
("Here are some details:", "Sorry, I can't assist with that."),
|
||||||
|
("Of course, let me explain.", "I won't be able to help."),
|
||||||
|
("Yes, that makes sense.", "No, I have to decline."),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
METHOD_MAP = {
|
||||||
|
"mean_diff": MeanDiffC,
|
||||||
|
"pca": PCAC,
|
||||||
|
"directional_ablation": DirectionalAblationC,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Args:
|
||||||
|
model: str
|
||||||
|
method: str
|
||||||
|
seed: int = 0
|
||||||
|
window: int = 50
|
||||||
|
run_id: str = ""
|
||||||
|
layer_frac: float = 0.6
|
||||||
|
target_kl: float = 1.0
|
||||||
|
out_root: str = "outputs"
|
||||||
|
device: str = "cuda"
|
||||||
|
dtype: str = "bfloat16"
|
||||||
|
suffix_str: str = ' Final answer in JSON: {"value": '
|
||||||
|
target_words: list[str] = field(default_factory=lambda: ["true", "false", "yes", "no"])
|
||||||
|
fork_step: int = 5
|
||||||
|
|
||||||
|
|
||||||
|
def _set_seed(s: int):
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
random.seed(s); np.random.seed(s); torch.manual_seed(s)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(s)
|
||||||
|
|
||||||
|
|
||||||
|
def main(a: Args):
|
||||||
|
if not a.run_id:
|
||||||
|
a.run_id = f"{a.model.split('/')[-1]}_{a.method}_s{a.seed}_w{a.window}"
|
||||||
|
out_dir = Path(a.out_root) / a.run_id
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
logger.add(out_dir / "run.log", level="INFO")
|
||||||
|
|
||||||
|
_set_seed(a.seed)
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
dtype = getattr(torch, a.dtype)
|
||||||
|
tok = AutoTokenizer.from_pretrained(a.model)
|
||||||
|
if tok.pad_token_id is None:
|
||||||
|
tok.pad_token_id = tok.eos_token_id
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(a.model, torch_dtype=dtype).to(a.device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
n_layers = model.config.num_hidden_layers
|
||||||
|
layer = int(a.layer_frac * n_layers)
|
||||||
|
logger.info(f"model={a.model} n_layers={n_layers} target_layer={layer}")
|
||||||
|
|
||||||
|
cfg_cls = METHOD_MAP[a.method]
|
||||||
|
cfg = cfg_cls(coeff=1.0, layers=(layer,))
|
||||||
|
|
||||||
|
pos = [tok.apply_chat_template([{"role": "user", "content": u},
|
||||||
|
{"role": "assistant", "content": p}],
|
||||||
|
tokenize=False)
|
||||||
|
for u, (p, _) in zip(CALIB_PROMPTS, POS_NEG)]
|
||||||
|
neg = [tok.apply_chat_template([{"role": "user", "content": u},
|
||||||
|
{"role": "assistant", "content": n}],
|
||||||
|
tokenize=False)
|
||||||
|
for u, (_, n) in zip(CALIB_PROMPTS, POS_NEG)]
|
||||||
|
v = train(model, tok, pos, neg, cfg, batch_size=4, max_length=128)
|
||||||
|
|
||||||
|
logger.info("=== calibrate ===")
|
||||||
|
c_star, history = calibrate_iso_kl(
|
||||||
|
v, model, tok, CALIB_PROMPTS,
|
||||||
|
target_kl=a.target_kl, target_stat="kl_p95",
|
||||||
|
T=a.window, device=a.device,
|
||||||
|
)
|
||||||
|
v.cfg.coeff = c_star
|
||||||
|
logger.info(f"c_star = {c_star:+.4f}")
|
||||||
|
(out_dir / "history.json").write_text(json.dumps(history, indent=2))
|
||||||
|
(out_dir / "calib.json").write_text(json.dumps({
|
||||||
|
"c_star": c_star, "target_kl": a.target_kl, "window": a.window,
|
||||||
|
"method": a.method, "model": a.model, "seed": a.seed, "layer": layer,
|
||||||
|
}, indent=2))
|
||||||
|
|
||||||
|
# -- trajectory + pmass at alpha in {1, 2} on held-out prompts
|
||||||
|
rows = []
|
||||||
|
fork_points = list(range(0, a.window + 1, a.fork_step))
|
||||||
|
trajectory: dict[str, list] = {}
|
||||||
|
pmass_all: dict[str, list] = {}
|
||||||
|
for alpha in (1.0, 2.0):
|
||||||
|
v.cfg.coeff = alpha * c_star
|
||||||
|
logger.info(f"=== eval alpha={alpha} c={v.cfg.coeff:+.4f} ===")
|
||||||
|
m = measure_kl(v, model, tok, EVAL_PROMPTS, T=a.window, device=a.device)
|
||||||
|
trajectory[str(alpha)] = m["per_t_p95"]
|
||||||
|
rows.append({"alpha": alpha, "coeff": v.cfg.coeff, "kl_p95": m["kl_p95"],
|
||||||
|
"kl_mean": m["kl_mean"], "kl_max": m["kl_max"]})
|
||||||
|
|
||||||
|
# pmass per held-out prompt
|
||||||
|
pm_for_alpha = []
|
||||||
|
for p in EVAL_PROMPTS:
|
||||||
|
ids = tok.apply_chat_template(
|
||||||
|
[{"role": "user", "content": p}],
|
||||||
|
add_generation_prompt=True, return_tensors="pt",
|
||||||
|
).input_ids[0]
|
||||||
|
pad = tok.pad_token_id
|
||||||
|
with v(model):
|
||||||
|
gen = model.generate(
|
||||||
|
ids.unsqueeze(0).to(a.device),
|
||||||
|
max_new_tokens=a.window,
|
||||||
|
pad_token_id=pad, eos_token_id=tok.eos_token_id,
|
||||||
|
do_sample=False,
|
||||||
|
)[0, ids.shape[0]:]
|
||||||
|
pm = branch_pmass(
|
||||||
|
v, model, tok, ids, gen, fork_points,
|
||||||
|
a.suffix_str, a.target_words, device=a.device,
|
||||||
|
)
|
||||||
|
pm_for_alpha.append(pm["pmass"])
|
||||||
|
pmass_all[str(alpha)] = pm_for_alpha
|
||||||
|
|
||||||
|
(out_dir / "trajectory.json").write_text(json.dumps({
|
||||||
|
"fork_points_full": list(range(a.window)),
|
||||||
|
"per_t_p95_kl": trajectory,
|
||||||
|
}, indent=2))
|
||||||
|
(out_dir / "pmass.json").write_text(json.dumps({
|
||||||
|
"fork_points": fork_points,
|
||||||
|
"pmass": pmass_all,
|
||||||
|
"suffix": a.suffix_str,
|
||||||
|
"target_words": a.target_words,
|
||||||
|
}, indent=2))
|
||||||
|
import csv
|
||||||
|
with open(out_dir / "results.csv", "w", newline="") as f:
|
||||||
|
w = csv.DictWriter(f, fieldnames=["alpha", "coeff", "kl_p95", "kl_mean", "kl_max"])
|
||||||
|
w.writeheader()
|
||||||
|
for r in rows:
|
||||||
|
w.writerow(r)
|
||||||
|
logger.info(f"DONE -> {out_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main(tyro.cli(Args))
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# Sweep: model x method x seed x window. Edit the lists to taste.
|
||||||
|
set -euo pipefail
|
||||||
|
cd "$(dirname "$0")/.."
|
||||||
|
|
||||||
|
MODELS=("Qwen/Qwen2.5-0.5B-Instruct" "Qwen/Qwen2.5-1.5B-Instruct" "meta-llama/Llama-3.2-1B-Instruct")
|
||||||
|
METHODS=("mean_diff" "directional_ablation" "pca")
|
||||||
|
SEEDS=(0 1 2)
|
||||||
|
WINDOWS=(20 50)
|
||||||
|
|
||||||
|
for model in "${MODELS[@]}"; do
|
||||||
|
for method in "${METHODS[@]}"; do
|
||||||
|
for seed in "${SEEDS[@]}"; do
|
||||||
|
for window in "${WINDOWS[@]}"; do
|
||||||
|
run_id="$(basename "$model")_${method}_s${seed}_w${window}"
|
||||||
|
if [ -f "outputs/${run_id}/calib.json" ]; then
|
||||||
|
echo "skip ${run_id}"; continue
|
||||||
|
fi
|
||||||
|
echo "=== ${run_id} ==="
|
||||||
|
uv run --extra all python scripts/run_cell.py \
|
||||||
|
--model "$model" --method "$method" --seed "$seed" --window "$window"
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
||||||
|
|
||||||
|
uv run --extra all python scripts/aggregate.py --runs-root outputs --out figs
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
import os as _os
|
||||||
|
|
||||||
|
if _os.environ.get("BEARTYPE"):
|
||||||
|
from beartype.claw import beartype_this_package as _bt
|
||||||
|
_bt()
|
||||||
|
|
||||||
|
from .config import SteeringConfig, REGISTRY, register
|
||||||
|
from .extract import record_activations
|
||||||
|
from .attach import attach, detach, save, load, train
|
||||||
|
from .calibrate import measure_kl, calibrate_iso_kl
|
||||||
|
from . import variants # noqa: F401 triggers method + config registration
|
||||||
|
from .vector import Vector
|
||||||
|
|
||||||
|
from .variants.mean_diff import MeanDiffC
|
||||||
|
from .variants.pca import PCAC
|
||||||
|
from .variants.directional_ablation import DirectionalAblationC
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SteeringConfig",
|
||||||
|
"MeanDiffC",
|
||||||
|
"PCAC",
|
||||||
|
"DirectionalAblationC",
|
||||||
|
"record_activations",
|
||||||
|
"train",
|
||||||
|
"attach",
|
||||||
|
"detach",
|
||||||
|
"save",
|
||||||
|
"load",
|
||||||
|
"measure_kl",
|
||||||
|
"calibrate_iso_kl",
|
||||||
|
"REGISTRY",
|
||||||
|
"register",
|
||||||
|
"Vector",
|
||||||
|
]
|
||||||
@@ -0,0 +1,243 @@
|
|||||||
|
"""attach / detach / save / load. The whole runtime.
|
||||||
|
|
||||||
|
Variant protocol (uniform across both hook paths):
|
||||||
|
|
||||||
|
apply(mod, x, y, state, cfg) -> y_new
|
||||||
|
|
||||||
|
`mod` is the hooked module itself (a transformer block or a Linear); `x` is
|
||||||
|
its input, `y` its output. Variants return the module's NEW output: additive
|
||||||
|
variants do `return y + delta`, replacing variants ignore `y` and return any
|
||||||
|
tensor of the same shape. Same contract as lora-lite's `Variant.forward`.
|
||||||
|
|
||||||
|
Two hook paths, dispatched on `cfg.target_submodule`:
|
||||||
|
|
||||||
|
- `target_submodule is None` (default): hook each transformer block's
|
||||||
|
forward output. `mod = block`, `x = args[0]` (input residual),
|
||||||
|
`y = out[0]` (output hidden_states). State keyed by `int` (block index).
|
||||||
|
- `target_submodule = <regex>`: hook every nn.Linear in each selected block
|
||||||
|
whose dotted path matches the regex. `mod = linear`, `x` is the Linear's
|
||||||
|
input, `y` its output. State keyed by `str` (full dotted name like
|
||||||
|
`"layers.5.mlp.down_proj"`).
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.hooks import RemovableHandle
|
||||||
|
|
||||||
|
from .config import SteeringConfig, REGISTRY
|
||||||
|
from .target import find_targets
|
||||||
|
from .extract import record_activations
|
||||||
|
|
||||||
|
|
||||||
|
_ATTACHED_ATTR = "_steering_lite_attached"
|
||||||
|
_STATE_PREFIX = "_steering_state_"
|
||||||
|
_SUB_KEY_PREFIX = "sub::" # safetensors key prefix marking submodule-level state
|
||||||
|
_SUB_KEY_SEP = "::" # separator between full_name and state_key
|
||||||
|
|
||||||
|
|
||||||
|
def _gather_state(mod) -> dict[str, torch.Tensor]:
|
||||||
|
return {
|
||||||
|
k[len(_STATE_PREFIX):]: getattr(mod, k)
|
||||||
|
for k in dir(mod)
|
||||||
|
if k.startswith(_STATE_PREFIX) and isinstance(getattr(mod, k, None), torch.Tensor)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _hook(mod, args, out):
|
||||||
|
"""Forward hook for block-level variants. Block forward returns a tuple
|
||||||
|
`(hidden_states, ...)`; we replace `[0]` with the variant's output."""
|
||||||
|
cfg: SteeringConfig = mod._steering_cfg
|
||||||
|
method = mod._steering_method
|
||||||
|
state = _gather_state(mod)
|
||||||
|
x = args[0]
|
||||||
|
if isinstance(out, tuple):
|
||||||
|
y = out[0]
|
||||||
|
y_new = method.apply(mod, x, y, state, cfg)
|
||||||
|
return (y_new,) + out[1:]
|
||||||
|
return method.apply(mod, x, out, state, cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def _linear_hook(mod, args, out):
|
||||||
|
"""Forward hook for sub-module variants (cfg.target_submodule is set).
|
||||||
|
`out` is the Linear's output tensor (not a tuple)."""
|
||||||
|
cfg: SteeringConfig = mod._steering_cfg
|
||||||
|
method = mod._steering_method
|
||||||
|
state = _gather_state(mod)
|
||||||
|
return method.apply(mod, args[0], out, state, cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def _install_state(mod: nn.Module, state: dict[str, torch.Tensor], cfg: SteeringConfig) -> None:
|
||||||
|
for k, v in state.items():
|
||||||
|
attr = _STATE_PREFIX + k
|
||||||
|
if hasattr(mod, attr):
|
||||||
|
raise RuntimeError(f"module already has {attr}; detach first")
|
||||||
|
mod.register_buffer(attr, v.to(cfg.dtype), persistent=True)
|
||||||
|
|
||||||
|
|
||||||
|
def attach(
|
||||||
|
model: nn.Module,
|
||||||
|
cfg: SteeringConfig,
|
||||||
|
vectors,
|
||||||
|
) -> list[RemovableHandle]:
|
||||||
|
"""Install per-target state as buffers and register forward hooks.
|
||||||
|
|
||||||
|
`vectors` shape depends on cfg.target_submodule:
|
||||||
|
- None: dict[int, dict[str, Tensor]] keyed by block layer index.
|
||||||
|
- regex set: dict[str, dict[str, Tensor]] keyed by full dotted name.
|
||||||
|
|
||||||
|
Accepts a `Vector` (auto-unwrapped to its `.state`).
|
||||||
|
"""
|
||||||
|
from .vector import Vector
|
||||||
|
if isinstance(vectors, Vector):
|
||||||
|
vectors = vectors.state
|
||||||
|
if cfg.method not in REGISTRY:
|
||||||
|
raise KeyError(f"unknown method {cfg.method!r}; registered: {list(REGISTRY)}")
|
||||||
|
method = REGISTRY[cfg.method]
|
||||||
|
# variant-level default target_submodule (e.g. sspace -> residual writers)
|
||||||
|
if cfg.target_submodule is None and getattr(method, "default_target_submodule", None):
|
||||||
|
cfg.target_submodule = method.default_target_submodule
|
||||||
|
requires_linear = cfg.target_submodule is not None
|
||||||
|
targets = find_targets(model, cfg)
|
||||||
|
if not targets:
|
||||||
|
raise RuntimeError("no target layers matched cfg")
|
||||||
|
|
||||||
|
handles: list[RemovableHandle] = []
|
||||||
|
attached_names: list[str] = []
|
||||||
|
hooked_modules: list[nn.Module] = []
|
||||||
|
for full_name, mod, li in targets:
|
||||||
|
key = full_name if requires_linear else li
|
||||||
|
if key not in vectors:
|
||||||
|
raise KeyError(f"vectors missing key {key!r}; have {sorted(vectors)}")
|
||||||
|
_install_state(mod, vectors[key], cfg)
|
||||||
|
mod._steering_cfg = cfg
|
||||||
|
mod._steering_method = method
|
||||||
|
if requires_linear:
|
||||||
|
mod._steering_module_name = full_name
|
||||||
|
hooked_modules.append(mod)
|
||||||
|
handles.append(mod.register_forward_hook(_linear_hook))
|
||||||
|
else:
|
||||||
|
mod._steering_layer_idx = li
|
||||||
|
handles.append(mod.register_forward_hook(_hook))
|
||||||
|
attached_names.append(full_name)
|
||||||
|
|
||||||
|
setattr(model, _ATTACHED_ATTR, {
|
||||||
|
"cfg": cfg, "targets": attached_names, "handles": handles,
|
||||||
|
"hooked_modules": hooked_modules,
|
||||||
|
})
|
||||||
|
return handles
|
||||||
|
|
||||||
|
|
||||||
|
def detach(model: nn.Module) -> None:
|
||||||
|
state = getattr(model, _ATTACHED_ATTR, None)
|
||||||
|
if state is None:
|
||||||
|
return
|
||||||
|
for h in state["handles"]:
|
||||||
|
h.remove()
|
||||||
|
for _, mod in model.named_modules():
|
||||||
|
if not hasattr(mod, "_steering_method"):
|
||||||
|
continue
|
||||||
|
for k in [k for k in list(mod._buffers) if k.startswith(_STATE_PREFIX)]:
|
||||||
|
del mod._buffers[k]
|
||||||
|
for attr in (
|
||||||
|
"_steering_cfg", "_steering_method",
|
||||||
|
"_steering_layer_idx", "_steering_module_name",
|
||||||
|
):
|
||||||
|
if hasattr(mod, attr):
|
||||||
|
delattr(mod, attr)
|
||||||
|
delattr(model, _ATTACHED_ATTR)
|
||||||
|
|
||||||
|
|
||||||
|
def _log_extract_demo(tok, pos_prompts: list[str], neg_prompts: list[str]) -> None:
|
||||||
|
"""One trace per stage: decoded full prompt + special tokens, for format debugging."""
|
||||||
|
from loguru import logger
|
||||||
|
pos = pos_prompts[0]
|
||||||
|
neg = neg_prompts[0]
|
||||||
|
logger.info(
|
||||||
|
"EXPECT: POS and NEG share user_msg + suffix; differ only in system persona; "
|
||||||
|
"chat template applied; special tokens (e.g. <|im_start|>) visible.\n"
|
||||||
|
"=== EXTRACT demo trace ===\n"
|
||||||
|
f"POS[0]:\n{pos}\n---\nNEG[0]:\n{neg}\n=== /EXTRACT ==="
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def train(
|
||||||
|
model: nn.Module,
|
||||||
|
tok,
|
||||||
|
pos_prompts: list[str],
|
||||||
|
neg_prompts: list[str],
|
||||||
|
cfg: SteeringConfig,
|
||||||
|
*,
|
||||||
|
batch_size: int = 8,
|
||||||
|
max_length: int = 256,
|
||||||
|
):
|
||||||
|
"""Extract activations + run method.extract -> Vector. Block-level only.
|
||||||
|
|
||||||
|
Stripped from steering-lite: no submodule regex hooks, no attn pooling.
|
||||||
|
"""
|
||||||
|
from .vector import Vector
|
||||||
|
_log_extract_demo(tok, pos_prompts, neg_prompts)
|
||||||
|
method = REGISTRY[cfg.method]
|
||||||
|
targets = find_targets(model, cfg)
|
||||||
|
layers = tuple(li for _, _, li in targets)
|
||||||
|
pos_acts = record_activations(model, tok, pos_prompts, layers, batch_size=batch_size, max_length=max_length)
|
||||||
|
neg_acts = record_activations(model, tok, neg_prompts, layers, batch_size=batch_size, max_length=max_length)
|
||||||
|
state = method.extract(pos_acts, neg_acts, cfg)
|
||||||
|
return Vector(cfg, state)
|
||||||
|
|
||||||
|
|
||||||
|
def _state_to_safetensors_dict(model: nn.Module) -> dict:
|
||||||
|
"""Serialise installed state buffers from all hooked modules. Forks on
|
||||||
|
whether the module is block-level (_steering_layer_idx) or submodule-level
|
||||||
|
(_steering_module_name); keys distinguish the two so load() can rebuild."""
|
||||||
|
sd = {}
|
||||||
|
for _, mod in model.named_modules():
|
||||||
|
if not hasattr(mod, "_steering_method"):
|
||||||
|
continue
|
||||||
|
if hasattr(mod, "_steering_module_name"):
|
||||||
|
full_name = mod._steering_module_name
|
||||||
|
for k, v in mod._buffers.items():
|
||||||
|
if k.startswith(_STATE_PREFIX):
|
||||||
|
sd[f"{_SUB_KEY_PREFIX}{full_name}{_SUB_KEY_SEP}{k[len(_STATE_PREFIX):]}"] = v.detach().cpu()
|
||||||
|
elif hasattr(mod, "_steering_layer_idx"):
|
||||||
|
li = mod._steering_layer_idx
|
||||||
|
for k, v in mod._buffers.items():
|
||||||
|
if k.startswith(_STATE_PREFIX):
|
||||||
|
sd[f"layer{li}.{k[len(_STATE_PREFIX):]}"] = v.detach().cpu()
|
||||||
|
return sd
|
||||||
|
|
||||||
|
|
||||||
|
def _safetensors_dict_to_state(sd: dict[str, torch.Tensor]) -> dict:
|
||||||
|
"""Inverse of _state_to_safetensors_dict. Returns dict keyed by int (block-level)
|
||||||
|
or str (submodule-level), depending on the prefix of each key."""
|
||||||
|
vectors: dict = {}
|
||||||
|
for k, v in sd.items():
|
||||||
|
if k.startswith(_SUB_KEY_PREFIX):
|
||||||
|
rest = k[len(_SUB_KEY_PREFIX):]
|
||||||
|
full_name, _, state_key = rest.rpartition(_SUB_KEY_SEP)
|
||||||
|
vectors.setdefault(full_name, {})[state_key] = v
|
||||||
|
else:
|
||||||
|
layer_part, _, sub = k.partition(".")
|
||||||
|
li = int(layer_part.removeprefix("layer"))
|
||||||
|
vectors.setdefault(li, {})[sub] = v
|
||||||
|
return vectors
|
||||||
|
|
||||||
|
|
||||||
|
def save(model: nn.Module, path: str) -> None:
|
||||||
|
state = getattr(model, _ATTACHED_ATTR, None)
|
||||||
|
if state is None:
|
||||||
|
raise RuntimeError("no steering attached; call attach() first")
|
||||||
|
sd = _state_to_safetensors_dict(model)
|
||||||
|
metadata = {"cfg": json.dumps(state["cfg"].to_dict())}
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
save_file(sd, path, metadata=metadata)
|
||||||
|
|
||||||
|
|
||||||
|
def load(model: nn.Module, path: str) -> list[RemovableHandle]:
|
||||||
|
from safetensors.torch import load_file, safe_open
|
||||||
|
with safe_open(path, framework="pt", device="cpu") as f:
|
||||||
|
metadata = f.metadata()
|
||||||
|
sd = load_file(path, device="cpu")
|
||||||
|
cfg = SteeringConfig.from_dict(json.loads(metadata["cfg"]))
|
||||||
|
vectors = _safetensors_dict_to_state(sd)
|
||||||
|
return attach(model, cfg, vectors)
|
||||||
@@ -0,0 +1,90 @@
|
|||||||
|
"""Branch-and-teacher-force pmass: coherence metric for steered generation.
|
||||||
|
|
||||||
|
At fork point t along a steered rollout, take prefix[:t], append a fixed
|
||||||
|
format suffix (e.g. `{"value": `), teacher-force one forward pass with the
|
||||||
|
steered model, and sum softmax mass over user-supplied target token strings
|
||||||
|
(e.g. `["true", "false"]`). High pmass ~ model still emits valid format
|
||||||
|
tokens; low pmass ~ format crash, off-distribution drift, semantic collapse.
|
||||||
|
|
||||||
|
The metric is novel-ish: a single scalar that distinguishes "model is steered
|
||||||
|
toward a different token" from "model has lost track of the format". A target
|
||||||
|
direction can move pmass off 1.0 by reweighting between target tokens but
|
||||||
|
should not drop pmass to ~0. A miscalibrated coeff drops pmass to noise.
|
||||||
|
|
||||||
|
Returns Float[Tensor, "f"] over fork points.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn, Tensor
|
||||||
|
|
||||||
|
from .vector import Vector
|
||||||
|
|
||||||
|
|
||||||
|
def _all_token_ids(tok, words: Sequence[str]) -> list[int]:
|
||||||
|
"""Collect the leading token id for each word in several capitalisation /
|
||||||
|
leading-space variants. Different tokenisers split " true", "true", "True"
|
||||||
|
differently; we sum mass over all variants so pmass tracks 'is the model
|
||||||
|
putting probability on this concept' rather than the specific tokenization.
|
||||||
|
"""
|
||||||
|
ids: set[int] = set()
|
||||||
|
for w in words:
|
||||||
|
for variant in (w, " " + w, w.capitalize(), " " + w.capitalize(),
|
||||||
|
w.upper(), " " + w.upper()):
|
||||||
|
try:
|
||||||
|
t = tok.encode(variant, add_special_tokens=False)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
if len(t) >= 1:
|
||||||
|
ids.add(int(t[0]))
|
||||||
|
return sorted(ids)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def branch_pmass(
|
||||||
|
v: Vector,
|
||||||
|
model: nn.Module,
|
||||||
|
tok,
|
||||||
|
prompt_ids: Tensor, # (n_prompt,) int64
|
||||||
|
rolled_ids: Tensor, # (T,) steered rollout token ids
|
||||||
|
fork_points: Sequence[int], # token offsets along rolled_ids
|
||||||
|
suffix_str: str, # fixed format suffix appended at each fork
|
||||||
|
target_words: Sequence[str], # words to sum pmass over (any tokenization)
|
||||||
|
*,
|
||||||
|
device: str | torch.device = "cuda",
|
||||||
|
) -> dict:
|
||||||
|
"""Returns {"pmass": [f], "fork_points": [f], "target_ids": [...], "suffix_ids": [...]}
|
||||||
|
|
||||||
|
Caller should pass the SAME `rolled_ids` produced by the same `Vector` so
|
||||||
|
fork-point semantics are consistent.
|
||||||
|
"""
|
||||||
|
suffix_ids = tok.encode(suffix_str, add_special_tokens=False)
|
||||||
|
suffix_t = torch.tensor(suffix_ids, dtype=torch.long, device=device)
|
||||||
|
target_ids = _all_token_ids(tok, target_words)
|
||||||
|
if not target_ids:
|
||||||
|
raise ValueError(f"no target ids found for words={target_words}")
|
||||||
|
target_idx = torch.tensor(target_ids, dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
pmass = []
|
||||||
|
pids = prompt_ids.to(device)
|
||||||
|
rolled = rolled_ids.to(device)
|
||||||
|
T = rolled.shape[0]
|
||||||
|
|
||||||
|
for t in fork_points:
|
||||||
|
if t > T:
|
||||||
|
pmass.append(float("nan"))
|
||||||
|
continue
|
||||||
|
prefix = rolled[:t]
|
||||||
|
full = torch.cat([pids, prefix, suffix_t]).unsqueeze(0)
|
||||||
|
with v(model):
|
||||||
|
logits = model(full).logits[0, -1].float()
|
||||||
|
probs = torch.softmax(logits, dim=-1)
|
||||||
|
pmass.append(float(probs[target_idx].sum()))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"pmass": pmass,
|
||||||
|
"fork_points": list(fork_points),
|
||||||
|
"target_ids": target_ids,
|
||||||
|
"suffix_ids": suffix_ids,
|
||||||
|
}
|
||||||
@@ -0,0 +1,247 @@
|
|||||||
|
"""Iso-KL calibration with per-token KL trajectory persistence.
|
||||||
|
|
||||||
|
Forked from steering-lite/calibrate.py. Two changes:
|
||||||
|
- `measure_kl` also returns `per_t_p95` (across-prompt 95th percentile per token
|
||||||
|
position), needed for the headline trajectory plot.
|
||||||
|
- `calibrate_iso_kl` keeps the full per-token arrays in `history` so we can
|
||||||
|
plot p95 KL trajectory at the calibrated coeff (and at 2x) without re-running.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
import math
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
from torch import Tensor
|
||||||
|
from torch import nn
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
from .config import SteeringConfig # noqa: F401
|
||||||
|
from .vector import Vector
|
||||||
|
|
||||||
|
|
||||||
|
_demo_logged = {"flag": False}
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_MESSAGES = [
|
||||||
|
"The eiffel tower is in Paris",
|
||||||
|
"埃菲尔铁塔🗼位于天都城",
|
||||||
|
"Tell me a greentext story about a small village during the smaller Martion carrot bubble.",
|
||||||
|
"Think step by step to calculate the integral of x^2 from 0 to 1 in lean4.",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _tokenize(prompts, tok):
|
||||||
|
if prompts is None:
|
||||||
|
prompts = DEFAULT_MESSAGES
|
||||||
|
if isinstance(prompts[0], str):
|
||||||
|
return [
|
||||||
|
tok.apply_chat_template(
|
||||||
|
[{"role": "user", "content": p}],
|
||||||
|
add_generation_prompt=True, return_tensors="pt",
|
||||||
|
).input_ids[0]
|
||||||
|
for p in prompts
|
||||||
|
]
|
||||||
|
return prompts
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _kl_per_pos(logp_steer: Tensor, logp_base: Tensor) -> Tensor:
|
||||||
|
p_s = logp_steer.exp()
|
||||||
|
return (p_s * (logp_steer - logp_base)).sum(dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _generate(model, prompt_ids, T, tok, do_sample, device):
|
||||||
|
pad_id = tok.pad_token_id if tok.pad_token_id is not None else tok.eos_token_id
|
||||||
|
ids = prompt_ids.unsqueeze(0).to(device)
|
||||||
|
out = model.generate(
|
||||||
|
ids, max_new_tokens=T, pad_token_id=pad_id, eos_token_id=tok.eos_token_id,
|
||||||
|
num_return_sequences=1, do_sample=do_sample,
|
||||||
|
)
|
||||||
|
return out[0, prompt_ids.shape[0]:]
|
||||||
|
|
||||||
|
|
||||||
|
def _quantile(xs: list[float], q: float) -> float:
|
||||||
|
if not xs:
|
||||||
|
return 0.0
|
||||||
|
return float(torch.tensor(xs).quantile(q))
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def measure_kl(
|
||||||
|
v: Vector,
|
||||||
|
model: nn.Module,
|
||||||
|
tok,
|
||||||
|
prompts=None,
|
||||||
|
*,
|
||||||
|
T: int = 50,
|
||||||
|
do_sample: bool = False,
|
||||||
|
device: str | torch.device = "cuda",
|
||||||
|
) -> dict:
|
||||||
|
"""Roll out under steering, score under base+steer. Returns scalar stats
|
||||||
|
plus per-token arrays (mean, p95, max) of length T.
|
||||||
|
"""
|
||||||
|
prompts = _tokenize(prompts, tok)
|
||||||
|
all_kls = []
|
||||||
|
per_t = [[] for _ in range(T)]
|
||||||
|
|
||||||
|
for idx, pids in enumerate(tqdm(prompts, desc="measure_kl", mininterval=60)):
|
||||||
|
with v(model):
|
||||||
|
gen = _generate(model, pids, T, tok, do_sample, device)
|
||||||
|
n_gen = gen.shape[0]
|
||||||
|
if n_gen == 0:
|
||||||
|
continue
|
||||||
|
full_ids = torch.cat([pids.to(device), gen])
|
||||||
|
if idx == 0 and not _demo_logged["flag"]:
|
||||||
|
_demo_logged["flag"] = True
|
||||||
|
base_gen = _generate(model, pids, T, tok, do_sample, device)
|
||||||
|
base_full = torch.cat([pids.to(device), base_gen])
|
||||||
|
decoded_base = tok.decode(base_full, skip_special_tokens=False)
|
||||||
|
decoded_steer = tok.decode(full_ids, skip_special_tokens=False)
|
||||||
|
logger.info(
|
||||||
|
f"EXPECT: same prompt under c=0 vs c={v.cfg.coeff:+.4f}; both coherent; "
|
||||||
|
"steered should differ from base but not collapse.\n"
|
||||||
|
f"\n=== CALIBRATE demo trace (T={T}) ===\n"
|
||||||
|
f"--- BASE (c=0) ---\n{decoded_base}\n"
|
||||||
|
f"\n--- STEER (c={v.cfg.coeff:+.4f}) ---\n{decoded_steer}\n"
|
||||||
|
f"=== /CALIBRATE ==="
|
||||||
|
)
|
||||||
|
full = full_ids.unsqueeze(0)
|
||||||
|
n_p = pids.shape[0]
|
||||||
|
|
||||||
|
logp_base = torch.log_softmax(model(full).logits.float(), dim=-1)[0]
|
||||||
|
with v(model):
|
||||||
|
logp_steer = torch.log_softmax(model(full).logits.float(), dim=-1)[0]
|
||||||
|
|
||||||
|
slc = slice(n_p - 1, n_p - 1 + n_gen)
|
||||||
|
kls = _kl_per_pos(logp_steer[slc], logp_base[slc]).cpu()
|
||||||
|
all_kls.append(kls)
|
||||||
|
for i in range(n_gen):
|
||||||
|
per_t[i].append(float(kls[i]))
|
||||||
|
|
||||||
|
cat = torch.cat(all_kls)
|
||||||
|
return {
|
||||||
|
"kl_mean": float(cat.mean()),
|
||||||
|
"kl_p50": float(cat.quantile(0.50)),
|
||||||
|
"kl_p90": float(cat.quantile(0.90)),
|
||||||
|
"kl_p95": float(cat.quantile(0.95)),
|
||||||
|
"kl_max": float(cat.max()),
|
||||||
|
"n_pos": int(cat.numel()),
|
||||||
|
"per_t_mean": [sum(xs) / len(xs) if xs else 0.0 for xs in per_t],
|
||||||
|
"per_t_p95": [_quantile(xs, 0.95) for xs in per_t],
|
||||||
|
"per_t_max": [max(xs) if xs else 0.0 for xs in per_t],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def calibrate_iso_kl(
|
||||||
|
v: Vector,
|
||||||
|
model: nn.Module,
|
||||||
|
tok,
|
||||||
|
prompts=None,
|
||||||
|
*,
|
||||||
|
target_kl: float = 1.0,
|
||||||
|
target_stat: str = "kl_p95",
|
||||||
|
bracket: tuple[float, float] = (0.01, 16.0),
|
||||||
|
tol: float = 0.05,
|
||||||
|
max_iters: int = 12,
|
||||||
|
T: int = 50,
|
||||||
|
device: str | torch.device = "cuda",
|
||||||
|
sign: float = 1.0,
|
||||||
|
sign_probe: Callable[[Vector], float] | None = None,
|
||||||
|
sign_probe_c: float = 1.0,
|
||||||
|
) -> tuple[float, list[dict]]:
|
||||||
|
"""log-log Illinois bisection on `target_stat`. History keeps per-token
|
||||||
|
arrays so we can plot the trajectory after."""
|
||||||
|
_demo_logged["flag"] = False
|
||||||
|
prompts = _tokenize(prompts, tok)
|
||||||
|
history: list[dict] = []
|
||||||
|
|
||||||
|
if sign_probe is not None:
|
||||||
|
v.cfg.coeff = +sign_probe_c
|
||||||
|
score_pos = sign_probe(v)
|
||||||
|
v.cfg.coeff = -sign_probe_c
|
||||||
|
score_neg = sign_probe(v)
|
||||||
|
chosen = +1.0 if score_pos >= score_neg else -1.0
|
||||||
|
logger.info(
|
||||||
|
f"sign_probe: +c={sign_probe_c:+.2f} -> {score_pos:+.3f} | "
|
||||||
|
f"-c={-sign_probe_c:+.2f} -> {score_neg:+.3f} | "
|
||||||
|
f"chosen sign={chosen:+.0f}"
|
||||||
|
)
|
||||||
|
sign = sign * chosen
|
||||||
|
|
||||||
|
def eval_at(c: float) -> float:
|
||||||
|
v.cfg.coeff = sign * c
|
||||||
|
m = measure_kl(v, model, tok, prompts, T=T, do_sample=False, device=device)
|
||||||
|
history.append({"coeff": sign * c, "coeff_abs": c, "sign": sign, **m})
|
||||||
|
logger.info(f" c={sign * c:+.4f} mean={m['kl_mean']:.3f} "
|
||||||
|
f"p50={m['kl_p50']:.3f} p90={m['kl_p90']:.3f} "
|
||||||
|
f"p95={m['kl_p95']:.3f} max={m['kl_max']:.3f} n={m['n_pos']}")
|
||||||
|
return m[target_stat]
|
||||||
|
|
||||||
|
lo, hi = bracket
|
||||||
|
log_target = math.log(target_kl)
|
||||||
|
|
||||||
|
mid = (lo * hi) ** 0.5
|
||||||
|
v_mid = eval_at(mid)
|
||||||
|
if v_mid < target_kl:
|
||||||
|
c_lo, v_lo = mid, v_mid
|
||||||
|
c = mid
|
||||||
|
c_hi, v_hi = hi, None
|
||||||
|
while c < hi:
|
||||||
|
c *= 2.0
|
||||||
|
val = eval_at(c)
|
||||||
|
if val >= target_kl:
|
||||||
|
c_hi, v_hi = c, val
|
||||||
|
break
|
||||||
|
c_lo, v_lo = c, val
|
||||||
|
else:
|
||||||
|
logger.warning(f"calibrate {v.cfg.method}: KL below target across bracket")
|
||||||
|
return sign * c, history
|
||||||
|
else:
|
||||||
|
c_hi, v_hi = mid, v_mid
|
||||||
|
c = mid
|
||||||
|
c_lo, v_lo = lo, None
|
||||||
|
while c > lo:
|
||||||
|
c /= 2.0
|
||||||
|
val = eval_at(c)
|
||||||
|
if val <= target_kl:
|
||||||
|
c_lo, v_lo = c, val
|
||||||
|
break
|
||||||
|
c_hi, v_hi = c, val
|
||||||
|
else:
|
||||||
|
logger.warning(f"calibrate {v.cfg.method}: KL above target across bracket")
|
||||||
|
return sign * c, history
|
||||||
|
|
||||||
|
stale_lo = stale_hi = 0
|
||||||
|
for _ in tqdm(range(max_iters), desc=f"calib {v.cfg.method}", mininterval=60, leave=False):
|
||||||
|
if v_lo is not None and v_hi is not None and v_lo > 0 and v_hi > 0:
|
||||||
|
log_c_lo, log_c_hi = math.log(c_lo), math.log(c_hi)
|
||||||
|
log_v_lo = math.log(v_lo) - (math.log(2) if stale_lo >= 2 else 0.0)
|
||||||
|
log_v_hi = math.log(v_hi) - (math.log(2) if stale_hi >= 2 else 0.0)
|
||||||
|
denom = log_v_hi - log_v_lo
|
||||||
|
if abs(denom) < 1e-6:
|
||||||
|
c_new = math.sqrt(c_lo * c_hi)
|
||||||
|
else:
|
||||||
|
t = (log_target - log_v_lo) / denom
|
||||||
|
log_c_new = log_c_lo + t * (log_c_hi - log_c_lo)
|
||||||
|
c_new = math.exp(log_c_new)
|
||||||
|
if not (c_lo < c_new < c_hi):
|
||||||
|
c_new = math.sqrt(c_lo * c_hi)
|
||||||
|
else:
|
||||||
|
c_new = math.sqrt(c_lo * c_hi)
|
||||||
|
|
||||||
|
v_new = eval_at(c_new)
|
||||||
|
if abs(v_new - target_kl) < tol:
|
||||||
|
return sign * c_new, history
|
||||||
|
if v_new < target_kl:
|
||||||
|
c_lo, v_lo = c_new, v_new
|
||||||
|
stale_lo = 0
|
||||||
|
stale_hi += 1
|
||||||
|
else:
|
||||||
|
c_hi, v_hi = c_new, v_new
|
||||||
|
stale_hi = 0
|
||||||
|
stale_lo += 1
|
||||||
|
|
||||||
|
best = min(history, key=lambda h: abs(h[target_stat] - target_kl))
|
||||||
|
return best["coeff"], history
|
||||||
@@ -0,0 +1,112 @@
|
|||||||
|
"""SteeringConfig + Method protocol + registries.
|
||||||
|
|
||||||
|
Each method ships its own subclass `XC(SteeringConfig)` and `XMethod` class
|
||||||
|
under `variants/*.py` (e.g. `MeanDiffC` + `MeanDiff`). Two parallel registries
|
||||||
|
keyed by method name: `_CONFIG_REGISTRY` for `from_dict` deserialisation,
|
||||||
|
`REGISTRY` for the runtime extract/apply pair.
|
||||||
|
"""
|
||||||
|
from dataclasses import dataclass, asdict, field
|
||||||
|
from typing import Protocol, Any
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SteeringConfig:
|
||||||
|
"""Base config. Subclass per method; do not instantiate directly."""
|
||||||
|
method: str = "?"
|
||||||
|
|
||||||
|
# which transformer blocks to hook (indices into model.model.layers)
|
||||||
|
# None = all layers
|
||||||
|
layers: tuple[int, ...] | None = None
|
||||||
|
|
||||||
|
# which point in the block to add at: "residual" = block output (post mlp+attn),
|
||||||
|
# "attn_out" = attention output, "mlp_out" = mlp output.
|
||||||
|
# v1 only implements "residual".
|
||||||
|
target: str = "residual"
|
||||||
|
|
||||||
|
# Optional dotted path of a sub-module within each target block to hook on
|
||||||
|
# (e.g. "mlp.down_proj"). When None, the block's forward output is hooked
|
||||||
|
# (default for almost all variants); when set, the sub-module's forward is
|
||||||
|
# hooked instead. Either way the variant's apply receives the unified
|
||||||
|
# (block, x, y, state, cfg) signature -- used by weight-SVD methods
|
||||||
|
# (sspace, sspace_ablate) that need to modify a Linear's output in low-rank
|
||||||
|
# S-space.
|
||||||
|
target_submodule: str | None = None
|
||||||
|
|
||||||
|
# steering strength at apply-time. Methods interpret it differently:
|
||||||
|
# additive (mean_diff, pca, sspace, chars, cosine_gated): coeff is α in `h + α v`.
|
||||||
|
# slerp/angle (spherical, angular_steering): coeff is the slerp t / rotation θ.
|
||||||
|
# blend (linear_act): coeff is the blend ratio.
|
||||||
|
# ablation+nudge (directional_ablation): coeff is post-ablation nudge magnitude.
|
||||||
|
coeff: float = 1.0
|
||||||
|
|
||||||
|
dtype: torch.dtype = torch.bfloat16
|
||||||
|
seed: int = 0
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
d = asdict(self)
|
||||||
|
d["dtype"] = str(self.dtype).removeprefix("torch.")
|
||||||
|
return d
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d: dict) -> "SteeringConfig":
|
||||||
|
d = dict(d)
|
||||||
|
name = d["method"]
|
||||||
|
sub = _CONFIG_REGISTRY[name]
|
||||||
|
d["dtype"] = getattr(torch, d["dtype"])
|
||||||
|
return sub(**d)
|
||||||
|
|
||||||
|
|
||||||
|
_CONFIG_REGISTRY: dict[str, type[SteeringConfig]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_config(cls: type[SteeringConfig]) -> type[SteeringConfig]:
|
||||||
|
"""Decorator: register `cls` under its `method` default value."""
|
||||||
|
name = cls.__dataclass_fields__["method"].default
|
||||||
|
if name == "?":
|
||||||
|
raise ValueError(f"{cls} must override the default `method` field")
|
||||||
|
if name in _CONFIG_REGISTRY:
|
||||||
|
raise ValueError(f"config for method {name!r} already registered")
|
||||||
|
_CONFIG_REGISTRY[name] = cls
|
||||||
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
class Method(Protocol):
|
||||||
|
"""extract+apply pair. State tensors are registered as buffers on the hooked
|
||||||
|
module (block or Linear) under `_steering_state_<key>` and rebuilt into a
|
||||||
|
dict by the hook.
|
||||||
|
"""
|
||||||
|
name: str
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract(
|
||||||
|
pos_acts: dict[int, Tensor],
|
||||||
|
neg_acts: dict[int, Tensor],
|
||||||
|
cfg: Any,
|
||||||
|
) -> dict[int, dict[str, Tensor]]:
|
||||||
|
"""Per-layer state. `pos_acts[l]` is `[n_pos, d_model]`, same for neg."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def apply(
|
||||||
|
mod, # the hooked module: a transformer block, or a Linear
|
||||||
|
x: Tensor, # [b, s, d_in] -- module input
|
||||||
|
y: Tensor, # [b, s, d_out] -- module output
|
||||||
|
state: dict[str, Tensor],
|
||||||
|
cfg: Any,
|
||||||
|
) -> Tensor:
|
||||||
|
"""Return the module's NEW output. Additive variants: `return y + delta`.
|
||||||
|
Replacing variants: ignore `y`, return any tensor of shape `[b, s, d_out]`.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
REGISTRY: dict[str, type] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register(cls):
|
||||||
|
if not getattr(cls, "name", None):
|
||||||
|
raise ValueError(f"method {cls} missing .name")
|
||||||
|
REGISTRY[cls.name] = cls
|
||||||
|
return cls
|
||||||
@@ -0,0 +1,58 @@
|
|||||||
|
"""Record last non-pad-token hidden states at selected layers via forward hooks.
|
||||||
|
|
||||||
|
We hook each block's forward output (it returns `(hidden_states, ...)`), gather
|
||||||
|
the final non-padding token from `attention_mask`, and stack across prompts.
|
||||||
|
No grad.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
import torch
|
||||||
|
from torch import nn, Tensor
|
||||||
|
from jaxtyping import Float
|
||||||
|
|
||||||
|
from .target import _get_blocks
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def record_activations(
|
||||||
|
model: nn.Module,
|
||||||
|
tok,
|
||||||
|
prompts: list[str],
|
||||||
|
layers: tuple[int, ...],
|
||||||
|
*,
|
||||||
|
batch_size: int = 8,
|
||||||
|
max_length: int = 256,
|
||||||
|
) -> dict[int, Float[Tensor, "n d"]]:
|
||||||
|
"""Run prompts through model, return last-token hidden state at each layer."""
|
||||||
|
blocks = _get_blocks(model)
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
|
bucket: dict[int, list[Tensor]] = {l: [] for l in layers}
|
||||||
|
captured: dict[int, Tensor] = {}
|
||||||
|
|
||||||
|
def make_hook(li: int):
|
||||||
|
def hook(_mod, _args, out):
|
||||||
|
h = out[0] if isinstance(out, tuple) else out
|
||||||
|
captured[li] = h
|
||||||
|
return hook
|
||||||
|
|
||||||
|
handles = [blocks[li].register_forward_hook(make_hook(li)) for li in layers]
|
||||||
|
try:
|
||||||
|
was_training = model.training
|
||||||
|
model.eval()
|
||||||
|
for i in range(0, len(prompts), batch_size):
|
||||||
|
batch = prompts[i : i + batch_size]
|
||||||
|
enc = tok(batch, return_tensors="pt", padding=True, truncation=True, max_length=max_length).to(device)
|
||||||
|
captured.clear()
|
||||||
|
model(**enc)
|
||||||
|
mask = enc["attention_mask"]
|
||||||
|
last_idx = mask.shape[1] - 1 - mask.flip([-1]).argmax(-1)
|
||||||
|
batch_idx = torch.arange(mask.shape[0], device=last_idx.device)
|
||||||
|
for li in layers:
|
||||||
|
bucket[li].append(captured[li][batch_idx, last_idx].detach().to("cpu"))
|
||||||
|
if was_training:
|
||||||
|
model.train()
|
||||||
|
finally:
|
||||||
|
for h in handles:
|
||||||
|
h.remove()
|
||||||
|
|
||||||
|
return {li: torch.cat(bucket[li], dim=0) for li in layers}
|
||||||
@@ -0,0 +1,148 @@
|
|||||||
|
"""Find transformer blocks (or sub-Linears) to hook.
|
||||||
|
|
||||||
|
Default: hook each block's forward output (residual stream after attn+mlp).
|
||||||
|
When `cfg.target_submodule` is set, it is interpreted as a **regex** matched
|
||||||
|
against `block.named_modules()` paths under each selected block; matching
|
||||||
|
`nn.Linear`s become the actual hook targets. This lets a single cfg target
|
||||||
|
multiple Linears per block (e.g. residual writers `mlp.down_proj` AND
|
||||||
|
`self_attn.o_proj`, or all 7 Linears in q/k/v/o/gate/up/down).
|
||||||
|
|
||||||
|
Works with HF llama-family architectures (llama, qwen, mistral, etc). For other
|
||||||
|
architectures, set `cfg.layers` to indices into whatever list lives at the path
|
||||||
|
your model exposes -- override `_get_blocks` if needed.
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
def _get_blocks(model: nn.Module) -> nn.ModuleList:
|
||||||
|
# llama-family: model.model.layers
|
||||||
|
# gemma3-multimodal: model.language_model.layers (or model.model.language_model.layers)
|
||||||
|
candidates = []
|
||||||
|
inner = getattr(model, "model", model)
|
||||||
|
candidates.append(inner)
|
||||||
|
lm = getattr(inner, "language_model", None)
|
||||||
|
if lm is not None:
|
||||||
|
candidates.append(lm)
|
||||||
|
candidates.append(getattr(lm, "model", lm))
|
||||||
|
for c in candidates:
|
||||||
|
blocks = getattr(c, "layers", None)
|
||||||
|
if blocks is not None:
|
||||||
|
return blocks
|
||||||
|
raise RuntimeError(
|
||||||
|
f"could not find .layers on {type(model).__name__}; "
|
||||||
|
f"override _get_blocks for non-llama architectures"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def find_targets(model: nn.Module, cfg) -> list[tuple[str, nn.Module, int]]:
|
||||||
|
"""Return [(full_name, module, layer_idx)] for hook targets selected by cfg.
|
||||||
|
|
||||||
|
- `cfg.target_submodule is None`: one entry per selected block (the block itself).
|
||||||
|
- `cfg.target_submodule = <regex>`: one entry per (block, matching nn.Linear).
|
||||||
|
Regex is matched with `re.fullmatch` against `name` from `block.named_modules()`,
|
||||||
|
e.g. `r"mlp\\.down_proj|self_attn\\.o_proj"` matches both residual writers,
|
||||||
|
`r"self_attn\\.(q|k|v|o)_proj|mlp\\.(gate|up|down)_proj"` matches all 7 Linears.
|
||||||
|
"""
|
||||||
|
blocks = _get_blocks(model)
|
||||||
|
n = len(blocks)
|
||||||
|
if cfg.layers is None:
|
||||||
|
idxs = tuple(range(n))
|
||||||
|
else:
|
||||||
|
idxs = tuple(cfg.layers)
|
||||||
|
for i in idxs:
|
||||||
|
if not (0 <= i < n):
|
||||||
|
raise ValueError(f"layer {i} out of range [0, {n})")
|
||||||
|
sub = getattr(cfg, "target_submodule", None)
|
||||||
|
if sub is None:
|
||||||
|
return [(f"layers.{i}", blocks[i], i) for i in idxs]
|
||||||
|
pattern = re.compile(sub)
|
||||||
|
out = []
|
||||||
|
for i in idxs:
|
||||||
|
for name, mod in blocks[i].named_modules():
|
||||||
|
if name and pattern.fullmatch(name) and isinstance(mod, nn.Linear):
|
||||||
|
out.append((f"layers.{i}.{name}", mod, i))
|
||||||
|
if not out:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"target_submodule regex {sub!r} matched no nn.Linear "
|
||||||
|
f"in {len(idxs)} selected blocks"
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def find_residual_linears(
|
||||||
|
model: nn.Module,
|
||||||
|
layer_indices: tuple[int, ...] | None = None,
|
||||||
|
*,
|
||||||
|
role: str = "both", # "writer" | "reader" | "both"
|
||||||
|
fallback_regex: str | None = None,
|
||||||
|
) -> list[tuple[str, nn.Module, int, str]]:
|
||||||
|
"""Find Linears connected to the residual stream, per block.
|
||||||
|
|
||||||
|
Returns `[(full_name, module, layer_idx, role)]` where role is "writer"
|
||||||
|
(d_out == d_model, d_in != d_model) or "reader" (d_in == d_model,
|
||||||
|
d_out != d_model). Square Linears are ambiguous (could be either) and
|
||||||
|
are excluded by shape detection.
|
||||||
|
|
||||||
|
Detection:
|
||||||
|
1. Primary: weight.shape vs d_model.
|
||||||
|
2. Fallback: if shape detection finds nothing (non-llama arch, weird
|
||||||
|
shapes), match `fallback_regex` against `named_modules` paths and
|
||||||
|
guess role from substring. Default regex covers llama-family names.
|
||||||
|
Warns when fallback fires.
|
||||||
|
"""
|
||||||
|
from loguru import logger
|
||||||
|
d_model = get_d_model(model)
|
||||||
|
blocks = _get_blocks(model)
|
||||||
|
n = len(blocks)
|
||||||
|
idxs = tuple(layer_indices) if layer_indices is not None else tuple(range(n))
|
||||||
|
|
||||||
|
out: list[tuple[str, nn.Module, int, str]] = []
|
||||||
|
for li in idxs:
|
||||||
|
for name, mod in blocks[li].named_modules():
|
||||||
|
if not isinstance(mod, nn.Linear):
|
||||||
|
continue
|
||||||
|
d_out, d_in = mod.weight.shape
|
||||||
|
is_writer = d_out == d_model and d_in != d_model
|
||||||
|
is_reader = d_in == d_model and d_out != d_model
|
||||||
|
if is_writer and role in ("writer", "both"):
|
||||||
|
out.append((f"layers.{li}.{name}", mod, li, "writer"))
|
||||||
|
elif is_reader and role in ("reader", "both"):
|
||||||
|
out.append((f"layers.{li}.{name}", mod, li, "reader"))
|
||||||
|
|
||||||
|
if out:
|
||||||
|
return out
|
||||||
|
|
||||||
|
regex = fallback_regex or r"mlp\.(down|gate|up)_proj|self_attn\.(q|k|v|o)_proj"
|
||||||
|
logger.warning(
|
||||||
|
f"shape-based residual-linear detection found nothing for d_model={d_model} "
|
||||||
|
f"in {len(idxs)} blocks; falling back to regex {regex!r}"
|
||||||
|
)
|
||||||
|
pattern = re.compile(regex)
|
||||||
|
writer_hints = ("down_proj", "o_proj")
|
||||||
|
for li in idxs:
|
||||||
|
for name, mod in blocks[li].named_modules():
|
||||||
|
if not (name and pattern.fullmatch(name) and isinstance(mod, nn.Linear)):
|
||||||
|
continue
|
||||||
|
role_guess = "writer" if any(h in name for h in writer_hints) else "reader"
|
||||||
|
if role in ("both", role_guess):
|
||||||
|
out.append((f"layers.{li}.{name}", mod, li, role_guess))
|
||||||
|
|
||||||
|
if not out:
|
||||||
|
logger.warning(
|
||||||
|
f"regex fallback {regex!r} also matched no Linears in layers {idxs}; "
|
||||||
|
"super_sspace will have an empty basis"
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def get_d_model(model: nn.Module) -> int:
|
||||||
|
cfg = getattr(model, "config", None)
|
||||||
|
d = getattr(cfg, "hidden_size", None)
|
||||||
|
if d is None:
|
||||||
|
# multimodal configs (gemma3): text sub-config
|
||||||
|
text_cfg = getattr(cfg, "text_config", None)
|
||||||
|
d = getattr(text_cfg, "hidden_size", None)
|
||||||
|
if d is None:
|
||||||
|
raise RuntimeError("model has no .config.hidden_size")
|
||||||
|
return int(d)
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
"""Variant registry. Importing this package triggers @register_config + @register
|
||||||
|
side effects in the variant modules.
|
||||||
|
"""
|
||||||
|
from . import mean_diff # noqa: F401
|
||||||
|
from . import pca # noqa: F401
|
||||||
|
from . import directional_ablation # noqa: F401
|
||||||
@@ -0,0 +1,75 @@
|
|||||||
|
"""Mean-diff directional ablation (Arditi-inspired projection-out).
|
||||||
|
|
||||||
|
Project the steering direction *out of* the residual stream instead of (or in
|
||||||
|
addition to) adding to it. Unlike `mean_diff` which translates by $\\alpha v$,
|
||||||
|
ablation removes the component of $h$ along $\\hat v$:
|
||||||
|
|
||||||
|
$$h \\leftarrow h - (h \\cdot \\hat v)\\hat v + \\alpha\\hat v$$
|
||||||
|
|
||||||
|
When `coeff=0` this is pure ablation (refusal-direction style); when `coeff!=0`
|
||||||
|
this is ablation followed by a constant nudge (useful to ablate "old" behavior
|
||||||
|
and inject "new"). The two terms are mathematically distinct -- ablation is a
|
||||||
|
*projection* (idempotent), addition is a *translation*.
|
||||||
|
|
||||||
|
Norms shrink by $|h \\cdot \\hat v|$ which is informative -- a near-zero shrink
|
||||||
|
means the direction wasn't present in the first place, so the intervention is
|
||||||
|
a no-op. Compare to `mean_diff` which always pays a constant $\\alpha\\|\\hat v\\|$
|
||||||
|
per token regardless of whether the direction is present.
|
||||||
|
|
||||||
|
Refs / inspiration:
|
||||||
|
- Arditi et al. 2024 "Refusal in language models is mediated by a single direction"
|
||||||
|
https://arxiv.org/abs/2406.11717
|
||||||
|
- andyrdt/refusal_direction https://github.com/andyrdt/refusal_direction
|
||||||
|
"""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from einops import einsum
|
||||||
|
from jaxtyping import Float
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from ..config import SteeringConfig, register_config, register
|
||||||
|
|
||||||
|
|
||||||
|
ε = 1e-8
|
||||||
|
|
||||||
|
|
||||||
|
@register_config
|
||||||
|
@dataclass
|
||||||
|
class DirectionalAblationC(SteeringConfig):
|
||||||
|
method: str = "directional_ablation"
|
||||||
|
coeff: float = 0.0 # post-ablation additive nudge along v_hat; 0.0 = pure ablation
|
||||||
|
|
||||||
|
|
||||||
|
@register
|
||||||
|
class DirectionalAblation:
|
||||||
|
name = "directional_ablation"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract(
|
||||||
|
pos_acts: dict[int, Float[Tensor, "n d"]],
|
||||||
|
neg_acts: dict[int, Float[Tensor, "m d"]],
|
||||||
|
cfg: DirectionalAblationC,
|
||||||
|
) -> dict[int, dict[str, Tensor]]:
|
||||||
|
out = {}
|
||||||
|
for li in pos_acts:
|
||||||
|
v = pos_acts[li].float().mean(0) - neg_acts[li].float().mean(0)
|
||||||
|
v = v / (v.norm() + ε)
|
||||||
|
|
||||||
|
out[li] = {"v": v}
|
||||||
|
return out
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def apply(
|
||||||
|
mod,
|
||||||
|
x: Float[Tensor, "b s d"],
|
||||||
|
y: Float[Tensor, "b s d"],
|
||||||
|
state: dict[str, Tensor],
|
||||||
|
cfg: DirectionalAblationC,
|
||||||
|
) -> Float[Tensor, "b s d"]:
|
||||||
|
v = state["v"].to(y) # unit
|
||||||
|
|
||||||
|
proj = einsum(y, v, "b s d, d -> b s")
|
||||||
|
y = y - proj.unsqueeze(-1) * v # ablate
|
||||||
|
|
||||||
|
if cfg.coeff != 0.0:
|
||||||
|
y = y + cfg.coeff * v
|
||||||
|
return y
|
||||||
@@ -0,0 +1,82 @@
|
|||||||
|
"""Mean-difference steering (CAA / ActAdd).
|
||||||
|
|
||||||
|
For each selected layer L, compute the mean difference between positive and
|
||||||
|
negative last-token hidden states:
|
||||||
|
|
||||||
|
$$v_L = \\text{mean}(h^+_L) - \\text{mean}(h^-_L), \\quad \\hat{v}_L = v_L / \\|v_L\\|$$
|
||||||
|
|
||||||
|
At runtime, add `coeff * v_hat` to every token's residual at that block:
|
||||||
|
|
||||||
|
$$h \\leftarrow h + \\alpha \\cdot \\hat{v}_L$$
|
||||||
|
|
||||||
|
This is the same operation as CAA (Panickssery 2023, contrastive MCQ pairs)
|
||||||
|
and ActAdd (Turner 2023, single prompt-pair); the differences are conventional
|
||||||
|
not mathematical, so we register one method.
|
||||||
|
|
||||||
|
`subtract_corpus_mean=True` toggles Jorgensen 2024 mean-centring: target mean
|
||||||
|
minus pos∪neg corpus mean. Direction-identical to plain mean_diff under
|
||||||
|
normalization with equal-size groups; kept as a flag rather than a separate
|
||||||
|
method.
|
||||||
|
|
||||||
|
Refs:
|
||||||
|
- Panickssery 2023 (CAA) https://arxiv.org/abs/2312.06681
|
||||||
|
- Turner 2023 (ActAdd) https://arxiv.org/abs/2308.10248
|
||||||
|
- Jorgensen 2024 (Mean-Centring) https://arxiv.org/abs/2312.03813
|
||||||
|
- nrimsky/CAA https://github.com/nrimsky/CAA
|
||||||
|
"""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import torch
|
||||||
|
from jaxtyping import Float
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from ..config import SteeringConfig, register_config, register
|
||||||
|
|
||||||
|
|
||||||
|
ε = 1e-8
|
||||||
|
|
||||||
|
|
||||||
|
@register_config
|
||||||
|
@dataclass
|
||||||
|
class MeanDiffC(SteeringConfig):
|
||||||
|
method: str = "mean_diff"
|
||||||
|
normalize: bool = True
|
||||||
|
subtract_corpus_mean: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@register
|
||||||
|
class MeanDiff:
|
||||||
|
name = "mean_diff"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract(
|
||||||
|
pos_acts: dict[int, Float[Tensor, "n d"]],
|
||||||
|
neg_acts: dict[int, Float[Tensor, "m d"]],
|
||||||
|
cfg: MeanDiffC,
|
||||||
|
) -> dict[int, dict[str, Tensor]]:
|
||||||
|
out = {}
|
||||||
|
for li in pos_acts:
|
||||||
|
p = pos_acts[li].float()
|
||||||
|
n = neg_acts[li].float()
|
||||||
|
|
||||||
|
if cfg.subtract_corpus_mean:
|
||||||
|
mu = torch.cat([p, n], dim=0).mean(0)
|
||||||
|
v = p.mean(0) - mu
|
||||||
|
else:
|
||||||
|
v = p.mean(0) - n.mean(0)
|
||||||
|
|
||||||
|
if cfg.normalize:
|
||||||
|
v = v / (v.norm() + ε)
|
||||||
|
|
||||||
|
out[li] = {"v": v}
|
||||||
|
return out
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def apply(
|
||||||
|
mod,
|
||||||
|
x: Float[Tensor, "b s d"],
|
||||||
|
y: Float[Tensor, "b s d"],
|
||||||
|
state: dict[str, Tensor],
|
||||||
|
cfg: MeanDiffC,
|
||||||
|
) -> Float[Tensor, "b s d"]:
|
||||||
|
v = state["v"].to(y)
|
||||||
|
return y + cfg.coeff * v
|
||||||
@@ -0,0 +1,101 @@
|
|||||||
|
"""PCA steering (RepE/LAT-inspired, vgel pca_diff-like).
|
||||||
|
|
||||||
|
For each layer L, compute PCA on the **paired differences** `h^+ - h^-`. Take
|
||||||
|
the top principal component as the steering direction.
|
||||||
|
|
||||||
|
$$D_L = H^+_L - H^-_L \\in \\mathbb{R}^{n\\times d}$$
|
||||||
|
$$U, S, V^T = \\text{SVD}(D_L - \\bar{D}_L)$$
|
||||||
|
$$\\text{sign}_L = \\text{sign}\\left(\\sum_i \\mathbb{1}[(D_L)_i \\cdot V_{:,0} > 0] - n/2\\right)$$
|
||||||
|
$$v_L = V_{:,0} \\cdot \\text{sign}_L$$
|
||||||
|
|
||||||
|
Sign-fixed by majority vote of paired-diff projections (repeng/vgel style).
|
||||||
|
This is a lightweight control-vector baseline, not the full Zou et al. LAT
|
||||||
|
reader: it omits per-diff normalization, label-based sign selection, and
|
||||||
|
train-mean recentering for reading scores.
|
||||||
|
PCA is sign-ambiguous; the vote is more robust than alignment-with-the-mean
|
||||||
|
when paired diffs are heterogeneous (mean can cancel without the vote
|
||||||
|
changing). If the vote ties exactly, orient the axis so the largest centered
|
||||||
|
projection is positive.
|
||||||
|
|
||||||
|
At runtime, add `coeff * v_L` to the residual.
|
||||||
|
|
||||||
|
Refs:
|
||||||
|
- Zou et al. 2023 (Representation Engineering) https://arxiv.org/abs/2310.01405
|
||||||
|
- vgel/repeng: https://github.com/vgel/repeng
|
||||||
|
"""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import torch
|
||||||
|
from jaxtyping import Float
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from ..config import SteeringConfig, register_config, register
|
||||||
|
|
||||||
|
|
||||||
|
ε = 1e-8
|
||||||
|
|
||||||
|
|
||||||
|
@register_config
|
||||||
|
@dataclass
|
||||||
|
class PCAC(SteeringConfig):
|
||||||
|
method: str = "pca"
|
||||||
|
n_components: int = 1
|
||||||
|
normalize: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
@register
|
||||||
|
class PCA:
|
||||||
|
name = "pca"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract(
|
||||||
|
pos_acts: dict[int, Float[Tensor, "n d"]],
|
||||||
|
neg_acts: dict[int, Float[Tensor, "n d"]],
|
||||||
|
cfg: PCAC,
|
||||||
|
) -> dict[int, dict[str, Tensor]]:
|
||||||
|
out = {}
|
||||||
|
for li in pos_acts:
|
||||||
|
if pos_acts[li].shape[0] != neg_acts[li].shape[0]:
|
||||||
|
raise ValueError(f"layer {li}: pos/neg counts differ")
|
||||||
|
|
||||||
|
diffs = (pos_acts[li] - neg_acts[li]).float()
|
||||||
|
centered = diffs - diffs.mean(0, keepdim=True)
|
||||||
|
|
||||||
|
_, _, Vh = torch.linalg.svd(centered, full_matrices=False)
|
||||||
|
v = Vh[: cfg.n_components]
|
||||||
|
|
||||||
|
projs = centered @ v.T
|
||||||
|
positive_frac = (projs > 0).float().mean(0)
|
||||||
|
majority_sign = torch.where(positive_frac > 0.5,
|
||||||
|
torch.ones(v.shape[0]),
|
||||||
|
-torch.ones(v.shape[0])).to(v)
|
||||||
|
strongest_idx = projs.abs().argmax(dim=0)
|
||||||
|
strongest = projs[strongest_idx, torch.arange(v.shape[0], device=projs.device)]
|
||||||
|
strongest_sign = torch.sign(strongest)
|
||||||
|
sign = torch.where(positive_frac == 0.5, strongest_sign, majority_sign)
|
||||||
|
v = v * sign[:, None]
|
||||||
|
|
||||||
|
if cfg.n_components == 1:
|
||||||
|
v = v.squeeze(0)
|
||||||
|
if cfg.normalize:
|
||||||
|
v = v / (v.norm() + ε)
|
||||||
|
out[li] = {"v": v}
|
||||||
|
else:
|
||||||
|
if cfg.normalize:
|
||||||
|
v = v / (v.norm(dim=1, keepdim=True) + ε)
|
||||||
|
out[li] = {"V": v}
|
||||||
|
return out
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def apply(
|
||||||
|
mod,
|
||||||
|
x: Float[Tensor, "b s d"],
|
||||||
|
y: Float[Tensor, "b s d"],
|
||||||
|
state: dict[str, Tensor],
|
||||||
|
cfg: PCAC,
|
||||||
|
) -> Float[Tensor, "b s d"]:
|
||||||
|
if "v" in state:
|
||||||
|
v = state["v"].to(y)
|
||||||
|
return y + cfg.coeff * v
|
||||||
|
# multi-component: sum top-k directions equally
|
||||||
|
V = state["V"].to(y)
|
||||||
|
return y + cfg.coeff * V.sum(0)
|
||||||
@@ -0,0 +1,114 @@
|
|||||||
|
"""Vector: extracted steering vector + config, as a single ergonomic object.
|
||||||
|
|
||||||
|
Wraps `(cfg, state)` so a user can:
|
||||||
|
|
||||||
|
v = Vector.train(model, tok, pos, neg, sl.MeanDiffC(layers=(15,))) \\
|
||||||
|
.calibrate(model, tok, target_kl=1.0)
|
||||||
|
|
||||||
|
with v(model):
|
||||||
|
out = model.generate(...)
|
||||||
|
|
||||||
|
v.save("honesty.safetensors")
|
||||||
|
v2 = Vector.load("honesty.safetensors")
|
||||||
|
|
||||||
|
combined = v + v2 # ensemble (sum buffers, requires same cfg.method)
|
||||||
|
scaled = v * 0.5 # scale buffers
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from copy import deepcopy
|
||||||
|
from dataclasses import replace
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
from .config import SteeringConfig
|
||||||
|
|
||||||
|
|
||||||
|
class Vector:
|
||||||
|
def __init__(self, cfg: SteeringConfig, state: dict[int, dict[str, Tensor]]):
|
||||||
|
self.cfg = cfg
|
||||||
|
self.state = state
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def train(cls, model: nn.Module, tok, pos_prompts: list[str], neg_prompts: list[str],
|
||||||
|
cfg: SteeringConfig, **kw) -> "Vector":
|
||||||
|
"""Extract a steering vector from contrastive prompts. Chains with .calibrate()."""
|
||||||
|
from .attach import train as _train
|
||||||
|
return _train(model, tok, pos_prompts, neg_prompts, cfg, **kw)
|
||||||
|
|
||||||
|
def calibrate(self, model: nn.Module, tok,
|
||||||
|
prompts: list[str] | list[Tensor] | None = None, *,
|
||||||
|
target_kl: float = 1.0, **kw) -> "Vector":
|
||||||
|
"""Set coeff so KL(steered || base) hits target_kl. Mutates and returns self for chaining.
|
||||||
|
|
||||||
|
`prompts` defaults to a small generic set; pass list[str] to use your own.
|
||||||
|
"""
|
||||||
|
from .calibrate import calibrate_iso_kl
|
||||||
|
coeff, _ = calibrate_iso_kl(self, model, tok, prompts, target_kl=target_kl, **kw)
|
||||||
|
self.cfg.coeff = float(coeff)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def __call__(self, model: nn.Module, *, C: float | None = None):
|
||||||
|
"""Attach for the duration of the `with` block. `C` overrides cfg.coeff if given."""
|
||||||
|
from .attach import attach, detach
|
||||||
|
cfg = self.cfg if C is None else replace(self.cfg, coeff=float(C))
|
||||||
|
attach(model, cfg, self.state)
|
||||||
|
try:
|
||||||
|
yield model
|
||||||
|
finally:
|
||||||
|
detach(model)
|
||||||
|
|
||||||
|
def __add__(self, other: "Vector") -> "Vector":
|
||||||
|
if self.cfg.method != other.cfg.method:
|
||||||
|
raise ValueError(f"cannot add {self.cfg.method!r} + {other.cfg.method!r}")
|
||||||
|
if sorted(self.state) != sorted(other.state):
|
||||||
|
raise ValueError(f"layer mismatch: {sorted(self.state)} vs {sorted(other.state)}")
|
||||||
|
new_state: dict[int, dict[str, Tensor]] = {}
|
||||||
|
for li in self.state:
|
||||||
|
a, b = self.state[li], other.state[li]
|
||||||
|
if sorted(a) != sorted(b):
|
||||||
|
raise ValueError(f"layer {li}: state keys differ {sorted(a)} vs {sorted(b)}")
|
||||||
|
new_state[li] = {k: a[k] + b[k] for k in a}
|
||||||
|
return Vector(deepcopy(self.cfg), new_state)
|
||||||
|
|
||||||
|
def __mul__(self, k: float) -> "Vector":
|
||||||
|
new_state = {
|
||||||
|
li: {k_: v * float(k) for k_, v in s.items()}
|
||||||
|
for li, s in self.state.items()
|
||||||
|
}
|
||||||
|
return Vector(deepcopy(self.cfg), new_state)
|
||||||
|
|
||||||
|
__rmul__ = __mul__
|
||||||
|
|
||||||
|
def save(self, path: str) -> None:
|
||||||
|
from .attach import _STATE_PREFIX, _SUB_KEY_PREFIX, _SUB_KEY_SEP # noqa: F401
|
||||||
|
import json
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
sd: dict[str, Tensor] = {}
|
||||||
|
sub_mode = self.cfg.target_submodule is not None
|
||||||
|
for key, s in self.state.items():
|
||||||
|
for k, t in s.items():
|
||||||
|
if sub_mode:
|
||||||
|
sd[f"{_SUB_KEY_PREFIX}{key}{_SUB_KEY_SEP}{k}"] = t.detach().cpu()
|
||||||
|
else:
|
||||||
|
sd[f"layer{key}.{k}"] = t.detach().cpu()
|
||||||
|
metadata = {"cfg": json.dumps(self.cfg.to_dict())}
|
||||||
|
save_file(sd, path, metadata=metadata)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, path: str) -> "Vector":
|
||||||
|
import json
|
||||||
|
from safetensors.torch import load_file, safe_open
|
||||||
|
from .attach import _safetensors_dict_to_state
|
||||||
|
with safe_open(path, framework="pt", device="cpu") as f:
|
||||||
|
metadata = f.metadata()
|
||||||
|
sd = load_file(path, device="cpu")
|
||||||
|
cfg = SteeringConfig.from_dict(json.loads(metadata["cfg"]))
|
||||||
|
state = _safetensors_dict_to_state(sd)
|
||||||
|
return cls(cfg, state)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
layers = sorted(self.state)
|
||||||
|
return f"Vector(method={self.cfg.method!r}, layers={layers})"
|
||||||
@@ -0,0 +1,65 @@
|
|||||||
|
"""Smoke test: train + calibrate + branch_pmass on a tiny random model.
|
||||||
|
|
||||||
|
Pass = runtime sanity. Distinguishing checks:
|
||||||
|
- measure_kl returns kl > 0 at coeff > 0 (steer DID change distribution).
|
||||||
|
- measure_kl returns kl ~= 0 at coeff = 0 (silent failure detector: if hooks
|
||||||
|
leak, base==steer KL would be nonzero).
|
||||||
|
- branch_pmass is in [0, 1].
|
||||||
|
- branch_pmass changes between coeff=0 and coeff=large (sneaky-fail catch:
|
||||||
|
if pmass is just identity-pass-through it would be invariant).
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
from iso_kl_figure import (
|
||||||
|
SteeringConfig, MeanDiffC, train, measure_kl, attach, detach,
|
||||||
|
)
|
||||||
|
from iso_kl_figure.branch_pmass import branch_pmass
|
||||||
|
|
||||||
|
|
||||||
|
MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM"
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_smoke():
|
||||||
|
tok = AutoTokenizer.from_pretrained(MODEL)
|
||||||
|
if tok.pad_token_id is None:
|
||||||
|
tok.pad_token_id = tok.eos_token_id
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype=torch.float32)
|
||||||
|
model.eval()
|
||||||
|
n_layers = model.config.num_hidden_layers
|
||||||
|
cfg = MeanDiffC(coeff=1.0, layers=(n_layers // 2,))
|
||||||
|
|
||||||
|
pos = ["Sure: <answer>", "Yes: <answer>", "Of course: <answer>", "Here: <answer>"]
|
||||||
|
neg = ["No way.", "I refuse.", "Cannot help.", "Decline."]
|
||||||
|
v = train(model, tok, pos, neg, cfg, batch_size=2, max_length=32)
|
||||||
|
|
||||||
|
# KL must be ~0 at coeff=0 (no leak), and > 0 at large coeff
|
||||||
|
v.cfg.coeff = 0.0
|
||||||
|
m0 = measure_kl(v, model, tok, ["hello world"], T=4, device="cpu")
|
||||||
|
assert m0["kl_p95"] < 1e-3, f"coeff=0 should give zero KL, got {m0['kl_p95']}"
|
||||||
|
|
||||||
|
v.cfg.coeff = 5.0
|
||||||
|
m1 = measure_kl(v, model, tok, ["hello world"], T=4, device="cpu")
|
||||||
|
assert m1["kl_p95"] > 0.0, "coeff>0 should give nonzero KL"
|
||||||
|
|
||||||
|
# per_t arrays length matches T
|
||||||
|
assert len(m1["per_t_p95"]) == 4
|
||||||
|
assert len(m1["per_t_max"]) == 4
|
||||||
|
|
||||||
|
# branch_pmass: in [0, 1] and varies with coeff
|
||||||
|
pids = tok("hello", return_tensors="pt").input_ids[0]
|
||||||
|
rolled = pids[-2:].clone()
|
||||||
|
suffix = ' {"value": '
|
||||||
|
fork = [0, 1, 2]
|
||||||
|
|
||||||
|
v.cfg.coeff = 0.0
|
||||||
|
p_zero = branch_pmass(v, model, tok, pids, rolled, fork, suffix,
|
||||||
|
["true", "false"], device="cpu")
|
||||||
|
v.cfg.coeff = 5.0
|
||||||
|
p_steer = branch_pmass(v, model, tok, pids, rolled, fork, suffix,
|
||||||
|
["true", "false"], device="cpu")
|
||||||
|
for x in p_zero["pmass"] + p_steer["pmass"]:
|
||||||
|
assert 0.0 <= x <= 1.0, f"pmass out of [0,1]: {x}"
|
||||||
|
diff = max(abs(a - b) for a, b in zip(p_zero["pmass"], p_steer["pmass"]))
|
||||||
|
assert diff > 1e-6, "pmass invariant to coeff -- hook is dead"
|
||||||
Reference in New Issue
Block a user