refactor: collapse to lora2r-only (none/routeV/absorb); delete erase/antipasto/lora_frozen_b paths

train.py rewritten straight-line for the single rank-2r Gaussian-init LoRA adapter
and three arms (intervention none|routeV|absorb). Removes the erase grad-surgery,
act_vote/online_stats gates, beta/KL reference path, per-source split harvest, the
v_hack injection block, and all per-mechanism E/C/D/A-B tallies. Folds in:
- T2 Gaussian init (lora2r.py): A0~N(0,1/d_in), B0~N(0,1/2r), net delta 0 at init.
- T3 width-pooled gate labels: single (num/den) fraction across modules, skip
  zero-width modules, raise if none separate (was per-module equal-weight blowup).
- T5 absorb arm: masks pinned (1,0) -> both blocks train, no gate.
- T6 self-contained ckpt: A/B/A0/B0 in one file (no _hack file, no SVD cache),
  adapter:"lora2r" in saved cfg.
- T8 m3: step_flagged logs the hack share (d.mean), not m.mean.

Gates green: verify_lora2r_routing (4 invariants) + smoke none/routeV/absorb
end-to-end on tiny-random Qwen3 (logs in /tmp/claude-1000/smoke_*.log).

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-10 10:58:22 +00:00
parent 6094568c56
commit 5c97975185
8 changed files with 517 additions and 1190 deletions
+5 -5
View File
@@ -1,4 +1,4 @@
"""lora2r invariants (rank-2r PiSSA-init LoRA + SGTM-style block masks).
"""lora2r invariants (rank-2r Gaussian-init LoRA + SGTM-style block masks).
Asserts, on tiny-random-qwen3 (CPU, fp32):
1. IDENTITY AT INIT: wrapped logits == base logits (the hook subtracts the
@@ -17,7 +17,7 @@ Exit nonzero on any violation. Wired into `just smoke-lora2r`.
import torch
from transformers import AutoModelForCausalLM
from vgrout.antipasto import wrap_model_with_lora2r
from vgrout.lora2r import wrap_model_with_lora2r
from vgrout.eval import ablate_quarantine
MODEL = "llamafactory/tiny-random-qwen3"
@@ -31,7 +31,7 @@ ids = torch.randint(100, 1000, (2, 12))
with torch.no_grad():
base_logits = model(ids).logits.clone()
wrappers = wrap_model_with_lora2r(model, MODEL, svd_device="cpu", r=R, grad_probe=True)
wrappers = wrap_model_with_lora2r(model, r=R, grad_probe=True)
# 1. identity at init
with torch.no_grad():
@@ -52,7 +52,7 @@ def run_masked(m_val: float, d_val: float) -> tuple[float, float]:
dep_sq = quar_sq = 0.0
for info in wrappers.values():
r = info["r"]
gA, gB = info["delta_S"].grad, info["B"].grad
gA, gB = info["A"].grad, info["B"].grad
dep_sq += gA[:r].pow(2).sum().item() + gB[:, :r].pow(2).sum().item()
quar_sq += gA[r:].pow(2).sum().item() + gB[:, r:].pow(2).sum().item()
return dep_sq ** 0.5, quar_sq ** 0.5
@@ -96,7 +96,7 @@ with torch.no_grad():
assert torch.allclose(out_abl_init, out0, atol=1e-6), "ablate at init is not a no-op"
for info in wrappers.values():
r = info["r"]
info["delta_S"].data[r:] += 0.05 * torch.randn_like(info["delta_S"].data[r:])
info["A"].data[r:] += 0.05 * torch.randn_like(info["A"].data[r:])
out_pert = model(ids).logits.clone()
pert = (out_pert - out0).abs().max().item()
assert pert > 1e-6, f"quarantine perturbation invisible in forward ({pert:.2e})"