mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 16:45:42 +08:00
journal(#41): entry (g) routeA shipped + guard-drop calibration; track moduleS diag scripts
Entry (f) already cited scripts/diag_pinning_moduleS_exact.py; both moduleS scripts were untracked, so committing them for provenance alongside the calibration script the new entry cites. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -4257,3 +4257,96 @@ Provenance:
|
|||||||
**Discussion (speculative).** I read both as nulls at this sample size (per-window SE ~0.07). For super-S, whitening consistently sits at or below the raw baseline, which would fit the pooled spectrum amplifying low-energy directions that carry no hack signal here, but I cannot distinguish that from noise. The one apparent gain (rotation + reader basis + top-64) is exactly what taking the maximum of 50 noisy rows produces; I would only believe it if it survived on new windows chosen in advance. For t-stat the entry-(d) alternative stands: the per-coordinate std over 8 pairs is itself ~25% noise, so the weighting may be real but unestimable at this pair count. Both nulls leave act_dot with plain mean extraction as the routeA default.
|
**Discussion (speculative).** I read both as nulls at this sample size (per-window SE ~0.07). For super-S, whitening consistently sits at or below the raw baseline, which would fit the pooled spectrum amplifying low-energy directions that carry no hack signal here, but I cannot distinguish that from noise. The one apparent gain (rotation + reader basis + top-64) is exactly what taking the maximum of 50 noisy rows produces; I would only believe it if it survived on new windows chosen in advance. For t-stat the entry-(d) alternative stands: the per-coordinate std over 8 pairs is itself ~25% noise, so the weighting may be real but unestimable at this pair count. Both nulls leave act_dot with plain mean extraction as the routeA default.
|
||||||
|
|
||||||
**Next.** routeA implementation per the plan now written into docs/spec/20260611_act_gate_spec.md (extraction module with verify gate, gate wiring replacing routeV, rolling-buffer winsorized-Otsu pinning), pending wassname's approval. More authored pairs remains the highest-leverage data change.
|
**Next.** routeA implementation per the plan now written into docs/spec/20260611_act_gate_spec.md (extraction module with verify gate, gate wiring replacing routeV, rolling-buffer winsorized-Otsu pinning), pending wassname's approval. More authored pairs remains the highest-leverage data change.
|
||||||
|
## 2026-06-11 (f) -- per-module S-space shows no robust gate-score improvement
|
||||||
|
|
||||||
|
**Question.** Does preserving each Linear's own SVD space reveal module-specific
|
||||||
|
hack signal that pooled Super-S washes out?
|
||||||
|
|
||||||
|
**Methods.** `scripts/diag_pinning_moduleS_exact.py` hooks the actual inputs of
|
||||||
|
reader Linears and recomputes base-weight outputs of writer Linears from their
|
||||||
|
actual inputs in blocks 12/18/24. Per module it
|
||||||
|
uses the steering-lite S-space identities `x @ V * sqrt(S)` for readers and
|
||||||
|
`y @ U / sqrt(S)` for writers, extracts a direction from the eight `behavior_`
|
||||||
|
pairs, selects top-r modes by pair-difference magnitude, then aggregates module
|
||||||
|
scores. Pueue 28/29/30 reran the v3/v4/v5 emergence windows and wrote strict TSV
|
||||||
|
evidence.
|
||||||
|
|
||||||
|
| score | v3 | v4 | v5 | mean | min |
|
||||||
|
|---|---:|---:|---:|---:|---:|
|
||||||
|
| moduleS writer r=256 concat-cos, best selected row | 0.892 | 0.733 | 0.786 | 0.804 | 0.733 |
|
||||||
|
| act-dot existing default | 0.870 | 0.747 | 0.747 | 0.788 | 0.747 |
|
||||||
|
| raw residual dot | 0.905 | 0.721 | 0.756 | 0.794 | 0.721 |
|
||||||
|
|
||||||
|
**Result.** Per-module S-space shows no robust improvement on these windows. Its
|
||||||
|
best row falls below act-dot on worst-window AUROC and was selected from 27 rows,
|
||||||
|
so the small gain over raw residual dot is not evidence of improvement. The earlier
|
||||||
|
cached-residual approximation produced a stronger reader-r64 row, but fresh
|
||||||
|
review correctly identified that it was a module-weight-derived metric on
|
||||||
|
post-block residuals rather than exact module S-space.
|
||||||
|
|
||||||
|
**Evidence.** Full cross-window table:
|
||||||
|
`out/diag/moduleS_exact_summary.tsv`. Per-window tables:
|
||||||
|
`out/diag{,_v4,_v5}/moduleS_exact.tsv`. Spec and failure log:
|
||||||
|
`docs/spec/20260611_per_module_sspace.md`.
|
||||||
|
|
||||||
|
## 2026-06-11 (g) -- routeA act gate shipped; bimodality guard dropped after calibration
|
||||||
|
|
||||||
|
**Introduction.** Entries (b)-(f) established that the activation dot score is the
|
||||||
|
stable gate input. This entry covers the implementation
|
||||||
|
(docs/spec/20260611_act_gate_spec.md): forward-only `v_act` extraction
|
||||||
|
(`src/vgrout/extract_vhack_act.py`), the routeA gate in train.py (act capture on the
|
||||||
|
quarantine-ablated logp_old forward, masks pinned before the single grad forward,
|
||||||
|
rolling-buffer Otsu thresholds), and deletion of the routeV gradient gate. The spec
|
||||||
|
left one open question: should an online bimodality guard close the rout zone before
|
||||||
|
hacks emerge? Expected: some shape statistic of the score window separates the
|
||||||
|
emergence mixture (hack share 35-43%) from hack-free scores.
|
||||||
|
|
||||||
|
**Methods.** Calibration is offline on the cached v3/v4/v5 emergence-window features
|
||||||
|
(out/diag{,_v4,_v5}/pinning_feats.pt, produced by scripts/diag_pinning.py at commit
|
||||||
|
70697ff). Score = act dot vs the `behavior_` 8-pair v_act. Conditions: mixture (all
|
||||||
|
valid live rollouts), cleanonly (non-exploited only, pre-emergence proxy), and N(0,1)
|
||||||
|
n=256 (10 seeds). Statistics computed after z-norm, winsorize(1/99), two-threshold
|
||||||
|
Otsu: `sep` = mean(z above t_hi) minus mean(z below t_lo) in buffer-sd units, `nbcv` =
|
||||||
|
between-class variance fraction. Command: `uv run python
|
||||||
|
scripts/attic/calib_otsu_guard.py`. Extractor equivalence was verified on GPU as
|
||||||
|
pueue #24 (`scripts/verify_v_act.py`, commit 5a340e5).
|
||||||
|
|
||||||
|
| cond | hack% | n | sep | nbcv |
|
||||||
|
|---|---:|---:|---:|---:|
|
||||||
|
| v3 mixture | 0.43 | 138 | 2.75 | 0.80 |
|
||||||
|
| v3 cleanonly | 0.00 | 79 | 2.54 | 0.82 |
|
||||||
|
| v4 mixture | 0.35 | 96 | 2.82 | 0.81 |
|
||||||
|
| v4 cleanonly | 0.00 | 62 | 2.34 | 0.84 |
|
||||||
|
| v5 mixture | 0.39 | 138 | 2.44 | 0.76 |
|
||||||
|
| v5 cleanonly | 0.00 | 84 | 3.52 | 0.77 |
|
||||||
|
| gauss n=256 (mean of 10 seeds) | 0.00 | 256 | 2.42 | 0.83 |
|
||||||
|
|
||||||
|
Table: guard-candidate statistics per condition. A usable guard needs the mixture rows
|
||||||
|
to separate from the cleanonly and gauss rows on `sep` or `nbcv` with margin.
|
||||||
|
|
||||||
|
Provenance:
|
||||||
|
- Script: `scripts/attic/calib_otsu_guard.py` (copied from the session scratch file);
|
||||||
|
rerun output captured at `/tmp/claude-1000/calib_guard_out.txt` this session.
|
||||||
|
- Inputs: `out/diag/pinning_feats.pt`, `out/diag_v4/pinning_feats.pt`,
|
||||||
|
`out/diag_v5/pinning_feats.pt` (the (c) replication windows).
|
||||||
|
- verify_v_act: pueue #24 log; acts match cached diag features at rel diff 7.3e-4 and
|
||||||
|
7.7e-4 (hack/clean), v cos > 0.999, per-module cos >= 0.99997.
|
||||||
|
|
||||||
|
**Results.** No statistic separates the conditions. The largest `sep` of all rows is a
|
||||||
|
hack-FREE window (v5 cleanonly, 3.52); pure Gaussians sit at 2.42, inside the mixture
|
||||||
|
range (2.44-2.82). `nbcv` overlaps the same way (mixtures 0.76-0.81 vs gauss 0.83).
|
||||||
|
Otsu always finds tail classes ~2.4 sd apart even when no structure exists, so any
|
||||||
|
threshold on these statistics either always opens or always closes.
|
||||||
|
|
||||||
|
**Discussion (speculative).** I read this as: the guard idea was solving a
|
||||||
|
non-problem. Before emergence a false rout costs one update removed from deployment
|
||||||
|
(asymmetric, cheap), and warmup already pins absorb while the buffer fills. An
|
||||||
|
alternative read is that a better statistic exists (e.g. dip test, mixture-model BIC)
|
||||||
|
and I only tried Otsu-derived ones; I did not pursue this because the cost asymmetry
|
||||||
|
makes the guard's value marginal even if it worked. The gate therefore ships with
|
||||||
|
warmup + Otsu only (commits adca442 routeA wiring + routeV deletion, f646e57
|
||||||
|
review-driven hardening; smoke logs /tmp/claude-1000/smoke_routeA*.log).
|
||||||
|
|
||||||
|
**Next.** Queue the seed-43 fast 4-arm set (`just queue-decision`): routeA real vs
|
||||||
|
Haar placebo vs vanilla vs absorb. Decision: directionality is real iff real-v
|
||||||
|
deploy_hack << placebo at matched solve, with gate AUROC >> 0.5 around emergence.
|
||||||
|
|||||||
@@ -0,0 +1,71 @@
|
|||||||
|
"""Calibrate the routeA bimodality guard offline.
|
||||||
|
|
||||||
|
Conditions per window (v3/v4/v5), act_dot score vs behavior_ pair v_act:
|
||||||
|
mixture all valid live rollouts (hack share 35-43%) -> guard SHOULD OPEN
|
||||||
|
cleanonly non-exploited rollouts only (pre-emergence proxy) -> SHOULD CLOSE
|
||||||
|
gauss N(0,1) n=256, 10 seeds -> SHOULD CLOSE
|
||||||
|
|
||||||
|
Candidate statistics, computed after z-norm + winsorize(1/99) + otsu3:
|
||||||
|
sep = mean(z | z>=t_hi) - mean(z | z<t_lo), in buffer-sd units
|
||||||
|
nbcv = three-class between-class variance / total variance (both on winsorized z)
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
ROOT = Path("/workspace/projected_grpo")
|
||||||
|
RUNS = {"v3": ROOT / "out/diag", "v4": ROOT / "out/diag_v4", "v5": ROOT / "out/diag_v5"}
|
||||||
|
|
||||||
|
|
||||||
|
def otsu3(x: np.ndarray) -> tuple[float, float]:
|
||||||
|
x = np.clip(x, *np.quantile(x, [0.01, 0.99]))
|
||||||
|
s = np.sort(np.asarray(x, float))
|
||||||
|
n = len(s)
|
||||||
|
c = np.concatenate([[0.0], np.cumsum(s)])
|
||||||
|
best, best_ij = -np.inf, (1, 2)
|
||||||
|
for i in range(1, n - 1):
|
||||||
|
for j in range(i + 1, n):
|
||||||
|
obj = c[i] ** 2 / i + (c[j] - c[i]) ** 2 / (j - i) + (c[n] - c[j]) ** 2 / (n - j)
|
||||||
|
if obj > best:
|
||||||
|
best, best_ij = obj, (i, j)
|
||||||
|
i, j = best_ij
|
||||||
|
return float((s[i - 1] + s[i]) / 2), float((s[j - 1] + s[j]) / 2)
|
||||||
|
|
||||||
|
|
||||||
|
def stats(scores: np.ndarray) -> dict:
|
||||||
|
z = (scores - scores.mean()) / (scores.std() or 1.0)
|
||||||
|
zw = np.clip(z, *np.quantile(z, [0.01, 0.99]))
|
||||||
|
t_lo, t_hi = otsu3(z)
|
||||||
|
keep, mid, rout = zw < t_lo, (zw >= t_lo) & (zw < t_hi), zw >= t_hi
|
||||||
|
sep = float(zw[rout].mean() - zw[keep].mean()) if rout.any() and keep.any() else float("nan")
|
||||||
|
mu = zw.mean()
|
||||||
|
nbcv = sum(m.mean() * (zw[m].mean() - mu) ** 2 for m in (keep, mid, rout) if m.any()) / zw.var()
|
||||||
|
return {"n": len(scores), "t_lo": t_lo, "t_hi": t_hi, "sep": sep, "nbcv": float(nbcv),
|
||||||
|
"w_rout": float(rout.mean())}
|
||||||
|
|
||||||
|
|
||||||
|
rows = []
|
||||||
|
for tag, d in RUNS.items():
|
||||||
|
fe = torch.load(d / "pinning_feats.pt", weights_only=False)
|
||||||
|
ACT, adv, exploited = fe["ACT"].float(), fe["adv"], fe["exploited"]
|
||||||
|
head = [i for i, pid in enumerate(fe["pair_ids"]) if pid.startswith("behavior_")]
|
||||||
|
D = (fe["pair_feats"][("act", "hack")][head] - fe["pair_feats"][("act", "clean")][head]).float().mean(0)
|
||||||
|
V = D / D.norm(dim=-1, keepdim=True).clamp_min(1e-12)
|
||||||
|
s = torch.einsum("nmr,mr->n", ACT, V).numpy()
|
||||||
|
valid = np.abs(adv) > 1e-6
|
||||||
|
rows.append({"cond": f"{tag} mixture", "hack%": exploited[valid].mean(), **stats(s[valid])})
|
||||||
|
cl = valid & ~exploited
|
||||||
|
rows.append({"cond": f"{tag} cleanonly", "hack%": 0.0, **stats(s[cl])})
|
||||||
|
|
||||||
|
rng = np.random.default_rng(0)
|
||||||
|
g = [stats(rng.standard_normal(256)) for _ in range(10)]
|
||||||
|
rows.append({"cond": "gauss n=256 (mean of 10)", "hack%": 0.0,
|
||||||
|
**{k: float(np.mean([r[k] for r in g])) for k in g[0]}})
|
||||||
|
rows.append({"cond": "gauss n=256 (max sep/nbcv)", "hack%": 0.0,
|
||||||
|
**{k: float(np.max([r[k] for r in g])) for k in g[0]}})
|
||||||
|
|
||||||
|
from tabulate import tabulate
|
||||||
|
print(tabulate(rows, headers="keys", tablefmt="pipe", floatfmt="+.3f"))
|
||||||
|
print("\nSHOULD: mixture rows separate from cleanonly+gauss rows on at least one of "
|
||||||
|
"sep/nbcv with a usable margin; ELSE no shape-based guard works and we rely on "
|
||||||
|
"warmup + asymmetric-cost only.")
|
||||||
@@ -0,0 +1,164 @@
|
|||||||
|
"""Per-module residual S-space gate score.
|
||||||
|
|
||||||
|
Unlike Super-S, this keeps one residual-side SVD basis per Linear through
|
||||||
|
projection, pair-direction extraction, mode selection, and scoring. Scores are
|
||||||
|
aggregated only after each module has produced its own score.
|
||||||
|
|
||||||
|
This is an offline residual-stream diagnostic, not exact steering-lite sspace:
|
||||||
|
the cache contains block residuals at layers 12/18/24, not each Linear's own
|
||||||
|
input/output activations.
|
||||||
|
|
||||||
|
writer (o_proj/down_proj): G_m = W_m W_m^T = U_m S_m^2 U_m^T
|
||||||
|
xS_m = h @ U_m / sqrt(S_m)
|
||||||
|
reader (q/k/v/gate/up): G_m = W_m^T W_m = V_m S_m^2 V_m^T
|
||||||
|
xS_m = h @ V_m * sqrt(S_m)
|
||||||
|
|
||||||
|
uv run python scripts/diag_pinning_moduleS.py
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from glob import glob
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from safetensors import safe_open
|
||||||
|
from tabulate import tabulate
|
||||||
|
|
||||||
|
from vgrout.train import _auroc
|
||||||
|
|
||||||
|
|
||||||
|
ROOT = Path("/workspace/projected_grpo")
|
||||||
|
RUNS = {"v3": ROOT / "out/diag", "v4": ROOT / "out/diag_v4", "v5": ROOT / "out/diag_v5"}
|
||||||
|
HEAD_PREFIX = "behavior_"
|
||||||
|
HOOKED_LAYERS = (12, 18, 24)
|
||||||
|
WRITERS = ("self_attn.o_proj", "mlp.down_proj")
|
||||||
|
READERS = ("self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj",
|
||||||
|
"mlp.gate_proj", "mlp.up_proj")
|
||||||
|
RANKS = (64, 256, -1)
|
||||||
|
EPS = 1e-8
|
||||||
|
|
||||||
|
|
||||||
|
def model_weights():
|
||||||
|
snap = Path(glob("/workspace/.hf_home/hub/models--Qwen--Qwen3-4B/snapshots/*")[0])
|
||||||
|
wmap = json.loads((snap / "model.safetensors.index.json").read_text())["weight_map"]
|
||||||
|
handles = {f: safe_open(snap / f, framework="pt") for f in set(wmap.values())}
|
||||||
|
for residual_idx, layer_idx in enumerate(HOOKED_LAYERS):
|
||||||
|
for role, modules in (("writer", WRITERS), ("reader", READERS)):
|
||||||
|
module_layer = layer_idx if role == "writer" else layer_idx + 1
|
||||||
|
for module in modules:
|
||||||
|
key = f"model.layers.{module_layer}.{module}.weight"
|
||||||
|
yield residual_idx, module_layer, role, module, handles[wmap[key]].get_tensor(key).float()
|
||||||
|
|
||||||
|
|
||||||
|
def residual_basis(W: torch.Tensor, role: str) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Return residual-side basis and sqrt(Sigma), descending by singular value."""
|
||||||
|
gram = W @ W.T if role == "writer" else W.T @ W
|
||||||
|
eigenvalues, basis = torch.linalg.eigh(gram)
|
||||||
|
return basis.flip(1), eigenvalues.flip(0).clamp_min(0).pow(0.25).clamp_min(EPS)
|
||||||
|
|
||||||
|
|
||||||
|
def module_stats(
|
||||||
|
X: torch.Tensor,
|
||||||
|
Ph: torch.Tensor,
|
||||||
|
Pc: torch.Tensor,
|
||||||
|
basis: torch.Tensor,
|
||||||
|
sqrtS: torch.Tensor,
|
||||||
|
role: str,
|
||||||
|
rank: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""Return per-example dot, cosine, and squared norm in one module S-space."""
|
||||||
|
scale = sqrtS if role == "reader" else sqrtS.reciprocal()
|
||||||
|
XS = (X @ basis) * scale
|
||||||
|
dS = ((Ph @ basis) * scale - (Pc @ basis) * scale).mean(0)
|
||||||
|
if rank > 0:
|
||||||
|
idx = dS.abs().topk(rank).indices
|
||||||
|
XS, dS = XS[:, idx], dS[idx]
|
||||||
|
direction = dS / dS.norm().clamp_min(EPS)
|
||||||
|
dot = XS @ direction
|
||||||
|
norm2 = XS.square().sum(1)
|
||||||
|
return dot, dot / norm2.sqrt().clamp_min(EPS), norm2
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
features = {tag: torch.load(path / "pinning_feats.pt", weights_only=False)
|
||||||
|
for tag, path in RUNS.items()}
|
||||||
|
prepared = {}
|
||||||
|
for tag, fe in features.items():
|
||||||
|
head = [i for i, pair_id in enumerate(fe["pair_ids"]) if pair_id.startswith(HEAD_PREFIX)]
|
||||||
|
prepared[tag] = {
|
||||||
|
"X": fe["RES"].float(),
|
||||||
|
"Ph": fe["pair_feats"][("resid", "hack")][head].float(),
|
||||||
|
"Pc": fe["pair_feats"][("resid", "clean")][head].float(),
|
||||||
|
"pos": (np.abs(fe["adv"]) > 1e-6) & (fe["adv"] > 0),
|
||||||
|
"y": fe["exploited"] & (fe["adv"] > 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
collected = {
|
||||||
|
(tag, role, rank): {"dot": [], "cos": [], "norm2": []}
|
||||||
|
for tag in RUNS for role in ("writer", "reader") for rank in RANKS
|
||||||
|
}
|
||||||
|
print("building and scoring 21 independent residual-side module SVD bases (CPU)...")
|
||||||
|
for residual_idx, layer_idx, role, module, W in model_weights():
|
||||||
|
print(f" residual={HOOKED_LAYERS[residual_idx]:02d} module_layer={layer_idx:02d} "
|
||||||
|
f"role={role:6s} module={module}")
|
||||||
|
basis, sqrtS = residual_basis(W, role)
|
||||||
|
for tag, fe in prepared.items():
|
||||||
|
X, Ph, Pc = fe["X"][:, residual_idx], fe["Ph"][:, residual_idx], fe["Pc"][:, residual_idx]
|
||||||
|
for rank in RANKS:
|
||||||
|
dot, cos, norm2 = module_stats(X, Ph, Pc, basis, sqrtS, role, rank)
|
||||||
|
for key, value in (("dot", dot), ("cos", cos), ("norm2", norm2)):
|
||||||
|
collected[(tag, role, rank)][key].append(value)
|
||||||
|
|
||||||
|
rows = {}
|
||||||
|
for tag, fe in prepared.items():
|
||||||
|
au = lambda score: _auroc(score[fe["pos"]].tolist(), fe["y"][fe["pos"]].tolist())
|
||||||
|
raw_d = (fe["Ph"] - fe["Pc"]).mean(0)
|
||||||
|
raw_v = raw_d / raw_d.norm(dim=-1, keepdim=True).clamp_min(EPS)
|
||||||
|
raw_dot = torch.einsum("nld,ld->n", fe["X"], raw_v)
|
||||||
|
raw_cos = raw_dot / (fe["X"].flatten(1).norm(dim=1).clamp_min(EPS)
|
||||||
|
* raw_v.flatten().norm())
|
||||||
|
rows.setdefault(("raw resid", "concat", "cos"), {"variant": "raw resid",
|
||||||
|
"aggregate": "concat", "score": "cos"})[tag] = au(raw_cos)
|
||||||
|
rows.setdefault(("raw resid", "concat", "dot"), {"variant": "raw resid",
|
||||||
|
"aggregate": "concat", "score": "dot"})[tag] = au(raw_dot)
|
||||||
|
|
||||||
|
for role in ("writer", "reader", "both"):
|
||||||
|
roles = ("writer", "reader") if role == "both" else (role,)
|
||||||
|
for rank in RANKS:
|
||||||
|
stats = [collected[(tag, r, rank)] for r in roles]
|
||||||
|
dots = torch.stack([x for stat in stats for x in stat["dot"]], 1)
|
||||||
|
coses = torch.stack([x for stat in stats for x in stat["cos"]], 1)
|
||||||
|
norm2 = torch.stack([x for stat in stats for x in stat["norm2"]], 1)
|
||||||
|
scores = {
|
||||||
|
("mean", "cos"): coses.mean(1),
|
||||||
|
("mean", "dot"): dots.mean(1),
|
||||||
|
("concat", "cos"): dots.sum(1) / (
|
||||||
|
norm2.sum(1).sqrt().clamp_min(EPS) * np.sqrt(dots.shape[1])),
|
||||||
|
("concat", "dot"): dots.sum(1),
|
||||||
|
}
|
||||||
|
rank_name = "full" if rank < 0 else str(rank)
|
||||||
|
variant = f"moduleS {role} r={rank_name}"
|
||||||
|
for (aggregate, kind), score in scores.items():
|
||||||
|
key = (variant, aggregate, kind)
|
||||||
|
rows.setdefault(key, {"variant": variant, "aggregate": aggregate,
|
||||||
|
"score": kind})[tag] = au(score)
|
||||||
|
|
||||||
|
out = list(rows.values())
|
||||||
|
for row in out:
|
||||||
|
values = [row[tag] for tag in RUNS]
|
||||||
|
row["mean"], row["min"] = float(np.mean(values)), float(np.min(values))
|
||||||
|
out.sort(key=lambda row: -row["min"])
|
||||||
|
cols = ["variant", "aggregate", "score", "v3", "v4", "v5", "mean", "min"]
|
||||||
|
print("\nAUROC on the A>0 contrast:")
|
||||||
|
print("SHOULD: raw resid cos == 0.916/0.700/0.804 and dot == "
|
||||||
|
"0.905/0.721/0.756 ELSE this harness disagrees with diag_pinning_superS. "
|
||||||
|
"Per-module bases must remain separate until score aggregation.")
|
||||||
|
print(tabulate([{c: row[c] for c in cols} for row in out],
|
||||||
|
headers="keys", tablefmt="pipe", floatfmt="+.3f"))
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
@@ -0,0 +1,232 @@
|
|||||||
|
"""Exact per-module S-space gate diagnostic on actual Linear inputs/outputs.
|
||||||
|
|
||||||
|
For every selected Linear, capture the same side used by steering-lite sspace:
|
||||||
|
|
||||||
|
writer output: xS = y @ U / sqrt(S)
|
||||||
|
reader input: xS = x @ V * sqrt(S)
|
||||||
|
|
||||||
|
Each module keeps its own thin SVD basis through pair-direction extraction,
|
||||||
|
top-r selection, and scoring. Scores are aggregated only after each module has
|
||||||
|
produced its own score.
|
||||||
|
|
||||||
|
GPU required:
|
||||||
|
uv run python scripts/diag_pinning_moduleS_exact.py
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import struct
|
||||||
|
import csv
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import tyro
|
||||||
|
from loguru import logger
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
from tabulate import tabulate
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
from vgrout.lora2r import wrap_model_with_lora2r
|
||||||
|
from vgrout.pairs import load_pairs
|
||||||
|
from vgrout.train import _auroc
|
||||||
|
|
||||||
|
|
||||||
|
ROLES = {
|
||||||
|
"self_attn.o_proj": "writer",
|
||||||
|
"mlp.down_proj": "writer",
|
||||||
|
"self_attn.q_proj": "reader",
|
||||||
|
"self_attn.k_proj": "reader",
|
||||||
|
"self_attn.v_proj": "reader",
|
||||||
|
"mlp.gate_proj": "reader",
|
||||||
|
"mlp.up_proj": "reader",
|
||||||
|
}
|
||||||
|
RANKS = (64, 256, -1)
|
||||||
|
EPS = 1e-8
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Cfg:
|
||||||
|
run_dir: Path = Path("out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3")
|
||||||
|
ckpt: str = "first_hack"
|
||||||
|
pairs: Path = Path("data/pairs/hack_pairs.md#all-in-one")
|
||||||
|
headline_prefix: str = "behavior_"
|
||||||
|
step_lo: int = 2
|
||||||
|
step_hi: int = 9
|
||||||
|
max_rollouts: int = 240
|
||||||
|
layers: tuple[int, ...] = (12, 18, 24)
|
||||||
|
out: Path = Path("out/diag/moduleS_exact.tsv")
|
||||||
|
|
||||||
|
|
||||||
|
def checkpoint_meta(path: Path) -> dict:
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
return json.loads(f.read(struct.unpack("<Q", f.read(8))[0])).get("__metadata__", {})
|
||||||
|
|
||||||
|
|
||||||
|
def completion_nll(model, tokenizer, prompt: str, completion: str, device) -> torch.Tensor:
|
||||||
|
prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
||||||
|
full_ids = tokenizer(prompt + completion, return_tensors="pt").input_ids.to(device)
|
||||||
|
n_prompt = prompt_ids.shape[1]
|
||||||
|
logits = model(full_ids).logits[:, :-1]
|
||||||
|
targets = full_ids[:, 1:]
|
||||||
|
logp = torch.nn.functional.log_softmax(logits.float(), dim=-1)
|
||||||
|
nll = -logp.gather(-1, targets.unsqueeze(-1)).squeeze(-1)
|
||||||
|
positions = torch.arange(full_ids.shape[1] - 1, device=device).unsqueeze(0)
|
||||||
|
mask = (positions >= (n_prompt - 1)).float()
|
||||||
|
return (nll * mask).sum() / mask.sum()
|
||||||
|
|
||||||
|
|
||||||
|
def thin_basis(W: torch.Tensor, role: str) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Return residual-side thin basis and sqrt(Sigma), descending."""
|
||||||
|
d_out, d_in = W.shape
|
||||||
|
if role == "writer":
|
||||||
|
eigenvalues, basis = torch.linalg.eigh(W @ W.T)
|
||||||
|
elif d_out >= d_in:
|
||||||
|
eigenvalues, basis = torch.linalg.eigh(W.T @ W)
|
||||||
|
else:
|
||||||
|
eigenvalues, U = torch.linalg.eigh(W @ W.T)
|
||||||
|
singular_values = eigenvalues.clamp_min(0).sqrt()
|
||||||
|
basis = W.T @ U / singular_values.clamp_min(EPS)
|
||||||
|
return basis.flip(1), eigenvalues.flip(0).clamp_min(0).pow(0.25).clamp_min(EPS)
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleSTap:
|
||||||
|
def __init__(self, modules: dict[str, tuple[torch.nn.Module, str, torch.Tensor, torch.Tensor]]):
|
||||||
|
self.modules = modules
|
||||||
|
self.values: dict[str, torch.Tensor] = {}
|
||||||
|
self.handles = []
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
for name, (module, role, basis, sqrtS) in self.modules.items():
|
||||||
|
def hook(module, args, output, name=name, role=role, basis=basis, sqrtS=sqrtS):
|
||||||
|
h = F.linear(args[0], module.weight, module.bias) if role == "writer" else args[0]
|
||||||
|
scale = sqrtS.reciprocal() if role == "writer" else sqrtS
|
||||||
|
self.values[name] = (h.detach().float() @ basis) * scale
|
||||||
|
self.handles.append(module.register_forward_hook(hook))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *exc):
|
||||||
|
for handle in self.handles:
|
||||||
|
handle.remove()
|
||||||
|
|
||||||
|
def pooled(self, n_prompt: int) -> dict[str, torch.Tensor]:
|
||||||
|
return {name: value[0, n_prompt:].mean(0).cpu() for name, value in self.values.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def score_rows(
|
||||||
|
live: list[dict[str, torch.Tensor]],
|
||||||
|
pair_hack: list[dict[str, torch.Tensor]],
|
||||||
|
pair_clean: list[dict[str, torch.Tensor]],
|
||||||
|
roles: dict[str, str],
|
||||||
|
adv: np.ndarray,
|
||||||
|
exploited: np.ndarray,
|
||||||
|
) -> list[dict]:
|
||||||
|
head_diff = {
|
||||||
|
name: torch.stack([hack[name] - clean[name] for hack, clean in zip(pair_hack, pair_clean)]).mean(0)
|
||||||
|
for name in roles
|
||||||
|
}
|
||||||
|
pos = (np.abs(adv) > 1e-6) & (adv > 0)
|
||||||
|
y = exploited & (adv > 0)
|
||||||
|
au = lambda score: _auroc(score[pos].tolist(), y[pos].tolist())
|
||||||
|
rows = []
|
||||||
|
for role in ("writer", "reader", "both"):
|
||||||
|
names = [name for name, module_role in roles.items() if role == "both" or module_role == role]
|
||||||
|
for rank in RANKS:
|
||||||
|
dots, coses, norm2 = [], [], []
|
||||||
|
for name in names:
|
||||||
|
dS = head_diff[name]
|
||||||
|
X = torch.stack([example[name] for example in live])
|
||||||
|
if rank > 0 and rank < dS.numel():
|
||||||
|
idx = dS.abs().topk(rank).indices
|
||||||
|
dS, X = dS[idx], X[:, idx]
|
||||||
|
direction = dS / dS.norm().clamp_min(EPS)
|
||||||
|
dot = X @ direction
|
||||||
|
dots.append(dot)
|
||||||
|
coses.append(dot / X.norm(dim=1).clamp_min(EPS))
|
||||||
|
norm2.append(X.square().sum(1))
|
||||||
|
dots, coses, norm2 = torch.stack(dots, 1), torch.stack(coses, 1), torch.stack(norm2, 1)
|
||||||
|
scores = {
|
||||||
|
"mean_cos": coses.mean(1),
|
||||||
|
"concat_cos": dots.sum(1) / (norm2.sum(1).sqrt().clamp_min(EPS) * np.sqrt(dots.shape[1])),
|
||||||
|
"dot": dots.mean(1),
|
||||||
|
}
|
||||||
|
rank_name = "full" if rank < 0 else str(rank)
|
||||||
|
rows.extend({"variant": f"moduleS-exact {role} r={rank_name}", "score": kind, "auroc": au(score)}
|
||||||
|
for kind, score in scores.items())
|
||||||
|
return rows
|
||||||
|
|
||||||
|
|
||||||
|
def main(cfg: Cfg) -> int:
|
||||||
|
device = torch.device("cuda")
|
||||||
|
ckpt_path = cfg.run_dir / f"{cfg.ckpt}.safetensors"
|
||||||
|
meta = checkpoint_meta(ckpt_path)
|
||||||
|
run_cfg = json.loads(meta["cfg"])
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(run_cfg["model"])
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
run_cfg["model"], dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(device)
|
||||||
|
model.config.use_cache = False
|
||||||
|
wrappers = wrap_model_with_lora2r(model, r=run_cfg["lora_r"], init_seed=run_cfg["lora_init_seed"])
|
||||||
|
state = load_file(str(ckpt_path))
|
||||||
|
for name, info in wrappers.items():
|
||||||
|
info["A"].data.copy_(state[f"A/{name}"].to(device, torch.float32))
|
||||||
|
info["B"].data.copy_(state[f"B/{name}"].to(device, torch.float32))
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
selected = {}
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
for suffix, role in ROLES.items():
|
||||||
|
if name.endswith(suffix) and any(name.startswith(f"model.layers.{layer}.") for layer in cfg.layers):
|
||||||
|
basis, sqrtS = thin_basis(module.weight.detach().float(), role)
|
||||||
|
selected[name] = (module, role, basis, sqrtS)
|
||||||
|
assert len(selected) == len(cfg.layers) * len(ROLES), len(selected)
|
||||||
|
roles = {name: role for name, (_, role, _, _) in selected.items()}
|
||||||
|
logger.info(f"selected {len(selected)} exact module S-spaces")
|
||||||
|
|
||||||
|
def extract(tap: ModuleSTap, prompt: str, completion: str) -> dict[str, torch.Tensor]:
|
||||||
|
with torch.no_grad():
|
||||||
|
loss = completion_nll(model, tokenizer, prompt, completion, device)
|
||||||
|
assert torch.isfinite(loss), loss
|
||||||
|
n_prompt = tokenizer(prompt, return_tensors="pt").input_ids.shape[1]
|
||||||
|
return tap.pooled(n_prompt)
|
||||||
|
|
||||||
|
pairs = [pair for pair in load_pairs(cfg.pairs) if pair.problem_id.startswith(cfg.headline_prefix)]
|
||||||
|
records = [json.loads(line) for line in (cfg.run_dir / "rollouts.jsonl").read_text().splitlines()]
|
||||||
|
records = [r for r in records if cfg.step_lo <= r["step"] <= cfg.step_hi and r["text"].strip()][:cfg.max_rollouts]
|
||||||
|
pair_hack, pair_clean, live = [], [], []
|
||||||
|
with ModuleSTap(selected) as tap:
|
||||||
|
for i, pair in enumerate(pairs):
|
||||||
|
pair_hack.append(extract(tap, pair.prompt, pair.hack))
|
||||||
|
pair_clean.append(extract(tap, pair.prompt, pair.clean))
|
||||||
|
logger.info(f"pair {i + 1}/{len(pairs)}")
|
||||||
|
for i, record in enumerate(records):
|
||||||
|
live.append(extract(tap, record["prompt"], record["text"]))
|
||||||
|
if (i + 1) % 40 == 0:
|
||||||
|
logger.info(f"rollout {i + 1}/{len(records)}")
|
||||||
|
|
||||||
|
steps = np.array([r["step"] for r in records])
|
||||||
|
p_idx = np.array([r["p_idx"] for r in records])
|
||||||
|
reward = np.array([float(r["reward"]) for r in records])
|
||||||
|
adv = np.empty(len(records))
|
||||||
|
for step, prompt_idx in set(zip(steps.tolist(), p_idx.tolist())):
|
||||||
|
mask = (steps == step) & (p_idx == prompt_idx)
|
||||||
|
adv[mask] = reward[mask] - reward[mask].mean()
|
||||||
|
exploited = np.array([bool(r["exploited"]) for r in records])
|
||||||
|
|
||||||
|
rows = score_rows(live, pair_hack, pair_clean, roles, adv, exploited)
|
||||||
|
rows.sort(key=lambda row: -row["auroc"])
|
||||||
|
cfg.out.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with cfg.out.open("w", newline="") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=("variant", "score", "auroc"), delimiter="\t")
|
||||||
|
writer.writeheader()
|
||||||
|
writer.writerows(rows)
|
||||||
|
print("SHOULD: all rows finite; exact hooks capture writer outputs and reader inputs. "
|
||||||
|
"A selected maximum is a candidate, not confirmation.")
|
||||||
|
print(tabulate(rows, headers="keys", tablefmt="pipe", floatfmt="+.3f"))
|
||||||
|
logger.info(f"wrote {cfg.out}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main(tyro.cli(Cfg)))
|
||||||
Reference in New Issue
Block a user