diff --git a/justfile b/justfile index 27ff7a4..c505e13 100644 --- a/justfile +++ b/justfile @@ -53,6 +53,14 @@ smoke-unhackable *ARGS: --teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \ --eval-n-prompts=2 {{ ARGS }} +# routeV with a top-k routing subspace (max_i cos(g,v_i) over k SVD dirs) instead of +# the single mean-mass axis. UAT: log shows "top-3 SVD subspace, gate=max_i cos" and the +# band/gate still route (rout>0). k=1 (default) is the mean-diff headline. +smoke-topk *ARGS: + BEARTYPE=1 {{ TRAIN }} smoke --intervention=routeV --v-grad-k=3 \ + --teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \ + --eval-ablate-every=10 --eval-n-prompts=2 {{ ARGS }} + # All three arms back to back (the full-coverage gate). smoke-all: just smoke-vanilla diff --git a/src/vgrout/train.py b/src/vgrout/train.py index a3456a8..4aae56c 100644 --- a/src/vgrout/train.py +++ b/src/vgrout/train.py @@ -44,6 +44,7 @@ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") import torch import torch.nn.functional as F import tyro +from jaxtyping import Float from loguru import logger from safetensors.torch import save_file from tabulate import tabulate @@ -64,12 +65,43 @@ RUNS_DIR = OUT_DIR / "runs" def _haar_unit_dirs(v_grad: dict, seed: int, device) -> dict: - """Build the reproducible out-of-subspace directionality control (placebo) for routeV.""" + """Build the reproducible out-of-subspace directionality control (placebo) for routeV. + + Matches v_grad's [k, r] shape and unit-normalizes per row, so a top-k routing run + gets k random dirs (the gate's max-cosine still sees the same shape) for a fair placebo.""" g = torch.Generator().manual_seed(seed) out = {} for name in sorted(v_grad): - d = torch.randn(v_grad[name].shape, generator=g) - out[name] = (d / d.norm().clamp_min(1e-12)).to(device) + d = torch.randn(v_grad[name].shape, generator=g) # [k, r] + out[name] = (d / d.norm(dim=-1, keepdim=True).clamp_min(1e-12)).to(device) + return out + + +def _build_v_grad(raw_grads: dict, names, k: int, device) -> dict[str, Float[torch.Tensor, "k r"]]: + """Per-module routing directions from authored-pair gradients -> dict[name -> [k, r]]. + + k=1 (headline): normalized mean(hack-clean) -- one mean-mass axis. k>1: the top-k + oriented right singular vectors of the paired-diff matrix D=[g_hack-g_clean] (SVD + + per-pair majority orient, mirroring extract_vhack_grad.extract_v_hack), a rank-k hack + subspace. The live gate scores max_i cos(g, v_i). Rows are unit-norm. k=1 is NOT + SVD-top-1 (they differ off-isotropic): keeping mean-diff makes 'mean-mass vs top-k' + a clean A/B, not a confound.""" + out = {} + for name in names: + D: Float[torch.Tensor, "n_pairs r"] = ( + raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).float() + if k == 1: + d = D.mean(0) + V = (d / d.norm().clamp_min(1e-12)).unsqueeze(0) # [1, r] + else: + _, _, Vh = torch.linalg.svd(D, full_matrices=False) + kk = min(k, Vh.shape[0]) + V = Vh[:kk] # [kk, r] orthonormal + proj = torch.einsum("n r, k r -> n k", D, V) # per-pair projection + flip = torch.where((proj > 0).float().sum(0) < D.shape[0] / 2, + -torch.ones(kk), torch.ones(kk)) # orient hack-ward + V = V * flip.unsqueeze(1) + out[name] = V.to(device) return out @@ -93,11 +125,12 @@ def route_band_edges(raw_grads: dict, v_grad: dict, device) -> dict[str, tuple[f """ band = {} for name in v_grad: - v = v_grad[name].detach().cpu().float() + v: Float[torch.Tensor, "k r"] = v_grad[name].detach().cpu().float() gh = raw_grads[f"hack/{name}"].float() # [n_pairs, r] gc = raw_grads[f"clean/{name}"].float() - ch = (gh @ v) / gh.norm(dim=1).clamp_min(1e-12) # [n_pairs] hack-pair cosines - cc = (gc @ v) / gc.norm(dim=1).clamp_min(1e-12) # [n_pairs] clean-pair cosines + # max_i cos(g, v_i): same scoring the live gate uses, so band edges are apples-to-apples. + ch = torch.einsum("n r, k r -> n k", gh, v).max(dim=1).values / gh.norm(dim=1).clamp_min(1e-12) + cc = torch.einsum("n r, k r -> n k", gc, v).max(dim=1).values / gc.norm(dim=1).clamp_min(1e-12) band[name] = (cc.quantile(0.75).item(), ch.quantile(0.75).item()) # (lower=p75 clean, upper=p75 hack) return band @@ -206,20 +239,18 @@ def main(cfg: Config) -> int: route_band = None if is_routeV: # Authored pairs are the only routing-label source; live oracle labels never enter training. - from .pairs_from_pool import load_pairs_json + from .pairs import load_pairs from .extract_vhack_grad import extract_v_hack - MASK_PAIRS = load_pairs_json(cfg.vhack_pairs_path) + MASK_PAIRS = load_pairs(cfg.vhack_pairs_path) logger.info(f"routeV pairs: {cfg.vhack_pairs_path} -> {len(MASK_PAIRS)} pairs") model.eval() # match standalone extract: deterministic backward, no dropout _, _, raw_grads, _ = extract_v_hack( model, tok, wrappers, MASK_PAIRS, top_k=1, tau_axis=0.0, n_heldout=2, device=device, ) - v_grad = {} - for name in wrappers: - d = (raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).mean(0) - v_grad[name] = (d / d.norm().clamp_min(1e-12)).to(device) - logger.info(f"routeV grad: built v_grad (gradient mean-diff) for {len(v_grad)} modules") + v_grad = _build_v_grad(raw_grads, wrappers, cfg.v_grad_k, device) + _vk = "mean-diff (mean-mass)" if cfg.v_grad_k == 1 else f"top-{cfg.v_grad_k} SVD subspace, gate=max_i cos" + logger.info(f"routeV grad: built v_grad ({_vk}) for {len(v_grad)} modules") if cfg.routeV_random_v_seed is not None: v_grad = _haar_unit_dirs(v_grad, cfg.routeV_random_v_seed, device) logger.info(f"routeV grad: OVERRODE v_grad with Haar-random dirs " @@ -478,9 +509,10 @@ def main(cfg: Config) -> int: if upper - lower <= 0: # noisy module: pairs don't separate -> excluded continue r_blk = info["r"] - g_b = cg.reshape(n_rollouts, -1, 2 * r_blk).sum(1)[:, :r_blk].float() # [G, r] deployed block + g_b: Float[torch.Tensor, "G r"] = cg.reshape(n_rollouts, -1, 2 * r_blk).sum(1)[:, :r_blk].float() nrm = g_b.norm(dim=1) - cos_b = (g_b @ v_grad[name]) / nrm.clamp_min(1e-12) # [G] + # cos to each of the k routing dirs, then max: aligned with ANY known hack sub-mode. + cos_b = torch.einsum("g r, k r -> g k", g_b, v_grad[name]).max(dim=1).values / nrm.clamp_min(1e-12) num += cos_b - lower; den += upper - lower w += nrm; n_inc += 1 if n_inc == 0: @@ -852,9 +884,8 @@ def main(cfg: Config) -> int: model, tok, wrappers, MASK_PAIRS, top_k=1, tau_axis=0.0, n_heldout=2, device=device, ) - for name in wrappers: # update in place so the gate closure sees it - d = (raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).mean(0) - v_grad[name] = (d / d.norm().clamp_min(1e-12)).to(device) + # update in place so the gate closure sees the fresh dirs (same k as init). + v_grad.update(_build_v_grad(raw_grads, wrappers, cfg.v_grad_k, device)) route_band = route_band_edges(raw_grads, v_grad, device) # rebuild band on fresh v_grad finally: logger.enable("vgrout.extract_vhack_grad") @@ -1119,7 +1150,7 @@ def main(cfg: Config) -> int: "run_dir": run_dir.name, "arm": cfg.arm, "intervention": cfg.intervention, "adapter": "lora2r", "seed": cfg.seed, "steps": n_steps, "model": model_name, "out_tag": cfg.out_tag, - "unhackable_frac": cfg.unhackable_frac, "pairs": str(cfg.vhack_pairs_path.name), + "unhackable_frac": cfg.unhackable_frac, "pairs": str(cfg.vhack_pairs_path), "eval_set": "test", "eval_modes": eval_modes, "n": ev["n"], "hack_deployed": ev["hack"], "vhack_deployed": ev["vhack"], "solve_deployed": ev["solve"], "hack_as_trained": ev_on["hack"], "vhack_as_trained": ev_on["vhack"], diff --git a/src/vgrout/train_config.py b/src/vgrout/train_config.py index fba3f85..c7da54d 100644 --- a/src/vgrout/train_config.py +++ b/src/vgrout/train_config.py @@ -43,7 +43,12 @@ class Config: unbiased: bool = True vhack_refresh_every: int = 5 - vhack_pairs_path: Path = Path("out/pairsets/pairs_authored.json") + vhack_pairs_path: Path = Path("docs/personas/hack_pairs.md#mechanism-authored") + # Routing directions per module. k=1 (headline): the mean(hack-clean) "mean-mass" + # axis. k>1: top-k oriented SVD dirs of the paired diff; gate scores max_i cos(g,v_i) + # (alignment to ANY known hack sub-mode) -- catches multi-modal hack signal one mean + # washes out. k=1 stays mean-diff (not SVD-top-1) so "mean-mass vs top-k" is a clean A/B. + v_grad_k: int = 1 # Haar-random direction control (placebo): same routing machinery, no pair signal. routeV_random_v_seed: int | None = None rollout_ablate_frac: float = 0.0 diff --git a/src/vgrout/vhack.py b/src/vgrout/vhack.py deleted file mode 100644 index d9f291b..0000000 --- a/src/vgrout/vhack.py +++ /dev/null @@ -1,145 +0,0 @@ -"""Loading and post-processing the extracted hack-direction basis (v_hack). - -v_hack is a per-module set of top-k right singular vectors of the labeled-pair -GRPO gradient, saved by extract_vhack_grad. Here we load it for a wrapped model -(checking the model/dtype/rank all match) and apply the top-k slice plus the -global noise-floor filter. The same post-processing serves both the init-time -load and the in-loop refresh. -""" -from __future__ import annotations - -import hashlib -from pathlib import Path - -import torch -from jaxtyping import Float -from loguru import logger -from safetensors import safe_open - - -def pairset_sha256(path: Path) -> str: - return hashlib.sha256(path.read_bytes()).hexdigest() - - -def load_v_hack( - path: Path, model_name: str, wrappers: dict, pairs_path: Path, - k_use: int | None = None, drop_bottom_frac: float = 0.0, -) -> dict[str, Float[torch.Tensor, "k r"]]: - """Load v_hack (top-k directions) for this wrapped model. - - File schema (v2): bare `{name}` keys hold V[k_max, r]; `_sv/{name}` keys hold - S[k_max]. v_hack is model-specific because module names and per-module SVD - ranks depend on the exact checkpoint; a smoke (Qwen3.5-0.8B) v_hack must - not be reused for a full (Qwen3-4B) run. - - If `k_use` is given, slices V (and S) to top-k_use rows. Errors if - k_use > k_max saved (re-extract with a higher top_k). - - If `drop_bottom_frac > 0`, collects every S_i across every module and drops - the bottom-fraction by global quantile. Modules whose every axis is below - the global threshold get filtered out of the returned dict (projection on - those modules becomes a no-op — they didn't carry hack signal anywhere). - """ - with safe_open(str(path), framework="pt", device="cpu") as f: - meta = f.metadata() or {} - saved_model = meta.get("model") - saved_dtype = meta.get("dtype") - saved_pairs_sha256 = meta.get("pairs_sha256") - if saved_model is None or saved_dtype is None or saved_pairs_sha256 is None: - raise ValueError( - f"{path} has no model/dtype/pairs_sha256 metadata. " - f"Re-extract with `uv run python -m vgrout.extract_vhack_grad " - f"--model={model_name} --dtype=bf16 --pairs-from-pool={pairs_path} --out-path={path}`." - ) - if saved_model != model_name: - raise ValueError(f"v_hack model mismatch: {path} has {saved_model}, run uses {model_name}") - expected_pairs_sha256 = pairset_sha256(pairs_path) - if saved_pairs_sha256 != expected_pairs_sha256: - raise ValueError( - f"v_hack pairset mismatch: {path} has sha256={saved_pairs_sha256}, " - f"{pairs_path} has sha256={expected_pairs_sha256}. Re-extract the direction." - ) - # dtype mismatch: cross-dtype SVD bases can diverge silently, so error - # unless the saved dtype matches what train.py uses on this device. - # CPU runs in fp32, CUDA runs in bf16 (see model-load site above). - expected_dtype = "fp32" if torch.cuda.is_available() is False else "bf16" - if saved_dtype != expected_dtype: - raise ValueError( - f"v_hack dtype/SVD-basis mismatch: {path} was extracted with dtype={saved_dtype}; " - f"this run loads models in {expected_dtype}. Re-extract with `--dtype={expected_dtype}`." - ) - v_hack = {k: f.get_tensor(k) for k in f.keys() if not k.startswith("_sv/")} - v_sv = {k[len("_sv/"):]: f.get_tensor(k) for k in f.keys() if k.startswith("_sv/")} - - wrapper_keys = set(wrappers) - vhack_keys = set(v_hack) - missing = sorted(wrapper_keys - vhack_keys) - extra = sorted(vhack_keys - wrapper_keys) - # v_hack[name] is [k_max, r]; delta_S is [r]. Check last-dim match (rank r). - rank_bad = [ - (name, tuple(v_hack[name].shape), tuple(wrappers[name]["delta_S"].shape)) - for name in sorted(wrapper_keys & vhack_keys) - if v_hack[name].ndim != 2 or v_hack[name].shape[-1] != wrappers[name]["delta_S"].shape[0] - ] - if missing or extra or rank_bad: - raise ValueError( - "v_hack incompatible with wrapped model: " - f"missing={len(missing)} examples={missing[:5]} " - f"extra={len(extra)} examples={extra[:5]} " - f"rank_bad={len(rank_bad)} examples={rank_bad[:5]}. " - "Extract a fresh v_hack with `uv run python -m vgrout.extract_vhack_grad " - f"--model={model_name} --out-path={path}`." - ) - - v_hack = postprocess_v_hack( - v_hack, v_sv, k_use=k_use, drop_bottom_frac=drop_bottom_frac, source=str(path), - ) - return v_hack - - -def postprocess_v_hack( - v_hack: dict[str, Float[torch.Tensor, "k r"]], - v_sv: dict[str, Float[torch.Tensor, "k"]], - k_use: int | None, - drop_bottom_frac: float, - source: str = "", -) -> dict[str, Float[torch.Tensor, "k r"]]: - """Apply k_use slice + global noise-floor filter. - - Shared between `load_v_hack` (init-time, reading from safetensors) and the - in-loop refresh hook (where we hand in fresh `extract_v_hack` outputs). - Mutates neither input dict; returns a fresh filtered dict. - - Global noise floor: collect every S_i across every module, drop the bottom - `drop_bottom_frac` by quantile. A module whose every axis falls below the - global threshold is removed entirely — projection iterates v_hack so it - becomes a no-op for that module. Threshold recomputes per call (tracks - current S distribution). - """ - k_max = next(iter(v_hack.values())).shape[0] - if k_use is not None: - if k_use > k_max: - raise ValueError(f"requested k_use={k_use} exceeds k_max={k_max} (source={source})") - v_hack = {n: v[:k_use].contiguous() for n, v in v_hack.items()} - v_sv = {n: s[:k_use].contiguous() for n, s in v_sv.items()} - n_dropped_modules = 0 - n_axes_before = sum(v.shape[0] for v in v_hack.values()) - threshold = None - if drop_bottom_frac > 0 and v_sv: - all_S = torch.cat([v_sv[n].float() for n in v_hack]) - threshold = torch.quantile(all_S, drop_bottom_frac).item() - filtered: dict[str, torch.Tensor] = {} - for name, V in v_hack.items(): - keep = v_sv[name].float() >= threshold - if keep.any(): - filtered[name] = V[keep].contiguous() - else: - n_dropped_modules += 1 - v_hack = filtered - n_axes_after = sum(v.shape[0] for v in v_hack.values()) - logger.info( - f"postprocess_v_hack({source}): modules={len(v_hack)} (dropped {n_dropped_modules}); " - f"k_use={k_use or k_max}/k_max={k_max}; axes={n_axes_after}/{n_axes_before} kept " - f"(drop_bottom_frac={drop_bottom_frac}, threshold={threshold})" - ) - return v_hack