test: add mixed-batch per-rollout routing gate to verify_lora2r_routing (T8)

2a-2c only tested UNIFORM masks. 2d puts rollout 0 clean (0,0) and rollout 1 hack
(1,1) in ONE forward and asserts the mixed deployed grad == rollout-0-alone-clean
and the mixed quarantine grad == rollout-1-alone-hack -- the load-bearing
per-rollout mask vectorization ([G,1,1] reshape) with no cross-rollout bleed.
Green on tiny-random.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-10 11:24:49 +00:00
parent 4ee5c27f7b
commit 9fd2b6b89b
+32
View File
@@ -70,6 +70,38 @@ print(f"2c. mid (m=1,d=0): dep grad {dep_n:.2e} > 0, quar grad {quar_n:.2e} > 0
model.zero_grad(set_to_none=True)
# 2d. MIXED batch: rollout 0 clean (0,0), rollout 1 hack (1,1) in ONE forward. This
# is the load-bearing per-rollout vectorization (2a-2c only test uniform masks). The
# masks reshape to [G,1,1], so rollout 0 must route to deployed only, rollout 1 to
# quarantine only, with NO bleed. Loss summed over sequences -> per-rollout grads are
# additive and separable, so the mixed deployed grad must equal rollout-0-alone-clean,
# and the mixed quarantine grad must equal rollout-1-alone-hack.
def block_grads(m_vec: torch.Tensor, d_vec: torch.Tensor, batch: torch.Tensor) -> tuple[dict, dict]:
model.zero_grad(set_to_none=True)
for info in wrappers.values():
info["layer"]._lora2r_mask = (m_vec, d_vec)
model(batch).logits.float().pow(2).sum().backward() # sum -> per-sequence additive
for info in wrappers.values():
info["layer"]._lora2r_mask = None
dep = {n: (i["A"].grad[:i["r"]].clone(), i["B"].grad[:, :i["r"]].clone()) for n, i in wrappers.items()}
quar = {n: (i["A"].grad[i["r"]:].clone(), i["B"].grad[:, i["r"]:].clone()) for n, i in wrappers.items()}
return dep, quar
dep_mix, quar_mix = block_grads(torch.tensor([0., 1.]), torch.tensor([0., 1.]), ids) # r0 clean, r1 hack
dep_r0, _ = block_grads(torch.zeros(1), torch.zeros(1), ids[:1]) # r0 alone, clean
_, quar_r1 = block_grads(torch.ones(1), torch.ones(1), ids[1:]) # r1 alone, hack
for n in wrappers:
assert torch.allclose(dep_mix[n][0], dep_r0[n][0], atol=1e-5) and \
torch.allclose(dep_mix[n][1], dep_r0[n][1], atol=1e-5), \
f"{n}: deployed grad bled across rollouts (mixed != r0-clean-alone)"
assert torch.allclose(quar_mix[n][0], quar_r1[n][0], atol=1e-5) and \
torch.allclose(quar_mix[n][1], quar_r1[n][1], atol=1e-5), \
f"{n}: quarantine grad bled across rollouts (mixed != r1-hack-alone)"
print(f"2d. mixed-batch per-rollout routing OK ({len(wrappers)} modules, r0->deployed r1->quarantine, no bleed)")
model.zero_grad(set_to_none=True)
# 3. per-rollout c-probe recovery
def gate_grads(batch_ids: torch.Tensor) -> list[torch.Tensor]:
loss = model(batch_ids).logits.float().pow(2).sum() # sum -> per-sequence-additive