mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 17:30:41 +08:00
feat: Evil MoE — learned soft router + pin loss on an ablatable hack expert
Fork of vGROUT. Replaces routeA's fixed v_act quantile gate with a learned per-rollout soft router (HackRouter, seeded from v_act) on the ablatable hack expert: GRPO flows into the router through the soft weight w (it concentrates hack-like rollouts in the hack expert), and a continuous pin loss on the hand-authored pairs anchors the axis. No load balancing; routing is per rollout. lora2r gains a soft-weight forward path (_lora2r_w: w=0 keep, w=1 rout, deployed grad scaled by 1-w). train_moe.py is the on-policy GRPO loop; verify_moe_router.py gates the routing invariants. `just smoke` is green. README/AGENTS rewritten for the fork; original proposal kept as docs/spec/original_evil_moe_spec.md. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,134 @@
|
||||
"""Evil MoE invariants: the soft hack-expert weight w and the learned router.
|
||||
|
||||
Asserts, on tiny-random-qwen3 (CPU, fp32):
|
||||
1. SOFT-WEIGHT FORWARD: with _lora2r_w set, w=0 forward == hack-expert-ablated
|
||||
forward (deploy state) and w=1 forward == unmasked full forward. The hack
|
||||
expert's contribution scales linearly with w.
|
||||
2. SOFT-WEIGHT GRADS: w=0 -> only deployed (keep) block grads; w=1 -> only hack
|
||||
block grads (deployed detached); w=0.5 -> both. (The (1-w) soft detach.)
|
||||
3. ROUTER GRAD: w = router(acts) is differentiable -> a loss through the weighted
|
||||
forward gives nonzero grad to the router parameters.
|
||||
4. PIN SEPARATES: a few steps of router.pin_loss on distinct hack/clean act clusters
|
||||
drives w(hack) up and w(clean) down (the router learns the authored axis).
|
||||
|
||||
Exit nonzero on any violation. Wired into `just smoke`.
|
||||
"""
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from vgrout.eval import ablate_quarantine
|
||||
from vgrout.lora2r import wrap_model_with_lora2r
|
||||
from vgrout.moe_router import HackRouter
|
||||
|
||||
MODEL = "llamafactory/tiny-random-qwen3"
|
||||
R = 4
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = AutoModelForCausalLM.from_pretrained(MODEL, dtype=torch.float32)
|
||||
model.eval()
|
||||
ids = torch.randint(100, 1000, (3, 12))
|
||||
wrappers = wrap_model_with_lora2r(model, r=R)
|
||||
M = len(wrappers)
|
||||
|
||||
# Perturb the hack (quarantine) block so it has a visible, ablatable contribution.
|
||||
with torch.no_grad():
|
||||
for info in wrappers.values():
|
||||
info["A"].data[R:] += 0.05 * torch.randn_like(info["A"].data[R:])
|
||||
info["B"].data[:, R:] += 0.05 * torch.randn_like(info["B"].data[:, R:])
|
||||
|
||||
|
||||
def fwd_with_w(w_val):
|
||||
wv = torch.full((ids.shape[0],), float(w_val))
|
||||
for info in wrappers.values():
|
||||
info["layer"]._lora2r_w = wv
|
||||
with torch.no_grad():
|
||||
out = model(ids).logits.clone()
|
||||
for info in wrappers.values():
|
||||
info["layer"]._lora2r_w = None
|
||||
return out
|
||||
|
||||
|
||||
# 1. soft-weight forward endpoints
|
||||
with torch.no_grad():
|
||||
full = model(ids).logits.clone() # quar fully on (w=1 equivalent)
|
||||
with ablate_quarantine(wrappers):
|
||||
ablated = model(ids).logits.clone() # quar off (w=0 equivalent, = deploy)
|
||||
w0 = fwd_with_w(0.0)
|
||||
w1 = fwd_with_w(1.0)
|
||||
e0 = (w0 - ablated).abs().max().item()
|
||||
e1 = (w1 - full).abs().max().item()
|
||||
assert e0 < 1e-5, f"w=0 != ablated/deploy forward: max|d|={e0:.2e}"
|
||||
assert e1 < 1e-5, f"w=1 != full forward: max|d|={e1:.2e}"
|
||||
# linearity: w=0.5 sits between
|
||||
wh = fwd_with_w(0.5)
|
||||
assert (wh - w0).abs().max().item() > 1e-6 and (wh - w1).abs().max().item() > 1e-6, \
|
||||
"w=0.5 forward did not interpolate between deploy and full"
|
||||
print(f"1. soft-weight forward OK (w=0==deploy {e0:.1e}, w=1==full {e1:.1e}, w=0.5 interpolates)")
|
||||
|
||||
|
||||
# 2. soft-weight grads route by w
|
||||
def block_grad_norms(w_val):
|
||||
model.zero_grad(set_to_none=True)
|
||||
wv = torch.full((ids.shape[0],), float(w_val))
|
||||
for info in wrappers.values():
|
||||
info["layer"]._lora2r_w = wv
|
||||
model(ids).logits.float().pow(2).mean().backward()
|
||||
for info in wrappers.values():
|
||||
info["layer"]._lora2r_w = None
|
||||
dep = quar = 0.0
|
||||
for info in wrappers.values():
|
||||
gA, gB, r = info["A"].grad, info["B"].grad, info["r"]
|
||||
dep += gA[:r].pow(2).sum().item() + gB[:, :r].pow(2).sum().item()
|
||||
quar += gA[r:].pow(2).sum().item() + gB[:, r:].pow(2).sum().item()
|
||||
return dep ** 0.5, quar ** 0.5
|
||||
|
||||
|
||||
dep_n, quar_n = block_grad_norms(0.0)
|
||||
assert dep_n > 1e-8 and quar_n < 1e-12, f"w=0: dep={dep_n:.2e} quar={quar_n:.2e} (want keep-only)"
|
||||
print(f"2a. w=0 (keep): dep grad {dep_n:.2e} > 0, hack grad {quar_n:.2e} == 0 OK")
|
||||
dep_n, quar_n = block_grad_norms(1.0)
|
||||
assert dep_n < 1e-12 and quar_n > 1e-8, f"w=1: dep={dep_n:.2e} quar={quar_n:.2e} (want hack-only)"
|
||||
print(f"2b. w=1 (rout): dep grad {dep_n:.2e} == 0, hack grad {quar_n:.2e} > 0 OK")
|
||||
dep_n, quar_n = block_grad_norms(0.5)
|
||||
assert dep_n > 1e-8 and quar_n > 1e-8, f"w=0.5: dep={dep_n:.2e} quar={quar_n:.2e} (want both)"
|
||||
print(f"2c. w=0.5 (absorb): dep grad {dep_n:.2e} > 0, hack grad {quar_n:.2e} > 0 OK")
|
||||
model.zero_grad(set_to_none=True)
|
||||
|
||||
|
||||
# 3. router grad flows from the weighted forward
|
||||
v_act = torch.randn(M, R)
|
||||
v_act = v_act / v_act.norm(dim=-1, keepdim=True)
|
||||
router = HackRouter(v_act)
|
||||
acts = torch.randn(ids.shape[0], M, R)
|
||||
w = router(acts)
|
||||
for info in wrappers.values():
|
||||
info["layer"]._lora2r_w = w
|
||||
model.zero_grad(set_to_none=True)
|
||||
model(ids).logits.float().pow(2).mean().backward()
|
||||
for info in wrappers.values():
|
||||
info["layer"]._lora2r_w = None
|
||||
g = router.direction.grad
|
||||
assert g is not None and g.abs().sum().item() > 0, "router got no gradient from the weighted forward"
|
||||
print(f"3. router grad OK (||d router.direction||={g.norm().item():.2e} > 0)")
|
||||
|
||||
|
||||
# 4. pin loss separates two distinct act clusters
|
||||
torch.manual_seed(1)
|
||||
hack_dir = torch.randn(M, R); hack_dir /= hack_dir.norm()
|
||||
router2 = HackRouter(torch.zeros(M, R) + 0.01 * torch.randn(M, R)) # NOT seeded with hack_dir
|
||||
hack_acts = hack_dir.unsqueeze(0) + 0.05 * torch.randn(8, M, R) # cluster near +hack_dir
|
||||
clean_acts = -hack_dir.unsqueeze(0) + 0.05 * torch.randn(8, M, R) # cluster near -hack_dir
|
||||
opt = torch.optim.Adam(router2.parameters(), lr=0.05)
|
||||
w_hack0 = router2(hack_acts).mean().item()
|
||||
w_clean0 = router2(clean_acts).mean().item()
|
||||
for _ in range(50):
|
||||
opt.zero_grad()
|
||||
router2.pin_loss(hack_acts, clean_acts).backward()
|
||||
opt.step()
|
||||
w_hack1 = router2(hack_acts).mean().item()
|
||||
w_clean1 = router2(clean_acts).mean().item()
|
||||
assert w_hack1 > 0.7 and w_clean1 < 0.3, \
|
||||
f"pin did not separate: w_hack {w_hack0:.2f}->{w_hack1:.2f}, w_clean {w_clean0:.2f}->{w_clean1:.2f}"
|
||||
print(f"4. pin separates OK (w_hack {w_hack0:.2f}->{w_hack1:.2f}, w_clean {w_clean0:.2f}->{w_clean1:.2f})")
|
||||
|
||||
print("verify_moe_router: ALL OK")
|
||||
Reference in New Issue
Block a user