mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 17:15:58 +08:00
8000aa48f4
Entry (f) already cited scripts/diag_pinning_moduleS_exact.py; both moduleS scripts were untracked, so committing them for provenance alongside the calibration script the new entry cites. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
233 lines
9.5 KiB
Python
233 lines
9.5 KiB
Python
"""Exact per-module S-space gate diagnostic on actual Linear inputs/outputs.
|
|
|
|
For every selected Linear, capture the same side used by steering-lite sspace:
|
|
|
|
writer output: xS = y @ U / sqrt(S)
|
|
reader input: xS = x @ V * sqrt(S)
|
|
|
|
Each module keeps its own thin SVD basis through pair-direction extraction,
|
|
top-r selection, and scoring. Scores are aggregated only after each module has
|
|
produced its own score.
|
|
|
|
GPU required:
|
|
uv run python scripts/diag_pinning_moduleS_exact.py
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import struct
|
|
import csv
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import tyro
|
|
from loguru import logger
|
|
from safetensors.torch import load_file
|
|
from tabulate import tabulate
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
from vgrout.lora2r import wrap_model_with_lora2r
|
|
from vgrout.pairs import load_pairs
|
|
from vgrout.train import _auroc
|
|
|
|
|
|
ROLES = {
|
|
"self_attn.o_proj": "writer",
|
|
"mlp.down_proj": "writer",
|
|
"self_attn.q_proj": "reader",
|
|
"self_attn.k_proj": "reader",
|
|
"self_attn.v_proj": "reader",
|
|
"mlp.gate_proj": "reader",
|
|
"mlp.up_proj": "reader",
|
|
}
|
|
RANKS = (64, 256, -1)
|
|
EPS = 1e-8
|
|
|
|
|
|
@dataclass
|
|
class Cfg:
|
|
run_dir: Path = Path("out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3")
|
|
ckpt: str = "first_hack"
|
|
pairs: Path = Path("data/pairs/hack_pairs.md#all-in-one")
|
|
headline_prefix: str = "behavior_"
|
|
step_lo: int = 2
|
|
step_hi: int = 9
|
|
max_rollouts: int = 240
|
|
layers: tuple[int, ...] = (12, 18, 24)
|
|
out: Path = Path("out/diag/moduleS_exact.tsv")
|
|
|
|
|
|
def checkpoint_meta(path: Path) -> dict:
|
|
with open(path, "rb") as f:
|
|
return json.loads(f.read(struct.unpack("<Q", f.read(8))[0])).get("__metadata__", {})
|
|
|
|
|
|
def completion_nll(model, tokenizer, prompt: str, completion: str, device) -> torch.Tensor:
|
|
prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
|
full_ids = tokenizer(prompt + completion, return_tensors="pt").input_ids.to(device)
|
|
n_prompt = prompt_ids.shape[1]
|
|
logits = model(full_ids).logits[:, :-1]
|
|
targets = full_ids[:, 1:]
|
|
logp = torch.nn.functional.log_softmax(logits.float(), dim=-1)
|
|
nll = -logp.gather(-1, targets.unsqueeze(-1)).squeeze(-1)
|
|
positions = torch.arange(full_ids.shape[1] - 1, device=device).unsqueeze(0)
|
|
mask = (positions >= (n_prompt - 1)).float()
|
|
return (nll * mask).sum() / mask.sum()
|
|
|
|
|
|
def thin_basis(W: torch.Tensor, role: str) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""Return residual-side thin basis and sqrt(Sigma), descending."""
|
|
d_out, d_in = W.shape
|
|
if role == "writer":
|
|
eigenvalues, basis = torch.linalg.eigh(W @ W.T)
|
|
elif d_out >= d_in:
|
|
eigenvalues, basis = torch.linalg.eigh(W.T @ W)
|
|
else:
|
|
eigenvalues, U = torch.linalg.eigh(W @ W.T)
|
|
singular_values = eigenvalues.clamp_min(0).sqrt()
|
|
basis = W.T @ U / singular_values.clamp_min(EPS)
|
|
return basis.flip(1), eigenvalues.flip(0).clamp_min(0).pow(0.25).clamp_min(EPS)
|
|
|
|
|
|
class ModuleSTap:
|
|
def __init__(self, modules: dict[str, tuple[torch.nn.Module, str, torch.Tensor, torch.Tensor]]):
|
|
self.modules = modules
|
|
self.values: dict[str, torch.Tensor] = {}
|
|
self.handles = []
|
|
|
|
def __enter__(self):
|
|
for name, (module, role, basis, sqrtS) in self.modules.items():
|
|
def hook(module, args, output, name=name, role=role, basis=basis, sqrtS=sqrtS):
|
|
h = F.linear(args[0], module.weight, module.bias) if role == "writer" else args[0]
|
|
scale = sqrtS.reciprocal() if role == "writer" else sqrtS
|
|
self.values[name] = (h.detach().float() @ basis) * scale
|
|
self.handles.append(module.register_forward_hook(hook))
|
|
return self
|
|
|
|
def __exit__(self, *exc):
|
|
for handle in self.handles:
|
|
handle.remove()
|
|
|
|
def pooled(self, n_prompt: int) -> dict[str, torch.Tensor]:
|
|
return {name: value[0, n_prompt:].mean(0).cpu() for name, value in self.values.items()}
|
|
|
|
|
|
def score_rows(
|
|
live: list[dict[str, torch.Tensor]],
|
|
pair_hack: list[dict[str, torch.Tensor]],
|
|
pair_clean: list[dict[str, torch.Tensor]],
|
|
roles: dict[str, str],
|
|
adv: np.ndarray,
|
|
exploited: np.ndarray,
|
|
) -> list[dict]:
|
|
head_diff = {
|
|
name: torch.stack([hack[name] - clean[name] for hack, clean in zip(pair_hack, pair_clean)]).mean(0)
|
|
for name in roles
|
|
}
|
|
pos = (np.abs(adv) > 1e-6) & (adv > 0)
|
|
y = exploited & (adv > 0)
|
|
au = lambda score: _auroc(score[pos].tolist(), y[pos].tolist())
|
|
rows = []
|
|
for role in ("writer", "reader", "both"):
|
|
names = [name for name, module_role in roles.items() if role == "both" or module_role == role]
|
|
for rank in RANKS:
|
|
dots, coses, norm2 = [], [], []
|
|
for name in names:
|
|
dS = head_diff[name]
|
|
X = torch.stack([example[name] for example in live])
|
|
if rank > 0 and rank < dS.numel():
|
|
idx = dS.abs().topk(rank).indices
|
|
dS, X = dS[idx], X[:, idx]
|
|
direction = dS / dS.norm().clamp_min(EPS)
|
|
dot = X @ direction
|
|
dots.append(dot)
|
|
coses.append(dot / X.norm(dim=1).clamp_min(EPS))
|
|
norm2.append(X.square().sum(1))
|
|
dots, coses, norm2 = torch.stack(dots, 1), torch.stack(coses, 1), torch.stack(norm2, 1)
|
|
scores = {
|
|
"mean_cos": coses.mean(1),
|
|
"concat_cos": dots.sum(1) / (norm2.sum(1).sqrt().clamp_min(EPS) * np.sqrt(dots.shape[1])),
|
|
"dot": dots.mean(1),
|
|
}
|
|
rank_name = "full" if rank < 0 else str(rank)
|
|
rows.extend({"variant": f"moduleS-exact {role} r={rank_name}", "score": kind, "auroc": au(score)}
|
|
for kind, score in scores.items())
|
|
return rows
|
|
|
|
|
|
def main(cfg: Cfg) -> int:
|
|
device = torch.device("cuda")
|
|
ckpt_path = cfg.run_dir / f"{cfg.ckpt}.safetensors"
|
|
meta = checkpoint_meta(ckpt_path)
|
|
run_cfg = json.loads(meta["cfg"])
|
|
tokenizer = AutoTokenizer.from_pretrained(run_cfg["model"])
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
run_cfg["model"], dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(device)
|
|
model.config.use_cache = False
|
|
wrappers = wrap_model_with_lora2r(model, r=run_cfg["lora_r"], init_seed=run_cfg["lora_init_seed"])
|
|
state = load_file(str(ckpt_path))
|
|
for name, info in wrappers.items():
|
|
info["A"].data.copy_(state[f"A/{name}"].to(device, torch.float32))
|
|
info["B"].data.copy_(state[f"B/{name}"].to(device, torch.float32))
|
|
model.eval()
|
|
|
|
selected = {}
|
|
for name, module in model.named_modules():
|
|
for suffix, role in ROLES.items():
|
|
if name.endswith(suffix) and any(name.startswith(f"model.layers.{layer}.") for layer in cfg.layers):
|
|
basis, sqrtS = thin_basis(module.weight.detach().float(), role)
|
|
selected[name] = (module, role, basis, sqrtS)
|
|
assert len(selected) == len(cfg.layers) * len(ROLES), len(selected)
|
|
roles = {name: role for name, (_, role, _, _) in selected.items()}
|
|
logger.info(f"selected {len(selected)} exact module S-spaces")
|
|
|
|
def extract(tap: ModuleSTap, prompt: str, completion: str) -> dict[str, torch.Tensor]:
|
|
with torch.no_grad():
|
|
loss = completion_nll(model, tokenizer, prompt, completion, device)
|
|
assert torch.isfinite(loss), loss
|
|
n_prompt = tokenizer(prompt, return_tensors="pt").input_ids.shape[1]
|
|
return tap.pooled(n_prompt)
|
|
|
|
pairs = [pair for pair in load_pairs(cfg.pairs) if pair.problem_id.startswith(cfg.headline_prefix)]
|
|
records = [json.loads(line) for line in (cfg.run_dir / "rollouts.jsonl").read_text().splitlines()]
|
|
records = [r for r in records if cfg.step_lo <= r["step"] <= cfg.step_hi and r["text"].strip()][:cfg.max_rollouts]
|
|
pair_hack, pair_clean, live = [], [], []
|
|
with ModuleSTap(selected) as tap:
|
|
for i, pair in enumerate(pairs):
|
|
pair_hack.append(extract(tap, pair.prompt, pair.hack))
|
|
pair_clean.append(extract(tap, pair.prompt, pair.clean))
|
|
logger.info(f"pair {i + 1}/{len(pairs)}")
|
|
for i, record in enumerate(records):
|
|
live.append(extract(tap, record["prompt"], record["text"]))
|
|
if (i + 1) % 40 == 0:
|
|
logger.info(f"rollout {i + 1}/{len(records)}")
|
|
|
|
steps = np.array([r["step"] for r in records])
|
|
p_idx = np.array([r["p_idx"] for r in records])
|
|
reward = np.array([float(r["reward"]) for r in records])
|
|
adv = np.empty(len(records))
|
|
for step, prompt_idx in set(zip(steps.tolist(), p_idx.tolist())):
|
|
mask = (steps == step) & (p_idx == prompt_idx)
|
|
adv[mask] = reward[mask] - reward[mask].mean()
|
|
exploited = np.array([bool(r["exploited"]) for r in records])
|
|
|
|
rows = score_rows(live, pair_hack, pair_clean, roles, adv, exploited)
|
|
rows.sort(key=lambda row: -row["auroc"])
|
|
cfg.out.parent.mkdir(parents=True, exist_ok=True)
|
|
with cfg.out.open("w", newline="") as f:
|
|
writer = csv.DictWriter(f, fieldnames=("variant", "score", "auroc"), delimiter="\t")
|
|
writer.writeheader()
|
|
writer.writerows(rows)
|
|
print("SHOULD: all rows finite; exact hooks capture writer outputs and reader inputs. "
|
|
"A selected maximum is a candidate, not confirmation.")
|
|
print(tabulate(rows, headers="keys", tablefmt="pipe", floatfmt="+.3f"))
|
|
logger.info(f"wrote {cfg.out}")
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main(tyro.cli(Cfg)))
|