diag: super-S-space gate score null; spec -> act_dot + winsorized-Otsu plan

superS (pooled writer/reader eigenbasis, whitened + top-r) tops out at
min-window AUROC 0.721 = raw resid dot; best unwhitened rotation+top-64
0.740 < act 0.747 (max of ~50-variant grid). act t-stat extraction also
null (0.719 vs 0.749 min). Spec updated: act_dot default, journal-(d)
evidence table, implementation plan for routeA.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-11 11:42:44 +00:00
parent 1b51c62cdc
commit 1d4f33ffb6
2 changed files with 288 additions and 0 deletions
+135
View File
@@ -0,0 +1,135 @@
"""Super-S-space gate score: project the residual stream onto the pooled SVD basis
of the residual writers/readers before extracting v and scoring.
Source idea: wassname/steering-lite variants/super_sspace.py. Pool the residual-side
singular structure of the block Linears into a Gram matrix and eigendecompose:
writer (d_out = d_model, o_proj/down_proj): G += W W^T (= U S^2 U^T)
reader (d_in = d_model, q/k/v/gate/up): G += W^T W (= V S^2 V^T)
eigh(G) -> U_star [d, d], Sigma_star = sqrt(lambda)
Whitened coordinates xS = h @ U_star / sqrt(Sigma_star). At full rank WITHOUT the
whitening this is a pure rotation (cos/dot unchanged), so the testable content is
(a) the 1/sqrt(Sigma) reweighting -- a Mahalanobis metric in the writer/reader
spectrum -- and (b) top-r mode selection by |dS| from the authored-pair diff
(pair-only info, oracle-free).
Offline from cached pinning_feats.pt (RES [N, L, 2560]) + safetensors weights, no
forward pass, no GPU. Evaluation = the A>0 contrast AUROC (exploited & adv>0 vs
rest among adv>0), same as diag_pinning.py.
uv run python scripts/diag_pinning_superS.py
"""
from __future__ import annotations
import json
from glob import glob
from pathlib import Path
import numpy as np
import torch
from safetensors import safe_open
from tabulate import tabulate
from vgrout.train import _auroc
ROOT = Path("/workspace/projected_grpo")
RUNS = {"v3": ROOT / "out/diag", "v4": ROOT / "out/diag_v4", "v5": ROOT / "out/diag_v5"}
HEAD_PREFIX = "behavior_"
HOOKED_LAYERS = (12, 18, 24)
WRITERS = ("self_attn.o_proj", "mlp.down_proj")
READERS = ("self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj",
"mlp.gate_proj", "mlp.up_proj")
EPS = 1e-8
def load_grams(n_layers: int = 36, d: int = 2560) -> dict[tuple[str, str], torch.Tensor]:
"""{(role, blocks): G [d, d]} for role in {writer, reader}, blocks in {hooked, all}."""
snap = Path(glob("/workspace/.hf_home/hub/models--Qwen--Qwen3-4B/snapshots/*")[0])
wmap = json.loads((snap / "model.safetensors.index.json").read_text())["weight_map"]
G = {(role, blk): torch.zeros(d, d) for role in ("writer", "reader") for blk in ("hooked", "all")}
handles = {f: safe_open(snap / f, framework="pt") for f in set(wmap.values())}
for li in range(n_layers):
for role, mods in (("writer", WRITERS), ("reader", READERS)):
for m in mods:
key = f"model.layers.{li}.{m}.weight"
W = handles[wmap[key]].get_tensor(key).float()
gram = W @ W.T if role == "writer" else W.T @ W
G[(role, "all")] += gram
if li in HOOKED_LAYERS:
G[(role, "hooked")] += gram
return G
def basis(G: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""U_star [d, d] descending, sqrt(Sigma_star) = lambda**0.25 clamped."""
lam, U = torch.linalg.eigh(G.double())
lam, U = lam.flip(0).clamp_min(0), U.flip(1)
return U.float(), lam.float().pow(0.25).clamp_min(EPS) # sqrt(Sigma) = lam^(1/4)
def score(XS: torch.Tensor, dS: torch.Tensor, kind: str) -> np.ndarray:
"""XS [N, L, r], dS [L, r] -> concat score over layers."""
v = dS / dS.norm(dim=-1, keepdim=True).clamp_min(EPS)
d = torch.einsum("nlr,lr->n", XS, v)
if kind == "dot":
return d.numpy()
return (d / (XS.flatten(1).norm(dim=1).clamp_min(EPS) * v.flatten().norm())).numpy()
def main() -> int:
print("building writer/reader Gram bases from Qwen3-4B safetensors (CPU)...")
grams = load_grams()
grams[("both", "hooked")] = grams[("writer", "hooked")] + grams[("reader", "hooked")]
grams[("both", "all")] = grams[("writer", "all")] + grams[("reader", "all")]
bases = {k: basis(g) for k, g in grams.items()}
rows = {}
for tag, diag_dir in RUNS.items():
fe = torch.load(diag_dir / "pinning_feats.pt", weights_only=False)
RES, adv, exploited = fe["RES"].float(), fe["adv"], fe["exploited"]
head = [i for i, pid in enumerate(fe["pair_ids"]) if pid.startswith(HEAD_PREFIX)]
Ph = fe["pair_feats"][("resid", "hack")][head].float()
Pc = fe["pair_feats"][("resid", "clean")][head].float()
pos = (np.abs(adv) > 1e-6) & (adv > 0)
y = exploited & (adv > 0)
au = lambda s: _auroc(s[pos].tolist(), y[pos].tolist())
def add(name: str, XS, dS):
for kind in ("cos", "dot"):
rows.setdefault((name, kind), {"variant": name, "score": kind})[tag] = \
au(score(XS, dS, kind))
rows[(name, kind)][f"{tag} L24"] = au(score(XS[:, 2:], dS[2:], kind))
add("raw resid (baseline)", RES, (Ph - Pc).mean(0))
for (role, blk), (U, sqrtS) in sorted(bases.items()):
# whiten=False at r=full is a pure rotation (== baseline, skipped); with
# top-r it still tests sparsification in the pooled eigenbasis.
for whiten, lab in ((True, "superS"), (False, "superS-rot")):
tx = lambda X: torch.einsum("nld,dk->nlk", X, U) / (sqrtS if whiten else 1.0)
XS, dS = tx(RES), (tx(Ph) - tx(Pc)).mean(0)
if whiten:
add(f"{lab} {role}/{blk} r=full", XS, dS)
for r in (256, 64):
idx = dS.abs().topk(r, dim=-1).indices # [L, r] per-layer top modes
XSr = torch.gather(XS, 2, idx.unsqueeze(0).expand(XS.shape[0], -1, -1))
add(f"{lab} {role}/{blk} r={r}", XSr, torch.gather(dS, 1, idx))
out = list(rows.values())
for r in out:
vals = [r[t] for t in RUNS]
r["mean"], r["min"] = float(np.mean(vals)), float(np.min(vals))
out.sort(key=lambda r: -r["min"])
cols = ["variant", "score", "v3", "v4", "v5", "mean", "min",
"v3 L24", "v4 L24", "v5 L24"]
print("\nAUROC on the A>0 contrast, concat over layers 12/18/24 (and L24 alone):")
print("SHOULD: 'raw resid (baseline)' cos == 0.916/0.700/0.804 (matches journal "
"Table 1) ELSE the harness disagrees with diag_pinning. superS rows differ "
"from baseline only via whitening + top-r; if no superS row beats baseline "
"min by >0.03 the pooled-spectrum metric adds nothing here.")
print(tabulate([{c: r[c] for c in cols} for r in out],
headers="keys", tablefmt="pipe", floatfmt="+.3f"))
return 0
if __name__ == "__main__":
raise SystemExit(main())