Files
evil_MoE/scripts/verify_moe_router.py
wassname 04a98b321e feat: Evil MoE — learned soft router + pin loss on an ablatable hack expert
Fork of vGROUT. Replaces routeA's fixed v_act quantile gate with a learned
per-rollout soft router (HackRouter, seeded from v_act) on the ablatable hack
expert: GRPO flows into the router through the soft weight w (it concentrates
hack-like rollouts in the hack expert), and a continuous pin loss on the
hand-authored pairs anchors the axis. No load balancing; routing is per rollout.

lora2r gains a soft-weight forward path (_lora2r_w: w=0 keep, w=1 rout, deployed
grad scaled by 1-w). train_moe.py is the on-policy GRPO loop; verify_moe_router.py
gates the routing invariants. `just smoke` is green. README/AGENTS rewritten for
the fork; original proposal kept as docs/spec/original_evil_moe_spec.md.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
2026-06-14 11:25:14 +08:00

135 lines
5.6 KiB
Python

"""Evil MoE invariants: the soft hack-expert weight w and the learned router.
Asserts, on tiny-random-qwen3 (CPU, fp32):
1. SOFT-WEIGHT FORWARD: with _lora2r_w set, w=0 forward == hack-expert-ablated
forward (deploy state) and w=1 forward == unmasked full forward. The hack
expert's contribution scales linearly with w.
2. SOFT-WEIGHT GRADS: w=0 -> only deployed (keep) block grads; w=1 -> only hack
block grads (deployed detached); w=0.5 -> both. (The (1-w) soft detach.)
3. ROUTER GRAD: w = router(acts) is differentiable -> a loss through the weighted
forward gives nonzero grad to the router parameters.
4. PIN SEPARATES: a few steps of router.pin_loss on distinct hack/clean act clusters
drives w(hack) up and w(clean) down (the router learns the authored axis).
Exit nonzero on any violation. Wired into `just smoke`.
"""
import torch
from transformers import AutoModelForCausalLM
from vgrout.eval import ablate_quarantine
from vgrout.lora2r import wrap_model_with_lora2r
from vgrout.moe_router import HackRouter
MODEL = "llamafactory/tiny-random-qwen3"
R = 4
torch.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(MODEL, dtype=torch.float32)
model.eval()
ids = torch.randint(100, 1000, (3, 12))
wrappers = wrap_model_with_lora2r(model, r=R)
M = len(wrappers)
# Perturb the hack (quarantine) block so it has a visible, ablatable contribution.
with torch.no_grad():
for info in wrappers.values():
info["A"].data[R:] += 0.05 * torch.randn_like(info["A"].data[R:])
info["B"].data[:, R:] += 0.05 * torch.randn_like(info["B"].data[:, R:])
def fwd_with_w(w_val):
wv = torch.full((ids.shape[0],), float(w_val))
for info in wrappers.values():
info["layer"]._lora2r_w = wv
with torch.no_grad():
out = model(ids).logits.clone()
for info in wrappers.values():
info["layer"]._lora2r_w = None
return out
# 1. soft-weight forward endpoints
with torch.no_grad():
full = model(ids).logits.clone() # quar fully on (w=1 equivalent)
with ablate_quarantine(wrappers):
ablated = model(ids).logits.clone() # quar off (w=0 equivalent, = deploy)
w0 = fwd_with_w(0.0)
w1 = fwd_with_w(1.0)
e0 = (w0 - ablated).abs().max().item()
e1 = (w1 - full).abs().max().item()
assert e0 < 1e-5, f"w=0 != ablated/deploy forward: max|d|={e0:.2e}"
assert e1 < 1e-5, f"w=1 != full forward: max|d|={e1:.2e}"
# linearity: w=0.5 sits between
wh = fwd_with_w(0.5)
assert (wh - w0).abs().max().item() > 1e-6 and (wh - w1).abs().max().item() > 1e-6, \
"w=0.5 forward did not interpolate between deploy and full"
print(f"1. soft-weight forward OK (w=0==deploy {e0:.1e}, w=1==full {e1:.1e}, w=0.5 interpolates)")
# 2. soft-weight grads route by w
def block_grad_norms(w_val):
model.zero_grad(set_to_none=True)
wv = torch.full((ids.shape[0],), float(w_val))
for info in wrappers.values():
info["layer"]._lora2r_w = wv
model(ids).logits.float().pow(2).mean().backward()
for info in wrappers.values():
info["layer"]._lora2r_w = None
dep = quar = 0.0
for info in wrappers.values():
gA, gB, r = info["A"].grad, info["B"].grad, info["r"]
dep += gA[:r].pow(2).sum().item() + gB[:, :r].pow(2).sum().item()
quar += gA[r:].pow(2).sum().item() + gB[:, r:].pow(2).sum().item()
return dep ** 0.5, quar ** 0.5
dep_n, quar_n = block_grad_norms(0.0)
assert dep_n > 1e-8 and quar_n < 1e-12, f"w=0: dep={dep_n:.2e} quar={quar_n:.2e} (want keep-only)"
print(f"2a. w=0 (keep): dep grad {dep_n:.2e} > 0, hack grad {quar_n:.2e} == 0 OK")
dep_n, quar_n = block_grad_norms(1.0)
assert dep_n < 1e-12 and quar_n > 1e-8, f"w=1: dep={dep_n:.2e} quar={quar_n:.2e} (want hack-only)"
print(f"2b. w=1 (rout): dep grad {dep_n:.2e} == 0, hack grad {quar_n:.2e} > 0 OK")
dep_n, quar_n = block_grad_norms(0.5)
assert dep_n > 1e-8 and quar_n > 1e-8, f"w=0.5: dep={dep_n:.2e} quar={quar_n:.2e} (want both)"
print(f"2c. w=0.5 (absorb): dep grad {dep_n:.2e} > 0, hack grad {quar_n:.2e} > 0 OK")
model.zero_grad(set_to_none=True)
# 3. router grad flows from the weighted forward
v_act = torch.randn(M, R)
v_act = v_act / v_act.norm(dim=-1, keepdim=True)
router = HackRouter(v_act)
acts = torch.randn(ids.shape[0], M, R)
w = router(acts)
for info in wrappers.values():
info["layer"]._lora2r_w = w
model.zero_grad(set_to_none=True)
model(ids).logits.float().pow(2).mean().backward()
for info in wrappers.values():
info["layer"]._lora2r_w = None
g = router.direction.grad
assert g is not None and g.abs().sum().item() > 0, "router got no gradient from the weighted forward"
print(f"3. router grad OK (||d router.direction||={g.norm().item():.2e} > 0)")
# 4. pin loss separates two distinct act clusters
torch.manual_seed(1)
hack_dir = torch.randn(M, R); hack_dir /= hack_dir.norm()
router2 = HackRouter(torch.zeros(M, R) + 0.01 * torch.randn(M, R)) # NOT seeded with hack_dir
hack_acts = hack_dir.unsqueeze(0) + 0.05 * torch.randn(8, M, R) # cluster near +hack_dir
clean_acts = -hack_dir.unsqueeze(0) + 0.05 * torch.randn(8, M, R) # cluster near -hack_dir
opt = torch.optim.Adam(router2.parameters(), lr=0.05)
w_hack0 = router2(hack_acts).mean().item()
w_clean0 = router2(clean_acts).mean().item()
for _ in range(50):
opt.zero_grad()
router2.pin_loss(hack_acts, clean_acts).backward()
opt.step()
w_hack1 = router2(hack_acts).mean().item()
w_clean1 = router2(clean_acts).mean().item()
assert w_hack1 > 0.7 and w_clean1 < 0.3, \
f"pin did not separate: w_hack {w_hack0:.2f}->{w_hack1:.2f}, w_clean {w_clean0:.2f}->{w_clean1:.2f}"
print(f"4. pin separates OK (w_hack {w_hack0:.2f}->{w_hack1:.2f}, w_clean {w_clean0:.2f}->{w_clean1:.2f})")
print("verify_moe_router: ALL OK")