The hook gate is necessarily per-token ([G*s, r], nn.Linear flattens the batch). _route2_grad_filter now sums each rollout's token gate-grads before the cos(g_b, v_grad) flag, so routing is per-rollout (the preregistered GRPO unit) and the sign is denoised. Per-token a clean rollout scatters ~50% of tokens over cos>0 by noise, spuriously routing half its gradient mass. Verified by deepseek-v4-pro review: gate identity, divide-out, eps-guard, Arm B detach-route, R5 no-cheat all correct; this was the one finding. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
31 KiB
The user wants me to review the route2 implementation against its spec, focusing on specific areas. Let me read the spec file and the relevant source files to understand the implementation before I analyze the diff.
Let me start by reading the spec, then the source files for context.
[Tool Call: ] {"path": "docs/spec/20260531_routing_v2_distinct_basis.md"} Let me now read the relevant source files to check the implementation against the spec.
[Tool Call: ] {"path": "/workspace/projected_grpo/src/projected_grpo/antipasto.py", "offset": 65, "limit": 80}
[Tool Call: ] {"path": "/workspace/projected_grpo/src/projected_grpo/train.py", "offset": 1100, "limit": 120}
Let me carefully review the route2 implementation against the spec. I need to focus on:
- Arm A (grad-mask) - the gate identity, divide-out, eps-guard, gradient flow
- Arm B (act-mask) - detach-route correctness
- NO-CHEAT invariant
- Spec vs code drift
- Math/gradient errors
Let me read the rest of the train.py to see the full flow for the grad-mask.
[Tool Call: ] {"offset": 1220, "path": "/workspace/projected_grpo/src/projected_grpo/train.py", "limit": 100} Let me now look at the critical sections of the code - the grad filter in train.py and how it's called, plus the v_grad construction.
[Tool Call: ] {"offset": 1340, "filePath": "/workspace/projected_grpo/src/projected_grpo/train.py", "limit": 130} Let me read the correct section of train.py.
[Tool Call: ] {"limit": "130", "offset": 1340, "path": "/workspace/projected_grpo/src/projected_grpo/train.py"} Let me now look at the remaining portion of train.py where the quarantine grads are re-injected and where the opt.step happens, to check for any issues.
[Tool Call: ] {"offset": 1470, "path": "/workspace/projected_grpo/src/projected_grpo/train.py", "limit": 60} Let me now carefully analyze the Arm A (grad-mask, single-pass gate-subtraction) implementation against the spec. I need to trace through the math carefully.
Arm A: The Gate Identity
The spec says:
c = ones(b, 1, r, requires_grad=True) # per-rollout gate
kept = U @ ((a * c) * delta_S) # identity at c=1: forward value unchanged
After backward, the spec claims: c.grad = delta_S * g_b where g_b is the per-rollout delta_S grad.
Let me verify this from first principles.
The contribution to the output from the kept adapter is:
y_kept = U @ ((a * c) * delta_S)
where:
ais shape[b, s, r](b=batch, s=seq_len, r=rank)cis shape[b, 1, r](broadcast over s)delta_Sis shape[r](broadcast over b, s)
So (a * c) * delta_S = entry a[b,s,i] * c[b,1,i] * delta_S[i].
Now, the total loss L depends on y_kept through some downstream computation.
∂L/∂c[b,1,i] = sum over s of ∂L/∂(a*c*delta_S)[b,s,i] * ∂(a*c*delta_S)[b,s,i]/∂c[b,1,i]
= sum over s of ∂L/∂(a*c*delta_S)[b,s,i] * a[b,s,i] * delta_S[i]
Now, what is ∂L/∂delta_S[i] in the standard (non-gate) case? With kept = U @ (a * delta_S):
∂L/∂delta_S[i] = sum over b,s of ∂L/∂(a*delta_S)[b,s,i] * a[b,s,i]
So in the gate case, the per-rollout contribution to delta_S gradient (call it g_b[b,i]) would be:
g_b[b,i] = sum over s of ∂L/∂(a*c*delta_S)[b,s,i] * a[b,s,i] * c[b,1,i]
Since c = 1 (initialized), c[b,1,i] = 1, so:
g_b[b,i] = sum_s of ∂L/∂(a*c*delta_S)[b,s,i] * a[b,s,i]
And c.grad[b,1,i] = sum_s of ∂L/∂(a*c*delta_S)[b,s,i] * a[b,s,i] * delta_S[i]
= g_b[b,i] * delta_S[i]
So indeed: c.grad[b,i] = g_b[b,i] * delta_S[i].
The code does:
cg = info["layer"]._antipasto_gate.grad.reshape(-1, g.shape[0]) # [b, r]
dS = info["delta_S"].detach() # [r]
reliable = dS.abs() > GATE_EPS # [r]
dS_safe = torch.where(reliable, dS, torch.ones_like(dS))
g_b = torch.where(reliable, cg / dS_safe, torch.zeros_like(cg)) # [b, r]
This is: on reliable axes, g_b[b,i] = c.grad[b,i] / delta_S[i]. On unreliable axes, g_b[b,i] = 0.
The subtraction
Then:
sub = torch.where(reliable, (cg * flagged.unsqueeze(1)).sum(0) / dS_safe,
torch.zeros_like(g))
return g - sub
Where:
flaggedis shape[b](0/1)cgisc.gradshape[b, r](cg * flagged.unsqueeze(1)).sum(0)sums over batch, giving[r]- Dividing by
dS_safegivessum over flagged b of (cg[b,i] / delta_S[i])for reliable axes
But wait: cg = c.grad, and c.grad[b,i] / delta_S[i] = g_b[b,i] (for reliable axes). So sub[i] = sum over flagged b of g_b[b,i] for reliable axes, and 0 for unreliable.
Then g - sub = sum over all b of g_b[b,i] - sum over flagged b of g_b[b,i] = sum over unflagged b of g_b[b,i] for reliable axes, and g (unchanged) for unreliable axes.
Wait, but g here is delta_S.grad after backward, which should be sum over all b of g_b[b,i]. But in the gate-based forward, delta_S.grad after backward is NOT sum_b g_b[b,i] because the gate c is requires_grad=True, not delta_S directly. Let me re-examine.
In the gate-based forward:
kept = U @ ((a * c) * delta_S)
Both c and delta_S are trainable parameters that require grad. The loss backward passes gradient to both. delta_S.grad in this case would be:
∂L/∂delta_S[i] = sum over b,s of ∂L/∂(a*c*delta_S)[b,s,i] * a[b,s,i] * c[b,1,i]
Since c=1, this equals sum over b,s of ∂L/∂(a*c*delta_S)[b,s,i] * a[b,s,i], which is exactly sum over b of g_b[b,i]. So delta_S.grad = sum_b g_b[b,i] still holds.
And c.grad[b,1,i] = sum over s of ∂L/∂(a*c*delta_S)[b,s,i] * a[b,s,i] * delta_S[i] = g_b[b,i] * delta_S[i].
Good, so the gate identity c.grad = delta_S * g_b holds.
Then cg / dS_safe = g_b on reliable axes, and:
sub = sum over flagged b of g_b[b,:](on reliable axes)return g - sub = delta_S.grad - sum_{flagged b} g_b[b,:]- But
delta_S.grad = sum over all b of g_b[b,:] - So
return = sum over unflagged b of g_b[b,:]on reliable axes
This is correct: it removes the flagged rollouts' contribution from delta_S.grad.
Issues and concerns
1. Multiple prompts per step with gate
The gate c is created in the forward hook with shape [a.shape[0], 1, ..., r] where a.shape[0] is the batch size of the current prompt. But train.py processes multiple prompts per step, with separate forwards/backwards. Each prompt creates its own gate.
Then _route2_grad_filter is called after each combined backward (which is per-prompt). It reads info["layer"]._antipasto_gate.grad, which is from the most recent backward. This is correct per-prompt.
But wait — between prompts, model.zero_grad(set_to_none=True) is called. This would zero the gate's .grad! And _route2_grad_filter accesses _antipasto_gate.grad.
Let me look at the flow more carefully:
for p_idx in range(prompts_per_step):
...
loss.backward()
for name, info in wrappers.items():
g = info["delta_S"].grad
if g is None:
continue
if is_route2_grad:
g = _route2_grad_filter(info)
step_grad_s[name] = ... + g.detach().clone()
_stash_quar_grads()
model.zero_grad(set_to_none=True)
So _route2_grad_filter is called AFTER loss.backward() but BEFORE model.zero_grad(). At this point, _antipasto_gate.grad still exists (it was populated by the backward). Then model.zero_grad(set_to_none=True) zeroes it.
But wait, model.zero_grad(set_to_none=True) would set the gate's .grad to None. And the gate was created as a new tensor in the forward hook for this prompt only. So after zero_grad, the gate tensor itself still exists (as layer._antipasto_gate), but its .grad is None. On the next prompt's forward, a NEW gate c is created (as a new tensor with requires_grad=True), overwriting layer._antipasto_gate. The old gate (from previous prompt) is no longer referenced anywhere.
So this is fine — each prompt gets its own gate, the filter reads it after backward but before zero_grad.
But there's a subtle issue: _route2_grad_filter also modifies delta_S.grad (via g = _route2_grad_filter(info)). Then the code does:
step_grad_s[name] = (step_grad_s[name] + g.detach().clone() ...)
And g here is the RETURN VALUE of _route2_grad_filter, which is g - sub. But the function also MODIFIES info["delta_S"].grad — wait, no, it doesn't! Let me re-read:
def _route2_grad_filter(info) -> torch.Tensor:
g = info["delta_S"].grad # [r] summed over rollouts
...
return g - sub
Then:
g = _route2_grad_filter(info)
step_grad_s[name] = (step_grad_s[name] + g.detach().clone() ...)
g is now delta_S.grad - sub. But info["delta_S"].grad is still the original, unmodified delta_S.grad! The function doesn't mutate it. Then the code uses g (the filtered version) for accumulation.
Then model.zero_grad(set_to_none=True) zeroes everything, including delta_S.grad. On the next prompt, delta_S.grad gets a new backward contribution.
After the loop, the injection block does:
for name, info in wrappers.items():
gs = step_grad_s.get(name)
...
info["delta_S"].grad = gs # or gs + gt
So info["delta_S"].grad ends up as the accumulated filtered grads. This looks correct.
2. The gate creation issue with batch dim
c = torch.ones(a.shape[0], *([1] * (a.dim() - 2)), a.shape[-1], ...)
For a batch with one prompt, a is [1, s, r] after tokenization. So a.shape[0] = 1, and the gate is [1, 1, r]. But there are G=group rollouts in the merged tensor: merged is [G, L]. The forward hook is called with x coming from the merged tensor through the linear layers...
Wait, let me think about this more carefully. The forward hook _delta_hook is registered on every target Linear. When model(merged) is called, the input to each Linear has shape [G*s, d_in] (flattened batch+seq). So a = Vh @ x has shape [G*s, r]. Then:
c = torch.ones(a.shape[0], *([1] * (a.dim() - 2)), a.shape[-1], ...)
Since a.dim() = 2, a.shape[0] = G*s, a.shape[-1] = r. So c is [G*s, r]. That's NOT per-rollout with broadcast over tokens — it's per-token.
But the spec says c = ones(b, 1, r) where b is the batch dimension (number of rollouts). The code creates c as [G*s, r], which is per-token, not per-rollout.
This is a BUG. The gate should be per-rollout, not per-token. With [G*s, r] gate, each token gets its own gate parameter, and c.grad after backward is [G*s, r], not [G, r] or [G, 1, r].
But wait, does this actually cause a problem? Let me think again...
The code reshapes:
cg = info["layer"]._antipasto_gate.grad.reshape(-1, g.shape[0]) # [b, r] -- wait, this assumes per-sample
If the gate is [G*s, r], then cg reshapes to [G*s, r]. Then g_b = cg / dS_safe gives per-token g_b, not per-rollout. The cosine with v_grad is cos_b = (g_b @ vg) / ... which gives per-token cos. Then flagged = (cos_b > 0).float() gives per-token flags.
Then sub = (cg * flagged.unsqueeze(1)).sum(0) / dS_safe sums over all flagged tokens, not rollouts.
Is this wrong? Let's think about what c.grad actually is per-token...
For token t in rollout b:
c.grad[t] = sum of downstream gradient through that token * a[t] * delta_S
Note that a[t] is the activation at token t (the input to the Linear, projected by Vh). And g_b[t] (the per-token contribution to delta_S grad) would be:
g_b[t] = sum of downstream gradient through token t * a[t]
So c.grad[t] = g_b[t] * delta_S, where g_b[t] is the per-token delta_S contribution (NOT per-rollout).
Now, the total delta_S.grad = sum over all tokens t of g_b[t].
If we want to flag rollouts, not tokens, we need per-rollout g_b, not per-token. The per-rollout g_b[b,:] would be sum over tokens in rollout b of g_b[t,:].
But we're computing per-token g_b and flagging per-token, then subtracting per-token flagged contributions. This means: if a rollout has some tokens flagged and some not, we subtract only the flagged tokens. This is a different behavior from the spec (which flags whole rollouts).
However, from a practical standpoint, this might not matter much: if a rollout is hacky, most of its tokens will point in the same general direction (or at least the completion tokens will). And the cosine is taken with the same v_grad across tokens, so the sign should be consistent within a rollout for the "hack-ward" vs "clean-ward" question.
But from a spec-compliance standpoint, the code doesn't match the spec. The spec says per-rollout gate; the code implements per-token gate due to the 2D reshape of a.
Actually wait, let me re-read the spec more carefully:
c = ones(b, 1, r, requires_grad=True) # per-rollout gate (broadcast over tokens)
kept = U @ ((a * c) * delta_S)
Where a has shape [b, s, r]. The spec assumes the forward receives a tensor of shape [b, s, r]. But in practice, PyTorch's Linear flatten batch*seq, so a is [b*s, r].
The code creates c as [b*s, r] and then (a * c) where both are [b*s, r]. This is elementwise multiplication, so each token gets its own gate entry. But the spec intended c to be [b, 1, r] broadcasting over s, so all tokens in a rollout share the same gate.
This is a spec-implementation mismatch. However, for the math to work properly (per-rollout gating with broadcast), the forward hook would need to know the per-prompt batch layout (G, L) to reshape a to [G, L, r], create c as [G, 1, r], multiply as (a.reshape(G, L, r) * c), then flatten back.
In the current implementation, the gate is per-token, and the filter treats each token independently. This changes the meaning of the flag: instead of "does this rollout look hack-ward?", it's "does this token look hack-ward?".
For the subtraction to be mathematically equivalent to per-rollout gating, we'd need:
sum_{flagged_tokens_in_rollout_b} g_b[t] = (entire rollout flagged ? sum_all_tokens : 0)
This doesn't hold in general, but for the "cos > 0" test, if most tokens in a hacky rollout point hack-ward, the per-token approach is approximately the same as the per-rollout approach.
Severity: Medium. The implementation works but is not spec-compliant. The per-token gating is a different design than the spec's per-rollout gating. For the smoke test (tiny model, only 2 prompts), this might not matter. For real runs, it changes the routing granularity from rollout-level to token-level. The comment in the code says "per-rollout gate [b, 1.., r]" but the actual shape is [b*s, r] which is per-token.
3. No forward detach for Arm A
The spec says for Arm A: "no forward detach here" (antipasto.py _delta_hook docstring). Let me verify: in the grad-mask branch, kept = U @ ((a * c) * delta_S.to(a.dtype)) — there's no .detach() on kept. Correct: the routing is post-backward via subtraction, so no forward detach is needed.
4. Does the quar get the flagged rollouts' gradient?
Yes. quar = B_q @ (A_q @ x) is always computed (not gated), so it's always in the forward path. When backward runs, A_q and B_q get gradients from all rollouts. The _route2_grad_filter only subtracts from delta_S.grad, not from A_q.grad or B_q.grad. So the quar learns from flagged rollouts naturally, and also from unflagged rollouts (for absorption). This matches the spec's intent.
5. Multiple linear layers: each gets its own gate
Each wrapped Linear creates its own gate c in its forward hook. After backward, each layer's gate has its own .grad. _route2_grad_filter is called per module, reading info["layer"]._antipasto_gate.grad. This is correct — each module's delta_S routing is independent.
6. The gate tensor leak and memory
The gate c is created as a regular tensor with requires_grad=True. It's stored as layer._antipasto_gate. After model.zero_grad(set_to_none=True), c.grad is freed (set to None). But c itself (the forward tensor) is still referenced by layer._antipasto_gate and the computation graph nodes. After the optimizer step, the old computation graph is freed. Then on the next prompt, a new c is created and overwrites layer._antipasto_gate, letting the old c be garbage collected.
This is fine, no leak.
7. The eps-guard bias
On unreliable axes (|delta_S| <= 1e-6), g_b is set to 0, and the subtraction keeps the full grad. So these axes are not routed at all. Once delta_S grows past eps on an axis, routing kicks in. The spec says: "routing on a fresh axis lags ~1 step until delta_S grows past eps there — the A1-equivalent one-step lag, per-axis."
Is this sound? When delta_S is ~0, the kept adapter's contribution to the output on that axis is ~0, so there's no hack to route yet. This seems fine. The bias is: at init, no axes are routed (delta_S=0 everywhere), so the first few steps use full grad on all axes. Once delta_S grows (some axes get non-zero values), routing activates on those axes. This is conservative but not biased in a harmful way.
However, there's a subtlety: if delta_S is small but non-zero (e.g., 5e-7), it fails the > 1e-6 test and is treated as unreliable. The contribution to the output is delta_S * a, which might be non-negligible if a is large. But for typical embedding norms, this is likely fine. The threshold is a hyperparameter choice.
8. The v_grad computation in train.py
from .extract_vhack_grad import extract_v_hack
_, _, raw_grads, _ = extract_v_hack(
model, tok, wrappers, MASK_PAIRS,
top_k=1, tau_axis=0,
)
v_grad = {}
for name in wrappers:
d = (raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).mean(0)
v_grad[name] = (d / d.norm().clamp_min(1e-12)).to(device)
Wait, I need to check the indentation. The diff shows this code is inside the if is_route2: block, under else: (i.e., route2_mask == "grad"). But the indentation in the diff might be misleading. Let me look at the actual code...
The diff excerpt shows:
v_grad = {}
for name in wrappers:
d = (raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).mean(0)
v_grad[name] = (d / d.norm().clamp_min(1e-12)).to(device)
Wait, actually looking at the diff more carefully:
from .extract_vhack_grad import extract_v_hack
_, _, raw_grads, _ = extract_v_hack(
model, tok, wrappers, MASK_PAIRS,
top_k=1, tau_axis=0, device=device, # <-- this line incomplete in diff?
)
v_grad = {}
for name in wrappers:
d = (raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).mean(0)
v_grad[name] = (d / d.norm().clamp_min(1e-12)).to(device)
Hmm, the extract_v_hack call might be missing arguments. Let me check the actual file.
Actually, looking at the diff hunk more carefully - it's truncated/line-wrapped. The diff says:
from .extract_vhack_grad import extract_v_hack
_, _, raw_grads, _ = extract_v_hack(
model, tok, wrappers, MASK_PAIRS,
top_k=1, tau_axis=0,
And the line got wrapped. But the important thing is that raw_grads is used. Let me check what extract_v_hack returns. From the import and usage pattern, raw_grads should be a dict with keys like "hack/module_name" and "clean/module_name", each containing per-pair gradients.
But wait - extract_v_hack expects model in eval mode. Earlier in the route2 setup:
model.eval()
if cfg.route2_mask == "act":
...
else:
# Arm A: gradient-space mean-diff.
from .extract_vhack_grad import extract_v_hack
_, _, raw_grads, _ = extract_v_hack(...)
...
model.train()
This looks correct - model is in eval mode during extraction, then back to train.
But I notice: for route2 (not split), cos_pre_split_every doesn't matter because split_this_step = (step % cfg.cos_pre_split_every == 0) and not is_route2. The route2 grad-mask always uses combined backward. And _route2_grad_filter is called inside the combined backward block. This is correct.
9. Key concern: the gate creation is per-token, not per-rollout
Let me re-examine this more carefully.
The forward hook receives x of shape [b*s, d_in] where b*s is the flattened batch-size-times-sequence-length. Then a = Vh @ x has shape [b*s, r]. The code does:
c = torch.ones(a.shape[0], *([1] * (a.dim() - 2)), a.shape[-1],
device=a.device, dtype=a.dtype, requires_grad=True)
Since a.dim() = 2, a.shape[0] = b*s, and *([1] * 0) = (), this simplifies to:
c = torch.ones(a.shape[0], a.shape[-1], ...) # [b*s, r]
Then:
kept = torch.nn.functional.linear((a * c) * delta_S.to(a.dtype), U)
Where (a * c) is [b*s, r] with per-token elementwise multiplication.
After backward:
cg = info["layer"]._antipasto_gate.grad.reshape(-1, g.shape[0])
cg shape is [b*s, r] but g.shape[0] = r, so .reshape(-1, r) gives [b*s, r].
Then g_b = cg / dS_safe is per-token, shape [b*s, r].
Then cos_b = (g_b @ vg) / ... is per-token, shape [b*s].
Then flagged = (cos_b > 0).float() is per-token, shape [b*s].
Then sub = (cg * flagged.unsqueeze(1)).sum(0) / dS_safe sums over all b*s tokens.
The final delta_S.grad -= sub subtracts the flagged tokens' contributions to the summed delta_S gradient. This is per-token routing, not per-rollout.
This is the main finding: the implementation is per-token gating, not per-rollout gating as specified. The spec says:
c = ones(b, 1, r, requires_grad=True) # per-rollout gate (broadcast over tokens)
But the code creates per-token gates. The comment in antipasto.py even says "per-rollout gate [b, 1.., r]" which is inconsistent with the actual shape.
Now, does this matter practically? For the GRPO setting, the "hack" question is about the whole rollout. If a rollout is hacky, we want the entire rollout's gradient contribution to go to the quarantine. With per-token gating, individual tokens from a hacky rollout might not all be flagged. This could lead to partial routing, where some tokens from a hacky rollout train the kept adapter.
But from the "seed hard, absorb soft" perspective, this partial routing might be fine — it's like a soft version of the hard route. The quarantine still gets signal from hack-ward tokens, and the kept adapter still gets some signal from the non-hack-ward tokens in the same rollout.
However, from a spec-compliance standpoint, this is a drift. The spec explicitly says "per-rollout gate" with broadcast over tokens.
10. Summary of findings
Now let me also check Arm B quickly.
Arm B (act-mask):
kept = torch.nn.functional.linear(a * delta_S.to(a.dtype), U)
v_act = layer._antipasto_v_act
cos = (a @ v_act) / (a.norm(dim=-1).clamp_min(1e-6) * v_act.norm().clamp_min(1e-6))
m = cos > 0
kept = torch.where(m.unsqueeze(-1), kept.detach(), kept)
This is per-token detach routing. The spec pseudocode shows:
a_i = Vh @ x
cos = cosine(a_i, v_act_in_r_space)
m = (cos > tau).float()
y, quar = routed_forward(x, m)
The code uses tau=0 (not a configurable threshold, just cos > 0). The spec mentions tau as a threshold but the log says "removed route2_tau (never tuned; mask is cos>0, the natural hack-ward boundary)". So this is a deliberate deviation from the spec, justified by the log.
The Arm B where(m, kept.detach(), kept) correctly detaches the kept adapter for flagged tokens, leaving the forward value unchanged. The quarantine always sums. This matches the spec's "detach-route" primitive.
Now, for the no-cheat check (R5): The mask in both arms uses only persona pairs (self-supervised). The code never gates on gt_pass or runs the full detector suite over student rollouts. This is confirmed by:
v_actcomes fromextract_v_actwhich processes persona pairs onlyv_gradcomes fromextract_v_hackon persona pairs only- The mask computation in both arms uses only
a(activations) /g_b(gradients) and the precomputed persona direction - No
gt_passis checked during training
This satisfies R5.
For T4 (leakage metric) and T6 (plot): these are unchecked in the diff (not implemented yet), which is consistent with the task list showing them as incomplete.
Now let me compile the final answer.
Actually, let me also check one more concern: does the _stash_quar_grads() correctly handle multiple prompts per step?
def _stash_quar_grads():
if not is_route2:
return
for name, info in wrappers.items():
for sub in ("A_q", "B_q"):
p = info[sub]
if p.grad is None:
continue
key = f"{name}.{sub}"
step_grad_quar[key] = (step_grad_quar[key] + p.grad.detach().clone()
if key in step_grad_quar else p.grad.detach().clone())
This accumulates A_q.grad and B_q.grad across prompts. After the loop, they're re-injected:
for key, g in step_grad_quar.items():
name, sub = key.rsplit(".", 1)
wrappers[name][sub].grad = g
This is correct — the quarantine params accumulate their grads across all prompts in a step.
Similarly, delta_S.grad is split across student/teacher and accumulated in step_grad_s/step_grad_t, then re-injected. For route2 (no split), only step_grad_s is populated, and it's the filtered version.
Let me also double-check: is there any issue with _route2_grad_filter being called multiple times for the same module across prompts? Each prompt's backward sets delta_S.grad to the per-prompt grad, and _route2_grad_filter reads and filters it. The filtered result is accumulated into step_grad_s. This is correct.
But wait — _route2_grad_filter reads info["delta_S"].grad which includes contributions from ALL rollouts in the current prompt (student + teacher, since route2 uses combined backward). The gate c is also from the same prompt. So the per-prompt filtering is correct.
One final thing: after _route2_grad_filter, the return value g - sub is the per-prompt filtered delta_S grad. This is then detached and accumulated. But info["delta_S"].grad is NOT modified by _route2_grad_filter — it's kept at the unfiltered value. Then model.zero_grad(set_to_none=True) sets it to None. This means the unfiltered grad is never used. The filtered version in step_grad_s is what gets injected later. Good.
BUT: sub is computed using cg (gate grad) and dS. The gate grad is lost after model.zero_grad(). But we already used it in the filter. So this is fine.
OK, I think I have a solid review. Let me write it up.Here's the review, ordered by priority.
1. Arm A: gate is per-token, not per-rollout (spec drift)
antipasto.py:~105-108
The spec pseudocode has c = ones(b, 1, r) with a shape [b, s, r], producing per-rollout gating with token broadcast. But the forward hook receives flattened x of shape [G·s, d_in] (PyTorch's standard nn.Linear convention), so a = Vh @ x is [G·s, r]. The code creates:
c = torch.ones(a.shape[0], *([1] * (a.dim() - 2)), a.shape[-1], ...)
Since a.dim() == 2, this is [G·s, r] — per-token, not per-rollout. The docstring says "per-rollout gate [b, 1.., r]" but the actual tensor is [G·s, r].
After backward, cg = c.grad.reshape(-1, r) is [G·s, r], so g_b, cos_b, and flagged are all per-token. The subtraction sub = Σ_{flagged tokens} cg / dS sums over flagged tokens, not flagged rollouts.
Impact: routing granularity is token-level instead of rollout-level. A hacky rollout where some tokens don't point hack-ward gets partially routed (mixed signal to both adapters). For the "seed hard" mechanism this is a soft version of the spec's hard route — not necessarily broken, but not what the spec claims. The spec's absorption argument (item 4) assumed whole-rollout routing.
Fix sketch: reshape a to [G, s, r] inside the hook, create c as [G, 1, r], multiply, flatten back. Need G and s — which the hook doesn't currently receive. Would need to pass them through layer._antipasto_group or similar, or reshape based on x.shape[0] / some_stored_L.
2. The gate identity and divide-out are correct
train.py:~1138-1152
c.grad[b,i] = g_b[b,i] * delta_S[i] holds because kept = U @ ((a*c) * delta_S) and c=1 at forward. The chain rule gives ∂L/∂c[i] = ∂L/∂(a*c*dS)[i] * a[i] * dS[i] = g_b[i] * dS[i]. The divide-out g_b = cg / dS_safe on reliable axes recovers g_b.
The subtraction g - sub computes Σ_all g_b - Σ_flagged g_b = Σ_unflagged g_b on reliable axes. On unreliable axes (|delta_S| ≤ 1e-6), the full grad passes through — correct, since there's no learned hack on a ~0 axis worth routing.
No gradient-flow bug: _route2_grad_filter returns a value but never mutates delta_S.grad in-place. The filtered value is .detach().clone()-d into step_grad_s. The unfiltered delta_S.grad is freed by model.zero_grad(). Clean.
3. Arm B matches spec intent
antipasto.py:~117-123
kept = U @ (a * delta_S) # in SVD coords
cos = (a @ v_act) / (a.norm(...) * v_act.norm(...))
m = cos > 0 # tau=0, no configurable threshold
kept = where(m.unsqueeze(-1), kept.detach(), kept)
quar = B_q @ (A_q @ x) # always summed
Matches spec items 2-4: distinct basis (quarantine LoRA vs kept SVD), always-in-forward (quar summed unconditionally), detach-route (flagged → kept.detach(), forward value unchanged). The tau=0 choice (spec had cos > tau) is deliberate per the log: "removed route2_tau, cos>0 is the natural hack-ward boundary." Fine.
4. R5 (NO-CHEAT): clean
Training never reads gt_pass or runs the detector suite. Persona pairs are the only signal:
v_act←extract_v_acton persona pairs only (antipasto.pydocstring;extract_vhack_grad.py:~260-288)v_grad←extract_v_hackon persona pairs only (train.py:~780)- The mask logic at forward/backward time uses only
a,v_act,v_grad,delta_S,c.grad— all self-supervised. split_this_stepis hard-false for route2 (train.py:~1166), so the cos_pre_s/t discriminator (which uses v_hack from the precomputed detector suite) never runs on route2.
No flag leaks.
5. Spec-is-current status
The spec log entry acknowledges the Arm A design change from A1/A2 to "single-pass gate-SUBTRACTION" and describes it. The spec's core pseudocode for Arm A (c = ones(b, 1, r)...) still shows per-rollout gating with token broadcast. The code implements per-token gating. This is a spec-implementation mismatch — the spec claims per-rollout but the code is per-token. Which one is "correct" depends on intent; if per-token is deliberate, the spec needs updating. If per-rollout was the intent, item 1 above is a bug.
Remaining spec items not yet implemented (per diff and task list): T4 (leakage metric + L1 knob), T5 (5-arm sweep), T6 (plot).
6. Misc observations
- Gate memory:
layer._antipasto_gateis overwritten each forward. Old gate tensors are freed when the computation graph is freed. No leak. - Quarantine grad stashing (
train.py:~1118-1128): correctly clones and accumulatesA_q.grad/B_q.gradacross per-promptmodel.zero_grad()calls. Re-injected before clip + opt.step. Verified bugfix from the log is present. - Ablate quarantine (
train.py:~537-547): now zeroesB_qin addition todelta_S_hack. Correct for route2 deploy eval. - v_grad extraction (
train.py:~778):extract_v_hackis called withtau_axis=0(keep all singular vectors), thenmean(g_hack - g_clean)per module. This produces a unit vector in delta_S gradient space. The sign convention matches the spec.