Files
evil_MoE/scripts/diag_pinning_moduleS_exact.py
T
wassname 8000aa48f4 journal(#41): entry (g) routeA shipped + guard-drop calibration; track moduleS diag scripts
Entry (f) already cited scripts/diag_pinning_moduleS_exact.py; both moduleS
scripts were untracked, so committing them for provenance alongside the
calibration script the new entry cites.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
2026-06-11 12:50:07 +00:00

233 lines
9.5 KiB
Python

"""Exact per-module S-space gate diagnostic on actual Linear inputs/outputs.
For every selected Linear, capture the same side used by steering-lite sspace:
writer output: xS = y @ U / sqrt(S)
reader input: xS = x @ V * sqrt(S)
Each module keeps its own thin SVD basis through pair-direction extraction,
top-r selection, and scoring. Scores are aggregated only after each module has
produced its own score.
GPU required:
uv run python scripts/diag_pinning_moduleS_exact.py
"""
from __future__ import annotations
import json
import struct
import csv
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
import tyro
from loguru import logger
from safetensors.torch import load_file
from tabulate import tabulate
from transformers import AutoModelForCausalLM, AutoTokenizer
from vgrout.lora2r import wrap_model_with_lora2r
from vgrout.pairs import load_pairs
from vgrout.train import _auroc
ROLES = {
"self_attn.o_proj": "writer",
"mlp.down_proj": "writer",
"self_attn.q_proj": "reader",
"self_attn.k_proj": "reader",
"self_attn.v_proj": "reader",
"mlp.gate_proj": "reader",
"mlp.up_proj": "reader",
}
RANKS = (64, 256, -1)
EPS = 1e-8
@dataclass
class Cfg:
run_dir: Path = Path("out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3")
ckpt: str = "first_hack"
pairs: Path = Path("data/pairs/hack_pairs.md#all-in-one")
headline_prefix: str = "behavior_"
step_lo: int = 2
step_hi: int = 9
max_rollouts: int = 240
layers: tuple[int, ...] = (12, 18, 24)
out: Path = Path("out/diag/moduleS_exact.tsv")
def checkpoint_meta(path: Path) -> dict:
with open(path, "rb") as f:
return json.loads(f.read(struct.unpack("<Q", f.read(8))[0])).get("__metadata__", {})
def completion_nll(model, tokenizer, prompt: str, completion: str, device) -> torch.Tensor:
prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
full_ids = tokenizer(prompt + completion, return_tensors="pt").input_ids.to(device)
n_prompt = prompt_ids.shape[1]
logits = model(full_ids).logits[:, :-1]
targets = full_ids[:, 1:]
logp = torch.nn.functional.log_softmax(logits.float(), dim=-1)
nll = -logp.gather(-1, targets.unsqueeze(-1)).squeeze(-1)
positions = torch.arange(full_ids.shape[1] - 1, device=device).unsqueeze(0)
mask = (positions >= (n_prompt - 1)).float()
return (nll * mask).sum() / mask.sum()
def thin_basis(W: torch.Tensor, role: str) -> tuple[torch.Tensor, torch.Tensor]:
"""Return residual-side thin basis and sqrt(Sigma), descending."""
d_out, d_in = W.shape
if role == "writer":
eigenvalues, basis = torch.linalg.eigh(W @ W.T)
elif d_out >= d_in:
eigenvalues, basis = torch.linalg.eigh(W.T @ W)
else:
eigenvalues, U = torch.linalg.eigh(W @ W.T)
singular_values = eigenvalues.clamp_min(0).sqrt()
basis = W.T @ U / singular_values.clamp_min(EPS)
return basis.flip(1), eigenvalues.flip(0).clamp_min(0).pow(0.25).clamp_min(EPS)
class ModuleSTap:
def __init__(self, modules: dict[str, tuple[torch.nn.Module, str, torch.Tensor, torch.Tensor]]):
self.modules = modules
self.values: dict[str, torch.Tensor] = {}
self.handles = []
def __enter__(self):
for name, (module, role, basis, sqrtS) in self.modules.items():
def hook(module, args, output, name=name, role=role, basis=basis, sqrtS=sqrtS):
h = F.linear(args[0], module.weight, module.bias) if role == "writer" else args[0]
scale = sqrtS.reciprocal() if role == "writer" else sqrtS
self.values[name] = (h.detach().float() @ basis) * scale
self.handles.append(module.register_forward_hook(hook))
return self
def __exit__(self, *exc):
for handle in self.handles:
handle.remove()
def pooled(self, n_prompt: int) -> dict[str, torch.Tensor]:
return {name: value[0, n_prompt:].mean(0).cpu() for name, value in self.values.items()}
def score_rows(
live: list[dict[str, torch.Tensor]],
pair_hack: list[dict[str, torch.Tensor]],
pair_clean: list[dict[str, torch.Tensor]],
roles: dict[str, str],
adv: np.ndarray,
exploited: np.ndarray,
) -> list[dict]:
head_diff = {
name: torch.stack([hack[name] - clean[name] for hack, clean in zip(pair_hack, pair_clean)]).mean(0)
for name in roles
}
pos = (np.abs(adv) > 1e-6) & (adv > 0)
y = exploited & (adv > 0)
au = lambda score: _auroc(score[pos].tolist(), y[pos].tolist())
rows = []
for role in ("writer", "reader", "both"):
names = [name for name, module_role in roles.items() if role == "both" or module_role == role]
for rank in RANKS:
dots, coses, norm2 = [], [], []
for name in names:
dS = head_diff[name]
X = torch.stack([example[name] for example in live])
if rank > 0 and rank < dS.numel():
idx = dS.abs().topk(rank).indices
dS, X = dS[idx], X[:, idx]
direction = dS / dS.norm().clamp_min(EPS)
dot = X @ direction
dots.append(dot)
coses.append(dot / X.norm(dim=1).clamp_min(EPS))
norm2.append(X.square().sum(1))
dots, coses, norm2 = torch.stack(dots, 1), torch.stack(coses, 1), torch.stack(norm2, 1)
scores = {
"mean_cos": coses.mean(1),
"concat_cos": dots.sum(1) / (norm2.sum(1).sqrt().clamp_min(EPS) * np.sqrt(dots.shape[1])),
"dot": dots.mean(1),
}
rank_name = "full" if rank < 0 else str(rank)
rows.extend({"variant": f"moduleS-exact {role} r={rank_name}", "score": kind, "auroc": au(score)}
for kind, score in scores.items())
return rows
def main(cfg: Cfg) -> int:
device = torch.device("cuda")
ckpt_path = cfg.run_dir / f"{cfg.ckpt}.safetensors"
meta = checkpoint_meta(ckpt_path)
run_cfg = json.loads(meta["cfg"])
tokenizer = AutoTokenizer.from_pretrained(run_cfg["model"])
model = AutoModelForCausalLM.from_pretrained(
run_cfg["model"], dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(device)
model.config.use_cache = False
wrappers = wrap_model_with_lora2r(model, r=run_cfg["lora_r"], init_seed=run_cfg["lora_init_seed"])
state = load_file(str(ckpt_path))
for name, info in wrappers.items():
info["A"].data.copy_(state[f"A/{name}"].to(device, torch.float32))
info["B"].data.copy_(state[f"B/{name}"].to(device, torch.float32))
model.eval()
selected = {}
for name, module in model.named_modules():
for suffix, role in ROLES.items():
if name.endswith(suffix) and any(name.startswith(f"model.layers.{layer}.") for layer in cfg.layers):
basis, sqrtS = thin_basis(module.weight.detach().float(), role)
selected[name] = (module, role, basis, sqrtS)
assert len(selected) == len(cfg.layers) * len(ROLES), len(selected)
roles = {name: role for name, (_, role, _, _) in selected.items()}
logger.info(f"selected {len(selected)} exact module S-spaces")
def extract(tap: ModuleSTap, prompt: str, completion: str) -> dict[str, torch.Tensor]:
with torch.no_grad():
loss = completion_nll(model, tokenizer, prompt, completion, device)
assert torch.isfinite(loss), loss
n_prompt = tokenizer(prompt, return_tensors="pt").input_ids.shape[1]
return tap.pooled(n_prompt)
pairs = [pair for pair in load_pairs(cfg.pairs) if pair.problem_id.startswith(cfg.headline_prefix)]
records = [json.loads(line) for line in (cfg.run_dir / "rollouts.jsonl").read_text().splitlines()]
records = [r for r in records if cfg.step_lo <= r["step"] <= cfg.step_hi and r["text"].strip()][:cfg.max_rollouts]
pair_hack, pair_clean, live = [], [], []
with ModuleSTap(selected) as tap:
for i, pair in enumerate(pairs):
pair_hack.append(extract(tap, pair.prompt, pair.hack))
pair_clean.append(extract(tap, pair.prompt, pair.clean))
logger.info(f"pair {i + 1}/{len(pairs)}")
for i, record in enumerate(records):
live.append(extract(tap, record["prompt"], record["text"]))
if (i + 1) % 40 == 0:
logger.info(f"rollout {i + 1}/{len(records)}")
steps = np.array([r["step"] for r in records])
p_idx = np.array([r["p_idx"] for r in records])
reward = np.array([float(r["reward"]) for r in records])
adv = np.empty(len(records))
for step, prompt_idx in set(zip(steps.tolist(), p_idx.tolist())):
mask = (steps == step) & (p_idx == prompt_idx)
adv[mask] = reward[mask] - reward[mask].mean()
exploited = np.array([bool(r["exploited"]) for r in records])
rows = score_rows(live, pair_hack, pair_clean, roles, adv, exploited)
rows.sort(key=lambda row: -row["auroc"])
cfg.out.parent.mkdir(parents=True, exist_ok=True)
with cfg.out.open("w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=("variant", "score", "auroc"), delimiter="\t")
writer.writeheader()
writer.writerows(rows)
print("SHOULD: all rows finite; exact hooks capture writer outputs and reader inputs. "
"A selected maximum is a candidate, not confirmation.")
print(tabulate(rows, headers="keys", tablefmt="pipe", floatfmt="+.3f"))
logger.info(f"wrote {cfg.out}")
return 0
if __name__ == "__main__":
raise SystemExit(main(tyro.cli(Cfg)))