mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 15:00:20 +08:00
feat: top-k routing subspace for routeV (--v-grad-k, gate=max_i cos)
k=1 (default) stays the mean-mass mean-diff axis -- headline unchanged. k>1 builds the top-k oriented SVD dirs of the paired diff and the gate scores max_i cos(g, v_i) (alignment to ANY known hack sub-mode), catching multi-modal hack signal one mean washes out. Shared _build_v_grad at init + refresh; band edges and the live gate both max over k. Sims use einsum + jaxtyping dims. Smoke: just smoke-topk green (top-3 subspace, band width +0.087, 12/14 modules). Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -53,6 +53,14 @@ smoke-unhackable *ARGS:
|
||||
--teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \
|
||||
--eval-n-prompts=2 {{ ARGS }}
|
||||
|
||||
# routeV with a top-k routing subspace (max_i cos(g,v_i) over k SVD dirs) instead of
|
||||
# the single mean-mass axis. UAT: log shows "top-3 SVD subspace, gate=max_i cos" and the
|
||||
# band/gate still route (rout>0). k=1 (default) is the mean-diff headline.
|
||||
smoke-topk *ARGS:
|
||||
BEARTYPE=1 {{ TRAIN }} smoke --intervention=routeV --v-grad-k=3 \
|
||||
--teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \
|
||||
--eval-ablate-every=10 --eval-n-prompts=2 {{ ARGS }}
|
||||
|
||||
# All three arms back to back (the full-coverage gate).
|
||||
smoke-all:
|
||||
just smoke-vanilla
|
||||
|
||||
+50
-19
@@ -44,6 +44,7 @@ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import tyro
|
||||
from jaxtyping import Float
|
||||
from loguru import logger
|
||||
from safetensors.torch import save_file
|
||||
from tabulate import tabulate
|
||||
@@ -64,12 +65,43 @@ RUNS_DIR = OUT_DIR / "runs"
|
||||
|
||||
|
||||
def _haar_unit_dirs(v_grad: dict, seed: int, device) -> dict:
|
||||
"""Build the reproducible out-of-subspace directionality control (placebo) for routeV."""
|
||||
"""Build the reproducible out-of-subspace directionality control (placebo) for routeV.
|
||||
|
||||
Matches v_grad's [k, r] shape and unit-normalizes per row, so a top-k routing run
|
||||
gets k random dirs (the gate's max-cosine still sees the same shape) for a fair placebo."""
|
||||
g = torch.Generator().manual_seed(seed)
|
||||
out = {}
|
||||
for name in sorted(v_grad):
|
||||
d = torch.randn(v_grad[name].shape, generator=g)
|
||||
out[name] = (d / d.norm().clamp_min(1e-12)).to(device)
|
||||
d = torch.randn(v_grad[name].shape, generator=g) # [k, r]
|
||||
out[name] = (d / d.norm(dim=-1, keepdim=True).clamp_min(1e-12)).to(device)
|
||||
return out
|
||||
|
||||
|
||||
def _build_v_grad(raw_grads: dict, names, k: int, device) -> dict[str, Float[torch.Tensor, "k r"]]:
|
||||
"""Per-module routing directions from authored-pair gradients -> dict[name -> [k, r]].
|
||||
|
||||
k=1 (headline): normalized mean(hack-clean) -- one mean-mass axis. k>1: the top-k
|
||||
oriented right singular vectors of the paired-diff matrix D=[g_hack-g_clean] (SVD +
|
||||
per-pair majority orient, mirroring extract_vhack_grad.extract_v_hack), a rank-k hack
|
||||
subspace. The live gate scores max_i cos(g, v_i). Rows are unit-norm. k=1 is NOT
|
||||
SVD-top-1 (they differ off-isotropic): keeping mean-diff makes 'mean-mass vs top-k'
|
||||
a clean A/B, not a confound."""
|
||||
out = {}
|
||||
for name in names:
|
||||
D: Float[torch.Tensor, "n_pairs r"] = (
|
||||
raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).float()
|
||||
if k == 1:
|
||||
d = D.mean(0)
|
||||
V = (d / d.norm().clamp_min(1e-12)).unsqueeze(0) # [1, r]
|
||||
else:
|
||||
_, _, Vh = torch.linalg.svd(D, full_matrices=False)
|
||||
kk = min(k, Vh.shape[0])
|
||||
V = Vh[:kk] # [kk, r] orthonormal
|
||||
proj = torch.einsum("n r, k r -> n k", D, V) # per-pair projection
|
||||
flip = torch.where((proj > 0).float().sum(0) < D.shape[0] / 2,
|
||||
-torch.ones(kk), torch.ones(kk)) # orient hack-ward
|
||||
V = V * flip.unsqueeze(1)
|
||||
out[name] = V.to(device)
|
||||
return out
|
||||
|
||||
|
||||
@@ -93,11 +125,12 @@ def route_band_edges(raw_grads: dict, v_grad: dict, device) -> dict[str, tuple[f
|
||||
"""
|
||||
band = {}
|
||||
for name in v_grad:
|
||||
v = v_grad[name].detach().cpu().float()
|
||||
v: Float[torch.Tensor, "k r"] = v_grad[name].detach().cpu().float()
|
||||
gh = raw_grads[f"hack/{name}"].float() # [n_pairs, r]
|
||||
gc = raw_grads[f"clean/{name}"].float()
|
||||
ch = (gh @ v) / gh.norm(dim=1).clamp_min(1e-12) # [n_pairs] hack-pair cosines
|
||||
cc = (gc @ v) / gc.norm(dim=1).clamp_min(1e-12) # [n_pairs] clean-pair cosines
|
||||
# max_i cos(g, v_i): same scoring the live gate uses, so band edges are apples-to-apples.
|
||||
ch = torch.einsum("n r, k r -> n k", gh, v).max(dim=1).values / gh.norm(dim=1).clamp_min(1e-12)
|
||||
cc = torch.einsum("n r, k r -> n k", gc, v).max(dim=1).values / gc.norm(dim=1).clamp_min(1e-12)
|
||||
band[name] = (cc.quantile(0.75).item(), ch.quantile(0.75).item()) # (lower=p75 clean, upper=p75 hack)
|
||||
return band
|
||||
|
||||
@@ -206,20 +239,18 @@ def main(cfg: Config) -> int:
|
||||
route_band = None
|
||||
if is_routeV:
|
||||
# Authored pairs are the only routing-label source; live oracle labels never enter training.
|
||||
from .pairs_from_pool import load_pairs_json
|
||||
from .pairs import load_pairs
|
||||
from .extract_vhack_grad import extract_v_hack
|
||||
MASK_PAIRS = load_pairs_json(cfg.vhack_pairs_path)
|
||||
MASK_PAIRS = load_pairs(cfg.vhack_pairs_path)
|
||||
logger.info(f"routeV pairs: {cfg.vhack_pairs_path} -> {len(MASK_PAIRS)} pairs")
|
||||
model.eval() # match standalone extract: deterministic backward, no dropout
|
||||
_, _, raw_grads, _ = extract_v_hack(
|
||||
model, tok, wrappers, MASK_PAIRS,
|
||||
top_k=1, tau_axis=0.0, n_heldout=2, device=device,
|
||||
)
|
||||
v_grad = {}
|
||||
for name in wrappers:
|
||||
d = (raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).mean(0)
|
||||
v_grad[name] = (d / d.norm().clamp_min(1e-12)).to(device)
|
||||
logger.info(f"routeV grad: built v_grad (gradient mean-diff) for {len(v_grad)} modules")
|
||||
v_grad = _build_v_grad(raw_grads, wrappers, cfg.v_grad_k, device)
|
||||
_vk = "mean-diff (mean-mass)" if cfg.v_grad_k == 1 else f"top-{cfg.v_grad_k} SVD subspace, gate=max_i cos"
|
||||
logger.info(f"routeV grad: built v_grad ({_vk}) for {len(v_grad)} modules")
|
||||
if cfg.routeV_random_v_seed is not None:
|
||||
v_grad = _haar_unit_dirs(v_grad, cfg.routeV_random_v_seed, device)
|
||||
logger.info(f"routeV grad: OVERRODE v_grad with Haar-random dirs "
|
||||
@@ -478,9 +509,10 @@ def main(cfg: Config) -> int:
|
||||
if upper - lower <= 0: # noisy module: pairs don't separate -> excluded
|
||||
continue
|
||||
r_blk = info["r"]
|
||||
g_b = cg.reshape(n_rollouts, -1, 2 * r_blk).sum(1)[:, :r_blk].float() # [G, r] deployed block
|
||||
g_b: Float[torch.Tensor, "G r"] = cg.reshape(n_rollouts, -1, 2 * r_blk).sum(1)[:, :r_blk].float()
|
||||
nrm = g_b.norm(dim=1)
|
||||
cos_b = (g_b @ v_grad[name]) / nrm.clamp_min(1e-12) # [G]
|
||||
# cos to each of the k routing dirs, then max: aligned with ANY known hack sub-mode.
|
||||
cos_b = torch.einsum("g r, k r -> g k", g_b, v_grad[name]).max(dim=1).values / nrm.clamp_min(1e-12)
|
||||
num += cos_b - lower; den += upper - lower
|
||||
w += nrm; n_inc += 1
|
||||
if n_inc == 0:
|
||||
@@ -852,9 +884,8 @@ def main(cfg: Config) -> int:
|
||||
model, tok, wrappers, MASK_PAIRS,
|
||||
top_k=1, tau_axis=0.0, n_heldout=2, device=device,
|
||||
)
|
||||
for name in wrappers: # update in place so the gate closure sees it
|
||||
d = (raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).mean(0)
|
||||
v_grad[name] = (d / d.norm().clamp_min(1e-12)).to(device)
|
||||
# update in place so the gate closure sees the fresh dirs (same k as init).
|
||||
v_grad.update(_build_v_grad(raw_grads, wrappers, cfg.v_grad_k, device))
|
||||
route_band = route_band_edges(raw_grads, v_grad, device) # rebuild band on fresh v_grad
|
||||
finally:
|
||||
logger.enable("vgrout.extract_vhack_grad")
|
||||
@@ -1119,7 +1150,7 @@ def main(cfg: Config) -> int:
|
||||
"run_dir": run_dir.name, "arm": cfg.arm, "intervention": cfg.intervention,
|
||||
"adapter": "lora2r",
|
||||
"seed": cfg.seed, "steps": n_steps, "model": model_name, "out_tag": cfg.out_tag,
|
||||
"unhackable_frac": cfg.unhackable_frac, "pairs": str(cfg.vhack_pairs_path.name),
|
||||
"unhackable_frac": cfg.unhackable_frac, "pairs": str(cfg.vhack_pairs_path),
|
||||
"eval_set": "test", "eval_modes": eval_modes, "n": ev["n"],
|
||||
"hack_deployed": ev["hack"], "vhack_deployed": ev["vhack"], "solve_deployed": ev["solve"],
|
||||
"hack_as_trained": ev_on["hack"], "vhack_as_trained": ev_on["vhack"],
|
||||
|
||||
@@ -43,7 +43,12 @@ class Config:
|
||||
unbiased: bool = True
|
||||
|
||||
vhack_refresh_every: int = 5
|
||||
vhack_pairs_path: Path = Path("out/pairsets/pairs_authored.json")
|
||||
vhack_pairs_path: Path = Path("docs/personas/hack_pairs.md#mechanism-authored")
|
||||
# Routing directions per module. k=1 (headline): the mean(hack-clean) "mean-mass"
|
||||
# axis. k>1: top-k oriented SVD dirs of the paired diff; gate scores max_i cos(g,v_i)
|
||||
# (alignment to ANY known hack sub-mode) -- catches multi-modal hack signal one mean
|
||||
# washes out. k=1 stays mean-diff (not SVD-top-1) so "mean-mass vs top-k" is a clean A/B.
|
||||
v_grad_k: int = 1
|
||||
# Haar-random direction control (placebo): same routing machinery, no pair signal.
|
||||
routeV_random_v_seed: int | None = None
|
||||
rollout_ablate_frac: float = 0.0
|
||||
|
||||
@@ -1,145 +0,0 @@
|
||||
"""Loading and post-processing the extracted hack-direction basis (v_hack).
|
||||
|
||||
v_hack is a per-module set of top-k right singular vectors of the labeled-pair
|
||||
GRPO gradient, saved by extract_vhack_grad. Here we load it for a wrapped model
|
||||
(checking the model/dtype/rank all match) and apply the top-k slice plus the
|
||||
global noise-floor filter. The same post-processing serves both the init-time
|
||||
load and the in-loop refresh.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from jaxtyping import Float
|
||||
from loguru import logger
|
||||
from safetensors import safe_open
|
||||
|
||||
|
||||
def pairset_sha256(path: Path) -> str:
|
||||
return hashlib.sha256(path.read_bytes()).hexdigest()
|
||||
|
||||
|
||||
def load_v_hack(
|
||||
path: Path, model_name: str, wrappers: dict, pairs_path: Path,
|
||||
k_use: int | None = None, drop_bottom_frac: float = 0.0,
|
||||
) -> dict[str, Float[torch.Tensor, "k r"]]:
|
||||
"""Load v_hack (top-k directions) for this wrapped model.
|
||||
|
||||
File schema (v2): bare `{name}` keys hold V[k_max, r]; `_sv/{name}` keys hold
|
||||
S[k_max]. v_hack is model-specific because module names and per-module SVD
|
||||
ranks depend on the exact checkpoint; a smoke (Qwen3.5-0.8B) v_hack must
|
||||
not be reused for a full (Qwen3-4B) run.
|
||||
|
||||
If `k_use` is given, slices V (and S) to top-k_use rows. Errors if
|
||||
k_use > k_max saved (re-extract with a higher top_k).
|
||||
|
||||
If `drop_bottom_frac > 0`, collects every S_i across every module and drops
|
||||
the bottom-fraction by global quantile. Modules whose every axis is below
|
||||
the global threshold get filtered out of the returned dict (projection on
|
||||
those modules becomes a no-op — they didn't carry hack signal anywhere).
|
||||
"""
|
||||
with safe_open(str(path), framework="pt", device="cpu") as f:
|
||||
meta = f.metadata() or {}
|
||||
saved_model = meta.get("model")
|
||||
saved_dtype = meta.get("dtype")
|
||||
saved_pairs_sha256 = meta.get("pairs_sha256")
|
||||
if saved_model is None or saved_dtype is None or saved_pairs_sha256 is None:
|
||||
raise ValueError(
|
||||
f"{path} has no model/dtype/pairs_sha256 metadata. "
|
||||
f"Re-extract with `uv run python -m vgrout.extract_vhack_grad "
|
||||
f"--model={model_name} --dtype=bf16 --pairs-from-pool={pairs_path} --out-path={path}`."
|
||||
)
|
||||
if saved_model != model_name:
|
||||
raise ValueError(f"v_hack model mismatch: {path} has {saved_model}, run uses {model_name}")
|
||||
expected_pairs_sha256 = pairset_sha256(pairs_path)
|
||||
if saved_pairs_sha256 != expected_pairs_sha256:
|
||||
raise ValueError(
|
||||
f"v_hack pairset mismatch: {path} has sha256={saved_pairs_sha256}, "
|
||||
f"{pairs_path} has sha256={expected_pairs_sha256}. Re-extract the direction."
|
||||
)
|
||||
# dtype mismatch: cross-dtype SVD bases can diverge silently, so error
|
||||
# unless the saved dtype matches what train.py uses on this device.
|
||||
# CPU runs in fp32, CUDA runs in bf16 (see model-load site above).
|
||||
expected_dtype = "fp32" if torch.cuda.is_available() is False else "bf16"
|
||||
if saved_dtype != expected_dtype:
|
||||
raise ValueError(
|
||||
f"v_hack dtype/SVD-basis mismatch: {path} was extracted with dtype={saved_dtype}; "
|
||||
f"this run loads models in {expected_dtype}. Re-extract with `--dtype={expected_dtype}`."
|
||||
)
|
||||
v_hack = {k: f.get_tensor(k) for k in f.keys() if not k.startswith("_sv/")}
|
||||
v_sv = {k[len("_sv/"):]: f.get_tensor(k) for k in f.keys() if k.startswith("_sv/")}
|
||||
|
||||
wrapper_keys = set(wrappers)
|
||||
vhack_keys = set(v_hack)
|
||||
missing = sorted(wrapper_keys - vhack_keys)
|
||||
extra = sorted(vhack_keys - wrapper_keys)
|
||||
# v_hack[name] is [k_max, r]; delta_S is [r]. Check last-dim match (rank r).
|
||||
rank_bad = [
|
||||
(name, tuple(v_hack[name].shape), tuple(wrappers[name]["delta_S"].shape))
|
||||
for name in sorted(wrapper_keys & vhack_keys)
|
||||
if v_hack[name].ndim != 2 or v_hack[name].shape[-1] != wrappers[name]["delta_S"].shape[0]
|
||||
]
|
||||
if missing or extra or rank_bad:
|
||||
raise ValueError(
|
||||
"v_hack incompatible with wrapped model: "
|
||||
f"missing={len(missing)} examples={missing[:5]} "
|
||||
f"extra={len(extra)} examples={extra[:5]} "
|
||||
f"rank_bad={len(rank_bad)} examples={rank_bad[:5]}. "
|
||||
"Extract a fresh v_hack with `uv run python -m vgrout.extract_vhack_grad "
|
||||
f"--model={model_name} --out-path={path}`."
|
||||
)
|
||||
|
||||
v_hack = postprocess_v_hack(
|
||||
v_hack, v_sv, k_use=k_use, drop_bottom_frac=drop_bottom_frac, source=str(path),
|
||||
)
|
||||
return v_hack
|
||||
|
||||
|
||||
def postprocess_v_hack(
|
||||
v_hack: dict[str, Float[torch.Tensor, "k r"]],
|
||||
v_sv: dict[str, Float[torch.Tensor, "k"]],
|
||||
k_use: int | None,
|
||||
drop_bottom_frac: float,
|
||||
source: str = "<refresh>",
|
||||
) -> dict[str, Float[torch.Tensor, "k r"]]:
|
||||
"""Apply k_use slice + global noise-floor filter.
|
||||
|
||||
Shared between `load_v_hack` (init-time, reading from safetensors) and the
|
||||
in-loop refresh hook (where we hand in fresh `extract_v_hack` outputs).
|
||||
Mutates neither input dict; returns a fresh filtered dict.
|
||||
|
||||
Global noise floor: collect every S_i across every module, drop the bottom
|
||||
`drop_bottom_frac` by quantile. A module whose every axis falls below the
|
||||
global threshold is removed entirely — projection iterates v_hack so it
|
||||
becomes a no-op for that module. Threshold recomputes per call (tracks
|
||||
current S distribution).
|
||||
"""
|
||||
k_max = next(iter(v_hack.values())).shape[0]
|
||||
if k_use is not None:
|
||||
if k_use > k_max:
|
||||
raise ValueError(f"requested k_use={k_use} exceeds k_max={k_max} (source={source})")
|
||||
v_hack = {n: v[:k_use].contiguous() for n, v in v_hack.items()}
|
||||
v_sv = {n: s[:k_use].contiguous() for n, s in v_sv.items()}
|
||||
n_dropped_modules = 0
|
||||
n_axes_before = sum(v.shape[0] for v in v_hack.values())
|
||||
threshold = None
|
||||
if drop_bottom_frac > 0 and v_sv:
|
||||
all_S = torch.cat([v_sv[n].float() for n in v_hack])
|
||||
threshold = torch.quantile(all_S, drop_bottom_frac).item()
|
||||
filtered: dict[str, torch.Tensor] = {}
|
||||
for name, V in v_hack.items():
|
||||
keep = v_sv[name].float() >= threshold
|
||||
if keep.any():
|
||||
filtered[name] = V[keep].contiguous()
|
||||
else:
|
||||
n_dropped_modules += 1
|
||||
v_hack = filtered
|
||||
n_axes_after = sum(v.shape[0] for v in v_hack.values())
|
||||
logger.info(
|
||||
f"postprocess_v_hack({source}): modules={len(v_hack)} (dropped {n_dropped_modules}); "
|
||||
f"k_use={k_use or k_max}/k_max={k_max}; axes={n_axes_after}/{n_axes_before} kept "
|
||||
f"(drop_bottom_frac={drop_bottom_frac}, threshold={threshold})"
|
||||
)
|
||||
return v_hack
|
||||
Reference in New Issue
Block a user