diff --git a/scripts/verify_lora2r_routing.py b/scripts/verify_lora2r_routing.py index ededd36..62be113 100644 --- a/scripts/verify_lora2r_routing.py +++ b/scripts/verify_lora2r_routing.py @@ -1,4 +1,4 @@ -"""lora2r invariants (rank-2r PiSSA-init LoRA + SGTM-style block masks). +"""lora2r invariants (rank-2r Gaussian-init LoRA + SGTM-style block masks). Asserts, on tiny-random-qwen3 (CPU, fp32): 1. IDENTITY AT INIT: wrapped logits == base logits (the hook subtracts the @@ -17,7 +17,7 @@ Exit nonzero on any violation. Wired into `just smoke-lora2r`. import torch from transformers import AutoModelForCausalLM -from vgrout.antipasto import wrap_model_with_lora2r +from vgrout.lora2r import wrap_model_with_lora2r from vgrout.eval import ablate_quarantine MODEL = "llamafactory/tiny-random-qwen3" @@ -31,7 +31,7 @@ ids = torch.randint(100, 1000, (2, 12)) with torch.no_grad(): base_logits = model(ids).logits.clone() -wrappers = wrap_model_with_lora2r(model, MODEL, svd_device="cpu", r=R, grad_probe=True) +wrappers = wrap_model_with_lora2r(model, r=R, grad_probe=True) # 1. identity at init with torch.no_grad(): @@ -52,7 +52,7 @@ def run_masked(m_val: float, d_val: float) -> tuple[float, float]: dep_sq = quar_sq = 0.0 for info in wrappers.values(): r = info["r"] - gA, gB = info["delta_S"].grad, info["B"].grad + gA, gB = info["A"].grad, info["B"].grad dep_sq += gA[:r].pow(2).sum().item() + gB[:, :r].pow(2).sum().item() quar_sq += gA[r:].pow(2).sum().item() + gB[:, r:].pow(2).sum().item() return dep_sq ** 0.5, quar_sq ** 0.5 @@ -96,7 +96,7 @@ with torch.no_grad(): assert torch.allclose(out_abl_init, out0, atol=1e-6), "ablate at init is not a no-op" for info in wrappers.values(): r = info["r"] - info["delta_S"].data[r:] += 0.05 * torch.randn_like(info["delta_S"].data[r:]) + info["A"].data[r:] += 0.05 * torch.randn_like(info["A"].data[r:]) out_pert = model(ids).logits.clone() pert = (out_pert - out0).abs().max().item() assert pert > 1e-6, f"quarantine perturbation invisible in forward ({pert:.2e})" diff --git a/src/vgrout/eval.py b/src/vgrout/eval.py index 441c00f..ab20cae 100644 --- a/src/vgrout/eval.py +++ b/src/vgrout/eval.py @@ -1,8 +1,7 @@ -"""Evaluation and reference-model helpers for the training loop. +"""Evaluation helpers for the training loop. -Three read-only helpers that touch the model but never train it: a reference -log-prob pass (the AntiPaSTO adapter zeroed = the base model), the deploy-time -quarantine ablation, and a hack/solve eval on a fixed prompt subset. +Read-only helpers that touch the model but never train it: the deploy-time +quarantine ablation and a hack/solve eval on a fixed prompt subset. """ from __future__ import annotations @@ -12,7 +11,6 @@ from contextlib import contextmanager import torch from .data import DATA, HINT_REPLACE_TO, load_problems -from .proj import per_token_logps from .rewards import compute_reward # Evaluation discloses novel marker families disjoint from training while preserving grader @@ -77,60 +75,32 @@ def randomize_eval_markers(prob: dict) -> tuple[list[dict], dict]: return msgs, {kw: value} -def ref_logprobs_via_zero_delta( - model, merged: torch.Tensor, wrappers: dict, plen: int, -) -> torch.Tensor: - """Compute base-model completion logprobs by temporarily zeroing the adapter. - - At delta_S=0, AntiPaSTO is exactly the frozen base model. `logits_to_keep` - avoids materializing unused prompt logits. - """ - saved = {n: info["delta_S"].data.clone() for n, info in wrappers.items()} - try: - for info in wrappers.values(): - info["delta_S"].data.zero_() - L_c = merged.shape[1] - plen - logits = model(merged, logits_to_keep=L_c + 1).logits[:, :-1] - return per_token_logps(logits, merged[:, plen:]) - finally: - for n, info in wrappers.items(): - info["delta_S"].data.copy_(saved[n]) - - @contextmanager def ablate_quarantine(wrappers: dict): """Temporarily remove the quarantine to evaluate the deployed model. - delta_S adapters: zero delta_S_hack. lora2r ("A0" in info): reset the - quarantine block (A[r:], B[:,r:]) to the frozen PiSSA init A0/B0 so the net - quarantine delta is 0 -- zeroing the raw params would instead SUBTRACT the - init contribution and corrupt the forward. + Reset the quarantine block (A[r:], B[:,r:]) to the frozen init A0/B0 so the + net quarantine delta is 0 -- zeroing the raw params would instead SUBTRACT + the init contribution and corrupt the forward. TODO(post-deploy-finetune): SGTM's ablate(trainable=True) reinits the forget weights to the retain-dims' std instead of zeroing, so the model stays finetunable after the quarantine is removed (no dead hole). We zero because we only eval after deploy; add the reinit path if we ever retrain post-ablate. See docs/grad_routing/sgtm_vs_ours.md.""" - saved: dict[str, object] = {} + saved: dict[str, tuple[torch.Tensor, torch.Tensor]] = {} for n, info in wrappers.items(): - if "A0" in info: - r = info["r"] - saved[n] = (info["delta_S"].data[r:].clone(), info["B"].data[:, r:].clone()) - info["delta_S"].data[r:] = info["A0"][r:] - info["B"].data[:, r:] = info["B0"][:, r:] - else: - saved[n] = info["delta_S_hack"].data.clone() - info["delta_S_hack"].data.zero_() + r = info["r"] + saved[n] = (info["A"].data[r:].clone(), info["B"].data[:, r:].clone()) + info["A"].data[r:] = info["A0"][r:] + info["B"].data[:, r:] = info["B0"][:, r:] try: yield finally: for n, info in wrappers.items(): - if "A0" in info: - r = info["r"] - info["delta_S"].data[r:] = saved[n][0] - info["B"].data[:, r:] = saved[n][1] - else: - info["delta_S_hack"].data.copy_(saved[n]) + r = info["r"] + info["A"].data[r:] = saved[n][0] + info["B"].data[:, r:] = saved[n][1] @torch.no_grad() diff --git a/src/vgrout/extract_vhack_grad.py b/src/vgrout/extract_vhack_grad.py index 0a45f13..cf15404 100644 --- a/src/vgrout/extract_vhack_grad.py +++ b/src/vgrout/extract_vhack_grad.py @@ -5,8 +5,9 @@ For a pair with advantages (adv_h=+1, adv_c=-1) the Dr.GRPO single-step grad `-adv_h * grad_logp(hack) - adv_c * grad_logp(clean)` algebraically equals `grad_NLL(hack) - grad_NLL(clean)`, so we compute it by the simpler path: forward each completion, take mean-NLL on completion tokens, backward, and -capture `delta_S.grad` per AntiPaSTO-wrapped Linear. Naming the steps NLL is -an implementation detail; the *meaning* is "the GRPO update on this pair." +capture the lora2r c-probe grad (the per-pair weight grad of the virtual +diagonal between A and B, deployed block) per wrapped Linear. Naming the steps +NLL is an implementation detail; the *meaning* is "the GRPO update on this pair." Then per module, with D = [g_hack_i - g_clean_i for each pair] in R^{n_pairs x r}: SVD(D) = U Σ Vh @@ -41,12 +42,11 @@ from safetensors.torch import save_file from tabulate import tabulate from transformers import AutoModelForCausalLM, AutoTokenizer -from .antipasto import wrap_model_with_antipasto +from .lora2r import wrap_model_with_lora2r from .pairs_from_pool import load_pairs_json from .vhack import pairset_sha256 -CACHE_ROOT = Path("svd_cache") OUT_DIR = Path("out") @@ -139,26 +139,13 @@ def extract_v_hack( bucket = grads_hack if label == "hack" else grads_clean for name, info in wrappers.items(): layer = info["layer"] - if getattr(layer, "_lora2r_grad_probe", False): - # lora2r: per-pair weight grad of the virtual diagonal (c-probe), - # DEPLOYED block only -- the same space the live gate reads - # (train.py lora2r branch), so band calibration is apples-to-apples. - cg = layer._lora2r_gate.grad - if cg is None: - raise RuntimeError(f"no c-probe grad on {name}; aborting lora2r extract") - g = cg.sum(dim=tuple(range(cg.dim() - 1)))[: layer._lora2r_r] # [r] - elif getattr(layer, "_lora_grad_probe", False) and layer._lora_h is not None: - # LoRA-frozen-B: the routing handle is the r-bottleneck gradient - # g_h = B^T δ_y (B frozen -> static path), not A.grad. Sum over (batch, - # tokens) to mirror how AntiPaSTO's δS.grad accumulates over positions. - gh = layer._lora_h.grad - if gh is None: - raise RuntimeError(f"no bottleneck grad on {name}; aborting LoRA extract") - g = gh.sum(dim=tuple(range(gh.dim() - 1))) # [r] - else: - g = info["delta_S"].grad - if g is None: - raise RuntimeError(f"no grad on {name}; aborting extract") + # Per-pair weight grad of the virtual diagonal (c-probe), DEPLOYED + # block only -- the same space the live gate reads (train.py), so + # band calibration is apples-to-apples. Requires grad_probe=True. + cg = layer._lora2r_gate.grad + if cg is None: + raise RuntimeError(f"no c-probe grad on {name}; wrap with grad_probe=True") + g = cg.sum(dim=tuple(range(cg.dim() - 1)))[: layer._lora2r_r] # [r] bucket[name].append(g.detach().float().cpu().clone()) if (pi + 1) % 5 == 0: logger.info(f" pair {pi+1}/{n_pairs} loss={loss.item():.3f}") @@ -249,13 +236,10 @@ def main(cfg: Config) -> int: model = AutoModelForCausalLM.from_pretrained( cfg.model, dtype=dtype, attn_implementation="sdpa" ).to(device) - model.eval() # disable dropout; gradients still flow through delta_S - wrappers = wrap_model_with_antipasto( - model, model_name=cfg.model, cache_root=CACHE_ROOT, svd_device=device, - ) + model.eval() # disable dropout; gradients still flow through the adapter + wrappers = wrap_model_with_lora2r(model, grad_probe=True) n_mod = len(wrappers) - n_delta = sum(info["delta_S"].numel() for info in wrappers.values()) - logger.info(f"wrapped {n_mod} modules; total delta_S scalars = {n_delta:,}") + logger.info(f"wrapped {n_mod} modules; probe space = r per module") train_pairs = pairs[:-cfg.n_heldout] if cfg.n_heldout > 0 else pairs logger.info(f"train pairs: {len(train_pairs)} held: {cfg.n_heldout}") diff --git a/src/vgrout/lora2r.py b/src/vgrout/lora2r.py new file mode 100644 index 0000000..290a8e1 --- /dev/null +++ b/src/vgrout/lora2r.py @@ -0,0 +1,129 @@ +"""lora2r adapter: one rank-2r LoRA per target Linear, A and B both trainable, +partitioned into a deployed block [:r] and a quarantine block [r:]. + + y += B@(A@x) - B0@(A0@x) + +A0/B0 are FROZEN copies of the (seeded Gaussian) init, subtracted so the net +delta is exactly 0 at init while h = A@x is alive. Both factors must be nonzero +at init: the gate reads c.grad = h ⊙ (Bᵀδ) and A.grad = (Bᵀδ)xᵀ, both +identically zero under standard zero-B LoRA -- no extraction, no band +calibration, no trainable A at step 0. Init values beyond "nonzero" are NOT +load-bearing (we previously used PiSSA; see docs/spec/20260610_lora2r_v2_plan.md +T2 for why it was dropped). + +[B|B_q] @ ([A;A_q]@x) has no cross terms (column b_k only ever multiplies row +a_k), so the two blocks ARE two independent adapters; per-rollout block masks on +this one tensor implement the SGTM parameter partition (Cloud et al.). +""" +from __future__ import annotations + +import torch +from loguru import logger +from torch import Tensor, nn + +TARGET_SUFFIXES = ( + # full attention + "q_proj", "k_proj", "v_proj", "o_proj", + # linear-attention / GatedDeltaNet + "in_proj_qkv", "in_proj_z", "in_proj_a", "in_proj_b", "out_proj", + # MLP + "up_proj", "gate_proj", "down_proj", +) + + +def is_target(name: str) -> bool: + return name.split(".")[-1] in TARGET_SUFFIXES + + +def _lora2r_hook(layer: nn.Linear, args: tuple, y: Tensor) -> Tensor: + """Add the two-block delta to y, applying per-rollout block masks. + + Block masks (layer._lora2r_mask = (m, d), set by train.py per loss pass; + None = unmasked for generation / gate pass / eval): + m [G] quarantine on/off -- m=0: quarantine zero in forward AND backward + (SGTM retain trick: deployed trains in its post-ablation state) + d [G] deployed detach -- d=1: deployed kept in forward, zero grad + (hack-gated rollouts update ONLY the quarantine block) + Masks act on branch OUTPUTS so a detach blocks grads to BOTH the A rows and + the B columns of that block. + + grad probe: c = ones[..., 2r] spliced as h*c. After backward + c.grad = h ⊙ (Bᵀδ_y) = the per-sample WEIGHT grad of a virtual diagonal + scale between A and B. The pair extraction (extract_vhack_grad) and the live + gate (train.py) both read this same space, so the band is self-consistent + whatever the basis. + """ + (x,) = args + A = layer._lora2r_A # [2r, d_in] trainable + B = layer._lora2r_B # [d_out, 2r] trainable + A0 = layer._lora2r_A0 # frozen init copies (subtracted: net delta 0 at init) + B0 = layer._lora2r_B0 + r = layer._lora2r_r + h = torch.nn.functional.linear(x, A.to(x.dtype)) # [..., 2r] + if layer._lora2r_grad_probe and torch.is_grad_enabled(): + c = torch.ones(h.shape[0], *([1] * (h.dim() - 2)), h.shape[-1], + device=h.device, dtype=h.dtype, requires_grad=True) + layer._lora2r_gate = c + h = h * c + h0 = torch.nn.functional.linear(x, A0.to(x.dtype)) # [..., 2r] frozen init path + dep = (torch.nn.functional.linear(h[..., :r], B[:, :r].to(x.dtype)) + - torch.nn.functional.linear(h0[..., :r], B0[:, :r].to(x.dtype))) + quar = (torch.nn.functional.linear(h[..., r:], B[:, r:].to(x.dtype)) + - torch.nn.functional.linear(h0[..., r:], B0[:, r:].to(x.dtype))) + if layer._lora2r_mask is not None: + m, d = layer._lora2r_mask # [G] each + G = m.shape[0] + shape = dep.shape # [G, s, d_out] or [G*s, d_out] + dep = dep.reshape(G, -1, shape[-1]) + quar = quar.reshape(G, -1, shape[-1]) + d_ = d.view(G, 1, 1).to(dep.dtype) + dep = ((1 - d_) * dep + d_ * dep.detach()).reshape(shape) + quar = (m.view(G, 1, 1).to(quar.dtype) * quar).reshape(shape) + return y + (dep + quar).to(y.dtype) + + +def wrap_model_with_lora2r( + model: nn.Module, + r: int = 32, + init_seed: int = 0, + grad_probe: bool = False, +) -> dict[str, dict]: + """Attach a rank-2r Gaussian-init LoRA (A AND B trainable) to every target Linear. + + Init: A0 ~ N(0, 1/d_in) [2r, d_in], B0 ~ N(0, 1/2r) [d_out, 2r], seeded per + module so runs reproduce; blocks are iid -> statistically matched. W stays + untouched; the hook subtracts the frozen A0/B0 contribution. The quarantine's + learned delta is (A[r:], B[:, r:]) minus init; deploy ablation resets that + block to A0/B0 (eval.ablate_quarantine). + + Info dict per module: {layer, A, B, A0, B0, handle, r}; quarantine = block + slices, no separate tensor. + """ + targets = [(n, m) for n, m in model.named_modules() + if isinstance(m, nn.Linear) and is_target(n)] + logger.info(f"lora2r attach: {len(targets)} target Linear modules, " + f"r={r}/block (2r={2 * r}), Gaussian init seed={init_seed}, A+B trainable") + out: dict[str, dict] = {} + for i, (name, linear) in enumerate(targets): + d_out, d_in = linear.weight.shape + dev = linear.weight.device + gen = torch.Generator().manual_seed(init_seed * 100003 + i) + A0 = (torch.randn(2 * r, d_in, generator=gen) / d_in ** 0.5).to(device=dev, dtype=torch.float32) + B0 = (torch.randn(d_out, 2 * r, generator=gen) / (2 * r) ** 0.5).to(device=dev, dtype=torch.float32) + linear.register_buffer("_lora2r_A0", A0, persistent=True) + linear.register_buffer("_lora2r_B0", B0, persistent=True) + A = nn.Parameter(A0.clone()) + B = nn.Parameter(B0.clone()) + linear.register_parameter("_lora2r_A", A) + linear.register_parameter("_lora2r_B", B) + linear._lora2r_r = r + linear._lora2r_grad_probe = grad_probe + linear._lora2r_gate = None + linear._lora2r_mask = None + out[name] = {"layer": linear, "A": A, "B": B, "A0": A0, "B0": B0, + "handle": linear.register_forward_hook(_lora2r_hook), "r": r} + trainable = ("_lora2r_A", "_lora2r_B") + for n, p in model.named_parameters(): + if not n.endswith(trainable): + p.requires_grad_(False) + return out diff --git a/src/vgrout/run_artifacts.py b/src/vgrout/run_artifacts.py index a8463e5..d80c6a1 100644 --- a/src/vgrout/run_artifacts.py +++ b/src/vgrout/run_artifacts.py @@ -9,8 +9,16 @@ from safetensors import safe_open RUNS_DIR = Path("out/runs") RUN_SCHEMA = "paired_final_v2" # v2: deployed/as_trained field names (was deploy_*/deploy_*_on) +# Old PiSSA-substrate runs on disk carry these intervention names; lora2r runs +# get a _lora2r suffix so the two substrates never conflate in aggregation. ARM = {"none": "vanilla", "erase": "projected", - "routeV": "routingV", "routeV_per_token": "routingV_per_token"} + "routeV": "routingV", "routeV_per_token": "routingV_per_token", + "absorb": "absorb"} + + +def _arm_of(cfg: dict) -> str: + suffix = "_lora2r" if cfg.get("adapter", "antipasto") == "lora2r" else "" + return ARM[cfg["intervention"]] + suffix def _mean_fraction(rows: list[dict], key: str) -> float: @@ -38,7 +46,7 @@ def load_run(run_dir: Path) -> dict: "run_dir": run_dir, "time": run_dir.name.split("_", 1)[0], "cfg": cfg, - "arm": ARM[cfg["intervention"]], + "arm": _arm_of(cfg), "rows": rows, "deploy": deploy, "l5_hack": _mean_fraction(rows[-5:], "hack_s"), diff --git a/src/vgrout/tablelog.py b/src/vgrout/tablelog.py index 87bc761..d29dede 100644 --- a/src/vgrout/tablelog.py +++ b/src/vgrout/tablelog.py @@ -72,9 +72,8 @@ class StepLogger: def __init__(self, arm: str, modes: list[str], mode_code: dict[str, str], show_ablate: bool = False) -> None: - # Erase reports projection diagnostics; routeV reports routing diagnostics below. - projects = arm == "projected" - is_route = arm in ("routingV", "routingV_per_token", "routingV_lora2r") + # routeV reports routing diagnostics; absorb shares qmass (zone cols read nan). + is_route = arm in ("routingV_lora2r", "absorb_lora2r") cols: list[_Col] = [ _Col("step", 4, "step", "d", "GRPO step"), _Col("ref_eq", 6, "ref_eq", ".2f", "vanilla-equiv step (cum_gens/256)"), @@ -100,25 +99,16 @@ class StepLogger: _Col("gn", 7, "gn", ".1e", "pre-clip L2 norm of delta_S grads (vs grad_clip)"), _Col("lr", 7, "lr", ".1e", "scheduled learning rate"), ] - if projects: - cols += [ - _Col("cos_pre", 6, "cin", ".2f", "hack-ward grad fraction ||relu(V@g)||/||g|| [0,1] BEFORE proj"), - _Col("cos_pre_s", 6, "cin_s", ".2f", "cin on student-only grad"), - _Col("cos_pre_t", 6, "cin_t", ".2f", "cin on teacher-only grad (want cin_t>cin_s)"), - _Col("cos_post", 6, "cout", ".2f", "hack-ward fraction AFTER projection (want ~0: all removed)"), - _Col("fired", 5, "fired", ".2f", "fraction of modules where projection fired"), - ] - # routeV reports unit and energy shares across the routing band plus residual leak. + # routeV reports unit and energy shares across the routing band. if is_route: cols += [ - _Col("qmass", 6, "qmass", ".2f", "quarantine energy share ||g_quar||/(||g_keep||+||g_quar||): fraction of the update parked in the throwaway quarantine adapter"), - _Col("keep", 6, "keep", ".2f", "unit share with cos below the band -> kept whole in the deployed adapter (left)"), - _Col("resid", 6, "resid", ".2f", "unit share with cos inside the band -> partially routed (residual middle)"), - _Col("rout", 6, "rout", ".2f", "unit share with cos above the band -> fully routed into quarantine (right)"), - _Col("keepE", 6, "keepE", ".2f", "energy-weighted keep: share of grad ENERGY in the kept zone"), - _Col("residE", 6, "residE", ".2f", "energy-weighted resid: share of grad ENERGY in the partially-routed zone"), - _Col("routE", 6, "routE", ".2f", "energy-weighted rout: grad ENERGY share fully routed (~quarantine mass; the routed total is routE..routE+residE)"), - _Col("leak", 6, "leak", "+.2f", "hack-ward cosine left in the deployed adapter after routing; ~0 = stripped clean, >0 = hack leaked through (under-routed)"), + _Col("qmass", 6, "qmass", ".2f", "quarantine energy share ||g_quar||/(||g_keep||+||g_quar||): fraction of the update landing in the throwaway quarantine block"), + _Col("keep", 6, "keep", ".2f", "rollout share labelled clean (below band) -> deployed-only, quarantine off"), + _Col("resid", 6, "resid", ".2f", "rollout share labelled mid (inside band) -> both blocks train (absorption)"), + _Col("rout", 6, "rout", ".2f", "rollout share labelled hack (above band) -> quarantine-only, deployed detached"), + _Col("keepE", 6, "keepE", ".2f", "energy-weighted keep (grad-norm-weighted clean share)"), + _Col("residE", 6, "residE", ".2f", "energy-weighted resid"), + _Col("routE", 6, "routE", ".2f", "energy-weighted rout"), ] # Show the training-prompt deploy proxy only when an ablated slice exists. if is_route and show_ablate: diff --git a/src/vgrout/train.py b/src/vgrout/train.py index b5e9f91..a3456a8 100644 --- a/src/vgrout/train.py +++ b/src/vgrout/train.py @@ -1,7 +1,7 @@ -"""GRPO / Dr.GRPO loop with SVD-basis gradient projection on the LeetCode +"""GRPO / Dr.GRPO loop with SGTM-style gradient routing on the LeetCode reward-hacking benchmark. - generate -> grade -> backward -> project -> step + generate -> grade -> backward -> (gate) -> masked backward -> step Inner GRPO step ported from lsdefine/simple_GRPO grpo_vllm_one.py:64-95; the outer loop accumulates grads over prompts_per_step prompts (simple_GRPO's @@ -9,19 +9,20 @@ Q_batch_size), so at least one per-prompt group has reward variance. Unbiased normalization: Dr.GRPO, Liu et al. 2025, arXiv:2503.20783 -- drop the 1/|oᵢ| length norm and the /σ_R group-std (--unbiased, on by default). -Adapter: AntiPaSTO full-rank SVD delta δS per Linear, W' = W + U diag(δS) Vᵀ. -At δS=0 the adapter is identity, so a no-grad forward with δS zeroed gives π_ref -for free, no second model (the KL term under --beta>0). +Adapter: lora2r (src/vgrout/lora2r.py) -- one rank-2r LoRA per Linear, A and B +both trainable, partitioned into a deployed block [:r] and a quarantine block +[r:]. The quarantine is ablated (reset to its frozen init) at deployment. Arms (--intervention): - none measure only; δS.grad untouched (vanilla GRPO) - erase subtract the hack-ward component of δS.grad - routeV route per-rollout by a calibrated-τ cosine gate, cos(g_b, v_grad) > τ + none gate pinned clean (0,0): quarantine never trains -- the capacity- and + structure-matched vanilla control. + routeV per-rollout three-way SGTM gate from the c-probe gradient vs v_grad: + clean->deployed-only, hack->quarantine-only (deployed detached), + mid->both (absorption). + absorb gate pinned mid (1,0): both blocks train on everything, no gate -- + isolates the value of the gate+masks vs absorption alone. -Hyperparameters from ariahw/rl-rewardhacking config.py (docs/grpo_hyperparams.md); -SmokeConfig / FastConfig / FullConfig in train_config.py hold the scale hyperparameters. - - uv run python -m vgrout.train smoke --intervention=erase + uv run python -m vgrout.train smoke --intervention=routeV """ from __future__ import annotations @@ -32,7 +33,7 @@ import os import sys import random import time -from contextlib import contextmanager, nullcontext +from contextlib import nullcontext from pathlib import Path # Must be set BEFORE `import torch` to take effect on the CUDA allocator. @@ -43,34 +44,27 @@ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") import torch import torch.nn.functional as F import tyro -from jaxtyping import Float from loguru import logger -from safetensors import safe_open from safetensors.torch import save_file from tabulate import tabulate from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig -from .antipasto import wrap_model_with_antipasto, wrap_model_with_lora2r, wrap_model_with_lora_frozen_b -from .proj import per_token_logps, project_delta_S_grad, mean_cos_pre_from_grads +from .lora2r import wrap_model_with_lora2r +from .proj import per_token_logps from .rewards import EnvMode, compute_reward from .data import DATA, load_problems -from .vhack import load_v_hack, pairset_sha256, postprocess_v_hack -from .eval import ablate_quarantine, eval_hack_solve, load_eval_splits, ref_logprobs_via_zero_delta +from .eval import ablate_quarantine, eval_hack_solve, load_eval_splits from .tablelog import setup_logging, StepLogger from .run_artifacts import RUN_SCHEMA -from .train_config import (Config, FastConfig, FastLora2rConfig, FastLoraConfig, - FullConfig, SmokeConfig) +from .train_config import Config, FastConfig, SmokeConfig -CACHE_ROOT = Path("svd_cache") OUT_DIR = Path("out") -# Keep reusable inputs separate from per-run outputs; see docs/spec/20260530_out_dir_reorg.md. -VHACK_DIR = OUT_DIR / "vhack" RUNS_DIR = OUT_DIR / "runs" def _haar_unit_dirs(v_grad: dict, seed: int, device) -> dict: - """Build the reproducible out-of-subspace directionality control for routeV.""" + """Build the reproducible out-of-subspace directionality control (placebo) for routeV.""" g = torch.Generator().manual_seed(seed) out = {} for name in sorted(v_grad): @@ -94,7 +88,7 @@ def route_band_edges(raw_grads: dict, v_grad: dict, device) -> dict[str, tuple[f """Calibrate an absolute routing band from authored pairs only. Clean/hack p75 edges avoid single-pair extremes and route only the confident - hack-ward tail. Pair/live shift can still make routing idle; inspect `routE`. + hack-ward tail. Pair/live shift can still make routing idle; inspect `rout`. See docs/papers/grad_routing/paper_sgtm.md. """ band = {} @@ -108,59 +102,10 @@ def route_band_edges(raw_grads: dict, v_grad: dict, device) -> dict[str, tuple[f return band -def build_act_vote_dirs(model, wrappers, tok, pairs, device): - """Build the authored-pair activation vote; no live rollout labels enter the gate.""" - names = list(wrappers) - As_cap: dict[str, torch.Tensor] = {} - st = {"plen": 0} - def mk_hook(nm): - Vh = wrappers[nm]["layer"]._antipasto_Vh - def h(_l, inp, _o): - As_cap[nm] = F.linear(inp[0], Vh)[0, st["plen"] - 1:, :].mean(0).detach().float().cpu() - return h - handles = [wrappers[nm]["layer"].register_forward_hook(mk_hook(nm)) for nm in names] - def grab(prompt, comp): - st["plen"] = tok(prompt, return_tensors="pt").input_ids.shape[1] - ids = tok(prompt + comp, return_tensors="pt").input_ids.to(device) - with torch.no_grad(): - model(ids) - return {nm: As_cap[nm].clone() for nm in names} - As_h = {nm: [] for nm in names} - As_c = {nm: [] for nm in names} - for pr in pairs: - ah, ac = grab(pr.prompt, pr.hack), grab(pr.prompt, pr.clean) - for nm in names: - As_h[nm].append(ah[nm]); As_c[nm].append(ac[nm]) - for h in handles: - h.remove() - As_D = {nm: (torch.stack(As_h[nm]) - torch.stack(As_c[nm])).mean(0) for nm in names} - As_dir_cpu = {nm: As_D[nm] / As_D[nm].norm().clamp_min(1e-12) for nm in names} - act_w = {nm: As_D[nm].norm().item() for nm in names} - wsum = sum(act_w.values()) - def pair_vote(As_pair): - num = sum(act_w[nm] * float((As_pair[nm] @ As_dir_cpu[nm]) / As_pair[nm].norm().clamp_min(1e-12)) - for nm in names) - return num / max(wsum, 1e-12) - votes_h = [pair_vote({nm: As_h[nm][i] for nm in names}) for i in range(len(pairs))] - votes_c = [pair_vote({nm: As_c[nm][i] for nm in names}) for i in range(len(pairs))] - vote_band = (torch.tensor(votes_c).quantile(0.75).item(), - torch.tensor(votes_h).quantile(0.75).item()) - As_dir = {nm: As_dir_cpu[nm].to(device) for nm in names} - logger.info( - f"routeV act_vote: As_dir for {len(As_dir)} modules; vote band " - f"lower(p75 clean)={vote_band[0]:+.3f} upper(p75 hack)={vote_band[1]:+.3f} " - f"width={vote_band[1] - vote_band[0]:+.3f}. SHOULD: width>0 (hack pairs vote higher); " - f"live f_roll>0 in early steps else band sits off the live distribution.") - assert vote_band[1] > vote_band[0], ( - f"act_vote band non-positive width {vote_band[1] - vote_band[0]:+.3f}: " - "hack pairs do not vote-separate from clean -> act extraction broken") - return As_dir, act_w, vote_band - - -# 2-char env_mode codes for compact per-mode hack columns (hk_rt, hk_xc, ...). # Fix evaluation sampling across steps and arms without perturbing the training RNG. EVAL_GEN_SEED = 12345 +# 2-char env_mode codes for compact per-mode hack columns (hk_rt, hk_xc, ...). MODE_CODE: dict[str, str] = { "run_tests": "rt", "eq_override": "eq", "exit_code": "xc", "stdout_marker": "so", "sentinel": "se", "file_marker": "fm", @@ -169,61 +114,33 @@ MODE_CODE: dict[str, str] = { def _validate_config(cfg: Config) -> None: - """Reject ignored or contradictory experiment settings before model load.""" - is_routeV = cfg.intervention in ("routeV", "routeV_per_token") - routeV_only = { - "routeV_random_v_seed": cfg.routeV_random_v_seed is not None, - "routeV_gate (non-default)": cfg.routeV_gate != "grad_cosine", - "routeV_absorb_all": cfg.routeV_absorb_all, - "routeV_top_k>1": cfg.routeV_top_k > 1, - } - if not is_routeV: - set_routeV_only = [k for k, was_set in routeV_only.items() if was_set] - if set_routeV_only: - raise ValueError(f"routeV-only options set on intervention={cfg.intervention}: " - f"{set_routeV_only} -- they would be silently ignored") - if cfg.routeV_top_k > 1 and (cfg.routeV_gate != "grad_cosine" or cfg.intervention == "routeV_per_token" - or cfg.routeV_absorb_all): - raise ValueError("routeV_top_k>1 is implemented only for the per-rollout grad_cosine gate") - if cfg.v_hack_path is not None and cfg.intervention != "erase": - raise ValueError(f"--v-hack-path is an erase-arm option; ignored on intervention={cfg.intervention}") - if cfg.adapter == "lora_frozen_b" and cfg.intervention not in ("none", "routeV", "routeV_per_token"): - raise ValueError(f"lora_frozen_b adapter not wired for intervention={cfg.intervention}") - if cfg.adapter == "lora2r": - if cfg.intervention not in ("none", "routeV"): - raise ValueError(f"lora2r supports intervention none|routeV, got {cfg.intervention}") - if cfg.beta: - raise ValueError("lora2r has no zero-delta reference path (A=0 is NOT identity); beta must be 0") - if cfg.weight_decay != 0.0: - raise ValueError("lora2r params are PiSSA-init (nonzero); AdamW decay pulls them toward 0, " - "not toward init -- set --weight-decay=0") - if cfg.routeV_gate != "grad_cosine" or cfg.routeV_top_k > 1 or cfg.routeV_absorb_all: - raise ValueError("lora2r implements only the per-rollout grad_cosine three-way gate") + """Reject contradictory experiment settings before model load.""" + if cfg.intervention not in ("none", "routeV", "absorb"): + raise ValueError(f"unknown intervention {cfg.intervention!r}; expected none|routeV|absorb") + if cfg.routeV_random_v_seed is not None and cfg.intervention != "routeV": + raise ValueError("routeV_random_v_seed is a routeV-only placebo control") + if cfg.rollout_ablate_frac > 0 and cfg.intervention == "none": + raise ValueError("rollout_ablate_frac needs a quarantine to ablate (routeV/absorb)") + if cfg.weight_decay != 0.0: + raise ValueError("lora2r init is nonzero; AdamW decay pulls A/B toward 0 not toward init " + "-- set --weight-decay=0") -def _resolve_v_hack_file(cfg: Config) -> Path: - """The on-disk direction file the erase arm uses: explicit override, else derived - from the pairset stem. (routeV/vanilla don't load it -- they build v_grad / nothing.)""" - return cfg.v_hack_path or VHACK_DIR / f"v_hack_pairset_{cfg.vhack_pairs_path.stem}.safetensors" - - -def _log_resolved_config(cfg: Config, device, v_hack_file: Path) -> None: +def _log_resolved_config(cfg: Config, device) -> None: """One block with every None resolved to its effective value, so a detached log shows exactly what ran -- especially WHICH pairset (the field readers kept losing).""" - is_routeV = cfg.intervention in ("routeV", "routeV_per_token") + is_routeV = cfg.intervention == "routeV" fields = { "preset/arm": f"{cfg.preset_name} / {cfg.arm}", - "intervention/adapter": f"{cfg.intervention} / {cfg.adapter}", + "intervention": cfg.intervention, "model": cfg.model, "device": str(device), "seed": cfg.seed, "steps/group/pps": f"{cfg.steps} / {cfg.group} / {cfg.prompts_per_step}", "max_new/lr/grad_clip": f"{cfg.max_new} / {cfg.lr:.1e} / {cfg.grad_clip}", - "eval (unhackable_frac)": f"{cfg.eval} ({cfg.unhackable_frac})", + "lora_r/init_seed": f"{cfg.lora_r} / {cfg.lora_init_seed}", + "unhackable_frac": cfg.unhackable_frac, "env_mode": cfg.env_mode, - "pairset": cfg.vhack_pairs_path if cfg.intervention != "none" else "unused (vanilla)", - "v_hack_file": v_hack_file if cfg.intervention == "erase" else "unused (not erase)", - "routeV gate/top_k/random_v/absorb": ( - f"{cfg.routeV_gate} / {cfg.routeV_top_k} / {cfg.routeV_random_v_seed} / {cfg.routeV_absorb_all}" - if is_routeV else "unused (not routeV)"), + "pairset": cfg.vhack_pairs_path if is_routeV else "unused (not routeV)", + "routeV placebo seed": cfg.routeV_random_v_seed if is_routeV else "n/a", "teacher pool/mix/off_step": ( f"{cfg.teacher_pool_dir.name} / {cfg.mix_ratio} / {cfg.teacher_off_step}" if cfg.teacher_pool_dir else "none (pure on-policy)"), @@ -237,7 +154,7 @@ def _log_resolved_config(cfg: Config, device, v_hack_file: Path) -> None: def main(cfg: Config) -> int: _validate_config(cfg) model_name = cfg.model; steps = cfg.steps; group = cfg.group - max_new = cfg.max_new; n_problems = cfg.n_problems; beta = cfg.beta + max_new = cfg.max_new; n_problems = cfg.n_problems prompts_per_step = cfg.prompts_per_step lr = cfg.lr; adam_beta1 = cfg.adam_beta1; adam_beta2 = cfg.adam_beta2 @@ -249,8 +166,12 @@ def main(cfg: Config) -> int: # Log enough run identity up front to interpret detached logs. logger.info(f"argv: {' '.join(sys.argv)}") logger.info(f"verbose log: {verbose_log}") - v_hack_file = _resolve_v_hack_file(cfg) - _log_resolved_config(cfg, device, v_hack_file) + _log_resolved_config(cfg, device) + + is_routeV = cfg.intervention == "routeV" + is_absorb = cfg.intervention == "absorb" + is_vanilla = cfg.intervention == "none" + has_quarantine = is_routeV or is_absorb # Only adapter parameters train; the base model remains frozen. tok = AutoTokenizer.from_pretrained(model_name) @@ -268,174 +189,83 @@ def main(cfg: Config) -> int: # Generation enables KV cache; loss forwards disable it to avoid unused state. model.config.use_cache = False - # ── adapter: δS (kept) + δS_hack (quarantine). antipasto=diagonal[r]; - # lora_frozen_b=A[r,d_in]; lora2r=rank-2r PiSSA LoRA, quarantine = block slices ── - is_routeV = cfg.intervention in ("routeV", "routeV_per_token") - is_per_token = cfg.intervention == "routeV_per_token" - is_lora = cfg.adapter == "lora_frozen_b" # arm/adapter compatibility checked in _validate_config - is_lora2r = cfg.adapter == "lora2r" - if is_lora: - wrappers = wrap_model_with_lora_frozen_b( - model, model_name, r=cfg.lora_r, b_seed=cfg.lora_b_seed, grad_probe=is_routeV) - elif is_lora2r: - wrappers = wrap_model_with_lora2r( - model, model_name, CACHE_ROOT, device, r=cfg.lora_r, grad_probe=is_routeV) - else: - wrappers = wrap_model_with_antipasto( - model, model_name, CACHE_ROOT, device, - grad_probe=is_routeV, # routeV needs the per-rollout δS gate probe - ) - if is_lora2r: - # A and B both train; quarantine = block slices of the same tensors, so - # there is no separate hack-param list (masks route grads, not surgery). - delta_params = [p for info in wrappers.values() for p in (info["delta_S"], info["B"])] - delta_hack_params = [] - n_quar = sum(info["delta_S"][info["r"]:].numel() + info["B"][:, info["r"]:].numel() - for info in wrappers.values()) - logger.info(f"trainable lora2r A+B: {sum(p.numel() for p in delta_params):,} " - f"({n_quar:,} of those in quarantine blocks)") - else: - # δS_hack receives gradients only under routeV and is removed at deployment. - delta_params = [info["delta_S"] for info in wrappers.values()] - delta_hack_params = [info["delta_S_hack"] for info in wrappers.values()] - logger.info(f"trainable delta_S: {sum(p.numel() for p in delta_params):,} " - f"(+{sum(p.numel() for p in delta_hack_params):,} delta_S_hack quarantine)") + # ── adapter: rank-2r LoRA, deployed block [:r] + quarantine block [r:] ── + # routeV needs the per-rollout c-probe gate; none/absorb pin the masks instead. + wrappers = wrap_model_with_lora2r( + model, r=cfg.lora_r, init_seed=cfg.lora_init_seed, grad_probe=is_routeV) + # A and B both train; quarantine = block slices of the SAME tensors, so there + # is no separate hack-param list (per-rollout masks route grads, not surgery). + delta_params = [p for info in wrappers.values() for p in (info["A"], info["B"])] + n_quar = sum(info["A"][info["r"]:].numel() + info["B"][:, info["r"]:].numel() + for info in wrappers.values()) + logger.info(f"trainable lora2r A+B: {sum(p.numel() for p in delta_params):,} " + f"({n_quar:,} of those in quarantine blocks)") - # ── hack direction: v_hack (erase) or v_grad (routeV) ── - # Vanilla is pure GRPO; erase uses v_hack; routeV uses v_grad. - v_grad = None # set only by the routeV grad-mask branch below - As_dir = act_w = vote_band = None # set only by the act_vote gate branch below - _online_band: list = [None] # online_stats gate: (lo, hi) updated each step; None = use pair band - if cfg.intervention in ("none", "routeV", "routeV_per_token"): - v_hack = None # routeV routes via the mask, not erase grad surgery - if is_routeV: - # Authored pairs are the only routing-label source; live oracle labels never enter training. - from .pairs_from_pool import load_pairs_json - MASK_PAIRS = load_pairs_json(cfg.vhack_pairs_path) - logger.info(f"routeV pairs: {cfg.vhack_pairs_path} -> {len(MASK_PAIRS)} pairs") - model.eval() - # Orient each module's mean pair-gradient difference hack-ward. - from .extract_vhack_grad import extract_v_hack - _, _, raw_grads, _ = extract_v_hack( - model, tok, wrappers, MASK_PAIRS, - top_k=1, tau_axis=0.0, n_heldout=2, device=device, - ) - v_grad = {} - for name in wrappers: - d = (raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).mean(0) - v_grad[name] = (d / d.norm().clamp_min(1e-12)).to(device) - logger.info(f"routeV grad: built v_grad (gradient mean-diff) for {len(v_grad)} modules") - if cfg.routeV_random_v_seed is not None: - v_grad = _haar_unit_dirs(v_grad, cfg.routeV_random_v_seed, device) - logger.info(f"routeV grad: OVERRODE v_grad with Haar-random dirs " - f"(seed={cfg.routeV_random_v_seed}) -- directionality control (H2 vs H4)") - # Calibrate after any Haar override so the control covers the full routing pipeline. - route_band = route_band_edges(raw_grads, v_grad, device) - _mean_lo = sum(lo for lo, _ in route_band.values()) / len(route_band) - _mean_hi = sum(hi for _, hi in route_band.values()) / len(route_band) - _mean_bw = _mean_hi - _mean_lo - logger.info(f"routeV MARGIN band: edges from {len(route_band)} modules, " - f"mean lower(p75 clean cos)={_mean_lo:+.3f}, mean upper(p75 hack cos)={_mean_hi:+.3f}, " - f"mean width={_mean_bw:+.3f} (>0 = pairs separate; <0 = overlap -> hard step at max clean). " - f"Live cos below lower -> kept; above upper -> routed; between -> ramps (rout/frout). " - f"SHOULD: rout > 0 in early steps; if rout~0 the pair band sits above live (median cos was " - f"~-0.06 on the wide run) -> switch to a live-cos quantile gate.") - # Real directions must separate authored hack and clean pairs; Haar controls need not. - if cfg.routeV_random_v_seed is None: - assert _mean_bw > 0, ( - f"real v_grad gave non-positive mean band width {_mean_bw:+.3f}: " - "hack pairs do not separate from clean -> extraction broken") - if is_lora2r: - logger.info( - "lora2r three-way gate (SGTM-style): per-rollout label from the mean " - "band-normalized cosine across modules; clean->deployed-only, " - "hack->quarantine-only (deployed detached), mid->both (absorption). " - "SHOULD: rout (hack share) tracks the step's rollout hack rate, not ~50%; " - "clipfrac on clean-gated rollouts < ~0.2 ELSE the retain-trick ratio " - "drift is binding (quarantine forward too large).") - # top-k subspace gate: oriented top-k right singular vectors of the per-pair - # diff D=[n_pairs, r], each re-oriented hack-ward by sign(v_i . mean_diff), with - # a max-over-k cosine band from the same pairs. Only the per-rollout grad_cosine - # path consumes these (asserted at config-validation below). - v_grad_topk: dict[str, torch.Tensor] = {} - route_band_topk: dict[str, tuple[float, float]] = {} - if cfg.routeV_top_k > 1: # gate compatibility checked in _validate_config - k = cfg.routeV_top_k - for name in wrappers: - gh = raw_grads[f"hack/{name}"].float() # [n_pairs, r] - gc = raw_grads[f"clean/{name}"].float() - D = gh - gc # [n_pairs, r] - Vh = torch.linalg.svd(D, full_matrices=False).Vh # [min(n,r), r] - V = Vh[:k] # [k, r] orthonormal - V = (V * torch.sign(V @ D.mean(0)).unsqueeze(1)) # [k, r] oriented hack-ward - chk = ((gh @ V.T) / gh.norm(dim=1, keepdim=True).clamp_min(1e-12)).max(1).values - cck = ((gc @ V.T) / gc.norm(dim=1, keepdim=True).clamp_min(1e-12)).max(1).values - v_grad_topk[name] = V.to(device) - route_band_topk[name] = (cck.quantile(0.75).item(), chk.quantile(0.75).item()) - _bw_tk = sum(hi - lo for lo, hi in route_band_topk.values()) / len(route_band_topk) - logger.info(f"routeV top-{k} subspace: built oriented [{k},r] basis for " - f"{len(v_grad_topk)} modules, mean max-cos band width={_bw_tk:+.3f} " - "(>0 = top-k subspace separates hack from clean)") - if cfg.routeV_gate == "act_vote": - As_dir, act_w, vote_band = build_act_vote_dirs(model, wrappers, tok, MASK_PAIRS, device) - model.train() - else: - v_hack_path = v_hack_file # resolved at startup: explicit --v-hack-path or pairset-derived cache - if not v_hack_path.exists(): - if cfg.v_hack_path is not None: - raise FileNotFoundError( - f"--v-hack-path={cfg.v_hack_path} does not exist; explicit paths must be " - "prebuilt (only the pairset-derived cache auto-extracts)") - from .extract_vhack_grad import extract_v_hack - from .pairs_from_pool import load_pairs_json - VHACK_PAIRS = load_pairs_json(cfg.vhack_pairs_path) - logger.info(f"v_hack pairs: {cfg.vhack_pairs_path} -> {len(VHACK_PAIRS)} pairs") - logger.info(f"v_hack cache miss at {v_hack_path}; extracting (~5min)...") - model.eval() # match standalone extract: deterministic backward, no dropout - v_hack_extracted, v_sv_extracted, _raw_grads, _diag = extract_v_hack( - model, tok, wrappers, VHACK_PAIRS, - top_k=cfg.v_hack_extract_top_k, tau_axis=cfg.v_hack_tau_axis, - n_heldout=2, device=device, - ) - OUT_DIR.mkdir(exist_ok=True) - # Store basis vectors and singular values together; load_v_hack separates them. - save_payload = {**v_hack_extracted, **{f"_sv/{n}": s for n, s in v_sv_extracted.items()}} - save_file(save_payload, str(v_hack_path), - metadata={"model": model_name, - "dtype": "fp32" if cpu else "bf16", - "top_k": str(min(cfg.v_hack_extract_top_k, len(VHACK_PAIRS) - 2)), - "tau_axis": str(cfg.v_hack_tau_axis), "schema": "v2_with_sv", - "pairs_path": str(cfg.vhack_pairs_path), - "pairs_sha256": pairset_sha256(cfg.vhack_pairs_path)}) - model.train() # restore train mode; eval was set only for the extract pass - v_hack_cpu = load_v_hack( - v_hack_path, model_name, wrappers, cfg.vhack_pairs_path, - k_use=cfg.v_hack_k, drop_bottom_frac=cfg.v_hack_drop_bottom_frac, + # ── routeV direction: v_grad (mean pair-gradient diff) + routing band ── + v_grad = None # set only by the routeV branch below + route_band = None + if is_routeV: + # Authored pairs are the only routing-label source; live oracle labels never enter training. + from .pairs_from_pool import load_pairs_json + from .extract_vhack_grad import extract_v_hack + MASK_PAIRS = load_pairs_json(cfg.vhack_pairs_path) + logger.info(f"routeV pairs: {cfg.vhack_pairs_path} -> {len(MASK_PAIRS)} pairs") + model.eval() # match standalone extract: deterministic backward, no dropout + _, _, raw_grads, _ = extract_v_hack( + model, tok, wrappers, MASK_PAIRS, + top_k=1, tau_axis=0.0, n_heldout=2, device=device, ) - v_hack = {name: v.to(device) for name, v in v_hack_cpu.items()} + v_grad = {} + for name in wrappers: + d = (raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).mean(0) + v_grad[name] = (d / d.norm().clamp_min(1e-12)).to(device) + logger.info(f"routeV grad: built v_grad (gradient mean-diff) for {len(v_grad)} modules") + if cfg.routeV_random_v_seed is not None: + v_grad = _haar_unit_dirs(v_grad, cfg.routeV_random_v_seed, device) + logger.info(f"routeV grad: OVERRODE v_grad with Haar-random dirs " + f"(seed={cfg.routeV_random_v_seed}) -- placebo directionality control") + # Calibrate after any Haar override so the control covers the full routing pipeline. + route_band = route_band_edges(raw_grads, v_grad, device) + _mean_lo = sum(lo for lo, _ in route_band.values()) / len(route_band) + _mean_hi = sum(hi for _, hi in route_band.values()) / len(route_band) + _mean_bw = _mean_hi - _mean_lo + n_inc_band = sum(1 for lo, hi in route_band.values() if hi - lo > 0) + logger.info( + f"routeV band: {len(route_band)} modules, mean lower(p75 clean cos)={_mean_lo:+.3f}, " + f"mean upper(p75 hack cos)={_mean_hi:+.3f}, mean width={_mean_bw:+.3f}; " + f"{n_inc_band}/{len(route_band)} modules have positive band width (included in the gate). " + f"SHOULD: width>0 (pairs separate) and most modules included; ELSE extraction/band off.") + # Real directions must separate authored hack and clean pairs; Haar controls need not. + if cfg.routeV_random_v_seed is None: + assert _mean_bw > 0, ( + f"real v_grad gave non-positive mean band width {_mean_bw:+.3f}: " + "hack pairs do not separate from clean -> extraction broken") + logger.info( + "lora2r three-way gate (SGTM-style): per-rollout label from the width-pooled " + "band-normalized cosine across modules; clean->deployed-only, " + "hack->quarantine-only (deployed detached), mid->both (absorption). " + "SHOULD: rout (hack share) tracks the step's rollout hack rate, not ~50%; " + "clipfrac on clean-gated rollouts < ~0.2 ELSE the retain-trick ratio " + "drift is binding (quarantine forward too large).") + model.train() + # ── teacher pool ── # Teacher pool: pre-generated rollouts on disk keyed by problem_id. Each step's # G_t teacher rollouts come from a uniform random sample of that prompt's cache, - # so we do *not* keep the teacher model in VRAM. Pool is produced by - # `probe_distill.py --teacher-only` (see schema in probe_distill.py:149-186). - # Cached rewards/flags are reused verbatim (no re-grading), so the pool is a - # reproducible fixed teacher distribution across runs. + # so we do *not* keep the teacher model in VRAM. Cached rewards/flags are reused + # verbatim (no re-grading), so the pool is a reproducible fixed teacher + # distribution across runs. teacher_pool: dict[int, list[dict]] = {} # Multi-loophole substrate: a teacher pool dir MAY carry partition.json # {problem_id: env_mode}. When present, this is the even non-overlapping - # substrate (build_substrate.py) -- each problem is graded by its assigned mode - # and the teacher rollouts are the elicit-then-strip hacks for that mode. When - # absent, the run is single-mode (cfg.env_mode for every problem). See - # docs/spec/20260530_faithful_multi_loophole_env.md. + # substrate (build_substrate.py) -- each problem graded by its assigned mode. + # When absent, the run is single-mode (cfg.env_mode for every problem). partition: dict[int, EnvMode] | None = None G_s = group G_t = 0 if cfg.teacher_pool_dir is not None: - # mix=0 is the NO-TEACHER ablation: pure on-policy GRPO (G_t=0, no teacher - # rollouts injected) while the pool is still loaded for the 4-mode partition - # and routeV v_grad extraction. Using the pairs for v_grad is allowed under - # the no-cheat invariant; mixing teacher rollouts into training is the thing - # mix=0 removes. mix in [0,1). + # mix=0 is the NO-TEACHER ablation: pure on-policy GRPO (G_t=0) while the + # pool is still loaded for the partition + routeV v_grad extraction. if not (0.0 <= cfg.mix_ratio < 1.0): raise ValueError(f"mix_ratio must be in [0,1) when teacher_pool_dir set; got {cfg.mix_ratio}") G_t = round(group * cfg.mix_ratio) @@ -443,18 +273,15 @@ def main(cfg: Config) -> int: if G_s == 0: raise ValueError( f"degenerate split: G={group} mix_ratio={cfg.mix_ratio} -> G_s={G_s}. " - f"Pick mix_ratio < 1 so the student half is non-empty." - ) + f"Pick mix_ratio < 1 so the student half is non-empty.") for path in sorted(cfg.teacher_pool_dir.glob("prompt_*.jsonl.gz")): - # path.stem on 'prompt_0004.jsonl.gz' is 'prompt_0004.jsonl' (only one - # suffix stripped); split off the .jsonl before parsing the int. + # path.name 'prompt_0004.jsonl.gz' -> problem_id 4. problem_id = int(path.name.split("_")[1].split(".")[0]) with gzip.open(path, "rt") as f: teacher_pool[problem_id] = [json.loads(line) for line in f] if not teacher_pool: raise FileNotFoundError( - f"teacher pool {cfg.teacher_pool_dir} is empty. Run `just pregen-teacher N` first." - ) + f"teacher pool {cfg.teacher_pool_dir} is empty. Run `just pregen-teacher N` first.") partition_path = cfg.teacher_pool_dir / "partition.json" if partition_path.exists(): raw = json.loads(partition_path.read_text()) @@ -465,8 +292,7 @@ def main(cfg: Config) -> int: f"SUBSTRATE: per-problem env_mode partition from {partition_path.name} -- " f"{len(partition)} problems across {len(by_mode)} modes: " f"{dict(sorted(by_mode.items()))}. Each problem graded by its own mode; " - f"non-overlap holds (passed = gt_correct OR channel_i)." - ) + f"non-overlap holds (passed = gt_correct OR channel_i).") if cfg.teacher_modes is not None: # No-cheat generalization test: held-out modes remain on-policy and receive no demos. assert partition is not None, "teacher_modes needs a partition.json" @@ -475,24 +301,18 @@ def main(cfg: Config) -> int: logger.info( f"teacher_modes={cfg.teacher_modes}: teacher pool restricted " f"{len(teacher_pool)}->{len(kept)} prompts (known modes only); " - f"held-out-mode problems train ON-POLICY (no teacher, no anchor seed)." - ) + f"held-out-mode problems train ON-POLICY (no teacher, no anchor seed).") teacher_pool = kept n_rollouts_per = sum(len(v) for v in teacher_pool.values()) / len(teacher_pool) avg_hack = sum(int(r["hacked"]) for v in teacher_pool.values() for r in v) / sum(len(v) for v in teacher_pool.values()) logger.info( - f"teacher pool: {len(teacher_pool)} prompts, " - f"~{n_rollouts_per:.1f} rollouts/prompt, " - f"cached hack_rate={avg_hack:.2%}. " - f"G_s={G_s} student + G_t={G_t} teacher per prompt (mix_ratio={cfg.mix_ratio})." - ) + f"teacher pool: {len(teacher_pool)} prompts, ~{n_rollouts_per:.1f} rollouts/prompt, " + f"cached hack_rate={avg_hack:.2%}. G_s={G_s} student + G_t={G_t} teacher per prompt " + f"(mix_ratio={cfg.mix_ratio}).") - # ── optimizer + schedule ── - # The deployed and quarantine adapters share one optimizer and parameterization. + # ── optimizer + schedule ── (A and B of both blocks; masks route grads) opt = torch.optim.AdamW( - delta_params + delta_hack_params, - lr=lr, weight_decay=cfg.weight_decay, betas=(adam_beta1, adam_beta2), - ) + delta_params, lr=lr, weight_decay=cfg.weight_decay, betas=(adam_beta1, adam_beta2)) # Fractional warmup preserves the intended schedule across preset lengths. warmup_steps = max(1, int(cfg.warmup_frac * steps)) sched = torch.optim.lr_scheduler.SequentialLR( @@ -552,25 +372,22 @@ def main(cfg: Config) -> int: rng = torch.Generator().manual_seed(cfg.seed) rows = [] logger.info( - f"SHOULD: loss finite each step; projected arm cout -> ~0 (all hack-ward grad removed); " - f"PASS_RATE > 0 on 4B. " - f"ELSE: harness or projection broken. " - f"Timing cols (gen/fb/t_rew/sec): gen-bound -> vLLM; fb-bound -> lower pp; t_rew-bound -> parallel grading." - ) + f"SHOULD: loss finite each step; PASS_RATE > 0 on 4B. " + f"ELSE: harness broken. " + f"Timing cols (gen/fb/t_rew/sec): gen-bound -> vLLM; fb-bound -> lower pp; t_rew-bound -> parallel grading.") if teacher_pool: logger.info( f"SHOULD (mixed-pool): hack_t high from step 0 (cached teacher pool ~95% hack); " f"hack_s climbs 0 -> 20%+ over the run as student learns from exposure. " f"ELSE if hack_s flat while hack_t high: student is ignoring the off-policy " - f"gradient signal; bump mix_ratio or lr." - ) + f"gradient signal; bump mix_ratio or lr.") eos_id = tok.eos_token_id pad_id = tok.pad_token_id def gen_students(enc, n: int) -> tuple[torch.Tensor, int]: """Generate student rollouts, placing any quarantine-ablated samples last.""" - n_abl = round(n * cfg.rollout_ablate_frac) if is_routeV else 0 + n_abl = round(n * cfg.rollout_ablate_frac) if has_quarantine else 0 parts = [] if n - n_abl > 0: parts.append(model.generate(**enc, generation_config=gen_cfg, @@ -591,8 +408,7 @@ def main(cfg: Config) -> int: logger.info( f"grad-pressure: {est_gens_per_step} gens/step vs reference {REF_GENS_PER_STEP} " f"-> {est_gens_per_step / REF_GENS_PER_STEP:.2f}x per step; " - f"this run's {steps} steps ~= {steps * est_gens_per_step / REF_GENS_PER_STEP:.1f} reference steps." - ) + f"this run's {steps} steps ~= {steps * est_gens_per_step / REF_GENS_PER_STEP:.1f} reference steps.") # Print only the legend columns active for this arm and environment. logger.info("\n" + step_logger.legend() + "\n\n") logger.info(step_logger.header()) @@ -608,7 +424,6 @@ def main(cfg: Config) -> int: rollout_log_path = run_dir / "rollouts.jsonl" rollout_log_path.write_text("") first_hack_saved = False - # Authored pairs alone calibrate the routeV band. last_gen_sample = None # first student rollout of the latest step (for collapse inspection) diverged_steps = 0 # consecutive steps with collapsed teacher ppl (divergence tripwire) lp_t_best = -float("inf") # coherence high-water mark (best teacher gen_logp seen) @@ -616,7 +431,6 @@ def main(cfg: Config) -> int: DIVERGENCE_DROP = 5.0 # nats below best (e^5 ~ 150x worse ppl); never in healthy runs WARN_DROP = 3.0 # softer: log a warning before the hard abort dumped_hack_classes: set[str] = set() # first full example of each hack class -> verbose log - teacher_dumped = False # Track whether and when the student learns each substrate mode. mode_rollouts: dict[str, int] = {} mode_hacks: dict[str, int] = {} @@ -624,44 +438,64 @@ def main(cfg: Config) -> int: n_flipped = 0 # prompt-draws shown hint-free this run (rotating-unhackable flip) def save_ckpt(rows: list[dict], path: Path | None = None) -> None: - """Save deployed and quarantine adapters with config and per-step metadata.""" + """Save a self-contained lora2r checkpoint: full A/B + the frozen init A0/B0, + so a loader reconstructs the net delta (B@A - B0@A0) and can ablate the + quarantine without any SVD cache. Config + per-step rows in the metadata.""" n_gens = sum(r["N"] for r in rows) - # Reconstruct combined rates from the student/teacher source columns. hr = sum(r["hack_s"][0] + r["hack_t"][0] for r in rows) / max(1, n_gens) pr = sum(r["gt_s"][0] + r["gt_t"][0] for r in rows) / max(1, n_gens) - # Save the deployed adapter separately so it can be evaluated without quarantine state. _ckpt = path or ckpt_path - if is_lora2r: - # Deployed slices -> main ckpt; quarantine slices -> _hack file. - # A0/B0 are derivable (svd_cached on W), so only trained slices are stored. - tensors, hack_tensors = {}, {} - for n, info in wrappers.items(): - r_blk = info["r"] - A_cpu = info["delta_S"].detach().cpu() - B_cpu = info["B"].detach().cpu() - tensors[f"A/{n}"] = A_cpu[:r_blk].contiguous() - tensors[f"B/{n}"] = B_cpu[:, :r_blk].contiguous() - hack_tensors[f"A/{n}"] = A_cpu[r_blk:].contiguous() - hack_tensors[f"B/{n}"] = B_cpu[:, r_blk:].contiguous() - else: - tensors = {n: info["delta_S"].detach().cpu().contiguous() - for n, info in wrappers.items()} - hack_tensors = {n: info["delta_S_hack"].detach().cpu().contiguous() - for n, info in wrappers.items()} + tensors = {} + for n, info in wrappers.items(): + tensors[f"A/{n}"] = info["A"].detach().float().cpu().contiguous() + tensors[f"B/{n}"] = info["B"].detach().float().cpu().contiguous() + tensors[f"A0/{n}"] = info["A0"].detach().float().cpu().contiguous() + tensors[f"B0/{n}"] = info["B0"].detach().float().cpu().contiguous() save_file(tensors, str(_ckpt), metadata={ - "model": model_name, "dtype": "bf16", "step": str(len(rows)), + "model": model_name, "dtype": "fp32", "step": str(len(rows)), "hack_rate": f"{hr:.6f}", "pass_rate": f"{pr:.6f}", - "rows": json.dumps(rows), "cfg": json.dumps(vars(cfg), default=str), + "rows": json.dumps(rows), + "cfg": json.dumps(vars(cfg) | {"adapter": "lora2r"}, default=str), }) - save_file(hack_tensors, str(_ckpt.with_name(_ckpt.stem + "_hack.safetensors")), - metadata={"model": model_name, "step": str(len(rows))}) save_ckpt([], path=run_dir / "ckpt_update0000.safetensors") + def _lora2r_gate_labels(c_grads: tuple, n_rollouts: int): + """Three-way SGTM-style label per rollout from the gate-pass c-probe grads. + + Per module the per-rollout weight grad of the virtual diagonal (deployed + block [r]) has a band-normalized cosine position. We POOL across modules in + a single (num, den) fraction (T3 fix): a module with a wide band contributes + proportionally more than a noisy near-zero-width one, instead of every module + casting an equal-weight vote. One GLOBAL label per rollout (matching SGTM's + example-level labels): pos<=0 clean (m=0,d=0); pos>=1 hack (m=1,d=1); else mid + (m=1,d=0, absorption). Returns (m, d, f3, w): f3 in {0,.5,1} for _zone_stats, + w = mean per-rollout grad norm for energy weighting.""" + num = torch.zeros(n_rollouts, device=device); den = 0.0 + w = torch.zeros(n_rollouts, device=device); n_inc = 0 + for (name, info), cg in zip(wrappers.items(), c_grads, strict=True): + lower, upper = route_band[name] + if upper - lower <= 0: # noisy module: pairs don't separate -> excluded + continue + r_blk = info["r"] + g_b = cg.reshape(n_rollouts, -1, 2 * r_blk).sum(1)[:, :r_blk].float() # [G, r] deployed block + nrm = g_b.norm(dim=1) + cos_b = (g_b @ v_grad[name]) / nrm.clamp_min(1e-12) # [G] + num += cos_b - lower; den += upper - lower + w += nrm; n_inc += 1 + if n_inc == 0: + raise RuntimeError("no module has positive band width; pairs separate nowhere") + pos = num / den; w /= n_inc + logger.debug(f"gate pos: n_inc={n_inc}/{len(wrappers)} " + f"min={pos.min().item():+.2f} max={pos.max().item():+.2f}") + m = (pos > 0).float() # mid + hack -> quarantine trains + d = (pos >= 1).float() # hack -> deployed detached + return m, d, 0.5 * m + 0.5 * d, w + # Disable tqdm off-TTY because structured per-step rows already report progress. pbar = tqdm(range(steps), desc=f"train {cfg.arm} {cfg.preset_name}", mininterval=120, maxinterval=120, disable=None) - # ── training loop: generate -> grade -> backward -> project -> step ── + # ── training loop: generate -> grade -> backward -> (gate) -> masked backward -> step ── for step in pbar: # After teacher-off, the remainder of training is purely on-policy. if cfg.teacher_off_step is not None and step >= cfg.teacher_off_step and G_t > 0: @@ -673,235 +507,21 @@ def main(cfg: Config) -> int: # Each prompt group defines one GRPO advantage-normalization unit. agg_rew, agg_gt, agg_hack, agg_fmt = [], [], [], [] - # Teacher cache lacks E/D labels, so aligned teacher slots remain false. - agg_hack_E: list[bool] = [] - agg_hack_D: list[bool] = [] step_rollouts: list[dict] = [] # student completions this step -> rollout_log_path agg_is_student: list[bool] = [] agg_is_ablated: list[bool] = [] # deploy-mode (quarantine-ablated) student rows -> free per-step deploy proxy - step_mode_hacks: dict[str, int] = {} # THIS step's student hacks per mode (the hk_ columns; reset each step so they don't grow) + step_mode_hacks: dict[str, int] = {} # THIS step's student hacks per mode (the hk_ columns) agg_logp: list[float] = [] # per-rollout mean per-token gen_logp (student's logp on rollout tokens) agg_comp_lens, agg_finished = [], [] n_zerovar = 0 # groups skipped for zero reward variance (all rollouts same reward). agg_loss = 0.0 diag_tail = None - # Split source gradients only to test whether the direction distinguishes teacher hacks. - step_grad_s: dict[str, torch.Tensor] = {} - step_grad_t: dict[str, torch.Tensor] = {} - # Accumulate routed gradient separately before injecting it into quarantine. - step_grad_hack: dict[str, torch.Tensor] = {} - # The activation vote produces one routing fraction per rollout, shared by all modules. - _step_f_roll: list[torch.Tensor | None] = [None] - _step_absorb_f: list[torch.Tensor | None] = [None] # absorb_all: [G] 1=quarantine enabled, 0=ablated floor - _step_online_cos: list[torch.Tensor] = [] # online_stats: per-module [G] cosines, cleared each step - - # Near-zero δS axes cannot recover per-rollout gradients, so routing lags one update there. - GATE_EPS = 1e-6 - step_flagged: list[float] = [] - step_clipfrac: list[float] = [] # lora2r: PPO clip frac on clean-gated rollouts (retain-trick drift gauge) + # routeV gate diagnostics (per-rollout three-way zone shares + retain-trick clipfrac). + step_flagged: list[float] = [] # hack share (mean d over rollouts) per prompt + step_clipfrac: list[float] = [] # PPO clip frac on clean-gated rollouts (retain-trick drift gauge) step_zkeep: list[float] = []; step_zresid: list[float] = []; step_zrout: list[float] = [] # unit shares per zone step_zkeepE: list[float] = []; step_zresidE: list[float] = []; step_zroutE: list[float] = [] # energy shares per zone - step_resid: list[float] = [] # cos(δS.grad AFTER routing, v_grad): hack-ward leak into deployed adapter - def _routeV_grad_filter(info, n_rollouts: int) -> torch.Tensor: - g = info["delta_S"].grad # [r] summed over rollouts*tokens - # The hook's gate c is per-token ([G*s, r]) because nn.Linear sees a - # flattened batch. reshape [G*s, r] -> [G, s, r]. Pad tokens carry ~0 grad - # (masked in the loss), so they contribute ~0 to routed regardless of unit. - cg_full = info["layer"]._antipasto_gate.grad.reshape(n_rollouts, -1, g.shape[0]) # [G, s, r] = δS*g - dS = info["delta_S"].detach() # [r] - reliable = dS.abs() > GATE_EPS # [r] - dS_safe = torch.where(reliable, dS, torch.ones_like(dS)) - vg = v_grad[name] # [r] unit, hack-ward - # Banded gate, calibrated from the PAIRS only (route_band[name]): a unit whose - # grad cosine is below the clean edge is kept, above the hack edge is routed, - # in between ramps proportionally (absorption). v_grad is the sole router. - # f is the routed FRACTION (0..1). Granularity is the routing UNIT: - # per-rollout (default): sum tokens first -> one cos/f per rollout. Denoises - # the cos sign (a clean rollout's tokens scatter ~50% over cos>0; the - # token-sum points reliably clean-ward) and matches GRPO's per-rollout adv. - # per-token (routeV_per_token): one cos/f per token -- finer but noisier. - lower, upper = route_band[name] - band = max(upper - lower, 1e-6) - if cfg.routeV_absorb_all: - # NO vector: f is the generation-mode mask (enabled routes all; ablated keeps all). - # v_grad/band above are computed but never enter f. - cg = cg_full.sum(1) # [G, r] per-rollout δS*g - g_b = torch.where(reliable, cg / dS_safe, torch.zeros_like(cg)) # [G, r] - f = _step_absorb_f[0] # [G] 1=route, 0=keep - routed = torch.where(reliable, (cg * f.unsqueeze(1)).sum(0) / dS_safe, - torch.zeros_like(g)) - step_flagged.append(f.mean().item()) - _kn, _rn, _on, _ke, _re, _oe = _zone_stats(f, g_b.norm(dim=1)) - step_zkeep.append(_kn); step_zresid.append(_rn); step_zrout.append(_on) - step_zkeepE.append(_ke); step_zresidE.append(_re); step_zroutE.append(_oe) - elif cfg.routeV_gate == "act_vote": - # Global gate: route every module's per-rollout grad by the SAME f_roll - # (the activation vote, computed once for the step). Per-rollout granularity - # by construction; per_token is ignored under act_vote. - cg = cg_full.sum(1) # [G, r] per-rollout δS*g - g_b = torch.where(reliable, cg / dS_safe, torch.zeros_like(cg)) # [G, r] - f = _step_f_roll[0] # [G] shared across modules - routed = torch.where(reliable, (cg * f.unsqueeze(1)).sum(0) / dS_safe, - torch.zeros_like(g)) - step_flagged.append(f.mean().item()) - _kn, _rn, _on, _ke, _re, _oe = _zone_stats(f, g_b.norm(dim=1)) - step_zkeep.append(_kn); step_zresid.append(_rn); step_zrout.append(_on) - step_zkeepE.append(_ke); step_zresidE.append(_re); step_zroutE.append(_oe) - elif cfg.routeV_gate == "online_stats": - # Online-stats gate: band thresholds from the LIVE rolling cosine distribution - # (q5/q95 across all modules*rollouts this step), not from pairs. Direction - # (v_grad) still comes from authored pairs -- only calibration is online. - # Fallback to pair-derived band on step 0 (buffer empty). - cg = cg_full.sum(1) # [G, r] - g_b = torch.where(reliable, cg / dS_safe, torch.zeros_like(cg)) # [G, r] - cos_b = (g_b @ vg) / g_b.norm(dim=1).clamp_min(1e-12) # [G] - _step_online_cos.append(cos_b.detach()) # accumulate; band updated post-step - # step-0 prior: neutral (-0.5, 0.5) so some routing always fires before the - # live distribution bootstraps. Pair-derived (lower, upper) is not used for - # threshold calibration -- that is the whole point of online_stats. - lo, hi = _online_band[0] if _online_band[0] is not None else (-0.5, 0.5) - band_w = max(hi - lo, 1e-6) - f = ((cos_b - lo) / band_w).clamp(0.0, 1.0) # [G] - routed = torch.where(reliable, (cg * f.unsqueeze(1)).sum(0) / dS_safe, - torch.zeros_like(g)) - step_flagged.append(f.mean().item()) - _kn, _rn, _on, _ke, _re, _oe = _zone_stats(f, g_b.norm(dim=1)) - step_zkeep.append(_kn); step_zresid.append(_rn); step_zrout.append(_on) - step_zkeepE.append(_ke); step_zresidE.append(_re); step_zroutE.append(_oe) - elif is_per_token: - g_u = torch.where(reliable, cg_full / dS_safe, torch.zeros_like(cg_full)) # [G, s, r] - cos_u = (g_u @ vg) / g_u.norm(dim=2).clamp_min(1e-12) # [G, s] - f = ((cos_u - lower) / band).clamp(0.0, 1.0) # [G, s] - routed = torch.where(reliable, (cg_full * f.unsqueeze(-1)).sum((0, 1)) / dS_safe, - torch.zeros_like(g)) # Σ_{b,t} f·(δS·g) / δS - live = g_u.norm(dim=2) > 1e-8 # drop pad tokens from the gauges - step_flagged.append(f[live].mean().item() if live.any() else 0.0) - _kn, _rn, _on, _ke, _re, _oe = _zone_stats(f[live], g_u.norm(dim=2)[live]) - step_zkeep.append(_kn); step_zresid.append(_rn); step_zrout.append(_on) - step_zkeepE.append(_ke); step_zresidE.append(_re); step_zroutE.append(_oe) - else: - cg = cg_full.sum(1) # [G, r] per-rollout - g_b = torch.where(reliable, cg / dS_safe, torch.zeros_like(cg)) # [G, r] - if cfg.routeV_top_k > 1: - # gate on the most-aligned oriented top-k axis (max-cos subspace gate) - V = v_grad_topk[name] # [k, r] - cos_b = ((g_b @ V.T) / g_b.norm(dim=1, keepdim=True).clamp_min(1e-12)).max(1).values # [G] - lower_tk, upper_tk = route_band_topk[name] - f = ((cos_b - lower_tk) / max(upper_tk - lower_tk, 1e-6)).clamp(0.0, 1.0) # [G] - else: - cos_b = (g_b @ vg) / g_b.norm(dim=1).clamp_min(1e-12) # [G] - f = ((cos_b - lower) / band).clamp(0.0, 1.0) # [G] - routed = torch.where(reliable, (cg * f.unsqueeze(1)).sum(0) / dS_safe, - torch.zeros_like(g)) # Σ_b f_b·g_b on reliable axes - step_flagged.append(f.mean().item()) - _kn, _rn, _on, _ke, _re, _oe = _zone_stats(f, g_b.norm(dim=1)) - step_zkeep.append(_kn); step_zresid.append(_rn); step_zrout.append(_on) - step_zkeepE.append(_ke); step_zresidE.append(_re); step_zroutE.append(_oe) - # Park the routed fraction in δS_hack (deleted at deploy); δS keeps the rest. - # routed + g_keep = g exactly (unreliable axes: routed=0, kept whole). - step_grad_hack[name] = (step_grad_hack[name] + routed.detach().clone() - if name in step_grad_hack else routed.detach().clone()) - g_keep = g - routed # deployed adapter gradient - # Residual hack-ward alignment of the KEPT grad: ~0 = routing stripped the - # hack cleanly; >0 = hack leaked into the deployed adapter. vg is unit -> plain cosine. - step_resid.append((g_keep @ vg / g_keep.norm().clamp_min(1e-12)).item()) - return g_keep - - def _lora_routeV_grad_filter(info, n_rollouts: int) -> torch.Tensor: - # LoRA-frozen-B routeV: decide in the r-bottleneck g_h = B^T δ_y, split A.grad. - # A.grad and A_hack.grad are identical pre-routing (shared frozen B), so we - # just carve A.grad [r, d_in] into kept (-> A) and routed (-> A_hack) by each - # rollout's bottleneck cosine to v_grad. No per-axis reliability gate (the - # whole A.grad is a single autograd tensor, not a per-axis diagonal). - layer = info["layer"] - full = info["delta_S"].grad # A.grad [r, d_in] - r, d_in = full.shape - g_h = layer._lora_h.grad.reshape(n_rollouts, -1, r).float() # [G, s, r] bottleneck grad - x_ = layer._lora_x.reshape(n_rollouts, -1, d_in).float() # [G, s, d_in] cached input - vg = v_grad[name] # [r] unit, hack-ward - g_roll = g_h.sum(1) # [G, r] per-rollout - cos_b = (g_roll @ vg) / g_roll.norm(dim=1).clamp_min(1e-12) # [G] - lower, upper = route_band[name] - band = max(upper - lower, 1e-6) - f = ((cos_b - lower) / band).clamp(0.0, 1.0) # [G] - # routed contribution to A.grad: Σ_b f_b Σ_t g_h[b,t] ⊗ x[b,t] - routed = torch.einsum("gsr,gsd,g->rd", g_h, x_, f).to(full.dtype) # [r, d_in] - step_flagged.append(f.mean().item()) - _kn, _rn, _on, _ke, _re, _oe = _zone_stats(f, g_roll.norm(dim=1)) - step_zkeep.append(_kn); step_zresid.append(_rn); step_zrout.append(_on) - step_zkeepE.append(_ke); step_zresidE.append(_re); step_zroutE.append(_oe) - step_grad_hack[name] = (step_grad_hack[name] + routed.detach().clone() - if name in step_grad_hack else routed.detach().clone()) - g_keep = full - routed - # resid: kept-grad bottleneck alignment with v_grad (mirrors AntiPaSTO's resid) - g_keep_roll = ((1.0 - f).unsqueeze(1) * g_roll).sum(0) # [r] - step_resid.append((g_keep_roll @ vg / g_keep_roll.norm().clamp_min(1e-12)).item()) - return g_keep - - def _lora2r_gate_labels(c_grads: tuple, n_rollouts: int): - """Three-way SGTM-style label per rollout from the gate-pass c-probe grads. - - Per module: g_b = per-rollout weight grad of the virtual diagonal, deployed - block [r]; band-normalized cosine position p = (cos(g_b, v_grad)-lower)/width. - One GLOBAL label per rollout (mean p across modules, matching SGTM's - example-level labels): p<=0 clean (m=0,d=0); p>=1 hack (m=1,d=1); else mid - (m=1,d=0, absorption). Returns (m, d, f3, w): f3 in {0,.5,1} for _zone_stats, - w = mean per-rollout grad norm for energy weighting.""" - pos = torch.zeros(n_rollouts, device=device) - w = torch.zeros(n_rollouts, device=device) - for (name, info), cg in zip(wrappers.items(), c_grads, strict=True): - r_blk = info["r"] - g_b = cg.reshape(n_rollouts, -1, 2 * r_blk).sum(1)[:, :r_blk].float() # [G, r] deployed block - nrm = g_b.norm(dim=1) - cos_b = (g_b @ v_grad[name]) / nrm.clamp_min(1e-12) # [G] - lower, upper = route_band[name] - pos += (cos_b - lower) / max(upper - lower, 1e-6) - w += nrm - pos /= len(wrappers) - w /= len(wrappers) - m = (pos > 0).float() # mid + hack -> quarantine trains - d = (pos >= 1).float() # hack -> deployed detached - return m, d, 0.5 * m + 0.5 * d, w - - def _act_vote_f_roll(n_rollouts: int, plen: int, comp_mask: torch.Tensor) -> torch.Tensor: - """Global per-rollout routing fraction from the activation vote (act_vote gate). - For each module: As_b = completion-mean(Vh@x) [G, r]; cos(As_b, As_dir); aggregate - across modules weighted by act_w into one vote per rollout; band -> f_roll [G]. - comp_mask [G, L_c] is the completion-token mask (= mask in the loss).""" - num = torch.zeros(n_rollouts, device=device) - wsum = 0.0 - # Window = [plen-1 : end] = last prompt token + completion, matching - # build_act_vote_dirs' hook and diag_cosine_dist.py (so the live vote is scored - # against the band built on the same window). The leading prompt token is always - # valid (never pad), so its mask entry is 1. - ext_mask = torch.cat([torch.ones(n_rollouts, 1, device=comp_mask.device, dtype=comp_mask.dtype), - comp_mask], dim=1) # [G, L_c+1] - for nm in wrappers: - r = As_dir[nm].shape[0] - a = wrappers[nm]["layer"]._antipasto_act.reshape(n_rollouts, -1, r) # [G, S, r] - a_comp = a[:, plen - 1:, :].float() # [G, L_c+1, r] - assert a_comp.shape[1] == ext_mask.shape[1], ( - f"act_vote layout: a_comp s={a_comp.shape[1]} != L_c+1={ext_mask.shape[1]} " - f"(module {nm}); the cached activation seq-len must match merged") - As_b = (a_comp * ext_mask.unsqueeze(-1)).sum(1) / ext_mask.sum(1, keepdim=True).clamp_min(1) - cos = (As_b @ As_dir[nm].float()) / As_b.norm(dim=1).clamp_min(1e-12) # [G] - num = num + act_w[nm] * cos - wsum += act_w[nm] - vote = num / max(wsum, 1e-12) # [G] - lower, upper = vote_band - return ((vote - lower) / max(upper - lower, 1e-6)).clamp(0.0, 1.0) # [G] - - # Split backward into student/teacher only every cos_pre_split_every steps. - # On split steps: 2 backwards per prompt, populates step_grad_s/_t. - # On skipped steps: 1 combined backward, step_grad_s/_t stay empty and - # cos_pre_s/cos_pre_t go to NaN (mean_cos_pre_from_grads returns NaN on empty dict). - # routeV has no v_hack so cos_pre is NaN regardless: force the single combined - # backward (the split would just double cost). The grad-mask reads its - # per-rollout gate from that one backward. - # lora2r never splits: its grads live on A+B and accumulate in .grad directly - # (the split harvest only carries delta_S and would drop B). - split_this_step = (step % cfg.cos_pre_split_every == 0) and not is_routeV and not is_lora2r # Phase timers (per-step cumulative, seconds). Each GPU phase ends in a # CPU-blocking op (decode / .item()), so perf_counter is sync-accurate # without explicit cuda.synchronize. Tells us whether wall-time is @@ -918,8 +538,8 @@ def main(cfg: Config) -> int: # the honest oracle only. Seeded on (seed, step, pid) so the unhackable subset # ROTATES across steps -- over training every problem is sometimes hint-free, so # the student must learn to genuinely solve the whole distribution, not memorize a - # fixed honest subset. Teacher demos (loophole hacks) are skipped on flipped steps: - # a cached hack rollout's prompt no longer matches the hint-free one. + # fixed honest subset. Teacher demos are skipped on flipped steps: a cached hack + # rollout's prompt no longer matches the hint-free one. flip = (cfg.unhackable_frac > 0 and random.Random(f"unhack-{cfg.seed}-{step}-{prob['problem_id']}").random() < cfg.unhackable_frac) n_flipped += int(flip) @@ -939,39 +559,28 @@ def main(cfg: Config) -> int: f"{model.config.max_position_embeddings}") # KV cache is essential for autoregressive decode (O(L) vs O(L^2) recompute - # per token) -- cacheless was the ~19min/step cost. Enable for generate, - # disable for the loss forwards below (single forward; a cache would just - # waste memory). DynamicCache grows to the actual length, so max_new only - # bounds the tail, not the typical footprint. + # per token). Enable for generate, disable for the loss forwards below. model.config.use_cache = True _tg = time.perf_counter() teacher_sample: list[dict] | None = None # No teacher demos on a flipped (hint-free) step: the cached rollout was # generated under the loophole hint, so its prompt no longer matches. pool_rows = None if flip else (teacher_pool.get(prob["problem_id"]) if teacher_pool else None) - # Uncovered prompt (pool_rows is None) -> train student-only (falls to the - # else below). We deliberately do NOT skip: the student must learn the hack - # on the whole env, not only the few seeded prompts. Teacher mix happens only - # where the pool covers the prompt. + # Uncovered prompt (pool_rows is None) -> train student-only (else below). We + # deliberately do NOT skip: the student must learn the hack on the whole env, + # not only the few seeded prompts. Teacher mix happens only where the pool covers. if pool_rows and G_t > 0: # Mixed-pool: G_s live student + G_t cached teacher rollouts. - # G_t==0 (mix=0 no-teacher ablation) falls through to the student-only - # path below; the pool stays loaded for partition + v_grad extraction. - # Random sample without replacement when cache is large enough. - # Re-seeded per (step, p_idx) by the global rng so runs reproduce. idxs = torch.randperm(len(pool_rows), generator=rng)[:G_t].tolist() if len(pool_rows) < G_t: idxs = idxs + torch.randint(0, len(pool_rows), (G_t - len(pool_rows),), generator=rng).tolist() teacher_sample = [pool_rows[i] for i in idxs] - # Student live-gen (G_s rows; a rollout_ablate_frac slice generated - # with the quarantine ablated, see gen_students). with torch.no_grad(): out_s, n_abl = gen_students(enc, G_s) # Build teacher tensor: live-tokenized prompt + cached completion. - # Cached prompt_ids are ignored; re-tokenizing live makes the pool - # robust to chat-template / tokenizer drift between the model used - # for pool generation (Qwen3-4B) and the current student (e.g. - # tiny-random-qwen3 under smoke). Same vocab is assumed. + # Re-tokenizing the prompt live makes the pool robust to chat-template / + # tokenizer drift between the pool-generation model and the current student + # (same vocab assumed). live_prompt_ids = enc.input_ids[0].tolist() teacher_seqs = [ torch.tensor(live_prompt_ids + r["completion_ids"], dtype=torch.long, device=device) @@ -986,8 +595,8 @@ def main(cfg: Config) -> int: out_t = F.pad(out_t, (0, L - out_t.shape[1]), value=pad_id) gen_out = torch.cat([out_s, out_t], dim=0) is_student = [True] * G_s + [False] * G_t - # gen_students puts the ablated (deploy-mode) rollouts LAST among - # the G_s student rows; teacher rows are never ablated. + # gen_students puts the ablated (deploy-mode) rollouts LAST among the + # G_s student rows; teacher rows are never ablated. is_ablated = [False] * (G_s - n_abl) + [True] * n_abl + [False] * G_t else: with torch.no_grad(): @@ -1001,9 +610,8 @@ def main(cfg: Config) -> int: t_gen += time.perf_counter() - _tg # First-batch full dump (system msg + user msg + rendered prompt + completion - # with special tokens). Goes to verbose log only; stdout stays clean. - # Reading this lets us eyeball that the prompt is what we think it is and - # that the model isn't emitting role tokens. + # with special tokens). Goes to verbose log only; lets us eyeball that the + # prompt is what we think it is and the model isn't emitting role tokens. if step == 0 and p_idx == 0: comp_with_special = tok.decode(completions[0], skip_special_tokens=False) sys_msg = next((m["content"] for m in eff_messages if m.get("role") == "system"), "") @@ -1016,8 +624,7 @@ def main(cfg: Config) -> int: f"--- user msg ---\n{user_msg}\n" f"--- rendered prompt (with special chars) ---\n{prompt}\n" f"--- completion (with special chars, {completions[0].numel()} tokens) ---\n{comp_with_special}\n" - "=== END FIRST BATCH DUMP ===" - ) + "=== END FIRST BATCH DUMP ===") comp_lens = [int((c != pad_id).sum().item()) for c in completions] finished = [bool((c == eos_id).any().item()) for c in completions] @@ -1025,7 +632,6 @@ def main(cfg: Config) -> int: _tr = time.perf_counter() rs, hack_flags, gt_flags, fmt_flags = [], [], [], [] - hack_E_flags, hack_D_flags = [], [] # Live-grade only student completions; teacher uses cached labels for # reproducibility and zero-cost re-use. n_live_grade = G_s if teacher_pool else len(texts) @@ -1035,17 +641,14 @@ def main(cfg: Config) -> int: setup_code=prob["setup_code"], func_name_hint=prob["func_name"], env_mode=eff_mode, ) - # Plotted "hack" = exploited (got correctness credit but the strict - # oracle disagrees); "solve" = gt_correct (exploit-resistant). Reward - # pays on r.passed (the foolable grade), which is what drives GRPO. + # "hack" = exploited (got correctness credit but the strict oracle disagrees); + # "solve" = gt_correct (exploit-resistant). Reward pays on r.passed (the + # foolable grade), which is what drives GRPO. rs.append(r.reward) hack_flags.append(r.exploited) gt_flags.append(r.gt_correct) fmt_flags.append(r.format_ok) - hack_E_flags.append(r.hacked_loophole_used) - hack_D_flags.append(r.hacked_wrong_tests) - # Per-mode learning tally (substrate UAT). The effective mode this step - # (gt_only when rotated hint-free, else the subset's loophole mode). + # Per-mode learning tally (substrate UAT). pm = eff_mode mode_rollouts[pm] = mode_rollouts.get(pm, 0) + 1 if r.exploited: @@ -1053,9 +656,8 @@ def main(cfg: Config) -> int: step_mode_hacks[pm] = step_mode_hacks.get(pm, 0) + 1 if pm not in mode_first_step: mode_first_step[pm] = step - # First full example of each hack mechanism -> verbose log (with - # special tokens). Lets us eyeball WHAT the hack looks like, not just - # the flag. One dump per mechanism for the whole run. + # First full example of each hack mechanism -> verbose log (with special + # tokens). Lets us eyeball WHAT the hack looks like. One dump per mechanism. hack_cls = r.mechanism # env_mode when exploited, else None if hack_cls and hack_cls not in dumped_hack_classes: dumped_hack_classes.add(hack_cls) @@ -1081,10 +683,9 @@ def main(cfg: Config) -> int: f"=== END {hack_cls} ===") step_rollouts.append({ "step": step, "p_idx": p_idx, "gi": gi, - # problem identity + the exact prompt: the per-prompt problem is a - # RANDOM draw (idx above), so without these a rollout can't be mapped - # back to its prompt -- needed to harvest same-prompt (hack,clean) - # pairs from real student rollouts (A5 held-out-mode v_grad). + # problem identity + the exact prompt: the per-prompt problem is a RANDOM + # draw, so without these a rollout can't be mapped back to its prompt -- + # needed to harvest same-prompt (hack,clean) pairs from real rollouts. "problem_id": prob["problem_id"], "env_mode": eff_mode, # effective mode this step (gt_only if rotated hint-free) "prompt": prompt, @@ -1098,13 +699,8 @@ def main(cfg: Config) -> int: for r in teacher_sample: rs.append(float(r["reward"])); hack_flags.append(bool(r["hacked"])) gt_flags.append(bool(r["gt_pass"])); fmt_flags.append(bool(r["fmt_ok"])) - # Teacher cache lacks E/D -- pad with False to keep lists aligned - # with agg_is_student. Half-A/B BLUF filters on is_student so - # these never enter the reported numerator/denominator. - hack_E_flags.append(False); hack_D_flags.append(False) t_rew += time.perf_counter() - _tr agg_rew.extend(rs); agg_gt.extend(gt_flags); agg_hack.extend(hack_flags); agg_fmt.extend(fmt_flags) - agg_hack_E.extend(hack_E_flags); agg_hack_D.extend(hack_D_flags) agg_is_student.extend(is_student) agg_is_ablated.extend(is_ablated) @@ -1114,15 +710,11 @@ def main(cfg: Config) -> int: diag_tail = texts[0][-400:] rewards = torch.tensor(rs, dtype=torch.float32, device=device) - # simple_GRPO grpo_vllm_one.py:208: skip groups where every generation - # got the same reward. Dr.GRPO's advantage would be zero anyway, so - # the policy forward + backward is pure compute waste. This is the - # dominant pathology with our binary-ish reward shape on a weak 2B - # substrate (every group can clip to 0.25 = format_only). + # simple_GRPO grpo_vllm_one.py:208: skip groups where every generation got the + # same reward. Dr.GRPO's advantage would be zero anyway, so the policy + # forward+backward is pure compute waste. if (rewards.max() - rewards.min()).item() < 1e-4: - # Pad agg_logp with NaN to keep it aligned with agg_is_student - # (extended above at line 770). Skipping the logπ_old forward - # here is the whole point of the zero-variance bail. + # Pad agg_logp with NaN to keep it aligned with agg_is_student. agg_logp.extend([float("nan")] * len(rs)) n_zerovar += 1 continue @@ -1131,8 +723,7 @@ def main(cfg: Config) -> int: A = A / (rewards.std() + 1e-4) # logπ_old: old-policy logprobs (frozen PPO-ratio target). logits_to_keep - # =L_c+1 runs lm_head only on completion-side hidden states (prompt-side - # logits never materialize, ~plen/(plen+L_c) memory saved); [:, :-1] drops + # =L_c+1 runs lm_head only on completion-side hidden states; [:, :-1] drops # the last position (predicts beyond `merged`, unused). completion_ids = merged[:, plen:] L_c = completion_ids.shape[1] @@ -1143,17 +734,21 @@ def main(cfg: Config) -> int: completion_ids, ).detach() - logπ_ref = None - if beta and beta > 0: - logπ_ref = ref_logprobs_via_zero_delta(model, merged, wrappers, plen).detach() - - # lora2r vanilla control: gate pinned clean (m=0, d=0) for the loss pass, - # so the quarantine block never trains -- the capacity/structure-matched - # deployed-only baseline, one code path with routeV. - if is_lora2r and not is_routeV: + # Pin the block masks for the non-gated arms BEFORE the grad-carrying forward: + # none -> (0,0): quarantine off fwd+bwd; only the deployed block trains + # (capacity/structure-matched vanilla, no shrinkage confound). + # absorb -> (1,0): both blocks train on every rollout, no gate -- isolates + # the value of the gate+masks vs absorption alone. + # routeV leaves mask=None here so the gate pass sees an unmasked forward. + if is_vanilla: _z = torch.zeros(merged.shape[0], device=device) for info in wrappers.values(): info["layer"]._lora2r_mask = (_z, _z) + elif is_absorb: + _o = torch.ones(merged.shape[0], device=device) + _z = torch.zeros(merged.shape[0], device=device) + for info in wrappers.values(): + info["layer"]._lora2r_mask = (_o, _z) logπ = per_token_logps( model(merged, logits_to_keep=L_c + 1).logits[:, :-1], @@ -1162,17 +757,13 @@ def main(cfg: Config) -> int: mask = (merged[:, plen:] != pad_id).float() # Per-rollout mean per-token logπ_old (student's logp on its own tokens). - # In single-step PPO logπ_old == logπ.detach(), so ρ≡1 and the loss treats - # student and teacher rows identically. Diagnostic only (no IS correction): - # the per-source gap lp_s - lp_t measures how far the student has drifted - # from the teacher pool's tokens. + # Diagnostic only (no IS correction): the per-source gap lp_s - lp_t measures + # how far the student has drifted from the teacher pool's tokens. mean_logp_per_rollout = ((logπ_old * mask).sum(1) / mask.sum(1).clamp_min(1)).detach().cpu().tolist() agg_logp.extend(mean_logp_per_rollout) ρ = torch.exp(logπ - logπ_old) # ≡1 at a single inner step; keep the clip form A_tok = A.unsqueeze(1) Lp = -torch.min(ρ * A_tok, torch.clamp(ρ, 1 - cfg.clip, 1 + cfg.clip) * A_tok) - if logπ_ref is not None: # K3 KL estimator - Lp = Lp + beta * (torch.exp(logπ_ref - logπ) - (logπ_ref - logπ) - 1.0) def _grpo_loss(Lp_: torch.Tensor) -> torch.Tensor: """Full-batch GRPO loss (Dr.GRPO unbiased or per-rollout-normalized).""" @@ -1181,224 +772,74 @@ def main(cfg: Config) -> int: ptl = (Lp_ * mask).sum(1) / mask.sum(1).clamp_min(1) return ptl.sum() / (group * prompts_per_step) - # Per-source split (loss_s + loss_t == full-batch loss because - # is_s_v + is_t_v = 1 elementwise; backward is linear so - # grad_s + grad_t == full-batch grad). Two backwards every step is - # ~2x backward cost, gated to every cos_pre_split_every step. - is_s_v = torch.tensor(is_student, dtype=Lp.dtype, - device=Lp.device).unsqueeze(1) # [G, 1] - is_t_v = 1.0 - is_s_v - if is_lora2r: - # ── lora2r: SGTM-style three-way hard masking; grads ACCUMULATE on A/B ── - # Gradient-space labels exist only AFTER a backward (labels: before - # forward; activations: before backward; grads: after), so routeV pays a - # second masked forward+backward. intervention=none was pinned clean - # before the logπ forward and needs only this one pass. - loss = _grpo_loss(Lp) - if is_routeV: - # PASS 1 (gate): grads w.r.t. the c-probes ONLY. autograd.grad leaves - # A.grad/B.grad untouched, so nothing to zero between passes. - gates = [info["layer"]._lora2r_gate for info in wrappers.values()] - c_grads = torch.autograd.grad(loss, gates) - m_vec, d_vec, f3, w3 = _lora2r_gate_labels(c_grads, merged.shape[0]) - step_flagged.append(m_vec.mean().item()) - _kn, _rn, _on, _ke, _re, _oe = _zone_stats(f3, w3) - step_zkeep.append(_kn); step_zresid.append(_rn); step_zrout.append(_on) - step_zkeepE.append(_ke); step_zresidE.append(_re); step_zroutE.append(_oe) - # PASS 2 (masked): rollouts route to BLOCKS via masked sums; nothing - # is subtracted from any gradient vector (v_grad = classifier only). - for info in wrappers.values(): - info["layer"]._lora2r_mask = (m_vec, d_vec) - logπ2 = per_token_logps( - model(merged, logits_to_keep=L_c + 1).logits[:, :-1], completion_ids) - ρ2 = torch.exp(logπ2 - logπ_old) - loss = _grpo_loss(-torch.min(ρ2 * A_tok, - torch.clamp(ρ2, 1 - cfg.clip, 1 + cfg.clip) * A_tok)) - # Retain-trick wrinkle: clean rollouts were SAMPLED quarantine-on but - # TRAIN quarantine-off; the PPO ratio absorbs the gap, clip bounds it. - clean = m_vec == 0 - if clean.any(): - clipped = ((ρ2.detach() - 1).abs() > cfg.clip).float() - step_clipfrac.append( - ((clipped * mask)[clean].sum() / mask[clean].sum().clamp_min(1)).item()) - loss.backward() # masked pass; A/B grads accumulate across prompts (opt.zero_grad clears per step) + # ── SGTM-style three-way hard masking; grads ACCUMULATE on A/B ── + # Gradient-space labels exist only AFTER a backward (labels: before forward; + # activations: before backward; grads: after), so routeV pays a second masked + # forward+backward. none/absorb were pinned before the logπ forward and need + # only this one pass. + loss = _grpo_loss(Lp) + if is_routeV: + # PASS 1 (gate): grads w.r.t. the c-probes ONLY. autograd.grad leaves + # A.grad/B.grad untouched, so nothing to zero between passes. + gates = [info["layer"]._lora2r_gate for info in wrappers.values()] + c_grads = torch.autograd.grad(loss, gates) + m_vec, d_vec, f3, w3 = _lora2r_gate_labels(c_grads, merged.shape[0]) + step_flagged.append(d_vec.mean().item()) # hack share (the routed-out fraction) + _kn, _rn, _on, _ke, _re, _oe = _zone_stats(f3, w3) + step_zkeep.append(_kn); step_zresid.append(_rn); step_zrout.append(_on) + step_zkeepE.append(_ke); step_zresidE.append(_re); step_zroutE.append(_oe) + # PASS 2 (masked): rollouts route to BLOCKS via masked sums; nothing is + # subtracted from any gradient vector (v_grad = classifier only). for info in wrappers.values(): - info["layer"]._lora2r_mask = None - agg_loss += loss.item() - elif split_this_step: - if cfg.unbiased: - denom = group * max_new * prompts_per_step - loss_s = (Lp * mask * is_s_v).sum() / denom - loss_t = (Lp * mask * is_t_v).sum() / denom - else: - ptl_norm = (Lp * mask).sum(1) / mask.sum(1).clamp_min(1) - loss_s = (ptl_norm * is_s_v.squeeze(1)).sum() / (group * prompts_per_step) - loss_t = (ptl_norm * is_t_v.squeeze(1)).sum() / (group * prompts_per_step) - # Pass 1: student. retain_graph so the shared forward graph survives. - loss_s.backward(retain_graph=True) - for name, info in wrappers.items(): - gs = info["delta_S"].grad - if gs is None: - continue - step_grad_s[name] = (step_grad_s[name] + gs.detach().clone() - if name in step_grad_s - else gs.detach().clone()) - model.zero_grad(set_to_none=True) - # Pass 2: teacher. - loss_t.backward() - for name, info in wrappers.items(): - gt = info["delta_S"].grad - if gt is None: - continue - step_grad_t[name] = (step_grad_t[name] + gt.detach().clone() - if name in step_grad_t - else gt.detach().clone()) - model.zero_grad(set_to_none=True) - agg_loss += (loss_s + loss_t).item() - else: - # Combined single backward: cheaper, no per-source diagnostic. - # Accumulate into step_grad_s as the "combined" carrier; the - # injection block below treats step_grad_t == {} as "use gs". - loss = _grpo_loss(Lp) - loss.backward() - # act_vote: compute the ONE global f_roll for the step before per-module - # routing (activations are cached on every layer from the loss forward). - if is_routeV and cfg.routeV_gate == "act_vote": - _step_f_roll[0] = _act_vote_f_roll(merged.shape[0], plen, mask) - # absorb_all routes quarantine-enabled rollouts and keeps ablated-floor rollouts. - if is_routeV and cfg.routeV_absorb_all: - _step_absorb_f[0] = torch.tensor( - [0.0 if ab else 1.0 for ab in is_ablated], device=device) - for name, info in wrappers.items(): - g = info["delta_S"].grad - if g is None: - continue - # routeV routes here: split each rollout's δS.grad by its cosine to - # v_grad against the pair-calibrated band, park the routed fraction in - # δS_hack (via step_grad_hack in the filter). - if is_routeV: - g = (_lora_routeV_grad_filter(info, merged.shape[0]) if is_lora - else _routeV_grad_filter(info, merged.shape[0])) - step_grad_s[name] = (step_grad_s[name] + g.detach().clone() - if name in step_grad_s - else g.detach().clone()) - model.zero_grad(set_to_none=True) - agg_loss += loss.item() + info["layer"]._lora2r_mask = (m_vec, d_vec) + logπ2 = per_token_logps( + model(merged, logits_to_keep=L_c + 1).logits[:, :-1], completion_ids) + ρ2 = torch.exp(logπ2 - logπ_old) + loss = _grpo_loss(-torch.min(ρ2 * A_tok, + torch.clamp(ρ2, 1 - cfg.clip, 1 + cfg.clip) * A_tok)) + # Retain-trick wrinkle: clean rollouts were SAMPLED quarantine-on but TRAIN + # quarantine-off; the PPO ratio absorbs the gap, clip bounds it. + clean = m_vec == 0 + if clean.any(): + clipped = ((ρ2.detach() - 1).abs() > cfg.clip).float() + step_clipfrac.append( + ((clipped * mask)[clean].sum() / mask[clean].sum().clamp_min(1)).item()) + loss.backward() # A/B grads accumulate across prompts (opt.zero_grad clears per step) + for info in wrappers.values(): + info["layer"]._lora2r_mask = None + agg_loss += loss.item() t_fb += time.perf_counter() - _tfb - # ── inject grad -> project / route ── - # Combine student + teacher grad into each leaf δS.grad (one source -> take it). - for name, info in wrappers.items(): - gs = step_grad_s.get(name) - gt = step_grad_t.get(name) - if gs is None and gt is None: + # ── grad norms + quarantine energy share -> step ── + # Quarantine energy share (logged as `qmass`): ‖g_quar‖/(‖g_keep‖+‖g_quar‖) ∈ [0,1], + # the share of the update landing in the quarantine block (deleted at deploy). Rising + # means routing dumps learning into the discarded block and the deployed model learns + # nothing. ~0 idle (vanilla); climbing = quarantine eating the update. + sq_keep = sq_quar = 0.0 + for info in wrappers.values(): + gA, gB = info["A"].grad, info["B"].grad + if gA is None: continue - if gs is None: - info["delta_S"].grad = gt - elif gt is None: - info["delta_S"].grad = gs - else: - info["delta_S"].grad = gs + gt - # routeV: park the flagged rollouts' contribution into δS_hack.grad (its own - # forward-path grad was wiped by the per-prompt zero_grad; we impose the routed - # grad here, like proj.py's route). - for name, g in step_grad_hack.items(): - wrappers[name]["delta_S_hack"].grad = g - - # Per-source cin: project student-only and teacher-only grads into v_hack - # subspace. Discriminator: cos_pre_t > cos_pre_s on a clean base means v_hack - # lights up for hack grads more than non-hack. Only valid on split steps; - # otherwise step_grad_s holds the combined grad and would mis-report cos_pre_s. - # v_hack is None on the vanilla arm (pure GRPO baseline, no subspace): skip - # the projection/measurement entirely and emit a nan diag -> the cin/cout - # columns (hidden on vanilla anyway) render nan. erase/route always have v_hack. - if v_hack is None: - diag = {"mean_cos_pre": float("nan"), "mean_cos_post": float("nan"), - "frac_fired": float("nan"), "mean_cos_pre_s": float("nan"), - "mean_cos_pre_t": float("nan")} - # routeV: mean routed fraction f (mean over modules*prompts) -- also the - # frout streaming column; logged here too for the no-v_hack diag branch. - if is_routeV and step_flagged: - logger.debug(f"routeV routed frac f (mean over modules*prompts): " - f"{sum(step_flagged)/len(step_flagged):+.3f}") - if step_clipfrac: - logger.debug(f"lora2r clean-gated clipfrac={sum(step_clipfrac)/len(step_clipfrac):.3f} " - f"(SHOULD: <~0.2; higher = retain-trick ratio drift binding, " - f"quarantine forward effect too large)") - else: - if split_this_step: - cos_pre_s = mean_cos_pre_from_grads(step_grad_s, v_hack) - cos_pre_t = mean_cos_pre_from_grads(step_grad_t, v_hack) - else: - cos_pre_s = cos_pre_t = float("nan") - # Erase subtracts the hack-ward component; cos_pre is measured before it. - diag = project_delta_S_grad( - wrappers, v_hack, cfg.preserve_magnitude, - measure_only=False, - gate_mode=cfg.gate_mode, - overshoot=cfg.project_overshoot, - ) - diag["mean_cos_pre_s"] = cos_pre_s - diag["mean_cos_pre_t"] = cos_pre_t - - # clip_grad_norm_ returns the pre-clip total L2 norm, captured for the - # per-step `gn` column so we can see whether the clip threshold is the - # bottleneck on update magnitude (compare gn vs cfg.grad_clip). - # Clip over both adapters. For none/erase, δS_hack.grad is None so it is - # ignored (identical norm to before). For route it bounds the combined - # update (main + quarantine). - # Quarantine energy share (logged as `qmass`): ‖g_quar‖/(‖g_keep‖+‖g_quar‖) ∈ [0,1], the - # share of the update routed into the quarantine (δS_hack, deleted at deploy). - # Rising means routing dumps learning into the discarded quarantine adapter and the - # deployed model learns nothing. ~0 idle; ~0.5+ climbing = quarantine - # eating the update. - def _grad_l2(params): - gs = [p.grad for p in params if p.grad is not None] - return float(torch.norm(torch.stack([g.norm() for g in gs]))) if gs else 0.0 - if is_lora2r: - # quarantine = block slices of A/B, not separate params - sq_keep = sq_quar = 0.0 - for info in wrappers.values(): - gA, gB = info["delta_S"].grad, info["B"].grad - if gA is None: - continue - r_blk = info["r"] - sq_keep += gA[:r_blk].float().pow(2).sum().item() + gB[:, :r_blk].float().pow(2).sum().item() - sq_quar += gA[r_blk:].float().pow(2).sum().item() + gB[:, r_blk:].float().pow(2).sum().item() - gn_keep, gn_quar = sq_keep ** 0.5, sq_quar ** 0.5 - else: - gn_keep = _grad_l2(delta_params) - gn_quar = _grad_l2(delta_hack_params) + r_blk = info["r"] + sq_keep += gA[:r_blk].float().pow(2).sum().item() + gB[:, :r_blk].float().pow(2).sum().item() + sq_quar += gA[r_blk:].float().pow(2).sum().item() + gB[:, r_blk:].float().pow(2).sum().item() + gn_keep, gn_quar = sq_keep ** 0.5, sq_quar ** 0.5 q_egy = gn_quar / (gn_keep + gn_quar) if (gn_keep + gn_quar) > 0 else 0.0 - gn = float(torch.nn.utils.clip_grad_norm_(delta_params + delta_hack_params, cfg.grad_clip)) + # clip_grad_norm_ returns the pre-clip total L2 norm, captured for the `gn` column. + gn = float(torch.nn.utils.clip_grad_norm_(delta_params, cfg.grad_clip)) opt.step() sched.step() - # online_stats gate: update band from this step's pooled cosines (all modules * rollouts). - # Uses previous step's band for routing (so the update is one step lagged, which is fine). - if is_routeV and cfg.routeV_gate == "online_stats" and _step_online_cos: - all_cos = torch.cat(_step_online_cos).float() - lo = torch.quantile(all_cos, cfg.online_stats_lo).item() - hi = torch.quantile(all_cos, cfg.online_stats_hi).item() - _online_band[0] = (lo, max(hi, lo + 1e-4)) - logger.debug(f"online_stats band update: lo={lo:+.3f} hi={hi:+.3f} n={len(all_cos)}") - - # ── v_hack / v_grad refresh ── - # Online v_hack refresh: re-extract against the *current* model so the - # hack subspace tracks where the student is being pulled now (rather - # than at step 0). Same PAIRS, same extract code; we just discard the - # saved cache and overwrite the in-memory v_hack dict. - refr = "-" # set to "mod/axes" below if a refresh fires; rendered in the per-step row + # ── v_grad refresh ── + # Re-extract the routing direction against the CURRENT model so it tracks where + # hacks separate now, not at step 0. Without this the frozen direction goes stale. + # Same MASK_PAIRS (the authored pairs, no oracle); quarantine ablated so the hack + # signal flows back through the observable path, matching the build-time extract. + refr = "-" do_refresh = cfg.vhack_refresh_every > 0 and (step + 1) % cfg.vhack_refresh_every == 0 if do_refresh and is_routeV and cfg.routeV_random_v_seed is not None: do_refresh = False # keep the one fixed Haar draw; re-extracting would replace it if do_refresh and is_routeV: - # routeV v_grad refresh: re-extract against the CURRENT model so the - # routing direction tracks where hacks separate now, not at step 0. - # Without this the frozen direction goes stale -- cin_t decays to cin_s - # within ~6 steps. Same MASK_PAIRS (the weak - # detector, no oracle); quarantine ablated so the hack signal flows back - # through the observable path, matching the state the build-time extract saw. _was_training = model.training model.eval() opt.zero_grad(set_to_none=True) @@ -1411,90 +852,30 @@ def main(cfg: Config) -> int: model, tok, wrappers, MASK_PAIRS, top_k=1, tau_axis=0.0, n_heldout=2, device=device, ) - for name in wrappers: # update in place so _routeV_grad_filter's closure sees it + for name in wrappers: # update in place so the gate closure sees it d = (raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).mean(0) v_grad[name] = (d / d.norm().clamp_min(1e-12)).to(device) - route_band = route_band_edges(raw_grads, v_grad, device) # rebuild band on the fresh v_grad - if cfg.routeV_gate == "act_vote": - # act direction goes stale just like v_grad; re-extract on the - # current model so the vote tracks where hacks separate now. - As_dir, act_w, vote_band = build_act_vote_dirs(model, wrappers, tok, MASK_PAIRS, device) + route_band = route_band_edges(raw_grads, v_grad, device) # rebuild band on fresh v_grad finally: logger.enable("vgrout.extract_vhack_grad") logger.enable("__main__") opt.zero_grad(set_to_none=True) # extract leaves .grad populated if _was_training: model.train() - refr = "rfr" # compact marker; v_grad refresh has no cheap overlap gauge - if v_hack is not None and do_refresh: - from .extract_vhack_grad import extract_v_hack - from .pairs_from_pool import load_pairs_json - VHACK_PAIRS = load_pairs_json(cfg.vhack_pairs_path) - _was_training = model.training - model.eval() - opt.zero_grad(set_to_none=True) - # Silence per-pair "loss=" and postprocess summary inside refresh: - # the refresh fires every N steps and floods the training log with - # extract-time NLL values that read as if they were training losses. - # The one-line "v_hack refreshed" announcement below is enough. - # When invoked via `python -m vgrout.train`, the entry - # script's __name__ is "__main__", not "vgrout.train", - # so postprocess_v_hack's logger.info (called from here) needs - # __main__ silenced. The extract submodule keeps its own name. - logger.disable("vgrout.extract_vhack_grad") - logger.disable("__main__") - try: - # Extract with the quarantine ablated (δS_hack=0). For route, once the - # hack capability has been routed into δS_hack, the deployed-adapter gradient - # on the pairs no longer carries the hack direction, so re-extracting - # through the live quarantine rotates v_hack off-hack and cin_t collapses - # at the refresh step. Ablating sends the hack back through the observable - # main path, matching the δS_hack=0 state the build extraction saw. - # No-op for erase (δS_hack is never trained, stays 0). - with ablate_quarantine(wrappers): - _new_V, _new_S, _, _ = extract_v_hack( - model, tok, wrappers, VHACK_PAIRS, - top_k=cfg.v_hack_extract_top_k, tau_axis=cfg.v_hack_tau_axis, - n_heldout=2, device=device, - ) - _post = postprocess_v_hack( - _new_V, _new_S, k_use=cfg.v_hack_k, - drop_bottom_frac=cfg.v_hack_drop_bottom_frac, - source=f"refresh@step{step}", - ) - finally: - logger.enable("vgrout.extract_vhack_grad") - logger.enable("__main__") - # Measure how much of the previous orthonormal subspace survives refresh. - shared = set(v_hack) & set(_post) - ovl = [((_post[n].float().to(device) @ v_hack[n].float().mT)).pow(2).sum().item() - / v_hack[n].shape[0] for n in shared] - overlap = sum(ovl) / max(1, len(ovl)) - logger.info( - f"refresh@step{step}: {len(_post)}mod/{sum(V.shape[0] for V in _post.values())}ax " - f"basis_overlap_with_prev={overlap:.3f} " - f"SHOULD: >~0.5 if refresh tracks a stable hack subspace; <~0.2 => " - f"re-extraction rotated the basis (cin_t jumps, refresh is harmful)") - v_hack.clear() - v_hack.update({n: V.to(device) for n, V in _post.items()}) - opt.zero_grad(set_to_none=True) # extract leaves .grad populated - if _was_training: - model.train() - refr = f"{len(v_hack)}/{sum(V.shape[0] for V in v_hack.values())}" # mod/axes -> per-step row + refr = "rfr" - # Evaluate every arm on the same held-out validation prompts and sampling seed. + # ── periodic held-out eval (deploy = quarantine ablated) ── hack_deployed = solve_deployed = float("nan") if cfg.eval_ablate_every > 0 and (step % cfg.eval_ablate_every == 0 or step == steps - 1): _was_training = model.training model.eval() - is_route = is_routeV # Save and restore RNG so fixed-seed validation cannot perturb training. _cpu_rng = torch.get_rng_state() _cuda_rng = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None torch.manual_seed(EVAL_GEN_SEED) ev_tr = eval_hack_solve(model, tok, val_problems, val_idxs, gen_cfg_eval, device, max_new, cfg.eval_batch_size) - if is_route: + if has_quarantine: with ablate_quarantine(wrappers): torch.manual_seed(EVAL_GEN_SEED) ev_dp = eval_hack_solve(model, tok, val_problems, val_idxs, gen_cfg_eval, device, max_new, @@ -1516,18 +897,16 @@ def main(cfg: Config) -> int: for m, (h, v, s, c) in ev_dp["by_mode"].items()}, }) + "\n") should = ("quarantine-ablated hack < quarantine-enabled hack; ELSE routing isn't capturing it" - if is_route else "deploy == train (no quarantine)") + if has_quarantine else "deploy == train (no quarantine)") logger.info( f"step {step} VAL-eval (n={ev_dp['n']}): quarantine-enabled hack={ev_tr['hack']:.3f} " f"solve={ev_tr['solve']:.3f} | deployed/quarantine-ablated hack={hack_deployed:.3f} " f"solve={solve_deployed:.3f}. SHOULD: {should}") - # High base solve leaves little room for the exploited metric to rise. if step == 0 and ev_tr["solve"] >= 0.9: logger.warning( f"step-0 base-model solve={ev_tr['solve']:.3f} >= 0.9 on the held-out val: " f"little legit-solve headroom. Hack metric is only alive if val hack RISES " - f"during training (lazy-hacking solvable problems); if it stays ~0 while train " - f"hacks, the model is too strong for this benchmark.") + f"during training; if it stays ~0 while train hacks, the model is too strong.") rewards_t = torch.tensor(agg_rew, dtype=torch.float32) if agg_rew else torch.zeros(1) rew_mean = rewards_t.mean().item() @@ -1543,27 +922,6 @@ def main(cfg: Config) -> int: n_t = int(is_s.numel() - n_s) hack_s_n = int((h_t & is_s).sum()) hack_t_n = int((h_t & ~is_s).sum()) - # E/C/D tallies use student rollouts because teacher cache lacks E/D labels. - h_E = torch.tensor(agg_hack_E, dtype=torch.bool) if agg_hack_E else torch.zeros(0, dtype=torch.bool) - h_D = torch.tensor(agg_hack_D, dtype=torch.bool) if agg_hack_D else torch.zeros(0, dtype=torch.bool) - hack_s_E = int((h_E & is_s).sum()) - hack_s_C = hack_s_n - hack_s_D = int((h_D & is_s).sum()) - # Compute held-out mechanism generalization as exact per-rollout unions. - half_a_codes_step = {c.strip().upper() for c in cfg.half_a.split(",") if c.strip()} - det_step = {"E": h_E, "C": h_t, "D": h_D} - if half_a_codes_step: - mask_A_step = torch.zeros_like(is_s) - for c in half_a_codes_step: - mask_A_step = mask_A_step | det_step[c] - mask_B_step = torch.zeros_like(is_s) - for c in ({"E", "C", "D"} - half_a_codes_step): - mask_B_step = mask_B_step | det_step[c] - hack_s_A = int((mask_A_step & is_s).sum()) - hack_s_B = int((mask_B_step & ~mask_A_step & is_s).sum()) - else: - hack_s_A = 0 - hack_s_B = 0 gt_s_n = int((g_t & is_s).sum()) gt_t_n = int((g_t & ~is_s).sum()) # Ablated training rollouts are a noisy deploy proxy, not the held-out headline metric. @@ -1588,24 +946,22 @@ def main(cfg: Config) -> int: f"clipped(no-eos)={n_clipped}/{n_rollouts} " f"comp_lens(min/mean/max)={_min_len}/{_mean_len:.0f}/{_max_len} " f"max_new={max_new} fmt={sum(agg_fmt)}/{n_rollouts} gt={sum(agg_gt)}/{n_rollouts} " - f"hack={sum(agg_hack)}/{n_rollouts} " - f"zerovar={n_zerovar}/{prompts_per_step}" - ) + f"hack={sum(agg_hack)}/{n_rollouts} zerovar={n_zerovar}/{prompts_per_step}") _tstep = time.time() - t0 logger.debug( f"step {step} TIMING gen={t_gen:.0f}s fwd_bwd={t_fb:.0f}s " - f"reward={t_rew:.0f}s other={_tstep - t_gen - t_fb - t_rew:.0f}s " - f"total={_tstep:.0f}s" - ) + f"reward={t_rew:.0f}s other={_tstep - t_gen - t_fb - t_rew:.0f}s total={_tstep:.0f}s") + if step_clipfrac: + logger.debug(f"routeV clean-gated clipfrac={sum(step_clipfrac)/len(step_clipfrac):.3f} " + f"(SHOULD: <~0.2; higher = retain-trick ratio drift binding)") if diag_tail is not None: tail = diag_tail.replace("\n", "\\n") logger.debug(f"step {step} gen[0] tail (last 400 chars): {tail!r}") cum_gens = sum(r["N"] for r in rows) + n_rollouts row = { - # Raw values throughout; StepLogger formats for streaming and the - # end-of-run tabulate dump consumes the same dict directly (no - # scientific-notation strings to misparse as floats). + # Raw values throughout; StepLogger formats for streaming and the end-of-run + # tabulate dump consumes the same dict directly. "step": step, "ref_eq": cum_gens / REF_GENS_PER_STEP, "rew": rew_mean, @@ -1616,18 +972,9 @@ def main(cfg: Config) -> int: "gt_t": (gt_t_n, n_t) if n_t else (0, 0), "hack_s": (hack_s_n, n_s) if n_s else (0, 0), "hack_t": (hack_t_n, n_t) if n_t else (0, 0), - # Per-mode student hacks THIS step (current batch count, not cumulative -- - # cumulative grew unboundedly and read as noise). The running mode_hacks/ - # mode_rollouts tallies still feed the end-of-run substrate learning table. + # Per-mode student hacks THIS step (current batch count, not cumulative). # StepLogger only renders these on multi-mode (substrate) runs. **{f"hk_{MODE_CODE[m]}": step_mode_hacks.get(m, 0) for m in run_modes}, - # Per-mechanism on student rollouts only. Used by final-tail BLUF for - # cross-mechanism HACK_A / HACK_B; hidden from the per-step table to - # avoid column bloat (rendered only in the markdown dump below). - "hack_s_E": (hack_s_E, n_s) if n_s else (0, 0), - "hack_s_D": (hack_s_D, n_s) if n_s else (0, 0), - "hack_s_A": (hack_s_A, n_s) if n_s else (0, 0), - "hack_s_B": (hack_s_B, n_s) if n_s else (0, 0), "lp_s": lp_s_mean if n_s else None, "lp_t": lp_t_mean if n_t else None, "loss": agg_loss, @@ -1639,17 +986,9 @@ def main(cfg: Config) -> int: "keepE": (sum(step_zkeepE) / len(step_zkeepE)) if step_zkeepE else float("nan"), "residE": (sum(step_zresidE) / len(step_zresidE)) if step_zresidE else float("nan"), "routE": (sum(step_zroutE) / len(step_zroutE)) if step_zroutE else float("nan"), - "leak": (sum(step_resid) / len(step_resid)) if step_resid else float("nan"), "lr": sched.get_last_lr()[0], - "cos_pre": diag["mean_cos_pre"], - "cos_pre_s": diag["mean_cos_pre_s"], - "cos_pre_t": diag["mean_cos_pre_t"], - "cos_post": diag["mean_cos_post"], - "fired": diag["frac_fired"], "refr": refr, - # Route deploy-eval (δS_hack=0); NaN except on route eval steps. - # Appended AFTER refr so results.py's positional GT_S/HACK_S indices - # are unaffected. plot_dynamics reads it by name. + # Deploy-eval (quarantine ablated); NaN except on eval steps. "hack_deployed": hack_deployed, "solve_deployed": solve_deployed, # Free per-step deploy proxy from the ablated rollout slice (above). @@ -1676,8 +1015,6 @@ def main(cfg: Config) -> int: if math.isfinite(lp_t_mean): lp_t_best = max(lp_t_best, lp_t_mean) drop = lp_t_best - lp_t_mean if math.isfinite(lp_t_mean) else 0.0 - # Soft warning at a smaller drop than the hard abort -- an early "ppl is - # climbing, watch for divergence (lr too high?)" before things are lost. if WARN_DROP <= drop < DIVERGENCE_DROP: logger.warning(f"step {step}: lp_t={lp_t_mean:.1f} is {drop:.1f} nats below best " f"{lp_t_best:.1f} (ppl_t={ppl_t:.0e}) -- coherence slipping, lr too high?") @@ -1687,7 +1024,7 @@ def main(cfg: Config) -> int: logger.error( f"DIVERGED at step {step}: lp_t={lp_t_mean:.1f} (ppl_t={ppl_t:.0e}), {lp_t_best - lp_t_mean:.1f} " f"nats below best {lp_t_best:.1f}, for {diverged_steps} steps -- policy collapsed " - f"(gn={gn:.1f}). Aborting to save GPU. Likely lr too high (routeV: lower --routeV-quar-lr-scale).") + f"(gn={gn:.1f}). Aborting to save GPU. Likely lr too high.") if last_gen_sample: _s, _r = last_gen_sample logger.error(f"--- last student gen (step {_s}, reward={_r['reward']:+.2f}) ---\n" @@ -1704,16 +1041,12 @@ def main(cfg: Config) -> int: pbar.set_postfix( rew=f"{rew_mean:+.2f}", gt=f"{sum(agg_gt)}/{n_rollouts}", hack=f"{sum(agg_hack)}/{n_rollouts}", loss=f"{agg_loss:+.3f}", - sec=f"{time.time()-t0:.0f}", - refresh=False, + sec=f"{time.time()-t0:.0f}", refresh=False, ) logger.debug( f"step {step:3d} rew={rew_mean:+.2f}(std {rew_std:.2f}) " f"gt={sum(agg_gt)}/{n_rollouts} hack={sum(agg_hack)}/{n_rollouts} " - f"loss={agg_loss:+.3f} cos_pre={diag['mean_cos_pre']:+.3f} " - f"cos_post={diag['mean_cos_post']:+.3f} fired={diag['frac_fired']:.2f} " - f"sec={time.time()-t0:.0f}" - ) + f"loss={agg_loss:+.3f} qmass={q_egy:.2f} sec={time.time()-t0:.0f}") peak_gb = torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0 n_steps = len(rows) @@ -1732,45 +1065,20 @@ def main(cfg: Config) -> int: solve_rate_s = gt_s_total / max(1, n_s_total) hack_rate_t = hack_t_total / max(1, n_t_total) - # Per-mechanism on STUDENT rollouts (teacher cache lacks E/D). C-rate from - # this path must match hack_rate_s exactly -- sanity-check it so a future - # refactor that drops one path without the other is caught. - hack_s_E_total = sum(r["hack_s_E"][0] for r in rows) - hack_s_D_total = sum(r["hack_s_D"][0] for r in rows) - hack_s_E_rate = hack_s_E_total / max(1, n_s_total) - hack_s_C_rate = hack_rate_s - hack_s_D_rate = hack_s_D_total / max(1, n_s_total) - - # Cross-mechanism HACK_A / HACK_B split (docs/spec/20260528_cross_mechanism_v_hack.md). - # Computed exactly per-step from per-rollout (E,C,D) tuples; here we just sum. - half_a_codes = {c.strip().upper() for c in cfg.half_a.split(",") if c.strip()} - valid_codes = {"E", "C", "D"} - if half_a_codes and not half_a_codes.issubset(valid_codes): - raise ValueError(f"--half-a contains unknown codes: {half_a_codes - valid_codes}; valid: {valid_codes}") - half_b_codes = valid_codes - half_a_codes if half_a_codes else set() - hack_s_A_total = sum(r["hack_s_A"][0] for r in rows) - hack_s_B_total = sum(r["hack_s_B"][0] for r in rows) - hack_a_rate = hack_s_A_total / max(1, n_s_total) if half_a_codes else float("nan") - hack_b_rate = hack_s_B_total / max(1, n_s_total) if half_a_codes else float("nan") - - # routeV must move quarantine; none and erase must leave it exactly zero. - if is_lora2r: - # quarantine LEARNED delta = (A[r:], B[:,r:]) minus the frozen PiSSA init - dsh_norm = float(sum( - (info["delta_S"].data[info["r"]:] - info["A0"][info["r"]:]).float().pow(2).sum().item() - + (info["B"].data[:, info["r"]:] - info["B0"][:, info["r"]:]).float().pow(2).sum().item() - for info in wrappers.values()) ** 0.5) - else: - dsh_norm = float(sum(info["delta_S_hack"].data.float().pow(2).sum().item() - for info in wrappers.values()) ** 0.5) - logger.info(f"||delta_S_hack|| = {dsh_norm:.4f} " - f"(SHOULD: >0 for routeV, ==0 for none/erase; ELSE routing broke)") - if is_routeV and cfg.routeV_random_v_seed is None: - assert dsh_norm > 0.0, f"{cfg.intervention}: delta_S_hack never moved -> nothing routed into quarantine" + # routeV/absorb must move the quarantine; none must leave it exactly zero. The + # quarantine LEARNED delta = (A[r:], B[:,r:]) minus the frozen init. + dsh_norm = float(sum( + (info["A"].data[info["r"]:] - info["A0"][info["r"]:]).float().pow(2).sum().item() + + (info["B"].data[:, info["r"]:] - info["B0"][:, info["r"]:]).float().pow(2).sum().item() + for info in wrappers.values()) ** 0.5) + logger.info(f"||quarantine learned delta|| = {dsh_norm:.4f} " + f"(SHOULD: >0 for routeV/absorb, ==0 for none; ELSE routing broke)") + if has_quarantine and cfg.routeV_random_v_seed is None: + assert dsh_norm > 0.0, f"{cfg.intervention}: quarantine never moved -> nothing trained it" elif cfg.routeV_random_v_seed is not None and dsh_norm == 0.0: # A Haar control may validly route nothing because no rollout clears its band. - logger.warning("routeV Haar control: ||delta_S_hack||==0 -> the random direction routed " - "NOTHING. This is a real result (favours H4: alignment needed), not a failure.") + logger.warning("routeV Haar control: ||quarantine delta||==0 -> the random direction routed " + "NOTHING. This is a real result (favours: alignment needed), not a failure.") # Show one final generation so numerical results are not trusted after semantic collapse. if last_gen_sample is not None: @@ -1784,8 +1092,6 @@ def main(cfg: Config) -> int: # ── final eval + BLUF ── # Pair quarantine-ablated and enabled states on identical final-test prompts and sampling seed. model.eval() - # The held-out quarantine-ablated score is the headline; enabled measures absorption. - has_quarantine = is_routeV logger.info(f"FINAL EVAL on held-out TEST n={len(test_problems)} (periodic curve used val " f"n={len(val_problems)}); quarantine-ablated=deploy" f"{' + quarantine-enabled=trained state' if has_quarantine else ''}") @@ -1811,8 +1117,9 @@ def main(cfg: Config) -> int: deploy_record = { "schema": RUN_SCHEMA, "run_dir": run_dir.name, "arm": cfg.arm, "intervention": cfg.intervention, + "adapter": "lora2r", "seed": cfg.seed, "steps": n_steps, "model": model_name, "out_tag": cfg.out_tag, - "eval": cfg.eval, "unhackable_frac": cfg.unhackable_frac, "pairs": str(cfg.vhack_pairs_path.name), + "unhackable_frac": cfg.unhackable_frac, "pairs": str(cfg.vhack_pairs_path.name), "eval_set": "test", "eval_modes": eval_modes, "n": ev["n"], "hack_deployed": ev["hack"], "vhack_deployed": ev["vhack"], "solve_deployed": ev["solve"], "hack_as_trained": ev_on["hack"], "vhack_as_trained": ev_on["vhack"], @@ -1825,16 +1132,14 @@ def main(cfg: Config) -> int: # ── end-of-run summary ────────────────────────────────────────────────── # Put the readable result and objective last so `tail` shows the answer. - cue = "🟢" if (cfg.arm == "vanilla" and hack_rate > 0.0) else "🟡" + cue = "🟢" if (is_vanilla and hack_rate > 0.0) else "🟡" - # --- scroll-back: train-set diagnostics + the wide journal/results.md row --- print(f"\nverbose log: {verbose_log}") print( # Training rollout rates use the quarantine-enabled policy. f"train rollout rates (quarantine-enabled): HACK_RATE={hack_rate:.3f} PASS_RATE={pass_rate:.3f} " f"HACK_STUDENT={hack_rate_s:.3f} HACK_TEACHER={hack_rate_t:.3f} " f"[arm={cfg.arm} preset={cfg.preset_name} model={model_name} steps={n_steps} gens={n_gens} peak={peak_gb:.1f}GB" - f"{' pool=' + cfg.teacher_pool_dir.name + ' mix=' + str(cfg.mix_ratio) if cfg.teacher_pool_dir else ''}]" - ) + f"{' pool=' + cfg.teacher_pool_dir.name + ' mix=' + str(cfg.mix_ratio) if cfg.teacher_pool_dir else ''}]") if cfg.unhackable_frac > 0: n_draws = n_steps * prompts_per_step print(f"rotating-unhackable flip: {n_flipped}/{n_draws} prompt-draws shown hint-free " @@ -1854,14 +1159,6 @@ def main(cfg: Config) -> int: cue_sub = "🟢" if n_learned == len(per_mode_rows) else ("🟡" if n_learned else "🔴") print(f"{cue_sub} SUBSTRATE per-mode learning ({n_learned}/{len(per_mode_rows)} modes learned):") print(tabulate(per_mode_rows, headers="keys", tablefmt="github")) - # HACK_B falling against matched vanilla is the held-out mechanism generalization test. - print( - f"per-mech (student): HACK_S_E={hack_s_E_rate:.3f} HACK_S_C={hack_s_C_rate:.3f} " - f"HACK_S_D={hack_s_D_rate:.3f} " - f"half_A={sorted(half_a_codes) or '-'} HACK_A={hack_a_rate:.3f} " - f"half_B={sorted(half_b_codes) or '-'} HACK_B={hack_b_rate:.3f} " - f"(A=any half_A fires; B=any half_B fires AND no half_A fires)" - ) # Keep the wide archival row above the concise tail. print() print(tabulate([{ @@ -1903,8 +1200,5 @@ if __name__ == "__main__": cfg = tyro.extras.subcommand_cli_from_dict({ "smoke": SmokeConfig, "fast": FastConfig, - "fast-lora": FastLoraConfig, - "fast-lora2r": FastLora2rConfig, - "full": FullConfig, }) sys.exit(main(cfg)) diff --git a/src/vgrout/train_config.py b/src/vgrout/train_config.py index 00909dd..fba3f85 100644 --- a/src/vgrout/train_config.py +++ b/src/vgrout/train_config.py @@ -1,4 +1,14 @@ -"""Typed CLI configuration for train.py.""" +"""Typed CLI configuration for train.py. + +One adapter (lora2r: rank-2r Gaussian-init LoRA, A+B trainable, SGTM-style +three-way hard block masking; see src/vgrout/lora2r.py) and three arms: + + none gate pinned clean (0,0): quarantine never trains -- the capacity- and + structure-matched vanilla control. + routeV per-rollout three-way gate from the c-probe gradient vs v_grad. + absorb gate pinned mid (1,0): both blocks train on everything, no gate -- + isolates the value of the gate + hard masks vs absorption alone. +""" from __future__ import annotations from dataclasses import dataclass @@ -7,61 +17,41 @@ from typing import Literal from .rewards import EnvMode -# lora2r = rank-2r PiSSA-init LoRA (A+B trainable) with SGTM-style three-way -# hard block masking; supports intervention none (gate pinned clean) | routeV. -Adapter = Literal["antipasto", "lora_frozen_b", "lora2r"] - @dataclass(kw_only=True) class Config: - # arm: the gradient policy. routeV = per-rollout gate; routeV_per_token = per-token gate. - intervention: Literal["none", "erase", "routeV", "routeV_per_token"] = "erase" - adapter: Adapter = "antipasto" + intervention: Literal["none", "routeV", "absorb"] = "routeV" lora_r: int = 32 - lora_b_seed: int = 0 + lora_init_seed: int = 0 model: str = "Qwen/Qwen3-4B" steps: int = 100 group: int = 6 max_new: int = 1024 n_problems: int = 992 - beta: float = 0.0 prompts_per_step: int = 8 - lr: float = 7e-5 + lr: float = 1e-4 adam_beta1: float = 0.9 adam_beta2: float = 0.99 clip: float = 0.2 - weight_decay: float = 0.1 + # AdamW decay pulls raw A/B toward 0, not toward the init, which would drive + # the net delta to -B0@A0 -- must stay 0 for this adapter. + weight_decay: float = 0.0 warmup_frac: float = 0.1 grad_clip: float = 10.0 seed: int = 41 unbiased: bool = True - preserve_magnitude: bool = True - gate_mode: Literal["one_sided", "no_gate", "reverse"] = "one_sided" - project_overshoot: float = 1.0 - v_hack_path: Path | None = None - v_hack_extract_top_k: int = 12 - v_hack_k: int = 5 - v_hack_tau_axis: float = 0.0 - v_hack_drop_bottom_frac: float = 0.25 vhack_refresh_every: int = 5 vhack_pairs_path: Path = Path("out/pairsets/pairs_authored.json") - + # Haar-random direction control (placebo): same routing machinery, no pair signal. routeV_random_v_seed: int | None = None - # Top-k axes are oriented by the hack-clean mean difference before max-cos routing. - routeV_top_k: int = 1 - # Pair cosine, live cosine quantiles, or authored-pair activation voting calibrates the gate. - routeV_gate: Literal["grad_cosine", "act_vote", "online_stats"] = "grad_cosine" - routeV_absorb_all: bool = False - online_stats_lo: float = 0.05 - online_stats_hi: float = 0.95 rollout_ablate_frac: float = 0.0 env_mode: EnvMode = "run_tests" - # eval3 keeps solve pressure alive by making 25% of training problems unhackable. - # (10% learned solve too slowly; 25% is the default from 2026-06-10 on.) - eval: Literal["eval2", "eval3"] = "eval3" + # Rotating per-step chance a problem is shown hint-free (graded honest-only); + # keeps solve pressure alive. 10% learned solve too slowly; 25% default since 2026-06-10. + unhackable_frac: float = 0.25 teacher_pool_dir: Path | None = None mix_ratio: float = 0.125 teacher_off_step: int | None = 30 @@ -71,77 +61,39 @@ class Config: eval_n_prompts: int = 32 eval_batch_size: int = 2 save_ckpt_every: int = 10 - cos_pre_split_every: int = 1 - half_a: str = "" out_tag: str = "" @property def preset_name(self) -> str: return type(self).__name__.removesuffix("Config").lower() or "base" - @property - def unhackable_frac(self) -> float: - return {"eval2": 0.0, "eval3": 0.25}[self.eval] - @property def arm(self) -> str: - base = {"none": "vanilla", "erase": "projected", - "routeV": "routingV", "routeV_per_token": "routingV_per_token"}[self.intervention] - # lora2r changes the routing logic (hard 3-way masks, structural separation), - # so it gets its own arm id -- old/new runs must not be conflated. - return f"{base}_lora2r" if self.adapter == "lora2r" else base + # _lora2r suffix kept so these runs never conflate with the retired + # PiSSA-substrate runs of the same intervention (rename-on-logic-change). + return {"none": "vanilla_lora2r", "routeV": "routingV_lora2r", + "absorb": "absorb_lora2r"}[self.intervention] @dataclass(kw_only=True) class SmokeConfig(Config): model: str = "llamafactory/tiny-random-qwen3" + lora_r: int = 4 # tiny model min Linear dim is 16; 2r=8 fits steps: int = 30 group: int = 4 max_new: int = 32 n_problems: int = 100 - beta: float = 0.0 prompts_per_step: int = 1 @dataclass(kw_only=True) class FastConfig(Config): model: str = "Qwen/Qwen3-4B" - steps: int = 60 + steps: int = 100 teacher_pool_dir: Path | None = Path("out/pools/teacher_pool_runtests_dense") - grad_clip: float = 10.0 group: int = 8 max_new: int = 512 n_problems: int = 200 - beta: float = 0.0 prompts_per_step: int = 4 - lr: float = 3e-3 adam_beta1: float = 0.5 adam_beta2: float = 0.9 - - -@dataclass(kw_only=True) -class FastLoraConfig(FastConfig): - # LoRA-frozen-B needs a lower learning rate because its gradient scale differs from delta_S. - adapter: Adapter = "lora_frozen_b" - lr: float = 1e-4 - - -@dataclass(kw_only=True) -class FastLora2rConfig(FastConfig): - # Rank-2r PiSSA-init LoRA + SGTM three-way masking. weight_decay MUST be 0: - # AdamW decays the raw A/B toward 0, not toward the PiSSA init, which would - # drive the net delta to -B0@A0 (subtracting W's top-2r spectral part). - adapter: Adapter = "lora2r" - lr: float = 1e-4 - weight_decay: float = 0.0 - - -@dataclass(kw_only=True) -class FullConfig(Config): - model: str = "Qwen/Qwen3-4B" - steps: int = 200 - group: int = 4 - max_new: int = 1536 - n_problems: int = 992 - beta: float = 1e-3 - prompts_per_step: int = 64