mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 18:04:59 +08:00
8000aa48f4
Entry (f) already cited scripts/diag_pinning_moduleS_exact.py; both moduleS scripts were untracked, so committing them for provenance alongside the calibration script the new entry cites. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
165 lines
7.2 KiB
Python
165 lines
7.2 KiB
Python
"""Per-module residual S-space gate score.
|
|
|
|
Unlike Super-S, this keeps one residual-side SVD basis per Linear through
|
|
projection, pair-direction extraction, mode selection, and scoring. Scores are
|
|
aggregated only after each module has produced its own score.
|
|
|
|
This is an offline residual-stream diagnostic, not exact steering-lite sspace:
|
|
the cache contains block residuals at layers 12/18/24, not each Linear's own
|
|
input/output activations.
|
|
|
|
writer (o_proj/down_proj): G_m = W_m W_m^T = U_m S_m^2 U_m^T
|
|
xS_m = h @ U_m / sqrt(S_m)
|
|
reader (q/k/v/gate/up): G_m = W_m^T W_m = V_m S_m^2 V_m^T
|
|
xS_m = h @ V_m * sqrt(S_m)
|
|
|
|
uv run python scripts/diag_pinning_moduleS.py
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from glob import glob
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import torch
|
|
from safetensors import safe_open
|
|
from tabulate import tabulate
|
|
|
|
from vgrout.train import _auroc
|
|
|
|
|
|
ROOT = Path("/workspace/projected_grpo")
|
|
RUNS = {"v3": ROOT / "out/diag", "v4": ROOT / "out/diag_v4", "v5": ROOT / "out/diag_v5"}
|
|
HEAD_PREFIX = "behavior_"
|
|
HOOKED_LAYERS = (12, 18, 24)
|
|
WRITERS = ("self_attn.o_proj", "mlp.down_proj")
|
|
READERS = ("self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj",
|
|
"mlp.gate_proj", "mlp.up_proj")
|
|
RANKS = (64, 256, -1)
|
|
EPS = 1e-8
|
|
|
|
|
|
def model_weights():
|
|
snap = Path(glob("/workspace/.hf_home/hub/models--Qwen--Qwen3-4B/snapshots/*")[0])
|
|
wmap = json.loads((snap / "model.safetensors.index.json").read_text())["weight_map"]
|
|
handles = {f: safe_open(snap / f, framework="pt") for f in set(wmap.values())}
|
|
for residual_idx, layer_idx in enumerate(HOOKED_LAYERS):
|
|
for role, modules in (("writer", WRITERS), ("reader", READERS)):
|
|
module_layer = layer_idx if role == "writer" else layer_idx + 1
|
|
for module in modules:
|
|
key = f"model.layers.{module_layer}.{module}.weight"
|
|
yield residual_idx, module_layer, role, module, handles[wmap[key]].get_tensor(key).float()
|
|
|
|
|
|
def residual_basis(W: torch.Tensor, role: str) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""Return residual-side basis and sqrt(Sigma), descending by singular value."""
|
|
gram = W @ W.T if role == "writer" else W.T @ W
|
|
eigenvalues, basis = torch.linalg.eigh(gram)
|
|
return basis.flip(1), eigenvalues.flip(0).clamp_min(0).pow(0.25).clamp_min(EPS)
|
|
|
|
|
|
def module_stats(
|
|
X: torch.Tensor,
|
|
Ph: torch.Tensor,
|
|
Pc: torch.Tensor,
|
|
basis: torch.Tensor,
|
|
sqrtS: torch.Tensor,
|
|
role: str,
|
|
rank: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""Return per-example dot, cosine, and squared norm in one module S-space."""
|
|
scale = sqrtS if role == "reader" else sqrtS.reciprocal()
|
|
XS = (X @ basis) * scale
|
|
dS = ((Ph @ basis) * scale - (Pc @ basis) * scale).mean(0)
|
|
if rank > 0:
|
|
idx = dS.abs().topk(rank).indices
|
|
XS, dS = XS[:, idx], dS[idx]
|
|
direction = dS / dS.norm().clamp_min(EPS)
|
|
dot = XS @ direction
|
|
norm2 = XS.square().sum(1)
|
|
return dot, dot / norm2.sqrt().clamp_min(EPS), norm2
|
|
|
|
|
|
def main() -> int:
|
|
features = {tag: torch.load(path / "pinning_feats.pt", weights_only=False)
|
|
for tag, path in RUNS.items()}
|
|
prepared = {}
|
|
for tag, fe in features.items():
|
|
head = [i for i, pair_id in enumerate(fe["pair_ids"]) if pair_id.startswith(HEAD_PREFIX)]
|
|
prepared[tag] = {
|
|
"X": fe["RES"].float(),
|
|
"Ph": fe["pair_feats"][("resid", "hack")][head].float(),
|
|
"Pc": fe["pair_feats"][("resid", "clean")][head].float(),
|
|
"pos": (np.abs(fe["adv"]) > 1e-6) & (fe["adv"] > 0),
|
|
"y": fe["exploited"] & (fe["adv"] > 0),
|
|
}
|
|
|
|
collected = {
|
|
(tag, role, rank): {"dot": [], "cos": [], "norm2": []}
|
|
for tag in RUNS for role in ("writer", "reader") for rank in RANKS
|
|
}
|
|
print("building and scoring 21 independent residual-side module SVD bases (CPU)...")
|
|
for residual_idx, layer_idx, role, module, W in model_weights():
|
|
print(f" residual={HOOKED_LAYERS[residual_idx]:02d} module_layer={layer_idx:02d} "
|
|
f"role={role:6s} module={module}")
|
|
basis, sqrtS = residual_basis(W, role)
|
|
for tag, fe in prepared.items():
|
|
X, Ph, Pc = fe["X"][:, residual_idx], fe["Ph"][:, residual_idx], fe["Pc"][:, residual_idx]
|
|
for rank in RANKS:
|
|
dot, cos, norm2 = module_stats(X, Ph, Pc, basis, sqrtS, role, rank)
|
|
for key, value in (("dot", dot), ("cos", cos), ("norm2", norm2)):
|
|
collected[(tag, role, rank)][key].append(value)
|
|
|
|
rows = {}
|
|
for tag, fe in prepared.items():
|
|
au = lambda score: _auroc(score[fe["pos"]].tolist(), fe["y"][fe["pos"]].tolist())
|
|
raw_d = (fe["Ph"] - fe["Pc"]).mean(0)
|
|
raw_v = raw_d / raw_d.norm(dim=-1, keepdim=True).clamp_min(EPS)
|
|
raw_dot = torch.einsum("nld,ld->n", fe["X"], raw_v)
|
|
raw_cos = raw_dot / (fe["X"].flatten(1).norm(dim=1).clamp_min(EPS)
|
|
* raw_v.flatten().norm())
|
|
rows.setdefault(("raw resid", "concat", "cos"), {"variant": "raw resid",
|
|
"aggregate": "concat", "score": "cos"})[tag] = au(raw_cos)
|
|
rows.setdefault(("raw resid", "concat", "dot"), {"variant": "raw resid",
|
|
"aggregate": "concat", "score": "dot"})[tag] = au(raw_dot)
|
|
|
|
for role in ("writer", "reader", "both"):
|
|
roles = ("writer", "reader") if role == "both" else (role,)
|
|
for rank in RANKS:
|
|
stats = [collected[(tag, r, rank)] for r in roles]
|
|
dots = torch.stack([x for stat in stats for x in stat["dot"]], 1)
|
|
coses = torch.stack([x for stat in stats for x in stat["cos"]], 1)
|
|
norm2 = torch.stack([x for stat in stats for x in stat["norm2"]], 1)
|
|
scores = {
|
|
("mean", "cos"): coses.mean(1),
|
|
("mean", "dot"): dots.mean(1),
|
|
("concat", "cos"): dots.sum(1) / (
|
|
norm2.sum(1).sqrt().clamp_min(EPS) * np.sqrt(dots.shape[1])),
|
|
("concat", "dot"): dots.sum(1),
|
|
}
|
|
rank_name = "full" if rank < 0 else str(rank)
|
|
variant = f"moduleS {role} r={rank_name}"
|
|
for (aggregate, kind), score in scores.items():
|
|
key = (variant, aggregate, kind)
|
|
rows.setdefault(key, {"variant": variant, "aggregate": aggregate,
|
|
"score": kind})[tag] = au(score)
|
|
|
|
out = list(rows.values())
|
|
for row in out:
|
|
values = [row[tag] for tag in RUNS]
|
|
row["mean"], row["min"] = float(np.mean(values)), float(np.min(values))
|
|
out.sort(key=lambda row: -row["min"])
|
|
cols = ["variant", "aggregate", "score", "v3", "v4", "v5", "mean", "min"]
|
|
print("\nAUROC on the A>0 contrast:")
|
|
print("SHOULD: raw resid cos == 0.916/0.700/0.804 and dot == "
|
|
"0.905/0.721/0.756 ELSE this harness disagrees with diag_pinning_superS. "
|
|
"Per-module bases must remain separate until score aggregation.")
|
|
print(tabulate([{c: row[c] for c in cols} for row in out],
|
|
headers="keys", tablefmt="pipe", floatfmt="+.3f"))
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|