mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 17:48:43 +08:00
feat: routeV_top_k -- route by oriented top-k SVD subspace (max-cos gate)
The k=1 mean-diff is the only naturally hack-ward direction; SVD axes 2..k have arbitrary sign so each is re-oriented by sign(v_i . mean_diff). Gate = max_i cos(g, v_i), per-rollout grad_cosine only (asserted). top_k=1 is byte-identical to the prior mean-diff path. Smoke green: oriented [5,r] basis, band width +0.141. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
+35
-2
@@ -338,6 +338,32 @@ def main(cfg: Config) -> int:
|
||||
assert _mean_bw > 0, (
|
||||
f"real v_grad gave non-positive mean band width {_mean_bw:+.3f}: "
|
||||
"hack pairs do not separate from clean -> extraction broken")
|
||||
# top-k subspace gate: oriented top-k right singular vectors of the per-pair
|
||||
# diff D=[n_pairs, r], each re-oriented hack-ward by sign(v_i . mean_diff), with
|
||||
# a max-over-k cosine band from the same pairs. Only the per-rollout grad_cosine
|
||||
# path consumes these (asserted at config-validation below).
|
||||
v_grad_topk: dict[str, torch.Tensor] = {}
|
||||
route_band_topk: dict[str, tuple[float, float]] = {}
|
||||
if cfg.routeV_top_k > 1:
|
||||
assert cfg.routeV_gate == "grad_cosine" and not is_per_token \
|
||||
and not cfg.routeV_absorb_all, \
|
||||
"routeV_top_k>1 is implemented only for the per-rollout grad_cosine gate"
|
||||
k = cfg.routeV_top_k
|
||||
for name in wrappers:
|
||||
gh = raw_grads[f"hack/{name}"].float() # [n_pairs, r]
|
||||
gc = raw_grads[f"clean/{name}"].float()
|
||||
D = gh - gc # [n_pairs, r]
|
||||
Vh = torch.linalg.svd(D, full_matrices=False).Vh # [min(n,r), r]
|
||||
V = Vh[:k] # [k, r] orthonormal
|
||||
V = (V * torch.sign(V @ D.mean(0)).unsqueeze(1)) # [k, r] oriented hack-ward
|
||||
chk = ((gh @ V.T) / gh.norm(dim=1, keepdim=True).clamp_min(1e-12)).max(1).values
|
||||
cck = ((gc @ V.T) / gc.norm(dim=1, keepdim=True).clamp_min(1e-12)).max(1).values
|
||||
v_grad_topk[name] = V.to(device)
|
||||
route_band_topk[name] = (cck.quantile(0.75).item(), chk.quantile(0.75).item())
|
||||
_bw_tk = sum(hi - lo for lo, hi in route_band_topk.values()) / len(route_band_topk)
|
||||
logger.info(f"routeV top-{k} subspace: built oriented [{k},r] basis for "
|
||||
f"{len(v_grad_topk)} modules, mean max-cos band width={_bw_tk:+.3f} "
|
||||
"(>0 = top-k subspace separates hack from clean)")
|
||||
if cfg.routeV_gate == "act_vote":
|
||||
As_dir, act_w, vote_band = build_act_vote_dirs(model, wrappers, tok, MASK_PAIRS, device)
|
||||
model.train()
|
||||
@@ -816,8 +842,15 @@ def main(cfg: Config) -> int:
|
||||
else:
|
||||
cg = cg_full.sum(1) # [G, r] per-rollout
|
||||
g_b = torch.where(reliable, cg / dS_safe, torch.zeros_like(cg)) # [G, r]
|
||||
cos_b = (g_b @ vg) / g_b.norm(dim=1).clamp_min(1e-12) # [G]
|
||||
f = ((cos_b - lower) / band).clamp(0.0, 1.0) # [G]
|
||||
if cfg.routeV_top_k > 1:
|
||||
# gate on the most-aligned oriented top-k axis (max-cos subspace gate)
|
||||
V = v_grad_topk[name] # [k, r]
|
||||
cos_b = ((g_b @ V.T) / g_b.norm(dim=1, keepdim=True).clamp_min(1e-12)).max(1).values # [G]
|
||||
lower_tk, upper_tk = route_band_topk[name]
|
||||
f = ((cos_b - lower_tk) / max(upper_tk - lower_tk, 1e-6)).clamp(0.0, 1.0) # [G]
|
||||
else:
|
||||
cos_b = (g_b @ vg) / g_b.norm(dim=1).clamp_min(1e-12) # [G]
|
||||
f = ((cos_b - lower) / band).clamp(0.0, 1.0) # [G]
|
||||
routed = torch.where(reliable, (cg * f.unsqueeze(1)).sum(0) / dS_safe,
|
||||
torch.zeros_like(g)) # Σ_b f_b·g_b on reliable axes
|
||||
step_flagged.append(f.mean().item())
|
||||
|
||||
@@ -45,6 +45,11 @@ class Config:
|
||||
vhack_pairs_path: Path = Path("out/pairsets/prog_wide_clean.json")
|
||||
|
||||
routeV_random_v_seed: int | None = None
|
||||
# >1: route by the oriented top-k SVD subspace (gate = max_i cos(g, v_i)) instead of the
|
||||
# k=1 mean-diff. The mean-diff is the only naturally hack-ward direction; SVD axes 2..k
|
||||
# have arbitrary sign, so each is re-oriented by sign(v_i . mean_diff). per-rollout
|
||||
# grad_cosine only (asserted in train.py).
|
||||
routeV_top_k: int = 1
|
||||
# pinning: how the routing band is calibrated. grad_cosine = fixed from the pairs'
|
||||
# clean/hack cosine gap; online_stats = live rolling quantile (online_stats_lo/hi);
|
||||
# act_vote = activation-direction vote.
|
||||
|
||||
Reference in New Issue
Block a user