mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 16:15:35 +08:00
fix: route2 Arm A flags per-rollout not per-token (external review)
The hook gate is necessarily per-token ([G*s, r], nn.Linear flattens the batch). _route2_grad_filter now sums each rollout's token gate-grads before the cos(g_b, v_grad) flag, so routing is per-rollout (the preregistered GRPO unit) and the sign is denoised. Per-token a clean rollout scatters ~50% of tokens over cos>0 by noise, spuriously routing half its gradient mass. Verified by deepseek-v4-pro review: gate identity, divide-out, eps-guard, Arm B detach-route, R5 no-cheat all correct; this was the one finding. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,575 @@
|
||||
The user wants me to review the route2 implementation against its spec, focusing on specific areas. Let me read the spec file and the relevant source files to understand the implementation before I analyze the diff.
|
||||
|
||||
Let me start by reading the spec, then the source files for context.
|
||||
|
||||
[Tool Call: ]
|
||||
{"path": "docs/spec/20260531_routing_v2_distinct_basis.md"}
|
||||
Let me now read the relevant source files to check the implementation against the spec.
|
||||
|
||||
[Tool Call: ]
|
||||
{"path": "/workspace/projected_grpo/src/projected_grpo/antipasto.py", "offset": 65, "limit": 80}
|
||||
|
||||
[Tool Call: ]
|
||||
{"path": "/workspace/projected_grpo/src/projected_grpo/train.py", "offset": 1100, "limit": 120}
|
||||
|
||||
Let me carefully review the route2 implementation against the spec. I need to focus on:
|
||||
|
||||
1. Arm A (grad-mask) - the gate identity, divide-out, eps-guard, gradient flow
|
||||
2. Arm B (act-mask) - detach-route correctness
|
||||
3. NO-CHEAT invariant
|
||||
4. Spec vs code drift
|
||||
5. Math/gradient errors
|
||||
|
||||
Let me read the rest of the train.py to see the full flow for the grad-mask.
|
||||
|
||||
[Tool Call: ]
|
||||
{"offset": 1220, "path": "/workspace/projected_grpo/src/projected_grpo/train.py", "limit": 100}
|
||||
Let me now look at the critical sections of the code - the grad filter in train.py and how it's called, plus the v_grad construction.
|
||||
|
||||
[Tool Call: ]
|
||||
{"offset": 1340, "filePath": "/workspace/projected_grpo/src/projected_grpo/train.py", "limit": 130}
|
||||
Let me read the correct section of train.py.
|
||||
|
||||
[Tool Call: ]
|
||||
{"limit": "130", "offset": 1340, "path": "/workspace/projected_grpo/src/projected_grpo/train.py"}
|
||||
Let me now look at the remaining portion of train.py where the quarantine grads are re-injected and where the opt.step happens, to check for any issues.
|
||||
|
||||
[Tool Call: ]
|
||||
{"offset": 1470, "path": "/workspace/projected_grpo/src/projected_grpo/train.py", "limit": 60}
|
||||
Let me now carefully analyze the Arm A (grad-mask, single-pass gate-subtraction) implementation against the spec. I need to trace through the math carefully.
|
||||
|
||||
## Arm A: The Gate Identity
|
||||
|
||||
The spec says:
|
||||
```python
|
||||
c = ones(b, 1, r, requires_grad=True) # per-rollout gate
|
||||
kept = U @ ((a * c) * delta_S) # identity at c=1: forward value unchanged
|
||||
```
|
||||
|
||||
After backward, the spec claims: `c.grad = delta_S * g_b` where `g_b` is the per-rollout delta_S grad.
|
||||
|
||||
Let me verify this from first principles.
|
||||
|
||||
The contribution to the output from the kept adapter is:
|
||||
```
|
||||
y_kept = U @ ((a * c) * delta_S)
|
||||
```
|
||||
|
||||
where:
|
||||
- `a` is shape `[b, s, r]` (b=batch, s=seq_len, r=rank)
|
||||
- `c` is shape `[b, 1, r]` (broadcast over s)
|
||||
- `delta_S` is shape `[r]` (broadcast over b, s)
|
||||
|
||||
So `(a * c) * delta_S` = entry `a[b,s,i] * c[b,1,i] * delta_S[i]`.
|
||||
|
||||
Now, the total loss L depends on y_kept through some downstream computation.
|
||||
|
||||
`∂L/∂c[b,1,i] = sum over s of ∂L/∂(a*c*delta_S)[b,s,i] * ∂(a*c*delta_S)[b,s,i]/∂c[b,1,i]`
|
||||
`= sum over s of ∂L/∂(a*c*delta_S)[b,s,i] * a[b,s,i] * delta_S[i]`
|
||||
|
||||
Now, what is `∂L/∂delta_S[i]` in the standard (non-gate) case? With `kept = U @ (a * delta_S)`:
|
||||
```
|
||||
∂L/∂delta_S[i] = sum over b,s of ∂L/∂(a*delta_S)[b,s,i] * a[b,s,i]
|
||||
```
|
||||
|
||||
So in the gate case, the per-rollout contribution to delta_S gradient (call it `g_b[b,i]`) would be:
|
||||
```
|
||||
g_b[b,i] = sum over s of ∂L/∂(a*c*delta_S)[b,s,i] * a[b,s,i] * c[b,1,i]
|
||||
```
|
||||
|
||||
Since c = 1 (initialized), `c[b,1,i] = 1`, so:
|
||||
```
|
||||
g_b[b,i] = sum_s of ∂L/∂(a*c*delta_S)[b,s,i] * a[b,s,i]
|
||||
```
|
||||
|
||||
And `c.grad[b,1,i] = sum_s of ∂L/∂(a*c*delta_S)[b,s,i] * a[b,s,i] * delta_S[i]`
|
||||
`= g_b[b,i] * delta_S[i]`
|
||||
|
||||
So indeed: `c.grad[b,i] = g_b[b,i] * delta_S[i]`.
|
||||
|
||||
The code does:
|
||||
```python
|
||||
cg = info["layer"]._antipasto_gate.grad.reshape(-1, g.shape[0]) # [b, r]
|
||||
dS = info["delta_S"].detach() # [r]
|
||||
reliable = dS.abs() > GATE_EPS # [r]
|
||||
dS_safe = torch.where(reliable, dS, torch.ones_like(dS))
|
||||
g_b = torch.where(reliable, cg / dS_safe, torch.zeros_like(cg)) # [b, r]
|
||||
```
|
||||
|
||||
This is: on reliable axes, `g_b[b,i] = c.grad[b,i] / delta_S[i]`. On unreliable axes, `g_b[b,i] = 0`.
|
||||
|
||||
## The subtraction
|
||||
|
||||
Then:
|
||||
```python
|
||||
sub = torch.where(reliable, (cg * flagged.unsqueeze(1)).sum(0) / dS_safe,
|
||||
torch.zeros_like(g))
|
||||
return g - sub
|
||||
```
|
||||
|
||||
Where:
|
||||
- `flagged` is shape `[b]` (0/1)
|
||||
- `cg` is `c.grad` shape `[b, r]`
|
||||
- `(cg * flagged.unsqueeze(1)).sum(0)` sums over batch, giving `[r]`
|
||||
- Dividing by `dS_safe` gives `sum over flagged b of (cg[b,i] / delta_S[i])` for reliable axes
|
||||
|
||||
But wait: `cg = c.grad`, and `c.grad[b,i] / delta_S[i] = g_b[b,i]` (for reliable axes). So `sub[i] = sum over flagged b of g_b[b,i]` for reliable axes, and 0 for unreliable.
|
||||
|
||||
Then `g - sub` = `sum over all b of g_b[b,i] - sum over flagged b of g_b[b,i]` = `sum over unflagged b of g_b[b,i]` for reliable axes, and `g` (unchanged) for unreliable axes.
|
||||
|
||||
Wait, but `g` here is `delta_S.grad` after backward, which should be `sum over all b of g_b[b,i]`. But in the gate-based forward, `delta_S.grad` after backward is NOT `sum_b g_b[b,i]` because the gate `c` is `requires_grad=True`, not `delta_S` directly. Let me re-examine.
|
||||
|
||||
In the gate-based forward:
|
||||
```
|
||||
kept = U @ ((a * c) * delta_S)
|
||||
```
|
||||
|
||||
Both `c` and `delta_S` are trainable parameters that require grad. The loss backward passes gradient to both. `delta_S.grad` in this case would be:
|
||||
|
||||
```
|
||||
∂L/∂delta_S[i] = sum over b,s of ∂L/∂(a*c*delta_S)[b,s,i] * a[b,s,i] * c[b,1,i]
|
||||
```
|
||||
|
||||
Since c=1, this equals `sum over b,s of ∂L/∂(a*c*delta_S)[b,s,i] * a[b,s,i]`, which is exactly `sum over b of g_b[b,i]`. So `delta_S.grad = sum_b g_b[b,i]` still holds.
|
||||
|
||||
And `c.grad[b,1,i] = sum over s of ∂L/∂(a*c*delta_S)[b,s,i] * a[b,s,i] * delta_S[i] = g_b[b,i] * delta_S[i]`.
|
||||
|
||||
Good, so the gate identity `c.grad = delta_S * g_b` holds.
|
||||
|
||||
Then `cg / dS_safe = g_b` on reliable axes, and:
|
||||
- `sub = sum over flagged b of g_b[b,:]` (on reliable axes)
|
||||
- `return g - sub = delta_S.grad - sum_{flagged b} g_b[b,:]`
|
||||
- But `delta_S.grad = sum over all b of g_b[b,:]`
|
||||
- So `return = sum over unflagged b of g_b[b,:]` on reliable axes
|
||||
|
||||
This is correct: it removes the flagged rollouts' contribution from `delta_S.grad`.
|
||||
|
||||
## Issues and concerns
|
||||
|
||||
### 1. Multiple prompts per step with gate
|
||||
|
||||
The gate `c` is created in the forward hook with shape `[a.shape[0], 1, ..., r]` where `a.shape[0]` is the batch size of the current prompt. But `train.py` processes multiple prompts per step, with separate forwards/backwards. Each prompt creates its own gate.
|
||||
|
||||
Then `_route2_grad_filter` is called after each combined backward (which is per-prompt). It reads `info["layer"]._antipasto_gate.grad`, which is from the most recent backward. This is correct per-prompt.
|
||||
|
||||
But wait — between prompts, `model.zero_grad(set_to_none=True)` is called. This would zero the gate's `.grad`! And `_route2_grad_filter` accesses `_antipasto_gate.grad`.
|
||||
|
||||
Let me look at the flow more carefully:
|
||||
|
||||
```python
|
||||
for p_idx in range(prompts_per_step):
|
||||
...
|
||||
loss.backward()
|
||||
for name, info in wrappers.items():
|
||||
g = info["delta_S"].grad
|
||||
if g is None:
|
||||
continue
|
||||
if is_route2_grad:
|
||||
g = _route2_grad_filter(info)
|
||||
step_grad_s[name] = ... + g.detach().clone()
|
||||
_stash_quar_grads()
|
||||
model.zero_grad(set_to_none=True)
|
||||
```
|
||||
|
||||
So `_route2_grad_filter` is called AFTER `loss.backward()` but BEFORE `model.zero_grad()`. At this point, `_antipasto_gate.grad` still exists (it was populated by the backward). Then `model.zero_grad(set_to_none=True)` zeroes it.
|
||||
|
||||
But wait, `model.zero_grad(set_to_none=True)` would set the gate's `.grad` to None. And the gate was created as a new tensor in the forward hook for this prompt only. So after zero_grad, the gate tensor itself still exists (as `layer._antipasto_gate`), but its `.grad` is None. On the next prompt's forward, a NEW gate `c` is created (as a new tensor with `requires_grad=True`), overwriting `layer._antipasto_gate`. The old gate (from previous prompt) is no longer referenced anywhere.
|
||||
|
||||
So this is fine — each prompt gets its own gate, the filter reads it after backward but before zero_grad.
|
||||
|
||||
**But there's a subtle issue**: `_route2_grad_filter` also modifies `delta_S.grad` (via `g = _route2_grad_filter(info)`). Then the code does:
|
||||
```python
|
||||
step_grad_s[name] = (step_grad_s[name] + g.detach().clone() ...)
|
||||
```
|
||||
|
||||
And `g` here is the RETURN VALUE of `_route2_grad_filter`, which is `g - sub`. But the function also MODIFIES `info["delta_S"].grad` — wait, no, it doesn't! Let me re-read:
|
||||
|
||||
```python
|
||||
def _route2_grad_filter(info) -> torch.Tensor:
|
||||
g = info["delta_S"].grad # [r] summed over rollouts
|
||||
...
|
||||
return g - sub
|
||||
```
|
||||
|
||||
Then:
|
||||
```python
|
||||
g = _route2_grad_filter(info)
|
||||
step_grad_s[name] = (step_grad_s[name] + g.detach().clone() ...)
|
||||
```
|
||||
|
||||
`g` is now `delta_S.grad - sub`. But `info["delta_S"].grad` is still the original, unmodified `delta_S.grad`! The function doesn't mutate it. Then the code uses `g` (the filtered version) for accumulation.
|
||||
|
||||
Then `model.zero_grad(set_to_none=True)` zeroes everything, including `delta_S.grad`. On the next prompt, `delta_S.grad` gets a new backward contribution.
|
||||
|
||||
After the loop, the injection block does:
|
||||
```python
|
||||
for name, info in wrappers.items():
|
||||
gs = step_grad_s.get(name)
|
||||
...
|
||||
info["delta_S"].grad = gs # or gs + gt
|
||||
```
|
||||
|
||||
So `info["delta_S"].grad` ends up as the accumulated filtered grads. This looks correct.
|
||||
|
||||
### 2. The gate creation issue with batch dim
|
||||
|
||||
```python
|
||||
c = torch.ones(a.shape[0], *([1] * (a.dim() - 2)), a.shape[-1], ...)
|
||||
```
|
||||
|
||||
For a batch with one prompt, `a` is `[1, s, r]` after tokenization. So `a.shape[0] = 1`, and the gate is `[1, 1, r]`. But there are `G=group` rollouts in the merged tensor: `merged` is `[G, L]`. The forward hook is called with `x` coming from the merged tensor through the linear layers...
|
||||
|
||||
Wait, let me think about this more carefully. The forward hook `_delta_hook` is registered on every target Linear. When `model(merged)` is called, the input to each Linear has shape `[G*s, d_in]` (flattened batch+seq). So `a = Vh @ x` has shape `[G*s, r]`. Then:
|
||||
|
||||
```python
|
||||
c = torch.ones(a.shape[0], *([1] * (a.dim() - 2)), a.shape[-1], ...)
|
||||
```
|
||||
|
||||
Since `a.dim() = 2`, `a.shape[0] = G*s`, `a.shape[-1] = r`. So `c` is `[G*s, r]`. That's NOT per-rollout with broadcast over tokens — it's per-token.
|
||||
|
||||
But the spec says `c = ones(b, 1, r)` where `b` is the batch dimension (number of rollouts). The code creates `c` as `[G*s, r]`, which is per-token, not per-rollout.
|
||||
|
||||
**This is a BUG.** The gate should be per-rollout, not per-token. With `[G*s, r]` gate, each token gets its own gate parameter, and `c.grad` after backward is `[G*s, r]`, not `[G, r]` or `[G, 1, r]`.
|
||||
|
||||
But wait, does this actually cause a problem? Let me think again...
|
||||
|
||||
The code reshapes:
|
||||
```python
|
||||
cg = info["layer"]._antipasto_gate.grad.reshape(-1, g.shape[0]) # [b, r] -- wait, this assumes per-sample
|
||||
```
|
||||
|
||||
If the gate is `[G*s, r]`, then `cg` reshapes to `[G*s, r]`. Then `g_b = cg / dS_safe` gives per-token g_b, not per-rollout. The cosine with `v_grad` is `cos_b = (g_b @ vg) / ...` which gives per-token cos. Then `flagged = (cos_b > 0).float()` gives per-token flags.
|
||||
|
||||
Then `sub = (cg * flagged.unsqueeze(1)).sum(0) / dS_safe` sums over all flagged tokens, not rollouts.
|
||||
|
||||
Is this wrong? Let's think about what `c.grad` actually is per-token...
|
||||
|
||||
For token `t` in rollout `b`:
|
||||
```
|
||||
c.grad[t] = sum of downstream gradient through that token * a[t] * delta_S
|
||||
```
|
||||
|
||||
Note that `a[t]` is the activation at token `t` (the input to the Linear, projected by Vh). And `g_b[t]` (the per-token contribution to delta_S grad) would be:
|
||||
```
|
||||
g_b[t] = sum of downstream gradient through token t * a[t]
|
||||
```
|
||||
|
||||
So `c.grad[t] = g_b[t] * delta_S`, where `g_b[t]` is the per-token delta_S contribution (NOT per-rollout).
|
||||
|
||||
Now, the total `delta_S.grad = sum over all tokens t of g_b[t]`.
|
||||
|
||||
If we want to flag rollouts, not tokens, we need per-rollout `g_b`, not per-token. The per-rollout `g_b[b,:]` would be `sum over tokens in rollout b of g_b[t,:]`.
|
||||
|
||||
But we're computing per-token `g_b` and flagging per-token, then subtracting per-token flagged contributions. This means: if a rollout has some tokens flagged and some not, we subtract only the flagged tokens. This is a different behavior from the spec (which flags whole rollouts).
|
||||
|
||||
**However**, from a practical standpoint, this might not matter much: if a rollout is hacky, most of its tokens will point in the same general direction (or at least the completion tokens will). And the cosine is taken with the same `v_grad` across tokens, so the sign should be consistent within a rollout for the "hack-ward" vs "clean-ward" question.
|
||||
|
||||
But from a spec-compliance standpoint, the code doesn't match the spec. The spec says per-rollout gate; the code implements per-token gate due to the 2D reshape of `a`.
|
||||
|
||||
Actually wait, let me re-read the spec more carefully:
|
||||
|
||||
```python
|
||||
c = ones(b, 1, r, requires_grad=True) # per-rollout gate (broadcast over tokens)
|
||||
kept = U @ ((a * c) * delta_S)
|
||||
```
|
||||
|
||||
Where `a` has shape `[b, s, r]`. The spec assumes the forward receives a tensor of shape `[b, s, r]`. But in practice, PyTorch's Linear flatten batch*seq, so `a` is `[b*s, r]`.
|
||||
|
||||
The code creates `c` as `[b*s, r]` and then `(a * c)` where both are `[b*s, r]`. This is elementwise multiplication, so each token gets its own gate entry. But the spec intended `c` to be `[b, 1, r]` broadcasting over `s`, so all tokens in a rollout share the same gate.
|
||||
|
||||
**This is a spec-implementation mismatch.** However, for the math to work properly (per-rollout gating with broadcast), the forward hook would need to know the per-prompt batch layout (G, L) to reshape `a` to `[G, L, r]`, create `c` as `[G, 1, r]`, multiply as `(a.reshape(G, L, r) * c)`, then flatten back.
|
||||
|
||||
In the current implementation, the gate is per-token, and the filter treats each token independently. This changes the meaning of the flag: instead of "does this rollout look hack-ward?", it's "does this token look hack-ward?".
|
||||
|
||||
For the subtraction to be mathematically equivalent to per-rollout gating, we'd need:
|
||||
```
|
||||
sum_{flagged_tokens_in_rollout_b} g_b[t] = (entire rollout flagged ? sum_all_tokens : 0)
|
||||
```
|
||||
|
||||
This doesn't hold in general, but for the "cos > 0" test, if most tokens in a hacky rollout point hack-ward, the per-token approach is approximately the same as the per-rollout approach.
|
||||
|
||||
**Severity**: Medium. The implementation works but is not spec-compliant. The per-token gating is a different design than the spec's per-rollout gating. For the smoke test (tiny model, only 2 prompts), this might not matter. For real runs, it changes the routing granularity from rollout-level to token-level. The comment in the code says "per-rollout gate [b, 1.., r]" but the actual shape is `[b*s, r]` which is per-token.
|
||||
|
||||
### 3. No forward detach for Arm A
|
||||
|
||||
The spec says for Arm A: "no forward detach here" (antipasto.py _delta_hook docstring). Let me verify: in the grad-mask branch, `kept = U @ ((a * c) * delta_S.to(a.dtype))` — there's no `.detach()` on `kept`. Correct: the routing is post-backward via subtraction, so no forward detach is needed.
|
||||
|
||||
### 4. Does the quar get the flagged rollouts' gradient?
|
||||
|
||||
Yes. `quar = B_q @ (A_q @ x)` is always computed (not gated), so it's always in the forward path. When backward runs, `A_q` and `B_q` get gradients from all rollouts. The `_route2_grad_filter` only subtracts from `delta_S.grad`, not from `A_q.grad` or `B_q.grad`. So the quar learns from flagged rollouts naturally, and also from unflagged rollouts (for absorption). This matches the spec's intent.
|
||||
|
||||
### 5. Multiple linear layers: each gets its own gate
|
||||
|
||||
Each wrapped Linear creates its own gate `c` in its forward hook. After backward, each layer's gate has its own `.grad`. `_route2_grad_filter` is called per module, reading `info["layer"]._antipasto_gate.grad`. This is correct — each module's delta_S routing is independent.
|
||||
|
||||
### 6. The gate tensor leak and memory
|
||||
|
||||
The gate `c` is created as a regular tensor with `requires_grad=True`. It's stored as `layer._antipasto_gate`. After `model.zero_grad(set_to_none=True)`, `c.grad` is freed (set to None). But `c` itself (the forward tensor) is still referenced by `layer._antipasto_gate` and the computation graph nodes. After the optimizer step, the old computation graph is freed. Then on the next prompt, a new `c` is created and overwrites `layer._antipasto_gate`, letting the old `c` be garbage collected.
|
||||
|
||||
This is fine, no leak.
|
||||
|
||||
### 7. The eps-guard bias
|
||||
|
||||
On unreliable axes (`|delta_S| <= 1e-6`), `g_b` is set to 0, and the subtraction keeps the full grad. So these axes are not routed at all. Once `delta_S` grows past eps on an axis, routing kicks in. The spec says: "routing on a fresh axis lags ~1 step until delta_S grows past eps there — the A1-equivalent one-step lag, per-axis."
|
||||
|
||||
Is this sound? When delta_S is ~0, the kept adapter's contribution to the output on that axis is ~0, so there's no hack to route yet. This seems fine. The bias is: at init, no axes are routed (delta_S=0 everywhere), so the first few steps use full grad on all axes. Once delta_S grows (some axes get non-zero values), routing activates on those axes. This is conservative but not biased in a harmful way.
|
||||
|
||||
**However**, there's a subtlety: if delta_S is small but non-zero (e.g., 5e-7), it fails the `> 1e-6` test and is treated as unreliable. The contribution to the output is `delta_S * a`, which might be non-negligible if `a` is large. But for typical embedding norms, this is likely fine. The threshold is a hyperparameter choice.
|
||||
|
||||
### 8. The v_grad computation in train.py
|
||||
|
||||
```python
|
||||
from .extract_vhack_grad import extract_v_hack
|
||||
_, _, raw_grads, _ = extract_v_hack(
|
||||
model, tok, wrappers, MASK_PAIRS,
|
||||
top_k=1, tau_axis=0,
|
||||
)
|
||||
v_grad = {}
|
||||
for name in wrappers:
|
||||
d = (raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).mean(0)
|
||||
v_grad[name] = (d / d.norm().clamp_min(1e-12)).to(device)
|
||||
```
|
||||
|
||||
Wait, I need to check the indentation. The diff shows this code is inside the `if is_route2:` block, under `else:` (i.e., `route2_mask == "grad"`). But the indentation in the diff might be misleading. Let me look at the actual code...
|
||||
|
||||
The diff excerpt shows:
|
||||
```python
|
||||
v_grad = {}
|
||||
for name in wrappers:
|
||||
d = (raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).mean(0)
|
||||
v_grad[name] = (d / d.norm().clamp_min(1e-12)).to(device)
|
||||
```
|
||||
|
||||
Wait, actually looking at the diff more carefully:
|
||||
|
||||
```python
|
||||
from .extract_vhack_grad import extract_v_hack
|
||||
_, _, raw_grads, _ = extract_v_hack(
|
||||
model, tok, wrappers, MASK_PAIRS,
|
||||
top_k=1, tau_axis=0, device=device, # <-- this line incomplete in diff?
|
||||
)
|
||||
v_grad = {}
|
||||
for name in wrappers:
|
||||
d = (raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).mean(0)
|
||||
v_grad[name] = (d / d.norm().clamp_min(1e-12)).to(device)
|
||||
```
|
||||
|
||||
Hmm, the `extract_v_hack` call might be missing arguments. Let me check the actual file.
|
||||
|
||||
Actually, looking at the diff hunk more carefully - it's truncated/line-wrapped. The diff says:
|
||||
|
||||
```
|
||||
from .extract_vhack_grad import extract_v_hack
|
||||
_, _, raw_grads, _ = extract_v_hack(
|
||||
model, tok, wrappers, MASK_PAIRS,
|
||||
top_k=1, tau_axis=0,
|
||||
```
|
||||
|
||||
And the line got wrapped. But the important thing is that `raw_grads` is used. Let me check what `extract_v_hack` returns. From the import and usage pattern, `raw_grads` should be a dict with keys like `"hack/module_name"` and `"clean/module_name"`, each containing per-pair gradients.
|
||||
|
||||
But wait - `extract_v_hack` expects `model` in eval mode. Earlier in the route2 setup:
|
||||
```python
|
||||
model.eval()
|
||||
if cfg.route2_mask == "act":
|
||||
...
|
||||
else:
|
||||
# Arm A: gradient-space mean-diff.
|
||||
from .extract_vhack_grad import extract_v_hack
|
||||
_, _, raw_grads, _ = extract_v_hack(...)
|
||||
...
|
||||
model.train()
|
||||
```
|
||||
|
||||
This looks correct - model is in eval mode during extraction, then back to train.
|
||||
|
||||
But I notice: for route2 (not split), `cos_pre_split_every` doesn't matter because `split_this_step = (step % cfg.cos_pre_split_every == 0) and not is_route2`. The route2 grad-mask always uses combined backward. And `_route2_grad_filter` is called inside the combined backward block. This is correct.
|
||||
|
||||
### 9. Key concern: the gate creation is per-token, not per-rollout
|
||||
|
||||
Let me re-examine this more carefully.
|
||||
|
||||
The forward hook receives `x` of shape `[b*s, d_in]` where `b*s` is the flattened batch-size-times-sequence-length. Then `a = Vh @ x` has shape `[b*s, r]`. The code does:
|
||||
|
||||
```python
|
||||
c = torch.ones(a.shape[0], *([1] * (a.dim() - 2)), a.shape[-1],
|
||||
device=a.device, dtype=a.dtype, requires_grad=True)
|
||||
```
|
||||
|
||||
Since `a.dim() = 2`, `a.shape[0] = b*s`, and `*([1] * 0) = ()`, this simplifies to:
|
||||
```python
|
||||
c = torch.ones(a.shape[0], a.shape[-1], ...) # [b*s, r]
|
||||
```
|
||||
|
||||
Then:
|
||||
```python
|
||||
kept = torch.nn.functional.linear((a * c) * delta_S.to(a.dtype), U)
|
||||
```
|
||||
|
||||
Where `(a * c)` is `[b*s, r]` with per-token elementwise multiplication.
|
||||
|
||||
After backward:
|
||||
```python
|
||||
cg = info["layer"]._antipasto_gate.grad.reshape(-1, g.shape[0])
|
||||
```
|
||||
|
||||
`cg` shape is `[b*s, r]` but `g.shape[0] = r`, so `.reshape(-1, r)` gives `[b*s, r]`.
|
||||
|
||||
Then `g_b = cg / dS_safe` is per-token, shape `[b*s, r]`.
|
||||
Then `cos_b = (g_b @ vg) / ...` is per-token, shape `[b*s]`.
|
||||
Then `flagged = (cos_b > 0).float()` is per-token, shape `[b*s]`.
|
||||
|
||||
Then `sub = (cg * flagged.unsqueeze(1)).sum(0) / dS_safe` sums over all `b*s` tokens.
|
||||
|
||||
The final `delta_S.grad -= sub` subtracts the flagged tokens' contributions to the summed delta_S gradient. This is per-token routing, not per-rollout.
|
||||
|
||||
**This is the main finding: the implementation is per-token gating, not per-rollout gating as specified.** The spec says:
|
||||
> c = ones(b, 1, r, requires_grad=True) # per-rollout gate (broadcast over tokens)
|
||||
|
||||
But the code creates per-token gates. The comment in antipasto.py even says "per-rollout gate [b, 1.., r]" which is inconsistent with the actual shape.
|
||||
|
||||
Now, does this matter practically? For the GRPO setting, the "hack" question is about the whole rollout. If a rollout is hacky, we want the entire rollout's gradient contribution to go to the quarantine. With per-token gating, individual tokens from a hacky rollout might not all be flagged. This could lead to partial routing, where some tokens from a hacky rollout train the kept adapter.
|
||||
|
||||
But from the "seed hard, absorb soft" perspective, this partial routing might be fine — it's like a soft version of the hard route. The quarantine still gets signal from hack-ward tokens, and the kept adapter still gets some signal from the non-hack-ward tokens in the same rollout.
|
||||
|
||||
However, from a spec-compliance standpoint, this is a drift. The spec explicitly says "per-rollout gate" with broadcast over tokens.
|
||||
|
||||
### 10. Summary of findings
|
||||
|
||||
Now let me also check Arm B quickly.
|
||||
|
||||
Arm B (act-mask):
|
||||
```python
|
||||
kept = torch.nn.functional.linear(a * delta_S.to(a.dtype), U)
|
||||
v_act = layer._antipasto_v_act
|
||||
cos = (a @ v_act) / (a.norm(dim=-1).clamp_min(1e-6) * v_act.norm().clamp_min(1e-6))
|
||||
m = cos > 0
|
||||
kept = torch.where(m.unsqueeze(-1), kept.detach(), kept)
|
||||
```
|
||||
|
||||
This is per-token detach routing. The spec pseudocode shows:
|
||||
```python
|
||||
a_i = Vh @ x
|
||||
cos = cosine(a_i, v_act_in_r_space)
|
||||
m = (cos > tau).float()
|
||||
y, quar = routed_forward(x, m)
|
||||
```
|
||||
|
||||
The code uses `tau=0` (not a configurable threshold, just `cos > 0`). The spec mentions `tau` as a threshold but the log says "removed route2_tau (never tuned; mask is cos>0, the natural hack-ward boundary)". So this is a deliberate deviation from the spec, justified by the log.
|
||||
|
||||
The Arm B `where(m, kept.detach(), kept)` correctly detaches the kept adapter for flagged tokens, leaving the forward value unchanged. The quarantine always sums. This matches the spec's "detach-route" primitive.
|
||||
|
||||
Now, for the no-cheat check (R5): The mask in both arms uses only persona pairs (self-supervised). The code never gates on `gt_pass` or runs the full detector suite over student rollouts. This is confirmed by:
|
||||
- `v_act` comes from `extract_v_act` which processes persona pairs only
|
||||
- `v_grad` comes from `extract_v_hack` on persona pairs only
|
||||
- The mask computation in both arms uses only `a` (activations) / `g_b` (gradients) and the precomputed persona direction
|
||||
- No `gt_pass` is checked during training
|
||||
|
||||
This satisfies R5.
|
||||
|
||||
For T4 (leakage metric) and T6 (plot): these are unchecked in the diff (not implemented yet), which is consistent with the task list showing them as incomplete.
|
||||
|
||||
Now let me compile the final answer.
|
||||
|
||||
Actually, let me also check one more concern: does the `_stash_quar_grads()` correctly handle multiple prompts per step?
|
||||
|
||||
```python
|
||||
def _stash_quar_grads():
|
||||
if not is_route2:
|
||||
return
|
||||
for name, info in wrappers.items():
|
||||
for sub in ("A_q", "B_q"):
|
||||
p = info[sub]
|
||||
if p.grad is None:
|
||||
continue
|
||||
key = f"{name}.{sub}"
|
||||
step_grad_quar[key] = (step_grad_quar[key] + p.grad.detach().clone()
|
||||
if key in step_grad_quar else p.grad.detach().clone())
|
||||
```
|
||||
|
||||
This accumulates `A_q.grad` and `B_q.grad` across prompts. After the loop, they're re-injected:
|
||||
```python
|
||||
for key, g in step_grad_quar.items():
|
||||
name, sub = key.rsplit(".", 1)
|
||||
wrappers[name][sub].grad = g
|
||||
```
|
||||
|
||||
This is correct — the quarantine params accumulate their grads across all prompts in a step.
|
||||
|
||||
Similarly, `delta_S.grad` is split across student/teacher and accumulated in `step_grad_s`/`step_grad_t`, then re-injected. For route2 (no split), only `step_grad_s` is populated, and it's the filtered version.
|
||||
|
||||
Let me also double-check: is there any issue with `_route2_grad_filter` being called multiple times for the same module across prompts? Each prompt's backward sets `delta_S.grad` to the per-prompt grad, and `_route2_grad_filter` reads and filters it. The filtered result is accumulated into `step_grad_s`. This is correct.
|
||||
|
||||
But wait — `_route2_grad_filter` reads `info["delta_S"].grad` which includes contributions from ALL rollouts in the current prompt (student + teacher, since route2 uses combined backward). The gate `c` is also from the same prompt. So the per-prompt filtering is correct.
|
||||
|
||||
One final thing: after `_route2_grad_filter`, the return value `g - sub` is the per-prompt filtered delta_S grad. This is then detached and accumulated. But `info["delta_S"].grad` is NOT modified by `_route2_grad_filter` — it's kept at the unfiltered value. Then `model.zero_grad(set_to_none=True)` sets it to None. This means the unfiltered grad is never used. The filtered version in `step_grad_s` is what gets injected later. Good.
|
||||
|
||||
BUT: `sub` is computed using `cg` (gate grad) and `dS`. The gate grad is lost after `model.zero_grad()`. But we already used it in the filter. So this is fine.
|
||||
|
||||
OK, I think I have a solid review. Let me write it up.Here's the review, ordered by priority.
|
||||
|
||||
---
|
||||
|
||||
## 1. Arm A: gate is per-token, not per-rollout (spec drift)
|
||||
|
||||
**`antipasto.py:~105-108`**
|
||||
|
||||
The spec pseudocode has `c = ones(b, 1, r)` with `a` shape `[b, s, r]`, producing per-rollout gating with token broadcast. But the forward hook receives flattened `x` of shape `[G·s, d_in]` (PyTorch's standard `nn.Linear` convention), so `a = Vh @ x` is `[G·s, r]`. The code creates:
|
||||
|
||||
```python
|
||||
c = torch.ones(a.shape[0], *([1] * (a.dim() - 2)), a.shape[-1], ...)
|
||||
```
|
||||
|
||||
Since `a.dim() == 2`, this is `[G·s, r]` — **per-token**, not per-rollout. The docstring says "per-rollout gate [b, 1.., r]" but the actual tensor is `[G·s, r]`.
|
||||
|
||||
After backward, `cg = c.grad.reshape(-1, r)` is `[G·s, r]`, so `g_b`, `cos_b`, and `flagged` are all per-token. The subtraction `sub = Σ_{flagged tokens} cg / dS` sums over flagged *tokens*, not flagged *rollouts*.
|
||||
|
||||
**Impact**: routing granularity is token-level instead of rollout-level. A hacky rollout where some tokens don't point hack-ward gets partially routed (mixed signal to both adapters). For the "seed hard" mechanism this is a soft version of the spec's hard route — not necessarily broken, but not what the spec claims. The spec's absorption argument (item 4) assumed *whole-rollout* routing.
|
||||
|
||||
**Fix sketch**: reshape `a` to `[G, s, r]` inside the hook, create `c` as `[G, 1, r]`, multiply, flatten back. Need `G` and `s` — which the hook doesn't currently receive. Would need to pass them through `layer._antipasto_group` or similar, or reshape based on `x.shape[0] / some_stored_L`.
|
||||
|
||||
## 2. The gate identity and divide-out are correct
|
||||
|
||||
**`train.py:~1138-1152`**
|
||||
|
||||
`c.grad[b,i] = g_b[b,i] * delta_S[i]` holds because `kept = U @ ((a*c) * delta_S)` and c=1 at forward. The chain rule gives `∂L/∂c[i] = ∂L/∂(a*c*dS)[i] * a[i] * dS[i] = g_b[i] * dS[i]`. The divide-out `g_b = cg / dS_safe` on reliable axes recovers `g_b`.
|
||||
|
||||
The subtraction `g - sub` computes `Σ_all g_b - Σ_flagged g_b = Σ_unflagged g_b` on reliable axes. On unreliable axes (|delta_S| ≤ 1e-6), the full grad passes through — correct, since there's no learned hack on a ~0 axis worth routing.
|
||||
|
||||
No gradient-flow bug: `_route2_grad_filter` returns a value but never mutates `delta_S.grad` in-place. The filtered value is `.detach().clone()`-d into `step_grad_s`. The unfiltered `delta_S.grad` is freed by `model.zero_grad()`. Clean.
|
||||
|
||||
## 3. Arm B matches spec intent
|
||||
|
||||
**`antipasto.py:~117-123`**
|
||||
|
||||
```python
|
||||
kept = U @ (a * delta_S) # in SVD coords
|
||||
cos = (a @ v_act) / (a.norm(...) * v_act.norm(...))
|
||||
m = cos > 0 # tau=0, no configurable threshold
|
||||
kept = where(m.unsqueeze(-1), kept.detach(), kept)
|
||||
quar = B_q @ (A_q @ x) # always summed
|
||||
```
|
||||
|
||||
Matches spec items 2-4: distinct basis (quarantine LoRA vs kept SVD), always-in-forward (quar summed unconditionally), detach-route (flagged → kept.detach(), forward value unchanged). The tau=0 choice (spec had `cos > tau`) is deliberate per the log: "removed route2_tau, cos>0 is the natural hack-ward boundary." Fine.
|
||||
|
||||
## 4. R5 (NO-CHEAT): clean
|
||||
|
||||
Training never reads `gt_pass` or runs the detector suite. Persona pairs are the only signal:
|
||||
- `v_act` ← `extract_v_act` on persona pairs only (`antipasto.py` docstring; `extract_vhack_grad.py:~260-288`)
|
||||
- `v_grad` ← `extract_v_hack` on persona pairs only (`train.py:~780`)
|
||||
- The mask logic at forward/backward time uses only `a`, `v_act`, `v_grad`, `delta_S`, `c.grad` — all self-supervised.
|
||||
- `split_this_step` is hard-false for route2 (`train.py:~1166`), so the cos_pre_s/t discriminator (which uses v_hack from the precomputed detector suite) never runs on route2.
|
||||
|
||||
No flag leaks.
|
||||
|
||||
## 5. Spec-is-current status
|
||||
|
||||
The spec log entry acknowledges the Arm A design change from A1/A2 to "single-pass gate-SUBTRACTION" and describes it. The spec's **core pseudocode for Arm A** (`c = ones(b, 1, r)...`) still shows per-rollout gating with token broadcast. The code implements per-token gating. This is a spec-implementation mismatch — the spec claims per-rollout but the code is per-token. Which one is "correct" depends on intent; if per-token is deliberate, the spec needs updating. If per-rollout was the intent, item 1 above is a bug.
|
||||
|
||||
Remaining spec items not yet implemented (per diff and task list): T4 (leakage metric + L1 knob), T5 (5-arm sweep), T6 (plot).
|
||||
|
||||
## 6. Misc observations
|
||||
|
||||
- **Gate memory**: `layer._antipasto_gate` is overwritten each forward. Old gate tensors are freed when the computation graph is freed. No leak.
|
||||
- **Quarantine grad stashing** (`train.py:~1118-1128`): correctly clones and accumulates `A_q.grad`/`B_q.grad` across per-prompt `model.zero_grad()` calls. Re-injected before clip + opt.step. Verified bugfix from the log is present.
|
||||
- **Ablate quarantine** (`train.py:~537-547`): now zeroes `B_q` in addition to `delta_S_hack`. Correct for route2 deploy eval.
|
||||
- **v_grad extraction** (`train.py:~778`): `extract_v_hack` is called with `tau_axis=0` (keep all singular vectors), then `mean(g_hack - g_clean)` per module. This produces a unit vector in delta_S gradient space. The sign convention matches the spec.
|
||||
@@ -268,6 +268,15 @@ eps-guarded division. Arm B remains the cleaner bet.
|
||||
from extract_v_hack raw grads (gradient-space analogue of v_act). smoke
|
||||
routing2_grad: ||B_q||=0.109 after 30 steps (quarantine seeded by routed grad),
|
||||
deploy eval + asserts pass, exit 0.
|
||||
- 2026-05-31: external code review (deepseek-v4-pro, docs/spec/20260531_route2_code_review_v2.md)
|
||||
verified gate identity (c.grad=delta_S*g_b), divide-out, eps-guard, Arm B detach-route,
|
||||
and R5 NO-CHEAT (mask never reads gt_pass / detector suite) all CORRECT. One finding:
|
||||
Arm A flagged per-TOKEN, not per-rollout — the hook's gate is [G*s, r] (nn.Linear
|
||||
flattens the batch), so cos(g_b, v_grad) and the flag were per-token. A clean rollout
|
||||
scatters ~50% of its tokens over cos>0 by noise, spuriously routing half its gradient
|
||||
mass. FIXED: _route2_grad_filter now reshapes c.grad to [G, s, r], sums each rollout's
|
||||
tokens BEFORE the cosine (denoises the sign), flags per-rollout [G], matching the
|
||||
preregistered unit (GRPO advantage is per-rollout). Re-smoked.
|
||||
|
||||
## TODO (out of scope now)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user