feat: top-k routing subspace for routeV (--v-grad-k, gate=max_i cos)

k=1 (default) stays the mean-mass mean-diff axis -- headline unchanged. k>1
builds the top-k oriented SVD dirs of the paired diff and the gate scores
max_i cos(g, v_i) (alignment to ANY known hack sub-mode), catching multi-modal
hack signal one mean washes out. Shared _build_v_grad at init + refresh; band
edges and the live gate both max over k. Sims use einsum + jaxtyping dims.
Smoke: just smoke-topk green (top-3 subspace, band width +0.087, 12/14 modules).

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-10 11:44:45 +00:00
parent 9fd2b6b89b
commit 4120d75ea4
4 changed files with 64 additions and 165 deletions
+8
View File
@@ -53,6 +53,14 @@ smoke-unhackable *ARGS:
--teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \ --teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \
--eval-n-prompts=2 {{ ARGS }} --eval-n-prompts=2 {{ ARGS }}
# routeV with a top-k routing subspace (max_i cos(g,v_i) over k SVD dirs) instead of
# the single mean-mass axis. UAT: log shows "top-3 SVD subspace, gate=max_i cos" and the
# band/gate still route (rout>0). k=1 (default) is the mean-diff headline.
smoke-topk *ARGS:
BEARTYPE=1 {{ TRAIN }} smoke --intervention=routeV --v-grad-k=3 \
--teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \
--eval-ablate-every=10 --eval-n-prompts=2 {{ ARGS }}
# All three arms back to back (the full-coverage gate). # All three arms back to back (the full-coverage gate).
smoke-all: smoke-all:
just smoke-vanilla just smoke-vanilla
+50 -19
View File
@@ -44,6 +44,7 @@ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import tyro import tyro
from jaxtyping import Float
from loguru import logger from loguru import logger
from safetensors.torch import save_file from safetensors.torch import save_file
from tabulate import tabulate from tabulate import tabulate
@@ -64,12 +65,43 @@ RUNS_DIR = OUT_DIR / "runs"
def _haar_unit_dirs(v_grad: dict, seed: int, device) -> dict: def _haar_unit_dirs(v_grad: dict, seed: int, device) -> dict:
"""Build the reproducible out-of-subspace directionality control (placebo) for routeV.""" """Build the reproducible out-of-subspace directionality control (placebo) for routeV.
Matches v_grad's [k, r] shape and unit-normalizes per row, so a top-k routing run
gets k random dirs (the gate's max-cosine still sees the same shape) for a fair placebo."""
g = torch.Generator().manual_seed(seed) g = torch.Generator().manual_seed(seed)
out = {} out = {}
for name in sorted(v_grad): for name in sorted(v_grad):
d = torch.randn(v_grad[name].shape, generator=g) d = torch.randn(v_grad[name].shape, generator=g) # [k, r]
out[name] = (d / d.norm().clamp_min(1e-12)).to(device) out[name] = (d / d.norm(dim=-1, keepdim=True).clamp_min(1e-12)).to(device)
return out
def _build_v_grad(raw_grads: dict, names, k: int, device) -> dict[str, Float[torch.Tensor, "k r"]]:
"""Per-module routing directions from authored-pair gradients -> dict[name -> [k, r]].
k=1 (headline): normalized mean(hack-clean) -- one mean-mass axis. k>1: the top-k
oriented right singular vectors of the paired-diff matrix D=[g_hack-g_clean] (SVD +
per-pair majority orient, mirroring extract_vhack_grad.extract_v_hack), a rank-k hack
subspace. The live gate scores max_i cos(g, v_i). Rows are unit-norm. k=1 is NOT
SVD-top-1 (they differ off-isotropic): keeping mean-diff makes 'mean-mass vs top-k'
a clean A/B, not a confound."""
out = {}
for name in names:
D: Float[torch.Tensor, "n_pairs r"] = (
raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).float()
if k == 1:
d = D.mean(0)
V = (d / d.norm().clamp_min(1e-12)).unsqueeze(0) # [1, r]
else:
_, _, Vh = torch.linalg.svd(D, full_matrices=False)
kk = min(k, Vh.shape[0])
V = Vh[:kk] # [kk, r] orthonormal
proj = torch.einsum("n r, k r -> n k", D, V) # per-pair projection
flip = torch.where((proj > 0).float().sum(0) < D.shape[0] / 2,
-torch.ones(kk), torch.ones(kk)) # orient hack-ward
V = V * flip.unsqueeze(1)
out[name] = V.to(device)
return out return out
@@ -93,11 +125,12 @@ def route_band_edges(raw_grads: dict, v_grad: dict, device) -> dict[str, tuple[f
""" """
band = {} band = {}
for name in v_grad: for name in v_grad:
v = v_grad[name].detach().cpu().float() v: Float[torch.Tensor, "k r"] = v_grad[name].detach().cpu().float()
gh = raw_grads[f"hack/{name}"].float() # [n_pairs, r] gh = raw_grads[f"hack/{name}"].float() # [n_pairs, r]
gc = raw_grads[f"clean/{name}"].float() gc = raw_grads[f"clean/{name}"].float()
ch = (gh @ v) / gh.norm(dim=1).clamp_min(1e-12) # [n_pairs] hack-pair cosines # max_i cos(g, v_i): same scoring the live gate uses, so band edges are apples-to-apples.
cc = (gc @ v) / gc.norm(dim=1).clamp_min(1e-12) # [n_pairs] clean-pair cosines ch = torch.einsum("n r, k r -> n k", gh, v).max(dim=1).values / gh.norm(dim=1).clamp_min(1e-12)
cc = torch.einsum("n r, k r -> n k", gc, v).max(dim=1).values / gc.norm(dim=1).clamp_min(1e-12)
band[name] = (cc.quantile(0.75).item(), ch.quantile(0.75).item()) # (lower=p75 clean, upper=p75 hack) band[name] = (cc.quantile(0.75).item(), ch.quantile(0.75).item()) # (lower=p75 clean, upper=p75 hack)
return band return band
@@ -206,20 +239,18 @@ def main(cfg: Config) -> int:
route_band = None route_band = None
if is_routeV: if is_routeV:
# Authored pairs are the only routing-label source; live oracle labels never enter training. # Authored pairs are the only routing-label source; live oracle labels never enter training.
from .pairs_from_pool import load_pairs_json from .pairs import load_pairs
from .extract_vhack_grad import extract_v_hack from .extract_vhack_grad import extract_v_hack
MASK_PAIRS = load_pairs_json(cfg.vhack_pairs_path) MASK_PAIRS = load_pairs(cfg.vhack_pairs_path)
logger.info(f"routeV pairs: {cfg.vhack_pairs_path} -> {len(MASK_PAIRS)} pairs") logger.info(f"routeV pairs: {cfg.vhack_pairs_path} -> {len(MASK_PAIRS)} pairs")
model.eval() # match standalone extract: deterministic backward, no dropout model.eval() # match standalone extract: deterministic backward, no dropout
_, _, raw_grads, _ = extract_v_hack( _, _, raw_grads, _ = extract_v_hack(
model, tok, wrappers, MASK_PAIRS, model, tok, wrappers, MASK_PAIRS,
top_k=1, tau_axis=0.0, n_heldout=2, device=device, top_k=1, tau_axis=0.0, n_heldout=2, device=device,
) )
v_grad = {} v_grad = _build_v_grad(raw_grads, wrappers, cfg.v_grad_k, device)
for name in wrappers: _vk = "mean-diff (mean-mass)" if cfg.v_grad_k == 1 else f"top-{cfg.v_grad_k} SVD subspace, gate=max_i cos"
d = (raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).mean(0) logger.info(f"routeV grad: built v_grad ({_vk}) for {len(v_grad)} modules")
v_grad[name] = (d / d.norm().clamp_min(1e-12)).to(device)
logger.info(f"routeV grad: built v_grad (gradient mean-diff) for {len(v_grad)} modules")
if cfg.routeV_random_v_seed is not None: if cfg.routeV_random_v_seed is not None:
v_grad = _haar_unit_dirs(v_grad, cfg.routeV_random_v_seed, device) v_grad = _haar_unit_dirs(v_grad, cfg.routeV_random_v_seed, device)
logger.info(f"routeV grad: OVERRODE v_grad with Haar-random dirs " logger.info(f"routeV grad: OVERRODE v_grad with Haar-random dirs "
@@ -478,9 +509,10 @@ def main(cfg: Config) -> int:
if upper - lower <= 0: # noisy module: pairs don't separate -> excluded if upper - lower <= 0: # noisy module: pairs don't separate -> excluded
continue continue
r_blk = info["r"] r_blk = info["r"]
g_b = cg.reshape(n_rollouts, -1, 2 * r_blk).sum(1)[:, :r_blk].float() # [G, r] deployed block g_b: Float[torch.Tensor, "G r"] = cg.reshape(n_rollouts, -1, 2 * r_blk).sum(1)[:, :r_blk].float()
nrm = g_b.norm(dim=1) nrm = g_b.norm(dim=1)
cos_b = (g_b @ v_grad[name]) / nrm.clamp_min(1e-12) # [G] # cos to each of the k routing dirs, then max: aligned with ANY known hack sub-mode.
cos_b = torch.einsum("g r, k r -> g k", g_b, v_grad[name]).max(dim=1).values / nrm.clamp_min(1e-12)
num += cos_b - lower; den += upper - lower num += cos_b - lower; den += upper - lower
w += nrm; n_inc += 1 w += nrm; n_inc += 1
if n_inc == 0: if n_inc == 0:
@@ -852,9 +884,8 @@ def main(cfg: Config) -> int:
model, tok, wrappers, MASK_PAIRS, model, tok, wrappers, MASK_PAIRS,
top_k=1, tau_axis=0.0, n_heldout=2, device=device, top_k=1, tau_axis=0.0, n_heldout=2, device=device,
) )
for name in wrappers: # update in place so the gate closure sees it # update in place so the gate closure sees the fresh dirs (same k as init).
d = (raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).mean(0) v_grad.update(_build_v_grad(raw_grads, wrappers, cfg.v_grad_k, device))
v_grad[name] = (d / d.norm().clamp_min(1e-12)).to(device)
route_band = route_band_edges(raw_grads, v_grad, device) # rebuild band on fresh v_grad route_band = route_band_edges(raw_grads, v_grad, device) # rebuild band on fresh v_grad
finally: finally:
logger.enable("vgrout.extract_vhack_grad") logger.enable("vgrout.extract_vhack_grad")
@@ -1119,7 +1150,7 @@ def main(cfg: Config) -> int:
"run_dir": run_dir.name, "arm": cfg.arm, "intervention": cfg.intervention, "run_dir": run_dir.name, "arm": cfg.arm, "intervention": cfg.intervention,
"adapter": "lora2r", "adapter": "lora2r",
"seed": cfg.seed, "steps": n_steps, "model": model_name, "out_tag": cfg.out_tag, "seed": cfg.seed, "steps": n_steps, "model": model_name, "out_tag": cfg.out_tag,
"unhackable_frac": cfg.unhackable_frac, "pairs": str(cfg.vhack_pairs_path.name), "unhackable_frac": cfg.unhackable_frac, "pairs": str(cfg.vhack_pairs_path),
"eval_set": "test", "eval_modes": eval_modes, "n": ev["n"], "eval_set": "test", "eval_modes": eval_modes, "n": ev["n"],
"hack_deployed": ev["hack"], "vhack_deployed": ev["vhack"], "solve_deployed": ev["solve"], "hack_deployed": ev["hack"], "vhack_deployed": ev["vhack"], "solve_deployed": ev["solve"],
"hack_as_trained": ev_on["hack"], "vhack_as_trained": ev_on["vhack"], "hack_as_trained": ev_on["hack"], "vhack_as_trained": ev_on["vhack"],
+6 -1
View File
@@ -43,7 +43,12 @@ class Config:
unbiased: bool = True unbiased: bool = True
vhack_refresh_every: int = 5 vhack_refresh_every: int = 5
vhack_pairs_path: Path = Path("out/pairsets/pairs_authored.json") vhack_pairs_path: Path = Path("docs/personas/hack_pairs.md#mechanism-authored")
# Routing directions per module. k=1 (headline): the mean(hack-clean) "mean-mass"
# axis. k>1: top-k oriented SVD dirs of the paired diff; gate scores max_i cos(g,v_i)
# (alignment to ANY known hack sub-mode) -- catches multi-modal hack signal one mean
# washes out. k=1 stays mean-diff (not SVD-top-1) so "mean-mass vs top-k" is a clean A/B.
v_grad_k: int = 1
# Haar-random direction control (placebo): same routing machinery, no pair signal. # Haar-random direction control (placebo): same routing machinery, no pair signal.
routeV_random_v_seed: int | None = None routeV_random_v_seed: int | None = None
rollout_ablate_frac: float = 0.0 rollout_ablate_frac: float = 0.0
-145
View File
@@ -1,145 +0,0 @@
"""Loading and post-processing the extracted hack-direction basis (v_hack).
v_hack is a per-module set of top-k right singular vectors of the labeled-pair
GRPO gradient, saved by extract_vhack_grad. Here we load it for a wrapped model
(checking the model/dtype/rank all match) and apply the top-k slice plus the
global noise-floor filter. The same post-processing serves both the init-time
load and the in-loop refresh.
"""
from __future__ import annotations
import hashlib
from pathlib import Path
import torch
from jaxtyping import Float
from loguru import logger
from safetensors import safe_open
def pairset_sha256(path: Path) -> str:
return hashlib.sha256(path.read_bytes()).hexdigest()
def load_v_hack(
path: Path, model_name: str, wrappers: dict, pairs_path: Path,
k_use: int | None = None, drop_bottom_frac: float = 0.0,
) -> dict[str, Float[torch.Tensor, "k r"]]:
"""Load v_hack (top-k directions) for this wrapped model.
File schema (v2): bare `{name}` keys hold V[k_max, r]; `_sv/{name}` keys hold
S[k_max]. v_hack is model-specific because module names and per-module SVD
ranks depend on the exact checkpoint; a smoke (Qwen3.5-0.8B) v_hack must
not be reused for a full (Qwen3-4B) run.
If `k_use` is given, slices V (and S) to top-k_use rows. Errors if
k_use > k_max saved (re-extract with a higher top_k).
If `drop_bottom_frac > 0`, collects every S_i across every module and drops
the bottom-fraction by global quantile. Modules whose every axis is below
the global threshold get filtered out of the returned dict (projection on
those modules becomes a no-op — they didn't carry hack signal anywhere).
"""
with safe_open(str(path), framework="pt", device="cpu") as f:
meta = f.metadata() or {}
saved_model = meta.get("model")
saved_dtype = meta.get("dtype")
saved_pairs_sha256 = meta.get("pairs_sha256")
if saved_model is None or saved_dtype is None or saved_pairs_sha256 is None:
raise ValueError(
f"{path} has no model/dtype/pairs_sha256 metadata. "
f"Re-extract with `uv run python -m vgrout.extract_vhack_grad "
f"--model={model_name} --dtype=bf16 --pairs-from-pool={pairs_path} --out-path={path}`."
)
if saved_model != model_name:
raise ValueError(f"v_hack model mismatch: {path} has {saved_model}, run uses {model_name}")
expected_pairs_sha256 = pairset_sha256(pairs_path)
if saved_pairs_sha256 != expected_pairs_sha256:
raise ValueError(
f"v_hack pairset mismatch: {path} has sha256={saved_pairs_sha256}, "
f"{pairs_path} has sha256={expected_pairs_sha256}. Re-extract the direction."
)
# dtype mismatch: cross-dtype SVD bases can diverge silently, so error
# unless the saved dtype matches what train.py uses on this device.
# CPU runs in fp32, CUDA runs in bf16 (see model-load site above).
expected_dtype = "fp32" if torch.cuda.is_available() is False else "bf16"
if saved_dtype != expected_dtype:
raise ValueError(
f"v_hack dtype/SVD-basis mismatch: {path} was extracted with dtype={saved_dtype}; "
f"this run loads models in {expected_dtype}. Re-extract with `--dtype={expected_dtype}`."
)
v_hack = {k: f.get_tensor(k) for k in f.keys() if not k.startswith("_sv/")}
v_sv = {k[len("_sv/"):]: f.get_tensor(k) for k in f.keys() if k.startswith("_sv/")}
wrapper_keys = set(wrappers)
vhack_keys = set(v_hack)
missing = sorted(wrapper_keys - vhack_keys)
extra = sorted(vhack_keys - wrapper_keys)
# v_hack[name] is [k_max, r]; delta_S is [r]. Check last-dim match (rank r).
rank_bad = [
(name, tuple(v_hack[name].shape), tuple(wrappers[name]["delta_S"].shape))
for name in sorted(wrapper_keys & vhack_keys)
if v_hack[name].ndim != 2 or v_hack[name].shape[-1] != wrappers[name]["delta_S"].shape[0]
]
if missing or extra or rank_bad:
raise ValueError(
"v_hack incompatible with wrapped model: "
f"missing={len(missing)} examples={missing[:5]} "
f"extra={len(extra)} examples={extra[:5]} "
f"rank_bad={len(rank_bad)} examples={rank_bad[:5]}. "
"Extract a fresh v_hack with `uv run python -m vgrout.extract_vhack_grad "
f"--model={model_name} --out-path={path}`."
)
v_hack = postprocess_v_hack(
v_hack, v_sv, k_use=k_use, drop_bottom_frac=drop_bottom_frac, source=str(path),
)
return v_hack
def postprocess_v_hack(
v_hack: dict[str, Float[torch.Tensor, "k r"]],
v_sv: dict[str, Float[torch.Tensor, "k"]],
k_use: int | None,
drop_bottom_frac: float,
source: str = "<refresh>",
) -> dict[str, Float[torch.Tensor, "k r"]]:
"""Apply k_use slice + global noise-floor filter.
Shared between `load_v_hack` (init-time, reading from safetensors) and the
in-loop refresh hook (where we hand in fresh `extract_v_hack` outputs).
Mutates neither input dict; returns a fresh filtered dict.
Global noise floor: collect every S_i across every module, drop the bottom
`drop_bottom_frac` by quantile. A module whose every axis falls below the
global threshold is removed entirely — projection iterates v_hack so it
becomes a no-op for that module. Threshold recomputes per call (tracks
current S distribution).
"""
k_max = next(iter(v_hack.values())).shape[0]
if k_use is not None:
if k_use > k_max:
raise ValueError(f"requested k_use={k_use} exceeds k_max={k_max} (source={source})")
v_hack = {n: v[:k_use].contiguous() for n, v in v_hack.items()}
v_sv = {n: s[:k_use].contiguous() for n, s in v_sv.items()}
n_dropped_modules = 0
n_axes_before = sum(v.shape[0] for v in v_hack.values())
threshold = None
if drop_bottom_frac > 0 and v_sv:
all_S = torch.cat([v_sv[n].float() for n in v_hack])
threshold = torch.quantile(all_S, drop_bottom_frac).item()
filtered: dict[str, torch.Tensor] = {}
for name, V in v_hack.items():
keep = v_sv[name].float() >= threshold
if keep.any():
filtered[name] = V[keep].contiguous()
else:
n_dropped_modules += 1
v_hack = filtered
n_axes_after = sum(v.shape[0] for v in v_hack.values())
logger.info(
f"postprocess_v_hack({source}): modules={len(v_hack)} (dropped {n_dropped_modules}); "
f"k_use={k_use or k_max}/k_max={k_max}; axes={n_axes_after}/{n_axes_before} kept "
f"(drop_bottom_frac={drop_bottom_frac}, threshold={threshold})"
)
return v_hack