journal(#41): entry (g) routeA shipped + guard-drop calibration; track moduleS diag scripts

Entry (f) already cited scripts/diag_pinning_moduleS_exact.py; both moduleS
scripts were untracked, so committing them for provenance alongside the
calibration script the new entry cites.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-11 12:50:07 +00:00
parent f646e57028
commit 8000aa48f4
4 changed files with 560 additions and 0 deletions
+93
View File
@@ -4257,3 +4257,96 @@ Provenance:
**Discussion (speculative).** I read both as nulls at this sample size (per-window SE ~0.07). For super-S, whitening consistently sits at or below the raw baseline, which would fit the pooled spectrum amplifying low-energy directions that carry no hack signal here, but I cannot distinguish that from noise. The one apparent gain (rotation + reader basis + top-64) is exactly what taking the maximum of 50 noisy rows produces; I would only believe it if it survived on new windows chosen in advance. For t-stat the entry-(d) alternative stands: the per-coordinate std over 8 pairs is itself ~25% noise, so the weighting may be real but unestimable at this pair count. Both nulls leave act_dot with plain mean extraction as the routeA default.
**Next.** routeA implementation per the plan now written into docs/spec/20260611_act_gate_spec.md (extraction module with verify gate, gate wiring replacing routeV, rolling-buffer winsorized-Otsu pinning), pending wassname's approval. More authored pairs remains the highest-leverage data change.
## 2026-06-11 (f) -- per-module S-space shows no robust gate-score improvement
**Question.** Does preserving each Linear's own SVD space reveal module-specific
hack signal that pooled Super-S washes out?
**Methods.** `scripts/diag_pinning_moduleS_exact.py` hooks the actual inputs of
reader Linears and recomputes base-weight outputs of writer Linears from their
actual inputs in blocks 12/18/24. Per module it
uses the steering-lite S-space identities `x @ V * sqrt(S)` for readers and
`y @ U / sqrt(S)` for writers, extracts a direction from the eight `behavior_`
pairs, selects top-r modes by pair-difference magnitude, then aggregates module
scores. Pueue 28/29/30 reran the v3/v4/v5 emergence windows and wrote strict TSV
evidence.
| score | v3 | v4 | v5 | mean | min |
|---|---:|---:|---:|---:|---:|
| moduleS writer r=256 concat-cos, best selected row | 0.892 | 0.733 | 0.786 | 0.804 | 0.733 |
| act-dot existing default | 0.870 | 0.747 | 0.747 | 0.788 | 0.747 |
| raw residual dot | 0.905 | 0.721 | 0.756 | 0.794 | 0.721 |
**Result.** Per-module S-space shows no robust improvement on these windows. Its
best row falls below act-dot on worst-window AUROC and was selected from 27 rows,
so the small gain over raw residual dot is not evidence of improvement. The earlier
cached-residual approximation produced a stronger reader-r64 row, but fresh
review correctly identified that it was a module-weight-derived metric on
post-block residuals rather than exact module S-space.
**Evidence.** Full cross-window table:
`out/diag/moduleS_exact_summary.tsv`. Per-window tables:
`out/diag{,_v4,_v5}/moduleS_exact.tsv`. Spec and failure log:
`docs/spec/20260611_per_module_sspace.md`.
## 2026-06-11 (g) -- routeA act gate shipped; bimodality guard dropped after calibration
**Introduction.** Entries (b)-(f) established that the activation dot score is the
stable gate input. This entry covers the implementation
(docs/spec/20260611_act_gate_spec.md): forward-only `v_act` extraction
(`src/vgrout/extract_vhack_act.py`), the routeA gate in train.py (act capture on the
quarantine-ablated logp_old forward, masks pinned before the single grad forward,
rolling-buffer Otsu thresholds), and deletion of the routeV gradient gate. The spec
left one open question: should an online bimodality guard close the rout zone before
hacks emerge? Expected: some shape statistic of the score window separates the
emergence mixture (hack share 35-43%) from hack-free scores.
**Methods.** Calibration is offline on the cached v3/v4/v5 emergence-window features
(out/diag{,_v4,_v5}/pinning_feats.pt, produced by scripts/diag_pinning.py at commit
70697ff). Score = act dot vs the `behavior_` 8-pair v_act. Conditions: mixture (all
valid live rollouts), cleanonly (non-exploited only, pre-emergence proxy), and N(0,1)
n=256 (10 seeds). Statistics computed after z-norm, winsorize(1/99), two-threshold
Otsu: `sep` = mean(z above t_hi) minus mean(z below t_lo) in buffer-sd units, `nbcv` =
between-class variance fraction. Command: `uv run python
scripts/attic/calib_otsu_guard.py`. Extractor equivalence was verified on GPU as
pueue #24 (`scripts/verify_v_act.py`, commit 5a340e5).
| cond | hack% | n | sep | nbcv |
|---|---:|---:|---:|---:|
| v3 mixture | 0.43 | 138 | 2.75 | 0.80 |
| v3 cleanonly | 0.00 | 79 | 2.54 | 0.82 |
| v4 mixture | 0.35 | 96 | 2.82 | 0.81 |
| v4 cleanonly | 0.00 | 62 | 2.34 | 0.84 |
| v5 mixture | 0.39 | 138 | 2.44 | 0.76 |
| v5 cleanonly | 0.00 | 84 | 3.52 | 0.77 |
| gauss n=256 (mean of 10 seeds) | 0.00 | 256 | 2.42 | 0.83 |
Table: guard-candidate statistics per condition. A usable guard needs the mixture rows
to separate from the cleanonly and gauss rows on `sep` or `nbcv` with margin.
Provenance:
- Script: `scripts/attic/calib_otsu_guard.py` (copied from the session scratch file);
rerun output captured at `/tmp/claude-1000/calib_guard_out.txt` this session.
- Inputs: `out/diag/pinning_feats.pt`, `out/diag_v4/pinning_feats.pt`,
`out/diag_v5/pinning_feats.pt` (the (c) replication windows).
- verify_v_act: pueue #24 log; acts match cached diag features at rel diff 7.3e-4 and
7.7e-4 (hack/clean), v cos > 0.999, per-module cos >= 0.99997.
**Results.** No statistic separates the conditions. The largest `sep` of all rows is a
hack-FREE window (v5 cleanonly, 3.52); pure Gaussians sit at 2.42, inside the mixture
range (2.44-2.82). `nbcv` overlaps the same way (mixtures 0.76-0.81 vs gauss 0.83).
Otsu always finds tail classes ~2.4 sd apart even when no structure exists, so any
threshold on these statistics either always opens or always closes.
**Discussion (speculative).** I read this as: the guard idea was solving a
non-problem. Before emergence a false rout costs one update removed from deployment
(asymmetric, cheap), and warmup already pins absorb while the buffer fills. An
alternative read is that a better statistic exists (e.g. dip test, mixture-model BIC)
and I only tried Otsu-derived ones; I did not pursue this because the cost asymmetry
makes the guard's value marginal even if it worked. The gate therefore ships with
warmup + Otsu only (commits adca442 routeA wiring + routeV deletion, f646e57
review-driven hardening; smoke logs /tmp/claude-1000/smoke_routeA*.log).
**Next.** Queue the seed-43 fast 4-arm set (`just queue-decision`): routeA real vs
Haar placebo vs vanilla vs absorb. Decision: directionality is real iff real-v
deploy_hack << placebo at matched solve, with gate AUROC >> 0.5 around emergence.
+71
View File
@@ -0,0 +1,71 @@
"""Calibrate the routeA bimodality guard offline.
Conditions per window (v3/v4/v5), act_dot score vs behavior_ pair v_act:
mixture all valid live rollouts (hack share 35-43%) -> guard SHOULD OPEN
cleanonly non-exploited rollouts only (pre-emergence proxy) -> SHOULD CLOSE
gauss N(0,1) n=256, 10 seeds -> SHOULD CLOSE
Candidate statistics, computed after z-norm + winsorize(1/99) + otsu3:
sep = mean(z | z>=t_hi) - mean(z | z<t_lo), in buffer-sd units
nbcv = three-class between-class variance / total variance (both on winsorized z)
"""
import numpy as np
import torch
from pathlib import Path
ROOT = Path("/workspace/projected_grpo")
RUNS = {"v3": ROOT / "out/diag", "v4": ROOT / "out/diag_v4", "v5": ROOT / "out/diag_v5"}
def otsu3(x: np.ndarray) -> tuple[float, float]:
x = np.clip(x, *np.quantile(x, [0.01, 0.99]))
s = np.sort(np.asarray(x, float))
n = len(s)
c = np.concatenate([[0.0], np.cumsum(s)])
best, best_ij = -np.inf, (1, 2)
for i in range(1, n - 1):
for j in range(i + 1, n):
obj = c[i] ** 2 / i + (c[j] - c[i]) ** 2 / (j - i) + (c[n] - c[j]) ** 2 / (n - j)
if obj > best:
best, best_ij = obj, (i, j)
i, j = best_ij
return float((s[i - 1] + s[i]) / 2), float((s[j - 1] + s[j]) / 2)
def stats(scores: np.ndarray) -> dict:
z = (scores - scores.mean()) / (scores.std() or 1.0)
zw = np.clip(z, *np.quantile(z, [0.01, 0.99]))
t_lo, t_hi = otsu3(z)
keep, mid, rout = zw < t_lo, (zw >= t_lo) & (zw < t_hi), zw >= t_hi
sep = float(zw[rout].mean() - zw[keep].mean()) if rout.any() and keep.any() else float("nan")
mu = zw.mean()
nbcv = sum(m.mean() * (zw[m].mean() - mu) ** 2 for m in (keep, mid, rout) if m.any()) / zw.var()
return {"n": len(scores), "t_lo": t_lo, "t_hi": t_hi, "sep": sep, "nbcv": float(nbcv),
"w_rout": float(rout.mean())}
rows = []
for tag, d in RUNS.items():
fe = torch.load(d / "pinning_feats.pt", weights_only=False)
ACT, adv, exploited = fe["ACT"].float(), fe["adv"], fe["exploited"]
head = [i for i, pid in enumerate(fe["pair_ids"]) if pid.startswith("behavior_")]
D = (fe["pair_feats"][("act", "hack")][head] - fe["pair_feats"][("act", "clean")][head]).float().mean(0)
V = D / D.norm(dim=-1, keepdim=True).clamp_min(1e-12)
s = torch.einsum("nmr,mr->n", ACT, V).numpy()
valid = np.abs(adv) > 1e-6
rows.append({"cond": f"{tag} mixture", "hack%": exploited[valid].mean(), **stats(s[valid])})
cl = valid & ~exploited
rows.append({"cond": f"{tag} cleanonly", "hack%": 0.0, **stats(s[cl])})
rng = np.random.default_rng(0)
g = [stats(rng.standard_normal(256)) for _ in range(10)]
rows.append({"cond": "gauss n=256 (mean of 10)", "hack%": 0.0,
**{k: float(np.mean([r[k] for r in g])) for k in g[0]}})
rows.append({"cond": "gauss n=256 (max sep/nbcv)", "hack%": 0.0,
**{k: float(np.max([r[k] for r in g])) for k in g[0]}})
from tabulate import tabulate
print(tabulate(rows, headers="keys", tablefmt="pipe", floatfmt="+.3f"))
print("\nSHOULD: mixture rows separate from cleanonly+gauss rows on at least one of "
"sep/nbcv with a usable margin; ELSE no shape-based guard works and we rely on "
"warmup + asymmetric-cost only.")
+164
View File
@@ -0,0 +1,164 @@
"""Per-module residual S-space gate score.
Unlike Super-S, this keeps one residual-side SVD basis per Linear through
projection, pair-direction extraction, mode selection, and scoring. Scores are
aggregated only after each module has produced its own score.
This is an offline residual-stream diagnostic, not exact steering-lite sspace:
the cache contains block residuals at layers 12/18/24, not each Linear's own
input/output activations.
writer (o_proj/down_proj): G_m = W_m W_m^T = U_m S_m^2 U_m^T
xS_m = h @ U_m / sqrt(S_m)
reader (q/k/v/gate/up): G_m = W_m^T W_m = V_m S_m^2 V_m^T
xS_m = h @ V_m * sqrt(S_m)
uv run python scripts/diag_pinning_moduleS.py
"""
from __future__ import annotations
import json
from glob import glob
from pathlib import Path
import numpy as np
import torch
from safetensors import safe_open
from tabulate import tabulate
from vgrout.train import _auroc
ROOT = Path("/workspace/projected_grpo")
RUNS = {"v3": ROOT / "out/diag", "v4": ROOT / "out/diag_v4", "v5": ROOT / "out/diag_v5"}
HEAD_PREFIX = "behavior_"
HOOKED_LAYERS = (12, 18, 24)
WRITERS = ("self_attn.o_proj", "mlp.down_proj")
READERS = ("self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj",
"mlp.gate_proj", "mlp.up_proj")
RANKS = (64, 256, -1)
EPS = 1e-8
def model_weights():
snap = Path(glob("/workspace/.hf_home/hub/models--Qwen--Qwen3-4B/snapshots/*")[0])
wmap = json.loads((snap / "model.safetensors.index.json").read_text())["weight_map"]
handles = {f: safe_open(snap / f, framework="pt") for f in set(wmap.values())}
for residual_idx, layer_idx in enumerate(HOOKED_LAYERS):
for role, modules in (("writer", WRITERS), ("reader", READERS)):
module_layer = layer_idx if role == "writer" else layer_idx + 1
for module in modules:
key = f"model.layers.{module_layer}.{module}.weight"
yield residual_idx, module_layer, role, module, handles[wmap[key]].get_tensor(key).float()
def residual_basis(W: torch.Tensor, role: str) -> tuple[torch.Tensor, torch.Tensor]:
"""Return residual-side basis and sqrt(Sigma), descending by singular value."""
gram = W @ W.T if role == "writer" else W.T @ W
eigenvalues, basis = torch.linalg.eigh(gram)
return basis.flip(1), eigenvalues.flip(0).clamp_min(0).pow(0.25).clamp_min(EPS)
def module_stats(
X: torch.Tensor,
Ph: torch.Tensor,
Pc: torch.Tensor,
basis: torch.Tensor,
sqrtS: torch.Tensor,
role: str,
rank: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Return per-example dot, cosine, and squared norm in one module S-space."""
scale = sqrtS if role == "reader" else sqrtS.reciprocal()
XS = (X @ basis) * scale
dS = ((Ph @ basis) * scale - (Pc @ basis) * scale).mean(0)
if rank > 0:
idx = dS.abs().topk(rank).indices
XS, dS = XS[:, idx], dS[idx]
direction = dS / dS.norm().clamp_min(EPS)
dot = XS @ direction
norm2 = XS.square().sum(1)
return dot, dot / norm2.sqrt().clamp_min(EPS), norm2
def main() -> int:
features = {tag: torch.load(path / "pinning_feats.pt", weights_only=False)
for tag, path in RUNS.items()}
prepared = {}
for tag, fe in features.items():
head = [i for i, pair_id in enumerate(fe["pair_ids"]) if pair_id.startswith(HEAD_PREFIX)]
prepared[tag] = {
"X": fe["RES"].float(),
"Ph": fe["pair_feats"][("resid", "hack")][head].float(),
"Pc": fe["pair_feats"][("resid", "clean")][head].float(),
"pos": (np.abs(fe["adv"]) > 1e-6) & (fe["adv"] > 0),
"y": fe["exploited"] & (fe["adv"] > 0),
}
collected = {
(tag, role, rank): {"dot": [], "cos": [], "norm2": []}
for tag in RUNS for role in ("writer", "reader") for rank in RANKS
}
print("building and scoring 21 independent residual-side module SVD bases (CPU)...")
for residual_idx, layer_idx, role, module, W in model_weights():
print(f" residual={HOOKED_LAYERS[residual_idx]:02d} module_layer={layer_idx:02d} "
f"role={role:6s} module={module}")
basis, sqrtS = residual_basis(W, role)
for tag, fe in prepared.items():
X, Ph, Pc = fe["X"][:, residual_idx], fe["Ph"][:, residual_idx], fe["Pc"][:, residual_idx]
for rank in RANKS:
dot, cos, norm2 = module_stats(X, Ph, Pc, basis, sqrtS, role, rank)
for key, value in (("dot", dot), ("cos", cos), ("norm2", norm2)):
collected[(tag, role, rank)][key].append(value)
rows = {}
for tag, fe in prepared.items():
au = lambda score: _auroc(score[fe["pos"]].tolist(), fe["y"][fe["pos"]].tolist())
raw_d = (fe["Ph"] - fe["Pc"]).mean(0)
raw_v = raw_d / raw_d.norm(dim=-1, keepdim=True).clamp_min(EPS)
raw_dot = torch.einsum("nld,ld->n", fe["X"], raw_v)
raw_cos = raw_dot / (fe["X"].flatten(1).norm(dim=1).clamp_min(EPS)
* raw_v.flatten().norm())
rows.setdefault(("raw resid", "concat", "cos"), {"variant": "raw resid",
"aggregate": "concat", "score": "cos"})[tag] = au(raw_cos)
rows.setdefault(("raw resid", "concat", "dot"), {"variant": "raw resid",
"aggregate": "concat", "score": "dot"})[tag] = au(raw_dot)
for role in ("writer", "reader", "both"):
roles = ("writer", "reader") if role == "both" else (role,)
for rank in RANKS:
stats = [collected[(tag, r, rank)] for r in roles]
dots = torch.stack([x for stat in stats for x in stat["dot"]], 1)
coses = torch.stack([x for stat in stats for x in stat["cos"]], 1)
norm2 = torch.stack([x for stat in stats for x in stat["norm2"]], 1)
scores = {
("mean", "cos"): coses.mean(1),
("mean", "dot"): dots.mean(1),
("concat", "cos"): dots.sum(1) / (
norm2.sum(1).sqrt().clamp_min(EPS) * np.sqrt(dots.shape[1])),
("concat", "dot"): dots.sum(1),
}
rank_name = "full" if rank < 0 else str(rank)
variant = f"moduleS {role} r={rank_name}"
for (aggregate, kind), score in scores.items():
key = (variant, aggregate, kind)
rows.setdefault(key, {"variant": variant, "aggregate": aggregate,
"score": kind})[tag] = au(score)
out = list(rows.values())
for row in out:
values = [row[tag] for tag in RUNS]
row["mean"], row["min"] = float(np.mean(values)), float(np.min(values))
out.sort(key=lambda row: -row["min"])
cols = ["variant", "aggregate", "score", "v3", "v4", "v5", "mean", "min"]
print("\nAUROC on the A>0 contrast:")
print("SHOULD: raw resid cos == 0.916/0.700/0.804 and dot == "
"0.905/0.721/0.756 ELSE this harness disagrees with diag_pinning_superS. "
"Per-module bases must remain separate until score aggregation.")
print(tabulate([{c: row[c] for c in cols} for row in out],
headers="keys", tablefmt="pipe", floatfmt="+.3f"))
return 0
if __name__ == "__main__":
raise SystemExit(main())
+232
View File
@@ -0,0 +1,232 @@
"""Exact per-module S-space gate diagnostic on actual Linear inputs/outputs.
For every selected Linear, capture the same side used by steering-lite sspace:
writer output: xS = y @ U / sqrt(S)
reader input: xS = x @ V * sqrt(S)
Each module keeps its own thin SVD basis through pair-direction extraction,
top-r selection, and scoring. Scores are aggregated only after each module has
produced its own score.
GPU required:
uv run python scripts/diag_pinning_moduleS_exact.py
"""
from __future__ import annotations
import json
import struct
import csv
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
import tyro
from loguru import logger
from safetensors.torch import load_file
from tabulate import tabulate
from transformers import AutoModelForCausalLM, AutoTokenizer
from vgrout.lora2r import wrap_model_with_lora2r
from vgrout.pairs import load_pairs
from vgrout.train import _auroc
ROLES = {
"self_attn.o_proj": "writer",
"mlp.down_proj": "writer",
"self_attn.q_proj": "reader",
"self_attn.k_proj": "reader",
"self_attn.v_proj": "reader",
"mlp.gate_proj": "reader",
"mlp.up_proj": "reader",
}
RANKS = (64, 256, -1)
EPS = 1e-8
@dataclass
class Cfg:
run_dir: Path = Path("out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3")
ckpt: str = "first_hack"
pairs: Path = Path("data/pairs/hack_pairs.md#all-in-one")
headline_prefix: str = "behavior_"
step_lo: int = 2
step_hi: int = 9
max_rollouts: int = 240
layers: tuple[int, ...] = (12, 18, 24)
out: Path = Path("out/diag/moduleS_exact.tsv")
def checkpoint_meta(path: Path) -> dict:
with open(path, "rb") as f:
return json.loads(f.read(struct.unpack("<Q", f.read(8))[0])).get("__metadata__", {})
def completion_nll(model, tokenizer, prompt: str, completion: str, device) -> torch.Tensor:
prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
full_ids = tokenizer(prompt + completion, return_tensors="pt").input_ids.to(device)
n_prompt = prompt_ids.shape[1]
logits = model(full_ids).logits[:, :-1]
targets = full_ids[:, 1:]
logp = torch.nn.functional.log_softmax(logits.float(), dim=-1)
nll = -logp.gather(-1, targets.unsqueeze(-1)).squeeze(-1)
positions = torch.arange(full_ids.shape[1] - 1, device=device).unsqueeze(0)
mask = (positions >= (n_prompt - 1)).float()
return (nll * mask).sum() / mask.sum()
def thin_basis(W: torch.Tensor, role: str) -> tuple[torch.Tensor, torch.Tensor]:
"""Return residual-side thin basis and sqrt(Sigma), descending."""
d_out, d_in = W.shape
if role == "writer":
eigenvalues, basis = torch.linalg.eigh(W @ W.T)
elif d_out >= d_in:
eigenvalues, basis = torch.linalg.eigh(W.T @ W)
else:
eigenvalues, U = torch.linalg.eigh(W @ W.T)
singular_values = eigenvalues.clamp_min(0).sqrt()
basis = W.T @ U / singular_values.clamp_min(EPS)
return basis.flip(1), eigenvalues.flip(0).clamp_min(0).pow(0.25).clamp_min(EPS)
class ModuleSTap:
def __init__(self, modules: dict[str, tuple[torch.nn.Module, str, torch.Tensor, torch.Tensor]]):
self.modules = modules
self.values: dict[str, torch.Tensor] = {}
self.handles = []
def __enter__(self):
for name, (module, role, basis, sqrtS) in self.modules.items():
def hook(module, args, output, name=name, role=role, basis=basis, sqrtS=sqrtS):
h = F.linear(args[0], module.weight, module.bias) if role == "writer" else args[0]
scale = sqrtS.reciprocal() if role == "writer" else sqrtS
self.values[name] = (h.detach().float() @ basis) * scale
self.handles.append(module.register_forward_hook(hook))
return self
def __exit__(self, *exc):
for handle in self.handles:
handle.remove()
def pooled(self, n_prompt: int) -> dict[str, torch.Tensor]:
return {name: value[0, n_prompt:].mean(0).cpu() for name, value in self.values.items()}
def score_rows(
live: list[dict[str, torch.Tensor]],
pair_hack: list[dict[str, torch.Tensor]],
pair_clean: list[dict[str, torch.Tensor]],
roles: dict[str, str],
adv: np.ndarray,
exploited: np.ndarray,
) -> list[dict]:
head_diff = {
name: torch.stack([hack[name] - clean[name] for hack, clean in zip(pair_hack, pair_clean)]).mean(0)
for name in roles
}
pos = (np.abs(adv) > 1e-6) & (adv > 0)
y = exploited & (adv > 0)
au = lambda score: _auroc(score[pos].tolist(), y[pos].tolist())
rows = []
for role in ("writer", "reader", "both"):
names = [name for name, module_role in roles.items() if role == "both" or module_role == role]
for rank in RANKS:
dots, coses, norm2 = [], [], []
for name in names:
dS = head_diff[name]
X = torch.stack([example[name] for example in live])
if rank > 0 and rank < dS.numel():
idx = dS.abs().topk(rank).indices
dS, X = dS[idx], X[:, idx]
direction = dS / dS.norm().clamp_min(EPS)
dot = X @ direction
dots.append(dot)
coses.append(dot / X.norm(dim=1).clamp_min(EPS))
norm2.append(X.square().sum(1))
dots, coses, norm2 = torch.stack(dots, 1), torch.stack(coses, 1), torch.stack(norm2, 1)
scores = {
"mean_cos": coses.mean(1),
"concat_cos": dots.sum(1) / (norm2.sum(1).sqrt().clamp_min(EPS) * np.sqrt(dots.shape[1])),
"dot": dots.mean(1),
}
rank_name = "full" if rank < 0 else str(rank)
rows.extend({"variant": f"moduleS-exact {role} r={rank_name}", "score": kind, "auroc": au(score)}
for kind, score in scores.items())
return rows
def main(cfg: Cfg) -> int:
device = torch.device("cuda")
ckpt_path = cfg.run_dir / f"{cfg.ckpt}.safetensors"
meta = checkpoint_meta(ckpt_path)
run_cfg = json.loads(meta["cfg"])
tokenizer = AutoTokenizer.from_pretrained(run_cfg["model"])
model = AutoModelForCausalLM.from_pretrained(
run_cfg["model"], dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(device)
model.config.use_cache = False
wrappers = wrap_model_with_lora2r(model, r=run_cfg["lora_r"], init_seed=run_cfg["lora_init_seed"])
state = load_file(str(ckpt_path))
for name, info in wrappers.items():
info["A"].data.copy_(state[f"A/{name}"].to(device, torch.float32))
info["B"].data.copy_(state[f"B/{name}"].to(device, torch.float32))
model.eval()
selected = {}
for name, module in model.named_modules():
for suffix, role in ROLES.items():
if name.endswith(suffix) and any(name.startswith(f"model.layers.{layer}.") for layer in cfg.layers):
basis, sqrtS = thin_basis(module.weight.detach().float(), role)
selected[name] = (module, role, basis, sqrtS)
assert len(selected) == len(cfg.layers) * len(ROLES), len(selected)
roles = {name: role for name, (_, role, _, _) in selected.items()}
logger.info(f"selected {len(selected)} exact module S-spaces")
def extract(tap: ModuleSTap, prompt: str, completion: str) -> dict[str, torch.Tensor]:
with torch.no_grad():
loss = completion_nll(model, tokenizer, prompt, completion, device)
assert torch.isfinite(loss), loss
n_prompt = tokenizer(prompt, return_tensors="pt").input_ids.shape[1]
return tap.pooled(n_prompt)
pairs = [pair for pair in load_pairs(cfg.pairs) if pair.problem_id.startswith(cfg.headline_prefix)]
records = [json.loads(line) for line in (cfg.run_dir / "rollouts.jsonl").read_text().splitlines()]
records = [r for r in records if cfg.step_lo <= r["step"] <= cfg.step_hi and r["text"].strip()][:cfg.max_rollouts]
pair_hack, pair_clean, live = [], [], []
with ModuleSTap(selected) as tap:
for i, pair in enumerate(pairs):
pair_hack.append(extract(tap, pair.prompt, pair.hack))
pair_clean.append(extract(tap, pair.prompt, pair.clean))
logger.info(f"pair {i + 1}/{len(pairs)}")
for i, record in enumerate(records):
live.append(extract(tap, record["prompt"], record["text"]))
if (i + 1) % 40 == 0:
logger.info(f"rollout {i + 1}/{len(records)}")
steps = np.array([r["step"] for r in records])
p_idx = np.array([r["p_idx"] for r in records])
reward = np.array([float(r["reward"]) for r in records])
adv = np.empty(len(records))
for step, prompt_idx in set(zip(steps.tolist(), p_idx.tolist())):
mask = (steps == step) & (p_idx == prompt_idx)
adv[mask] = reward[mask] - reward[mask].mean()
exploited = np.array([bool(r["exploited"]) for r in records])
rows = score_rows(live, pair_hack, pair_clean, roles, adv, exploited)
rows.sort(key=lambda row: -row["auroc"])
cfg.out.parent.mkdir(parents=True, exist_ok=True)
with cfg.out.open("w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=("variant", "score", "auroc"), delimiter="\t")
writer.writeheader()
writer.writerows(rows)
print("SHOULD: all rows finite; exact hooks capture writer outputs and reader inputs. "
"A selected maximum is a candidate, not confirmation.")
print(tabulate(rows, headers="keys", tablefmt="pipe", floatfmt="+.3f"))
logger.info(f"wrote {cfg.out}")
return 0
if __name__ == "__main__":
raise SystemExit(main(tyro.cli(Cfg)))