diag: super-S-space gate score null; spec -> act_dot + winsorized-Otsu plan

superS (pooled writer/reader eigenbasis, whitened + top-r) tops out at
min-window AUROC 0.721 = raw resid dot; best unwhitened rotation+top-64
0.740 < act 0.747 (max of ~50-variant grid). act t-stat extraction also
null (0.719 vs 0.749 min). Spec updated: act_dot default, journal-(d)
evidence table, implementation plan for routeA.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-11 11:42:44 +00:00
parent 1b51c62cdc
commit 1d4f33ffb6
2 changed files with 288 additions and 0 deletions
+153
View File
@@ -0,0 +1,153 @@
# Spec: routeA -- score activations, route gradients (2026-06-11)
New arm id `routeA` (per the rename-on-logic-change rule): the gate input changes from
gradients to activations, so routeV results stay comparable only to routeV.
## Evidence motivating the change
RESEARCH_JOURNAL.md 2026-06-11 (b) correction, (c), and (d) Table 1. On the A>0
contrast (reward-hack vs non-reward-hack among positively-reinforced rollouts, where
the advantage baseline is near chance), across three independent emergence windows:
| score | v3 | v4 | v5 | mean | min |
|---|---|---|---|---|---|
| act_cos | 0.869 | 0.749 | 0.752 | 0.790 | 0.749 |
| act_dot | 0.870 | 0.747 | 0.747 | 0.788 | 0.747 |
| resid_dot | 0.905 | 0.721 | 0.756 | 0.794 | 0.721 |
| resid_cos | 0.916 | 0.700 | 0.804 | 0.807 | 0.700 |
| grad_cos (current gate space) | 0.838 | 0.617 | 0.455 | 0.636 | 0.455 |
| grad_dot | 0.809 | 0.607 | 0.455 | 0.623 | 0.455 |
The gradient score does not replicate; the activation and residual scores do, and form
one statistical cluster (per-window SE ~ 0.07). The activation score needs no advantage
reconstruction, no gate backward pass, and no hooks outside the adapter modules, and it
has the best worst-window AUROC, so it is the default.
Variants tested and found null on these windows (all within ~0.03 of baseline min,
inside one SE; logs /tmp/claude-1000/superS_v1.log, act_dot_tstat.log, pinning_followup.log):
- t-stat extraction (v = mean(D)/se(D) per coordinate over the 8 behavior_ pairs,
clamped to |t| <= 3): act 0.867/0.756/0.719 vs mean-extraction 0.869/0.749/0.752.
With 8 pairs the per-coordinate std has ~25 percent relative error; revisit when the
authored-pair set grows. Implemented behind `vact_tstat` (default off).
- Super-S-space (project the residual stream onto the pooled eigenbasis of the residual
writers/readers, optionally whitened by the pooled spectrum, per
wassname/steering-lite super_sspace; scripts/diag_pinning_superS.py): whitened
variants top out at min-window 0.721 (= raw resid dot); the best unwhitened
rotation+top-64 variant (reader basis) reaches 0.740, above raw resid cos (0.700)
but below act (0.747-0.749), and it is the maximum of a ~50-variant grid, so
post-hoc selection inflation applies. No variant beats the act default.
- SNR module weighting / top-quartile module pruning for act: moves AUROC <= 0.02.
- Per-layer resid: L24 alone >= the 3-layer concat on every window (0.925/0.712/0.814).
## What changes
1. **Extraction** (`extract_vhack_act` or a mode of the existing extractor): forward each
authored pair side once, capture the deployed bottleneck activation `h = A[:r] @ x`
per wrapped module, mean over completion tokens. `v_act[name] = unitnorm(mean over
pairs of (act_hack - act_clean))`. No NLL, no backward, no c-probe. Reference
implementation: `scripts/diag_pinning.py` (`ActTap`, `_v_from`).
2. **Gate**: capture the same pooled acts during the existing no-grad `logpi_old` forward
(one hook; that forward already runs every step). Per-rollout score = concatenated-
module `dot(act, v_act)` with v unit-norm per module. routeV's pass-1
forward+backward is deleted; routeA costs approximately the vanilla arm.
3. **Score is dot.** Empirically act_dot == act_cos to ~0.002 on all three windows, so
the choice is structural, not empirical: dot is a plain linear readout (one einsum, no
per-rollout norm), and cos's per-rollout normalization makes the score sensitive to
the norm of the pooled activation (a length/style factor) in the denominator.
Boundedness no longer matters for threshold portability because thresholds are set on
online-standardized scores (point 4), not raw values. Log cos as a diagnostic.
4. **Labels/pinning**: label-free, from online statistics of the live scores only
(journal entry 2026-06-11 (d); reference implementation `_otsu3` in
scripts/diag_pinning.py). The mean+k*sd rule was wrong because the live scores are a
MIXTURE (hack share 35-43% in the emergence windows), so mean+2/3sd lands beyond
every distribution and routes nothing. The replacement, validated on all three
windows: (a) maintain a rolling buffer of recent per-rollout scores (e.g. last 256;
one 24-rollout batch is too few); (b) standardize by the buffer mean/std (display
and logging comparability; the partition itself is affine-invariant); (c) winsorize
at the 1/99 percentiles, then a two-threshold Otsu search (maximize three-class
between-class variance) gives (t_lo, t_hi) -> keep | absorb | rout. Winsorizing is
load-bearing: without it a single outlier buys a zone (v4 grad_dot keep zone was one
point; v5 act rout precision was 0.00 before, 0.50 after). Bimodality guard -- only
open the rout zone when the between-class variance gain over a single threshold is
material (else an all-clean early batch gets split and clean updates are routed);
exact criterion to be fixed at implementation. The residual gap to a perfect router
is the SCORE's AUROC, not the threshold: even the oracle single cut only reaches
P ~ 0.5-0.65 at useful recall on v4/v5, so pure high-precision routing needs a
better score or more authored pairs, not better pinning.
5. **Routing payload unchanged**: three-way masks (m, d) on block outputs, quarantine
ablation at deployment, refresh of v_act every `vhack_refresh_every` steps (now
forward-only, so cheap), placebo arm = Haar-random v_act with identical machinery.
6. **Pairs**: `#all-in-one/behavior_` (8 pairs), unchanged -- it was the best or tied
vector on every window.
## What this removes
- The gate's second forward+backward per step (routeV's main cost).
- The advantage-reconstruction dependence of the gate score.
- The c-probe (`grad_probe`) machinery from the training path (keep it available for
diagnostics; `scripts/diag_pinning.py` still uses it).
- G_hack / gradient-vector extraction from the training path.
## Open questions (do not block implementation)
- **Bottleneck vs residual stream**: RESOLVED ENOUGH (pueue #21-23, journal (d) Table
1). resid and act are within ~1 SE of each other on every window; the random r=32
bottleneck is NOT the limiter. Default the gate to the bottleneck act (act_dot): it
has the best worst-window AUROC (0.747 vs resid_cos 0.700), lives entirely in gate
space (no hooks on decoder layers), and needs no layer-depth choice. resid (L24
alone, not the concat) stays the diagnostic alternative; it is adapter-independent
and closest to a plain steering-vector probe, one hook-point change away.
- **Texture-shortcut risk**: the act score may detect exploit-token surface text. The
held-out hack-mode test (train v_act on pairs covering some modes, measure routing of
an absent mode) distinguishes generalization from memorized texture.
- Threshold rule v2 (pre-emergence frozen stats vs per-batch online).
## Implementation plan
Ordered; each step is one commit with its verify gate. Not started until the user
approves the plan.
1. **Extraction** (`src/vgrout/extract_vhack_act.py`): `extract_v_act(model, wrappers,
names, pairs, tok, device, tstat=False) -> dict[name, Tensor[r]]`. For each pair
side one no-grad forward of prompt+completion; capture the deployed bottleneck per
wrapped module with the ActTap hook pattern from scripts/diag_pinning.py
(`F.linear(x, A[:r])` in a forward hook, [B, L, r], mean over completion tokens).
`v[name] = unitnorm(mean over pairs of (h_hack - h_clean))`; `tstat=True` divides
the mean by the standard error over pairs and clamps |t| <= 3 before the unitnorm.
No NLL, no backward, no c-probe. Verify gate: `scripts/verify_v_act.py` checks the
extractor reproduces the cached pair features in out/diag/pinning_feats.pt
(`pair_feats[("act", side)]`) on the v3 checkpoint to fp tolerance.
2. **Gate in train.py**: routeA pins masks BEFORE the grad-carrying forward, like
none/absorb, so the routeV second backward disappears and routeA costs ~ the
vanilla arm. During the existing no-grad logpi_old forward (train.py ~line 877),
the same ActTap hooks capture per-rollout pooled acts (pad-mask mean over
completion tokens, batched). Score per rollout = sum over modules of
`<h_m, v_m>` (act_dot; one einsum on the [B, M, r] stack). Log cos alongside.
3. **Pinning**: rolling deque of the last 256 live scores. Per step: append the
batch, standardize by buffer mean/std (logged z), winsorize at 1/99 percentiles,
`_otsu3` (lift from scripts/diag_pinning.py into the train module or a shared
util) -> (t_lo, t_hi) -> keep (0,0) | absorb (1,0) | rout (1,1). Warmup: gate
pinned absorb until the buffer holds >= 128 scores. Bimodality guard (proposal,
to confirm at implementation): open the rout zone only when the high-class
separation (mean of rout class minus mean of keep class) exceeds 1 buffer sd;
otherwise collapse rout into absorb for that step.
4. **Arm wiring**: `intervention="routeA"` (rename-on-logic-change; routeV results
stay comparable only to routeV). routeV is REMOVED from train.py in the same
commit (the c-probe/grad-gate machinery stays in scripts/ for diagnostics);
`grad_probe=True` is then never set in training. Placebo flag
`routeA_random_v_seed` = Haar-random unit v_act per module, identical machinery.
Refresh: reuse `vhack_refresh_every` (forward-only now, so cheap).
5. **Logging** (per step): gate AUROC on the A>0 contrast vs hack labels (diagnostic
only; labels never feed routing), zone shares keep/absorb/rout, buffer mean/sd,
(t_lo, t_hi) in z units, qmass. SHOULD lines per token-efficient-logging.
6. **Runs**: `just smoke` with routeA on the tiny model, then the fast preset pair:
routeA real + routeA placebo + vanilla, seed 43, with the journal-(d) windows as
the comparison reference.
## UAT
1. `just smoke` passes with the routeA pathway (act capture + gate labels on tiny model).
2. A routeA fast run logs per-step gate AUROC (A>0 contrast, students+cached teachers)
>= 0.7 around first_hack, and deploy hack rate < vanilla at matched solve rate.
3. The Haar-placebo routeA arm does NOT match real v_act (directionality, not shrinkage).
+135
View File
@@ -0,0 +1,135 @@
"""Super-S-space gate score: project the residual stream onto the pooled SVD basis
of the residual writers/readers before extracting v and scoring.
Source idea: wassname/steering-lite variants/super_sspace.py. Pool the residual-side
singular structure of the block Linears into a Gram matrix and eigendecompose:
writer (d_out = d_model, o_proj/down_proj): G += W W^T (= U S^2 U^T)
reader (d_in = d_model, q/k/v/gate/up): G += W^T W (= V S^2 V^T)
eigh(G) -> U_star [d, d], Sigma_star = sqrt(lambda)
Whitened coordinates xS = h @ U_star / sqrt(Sigma_star). At full rank WITHOUT the
whitening this is a pure rotation (cos/dot unchanged), so the testable content is
(a) the 1/sqrt(Sigma) reweighting -- a Mahalanobis metric in the writer/reader
spectrum -- and (b) top-r mode selection by |dS| from the authored-pair diff
(pair-only info, oracle-free).
Offline from cached pinning_feats.pt (RES [N, L, 2560]) + safetensors weights, no
forward pass, no GPU. Evaluation = the A>0 contrast AUROC (exploited & adv>0 vs
rest among adv>0), same as diag_pinning.py.
uv run python scripts/diag_pinning_superS.py
"""
from __future__ import annotations
import json
from glob import glob
from pathlib import Path
import numpy as np
import torch
from safetensors import safe_open
from tabulate import tabulate
from vgrout.train import _auroc
ROOT = Path("/workspace/projected_grpo")
RUNS = {"v3": ROOT / "out/diag", "v4": ROOT / "out/diag_v4", "v5": ROOT / "out/diag_v5"}
HEAD_PREFIX = "behavior_"
HOOKED_LAYERS = (12, 18, 24)
WRITERS = ("self_attn.o_proj", "mlp.down_proj")
READERS = ("self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj",
"mlp.gate_proj", "mlp.up_proj")
EPS = 1e-8
def load_grams(n_layers: int = 36, d: int = 2560) -> dict[tuple[str, str], torch.Tensor]:
"""{(role, blocks): G [d, d]} for role in {writer, reader}, blocks in {hooked, all}."""
snap = Path(glob("/workspace/.hf_home/hub/models--Qwen--Qwen3-4B/snapshots/*")[0])
wmap = json.loads((snap / "model.safetensors.index.json").read_text())["weight_map"]
G = {(role, blk): torch.zeros(d, d) for role in ("writer", "reader") for blk in ("hooked", "all")}
handles = {f: safe_open(snap / f, framework="pt") for f in set(wmap.values())}
for li in range(n_layers):
for role, mods in (("writer", WRITERS), ("reader", READERS)):
for m in mods:
key = f"model.layers.{li}.{m}.weight"
W = handles[wmap[key]].get_tensor(key).float()
gram = W @ W.T if role == "writer" else W.T @ W
G[(role, "all")] += gram
if li in HOOKED_LAYERS:
G[(role, "hooked")] += gram
return G
def basis(G: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""U_star [d, d] descending, sqrt(Sigma_star) = lambda**0.25 clamped."""
lam, U = torch.linalg.eigh(G.double())
lam, U = lam.flip(0).clamp_min(0), U.flip(1)
return U.float(), lam.float().pow(0.25).clamp_min(EPS) # sqrt(Sigma) = lam^(1/4)
def score(XS: torch.Tensor, dS: torch.Tensor, kind: str) -> np.ndarray:
"""XS [N, L, r], dS [L, r] -> concat score over layers."""
v = dS / dS.norm(dim=-1, keepdim=True).clamp_min(EPS)
d = torch.einsum("nlr,lr->n", XS, v)
if kind == "dot":
return d.numpy()
return (d / (XS.flatten(1).norm(dim=1).clamp_min(EPS) * v.flatten().norm())).numpy()
def main() -> int:
print("building writer/reader Gram bases from Qwen3-4B safetensors (CPU)...")
grams = load_grams()
grams[("both", "hooked")] = grams[("writer", "hooked")] + grams[("reader", "hooked")]
grams[("both", "all")] = grams[("writer", "all")] + grams[("reader", "all")]
bases = {k: basis(g) for k, g in grams.items()}
rows = {}
for tag, diag_dir in RUNS.items():
fe = torch.load(diag_dir / "pinning_feats.pt", weights_only=False)
RES, adv, exploited = fe["RES"].float(), fe["adv"], fe["exploited"]
head = [i for i, pid in enumerate(fe["pair_ids"]) if pid.startswith(HEAD_PREFIX)]
Ph = fe["pair_feats"][("resid", "hack")][head].float()
Pc = fe["pair_feats"][("resid", "clean")][head].float()
pos = (np.abs(adv) > 1e-6) & (adv > 0)
y = exploited & (adv > 0)
au = lambda s: _auroc(s[pos].tolist(), y[pos].tolist())
def add(name: str, XS, dS):
for kind in ("cos", "dot"):
rows.setdefault((name, kind), {"variant": name, "score": kind})[tag] = \
au(score(XS, dS, kind))
rows[(name, kind)][f"{tag} L24"] = au(score(XS[:, 2:], dS[2:], kind))
add("raw resid (baseline)", RES, (Ph - Pc).mean(0))
for (role, blk), (U, sqrtS) in sorted(bases.items()):
# whiten=False at r=full is a pure rotation (== baseline, skipped); with
# top-r it still tests sparsification in the pooled eigenbasis.
for whiten, lab in ((True, "superS"), (False, "superS-rot")):
tx = lambda X: torch.einsum("nld,dk->nlk", X, U) / (sqrtS if whiten else 1.0)
XS, dS = tx(RES), (tx(Ph) - tx(Pc)).mean(0)
if whiten:
add(f"{lab} {role}/{blk} r=full", XS, dS)
for r in (256, 64):
idx = dS.abs().topk(r, dim=-1).indices # [L, r] per-layer top modes
XSr = torch.gather(XS, 2, idx.unsqueeze(0).expand(XS.shape[0], -1, -1))
add(f"{lab} {role}/{blk} r={r}", XSr, torch.gather(dS, 1, idx))
out = list(rows.values())
for r in out:
vals = [r[t] for t in RUNS]
r["mean"], r["min"] = float(np.mean(vals)), float(np.min(vals))
out.sort(key=lambda r: -r["min"])
cols = ["variant", "score", "v3", "v4", "v5", "mean", "min",
"v3 L24", "v4 L24", "v5 L24"]
print("\nAUROC on the A>0 contrast, concat over layers 12/18/24 (and L24 alone):")
print("SHOULD: 'raw resid (baseline)' cos == 0.916/0.700/0.804 (matches journal "
"Table 1) ELSE the harness disagrees with diag_pinning. superS rows differ "
"from baseline only via whitening + top-r; if no superS row beats baseline "
"min by >0.03 the pooled-spectrum metric adds nothing here.")
print(tabulate([{c: r[c] for c in cols} for r in out],
headers="keys", tablefmt="pipe", floatfmt="+.3f"))
return 0
if __name__ == "__main__":
raise SystemExit(main())