Files
evil_MoE/scripts/verify_lora2r_routing.py
T
wassname 04a98b321e feat: Evil MoE — learned soft router + pin loss on an ablatable hack expert
Fork of vGROUT. Replaces routeA's fixed v_act quantile gate with a learned
per-rollout soft router (HackRouter, seeded from v_act) on the ablatable hack
expert: GRPO flows into the router through the soft weight w (it concentrates
hack-like rollouts in the hack expert), and a continuous pin loss on the
hand-authored pairs anchors the axis. No load balancing; routing is per rollout.

lora2r gains a soft-weight forward path (_lora2r_w: w=0 keep, w=1 rout, deployed
grad scaled by 1-w). train_moe.py is the on-policy GRPO loop; verify_moe_router.py
gates the routing invariants. `just smoke` is green. README/AGENTS rewritten for
the fork; original proposal kept as docs/spec/original_evil_moe_spec.md.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
2026-06-14 11:25:14 +08:00

143 lines
6.9 KiB
Python

"""lora2r invariants (rank-2r Gaussian-init LoRA with per-rollout output masks).
Asserts, on tiny-random-qwen3 (CPU, fp32):
1. IDENTITY AT INIT: wrapped logits == base logits (the hook subtracts the
frozen A0/B0 init contribution, so net delta is exactly 0).
2. MASK ROUTING (block grads under each three-way gate label):
clean (m=0,d=0): deployed-block grads nonzero, quarantine-block ZERO
hack (m=1,d=1): deployed-block ZERO (output detach), quarantine nonzero
mid (m=1,d=0): both nonzero (absorption)
3. C-PROBE PER-ROLLOUT RECOVERY: batched c.grad rows == single-rollout c.grad
(the gate's per-rollout weight grads are exact, not an approximation).
4. ABLATION TEETH: ablate_quarantine is a no-op at init, removes a quarantine
perturbation while active, and restores it on exit.
Exit nonzero on any violation. Wired into `just smoke-lora2r`.
"""
import torch
from transformers import AutoModelForCausalLM
from vgrout.lora2r import wrap_model_with_lora2r
from vgrout.eval import ablate_quarantine
MODEL = "llamafactory/tiny-random-qwen3"
R = 4 # tiny model min Linear dim is 16, so 2r=8 fits everywhere
torch.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(MODEL, dtype=torch.float32)
model.eval()
ids = torch.randint(100, 1000, (2, 12))
with torch.no_grad():
base_logits = model(ids).logits.clone()
wrappers = wrap_model_with_lora2r(model, r=R, grad_probe=True)
# 1. identity at init
with torch.no_grad():
err = (model(ids).logits - base_logits).abs().max().item()
assert err < 1e-5, f"init not identity: max|dlogits|={err:.2e}"
print(f"1. identity at init OK (max|dlogits|={err:.2e})")
# 2. mask routing
def run_masked(m_val: float, d_val: float) -> tuple[float, float]:
model.zero_grad(set_to_none=True)
g_vec = torch.full((ids.shape[0],), m_val), torch.full((ids.shape[0],), d_val)
for info in wrappers.values():
info["layer"]._lora2r_mask = g_vec
model(ids).logits.float().pow(2).mean().backward()
for info in wrappers.values():
info["layer"]._lora2r_mask = None
dep_sq = quar_sq = 0.0
for info in wrappers.values():
r = info["r"]
gA, gB = info["A"].grad, info["B"].grad
dep_sq += gA[:r].pow(2).sum().item() + gB[:, :r].pow(2).sum().item()
quar_sq += gA[r:].pow(2).sum().item() + gB[:, r:].pow(2).sum().item()
return dep_sq ** 0.5, quar_sq ** 0.5
dep_n, quar_n = run_masked(0.0, 0.0) # clean
assert dep_n > 1e-8 and quar_n < 1e-12, f"clean gate: dep={dep_n:.2e} quar={quar_n:.2e}"
print(f"2a. clean (m=0,d=0): dep grad {dep_n:.2e} > 0, quar grad {quar_n:.2e} == 0 OK")
dep_n, quar_n = run_masked(1.0, 1.0) # hack
assert dep_n < 1e-12 and quar_n > 1e-8, f"hack gate: dep={dep_n:.2e} quar={quar_n:.2e}"
print(f"2b. hack (m=1,d=1): dep grad {dep_n:.2e} == 0, quar grad {quar_n:.2e} > 0 OK")
dep_n, quar_n = run_masked(1.0, 0.0) # mid
assert dep_n > 1e-8 and quar_n > 1e-8, f"mid gate: dep={dep_n:.2e} quar={quar_n:.2e}"
print(f"2c. mid (m=1,d=0): dep grad {dep_n:.2e} > 0, quar grad {quar_n:.2e} > 0 OK")
model.zero_grad(set_to_none=True)
# 2d. MIXED batch: rollout 0 clean (0,0), rollout 1 hack (1,1) in ONE forward. This
# is the load-bearing per-rollout vectorization (2a-2c only test uniform masks). The
# masks reshape to [G,1,1], so rollout 0 must route to deployed only, rollout 1 to
# quarantine only, with NO bleed. Loss summed over sequences -> per-rollout grads are
# additive and separable, so the mixed deployed grad must equal rollout-0-alone-clean,
# and the mixed quarantine grad must equal rollout-1-alone-hack.
def block_grads(m_vec: torch.Tensor, d_vec: torch.Tensor, batch: torch.Tensor) -> tuple[dict, dict]:
model.zero_grad(set_to_none=True)
for info in wrappers.values():
info["layer"]._lora2r_mask = (m_vec, d_vec)
model(batch).logits.float().pow(2).sum().backward() # sum -> per-sequence additive
for info in wrappers.values():
info["layer"]._lora2r_mask = None
dep = {n: (i["A"].grad[:i["r"]].clone(), i["B"].grad[:, :i["r"]].clone()) for n, i in wrappers.items()}
quar = {n: (i["A"].grad[i["r"]:].clone(), i["B"].grad[:, i["r"]:].clone()) for n, i in wrappers.items()}
return dep, quar
dep_mix, quar_mix = block_grads(torch.tensor([0., 1.]), torch.tensor([0., 1.]), ids) # r0 clean, r1 hack
dep_r0, _ = block_grads(torch.zeros(1), torch.zeros(1), ids[:1]) # r0 alone, clean
_, quar_r1 = block_grads(torch.ones(1), torch.ones(1), ids[1:]) # r1 alone, hack
for n in wrappers:
assert torch.allclose(dep_mix[n][0], dep_r0[n][0], atol=1e-5) and \
torch.allclose(dep_mix[n][1], dep_r0[n][1], atol=1e-5), \
f"{n}: deployed grad bled across rollouts (mixed != r0-clean-alone)"
assert torch.allclose(quar_mix[n][0], quar_r1[n][0], atol=1e-5) and \
torch.allclose(quar_mix[n][1], quar_r1[n][1], atol=1e-5), \
f"{n}: quarantine grad bled across rollouts (mixed != r1-hack-alone)"
print(f"2d. mixed-batch per-rollout routing OK ({len(wrappers)} modules, r0->deployed r1->quarantine, no bleed)")
model.zero_grad(set_to_none=True)
# 3. per-rollout c-probe recovery
def gate_grads(batch_ids: torch.Tensor) -> list[torch.Tensor]:
loss = model(batch_ids).logits.float().pow(2).sum() # sum -> per-sequence-additive
gates = [info["layer"]._lora2r_gate for info in wrappers.values()]
return [g.detach().clone() for g in torch.autograd.grad(loss, gates)]
both = gate_grads(ids)
solo0 = gate_grads(ids[:1])
solo1 = gate_grads(ids[1:])
for name, gb, g0, g1 in zip(wrappers, both, solo0, solo1, strict=True):
gb2 = gb.reshape(2, -1, gb.shape[-1]).sum(1) # [2, 2r] per-rollout
g0r = g0.reshape(1, -1, g0.shape[-1]).sum(1)[0]
g1r = g1.reshape(1, -1, g1.shape[-1]).sum(1)[0]
assert torch.allclose(gb2[0], g0r, atol=1e-5, rtol=1e-4), f"{name}: rollout 0 c.grad mismatch"
assert torch.allclose(gb2[1], g1r, atol=1e-5, rtol=1e-4), f"{name}: rollout 1 c.grad mismatch"
print(f"3. c-probe per-rollout recovery OK ({len(both)} modules, batched == solo)")
# 4. ablation teeth
with torch.no_grad():
out0 = model(ids).logits.clone()
with ablate_quarantine(wrappers):
out_abl_init = model(ids).logits
assert torch.allclose(out_abl_init, out0, atol=1e-6), "ablate at init is not a no-op"
for info in wrappers.values():
r = info["r"]
info["A"].data[r:] += 0.05 * torch.randn_like(info["A"].data[r:])
out_pert = model(ids).logits.clone()
pert = (out_pert - out0).abs().max().item()
assert pert > 1e-6, f"quarantine perturbation invisible in forward ({pert:.2e})"
with ablate_quarantine(wrappers):
out_abl = model(ids).logits
assert torch.allclose(out_abl, out0, atol=1e-5), "ablation does not remove the quarantine delta"
out_back = model(ids).logits
assert torch.allclose(out_back, out_pert, atol=1e-6), "ablate context did not restore state"
print(f"4. ablation teeth OK (perturbation {pert:.2e} visible, removed under ablate, restored after)")
print("verify_lora2r_routing: ALL OK")