mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 16:30:30 +08:00
route2: add per-token routing granularity (route2_per_token), default per-rollout
Ablation arm requested by the user: route the banded gate per TOKEN (one cos/f per token) instead of per ROLLOUT (sum tokens first). Per-rollout stays the default (denoises the cos sign, matches GRPO per-rollout advantage). Per-token uses the same pair-calibrated band; gauges (frout/tau) mask pad tokens (|g_tok|<1e-8) so the ~0-grad positions don't skew them. Conservation (routed+kept=g) holds in both. Both paths smoke green. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
+34
-20
@@ -208,6 +208,10 @@ class Config:
|
||||
# is the causal test. Refresh no-ops when set, so the direction stays the one fixed
|
||||
# random draw regardless of --vhack-refresh-every.
|
||||
route2_random_v_seed: int | None = None
|
||||
# route2 granularity: False = route per ROLLOUT (sum tokens, one cos/f per rollout;
|
||||
# the preregistered default, denoises the cos sign + matches GRPO per-rollout adv).
|
||||
# True = route per TOKEN (one cos/f per token; finer but noisier). Ablation arm.
|
||||
route2_per_token: bool = False
|
||||
# Per-source cin diagnostic: split each prompt's backward into student-only
|
||||
# + teacher-only passes (~2x backward time). 1 = every step (default; full
|
||||
# signal); N>1 = only every Nth step (combined backward elsewhere, ~halves
|
||||
@@ -879,32 +883,42 @@ def main(cfg: Config) -> int:
|
||||
def _route2_grad_filter(info, n_rollouts: int) -> torch.Tensor:
|
||||
g = info["delta_S"].grad # [r] summed over rollouts*tokens
|
||||
# The hook's gate c is per-token ([G*s, r]) because nn.Linear sees a
|
||||
# flattened batch. Sum each rollout's token gate-grads -> per-rollout
|
||||
# δS*g_b: reshape [G*s, r] -> [G, s, r] -> sum tokens -> [G, r].
|
||||
# Pad tokens carry ~0 grad (masked in the loss), so summing every
|
||||
# position is safe. Per-rollout (not per-token) is the preregistered
|
||||
# unit: GRPO advantage is per-rollout, and summing first denoises the
|
||||
# cos(g_b, v_grad) sign (a clean rollout's individual tokens scatter
|
||||
# ~50% over cos>0; its token-sum points reliably clean-ward).
|
||||
cg = info["layer"]._antipasto_gate.grad.reshape(n_rollouts, -1, g.shape[0]).sum(1) # [G, r]
|
||||
# flattened batch. reshape [G*s, r] -> [G, s, r]. Pad tokens carry ~0 grad
|
||||
# (masked in the loss), so they contribute ~0 to routed regardless of unit.
|
||||
cg_full = info["layer"]._antipasto_gate.grad.reshape(n_rollouts, -1, g.shape[0]) # [G, s, r] = δS*g
|
||||
dS = info["delta_S"].detach() # [r]
|
||||
reliable = dS.abs() > GATE_EPS # [r]
|
||||
dS_safe = torch.where(reliable, dS, torch.ones_like(dS))
|
||||
g_b = torch.where(reliable, cg / dS_safe, torch.zeros_like(cg)) # [G, r] per-rollout
|
||||
vg = v_grad[name] # [r] unit, hack-ward
|
||||
cos_b = (g_b @ vg) / g_b.norm(dim=1).clamp_min(1e-12) # [G]
|
||||
# Banded gate, calibrated from the PAIRS only (route_band[name]): a rollout
|
||||
# whose grad cosine is below the clean edge is kept, above the hack edge is
|
||||
# routed, and in between ramps proportionally (the absorption zone). No live
|
||||
# detector, no teacher force-route -- v_grad is the sole router. f is the
|
||||
# routed FRACTION of this rollout's grad (0..1).
|
||||
# Banded gate, calibrated from the PAIRS only (route_band[name]): a unit whose
|
||||
# grad cosine is below the clean edge is kept, above the hack edge is routed,
|
||||
# in between ramps proportionally (absorption). v_grad is the sole router.
|
||||
# f is the routed FRACTION (0..1). Granularity is the routing UNIT:
|
||||
# per-rollout (default): sum tokens first -> one cos/f per rollout. Denoises
|
||||
# the cos sign (a clean rollout's tokens scatter ~50% over cos>0; the
|
||||
# token-sum points reliably clean-ward) and matches GRPO's per-rollout adv.
|
||||
# per-token (route2_per_token): one cos/f per token -- finer but noisier.
|
||||
lower, upper = route_band[name]
|
||||
f = ((cos_b - lower) / max(upper - lower, 1e-6)).clamp(0.0, 1.0) # [G]
|
||||
step_flagged.append(f.mean().item())
|
||||
band = max(upper - lower, 1e-6)
|
||||
if cfg.route2_per_token:
|
||||
g_u = torch.where(reliable, cg_full / dS_safe, torch.zeros_like(cg_full)) # [G, s, r]
|
||||
cos_u = (g_u @ vg) / g_u.norm(dim=2).clamp_min(1e-12) # [G, s]
|
||||
f = ((cos_u - lower) / band).clamp(0.0, 1.0) # [G, s]
|
||||
routed = torch.where(reliable, (cg_full * f.unsqueeze(-1)).sum((0, 1)) / dS_safe,
|
||||
torch.zeros_like(g)) # Σ_{b,t} f·(δS·g) / δS
|
||||
live = g_u.norm(dim=2) > 1e-8 # drop pad tokens from the gauges
|
||||
step_flagged.append(f[live].mean().item() if live.any() else 0.0)
|
||||
step_tau.append(cos_u[live].median().item() if live.any() else 0.0)
|
||||
else:
|
||||
cg = cg_full.sum(1) # [G, r] per-rollout
|
||||
g_b = torch.where(reliable, cg / dS_safe, torch.zeros_like(cg)) # [G, r]
|
||||
cos_b = (g_b @ vg) / g_b.norm(dim=1).clamp_min(1e-12) # [G]
|
||||
f = ((cos_b - lower) / band).clamp(0.0, 1.0) # [G]
|
||||
routed = torch.where(reliable, (cg * f.unsqueeze(1)).sum(0) / dS_safe,
|
||||
torch.zeros_like(g)) # Σ_b f_b·g_b on reliable axes
|
||||
step_flagged.append(f.mean().item())
|
||||
step_tau.append(cos_b.median().item()) # live cos centre vs the band
|
||||
step_hkgap.append(upper - lower)
|
||||
step_tau.append(cos_b.median().item()) # live cos centre vs the band
|
||||
routed = torch.where(reliable, (cg * f.unsqueeze(1)).sum(0) / dS_safe,
|
||||
torch.zeros_like(g)) # Σ_b f_b·g_b on reliable axes
|
||||
# Park the routed fraction in δS_hack (deleted at deploy); δS keeps the rest.
|
||||
# routed + g_keep = g exactly (unreliable axes: routed=0, kept whole).
|
||||
step_grad_hack[name] = (step_grad_hack[name] + routed.detach().clone()
|
||||
|
||||
Reference in New Issue
Block a user