From 6094568c5643e901b3ca3a3eb4a39a1664cd8c0d Mon Sep 17 00:00:00 2001 From: wassname <1103714+wassname@users.noreply.github.com> Date: Wed, 10 Jun 2026 09:25:58 +0000 Subject: [PATCH] feat: lora2r adapter (rank-2r PiSSA-init LoRA) + SGTM three-way hard routing Structural-separation arm to disentangle directionality from shrinkage. A rank-2r PiSSA-init LoRA with A and B both trainable, partitioned into a deployed block [:r] and a quarantine block [r:] (spectrum-matched via alternated SVD axes). Unlike the same-basis PiSSA routeV (where deploy-ablation only removes a magnitude slice of one shared update = shrinkage null), each block has its own input-side A rows and output-side B columns, so deploy-ablation removes a different FUNCTION. Routing = SGTM-style three-way hard per-rollout masks from the cosine of the deployed block's gate-pass gradient to the pair-extracted v_grad: clean (m=0,d=0) trains deployed only; hack (m=1,d=1) detaches deployed output so only the quarantine updates (SGTM grad-retain trick); mid (m=1,d=0) trains both (absorption). Gate is no-cheat: cos to the hand-authored-pair direction, never an oracle label of a live rollout. verify_lora2r_routing.py gates identity-at-init, the three-way block-grad routing, per-rollout c-probe recovery, and ablation teeth; wired into smoke-lora2r. Additive: PiSSA / lora_frozen_b paths untouched. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com> --- justfile | 24 ++++ scripts/verify_lora2r_routing.py | 110 ++++++++++++++++++ src/vgrout/antipasto.py | 113 ++++++++++++++++++ src/vgrout/eval.py | 27 ++++- src/vgrout/extract_vhack_grad.py | 10 +- src/vgrout/tablelog.py | 2 +- src/vgrout/train.py | 194 +++++++++++++++++++++++++++---- src/vgrout/train_config.py | 28 ++++- 8 files changed, 472 insertions(+), 36 deletions(-) create mode 100644 scripts/verify_lora2r_routing.py diff --git a/justfile b/justfile index 4c4ef16..11b8424 100644 --- a/justfile +++ b/justfile @@ -78,6 +78,17 @@ smoke-unhackable *ARGS: --teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \ --eval-n-prompts=2 {{ ARGS }} +# lora2r path: rank-2r PiSSA-init LoRA (A+B trainable) + SGTM-style three-way HARD +# masks (clean->deployed-only, hack->quarantine-only via output detach, mid->both). +# verify script gates the block-mask/ablation/c-probe invariants first; the train run +# exercises gate pass -> masked pass -> deploy ablation on the tiny model. +smoke-lora2r *ARGS: + uv run python scripts/verify_lora2r_routing.py + BEARTYPE=1 {{ TRAIN }} smoke --adapter=lora2r --lora-r=4 --weight-decay=0 \ + --intervention=routeV \ + --teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \ + --eval-ablate-every=10 --eval-n-prompts=2 {{ ARGS }} + # Run smoke twice: first warms the v_hack cache (cache-miss path), second hits # the cache (cache-hit path). Catches scope/save bugs that only manifest in one. smoke-both: @@ -234,6 +245,19 @@ queue-unhackable seed='43' steps='200': pueue add -w "$PWD" -o 8 -l "why: REALISM vanilla unhackable_frac=0.1 {{steps}}st s{{seed}}; resolve: solve climbs vs frac=0 vanilla (persistent solve pressure exists)" -- {{ TRAIN }} fast --steps={{steps}} --intervention=none --seed={{seed}} --out-tag=_unh1_vanilla_s{{seed}} pueue add -w "$PWD" -o 7 -l "why: REALISM routeV per-token unhackable_frac=0.1 {{steps}}st s{{seed}}; resolve: solve_uplift over vanilla LARGER than at frac=0 (routeV reveals the warm solve-skill once hack is ablated)" -- {{ TRAIN }} fast --steps={{steps}} --intervention=routeV_per_token --seed={{seed}} --out-tag=_unh1_routeV_pertoken_s{{seed}} +# H: lora2r directionality. The PiSSA placebo tie (job 86) was SHRINKAGE: deployed and +# quarantine share the frozen U/Vh basis -> identical per-step grads -> routing = magnitude +# split, any direction "works". lora2r gives each block its OWN input-side params +# (PiSSA-init A rows + B cols, all trainable) + SGTM three-way hard masks, so a +# discriminating gate can produce real separation. Arms: real-v, placebo-v (Haar), +# vanilla control (gate pinned clean = capacity/structure-matched, no shrinkage confound). +# resolve: directionality real iff real-v deploy_hack << placebo-v at matched solve; +# both ~vanilla -> the gate (not the adapter) is the bottleneck. +queue-lora2r seed='43': + pueue add -w "$PWD" -o 26 -l "why: lora2r routeV real-v s{{seed}} (SGTM 3-way masks, structural separation); resolve: deploy_hack << placebo-v at matched solve -> directionality real" -- {{ TRAIN }} fast-lora2r --intervention=routeV --seed={{seed}} --out-tag=_l2r_routeV_s{{seed}} + pueue add -w "$PWD" -o 25 -l "why: lora2r routeV PLACEBO-v (Haar) s{{seed}}; resolve: deploy_hack ~ vanilla-lora2r -> real-v suppression is directional, not absorption/shrinkage" -- {{ TRAIN }} fast-lora2r --intervention=routeV --routeV-random-v-seed=157 --seed={{seed}} --out-tag=_l2r_routeV_placebo_s{{seed}} + pueue add -w "$PWD" -o 24 -l "why: lora2r VANILLA control s{{seed}} (gate pinned clean, capacity-matched); resolve: deploy_hack >> 0 emergence reference on the identical adapter" -- {{ TRAIN }} fast-lora2r --intervention=none --seed={{seed}} --out-tag=_l2r_vanilla_s{{seed}} + queue-broad: #!/usr/bin/env bash set -eu diff --git a/scripts/verify_lora2r_routing.py b/scripts/verify_lora2r_routing.py new file mode 100644 index 0000000..ededd36 --- /dev/null +++ b/scripts/verify_lora2r_routing.py @@ -0,0 +1,110 @@ +"""lora2r invariants (rank-2r PiSSA-init LoRA + SGTM-style block masks). + +Asserts, on tiny-random-qwen3 (CPU, fp32): + 1. IDENTITY AT INIT: wrapped logits == base logits (the hook subtracts the + frozen A0/B0 init contribution, so net delta is exactly 0). + 2. MASK ROUTING (block grads under each three-way gate label): + clean (m=0,d=0): deployed-block grads nonzero, quarantine-block ZERO + hack (m=1,d=1): deployed-block ZERO (output detach), quarantine nonzero + mid (m=1,d=0): both nonzero (absorption) + 3. C-PROBE PER-ROLLOUT RECOVERY: batched c.grad rows == single-rollout c.grad + (the gate's per-rollout weight grads are exact, not an approximation). + 4. ABLATION TEETH: ablate_quarantine is a no-op at init, removes a quarantine + perturbation while active, and restores it on exit. + +Exit nonzero on any violation. Wired into `just smoke-lora2r`. +""" +import torch +from transformers import AutoModelForCausalLM + +from vgrout.antipasto import wrap_model_with_lora2r +from vgrout.eval import ablate_quarantine + +MODEL = "llamafactory/tiny-random-qwen3" +R = 4 # tiny model min Linear dim is 16, so 2r=8 fits everywhere + +torch.manual_seed(0) +model = AutoModelForCausalLM.from_pretrained(MODEL, dtype=torch.float32) +model.eval() +ids = torch.randint(100, 1000, (2, 12)) + +with torch.no_grad(): + base_logits = model(ids).logits.clone() + +wrappers = wrap_model_with_lora2r(model, MODEL, svd_device="cpu", r=R, grad_probe=True) + +# 1. identity at init +with torch.no_grad(): + err = (model(ids).logits - base_logits).abs().max().item() +assert err < 1e-5, f"init not identity: max|dlogits|={err:.2e}" +print(f"1. identity at init OK (max|dlogits|={err:.2e})") + + +# 2. mask routing +def run_masked(m_val: float, d_val: float) -> tuple[float, float]: + model.zero_grad(set_to_none=True) + g_vec = torch.full((ids.shape[0],), m_val), torch.full((ids.shape[0],), d_val) + for info in wrappers.values(): + info["layer"]._lora2r_mask = g_vec + model(ids).logits.float().pow(2).mean().backward() + for info in wrappers.values(): + info["layer"]._lora2r_mask = None + dep_sq = quar_sq = 0.0 + for info in wrappers.values(): + r = info["r"] + gA, gB = info["delta_S"].grad, info["B"].grad + dep_sq += gA[:r].pow(2).sum().item() + gB[:, :r].pow(2).sum().item() + quar_sq += gA[r:].pow(2).sum().item() + gB[:, r:].pow(2).sum().item() + return dep_sq ** 0.5, quar_sq ** 0.5 + + +dep_n, quar_n = run_masked(0.0, 0.0) # clean +assert dep_n > 1e-8 and quar_n < 1e-12, f"clean gate: dep={dep_n:.2e} quar={quar_n:.2e}" +print(f"2a. clean (m=0,d=0): dep grad {dep_n:.2e} > 0, quar grad {quar_n:.2e} == 0 OK") +dep_n, quar_n = run_masked(1.0, 1.0) # hack +assert dep_n < 1e-12 and quar_n > 1e-8, f"hack gate: dep={dep_n:.2e} quar={quar_n:.2e}" +print(f"2b. hack (m=1,d=1): dep grad {dep_n:.2e} == 0, quar grad {quar_n:.2e} > 0 OK") +dep_n, quar_n = run_masked(1.0, 0.0) # mid +assert dep_n > 1e-8 and quar_n > 1e-8, f"mid gate: dep={dep_n:.2e} quar={quar_n:.2e}" +print(f"2c. mid (m=1,d=0): dep grad {dep_n:.2e} > 0, quar grad {quar_n:.2e} > 0 OK") +model.zero_grad(set_to_none=True) + + +# 3. per-rollout c-probe recovery +def gate_grads(batch_ids: torch.Tensor) -> list[torch.Tensor]: + loss = model(batch_ids).logits.float().pow(2).sum() # sum -> per-sequence-additive + gates = [info["layer"]._lora2r_gate for info in wrappers.values()] + return [g.detach().clone() for g in torch.autograd.grad(loss, gates)] + + +both = gate_grads(ids) +solo0 = gate_grads(ids[:1]) +solo1 = gate_grads(ids[1:]) +for name, gb, g0, g1 in zip(wrappers, both, solo0, solo1, strict=True): + gb2 = gb.reshape(2, -1, gb.shape[-1]).sum(1) # [2, 2r] per-rollout + g0r = g0.reshape(1, -1, g0.shape[-1]).sum(1)[0] + g1r = g1.reshape(1, -1, g1.shape[-1]).sum(1)[0] + assert torch.allclose(gb2[0], g0r, atol=1e-5, rtol=1e-4), f"{name}: rollout 0 c.grad mismatch" + assert torch.allclose(gb2[1], g1r, atol=1e-5, rtol=1e-4), f"{name}: rollout 1 c.grad mismatch" +print(f"3. c-probe per-rollout recovery OK ({len(both)} modules, batched == solo)") + +# 4. ablation teeth +with torch.no_grad(): + out0 = model(ids).logits.clone() + with ablate_quarantine(wrappers): + out_abl_init = model(ids).logits + assert torch.allclose(out_abl_init, out0, atol=1e-6), "ablate at init is not a no-op" + for info in wrappers.values(): + r = info["r"] + info["delta_S"].data[r:] += 0.05 * torch.randn_like(info["delta_S"].data[r:]) + out_pert = model(ids).logits.clone() + pert = (out_pert - out0).abs().max().item() + assert pert > 1e-6, f"quarantine perturbation invisible in forward ({pert:.2e})" + with ablate_quarantine(wrappers): + out_abl = model(ids).logits + assert torch.allclose(out_abl, out0, atol=1e-5), "ablation does not remove the quarantine delta" + out_back = model(ids).logits + assert torch.allclose(out_back, out_pert, atol=1e-6), "ablate context did not restore state" +print(f"4. ablation teeth OK (perturbation {pert:.2e} visible, removed under ablate, restored after)") + +print("verify_lora2r_routing: ALL OK") diff --git a/src/vgrout/antipasto.py b/src/vgrout/antipasto.py index 2800373..d800c7e 100644 --- a/src/vgrout/antipasto.py +++ b/src/vgrout/antipasto.py @@ -180,6 +180,119 @@ def wrap_model_with_lora_frozen_b( return out +def _lora2r_hook(layer: nn.Linear, args: tuple, y: Tensor) -> Tensor: + """Rank-2r PiSSA-init LoRA, two blocks: deployed [:r] + quarantine [r:]. + + y += B@(A@x) - B0@(A0@x) + + A0/B0 are FROZEN copies of the PiSSA init, so the net delta is exactly 0 at + init while h = A@x is alive. (A zero-init would kill the c-probe weight-grad + space below AND pair extraction at step 0 -- the whole reason for PiSSA init.) + [B|B_q] @ ([A;A_q]@x) has no cross terms (column b_k only ever multiplies row + a_k), so the two blocks ARE two independent adapters; block masks on this one + tensor implement the SGTM parameter partition. + + Block masks (layer._lora2r_mask = (m, d), set by train.py per loss pass; + None = unmasked for generation / gate pass / eval): + m [G] quarantine on/off -- m=0: quarantine zero in forward AND backward + (SGTM retain trick: deployed trains in its post-ablation state) + d [G] deployed detach -- d=1: deployed kept in forward, zero grad + (hack-gated rollouts update ONLY the quarantine block) + Masks act on branch OUTPUTS so a detach blocks grads to BOTH the A rows and + the B columns of that block. + + grad probe: c = ones[..., 2r] spliced as h*c. After backward + c.grad = h ⊙ (Bᵀδ_y) = the per-sample WEIGHT grad of a virtual diagonal + scale between A and B -- the lora2r analog of delta_S.grad (coincides with + the SVD delta_S space at init, so pair extraction ports unchanged). + """ + (x,) = args + A = layer._lora2r_A # [2r, d_in] trainable + B = layer._lora2r_B # [d_out, 2r] trainable + A0 = layer._lora2r_A0 # frozen PiSSA init copies (subtracted: net delta 0 at init) + B0 = layer._lora2r_B0 + r = layer._lora2r_r + h = torch.nn.functional.linear(x, A.to(x.dtype)) # [..., 2r] + if layer._lora2r_grad_probe and torch.is_grad_enabled(): + c = torch.ones(h.shape[0], *([1] * (h.dim() - 2)), h.shape[-1], + device=h.device, dtype=h.dtype, requires_grad=True) + layer._lora2r_gate = c + h = h * c + h0 = torch.nn.functional.linear(x, A0.to(x.dtype)) # [..., 2r] frozen init path + dep = (torch.nn.functional.linear(h[..., :r], B[:, :r].to(x.dtype)) + - torch.nn.functional.linear(h0[..., :r], B0[:, :r].to(x.dtype))) + quar = (torch.nn.functional.linear(h[..., r:], B[:, r:].to(x.dtype)) + - torch.nn.functional.linear(h0[..., r:], B0[:, r:].to(x.dtype))) + if layer._lora2r_mask is not None: + m, d = layer._lora2r_mask # [G] each + G = m.shape[0] + shape = dep.shape # [G, s, d_out] or [G*s, d_out] + dep = dep.reshape(G, -1, shape[-1]) + quar = quar.reshape(G, -1, shape[-1]) + d_ = d.view(G, 1, 1).to(dep.dtype) + dep = ((1 - d_) * dep + d_ * dep.detach()).reshape(shape) + quar = (m.view(G, 1, 1).to(quar.dtype) * quar).reshape(shape) + return y + (dep + quar).to(y.dtype) + + +def wrap_model_with_lora2r( + model: nn.Module, + model_name: str, + cache_root: Path = Path("svd_cache"), + svd_device: torch.device | str = "cuda", + r: int = 32, + grad_probe: bool = False, +) -> dict[str, dict]: + """Attach a rank-2r PiSSA-init LoRA (A AND B trainable) to every target Linear. + + PiSSA init: A0 = sqrt(S)·Vh, B0 = U·sqrt(S) on the top-2r SVD axes of W, + ALTERNATED between the blocks (deployed even axes, quarantine odd) so the two + blocks are spectrum-matched. W stays untouched; the hook subtracts the frozen + A0/B0 contribution (unlike PiSSA proper, which edits W). The quarantine's + learned delta is (A[r:], B[:, r:]) minus init; deploy ablation resets that + block to A0/B0 (eval.ablate_quarantine). + + Info dict per module: {layer, delta_S=A, B, A0, B0, handle, r} -- no separate + delta_S_hack tensor; quarantine = block slices. Consumers branch on "A0". + """ + svd_device_t = torch.device(svd_device) if isinstance(svd_device, str) else svd_device + svd_dir = cache_root / model_name.replace("/", "__") + targets = [(n, m) for n, m in model.named_modules() + if isinstance(m, nn.Linear) and is_target(n)] + logger.info(f"lora2r attach: {len(targets)} target Linear modules, " + f"r={r}/block (2r={2 * r}), PiSSA init, A+B trainable") + out: dict[str, dict] = {} + for name, linear in targets: + W = linear.weight.data + d_out, d_in = W.shape + assert 2 * r <= min(d_out, d_in), \ + f"{name}: 2r={2 * r} exceeds min(d_out,d_in)={min(d_out, d_in)}; lower --lora-r" + U, S, Vh = svd_cached(W, svd_dir / f"{name}.pt", device=svd_device_t) + # Alternate the top-2r axes: deployed gets even ranks, quarantine odd. + order = torch.cat([torch.arange(0, 2 * r, 2), torch.arange(1, 2 * r, 2)]) + sqrtS = S[:2 * r].sqrt()[order] + dev = W.device + A0 = (sqrtS.unsqueeze(1) * Vh[:2 * r][order]).to(device=dev, dtype=torch.float32) # [2r, d_in] + B0 = (U[:, :2 * r][:, order] * sqrtS).to(device=dev, dtype=torch.float32) # [d_out, 2r] + linear.register_buffer("_lora2r_A0", A0, persistent=True) + linear.register_buffer("_lora2r_B0", B0, persistent=True) + A = nn.Parameter(A0.clone()) + B = nn.Parameter(B0.clone()) + linear.register_parameter("_lora2r_A", A) + linear.register_parameter("_lora2r_B", B) + linear._lora2r_r = r + linear._lora2r_grad_probe = grad_probe + linear._lora2r_gate = None + linear._lora2r_mask = None + out[name] = {"layer": linear, "delta_S": A, "B": B, "A0": A0, "B0": B0, + "handle": linear.register_forward_hook(_lora2r_hook), "r": r} + trainable = ("_lora2r_A", "_lora2r_B") + for n, p in model.named_parameters(): + if not n.endswith(trainable): + p.requires_grad_(False) + return out + + def wrap_model_with_antipasto( model: nn.Module, model_name: str, diff --git a/src/vgrout/eval.py b/src/vgrout/eval.py index 8433d41..441c00f 100644 --- a/src/vgrout/eval.py +++ b/src/vgrout/eval.py @@ -99,21 +99,38 @@ def ref_logprobs_via_zero_delta( @contextmanager def ablate_quarantine(wrappers: dict): - """Temporarily zero the routeV quarantine to evaluate the deployed model. + """Temporarily remove the quarantine to evaluate the deployed model. + + delta_S adapters: zero delta_S_hack. lora2r ("A0" in info): reset the + quarantine block (A[r:], B[:,r:]) to the frozen PiSSA init A0/B0 so the net + quarantine delta is 0 -- zeroing the raw params would instead SUBTRACT the + init contribution and corrupt the forward. TODO(post-deploy-finetune): SGTM's ablate(trainable=True) reinits the forget weights to the retain-dims' std instead of zeroing, so the model stays finetunable after the quarantine is removed (no dead hole). We zero because we only eval after deploy; add the reinit path if we ever retrain post-ablate. See docs/grad_routing/sgtm_vs_ours.md.""" - saved = {n: info["delta_S_hack"].data.clone() for n, info in wrappers.items()} - for info in wrappers.values(): - info["delta_S_hack"].data.zero_() + saved: dict[str, object] = {} + for n, info in wrappers.items(): + if "A0" in info: + r = info["r"] + saved[n] = (info["delta_S"].data[r:].clone(), info["B"].data[:, r:].clone()) + info["delta_S"].data[r:] = info["A0"][r:] + info["B"].data[:, r:] = info["B0"][:, r:] + else: + saved[n] = info["delta_S_hack"].data.clone() + info["delta_S_hack"].data.zero_() try: yield finally: for n, info in wrappers.items(): - info["delta_S_hack"].data.copy_(saved[n]) + if "A0" in info: + r = info["r"] + info["delta_S"].data[r:] = saved[n][0] + info["B"].data[:, r:] = saved[n][1] + else: + info["delta_S_hack"].data.copy_(saved[n]) @torch.no_grad() diff --git a/src/vgrout/extract_vhack_grad.py b/src/vgrout/extract_vhack_grad.py index e9dc9b1..0a45f13 100644 --- a/src/vgrout/extract_vhack_grad.py +++ b/src/vgrout/extract_vhack_grad.py @@ -139,7 +139,15 @@ def extract_v_hack( bucket = grads_hack if label == "hack" else grads_clean for name, info in wrappers.items(): layer = info["layer"] - if getattr(layer, "_lora_grad_probe", False) and layer._lora_h is not None: + if getattr(layer, "_lora2r_grad_probe", False): + # lora2r: per-pair weight grad of the virtual diagonal (c-probe), + # DEPLOYED block only -- the same space the live gate reads + # (train.py lora2r branch), so band calibration is apples-to-apples. + cg = layer._lora2r_gate.grad + if cg is None: + raise RuntimeError(f"no c-probe grad on {name}; aborting lora2r extract") + g = cg.sum(dim=tuple(range(cg.dim() - 1)))[: layer._lora2r_r] # [r] + elif getattr(layer, "_lora_grad_probe", False) and layer._lora_h is not None: # LoRA-frozen-B: the routing handle is the r-bottleneck gradient # g_h = B^T δ_y (B frozen -> static path), not A.grad. Sum over (batch, # tokens) to mirror how AntiPaSTO's δS.grad accumulates over positions. diff --git a/src/vgrout/tablelog.py b/src/vgrout/tablelog.py index 2e23417..87bc761 100644 --- a/src/vgrout/tablelog.py +++ b/src/vgrout/tablelog.py @@ -74,7 +74,7 @@ class StepLogger: show_ablate: bool = False) -> None: # Erase reports projection diagnostics; routeV reports routing diagnostics below. projects = arm == "projected" - is_route = arm in ("routingV", "routingV_per_token") + is_route = arm in ("routingV", "routingV_per_token", "routingV_lora2r") cols: list[_Col] = [ _Col("step", 4, "step", "d", "GRPO step"), _Col("ref_eq", 6, "ref_eq", ".2f", "vanilla-equiv step (cum_gens/256)"), diff --git a/src/vgrout/train.py b/src/vgrout/train.py index 8880670..b5e9f91 100644 --- a/src/vgrout/train.py +++ b/src/vgrout/train.py @@ -51,7 +51,7 @@ from tabulate import tabulate from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig -from .antipasto import wrap_model_with_antipasto, wrap_model_with_lora_frozen_b +from .antipasto import wrap_model_with_antipasto, wrap_model_with_lora2r, wrap_model_with_lora_frozen_b from .proj import per_token_logps, project_delta_S_grad, mean_cos_pre_from_grads from .rewards import EnvMode, compute_reward from .data import DATA, load_problems @@ -59,7 +59,8 @@ from .vhack import load_v_hack, pairset_sha256, postprocess_v_hack from .eval import ablate_quarantine, eval_hack_solve, load_eval_splits, ref_logprobs_via_zero_delta from .tablelog import setup_logging, StepLogger from .run_artifacts import RUN_SCHEMA -from .train_config import Config, FastConfig, FastLoraConfig, FullConfig, SmokeConfig +from .train_config import (Config, FastConfig, FastLora2rConfig, FastLoraConfig, + FullConfig, SmokeConfig) CACHE_ROOT = Path("svd_cache") OUT_DIR = Path("out") @@ -188,6 +189,16 @@ def _validate_config(cfg: Config) -> None: raise ValueError(f"--v-hack-path is an erase-arm option; ignored on intervention={cfg.intervention}") if cfg.adapter == "lora_frozen_b" and cfg.intervention not in ("none", "routeV", "routeV_per_token"): raise ValueError(f"lora_frozen_b adapter not wired for intervention={cfg.intervention}") + if cfg.adapter == "lora2r": + if cfg.intervention not in ("none", "routeV"): + raise ValueError(f"lora2r supports intervention none|routeV, got {cfg.intervention}") + if cfg.beta: + raise ValueError("lora2r has no zero-delta reference path (A=0 is NOT identity); beta must be 0") + if cfg.weight_decay != 0.0: + raise ValueError("lora2r params are PiSSA-init (nonzero); AdamW decay pulls them toward 0, " + "not toward init -- set --weight-decay=0") + if cfg.routeV_gate != "grad_cosine" or cfg.routeV_top_k > 1 or cfg.routeV_absorb_all: + raise ValueError("lora2r implements only the per-rollout grad_cosine three-way gate") def _resolve_v_hack_file(cfg: Config) -> Path: @@ -257,23 +268,38 @@ def main(cfg: Config) -> int: # Generation enables KV cache; loss forwards disable it to avoid unused state. model.config.use_cache = False - # ── adapter: δS (kept) + δS_hack (quarantine). antipasto=diagonal[r]; lora_frozen_b=A[r,d_in] ── + # ── adapter: δS (kept) + δS_hack (quarantine). antipasto=diagonal[r]; + # lora_frozen_b=A[r,d_in]; lora2r=rank-2r PiSSA LoRA, quarantine = block slices ── is_routeV = cfg.intervention in ("routeV", "routeV_per_token") is_per_token = cfg.intervention == "routeV_per_token" is_lora = cfg.adapter == "lora_frozen_b" # arm/adapter compatibility checked in _validate_config + is_lora2r = cfg.adapter == "lora2r" if is_lora: wrappers = wrap_model_with_lora_frozen_b( model, model_name, r=cfg.lora_r, b_seed=cfg.lora_b_seed, grad_probe=is_routeV) + elif is_lora2r: + wrappers = wrap_model_with_lora2r( + model, model_name, CACHE_ROOT, device, r=cfg.lora_r, grad_probe=is_routeV) else: wrappers = wrap_model_with_antipasto( model, model_name, CACHE_ROOT, device, grad_probe=is_routeV, # routeV needs the per-rollout δS gate probe ) - # δS_hack receives gradients only under routeV and is removed at deployment. - delta_params = [info["delta_S"] for info in wrappers.values()] - delta_hack_params = [info["delta_S_hack"] for info in wrappers.values()] - logger.info(f"trainable delta_S: {sum(p.numel() for p in delta_params):,} " - f"(+{sum(p.numel() for p in delta_hack_params):,} delta_S_hack quarantine)") + if is_lora2r: + # A and B both train; quarantine = block slices of the same tensors, so + # there is no separate hack-param list (masks route grads, not surgery). + delta_params = [p for info in wrappers.values() for p in (info["delta_S"], info["B"])] + delta_hack_params = [] + n_quar = sum(info["delta_S"][info["r"]:].numel() + info["B"][:, info["r"]:].numel() + for info in wrappers.values()) + logger.info(f"trainable lora2r A+B: {sum(p.numel() for p in delta_params):,} " + f"({n_quar:,} of those in quarantine blocks)") + else: + # δS_hack receives gradients only under routeV and is removed at deployment. + delta_params = [info["delta_S"] for info in wrappers.values()] + delta_hack_params = [info["delta_S_hack"] for info in wrappers.values()] + logger.info(f"trainable delta_S: {sum(p.numel() for p in delta_params):,} " + f"(+{sum(p.numel() for p in delta_hack_params):,} delta_S_hack quarantine)") # ── hack direction: v_hack (erase) or v_grad (routeV) ── # Vanilla is pure GRPO; erase uses v_hack; routeV uses v_grad. @@ -319,6 +345,14 @@ def main(cfg: Config) -> int: assert _mean_bw > 0, ( f"real v_grad gave non-positive mean band width {_mean_bw:+.3f}: " "hack pairs do not separate from clean -> extraction broken") + if is_lora2r: + logger.info( + "lora2r three-way gate (SGTM-style): per-rollout label from the mean " + "band-normalized cosine across modules; clean->deployed-only, " + "hack->quarantine-only (deployed detached), mid->both (absorption). " + "SHOULD: rout (hack share) tracks the step's rollout hack rate, not ~50%; " + "clipfrac on clean-gated rollouts < ~0.2 ELSE the retain-trick ratio " + "drift is binding (quarantine forward too large).") # top-k subspace gate: oriented top-k right singular vectors of the per-pair # diff D=[n_pairs, r], each re-oriented hack-ward by sign(v_i . mean_diff), with # a max-over-k cosine band from the same pairs. Only the per-rollout grad_cosine @@ -597,15 +631,28 @@ def main(cfg: Config) -> int: pr = sum(r["gt_s"][0] + r["gt_t"][0] for r in rows) / max(1, n_gens) # Save the deployed adapter separately so it can be evaluated without quarantine state. _ckpt = path or ckpt_path - tensors = {n: info["delta_S"].detach().cpu().contiguous() - for n, info in wrappers.items()} + if is_lora2r: + # Deployed slices -> main ckpt; quarantine slices -> _hack file. + # A0/B0 are derivable (svd_cached on W), so only trained slices are stored. + tensors, hack_tensors = {}, {} + for n, info in wrappers.items(): + r_blk = info["r"] + A_cpu = info["delta_S"].detach().cpu() + B_cpu = info["B"].detach().cpu() + tensors[f"A/{n}"] = A_cpu[:r_blk].contiguous() + tensors[f"B/{n}"] = B_cpu[:, :r_blk].contiguous() + hack_tensors[f"A/{n}"] = A_cpu[r_blk:].contiguous() + hack_tensors[f"B/{n}"] = B_cpu[:, r_blk:].contiguous() + else: + tensors = {n: info["delta_S"].detach().cpu().contiguous() + for n, info in wrappers.items()} + hack_tensors = {n: info["delta_S_hack"].detach().cpu().contiguous() + for n, info in wrappers.items()} save_file(tensors, str(_ckpt), metadata={ "model": model_name, "dtype": "bf16", "step": str(len(rows)), "hack_rate": f"{hr:.6f}", "pass_rate": f"{pr:.6f}", "rows": json.dumps(rows), "cfg": json.dumps(vars(cfg), default=str), }) - hack_tensors = {n: info["delta_S_hack"].detach().cpu().contiguous() - for n, info in wrappers.items()} save_file(hack_tensors, str(_ckpt.with_name(_ckpt.stem + "_hack.safetensors")), metadata={"model": model_name, "step": str(len(rows))}) @@ -651,6 +698,7 @@ def main(cfg: Config) -> int: # Near-zero δS axes cannot recover per-rollout gradients, so routing lags one update there. GATE_EPS = 1e-6 step_flagged: list[float] = [] + step_clipfrac: list[float] = [] # lora2r: PPO clip frac on clean-gated rollouts (retain-trick drift gauge) step_zkeep: list[float] = []; step_zresid: list[float] = []; step_zrout: list[float] = [] # unit shares per zone step_zkeepE: list[float] = []; step_zresidE: list[float] = []; step_zroutE: list[float] = [] # energy shares per zone step_resid: list[float] = [] # cos(δS.grad AFTER routing, v_grad): hack-ward leak into deployed adapter @@ -791,6 +839,31 @@ def main(cfg: Config) -> int: step_resid.append((g_keep_roll @ vg / g_keep_roll.norm().clamp_min(1e-12)).item()) return g_keep + def _lora2r_gate_labels(c_grads: tuple, n_rollouts: int): + """Three-way SGTM-style label per rollout from the gate-pass c-probe grads. + + Per module: g_b = per-rollout weight grad of the virtual diagonal, deployed + block [r]; band-normalized cosine position p = (cos(g_b, v_grad)-lower)/width. + One GLOBAL label per rollout (mean p across modules, matching SGTM's + example-level labels): p<=0 clean (m=0,d=0); p>=1 hack (m=1,d=1); else mid + (m=1,d=0, absorption). Returns (m, d, f3, w): f3 in {0,.5,1} for _zone_stats, + w = mean per-rollout grad norm for energy weighting.""" + pos = torch.zeros(n_rollouts, device=device) + w = torch.zeros(n_rollouts, device=device) + for (name, info), cg in zip(wrappers.items(), c_grads, strict=True): + r_blk = info["r"] + g_b = cg.reshape(n_rollouts, -1, 2 * r_blk).sum(1)[:, :r_blk].float() # [G, r] deployed block + nrm = g_b.norm(dim=1) + cos_b = (g_b @ v_grad[name]) / nrm.clamp_min(1e-12) # [G] + lower, upper = route_band[name] + pos += (cos_b - lower) / max(upper - lower, 1e-6) + w += nrm + pos /= len(wrappers) + w /= len(wrappers) + m = (pos > 0).float() # mid + hack -> quarantine trains + d = (pos >= 1).float() # hack -> deployed detached + return m, d, 0.5 * m + 0.5 * d, w + def _act_vote_f_roll(n_rollouts: int, plen: int, comp_mask: torch.Tensor) -> torch.Tensor: """Global per-rollout routing fraction from the activation vote (act_vote gate). For each module: As_b = completion-mean(Vh@x) [G, r]; cos(As_b, As_dir); aggregate @@ -826,7 +899,9 @@ def main(cfg: Config) -> int: # routeV has no v_hack so cos_pre is NaN regardless: force the single combined # backward (the split would just double cost). The grad-mask reads its # per-rollout gate from that one backward. - split_this_step = (step % cfg.cos_pre_split_every == 0) and not is_routeV + # lora2r never splits: its grads live on A+B and accumulate in .grad directly + # (the split harvest only carries delta_S and would drop B). + split_this_step = (step % cfg.cos_pre_split_every == 0) and not is_routeV and not is_lora2r # Phase timers (per-step cumulative, seconds). Each GPU phase ends in a # CPU-blocking op (decode / .item()), so perf_counter is sync-accurate # without explicit cuda.synchronize. Tells us whether wall-time is @@ -1072,6 +1147,14 @@ def main(cfg: Config) -> int: if beta and beta > 0: logπ_ref = ref_logprobs_via_zero_delta(model, merged, wrappers, plen).detach() + # lora2r vanilla control: gate pinned clean (m=0, d=0) for the loss pass, + # so the quarantine block never trains -- the capacity/structure-matched + # deployed-only baseline, one code path with routeV. + if is_lora2r and not is_routeV: + _z = torch.zeros(merged.shape[0], device=device) + for info in wrappers.values(): + info["layer"]._lora2r_mask = (_z, _z) + logπ = per_token_logps( model(merged, logits_to_keep=L_c + 1).logits[:, :-1], completion_ids, @@ -1091,6 +1174,13 @@ def main(cfg: Config) -> int: if logπ_ref is not None: # K3 KL estimator Lp = Lp + beta * (torch.exp(logπ_ref - logπ) - (logπ_ref - logπ) - 1.0) + def _grpo_loss(Lp_: torch.Tensor) -> torch.Tensor: + """Full-batch GRPO loss (Dr.GRPO unbiased or per-rollout-normalized).""" + if cfg.unbiased: + return (Lp_ * mask).sum() / (group * max_new * prompts_per_step) + ptl = (Lp_ * mask).sum(1) / mask.sum(1).clamp_min(1) + return ptl.sum() / (group * prompts_per_step) + # Per-source split (loss_s + loss_t == full-batch loss because # is_s_v + is_t_v = 1 elementwise; backward is linear so # grad_s + grad_t == full-batch grad). Two backwards every step is @@ -1098,7 +1188,44 @@ def main(cfg: Config) -> int: is_s_v = torch.tensor(is_student, dtype=Lp.dtype, device=Lp.device).unsqueeze(1) # [G, 1] is_t_v = 1.0 - is_s_v - if split_this_step: + if is_lora2r: + # ── lora2r: SGTM-style three-way hard masking; grads ACCUMULATE on A/B ── + # Gradient-space labels exist only AFTER a backward (labels: before + # forward; activations: before backward; grads: after), so routeV pays a + # second masked forward+backward. intervention=none was pinned clean + # before the logπ forward and needs only this one pass. + loss = _grpo_loss(Lp) + if is_routeV: + # PASS 1 (gate): grads w.r.t. the c-probes ONLY. autograd.grad leaves + # A.grad/B.grad untouched, so nothing to zero between passes. + gates = [info["layer"]._lora2r_gate for info in wrappers.values()] + c_grads = torch.autograd.grad(loss, gates) + m_vec, d_vec, f3, w3 = _lora2r_gate_labels(c_grads, merged.shape[0]) + step_flagged.append(m_vec.mean().item()) + _kn, _rn, _on, _ke, _re, _oe = _zone_stats(f3, w3) + step_zkeep.append(_kn); step_zresid.append(_rn); step_zrout.append(_on) + step_zkeepE.append(_ke); step_zresidE.append(_re); step_zroutE.append(_oe) + # PASS 2 (masked): rollouts route to BLOCKS via masked sums; nothing + # is subtracted from any gradient vector (v_grad = classifier only). + for info in wrappers.values(): + info["layer"]._lora2r_mask = (m_vec, d_vec) + logπ2 = per_token_logps( + model(merged, logits_to_keep=L_c + 1).logits[:, :-1], completion_ids) + ρ2 = torch.exp(logπ2 - logπ_old) + loss = _grpo_loss(-torch.min(ρ2 * A_tok, + torch.clamp(ρ2, 1 - cfg.clip, 1 + cfg.clip) * A_tok)) + # Retain-trick wrinkle: clean rollouts were SAMPLED quarantine-on but + # TRAIN quarantine-off; the PPO ratio absorbs the gap, clip bounds it. + clean = m_vec == 0 + if clean.any(): + clipped = ((ρ2.detach() - 1).abs() > cfg.clip).float() + step_clipfrac.append( + ((clipped * mask)[clean].sum() / mask[clean].sum().clamp_min(1)).item()) + loss.backward() # masked pass; A/B grads accumulate across prompts (opt.zero_grad clears per step) + for info in wrappers.values(): + info["layer"]._lora2r_mask = None + agg_loss += loss.item() + elif split_this_step: if cfg.unbiased: denom = group * max_new * prompts_per_step loss_s = (Lp * mask * is_s_v).sum() / denom @@ -1132,12 +1259,7 @@ def main(cfg: Config) -> int: # Combined single backward: cheaper, no per-source diagnostic. # Accumulate into step_grad_s as the "combined" carrier; the # injection block below treats step_grad_t == {} as "use gs". - if cfg.unbiased: - denom = group * max_new * prompts_per_step - loss = (Lp * mask).sum() / denom - else: - ptl_norm = (Lp * mask).sum(1) / mask.sum(1).clamp_min(1) - loss = ptl_norm.sum() / (group * prompts_per_step) + loss = _grpo_loss(Lp) loss.backward() # act_vote: compute the ONE global f_roll for the step before per-module # routing (activations are cached on every layer from the loss forward). @@ -1199,6 +1321,10 @@ def main(cfg: Config) -> int: if is_routeV and step_flagged: logger.debug(f"routeV routed frac f (mean over modules*prompts): " f"{sum(step_flagged)/len(step_flagged):+.3f}") + if step_clipfrac: + logger.debug(f"lora2r clean-gated clipfrac={sum(step_clipfrac)/len(step_clipfrac):.3f} " + f"(SHOULD: <~0.2; higher = retain-trick ratio drift binding, " + f"quarantine forward effect too large)") else: if split_this_step: cos_pre_s = mean_cos_pre_from_grads(step_grad_s, v_hack) @@ -1229,8 +1355,20 @@ def main(cfg: Config) -> int: def _grad_l2(params): gs = [p.grad for p in params if p.grad is not None] return float(torch.norm(torch.stack([g.norm() for g in gs]))) if gs else 0.0 - gn_keep = _grad_l2(delta_params) - gn_quar = _grad_l2(delta_hack_params) + if is_lora2r: + # quarantine = block slices of A/B, not separate params + sq_keep = sq_quar = 0.0 + for info in wrappers.values(): + gA, gB = info["delta_S"].grad, info["B"].grad + if gA is None: + continue + r_blk = info["r"] + sq_keep += gA[:r_blk].float().pow(2).sum().item() + gB[:, :r_blk].float().pow(2).sum().item() + sq_quar += gA[r_blk:].float().pow(2).sum().item() + gB[:, r_blk:].float().pow(2).sum().item() + gn_keep, gn_quar = sq_keep ** 0.5, sq_quar ** 0.5 + else: + gn_keep = _grad_l2(delta_params) + gn_quar = _grad_l2(delta_hack_params) q_egy = gn_quar / (gn_keep + gn_quar) if (gn_keep + gn_quar) > 0 else 0.0 gn = float(torch.nn.utils.clip_grad_norm_(delta_params + delta_hack_params, cfg.grad_clip)) opt.step() @@ -1616,8 +1754,15 @@ def main(cfg: Config) -> int: hack_b_rate = hack_s_B_total / max(1, n_s_total) if half_a_codes else float("nan") # routeV must move quarantine; none and erase must leave it exactly zero. - dsh_norm = float(sum(info["delta_S_hack"].data.float().pow(2).sum().item() - for info in wrappers.values()) ** 0.5) + if is_lora2r: + # quarantine LEARNED delta = (A[r:], B[:,r:]) minus the frozen PiSSA init + dsh_norm = float(sum( + (info["delta_S"].data[info["r"]:] - info["A0"][info["r"]:]).float().pow(2).sum().item() + + (info["B"].data[:, info["r"]:] - info["B0"][:, info["r"]:]).float().pow(2).sum().item() + for info in wrappers.values()) ** 0.5) + else: + dsh_norm = float(sum(info["delta_S_hack"].data.float().pow(2).sum().item() + for info in wrappers.values()) ** 0.5) logger.info(f"||delta_S_hack|| = {dsh_norm:.4f} " f"(SHOULD: >0 for routeV, ==0 for none/erase; ELSE routing broke)") if is_routeV and cfg.routeV_random_v_seed is None: @@ -1759,6 +1904,7 @@ if __name__ == "__main__": "smoke": SmokeConfig, "fast": FastConfig, "fast-lora": FastLoraConfig, + "fast-lora2r": FastLora2rConfig, "full": FullConfig, }) sys.exit(main(cfg)) diff --git a/src/vgrout/train_config.py b/src/vgrout/train_config.py index e12e7b3..00909dd 100644 --- a/src/vgrout/train_config.py +++ b/src/vgrout/train_config.py @@ -7,12 +7,16 @@ from typing import Literal from .rewards import EnvMode +# lora2r = rank-2r PiSSA-init LoRA (A+B trainable) with SGTM-style three-way +# hard block masking; supports intervention none (gate pinned clean) | routeV. +Adapter = Literal["antipasto", "lora_frozen_b", "lora2r"] + @dataclass(kw_only=True) class Config: # arm: the gradient policy. routeV = per-rollout gate; routeV_per_token = per-token gate. intervention: Literal["none", "erase", "routeV", "routeV_per_token"] = "erase" - adapter: Literal["antipasto", "lora_frozen_b"] = "antipasto" + adapter: Adapter = "antipasto" lora_r: int = 32 lora_b_seed: int = 0 @@ -55,7 +59,8 @@ class Config: rollout_ablate_frac: float = 0.0 env_mode: EnvMode = "run_tests" - # eval3 keeps solve pressure alive by making 10% of training problems unhackable. + # eval3 keeps solve pressure alive by making 25% of training problems unhackable. + # (10% learned solve too slowly; 25% is the default from 2026-06-10 on.) eval: Literal["eval2", "eval3"] = "eval3" teacher_pool_dir: Path | None = None mix_ratio: float = 0.125 @@ -76,12 +81,15 @@ class Config: @property def unhackable_frac(self) -> float: - return {"eval2": 0.0, "eval3": 0.1}[self.eval] + return {"eval2": 0.0, "eval3": 0.25}[self.eval] @property def arm(self) -> str: - return {"none": "vanilla", "erase": "projected", + base = {"none": "vanilla", "erase": "projected", "routeV": "routingV", "routeV_per_token": "routingV_per_token"}[self.intervention] + # lora2r changes the routing logic (hard 3-way masks, structural separation), + # so it gets its own arm id -- old/new runs must not be conflated. + return f"{base}_lora2r" if self.adapter == "lora2r" else base @dataclass(kw_only=True) @@ -114,10 +122,20 @@ class FastConfig(Config): @dataclass(kw_only=True) class FastLoraConfig(FastConfig): # LoRA-frozen-B needs a lower learning rate because its gradient scale differs from delta_S. - adapter: Literal["antipasto", "lora_frozen_b"] = "lora_frozen_b" + adapter: Adapter = "lora_frozen_b" lr: float = 1e-4 +@dataclass(kw_only=True) +class FastLora2rConfig(FastConfig): + # Rank-2r PiSSA-init LoRA + SGTM three-way masking. weight_decay MUST be 0: + # AdamW decays the raw A/B toward 0, not toward the PiSSA init, which would + # drive the net delta to -B0@A0 (subtracting W's top-2r spectral part). + adapter: Adapter = "lora2r" + lr: float = 1e-4 + weight_decay: float = 0.0 + + @dataclass(kw_only=True) class FullConfig(Config): model: str = "Qwen/Qwen3-4B"