Files
evil_MoE/scripts/diag_pinning_moduleS.py
T
wassname 8000aa48f4 journal(#41): entry (g) routeA shipped + guard-drop calibration; track moduleS diag scripts
Entry (f) already cited scripts/diag_pinning_moduleS_exact.py; both moduleS
scripts were untracked, so committing them for provenance alongside the
calibration script the new entry cites.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
2026-06-11 12:50:07 +00:00

165 lines
7.2 KiB
Python

"""Per-module residual S-space gate score.
Unlike Super-S, this keeps one residual-side SVD basis per Linear through
projection, pair-direction extraction, mode selection, and scoring. Scores are
aggregated only after each module has produced its own score.
This is an offline residual-stream diagnostic, not exact steering-lite sspace:
the cache contains block residuals at layers 12/18/24, not each Linear's own
input/output activations.
writer (o_proj/down_proj): G_m = W_m W_m^T = U_m S_m^2 U_m^T
xS_m = h @ U_m / sqrt(S_m)
reader (q/k/v/gate/up): G_m = W_m^T W_m = V_m S_m^2 V_m^T
xS_m = h @ V_m * sqrt(S_m)
uv run python scripts/diag_pinning_moduleS.py
"""
from __future__ import annotations
import json
from glob import glob
from pathlib import Path
import numpy as np
import torch
from safetensors import safe_open
from tabulate import tabulate
from vgrout.train import _auroc
ROOT = Path("/workspace/projected_grpo")
RUNS = {"v3": ROOT / "out/diag", "v4": ROOT / "out/diag_v4", "v5": ROOT / "out/diag_v5"}
HEAD_PREFIX = "behavior_"
HOOKED_LAYERS = (12, 18, 24)
WRITERS = ("self_attn.o_proj", "mlp.down_proj")
READERS = ("self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj",
"mlp.gate_proj", "mlp.up_proj")
RANKS = (64, 256, -1)
EPS = 1e-8
def model_weights():
snap = Path(glob("/workspace/.hf_home/hub/models--Qwen--Qwen3-4B/snapshots/*")[0])
wmap = json.loads((snap / "model.safetensors.index.json").read_text())["weight_map"]
handles = {f: safe_open(snap / f, framework="pt") for f in set(wmap.values())}
for residual_idx, layer_idx in enumerate(HOOKED_LAYERS):
for role, modules in (("writer", WRITERS), ("reader", READERS)):
module_layer = layer_idx if role == "writer" else layer_idx + 1
for module in modules:
key = f"model.layers.{module_layer}.{module}.weight"
yield residual_idx, module_layer, role, module, handles[wmap[key]].get_tensor(key).float()
def residual_basis(W: torch.Tensor, role: str) -> tuple[torch.Tensor, torch.Tensor]:
"""Return residual-side basis and sqrt(Sigma), descending by singular value."""
gram = W @ W.T if role == "writer" else W.T @ W
eigenvalues, basis = torch.linalg.eigh(gram)
return basis.flip(1), eigenvalues.flip(0).clamp_min(0).pow(0.25).clamp_min(EPS)
def module_stats(
X: torch.Tensor,
Ph: torch.Tensor,
Pc: torch.Tensor,
basis: torch.Tensor,
sqrtS: torch.Tensor,
role: str,
rank: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Return per-example dot, cosine, and squared norm in one module S-space."""
scale = sqrtS if role == "reader" else sqrtS.reciprocal()
XS = (X @ basis) * scale
dS = ((Ph @ basis) * scale - (Pc @ basis) * scale).mean(0)
if rank > 0:
idx = dS.abs().topk(rank).indices
XS, dS = XS[:, idx], dS[idx]
direction = dS / dS.norm().clamp_min(EPS)
dot = XS @ direction
norm2 = XS.square().sum(1)
return dot, dot / norm2.sqrt().clamp_min(EPS), norm2
def main() -> int:
features = {tag: torch.load(path / "pinning_feats.pt", weights_only=False)
for tag, path in RUNS.items()}
prepared = {}
for tag, fe in features.items():
head = [i for i, pair_id in enumerate(fe["pair_ids"]) if pair_id.startswith(HEAD_PREFIX)]
prepared[tag] = {
"X": fe["RES"].float(),
"Ph": fe["pair_feats"][("resid", "hack")][head].float(),
"Pc": fe["pair_feats"][("resid", "clean")][head].float(),
"pos": (np.abs(fe["adv"]) > 1e-6) & (fe["adv"] > 0),
"y": fe["exploited"] & (fe["adv"] > 0),
}
collected = {
(tag, role, rank): {"dot": [], "cos": [], "norm2": []}
for tag in RUNS for role in ("writer", "reader") for rank in RANKS
}
print("building and scoring 21 independent residual-side module SVD bases (CPU)...")
for residual_idx, layer_idx, role, module, W in model_weights():
print(f" residual={HOOKED_LAYERS[residual_idx]:02d} module_layer={layer_idx:02d} "
f"role={role:6s} module={module}")
basis, sqrtS = residual_basis(W, role)
for tag, fe in prepared.items():
X, Ph, Pc = fe["X"][:, residual_idx], fe["Ph"][:, residual_idx], fe["Pc"][:, residual_idx]
for rank in RANKS:
dot, cos, norm2 = module_stats(X, Ph, Pc, basis, sqrtS, role, rank)
for key, value in (("dot", dot), ("cos", cos), ("norm2", norm2)):
collected[(tag, role, rank)][key].append(value)
rows = {}
for tag, fe in prepared.items():
au = lambda score: _auroc(score[fe["pos"]].tolist(), fe["y"][fe["pos"]].tolist())
raw_d = (fe["Ph"] - fe["Pc"]).mean(0)
raw_v = raw_d / raw_d.norm(dim=-1, keepdim=True).clamp_min(EPS)
raw_dot = torch.einsum("nld,ld->n", fe["X"], raw_v)
raw_cos = raw_dot / (fe["X"].flatten(1).norm(dim=1).clamp_min(EPS)
* raw_v.flatten().norm())
rows.setdefault(("raw resid", "concat", "cos"), {"variant": "raw resid",
"aggregate": "concat", "score": "cos"})[tag] = au(raw_cos)
rows.setdefault(("raw resid", "concat", "dot"), {"variant": "raw resid",
"aggregate": "concat", "score": "dot"})[tag] = au(raw_dot)
for role in ("writer", "reader", "both"):
roles = ("writer", "reader") if role == "both" else (role,)
for rank in RANKS:
stats = [collected[(tag, r, rank)] for r in roles]
dots = torch.stack([x for stat in stats for x in stat["dot"]], 1)
coses = torch.stack([x for stat in stats for x in stat["cos"]], 1)
norm2 = torch.stack([x for stat in stats for x in stat["norm2"]], 1)
scores = {
("mean", "cos"): coses.mean(1),
("mean", "dot"): dots.mean(1),
("concat", "cos"): dots.sum(1) / (
norm2.sum(1).sqrt().clamp_min(EPS) * np.sqrt(dots.shape[1])),
("concat", "dot"): dots.sum(1),
}
rank_name = "full" if rank < 0 else str(rank)
variant = f"moduleS {role} r={rank_name}"
for (aggregate, kind), score in scores.items():
key = (variant, aggregate, kind)
rows.setdefault(key, {"variant": variant, "aggregate": aggregate,
"score": kind})[tag] = au(score)
out = list(rows.values())
for row in out:
values = [row[tag] for tag in RUNS]
row["mean"], row["min"] = float(np.mean(values)), float(np.min(values))
out.sort(key=lambda row: -row["min"])
cols = ["variant", "aggregate", "score", "v3", "v4", "v5", "mean", "min"]
print("\nAUROC on the A>0 contrast:")
print("SHOULD: raw resid cos == 0.916/0.700/0.804 and dot == "
"0.905/0.721/0.756 ELSE this harness disagrees with diag_pinning_superS. "
"Per-module bases must remain separate until score aggregation.")
print(tabulate([{c: row[c] for c in cols} for row in out],
headers="keys", tablefmt="pipe", floatfmt="+.3f"))
return 0
if __name__ == "__main__":
raise SystemExit(main())