mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 17:30:41 +08:00
refactor: collapse to lora2r-only (none/routeV/absorb); delete erase/antipasto/lora_frozen_b paths
train.py rewritten straight-line for the single rank-2r Gaussian-init LoRA adapter and three arms (intervention none|routeV|absorb). Removes the erase grad-surgery, act_vote/online_stats gates, beta/KL reference path, per-source split harvest, the v_hack injection block, and all per-mechanism E/C/D/A-B tallies. Folds in: - T2 Gaussian init (lora2r.py): A0~N(0,1/d_in), B0~N(0,1/2r), net delta 0 at init. - T3 width-pooled gate labels: single (num/den) fraction across modules, skip zero-width modules, raise if none separate (was per-module equal-weight blowup). - T5 absorb arm: masks pinned (1,0) -> both blocks train, no gate. - T6 self-contained ckpt: A/B/A0/B0 in one file (no _hack file, no SVD cache), adapter:"lora2r" in saved cfg. - T8 m3: step_flagged logs the hack share (d.mean), not m.mean. Gates green: verify_lora2r_routing (4 invariants) + smoke none/routeV/absorb end-to-end on tiny-random Qwen3 (logs in /tmp/claude-1000/smoke_*.log). Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
"""lora2r invariants (rank-2r PiSSA-init LoRA + SGTM-style block masks).
|
||||
"""lora2r invariants (rank-2r Gaussian-init LoRA + SGTM-style block masks).
|
||||
|
||||
Asserts, on tiny-random-qwen3 (CPU, fp32):
|
||||
1. IDENTITY AT INIT: wrapped logits == base logits (the hook subtracts the
|
||||
@@ -17,7 +17,7 @@ Exit nonzero on any violation. Wired into `just smoke-lora2r`.
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from vgrout.antipasto import wrap_model_with_lora2r
|
||||
from vgrout.lora2r import wrap_model_with_lora2r
|
||||
from vgrout.eval import ablate_quarantine
|
||||
|
||||
MODEL = "llamafactory/tiny-random-qwen3"
|
||||
@@ -31,7 +31,7 @@ ids = torch.randint(100, 1000, (2, 12))
|
||||
with torch.no_grad():
|
||||
base_logits = model(ids).logits.clone()
|
||||
|
||||
wrappers = wrap_model_with_lora2r(model, MODEL, svd_device="cpu", r=R, grad_probe=True)
|
||||
wrappers = wrap_model_with_lora2r(model, r=R, grad_probe=True)
|
||||
|
||||
# 1. identity at init
|
||||
with torch.no_grad():
|
||||
@@ -52,7 +52,7 @@ def run_masked(m_val: float, d_val: float) -> tuple[float, float]:
|
||||
dep_sq = quar_sq = 0.0
|
||||
for info in wrappers.values():
|
||||
r = info["r"]
|
||||
gA, gB = info["delta_S"].grad, info["B"].grad
|
||||
gA, gB = info["A"].grad, info["B"].grad
|
||||
dep_sq += gA[:r].pow(2).sum().item() + gB[:, :r].pow(2).sum().item()
|
||||
quar_sq += gA[r:].pow(2).sum().item() + gB[:, r:].pow(2).sum().item()
|
||||
return dep_sq ** 0.5, quar_sq ** 0.5
|
||||
@@ -96,7 +96,7 @@ with torch.no_grad():
|
||||
assert torch.allclose(out_abl_init, out0, atol=1e-6), "ablate at init is not a no-op"
|
||||
for info in wrappers.values():
|
||||
r = info["r"]
|
||||
info["delta_S"].data[r:] += 0.05 * torch.randn_like(info["delta_S"].data[r:])
|
||||
info["A"].data[r:] += 0.05 * torch.randn_like(info["A"].data[r:])
|
||||
out_pert = model(ids).logits.clone()
|
||||
pert = (out_pert - out0).abs().max().item()
|
||||
assert pert > 1e-6, f"quarantine perturbation invisible in forward ({pert:.2e})"
|
||||
|
||||
Reference in New Issue
Block a user