route2: add per-token routing granularity (route2_per_token), default per-rollout

Ablation arm requested by the user: route the banded gate per TOKEN (one cos/f
per token) instead of per ROLLOUT (sum tokens first). Per-rollout stays the
default (denoises the cos sign, matches GRPO per-rollout advantage). Per-token
uses the same pair-calibrated band; gauges (frout/tau) mask pad tokens
(|g_tok|<1e-8) so the ~0-grad positions don't skew them. Conservation
(routed+kept=g) holds in both. Both paths smoke green.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-06 04:52:30 +00:00
parent aca045ec99
commit dd922d8793
+34 -20
View File
@@ -208,6 +208,10 @@ class Config:
# is the causal test. Refresh no-ops when set, so the direction stays the one fixed
# random draw regardless of --vhack-refresh-every.
route2_random_v_seed: int | None = None
# route2 granularity: False = route per ROLLOUT (sum tokens, one cos/f per rollout;
# the preregistered default, denoises the cos sign + matches GRPO per-rollout adv).
# True = route per TOKEN (one cos/f per token; finer but noisier). Ablation arm.
route2_per_token: bool = False
# Per-source cin diagnostic: split each prompt's backward into student-only
# + teacher-only passes (~2x backward time). 1 = every step (default; full
# signal); N>1 = only every Nth step (combined backward elsewhere, ~halves
@@ -879,32 +883,42 @@ def main(cfg: Config) -> int:
def _route2_grad_filter(info, n_rollouts: int) -> torch.Tensor:
g = info["delta_S"].grad # [r] summed over rollouts*tokens
# The hook's gate c is per-token ([G*s, r]) because nn.Linear sees a
# flattened batch. Sum each rollout's token gate-grads -> per-rollout
# δS*g_b: reshape [G*s, r] -> [G, s, r] -> sum tokens -> [G, r].
# Pad tokens carry ~0 grad (masked in the loss), so summing every
# position is safe. Per-rollout (not per-token) is the preregistered
# unit: GRPO advantage is per-rollout, and summing first denoises the
# cos(g_b, v_grad) sign (a clean rollout's individual tokens scatter
# ~50% over cos>0; its token-sum points reliably clean-ward).
cg = info["layer"]._antipasto_gate.grad.reshape(n_rollouts, -1, g.shape[0]).sum(1) # [G, r]
# flattened batch. reshape [G*s, r] -> [G, s, r]. Pad tokens carry ~0 grad
# (masked in the loss), so they contribute ~0 to routed regardless of unit.
cg_full = info["layer"]._antipasto_gate.grad.reshape(n_rollouts, -1, g.shape[0]) # [G, s, r] = δS*g
dS = info["delta_S"].detach() # [r]
reliable = dS.abs() > GATE_EPS # [r]
dS_safe = torch.where(reliable, dS, torch.ones_like(dS))
g_b = torch.where(reliable, cg / dS_safe, torch.zeros_like(cg)) # [G, r] per-rollout
vg = v_grad[name] # [r] unit, hack-ward
cos_b = (g_b @ vg) / g_b.norm(dim=1).clamp_min(1e-12) # [G]
# Banded gate, calibrated from the PAIRS only (route_band[name]): a rollout
# whose grad cosine is below the clean edge is kept, above the hack edge is
# routed, and in between ramps proportionally (the absorption zone). No live
# detector, no teacher force-route -- v_grad is the sole router. f is the
# routed FRACTION of this rollout's grad (0..1).
# Banded gate, calibrated from the PAIRS only (route_band[name]): a unit whose
# grad cosine is below the clean edge is kept, above the hack edge is routed,
# in between ramps proportionally (absorption). v_grad is the sole router.
# f is the routed FRACTION (0..1). Granularity is the routing UNIT:
# per-rollout (default): sum tokens first -> one cos/f per rollout. Denoises
# the cos sign (a clean rollout's tokens scatter ~50% over cos>0; the
# token-sum points reliably clean-ward) and matches GRPO's per-rollout adv.
# per-token (route2_per_token): one cos/f per token -- finer but noisier.
lower, upper = route_band[name]
f = ((cos_b - lower) / max(upper - lower, 1e-6)).clamp(0.0, 1.0) # [G]
step_flagged.append(f.mean().item())
band = max(upper - lower, 1e-6)
if cfg.route2_per_token:
g_u = torch.where(reliable, cg_full / dS_safe, torch.zeros_like(cg_full)) # [G, s, r]
cos_u = (g_u @ vg) / g_u.norm(dim=2).clamp_min(1e-12) # [G, s]
f = ((cos_u - lower) / band).clamp(0.0, 1.0) # [G, s]
routed = torch.where(reliable, (cg_full * f.unsqueeze(-1)).sum((0, 1)) / dS_safe,
torch.zeros_like(g)) # Σ_{b,t} f·(δS·g) / δS
live = g_u.norm(dim=2) > 1e-8 # drop pad tokens from the gauges
step_flagged.append(f[live].mean().item() if live.any() else 0.0)
step_tau.append(cos_u[live].median().item() if live.any() else 0.0)
else:
cg = cg_full.sum(1) # [G, r] per-rollout
g_b = torch.where(reliable, cg / dS_safe, torch.zeros_like(cg)) # [G, r]
cos_b = (g_b @ vg) / g_b.norm(dim=1).clamp_min(1e-12) # [G]
f = ((cos_b - lower) / band).clamp(0.0, 1.0) # [G]
routed = torch.where(reliable, (cg * f.unsqueeze(1)).sum(0) / dS_safe,
torch.zeros_like(g)) # Σ_b f_b·g_b on reliable axes
step_flagged.append(f.mean().item())
step_tau.append(cos_b.median().item()) # live cos centre vs the band
step_hkgap.append(upper - lower)
step_tau.append(cos_b.median().item()) # live cos centre vs the band
routed = torch.where(reliable, (cg * f.unsqueeze(1)).sum(0) / dS_safe,
torch.zeros_like(g)) # Σ_b f_b·g_b on reliable axes
# Park the routed fraction in δS_hack (deleted at deploy); δS keeps the rest.
# routed + g_keep = g exactly (unreliable axes: routed=0, kept whole).
step_grad_hack[name] = (step_grad_hack[name] + routed.detach().clone()