From a94c506dbdfead39ccf8627bf04ddd72df027cc0 Mon Sep 17 00:00:00 2001 From: wassname <1103714+wassname@users.noreply.github.com> Date: Wed, 10 Jun 2026 04:42:57 +0000 Subject: [PATCH] feat: routeV_top_k -- route by oriented top-k SVD subspace (max-cos gate) The k=1 mean-diff is the only naturally hack-ward direction; SVD axes 2..k have arbitrary sign so each is re-oriented by sign(v_i . mean_diff). Gate = max_i cos(g, v_i), per-rollout grad_cosine only (asserted). top_k=1 is byte-identical to the prior mean-diff path. Smoke green: oriented [5,r] basis, band width +0.141. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com> --- src/vgrout/train.py | 37 +++++++++++++++++++++++++++++++++++-- src/vgrout/train_config.py | 5 +++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/src/vgrout/train.py b/src/vgrout/train.py index 2bbdcb6..e9b4804 100644 --- a/src/vgrout/train.py +++ b/src/vgrout/train.py @@ -338,6 +338,32 @@ def main(cfg: Config) -> int: assert _mean_bw > 0, ( f"real v_grad gave non-positive mean band width {_mean_bw:+.3f}: " "hack pairs do not separate from clean -> extraction broken") + # top-k subspace gate: oriented top-k right singular vectors of the per-pair + # diff D=[n_pairs, r], each re-oriented hack-ward by sign(v_i . mean_diff), with + # a max-over-k cosine band from the same pairs. Only the per-rollout grad_cosine + # path consumes these (asserted at config-validation below). + v_grad_topk: dict[str, torch.Tensor] = {} + route_band_topk: dict[str, tuple[float, float]] = {} + if cfg.routeV_top_k > 1: + assert cfg.routeV_gate == "grad_cosine" and not is_per_token \ + and not cfg.routeV_absorb_all, \ + "routeV_top_k>1 is implemented only for the per-rollout grad_cosine gate" + k = cfg.routeV_top_k + for name in wrappers: + gh = raw_grads[f"hack/{name}"].float() # [n_pairs, r] + gc = raw_grads[f"clean/{name}"].float() + D = gh - gc # [n_pairs, r] + Vh = torch.linalg.svd(D, full_matrices=False).Vh # [min(n,r), r] + V = Vh[:k] # [k, r] orthonormal + V = (V * torch.sign(V @ D.mean(0)).unsqueeze(1)) # [k, r] oriented hack-ward + chk = ((gh @ V.T) / gh.norm(dim=1, keepdim=True).clamp_min(1e-12)).max(1).values + cck = ((gc @ V.T) / gc.norm(dim=1, keepdim=True).clamp_min(1e-12)).max(1).values + v_grad_topk[name] = V.to(device) + route_band_topk[name] = (cck.quantile(0.75).item(), chk.quantile(0.75).item()) + _bw_tk = sum(hi - lo for lo, hi in route_band_topk.values()) / len(route_band_topk) + logger.info(f"routeV top-{k} subspace: built oriented [{k},r] basis for " + f"{len(v_grad_topk)} modules, mean max-cos band width={_bw_tk:+.3f} " + "(>0 = top-k subspace separates hack from clean)") if cfg.routeV_gate == "act_vote": As_dir, act_w, vote_band = build_act_vote_dirs(model, wrappers, tok, MASK_PAIRS, device) model.train() @@ -816,8 +842,15 @@ def main(cfg: Config) -> int: else: cg = cg_full.sum(1) # [G, r] per-rollout g_b = torch.where(reliable, cg / dS_safe, torch.zeros_like(cg)) # [G, r] - cos_b = (g_b @ vg) / g_b.norm(dim=1).clamp_min(1e-12) # [G] - f = ((cos_b - lower) / band).clamp(0.0, 1.0) # [G] + if cfg.routeV_top_k > 1: + # gate on the most-aligned oriented top-k axis (max-cos subspace gate) + V = v_grad_topk[name] # [k, r] + cos_b = ((g_b @ V.T) / g_b.norm(dim=1, keepdim=True).clamp_min(1e-12)).max(1).values # [G] + lower_tk, upper_tk = route_band_topk[name] + f = ((cos_b - lower_tk) / max(upper_tk - lower_tk, 1e-6)).clamp(0.0, 1.0) # [G] + else: + cos_b = (g_b @ vg) / g_b.norm(dim=1).clamp_min(1e-12) # [G] + f = ((cos_b - lower) / band).clamp(0.0, 1.0) # [G] routed = torch.where(reliable, (cg * f.unsqueeze(1)).sum(0) / dS_safe, torch.zeros_like(g)) # Σ_b f_b·g_b on reliable axes step_flagged.append(f.mean().item()) diff --git a/src/vgrout/train_config.py b/src/vgrout/train_config.py index 45cf5a2..4882cf0 100644 --- a/src/vgrout/train_config.py +++ b/src/vgrout/train_config.py @@ -45,6 +45,11 @@ class Config: vhack_pairs_path: Path = Path("out/pairsets/prog_wide_clean.json") routeV_random_v_seed: int | None = None + # >1: route by the oriented top-k SVD subspace (gate = max_i cos(g, v_i)) instead of the + # k=1 mean-diff. The mean-diff is the only naturally hack-ward direction; SVD axes 2..k + # have arbitrary sign, so each is re-oriented by sign(v_i . mean_diff). per-rollout + # grad_cosine only (asserted in train.py). + routeV_top_k: int = 1 # pinning: how the routing band is calibrated. grad_cosine = fixed from the pairs' # clean/hack cosine gap; online_stats = live rolling quantile (online_stats_lo/hi); # act_vote = activation-direction vote.