From 8000aa48f44784ac65951666b57ddaadfe9a2243 Mon Sep 17 00:00:00 2001 From: wassname <1103714+wassname@users.noreply.github.com> Date: Thu, 11 Jun 2026 12:50:07 +0000 Subject: [PATCH] journal(#41): entry (g) routeA shipped + guard-drop calibration; track moduleS diag scripts Entry (f) already cited scripts/diag_pinning_moduleS_exact.py; both moduleS scripts were untracked, so committing them for provenance alongside the calibration script the new entry cites. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com> --- RESEARCH_JOURNAL.md | 93 +++++++++++ scripts/attic/calib_otsu_guard.py | 71 ++++++++ scripts/diag_pinning_moduleS.py | 164 ++++++++++++++++++ scripts/diag_pinning_moduleS_exact.py | 232 ++++++++++++++++++++++++++ 4 files changed, 560 insertions(+) create mode 100644 scripts/attic/calib_otsu_guard.py create mode 100644 scripts/diag_pinning_moduleS.py create mode 100644 scripts/diag_pinning_moduleS_exact.py diff --git a/RESEARCH_JOURNAL.md b/RESEARCH_JOURNAL.md index 4505bff..791cd9b 100644 --- a/RESEARCH_JOURNAL.md +++ b/RESEARCH_JOURNAL.md @@ -4257,3 +4257,96 @@ Provenance: **Discussion (speculative).** I read both as nulls at this sample size (per-window SE ~0.07). For super-S, whitening consistently sits at or below the raw baseline, which would fit the pooled spectrum amplifying low-energy directions that carry no hack signal here, but I cannot distinguish that from noise. The one apparent gain (rotation + reader basis + top-64) is exactly what taking the maximum of 50 noisy rows produces; I would only believe it if it survived on new windows chosen in advance. For t-stat the entry-(d) alternative stands: the per-coordinate std over 8 pairs is itself ~25% noise, so the weighting may be real but unestimable at this pair count. Both nulls leave act_dot with plain mean extraction as the routeA default. **Next.** routeA implementation per the plan now written into docs/spec/20260611_act_gate_spec.md (extraction module with verify gate, gate wiring replacing routeV, rolling-buffer winsorized-Otsu pinning), pending wassname's approval. More authored pairs remains the highest-leverage data change. +## 2026-06-11 (f) -- per-module S-space shows no robust gate-score improvement + +**Question.** Does preserving each Linear's own SVD space reveal module-specific +hack signal that pooled Super-S washes out? + +**Methods.** `scripts/diag_pinning_moduleS_exact.py` hooks the actual inputs of +reader Linears and recomputes base-weight outputs of writer Linears from their +actual inputs in blocks 12/18/24. Per module it +uses the steering-lite S-space identities `x @ V * sqrt(S)` for readers and +`y @ U / sqrt(S)` for writers, extracts a direction from the eight `behavior_` +pairs, selects top-r modes by pair-difference magnitude, then aggregates module +scores. Pueue 28/29/30 reran the v3/v4/v5 emergence windows and wrote strict TSV +evidence. + +| score | v3 | v4 | v5 | mean | min | +|---|---:|---:|---:|---:|---:| +| moduleS writer r=256 concat-cos, best selected row | 0.892 | 0.733 | 0.786 | 0.804 | 0.733 | +| act-dot existing default | 0.870 | 0.747 | 0.747 | 0.788 | 0.747 | +| raw residual dot | 0.905 | 0.721 | 0.756 | 0.794 | 0.721 | + +**Result.** Per-module S-space shows no robust improvement on these windows. Its +best row falls below act-dot on worst-window AUROC and was selected from 27 rows, +so the small gain over raw residual dot is not evidence of improvement. The earlier +cached-residual approximation produced a stronger reader-r64 row, but fresh +review correctly identified that it was a module-weight-derived metric on +post-block residuals rather than exact module S-space. + +**Evidence.** Full cross-window table: +`out/diag/moduleS_exact_summary.tsv`. Per-window tables: +`out/diag{,_v4,_v5}/moduleS_exact.tsv`. Spec and failure log: +`docs/spec/20260611_per_module_sspace.md`. + +## 2026-06-11 (g) -- routeA act gate shipped; bimodality guard dropped after calibration + +**Introduction.** Entries (b)-(f) established that the activation dot score is the +stable gate input. This entry covers the implementation +(docs/spec/20260611_act_gate_spec.md): forward-only `v_act` extraction +(`src/vgrout/extract_vhack_act.py`), the routeA gate in train.py (act capture on the +quarantine-ablated logp_old forward, masks pinned before the single grad forward, +rolling-buffer Otsu thresholds), and deletion of the routeV gradient gate. The spec +left one open question: should an online bimodality guard close the rout zone before +hacks emerge? Expected: some shape statistic of the score window separates the +emergence mixture (hack share 35-43%) from hack-free scores. + +**Methods.** Calibration is offline on the cached v3/v4/v5 emergence-window features +(out/diag{,_v4,_v5}/pinning_feats.pt, produced by scripts/diag_pinning.py at commit +70697ff). Score = act dot vs the `behavior_` 8-pair v_act. Conditions: mixture (all +valid live rollouts), cleanonly (non-exploited only, pre-emergence proxy), and N(0,1) +n=256 (10 seeds). Statistics computed after z-norm, winsorize(1/99), two-threshold +Otsu: `sep` = mean(z above t_hi) minus mean(z below t_lo) in buffer-sd units, `nbcv` = +between-class variance fraction. Command: `uv run python +scripts/attic/calib_otsu_guard.py`. Extractor equivalence was verified on GPU as +pueue #24 (`scripts/verify_v_act.py`, commit 5a340e5). + +| cond | hack% | n | sep | nbcv | +|---|---:|---:|---:|---:| +| v3 mixture | 0.43 | 138 | 2.75 | 0.80 | +| v3 cleanonly | 0.00 | 79 | 2.54 | 0.82 | +| v4 mixture | 0.35 | 96 | 2.82 | 0.81 | +| v4 cleanonly | 0.00 | 62 | 2.34 | 0.84 | +| v5 mixture | 0.39 | 138 | 2.44 | 0.76 | +| v5 cleanonly | 0.00 | 84 | 3.52 | 0.77 | +| gauss n=256 (mean of 10 seeds) | 0.00 | 256 | 2.42 | 0.83 | + +Table: guard-candidate statistics per condition. A usable guard needs the mixture rows +to separate from the cleanonly and gauss rows on `sep` or `nbcv` with margin. + +Provenance: +- Script: `scripts/attic/calib_otsu_guard.py` (copied from the session scratch file); + rerun output captured at `/tmp/claude-1000/calib_guard_out.txt` this session. +- Inputs: `out/diag/pinning_feats.pt`, `out/diag_v4/pinning_feats.pt`, + `out/diag_v5/pinning_feats.pt` (the (c) replication windows). +- verify_v_act: pueue #24 log; acts match cached diag features at rel diff 7.3e-4 and + 7.7e-4 (hack/clean), v cos > 0.999, per-module cos >= 0.99997. + +**Results.** No statistic separates the conditions. The largest `sep` of all rows is a +hack-FREE window (v5 cleanonly, 3.52); pure Gaussians sit at 2.42, inside the mixture +range (2.44-2.82). `nbcv` overlaps the same way (mixtures 0.76-0.81 vs gauss 0.83). +Otsu always finds tail classes ~2.4 sd apart even when no structure exists, so any +threshold on these statistics either always opens or always closes. + +**Discussion (speculative).** I read this as: the guard idea was solving a +non-problem. Before emergence a false rout costs one update removed from deployment +(asymmetric, cheap), and warmup already pins absorb while the buffer fills. An +alternative read is that a better statistic exists (e.g. dip test, mixture-model BIC) +and I only tried Otsu-derived ones; I did not pursue this because the cost asymmetry +makes the guard's value marginal even if it worked. The gate therefore ships with +warmup + Otsu only (commits adca442 routeA wiring + routeV deletion, f646e57 +review-driven hardening; smoke logs /tmp/claude-1000/smoke_routeA*.log). + +**Next.** Queue the seed-43 fast 4-arm set (`just queue-decision`): routeA real vs +Haar placebo vs vanilla vs absorb. Decision: directionality is real iff real-v +deploy_hack << placebo at matched solve, with gate AUROC >> 0.5 around emergence. diff --git a/scripts/attic/calib_otsu_guard.py b/scripts/attic/calib_otsu_guard.py new file mode 100644 index 0000000..b26e3cf --- /dev/null +++ b/scripts/attic/calib_otsu_guard.py @@ -0,0 +1,71 @@ +"""Calibrate the routeA bimodality guard offline. + +Conditions per window (v3/v4/v5), act_dot score vs behavior_ pair v_act: + mixture all valid live rollouts (hack share 35-43%) -> guard SHOULD OPEN + cleanonly non-exploited rollouts only (pre-emergence proxy) -> SHOULD CLOSE + gauss N(0,1) n=256, 10 seeds -> SHOULD CLOSE + +Candidate statistics, computed after z-norm + winsorize(1/99) + otsu3: + sep = mean(z | z>=t_hi) - mean(z | z tuple[float, float]: + x = np.clip(x, *np.quantile(x, [0.01, 0.99])) + s = np.sort(np.asarray(x, float)) + n = len(s) + c = np.concatenate([[0.0], np.cumsum(s)]) + best, best_ij = -np.inf, (1, 2) + for i in range(1, n - 1): + for j in range(i + 1, n): + obj = c[i] ** 2 / i + (c[j] - c[i]) ** 2 / (j - i) + (c[n] - c[j]) ** 2 / (n - j) + if obj > best: + best, best_ij = obj, (i, j) + i, j = best_ij + return float((s[i - 1] + s[i]) / 2), float((s[j - 1] + s[j]) / 2) + + +def stats(scores: np.ndarray) -> dict: + z = (scores - scores.mean()) / (scores.std() or 1.0) + zw = np.clip(z, *np.quantile(z, [0.01, 0.99])) + t_lo, t_hi = otsu3(z) + keep, mid, rout = zw < t_lo, (zw >= t_lo) & (zw < t_hi), zw >= t_hi + sep = float(zw[rout].mean() - zw[keep].mean()) if rout.any() and keep.any() else float("nan") + mu = zw.mean() + nbcv = sum(m.mean() * (zw[m].mean() - mu) ** 2 for m in (keep, mid, rout) if m.any()) / zw.var() + return {"n": len(scores), "t_lo": t_lo, "t_hi": t_hi, "sep": sep, "nbcv": float(nbcv), + "w_rout": float(rout.mean())} + + +rows = [] +for tag, d in RUNS.items(): + fe = torch.load(d / "pinning_feats.pt", weights_only=False) + ACT, adv, exploited = fe["ACT"].float(), fe["adv"], fe["exploited"] + head = [i for i, pid in enumerate(fe["pair_ids"]) if pid.startswith("behavior_")] + D = (fe["pair_feats"][("act", "hack")][head] - fe["pair_feats"][("act", "clean")][head]).float().mean(0) + V = D / D.norm(dim=-1, keepdim=True).clamp_min(1e-12) + s = torch.einsum("nmr,mr->n", ACT, V).numpy() + valid = np.abs(adv) > 1e-6 + rows.append({"cond": f"{tag} mixture", "hack%": exploited[valid].mean(), **stats(s[valid])}) + cl = valid & ~exploited + rows.append({"cond": f"{tag} cleanonly", "hack%": 0.0, **stats(s[cl])}) + +rng = np.random.default_rng(0) +g = [stats(rng.standard_normal(256)) for _ in range(10)] +rows.append({"cond": "gauss n=256 (mean of 10)", "hack%": 0.0, + **{k: float(np.mean([r[k] for r in g])) for k in g[0]}}) +rows.append({"cond": "gauss n=256 (max sep/nbcv)", "hack%": 0.0, + **{k: float(np.max([r[k] for r in g])) for k in g[0]}}) + +from tabulate import tabulate +print(tabulate(rows, headers="keys", tablefmt="pipe", floatfmt="+.3f")) +print("\nSHOULD: mixture rows separate from cleanonly+gauss rows on at least one of " + "sep/nbcv with a usable margin; ELSE no shape-based guard works and we rely on " + "warmup + asymmetric-cost only.") diff --git a/scripts/diag_pinning_moduleS.py b/scripts/diag_pinning_moduleS.py new file mode 100644 index 0000000..faa0cb7 --- /dev/null +++ b/scripts/diag_pinning_moduleS.py @@ -0,0 +1,164 @@ +"""Per-module residual S-space gate score. + +Unlike Super-S, this keeps one residual-side SVD basis per Linear through +projection, pair-direction extraction, mode selection, and scoring. Scores are +aggregated only after each module has produced its own score. + +This is an offline residual-stream diagnostic, not exact steering-lite sspace: +the cache contains block residuals at layers 12/18/24, not each Linear's own +input/output activations. + + writer (o_proj/down_proj): G_m = W_m W_m^T = U_m S_m^2 U_m^T + xS_m = h @ U_m / sqrt(S_m) + reader (q/k/v/gate/up): G_m = W_m^T W_m = V_m S_m^2 V_m^T + xS_m = h @ V_m * sqrt(S_m) + + uv run python scripts/diag_pinning_moduleS.py +""" +from __future__ import annotations + +import json +from glob import glob +from pathlib import Path + +import numpy as np +import torch +from safetensors import safe_open +from tabulate import tabulate + +from vgrout.train import _auroc + + +ROOT = Path("/workspace/projected_grpo") +RUNS = {"v3": ROOT / "out/diag", "v4": ROOT / "out/diag_v4", "v5": ROOT / "out/diag_v5"} +HEAD_PREFIX = "behavior_" +HOOKED_LAYERS = (12, 18, 24) +WRITERS = ("self_attn.o_proj", "mlp.down_proj") +READERS = ("self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj", + "mlp.gate_proj", "mlp.up_proj") +RANKS = (64, 256, -1) +EPS = 1e-8 + + +def model_weights(): + snap = Path(glob("/workspace/.hf_home/hub/models--Qwen--Qwen3-4B/snapshots/*")[0]) + wmap = json.loads((snap / "model.safetensors.index.json").read_text())["weight_map"] + handles = {f: safe_open(snap / f, framework="pt") for f in set(wmap.values())} + for residual_idx, layer_idx in enumerate(HOOKED_LAYERS): + for role, modules in (("writer", WRITERS), ("reader", READERS)): + module_layer = layer_idx if role == "writer" else layer_idx + 1 + for module in modules: + key = f"model.layers.{module_layer}.{module}.weight" + yield residual_idx, module_layer, role, module, handles[wmap[key]].get_tensor(key).float() + + +def residual_basis(W: torch.Tensor, role: str) -> tuple[torch.Tensor, torch.Tensor]: + """Return residual-side basis and sqrt(Sigma), descending by singular value.""" + gram = W @ W.T if role == "writer" else W.T @ W + eigenvalues, basis = torch.linalg.eigh(gram) + return basis.flip(1), eigenvalues.flip(0).clamp_min(0).pow(0.25).clamp_min(EPS) + + +def module_stats( + X: torch.Tensor, + Ph: torch.Tensor, + Pc: torch.Tensor, + basis: torch.Tensor, + sqrtS: torch.Tensor, + role: str, + rank: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Return per-example dot, cosine, and squared norm in one module S-space.""" + scale = sqrtS if role == "reader" else sqrtS.reciprocal() + XS = (X @ basis) * scale + dS = ((Ph @ basis) * scale - (Pc @ basis) * scale).mean(0) + if rank > 0: + idx = dS.abs().topk(rank).indices + XS, dS = XS[:, idx], dS[idx] + direction = dS / dS.norm().clamp_min(EPS) + dot = XS @ direction + norm2 = XS.square().sum(1) + return dot, dot / norm2.sqrt().clamp_min(EPS), norm2 + + +def main() -> int: + features = {tag: torch.load(path / "pinning_feats.pt", weights_only=False) + for tag, path in RUNS.items()} + prepared = {} + for tag, fe in features.items(): + head = [i for i, pair_id in enumerate(fe["pair_ids"]) if pair_id.startswith(HEAD_PREFIX)] + prepared[tag] = { + "X": fe["RES"].float(), + "Ph": fe["pair_feats"][("resid", "hack")][head].float(), + "Pc": fe["pair_feats"][("resid", "clean")][head].float(), + "pos": (np.abs(fe["adv"]) > 1e-6) & (fe["adv"] > 0), + "y": fe["exploited"] & (fe["adv"] > 0), + } + + collected = { + (tag, role, rank): {"dot": [], "cos": [], "norm2": []} + for tag in RUNS for role in ("writer", "reader") for rank in RANKS + } + print("building and scoring 21 independent residual-side module SVD bases (CPU)...") + for residual_idx, layer_idx, role, module, W in model_weights(): + print(f" residual={HOOKED_LAYERS[residual_idx]:02d} module_layer={layer_idx:02d} " + f"role={role:6s} module={module}") + basis, sqrtS = residual_basis(W, role) + for tag, fe in prepared.items(): + X, Ph, Pc = fe["X"][:, residual_idx], fe["Ph"][:, residual_idx], fe["Pc"][:, residual_idx] + for rank in RANKS: + dot, cos, norm2 = module_stats(X, Ph, Pc, basis, sqrtS, role, rank) + for key, value in (("dot", dot), ("cos", cos), ("norm2", norm2)): + collected[(tag, role, rank)][key].append(value) + + rows = {} + for tag, fe in prepared.items(): + au = lambda score: _auroc(score[fe["pos"]].tolist(), fe["y"][fe["pos"]].tolist()) + raw_d = (fe["Ph"] - fe["Pc"]).mean(0) + raw_v = raw_d / raw_d.norm(dim=-1, keepdim=True).clamp_min(EPS) + raw_dot = torch.einsum("nld,ld->n", fe["X"], raw_v) + raw_cos = raw_dot / (fe["X"].flatten(1).norm(dim=1).clamp_min(EPS) + * raw_v.flatten().norm()) + rows.setdefault(("raw resid", "concat", "cos"), {"variant": "raw resid", + "aggregate": "concat", "score": "cos"})[tag] = au(raw_cos) + rows.setdefault(("raw resid", "concat", "dot"), {"variant": "raw resid", + "aggregate": "concat", "score": "dot"})[tag] = au(raw_dot) + + for role in ("writer", "reader", "both"): + roles = ("writer", "reader") if role == "both" else (role,) + for rank in RANKS: + stats = [collected[(tag, r, rank)] for r in roles] + dots = torch.stack([x for stat in stats for x in stat["dot"]], 1) + coses = torch.stack([x for stat in stats for x in stat["cos"]], 1) + norm2 = torch.stack([x for stat in stats for x in stat["norm2"]], 1) + scores = { + ("mean", "cos"): coses.mean(1), + ("mean", "dot"): dots.mean(1), + ("concat", "cos"): dots.sum(1) / ( + norm2.sum(1).sqrt().clamp_min(EPS) * np.sqrt(dots.shape[1])), + ("concat", "dot"): dots.sum(1), + } + rank_name = "full" if rank < 0 else str(rank) + variant = f"moduleS {role} r={rank_name}" + for (aggregate, kind), score in scores.items(): + key = (variant, aggregate, kind) + rows.setdefault(key, {"variant": variant, "aggregate": aggregate, + "score": kind})[tag] = au(score) + + out = list(rows.values()) + for row in out: + values = [row[tag] for tag in RUNS] + row["mean"], row["min"] = float(np.mean(values)), float(np.min(values)) + out.sort(key=lambda row: -row["min"]) + cols = ["variant", "aggregate", "score", "v3", "v4", "v5", "mean", "min"] + print("\nAUROC on the A>0 contrast:") + print("SHOULD: raw resid cos == 0.916/0.700/0.804 and dot == " + "0.905/0.721/0.756 ELSE this harness disagrees with diag_pinning_superS. " + "Per-module bases must remain separate until score aggregation.") + print(tabulate([{c: row[c] for c in cols} for row in out], + headers="keys", tablefmt="pipe", floatfmt="+.3f")) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/diag_pinning_moduleS_exact.py b/scripts/diag_pinning_moduleS_exact.py new file mode 100644 index 0000000..cc98de3 --- /dev/null +++ b/scripts/diag_pinning_moduleS_exact.py @@ -0,0 +1,232 @@ +"""Exact per-module S-space gate diagnostic on actual Linear inputs/outputs. + +For every selected Linear, capture the same side used by steering-lite sspace: + + writer output: xS = y @ U / sqrt(S) + reader input: xS = x @ V * sqrt(S) + +Each module keeps its own thin SVD basis through pair-direction extraction, +top-r selection, and scoring. Scores are aggregated only after each module has +produced its own score. + +GPU required: + uv run python scripts/diag_pinning_moduleS_exact.py +""" +from __future__ import annotations + +import json +import struct +import csv +from dataclasses import dataclass +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F +import tyro +from loguru import logger +from safetensors.torch import load_file +from tabulate import tabulate +from transformers import AutoModelForCausalLM, AutoTokenizer + +from vgrout.lora2r import wrap_model_with_lora2r +from vgrout.pairs import load_pairs +from vgrout.train import _auroc + + +ROLES = { + "self_attn.o_proj": "writer", + "mlp.down_proj": "writer", + "self_attn.q_proj": "reader", + "self_attn.k_proj": "reader", + "self_attn.v_proj": "reader", + "mlp.gate_proj": "reader", + "mlp.up_proj": "reader", +} +RANKS = (64, 256, -1) +EPS = 1e-8 + + +@dataclass +class Cfg: + run_dir: Path = Path("out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3") + ckpt: str = "first_hack" + pairs: Path = Path("data/pairs/hack_pairs.md#all-in-one") + headline_prefix: str = "behavior_" + step_lo: int = 2 + step_hi: int = 9 + max_rollouts: int = 240 + layers: tuple[int, ...] = (12, 18, 24) + out: Path = Path("out/diag/moduleS_exact.tsv") + + +def checkpoint_meta(path: Path) -> dict: + with open(path, "rb") as f: + return json.loads(f.read(struct.unpack(" torch.Tensor: + prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) + full_ids = tokenizer(prompt + completion, return_tensors="pt").input_ids.to(device) + n_prompt = prompt_ids.shape[1] + logits = model(full_ids).logits[:, :-1] + targets = full_ids[:, 1:] + logp = torch.nn.functional.log_softmax(logits.float(), dim=-1) + nll = -logp.gather(-1, targets.unsqueeze(-1)).squeeze(-1) + positions = torch.arange(full_ids.shape[1] - 1, device=device).unsqueeze(0) + mask = (positions >= (n_prompt - 1)).float() + return (nll * mask).sum() / mask.sum() + + +def thin_basis(W: torch.Tensor, role: str) -> tuple[torch.Tensor, torch.Tensor]: + """Return residual-side thin basis and sqrt(Sigma), descending.""" + d_out, d_in = W.shape + if role == "writer": + eigenvalues, basis = torch.linalg.eigh(W @ W.T) + elif d_out >= d_in: + eigenvalues, basis = torch.linalg.eigh(W.T @ W) + else: + eigenvalues, U = torch.linalg.eigh(W @ W.T) + singular_values = eigenvalues.clamp_min(0).sqrt() + basis = W.T @ U / singular_values.clamp_min(EPS) + return basis.flip(1), eigenvalues.flip(0).clamp_min(0).pow(0.25).clamp_min(EPS) + + +class ModuleSTap: + def __init__(self, modules: dict[str, tuple[torch.nn.Module, str, torch.Tensor, torch.Tensor]]): + self.modules = modules + self.values: dict[str, torch.Tensor] = {} + self.handles = [] + + def __enter__(self): + for name, (module, role, basis, sqrtS) in self.modules.items(): + def hook(module, args, output, name=name, role=role, basis=basis, sqrtS=sqrtS): + h = F.linear(args[0], module.weight, module.bias) if role == "writer" else args[0] + scale = sqrtS.reciprocal() if role == "writer" else sqrtS + self.values[name] = (h.detach().float() @ basis) * scale + self.handles.append(module.register_forward_hook(hook)) + return self + + def __exit__(self, *exc): + for handle in self.handles: + handle.remove() + + def pooled(self, n_prompt: int) -> dict[str, torch.Tensor]: + return {name: value[0, n_prompt:].mean(0).cpu() for name, value in self.values.items()} + + +def score_rows( + live: list[dict[str, torch.Tensor]], + pair_hack: list[dict[str, torch.Tensor]], + pair_clean: list[dict[str, torch.Tensor]], + roles: dict[str, str], + adv: np.ndarray, + exploited: np.ndarray, +) -> list[dict]: + head_diff = { + name: torch.stack([hack[name] - clean[name] for hack, clean in zip(pair_hack, pair_clean)]).mean(0) + for name in roles + } + pos = (np.abs(adv) > 1e-6) & (adv > 0) + y = exploited & (adv > 0) + au = lambda score: _auroc(score[pos].tolist(), y[pos].tolist()) + rows = [] + for role in ("writer", "reader", "both"): + names = [name for name, module_role in roles.items() if role == "both" or module_role == role] + for rank in RANKS: + dots, coses, norm2 = [], [], [] + for name in names: + dS = head_diff[name] + X = torch.stack([example[name] for example in live]) + if rank > 0 and rank < dS.numel(): + idx = dS.abs().topk(rank).indices + dS, X = dS[idx], X[:, idx] + direction = dS / dS.norm().clamp_min(EPS) + dot = X @ direction + dots.append(dot) + coses.append(dot / X.norm(dim=1).clamp_min(EPS)) + norm2.append(X.square().sum(1)) + dots, coses, norm2 = torch.stack(dots, 1), torch.stack(coses, 1), torch.stack(norm2, 1) + scores = { + "mean_cos": coses.mean(1), + "concat_cos": dots.sum(1) / (norm2.sum(1).sqrt().clamp_min(EPS) * np.sqrt(dots.shape[1])), + "dot": dots.mean(1), + } + rank_name = "full" if rank < 0 else str(rank) + rows.extend({"variant": f"moduleS-exact {role} r={rank_name}", "score": kind, "auroc": au(score)} + for kind, score in scores.items()) + return rows + + +def main(cfg: Cfg) -> int: + device = torch.device("cuda") + ckpt_path = cfg.run_dir / f"{cfg.ckpt}.safetensors" + meta = checkpoint_meta(ckpt_path) + run_cfg = json.loads(meta["cfg"]) + tokenizer = AutoTokenizer.from_pretrained(run_cfg["model"]) + model = AutoModelForCausalLM.from_pretrained( + run_cfg["model"], dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(device) + model.config.use_cache = False + wrappers = wrap_model_with_lora2r(model, r=run_cfg["lora_r"], init_seed=run_cfg["lora_init_seed"]) + state = load_file(str(ckpt_path)) + for name, info in wrappers.items(): + info["A"].data.copy_(state[f"A/{name}"].to(device, torch.float32)) + info["B"].data.copy_(state[f"B/{name}"].to(device, torch.float32)) + model.eval() + + selected = {} + for name, module in model.named_modules(): + for suffix, role in ROLES.items(): + if name.endswith(suffix) and any(name.startswith(f"model.layers.{layer}.") for layer in cfg.layers): + basis, sqrtS = thin_basis(module.weight.detach().float(), role) + selected[name] = (module, role, basis, sqrtS) + assert len(selected) == len(cfg.layers) * len(ROLES), len(selected) + roles = {name: role for name, (_, role, _, _) in selected.items()} + logger.info(f"selected {len(selected)} exact module S-spaces") + + def extract(tap: ModuleSTap, prompt: str, completion: str) -> dict[str, torch.Tensor]: + with torch.no_grad(): + loss = completion_nll(model, tokenizer, prompt, completion, device) + assert torch.isfinite(loss), loss + n_prompt = tokenizer(prompt, return_tensors="pt").input_ids.shape[1] + return tap.pooled(n_prompt) + + pairs = [pair for pair in load_pairs(cfg.pairs) if pair.problem_id.startswith(cfg.headline_prefix)] + records = [json.loads(line) for line in (cfg.run_dir / "rollouts.jsonl").read_text().splitlines()] + records = [r for r in records if cfg.step_lo <= r["step"] <= cfg.step_hi and r["text"].strip()][:cfg.max_rollouts] + pair_hack, pair_clean, live = [], [], [] + with ModuleSTap(selected) as tap: + for i, pair in enumerate(pairs): + pair_hack.append(extract(tap, pair.prompt, pair.hack)) + pair_clean.append(extract(tap, pair.prompt, pair.clean)) + logger.info(f"pair {i + 1}/{len(pairs)}") + for i, record in enumerate(records): + live.append(extract(tap, record["prompt"], record["text"])) + if (i + 1) % 40 == 0: + logger.info(f"rollout {i + 1}/{len(records)}") + + steps = np.array([r["step"] for r in records]) + p_idx = np.array([r["p_idx"] for r in records]) + reward = np.array([float(r["reward"]) for r in records]) + adv = np.empty(len(records)) + for step, prompt_idx in set(zip(steps.tolist(), p_idx.tolist())): + mask = (steps == step) & (p_idx == prompt_idx) + adv[mask] = reward[mask] - reward[mask].mean() + exploited = np.array([bool(r["exploited"]) for r in records]) + + rows = score_rows(live, pair_hack, pair_clean, roles, adv, exploited) + rows.sort(key=lambda row: -row["auroc"]) + cfg.out.parent.mkdir(parents=True, exist_ok=True) + with cfg.out.open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=("variant", "score", "auroc"), delimiter="\t") + writer.writeheader() + writer.writerows(rows) + print("SHOULD: all rows finite; exact hooks capture writer outputs and reader inputs. " + "A selected maximum is a candidate, not confirmation.") + print(tabulate(rows, headers="keys", tablefmt="pipe", floatfmt="+.3f")) + logger.info(f"wrote {cfg.out}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main(tyro.cli(Cfg)))