mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 18:04:59 +08:00
1d4f33ffb6
superS (pooled writer/reader eigenbasis, whitened + top-r) tops out at min-window AUROC 0.721 = raw resid dot; best unwhitened rotation+top-64 0.740 < act 0.747 (max of ~50-variant grid). act t-stat extraction also null (0.719 vs 0.749 min). Spec updated: act_dot default, journal-(d) evidence table, implementation plan for routeA. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
136 lines
6.3 KiB
Python
136 lines
6.3 KiB
Python
"""Super-S-space gate score: project the residual stream onto the pooled SVD basis
|
|
of the residual writers/readers before extracting v and scoring.
|
|
|
|
Source idea: wassname/steering-lite variants/super_sspace.py. Pool the residual-side
|
|
singular structure of the block Linears into a Gram matrix and eigendecompose:
|
|
|
|
writer (d_out = d_model, o_proj/down_proj): G += W W^T (= U S^2 U^T)
|
|
reader (d_in = d_model, q/k/v/gate/up): G += W^T W (= V S^2 V^T)
|
|
eigh(G) -> U_star [d, d], Sigma_star = sqrt(lambda)
|
|
|
|
Whitened coordinates xS = h @ U_star / sqrt(Sigma_star). At full rank WITHOUT the
|
|
whitening this is a pure rotation (cos/dot unchanged), so the testable content is
|
|
(a) the 1/sqrt(Sigma) reweighting -- a Mahalanobis metric in the writer/reader
|
|
spectrum -- and (b) top-r mode selection by |dS| from the authored-pair diff
|
|
(pair-only info, oracle-free).
|
|
|
|
Offline from cached pinning_feats.pt (RES [N, L, 2560]) + safetensors weights, no
|
|
forward pass, no GPU. Evaluation = the A>0 contrast AUROC (exploited & adv>0 vs
|
|
rest among adv>0), same as diag_pinning.py.
|
|
|
|
uv run python scripts/diag_pinning_superS.py
|
|
"""
|
|
from __future__ import annotations
|
|
import json
|
|
from glob import glob
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import torch
|
|
from safetensors import safe_open
|
|
from tabulate import tabulate
|
|
|
|
from vgrout.train import _auroc
|
|
|
|
ROOT = Path("/workspace/projected_grpo")
|
|
RUNS = {"v3": ROOT / "out/diag", "v4": ROOT / "out/diag_v4", "v5": ROOT / "out/diag_v5"}
|
|
HEAD_PREFIX = "behavior_"
|
|
HOOKED_LAYERS = (12, 18, 24)
|
|
WRITERS = ("self_attn.o_proj", "mlp.down_proj")
|
|
READERS = ("self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj",
|
|
"mlp.gate_proj", "mlp.up_proj")
|
|
EPS = 1e-8
|
|
|
|
|
|
def load_grams(n_layers: int = 36, d: int = 2560) -> dict[tuple[str, str], torch.Tensor]:
|
|
"""{(role, blocks): G [d, d]} for role in {writer, reader}, blocks in {hooked, all}."""
|
|
snap = Path(glob("/workspace/.hf_home/hub/models--Qwen--Qwen3-4B/snapshots/*")[0])
|
|
wmap = json.loads((snap / "model.safetensors.index.json").read_text())["weight_map"]
|
|
G = {(role, blk): torch.zeros(d, d) for role in ("writer", "reader") for blk in ("hooked", "all")}
|
|
handles = {f: safe_open(snap / f, framework="pt") for f in set(wmap.values())}
|
|
for li in range(n_layers):
|
|
for role, mods in (("writer", WRITERS), ("reader", READERS)):
|
|
for m in mods:
|
|
key = f"model.layers.{li}.{m}.weight"
|
|
W = handles[wmap[key]].get_tensor(key).float()
|
|
gram = W @ W.T if role == "writer" else W.T @ W
|
|
G[(role, "all")] += gram
|
|
if li in HOOKED_LAYERS:
|
|
G[(role, "hooked")] += gram
|
|
return G
|
|
|
|
|
|
def basis(G: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""U_star [d, d] descending, sqrt(Sigma_star) = lambda**0.25 clamped."""
|
|
lam, U = torch.linalg.eigh(G.double())
|
|
lam, U = lam.flip(0).clamp_min(0), U.flip(1)
|
|
return U.float(), lam.float().pow(0.25).clamp_min(EPS) # sqrt(Sigma) = lam^(1/4)
|
|
|
|
|
|
def score(XS: torch.Tensor, dS: torch.Tensor, kind: str) -> np.ndarray:
|
|
"""XS [N, L, r], dS [L, r] -> concat score over layers."""
|
|
v = dS / dS.norm(dim=-1, keepdim=True).clamp_min(EPS)
|
|
d = torch.einsum("nlr,lr->n", XS, v)
|
|
if kind == "dot":
|
|
return d.numpy()
|
|
return (d / (XS.flatten(1).norm(dim=1).clamp_min(EPS) * v.flatten().norm())).numpy()
|
|
|
|
|
|
def main() -> int:
|
|
print("building writer/reader Gram bases from Qwen3-4B safetensors (CPU)...")
|
|
grams = load_grams()
|
|
grams[("both", "hooked")] = grams[("writer", "hooked")] + grams[("reader", "hooked")]
|
|
grams[("both", "all")] = grams[("writer", "all")] + grams[("reader", "all")]
|
|
bases = {k: basis(g) for k, g in grams.items()}
|
|
|
|
rows = {}
|
|
for tag, diag_dir in RUNS.items():
|
|
fe = torch.load(diag_dir / "pinning_feats.pt", weights_only=False)
|
|
RES, adv, exploited = fe["RES"].float(), fe["adv"], fe["exploited"]
|
|
head = [i for i, pid in enumerate(fe["pair_ids"]) if pid.startswith(HEAD_PREFIX)]
|
|
Ph = fe["pair_feats"][("resid", "hack")][head].float()
|
|
Pc = fe["pair_feats"][("resid", "clean")][head].float()
|
|
pos = (np.abs(adv) > 1e-6) & (adv > 0)
|
|
y = exploited & (adv > 0)
|
|
au = lambda s: _auroc(s[pos].tolist(), y[pos].tolist())
|
|
|
|
def add(name: str, XS, dS):
|
|
for kind in ("cos", "dot"):
|
|
rows.setdefault((name, kind), {"variant": name, "score": kind})[tag] = \
|
|
au(score(XS, dS, kind))
|
|
rows[(name, kind)][f"{tag} L24"] = au(score(XS[:, 2:], dS[2:], kind))
|
|
|
|
add("raw resid (baseline)", RES, (Ph - Pc).mean(0))
|
|
for (role, blk), (U, sqrtS) in sorted(bases.items()):
|
|
# whiten=False at r=full is a pure rotation (== baseline, skipped); with
|
|
# top-r it still tests sparsification in the pooled eigenbasis.
|
|
for whiten, lab in ((True, "superS"), (False, "superS-rot")):
|
|
tx = lambda X: torch.einsum("nld,dk->nlk", X, U) / (sqrtS if whiten else 1.0)
|
|
XS, dS = tx(RES), (tx(Ph) - tx(Pc)).mean(0)
|
|
if whiten:
|
|
add(f"{lab} {role}/{blk} r=full", XS, dS)
|
|
for r in (256, 64):
|
|
idx = dS.abs().topk(r, dim=-1).indices # [L, r] per-layer top modes
|
|
XSr = torch.gather(XS, 2, idx.unsqueeze(0).expand(XS.shape[0], -1, -1))
|
|
add(f"{lab} {role}/{blk} r={r}", XSr, torch.gather(dS, 1, idx))
|
|
|
|
out = list(rows.values())
|
|
for r in out:
|
|
vals = [r[t] for t in RUNS]
|
|
r["mean"], r["min"] = float(np.mean(vals)), float(np.min(vals))
|
|
out.sort(key=lambda r: -r["min"])
|
|
cols = ["variant", "score", "v3", "v4", "v5", "mean", "min",
|
|
"v3 L24", "v4 L24", "v5 L24"]
|
|
print("\nAUROC on the A>0 contrast, concat over layers 12/18/24 (and L24 alone):")
|
|
print("SHOULD: 'raw resid (baseline)' cos == 0.916/0.700/0.804 (matches journal "
|
|
"Table 1) ELSE the harness disagrees with diag_pinning. superS rows differ "
|
|
"from baseline only via whitening + top-r; if no superS row beats baseline "
|
|
"min by >0.03 the pooled-spectrum metric adds nothing here.")
|
|
print(tabulate([{c: r[c] for c in cols} for r in out],
|
|
headers="keys", tablefmt="pipe", floatfmt="+.3f"))
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|