mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 17:30:41 +08:00
test: add mixed-batch per-rollout routing gate to verify_lora2r_routing (T8)
2a-2c only tested UNIFORM masks. 2d puts rollout 0 clean (0,0) and rollout 1 hack (1,1) in ONE forward and asserts the mixed deployed grad == rollout-0-alone-clean and the mixed quarantine grad == rollout-1-alone-hack -- the load-bearing per-rollout mask vectorization ([G,1,1] reshape) with no cross-rollout bleed. Green on tiny-random. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -70,6 +70,38 @@ print(f"2c. mid (m=1,d=0): dep grad {dep_n:.2e} > 0, quar grad {quar_n:.2e} > 0
|
||||
model.zero_grad(set_to_none=True)
|
||||
|
||||
|
||||
# 2d. MIXED batch: rollout 0 clean (0,0), rollout 1 hack (1,1) in ONE forward. This
|
||||
# is the load-bearing per-rollout vectorization (2a-2c only test uniform masks). The
|
||||
# masks reshape to [G,1,1], so rollout 0 must route to deployed only, rollout 1 to
|
||||
# quarantine only, with NO bleed. Loss summed over sequences -> per-rollout grads are
|
||||
# additive and separable, so the mixed deployed grad must equal rollout-0-alone-clean,
|
||||
# and the mixed quarantine grad must equal rollout-1-alone-hack.
|
||||
def block_grads(m_vec: torch.Tensor, d_vec: torch.Tensor, batch: torch.Tensor) -> tuple[dict, dict]:
|
||||
model.zero_grad(set_to_none=True)
|
||||
for info in wrappers.values():
|
||||
info["layer"]._lora2r_mask = (m_vec, d_vec)
|
||||
model(batch).logits.float().pow(2).sum().backward() # sum -> per-sequence additive
|
||||
for info in wrappers.values():
|
||||
info["layer"]._lora2r_mask = None
|
||||
dep = {n: (i["A"].grad[:i["r"]].clone(), i["B"].grad[:, :i["r"]].clone()) for n, i in wrappers.items()}
|
||||
quar = {n: (i["A"].grad[i["r"]:].clone(), i["B"].grad[:, i["r"]:].clone()) for n, i in wrappers.items()}
|
||||
return dep, quar
|
||||
|
||||
|
||||
dep_mix, quar_mix = block_grads(torch.tensor([0., 1.]), torch.tensor([0., 1.]), ids) # r0 clean, r1 hack
|
||||
dep_r0, _ = block_grads(torch.zeros(1), torch.zeros(1), ids[:1]) # r0 alone, clean
|
||||
_, quar_r1 = block_grads(torch.ones(1), torch.ones(1), ids[1:]) # r1 alone, hack
|
||||
for n in wrappers:
|
||||
assert torch.allclose(dep_mix[n][0], dep_r0[n][0], atol=1e-5) and \
|
||||
torch.allclose(dep_mix[n][1], dep_r0[n][1], atol=1e-5), \
|
||||
f"{n}: deployed grad bled across rollouts (mixed != r0-clean-alone)"
|
||||
assert torch.allclose(quar_mix[n][0], quar_r1[n][0], atol=1e-5) and \
|
||||
torch.allclose(quar_mix[n][1], quar_r1[n][1], atol=1e-5), \
|
||||
f"{n}: quarantine grad bled across rollouts (mixed != r1-hack-alone)"
|
||||
print(f"2d. mixed-batch per-rollout routing OK ({len(wrappers)} modules, r0->deployed r1->quarantine, no bleed)")
|
||||
model.zero_grad(set_to_none=True)
|
||||
|
||||
|
||||
# 3. per-rollout c-probe recovery
|
||||
def gate_grads(batch_ids: torch.Tensor) -> list[torch.Tensor]:
|
||||
loss = model(batch_ids).logits.float().pow(2).sum() # sum -> per-sequence-additive
|
||||
|
||||
Reference in New Issue
Block a user