diff --git a/src/vgrout/train.py b/src/vgrout/train.py index eba5020..0ffead3 100644 --- a/src/vgrout/train.py +++ b/src/vgrout/train.py @@ -208,6 +208,10 @@ class Config: # is the causal test. Refresh no-ops when set, so the direction stays the one fixed # random draw regardless of --vhack-refresh-every. route2_random_v_seed: int | None = None + # route2 granularity: False = route per ROLLOUT (sum tokens, one cos/f per rollout; + # the preregistered default, denoises the cos sign + matches GRPO per-rollout adv). + # True = route per TOKEN (one cos/f per token; finer but noisier). Ablation arm. + route2_per_token: bool = False # Per-source cin diagnostic: split each prompt's backward into student-only # + teacher-only passes (~2x backward time). 1 = every step (default; full # signal); N>1 = only every Nth step (combined backward elsewhere, ~halves @@ -879,32 +883,42 @@ def main(cfg: Config) -> int: def _route2_grad_filter(info, n_rollouts: int) -> torch.Tensor: g = info["delta_S"].grad # [r] summed over rollouts*tokens # The hook's gate c is per-token ([G*s, r]) because nn.Linear sees a - # flattened batch. Sum each rollout's token gate-grads -> per-rollout - # δS*g_b: reshape [G*s, r] -> [G, s, r] -> sum tokens -> [G, r]. - # Pad tokens carry ~0 grad (masked in the loss), so summing every - # position is safe. Per-rollout (not per-token) is the preregistered - # unit: GRPO advantage is per-rollout, and summing first denoises the - # cos(g_b, v_grad) sign (a clean rollout's individual tokens scatter - # ~50% over cos>0; its token-sum points reliably clean-ward). - cg = info["layer"]._antipasto_gate.grad.reshape(n_rollouts, -1, g.shape[0]).sum(1) # [G, r] + # flattened batch. reshape [G*s, r] -> [G, s, r]. Pad tokens carry ~0 grad + # (masked in the loss), so they contribute ~0 to routed regardless of unit. + cg_full = info["layer"]._antipasto_gate.grad.reshape(n_rollouts, -1, g.shape[0]) # [G, s, r] = δS*g dS = info["delta_S"].detach() # [r] reliable = dS.abs() > GATE_EPS # [r] dS_safe = torch.where(reliable, dS, torch.ones_like(dS)) - g_b = torch.where(reliable, cg / dS_safe, torch.zeros_like(cg)) # [G, r] per-rollout vg = v_grad[name] # [r] unit, hack-ward - cos_b = (g_b @ vg) / g_b.norm(dim=1).clamp_min(1e-12) # [G] - # Banded gate, calibrated from the PAIRS only (route_band[name]): a rollout - # whose grad cosine is below the clean edge is kept, above the hack edge is - # routed, and in between ramps proportionally (the absorption zone). No live - # detector, no teacher force-route -- v_grad is the sole router. f is the - # routed FRACTION of this rollout's grad (0..1). + # Banded gate, calibrated from the PAIRS only (route_band[name]): a unit whose + # grad cosine is below the clean edge is kept, above the hack edge is routed, + # in between ramps proportionally (absorption). v_grad is the sole router. + # f is the routed FRACTION (0..1). Granularity is the routing UNIT: + # per-rollout (default): sum tokens first -> one cos/f per rollout. Denoises + # the cos sign (a clean rollout's tokens scatter ~50% over cos>0; the + # token-sum points reliably clean-ward) and matches GRPO's per-rollout adv. + # per-token (route2_per_token): one cos/f per token -- finer but noisier. lower, upper = route_band[name] - f = ((cos_b - lower) / max(upper - lower, 1e-6)).clamp(0.0, 1.0) # [G] - step_flagged.append(f.mean().item()) + band = max(upper - lower, 1e-6) + if cfg.route2_per_token: + g_u = torch.where(reliable, cg_full / dS_safe, torch.zeros_like(cg_full)) # [G, s, r] + cos_u = (g_u @ vg) / g_u.norm(dim=2).clamp_min(1e-12) # [G, s] + f = ((cos_u - lower) / band).clamp(0.0, 1.0) # [G, s] + routed = torch.where(reliable, (cg_full * f.unsqueeze(-1)).sum((0, 1)) / dS_safe, + torch.zeros_like(g)) # Σ_{b,t} f·(δS·g) / δS + live = g_u.norm(dim=2) > 1e-8 # drop pad tokens from the gauges + step_flagged.append(f[live].mean().item() if live.any() else 0.0) + step_tau.append(cos_u[live].median().item() if live.any() else 0.0) + else: + cg = cg_full.sum(1) # [G, r] per-rollout + g_b = torch.where(reliable, cg / dS_safe, torch.zeros_like(cg)) # [G, r] + cos_b = (g_b @ vg) / g_b.norm(dim=1).clamp_min(1e-12) # [G] + f = ((cos_b - lower) / band).clamp(0.0, 1.0) # [G] + routed = torch.where(reliable, (cg * f.unsqueeze(1)).sum(0) / dS_safe, + torch.zeros_like(g)) # Σ_b f_b·g_b on reliable axes + step_flagged.append(f.mean().item()) + step_tau.append(cos_b.median().item()) # live cos centre vs the band step_hkgap.append(upper - lower) - step_tau.append(cos_b.median().item()) # live cos centre vs the band - routed = torch.where(reliable, (cg * f.unsqueeze(1)).sum(0) / dS_safe, - torch.zeros_like(g)) # Σ_b f_b·g_b on reliable axes # Park the routed fraction in δS_hack (deleted at deploy); δS keeps the rest. # routed + g_keep = g exactly (unreliable axes: routed=0, kept whole). step_grad_hack[name] = (step_grad_hack[name] + routed.detach().clone()