diff --git a/scripts/verify_lora2r_routing.py b/scripts/verify_lora2r_routing.py index 62be113..175b6a8 100644 --- a/scripts/verify_lora2r_routing.py +++ b/scripts/verify_lora2r_routing.py @@ -70,6 +70,38 @@ print(f"2c. mid (m=1,d=0): dep grad {dep_n:.2e} > 0, quar grad {quar_n:.2e} > 0 model.zero_grad(set_to_none=True) +# 2d. MIXED batch: rollout 0 clean (0,0), rollout 1 hack (1,1) in ONE forward. This +# is the load-bearing per-rollout vectorization (2a-2c only test uniform masks). The +# masks reshape to [G,1,1], so rollout 0 must route to deployed only, rollout 1 to +# quarantine only, with NO bleed. Loss summed over sequences -> per-rollout grads are +# additive and separable, so the mixed deployed grad must equal rollout-0-alone-clean, +# and the mixed quarantine grad must equal rollout-1-alone-hack. +def block_grads(m_vec: torch.Tensor, d_vec: torch.Tensor, batch: torch.Tensor) -> tuple[dict, dict]: + model.zero_grad(set_to_none=True) + for info in wrappers.values(): + info["layer"]._lora2r_mask = (m_vec, d_vec) + model(batch).logits.float().pow(2).sum().backward() # sum -> per-sequence additive + for info in wrappers.values(): + info["layer"]._lora2r_mask = None + dep = {n: (i["A"].grad[:i["r"]].clone(), i["B"].grad[:, :i["r"]].clone()) for n, i in wrappers.items()} + quar = {n: (i["A"].grad[i["r"]:].clone(), i["B"].grad[:, i["r"]:].clone()) for n, i in wrappers.items()} + return dep, quar + + +dep_mix, quar_mix = block_grads(torch.tensor([0., 1.]), torch.tensor([0., 1.]), ids) # r0 clean, r1 hack +dep_r0, _ = block_grads(torch.zeros(1), torch.zeros(1), ids[:1]) # r0 alone, clean +_, quar_r1 = block_grads(torch.ones(1), torch.ones(1), ids[1:]) # r1 alone, hack +for n in wrappers: + assert torch.allclose(dep_mix[n][0], dep_r0[n][0], atol=1e-5) and \ + torch.allclose(dep_mix[n][1], dep_r0[n][1], atol=1e-5), \ + f"{n}: deployed grad bled across rollouts (mixed != r0-clean-alone)" + assert torch.allclose(quar_mix[n][0], quar_r1[n][0], atol=1e-5) and \ + torch.allclose(quar_mix[n][1], quar_r1[n][1], atol=1e-5), \ + f"{n}: quarantine grad bled across rollouts (mixed != r1-hack-alone)" +print(f"2d. mixed-batch per-rollout routing OK ({len(wrappers)} modules, r0->deployed r1->quarantine, no bleed)") +model.zero_grad(set_to_none=True) + + # 3. per-rollout c-probe recovery def gate_grads(batch_ids: torch.Tensor) -> list[torch.Tensor]: loss = model(batch_ids).logits.float().pow(2).sum() # sum -> per-sequence-additive