diff --git a/README.md b/README.md index 330d039..f28e4b2 100644 --- a/README.md +++ b/README.md @@ -36,28 +36,30 @@ outputs (`m` = quarantine on/off, `d` = deployed detach): To get the hack direction we pair examples by hand: for each problem, one correct completion and one completion that exploits the evaluation procedure. -For each pair we compute the *exact GRPO gradient* that would result if the hack -rollout had advantage +1 and the clean rollout had advantage -1 -(`-grad logp(hack) + grad logp(clean)`), read in the per-module c-probe space (a -virtual diagonal scale between `A` and `B`). The mean hack-minus-clean direction, -normalized per module, is `v_grad`. (Mechanically identical to a twin-NLL -extraction, since GRPO with adv=+/-1 reduces to the NLL difference; the GRPO -framing is the one we mean.) The hand-authored pairs are off-distribution and the -*only* routing-label source. No oracle or ground-truth label from a training -rollout is used during training. +For each pair we run a forward pass and read the bottleneck activation `A@x` +(the rank-2r input projection of each wrapped Linear), masked-mean-pooled over +completion tokens. The per-module mean hack-minus-clean activation difference, +unit-normalized per module, is `v_act` (`src/vgrout/extract_vhack_act.py`). +Extraction is forward-only: no backward pass, no labels. The hand-authored pairs +are off-distribution and the *only* routing-label source. No oracle or +ground-truth label from a training rollout is used during training. -At training time routeV runs **two passes** per prompt. Pass 1 is an unmasked -forward+backward whose gradient w.r.t. the c-probes gives, per rollout, a -band-normalized cosine to `v_grad` pooled across modules -> a `{clean, mid, hack}` -label (band edges = authored-pair p75 cosines). Pass 2 sets the block masks from -those labels and does the real masked forward+backward; the optimizer steps on -what survives. A rollout with high alignment to the reward-hacking direction updates -the quarantine block while its deployed branch is detached. We re-extract `v_grad` -every N steps so it tracks the current separation rather than only the initial model. +At training time routeA scores each rollout on the no-grad `logp_old` forward it +already needs: an activation-capture hook pools the same bottleneck activations +over completion tokens, and the score is the pooled dot product with `v_act`. +Thresholds come from a rolling buffer of recent scores, z-normalized and split by +two-threshold Otsu into `{keep, absorb, rout}`; until the buffer fills the gate +pins absorb. The block masks are set from those labels *before* the single +masked forward+backward, so there is no second gradient pass. A rollout scoring +at or above the upper threshold updates the quarantine block while its deployed +branch is detached. We re-extract `v_act` every N steps (forward-only, +quarantine-ablated) so it tracks the current model; the buffer stores pooled +activations and re-scores them against the current `v_act`, so a refresh needs +no flush. Whether the *direction* (not just the *act* of routing) drives suppression is the -open question -- the placebo control (Haar-random `v_grad`, same routing -machinery) must NOT match real `v_grad`. We watch `qmass` (the share of update +open question -- the placebo control (Haar-random `v_act`, same routing +machinery) must NOT match real `v_act`. We watch `qmass` (the share of update energy assigned to quarantine) and the per-rollout zone shares (`keep/resid/rout`). ## What we compare @@ -74,8 +76,9 @@ Three arms, identical model/adapter/teacher pool, differing only in the gate - **none** -- gate pinned clean `(0,0)`: the quarantine never trains. The capacity- and structure-matched vanilla control (same adapter, no shrinkage confound). The emergence reference. -- **routeV** -- the method: per-rollout three-way gate from the c-probe gradient - vs `v_grad`. `--routeV-random-v-seed` swaps in a Haar-random direction (placebo). +- **routeA** -- the method: per-rollout three-way gate from the pooled bottleneck + activation vs `v_act`. `--routeA-random-v-seed` swaps in a Haar-random direction + (placebo). - **absorb** -- gate pinned mid `(1,0)`: both blocks train on every rollout. This tests ungated both-block training; it does not by itself establish absorption. @@ -88,10 +91,10 @@ ablation does not change the model. ```bash uv sync -just smoke # tiny-random model, routeV pathway + all verify gates, ~1-2 min -just smoke-all # vanilla + routeV + absorb back to back +just smoke # tiny-random model, routeA pathway + all verify gates, ~1-2 min +just smoke-all # vanilla + routeA + absorb back to back just download-model # warm Qwen3-4B cache -just queue-decision # queue the 4-arm decision run (routeV real / placebo / vanilla / absorb) +just queue-decision # queue the 4-arm decision run (routeA real / placebo / vanilla / absorb) ``` See [RESEARCH_JOURNAL.md](RESEARCH_JOURNAL.md) for session-by-session findings, diff --git a/justfile b/justfile index 18898eb..06a710a 100644 --- a/justfile +++ b/justfile @@ -1,7 +1,7 @@ set shell := ["bash", "-cu"] # vGROUT: rank-2r LoRA gradient routing vs reward-hacking. One adapter (lora2r), -# three arms (intervention none|routeV|absorb). See AGENTS.md / README.md. +# three arms (intervention none|routeA|absorb). See AGENTS.md / README.md. MODEL := "Qwen/Qwen3-4B" TINY_MODEL := "llamafactory/tiny-random-qwen3" # qwen3 arch, ~6M params, smoke only TRAIN := "uv run python -m vgrout.train" # real LeetCode GRPO entry point @@ -24,8 +24,10 @@ default: # Real pipeline on tiny inputs; verify_*.py assert invariants (no tests/ dir). # ───────────────────────────────────────────────────────────────────────────── -# Default smoke = routeV (full pipeline: extract v_grad -> two-pass gate -> deploy -# ablation). Runs all verify gates first, including the lora2r block-mask invariants. +# Default smoke = routeA (full pipeline: extract v_act -> act gate on the logpi_old +# forward -> Otsu pinning -> deploy ablation). Runs all verify gates first, including +# the lora2r block-mask invariants. (scripts/verify_v_act.py is the GPU-only extractor +# check vs the cached diag features -- run it manually after extractor changes.) smoke *ARGS: uv run python scripts/verify_rewards.py # grader: 3 env_modes x clean/hack uv run python scripts/verify_eval_gap.py # eval: train/test token gap, 4 modes @@ -33,18 +35,18 @@ smoke *ARGS: uv run python scripts/verify_science_invariants.py # pair provenance + untouched test uv run python scripts/verify_rotation.py # rotating-unhackable hint-free flip uv run python scripts/verify_lora2r_routing.py # block masks + ablation + c-probe - just smoke-routeV {{ ARGS }} + just smoke-routeA {{ ARGS }} # none: gate pinned clean (0,0) -> quarantine never trains (capacity/structure-matched vanilla). smoke-vanilla *ARGS: BEARTYPE=1 {{ TRAIN }} smoke --intervention=none \ --teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 {{ ARGS }} -# routeV: extract v_grad from authored pairs, splice the per-rollout c-probe gate, -# PASS 1 (unmasked) labels rollouts {clean,mid,hack} via the width-pooled band cosine, -# PASS 2 (masked) trains the blocks; deploy ablation resets the quarantine to init. -smoke-routeV *ARGS: - BEARTYPE=1 {{ TRAIN }} smoke --intervention=routeV \ +# routeA: extract v_act from authored pairs (forward-only), capture pooled acts on the +# no-grad logpi_old forward, label rollouts {keep,absorb,rout} via rolling-buffer Otsu +# thresholds, ONE masked forward+backward; deploy ablation resets the quarantine to init. +smoke-routeA *ARGS: + BEARTYPE=1 {{ TRAIN }} smoke --intervention=routeA \ --teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \ --eval-ablate-every=10 --eval-n-prompts=2 {{ ARGS }} @@ -62,21 +64,13 @@ smoke-unhackable *ARGS: --teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \ --eval-n-prompts=2 {{ ARGS }} -# routeV with a top-k routing subspace (max_i cos(g,v_i) over k SVD dirs) instead of -# the single mean-mass axis. UAT: log shows "top-3 SVD subspace, gate=max_i cos" and the -# band/gate still route (rout>0). k=1 (default) is the mean-diff headline. -smoke-topk *ARGS: - BEARTYPE=1 {{ TRAIN }} smoke --intervention=routeV --v-grad-k=3 \ - --teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \ - --eval-ablate-every=10 --eval-n-prompts=2 {{ ARGS }} - -# routeV + symmetric SOLVE-teacher pool: the G_t teacher slots split 50/50 solve/hack, +# routeA + symmetric SOLVE-teacher pool: the G_t teacher slots split 50/50 solve/hack, # and the run logs the routed-share discrimination (UAT: a line "solve-mix gate # discrimination: hack-teacher routed-share=X vs solve-teacher routed-share=Y"). Smoke # points solve at the same tiny pool just to exercise the split+diagnostic path; real # runs use out/pools/teacher_pool_solve (correct-solution demos) vs the hack pool. smoke-solvemix *ARGS: - BEARTYPE=1 {{ TRAIN }} smoke --intervention=routeV \ + BEARTYPE=1 {{ TRAIN }} smoke --intervention=routeA \ --teacher-pool-dir=out/pools/teacher_pool --solve-pool-dir=out/pools/teacher_pool \ --mix-ratio=0.5 --solve-mix-frac=0.5 \ --eval-ablate-every=10 --eval-n-prompts=2 {{ ARGS }} @@ -84,7 +78,7 @@ smoke-solvemix *ARGS: # All three arms back to back (the full-coverage gate). smoke-all: just smoke-vanilla - just smoke-routeV + just smoke-routeA just smoke-absorb # ───────────────────────────────────────────────────────────────────────────── @@ -92,22 +86,21 @@ smoke-all: # pool, 50% unhackable, authored pairs). Every job carries a why:/resolve: label. # ───────────────────────────────────────────────────────────────────────────── -# Headline 5-arm lora2r decision run, ONLINE-STATS gate + teacher forcing ({{ TEACH }}). -# real-v(k1) is the method; topk(k3) tries the multi-sub-mode subspace; placebo (Haar) -# isolates directionality; vanilla is the emergence reference; absorb isolates the -# gate+masks from absorption. Priority descending so they run in listed order. +# Headline 4-arm lora2r decision run, routeA ACT gate + teacher forcing ({{ TEACH }}). +# real-v is the method (v_act from authored pairs, Otsu rolling-buffer thresholds); +# placebo (Haar) isolates directionality; vanilla is the emergence reference; absorb +# isolates the gate+masks from absorption. Priority descending so they run in listed order. # --unhackable-frac pinned EXPLICIT so the regime is self-documenting, not default-dependent. # Decision: directionality is real iff real-v deploy_hack << placebo at matched solve. -# Watch the streamed `auroc` col: ~0.5 = v_grad blind to live hacks (no gate works); -# high + rout~0 = threshold problem; a drop at a refresh = the cliff is a direction problem. +# Watch the streamed `auroc` col (A>0 contrast): ~0.5 = v_act blind to live hacks (no gate +# works); high + rout~0 = threshold problem; a drop at a refresh = a direction problem. # NO inline eval (eval_ablate_every default 0): HF-generate-bound through 252 lora2r hooks # (~25-30 min/eval), so deploy is scored OFFLINE from the step-10 ckpts (`just results`). queue-decision seed='43': - pueue add -w "$PWD" -o 62 -l "why: P1 lora2r routeV REAL-v k1 online-stats + teacher-forcing s{{seed}} (25% unhackable); resolve: deploy_hack << placebo at matched solve -> directionality real" -- {{ TRAIN }} fast --intervention=routeV --unhackable-frac=0.25 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_routeV_real_s{{seed}} - pueue add -w "$PWD" -o 60 -l "why: P2 lora2r routeV TOPK k3 online-stats + teacher-forcing s{{seed}} (25% unhackable); resolve: topk deploy_hack <= real-k1 -> sub-mode subspace catches hacks the mean washes out" -- {{ TRAIN }} fast --intervention=routeV --v-grad-k=3 --unhackable-frac=0.25 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_routeV_topk_s{{seed}} - pueue add -w "$PWD" -o 58 -l "why: P3 lora2r routeV PLACEBO-v (Haar 157) + teacher-forcing s{{seed}} (25% unhackable); resolve: deploy_hack ~ vanilla -> real-v suppression is directional, not absorption/shrinkage" -- {{ TRAIN }} fast --intervention=routeV --routeV-random-v-seed=157 --unhackable-frac=0.25 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_routeV_placebo_s{{seed}} - pueue add -w "$PWD" -o 56 -l "why: P4 lora2r VANILLA (gate pinned clean) + teacher-forcing s{{seed}} (25% unhackable); resolve: deploy_hack >> 0 emergence reference on the identical adapter" -- {{ TRAIN }} fast --intervention=none --unhackable-frac=0.25 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_vanilla_s{{seed}} - pueue add -w "$PWD" -o 54 -l "why: P5 lora2r BOTH-BLOCK (masks pinned (1,0), no gate) + teacher-forcing s{{seed}} (25% unhackable); resolve: ~vanilla -> gate+masks add nothing; << vanilla -> ungated both-block training suppresses" -- {{ TRAIN }} fast --intervention=absorb --unhackable-frac=0.25 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_absorb_s{{seed}} + pueue add -w "$PWD" -o 62 -l "why: P1 lora2r routeA REAL-v act gate + teacher-forcing s{{seed}} (25% unhackable); resolve: deploy_hack << placebo at matched solve -> directionality real" -- {{ TRAIN }} fast --intervention=routeA --unhackable-frac=0.25 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_routeA_real_s{{seed}} + pueue add -w "$PWD" -o 58 -l "why: P2 lora2r routeA PLACEBO-v (Haar 157) + teacher-forcing s{{seed}} (25% unhackable); resolve: deploy_hack ~ vanilla -> real-v suppression is directional, not absorption/shrinkage" -- {{ TRAIN }} fast --intervention=routeA --routeA-random-v-seed=157 --unhackable-frac=0.25 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_routeA_placebo_s{{seed}} + pueue add -w "$PWD" -o 56 -l "why: P3 lora2r VANILLA (gate pinned clean) + teacher-forcing s{{seed}} (25% unhackable); resolve: deploy_hack >> 0 emergence reference on the identical adapter" -- {{ TRAIN }} fast --intervention=none --unhackable-frac=0.25 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_vanilla_s{{seed}} + pueue add -w "$PWD" -o 54 -l "why: P4 lora2r BOTH-BLOCK (masks pinned (1,0), no gate) + teacher-forcing s{{seed}} (25% unhackable); resolve: ~vanilla -> gate+masks add nothing; << vanilla -> ungated both-block training suppresses" -- {{ TRAIN }} fast --intervention=absorb --unhackable-frac=0.25 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_absorb_s{{seed}} # Base model zero-shot deploy eval (0 training steps): reproduce the paper's base # solve ~11.5% in our harness. resolve: base solve ~0.10-0.12. diff --git a/scripts/diag_pinning_refresh.py b/scripts/attic/diag_pinning_refresh.py similarity index 100% rename from scripts/diag_pinning_refresh.py rename to scripts/attic/diag_pinning_refresh.py diff --git a/scripts/diag_pinning.py b/scripts/diag_pinning.py index 67246e8..7dec362 100644 --- a/scripts/diag_pinning.py +++ b/scripts/diag_pinning.py @@ -1,4 +1,4 @@ -"""Q2 diagnostic: what should the live routeV gate SCORE, and where do the pinning +"""Q2 diagnostic: what should the live routing gate SCORE, and where do the pinning cuts go? THE QUESTION (Q2). The gate routes UPDATES, not rollouts: per rollout the GRPO update @@ -90,8 +90,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from vgrout.lora2r import wrap_model_with_lora2r from vgrout.pairs import load_pairs -from vgrout.extract_vhack_grad import completion_nll -from vgrout.train import _auroc +from vgrout.train import _auroc, _otsu3 # colour = behaviour (blue solve, red hack, grey fail); style = source (solid on-policy, dashed synthetic) SOLVE, HACK, FAIL, ABSORB_C, ROUT_C, ORACLE = "#3b6ea5", "#c44e52", "#9aa0a6", "#d1900a", "#c44e52", "#3a8a7a" @@ -104,7 +103,7 @@ class Cfg: run_dir: Path = Path("out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3") ckpt: str = "first_hack" pairs: Path = Path("data/pairs/hack_pairs.md#all-in-one") - # headline figure builds v from this heading-prefix subset = the routeV TRAINING + # headline figure builds v from this heading-prefix subset = the routeA TRAINING # default (train_config.vhack_pairs_path `#all-in-one/behavior_`, 8 pairs; the # trailing _ excludes behavior2_*). The pairset table spans all subsets of `pairs`. headline_prefix: str = "behavior_" @@ -216,24 +215,21 @@ def _kde(x: np.ndarray, grid: np.ndarray) -> np.ndarray: return np.exp(-0.5 * z ** 2).sum(1) / (len(x) * bw * np.sqrt(2 * np.pi)) -def _otsu3(x: np.ndarray) -> tuple[float, float]: - """Two-threshold Otsu: the pair of cuts maximizing 3-class between-class variance. - Label-free -- an online gate can compute this from a rolling window of scores, so - using it here is not oracle leakage. O(n^2), fine for a few hundred scores. - Scores are winsorized at 1/99% first: Otsu maximizes variance, so on heavy-tailed - scores a single extreme point otherwise buys a whole class (seen on grad_dot).""" - x = np.clip(x, *np.quantile(x, [0.01, 0.99])) - s = np.sort(np.asarray(x, float)) - n = len(s) - c = np.concatenate([[0.0], np.cumsum(s)]) - best, best_ij = -np.inf, (1, 2) - for i in range(1, n - 1): - for j in range(i + 1, n): - obj = c[i] ** 2 / i + (c[j] - c[i]) ** 2 / (j - i) + (c[n] - c[j]) ** 2 / (n - j) - if obj > best: - best, best_ij = obj, (i, j) - i, j = best_ij - return float((s[i - 1] + s[i]) / 2), float((s[j - 1] + s[j]) / 2) +def completion_nll(model, tokenizer, prompt: str, completion: str, device) -> torch.Tensor: + """Mean NLL over completion tokens only (length-normalized). The backward of this + loss populates the c-probe grads read by _gate_grads (the retired grad-gate space, + kept here as a diagnostic baseline).""" + prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) + full_ids = tokenizer(prompt + completion, return_tensors="pt").input_ids.to(device) + n_prompt = prompt_ids.shape[1] + logits = model(full_ids).logits[:, :-1] # [1, L-1, V] + targets = full_ids[:, 1:] # [1, L-1] + logp = torch.nn.functional.log_softmax(logits.float(), dim=-1) + nll = -logp.gather(-1, targets.unsqueeze(-1)).squeeze(-1) # [1, L-1] + # mask: positions whose target is a completion token (i.e. index >= n_prompt in full_ids) + pos = torch.arange(full_ids.shape[1] - 1, device=device).unsqueeze(0) + mask = (pos >= (n_prompt - 1)).float() + return (nll * mask).sum() / mask.sum().clamp_min(1.0) def plot_q2(df: pl.DataFrame, subtitle: str, out_png: Path) -> dict: @@ -497,7 +493,7 @@ def _downstream(cfg: Cfg, fe: dict, src: str) -> int: print(f"SHOULD: on_hackpos >= ~20 and on_drop not the majority, ELSE the window/run has " f"too few learnable hacks and every AUROC below is noise.") - # ── headline vectors from the routeV-default subset; placebo swaps in Haar ── + # ── headline vectors from the routeA-default subset; placebo swaps in Haar ── groups: dict[str, list[int]] = fe["pair_groups"] head_idx = [i for i, pid in enumerate(pair_ids) if pid.startswith(cfg.headline_prefix)] assert head_idx, f"no pairs match headline prefix {cfg.headline_prefix!r}" diff --git a/scripts/pairs_from_rollouts.py b/scripts/pairs_from_rollouts.py index 6bb260c..6118829 100644 --- a/scripts/pairs_from_rollouts.py +++ b/scripts/pairs_from_rollouts.py @@ -7,12 +7,12 @@ hack-side by detector signature. Here the source is the student's logged rollouts (out/runs//rollouts.jsonl) and the split is by env_mode: a rollout is hack-side iff it EXPLOITED its problem's mode AND that mode is one of the "known" modes the weak detector can flag. The held-out modes are never used to -build pairs -- v_grad is extracted only from the known modes, and the A5 figure +build pairs -- the routing vector is extracted only from the known modes, and the A5 figure then measures whether the held-out modes are also suppressed at deployment. This tests whether a detector trained on hack classes A suppresses unseen classes B. Constraint (load-bearing, same as pairs_from_pool): pairs MUST share the prompt. -The paired-diff g_hack - g_clean in extract_vhack_grad cancels prompt-specific +The paired-diff feat_hack - feat_clean in the pair extraction cancels prompt-specific noise only when both completions are conditioned on the same chat-templated prompt. A given problem_id renders one fixed (hinted) prompt across steps, so same-(problem_id) hack/clean rollouts share it; we assert equality to catch drift. diff --git a/scripts/verify_science_invariants.py b/scripts/verify_science_invariants.py index 4f3c391..bd0a9a8 100644 --- a/scripts/verify_science_invariants.py +++ b/scripts/verify_science_invariants.py @@ -74,7 +74,8 @@ def main() -> int: authored_pairs = load_pairs(Path("data/pairs/hack_pairs.md#all-in-one")) real_pairsets_ok = ( - len(authored_pairs) == 27 + len(authored_pairs) == 42 # 27 + 15 wave-2 behavior2_* (c33b810) + and len(load_pairs(Path("data/pairs/hack_pairs.md#all-in-one/behavior_"))) == 8 # routeA training default and len(load_pairs(Path("data/pairs/hack_pairs.md#all-in-one@opportunity-aware"))) == 6 and len(load_pairs(Path("data/pairs/hack_pairs.md#all-in-one@explicit"))) == 10 and len(load_pairs(Path("data/pairs/hack_pairs.md#all-in-one@roleplay"))) == 2 diff --git a/src/vgrout/extract_vhack_grad.py b/src/vgrout/extract_vhack_grad.py deleted file mode 100644 index a5e19a7..0000000 --- a/src/vgrout/extract_vhack_grad.py +++ /dev/null @@ -1,301 +0,0 @@ -"""Gradient-side per-module v_hack extraction (spec.md §B, top-k variant). - -We sample the per-module GRPO update direction on labeled (hack, clean) pairs. -For a pair with advantages (adv_h=+1, adv_c=-1) the Dr.GRPO single-step grad -`-adv_h * grad_logp(hack) - adv_c * grad_logp(clean)` algebraically equals -`grad_NLL(hack) - grad_NLL(clean)`, so we compute it by the simpler path: -forward each completion, take mean-NLL on completion tokens, backward, and -capture the lora2r c-probe grad (the per-pair weight grad of the virtual -diagonal between A and B, deployed block) per wrapped Linear. Naming the steps -NLL is an implementation detail; the *meaning* is "the GRPO update on this pair." - -Then per module, with D = [g_hack_i - g_clean_i for each pair] in R^{n_pairs x r}: - SVD(D) = U Σ Vh - v_hack[name] = top_k rows of Vh, each oriented so mean(D @ v_i) > 0 - -This generalizes mean-diff (which corresponds to top-1 PC of paired diffs under -isotropic covariance) to a rank-k hack subspace, motivated by CHaRS (Abdullaev -2025 -- see docs/paper_chars.md): hack signal is multi-modal across hack flavors -(weak tests, hardcode, persona, ...), so a single global direction is brittle. - -Orientation matters because proj.py applies a per-direction one-sided gate -(only subtracts when positive). +v_i must align with the reward-hacking gradient. - -Saves `out/v_hack.safetensors` = dict[name -> Tensor[k, r]] (cpu fp32, rows -unit-norm + orthonormal from SVD) with header {"model": str, "dtype": str, -"top_k": str(k)}. - -Run: uv run python -m vgrout.extract_vhack_grad -""" -from __future__ import annotations - -import sys -from collections import defaultdict -from dataclasses import dataclass -from pathlib import Path - -import torch -import tyro -from jaxtyping import Float -from loguru import logger -from safetensors.torch import save_file -from tabulate import tabulate -from transformers import AutoModelForCausalLM, AutoTokenizer - -from .lora2r import wrap_model_with_lora2r -from .pairs import load_pairs, pairset_sha256 - - -OUT_DIR = Path("out") - - -@dataclass -class Config: - model: str = "Qwen/Qwen3-4B" - dtype: str = "bf16" # must match train.py, else SVD basis cache can differ silently - out_path: Path = OUT_DIR / "vhack" / "v_hack.safetensors" - train_grads_path: Path = OUT_DIR / "vhack_grads" / "vhack_grads_train.safetensors" - n_heldout: int = 2 # last n pairs reserved for held-out validation - # top_k=12 = max(n_train_pairs after n_heldout=2 from N=14 pairs). Extract once - # at max rank; train.py slices via --v-hack-k for k-ablation without re-extract. - top_k: int = 12 - # tau_axis: zero rows where S_i/S_0 < tau_axis. Diagnostic -- projection along - # noise-direction unit vectors removes only ~||g||/sqrt(r) ≈ 2% of grad - # magnitude on r=2560 modules, so this rarely changes effect size; it does - # keep k-ablations interpretable (axes 4-5 might be pure noise on N=12 pairs). - tau_axis: float = 0.0 - # Pairset reference: generated JSON or one `path.md#section`. - pairs_from_pool: Path | None = None - - -def resolve_dtype(s: str) -> torch.dtype: - return {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[s] - - -def completion_nll(model, tokenizer, prompt: str, completion: str, device) -> torch.Tensor: - """Mean NLL over completion tokens only (length-normalized).""" - prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) - full_ids = tokenizer(prompt + completion, return_tensors="pt").input_ids.to(device) - n_prompt = prompt_ids.shape[1] - logits = model(full_ids).logits[:, :-1] # [1, L-1, V] - targets = full_ids[:, 1:] # [1, L-1] - logp = torch.nn.functional.log_softmax(logits.float(), dim=-1) - nll = -logp.gather(-1, targets.unsqueeze(-1)).squeeze(-1) # [1, L-1] - # mask: positions whose target is a completion token (i.e. index >= n_prompt in full_ids) - pos = torch.arange(full_ids.shape[1] - 1, device=device).unsqueeze(0) - mask = (pos >= (n_prompt - 1)).float() - return (nll * mask).sum() / mask.sum().clamp_min(1.0) - - -def extract_v_hack( - model, - tokenizer, - wrappers: dict, - pairs: list, - top_k: int, - tau_axis: float, - n_heldout: int, - device, -) -> tuple[ - dict[str, Float[torch.Tensor, "k r"]], - dict[str, Float[torch.Tensor, "k"]], - dict[str, Float[torch.Tensor, "n_pairs r"]], - list[dict], -]: - """Run pair-grads + per-module SVD on D = g_hack - g_clean, return v_hack. - - Pure function -- caller owns model loading, wrapping, and saving. train.py - calls this on its already-wrapped model when v_hack cache is missing, so - we don't pay the cost of a second model load. - - Returns: - v_hack: dict[name -> Tensor[k, r]] (cpu fp32), top-k right singular - vectors of D per module, oriented so mean(D @ v_i) > 0. If - tau_axis > 0, rows where S_i/S_0 < tau_axis are zeroed. - v_sv: dict[name -> Tensor[k]] (cpu fp32), singular values matching v_hack. - Saved alongside V under `_sv/{name}` keys so load-time noise-floor - filtering works without re-extracting. - raw_grads: dict["hack/name"|"clean/name" -> Tensor[n_pairs, r]] for - offline analysis (verify_vhack_heldout reads this). - diag_rows: per-module diagnostic dicts (sv_top frac, ||D||, etc.). - """ - train_pairs = pairs[:-n_heldout] if n_heldout > 0 else pairs - n_pairs = len(train_pairs) - - grads_hack: dict[str, list[torch.Tensor]] = defaultdict(list) - grads_clean: dict[str, list[torch.Tensor]] = defaultdict(list) - - for pi, pair in enumerate(train_pairs): - for label, completion in (("hack", pair.hack), ("clean", pair.clean)): - model.zero_grad(set_to_none=True) - loss = completion_nll(model, tokenizer, pair.prompt, completion, device) - if not torch.isfinite(loss): - # Skip-on-nonfinite would silently leave G_h and G_c with mismatched - # lengths and explode later in D = G_h - G_c. Fail fast. - raise RuntimeError(f"non-finite loss at pair={pi} label={label}: {loss.item()}") - loss.backward() - bucket = grads_hack if label == "hack" else grads_clean - for name, info in wrappers.items(): - layer = info["layer"] - # Per-pair weight grad of the virtual diagonal (c-probe), DEPLOYED - # block only -- the same space the live gate reads (train.py), so - # band calibration uses the same representation. Requires grad_probe=True. - cg = layer._lora2r_gate.grad - if cg is None: - raise RuntimeError(f"no c-probe grad on {name}; wrap with grad_probe=True") - g = cg.sum(dim=tuple(range(cg.dim() - 1)))[: layer._lora2r_r] # [r] - bucket[name].append(g.detach().float().cpu().clone()) - if (pi + 1) % 5 == 0: - logger.info(f" pair {pi+1}/{n_pairs} loss={loss.item():.3f}") - model.zero_grad(set_to_none=True) # leave caller with clean state - - raw_grads = { - **{f"hack/{n}": torch.stack(gs) for n, gs in grads_hack.items()}, - **{f"clean/{n}": torch.stack(gs) for n, gs in grads_clean.items()}, - } - - # Per module: D = g_hack - g_clean (paired diff cancels prompt-specific noise). - # SVD(D) gives orthonormal right singular vectors = principal axes of variation - # of the hack-clean axis. Top-k generalizes mean-diff (which is the k=1 case). - v_hack: dict[str, torch.Tensor] = {} - v_sv: dict[str, torch.Tensor] = {} - rows = [] - n_zero = 0 - k = min(top_k, n_pairs) - n_axes_kept_total = 0 - for name in grads_hack: - G_h = torch.stack(grads_hack[name]) # [n_pairs, r] - G_c = torch.stack(grads_clean[name]) - D = G_h - G_c - - U_d, S_d, Vh_d = torch.linalg.svd(D, full_matrices=False) - V = Vh_d[:k] # [k, r], rows orthonormal in R^r - # Orient by per-pair majority vote: for each axis i, count pairs where - # d_p @ v_i > 0; if strict majority disagree with current SVD sign, flip. - # More outlier-robust than sign(mean): one extreme pair can't flip a - # consensus direction. Matches repeng's _orient_svd convention. - proj_per_pair = D @ V.T # [n_pairs, k] - n_pos = (proj_per_pair > 0).float().sum(0) # [k] - flip = torch.where(n_pos < n_pairs / 2, -torch.ones(k), torch.ones(k)) - V = V * flip.unsqueeze(1) - - # tau_axis: zero rows where S_i/S_0 < tau_axis (diagnostic; see Config comment). - n_axes_kept = k - if tau_axis > 0 and S_d[0] > 1e-12: - ratios = S_d[:k] / S_d[0] - keep = (ratios >= tau_axis).float() - V = V * keep.unsqueeze(1) - n_axes_kept = int(keep.sum()) - n_axes_kept_total += n_axes_kept - - nrm = D.norm() - if nrm < 1e-12: - n_zero += 1 - v_hack[name] = torch.zeros((k, D.shape[1]), dtype=V.dtype).contiguous() - else: - v_hack[name] = V.contiguous() - # Record singular values so the load-time noise-floor filter has the - # extraction-time S_i per axis without re-extracting. Saved under - # `_sv/{name}` keys in the safetensors file (combined at save site). - v_sv[name] = S_d[:k].clone().contiguous() - sv_top = S_d[:k] - sv_total = S_d.sum().clamp_min(1e-12) - rows.append({ - "module": name.split(".")[-1], - "r": D.shape[1], - "||D||": f"{nrm:.2e}", - "sv_0": f"{S_d[0].item():.2e}" if S_d.numel() else "-", - f"sv_top{k}_frac": f"{(sv_top.sum() / sv_total).item():.2f}", - "sv_ratio_0/1": ("-" if S_d.numel() < 2 - else f"{(S_d[0] / S_d[1].clamp_min(1e-12)).item():.2f}"), - "axes_kept": n_axes_kept, - }) - n_modules = len(grads_hack) - logger.info( - f"v_hack: modules={n_modules} k_max={k} zero-||D||={n_zero} " - f"axes_kept_avg={n_axes_kept_total/max(1,n_modules):.1f} (tau_axis={tau_axis})" - ) - return v_hack, v_sv, raw_grads, rows - - -def main(cfg: Config) -> int: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - dtype = resolve_dtype(cfg.dtype) - if cfg.pairs_from_pool is None: - raise ValueError("--pairs-from-pool is required; use data/pairs/hack_pairs.md#all-in-one") - pairs = load_pairs(cfg.pairs_from_pool) - logger.info(f"pairs source: {cfg.pairs_from_pool} -> {len(pairs)} pairs") - logger.info( - f"device={device} model={cfg.model} dtype={cfg.dtype} " - f"N_pairs={len(pairs)} heldout={cfg.n_heldout} top_k={cfg.top_k} tau_axis={cfg.tau_axis}" - ) - - tokenizer = AutoTokenizer.from_pretrained(cfg.model) - model = AutoModelForCausalLM.from_pretrained( - cfg.model, dtype=dtype, attn_implementation="sdpa" - ).to(device) - model.eval() # disable dropout; gradients still flow through the adapter - wrappers = wrap_model_with_lora2r(model, grad_probe=True) - n_mod = len(wrappers) - logger.info(f"wrapped {n_mod} modules; probe space = r per module") - - train_pairs = pairs[:-cfg.n_heldout] if cfg.n_heldout > 0 else pairs - logger.info(f"train pairs: {len(train_pairs)} held: {cfg.n_heldout}") - - v_hack, v_sv, raw_grads, rows = extract_v_hack( - model, tokenizer, wrappers, pairs, - top_k=cfg.top_k, tau_axis=cfg.tau_axis, - n_heldout=cfg.n_heldout, device=device, - ) - n_zero = sum(1 for v in v_hack.values() if v.norm() < 1e-12) - k = min(cfg.top_k, len(train_pairs)) - - cfg.out_path.parent.mkdir(parents=True, exist_ok=True) - cfg.train_grads_path.parent.mkdir(parents=True, exist_ok=True) - save_file(raw_grads, str(cfg.train_grads_path), - metadata={"model": cfg.model, "dtype": cfg.dtype}) - # v_hack file layout: bare `{name}` keys hold V[k, r]; `_sv/{name}` keys - # hold S[k]. Loader at train.py:load_v_hack splits them back apart. - save_payload = {**v_hack, **{f"_sv/{n}": s for n, s in v_sv.items()}} - save_file(save_payload, str(cfg.out_path), - metadata={"model": cfg.model, "dtype": cfg.dtype, "top_k": str(k), - "tau_axis": str(cfg.tau_axis), "schema": "v2_with_sv", - "pairs_path": str(cfg.pairs_from_pool), - "pairs_sha256": pairset_sha256(cfg.pairs_from_pool)}) - - # summary: aggregate by suffix -- track top-k energy concentration - by_suffix: dict[str, list] = defaultdict(list) - for r in rows: - by_suffix[r["module"]].append(float(r[f"sv_top{k}_frac"])) - agg_rows = [] - for suf, vals in sorted(by_suffix.items()): - agg_rows.append({ - "suffix": suf, - "n": len(vals), - f"mean_sv_top{k}_frac": f"{sum(vals)/len(vals):.2f}", - f"min_sv_top{k}_frac": f"{min(vals):.2f}", - f"max_sv_top{k}_frac": f"{max(vals):.2f}", - }) - - # Final tail: BLUF -- what an agent reads first should be result + interp. - mean_frac = sum(float(r[f"sv_top{k}_frac"]) for r in rows) / max(len(rows), 1) - cue = "🟢" if (mean_frac > 0.5 and n_zero == 0) else ("🟡" if n_zero == 0 else "🔴") - - print(f"\nSHOULD: mean_sv_top{k}_frac > 0.5 per suffix (subspace captures most energy). " - f"zero-||D||==0 (else grad flow broken).\n") - print(tabulate(agg_rows, headers="keys", tablefmt="tsv", floatfmt=".2f")) - print() - print(f"out: {cfg.out_path}") - print(f"argv: extract_vhack_grad --model={cfg.model} --top-k={k} --n-heldout={cfg.n_heldout}") - print(f"main metric: mean_sv_top{k}_frac={mean_frac:.2f} [modules={len(v_hack)} zero-||D||={n_zero}]") - print(f"{cue} k={k} pairs={len(train_pairs)}/{len(pairs)} modules={len(v_hack)} " - f"mean_top{k}_frac={mean_frac:.2f} zero={n_zero}") - - if n_zero > 0: - logger.error(f"FAIL: {n_zero}/{len(v_hack)} modules have zero ||D|| -- gradient flow broken") - return 1 - return 0 - - -if __name__ == "__main__": - sys.exit(main(tyro.cli(Config))) diff --git a/src/vgrout/figs.py b/src/vgrout/figs.py index 3aa9b6e..871e4d2 100644 --- a/src/vgrout/figs.py +++ b/src/vgrout/figs.py @@ -17,14 +17,16 @@ from pathlib import Path FIGS_DIR = Path("docs/figs") # Reader-facing arm names. Code/log tags carry our internal vocabulary -# (routeV = the current routing arm); plots must +# (routeA = the current routing arm); plots must # not. Map every internal tag to the word a paper reader sees. Anything missing # falls through to its raw tag, so a new arm shows up loud rather than silently # mislabelled. ARM_DISPLAY = { - # routeV is the current banded-gate arm; routing2/route2 are the old binary-tau runs - # (kept so historical run artifacts still plot -- see rename, 2026-06-06). - "routingV": "route", "routeV": "route", "routingV_per_token": "route per-token", + # routeA is the current act-gate arm; routeV (grad gate) and routing2/route2 + # (binary-tau) are retired but kept so historical run artifacts still plot. + "routeA": "route", + "routingV": "route (grad)", "routeV": "route (grad)", + "routingV_per_token": "route per-token", "routing2": "route", "route2": "route", "routing2_grad": "route", "routing2_act": "route (act)", "projected": "erase", "route": "route", "erase": "erase", "vanilla": "vanilla", diff --git a/src/vgrout/lora2r.py b/src/vgrout/lora2r.py index 25864c4..b41f9d0 100644 --- a/src/vgrout/lora2r.py +++ b/src/vgrout/lora2r.py @@ -51,9 +51,9 @@ def _lora2r_hook(layer: nn.Linear, args: tuple, y: Tensor) -> Tensor: grad probe: c = ones[..., 2r] spliced as h*c. After backward c.grad = h ⊙ (Bᵀδ_y) = the per-sample WEIGHT grad of a virtual diagonal - scale between A and B. The pair extraction (extract_vhack_grad) and the live - gate (train.py) both read this same space, so the band is self-consistent - whatever the basis. + scale between A and B. Training no longer uses it (the routeA gate reads + activations); kept for scripts/diag_pinning.py, whose grad-score baseline + reads this space. """ (x,) = args # x: [..., d_in] A: Float[Tensor, "two_r d_in"] = layer._lora2r_A # trainable diff --git a/src/vgrout/pairs_from_pool.py b/src/vgrout/pairs_from_pool.py index af3066e..9d45dec 100644 --- a/src/vgrout/pairs_from_pool.py +++ b/src/vgrout/pairs_from_pool.py @@ -9,7 +9,7 @@ exclusively from `half-A` detectors. The clean-side is any rollout where all four upstream detectors are False AND format_ok is True. Constraint (load-bearing): pairs MUST share the prompt. The paired-diff -`g_hack - g_clean` in extract_vhack_grad cancels prompt-specific noise only +`feat_hack - feat_clean` in the pair extraction cancels prompt-specific noise only when both completions are conditioned on the same chat-templated prompt. Cross-prompt pairs would inject prompt-difference signal into v_hack. diff --git a/src/vgrout/run_artifacts.py b/src/vgrout/run_artifacts.py index d80c6a1..e04316a 100644 --- a/src/vgrout/run_artifacts.py +++ b/src/vgrout/run_artifacts.py @@ -13,6 +13,7 @@ RUN_SCHEMA = "paired_final_v2" # v2: deployed/as_trained field names (was depl # get a _lora2r suffix so the two substrates never conflate in aggregation. ARM = {"none": "vanilla", "erase": "projected", "routeV": "routingV", "routeV_per_token": "routingV_per_token", + "routeA": "routeA", "absorb": "absorb"} diff --git a/src/vgrout/tablelog.py b/src/vgrout/tablelog.py index 70d9199..a94fa9d 100644 --- a/src/vgrout/tablelog.py +++ b/src/vgrout/tablelog.py @@ -73,7 +73,7 @@ class StepLogger: def __init__(self, arm: str, modes: list[str], mode_code: dict[str, str], show_ablate: bool = False) -> None: # Routing diagnostics are ALWAYS shown (nan on vanilla, whose gate never runs) so the - # column layout is identical across arms -- vanilla/routeV/absorb tables line up. + # column layout is identical across arms -- vanilla/routeA/absorb tables line up. cols: list[_Col] = [ _Col("step", 4, "step", "d", "GRPO step"), _Col("ref_eq", 6, "ref_eq", ".2f", "vanilla-equiv step (cum_gens/256)"), @@ -84,7 +84,7 @@ class StepLogger: _Col("hack_s", 7, "hack_s?", "frac", "student hack-flagged rollouts (the headline)"), _Col("hack_t", 7, "hack_t", "frac", "teacher hack-flagged rollouts (sanity: pool hacks)"), # Held-out deployed evaluation with quarantine ablated; NaN between evaluation steps. - _Col("hack_deployed", 7, "hk_dep", "+.2f", "DEPLOY-eval hack (routeV: quarantine OFF; vanilla/erase: trained model); held-out subset, T=0.7, every eval_ablate_every steps; nan between"), + _Col("hack_deployed", 7, "hk_dep", "+.2f", "DEPLOY-eval hack (routeA/absorb: quarantine OFF; vanilla: trained model); held-out subset, T=0.7, every eval_ablate_every steps; nan between"), _Col("solve_deployed", 7, "slv_dep", "+.2f", "DEPLOY-eval solve (same cadence; nan between)"), ] # Multi-mode runs show current-step hacks per environment; single-mode would duplicate hack_s. @@ -99,17 +99,16 @@ class StepLogger: _Col("gn", 7, "gn", ".1e", "pre-clip L2 norm of A/B grads (vs grad_clip)"), _Col("lr", 7, "lr", ".1e", "scheduled learning rate"), ] - # routeV reports unit and energy shares across the routing band (nan on vanilla/absorb). + # routeA reports gate diagnostics (nan on vanilla/absorb, whose gate never runs). cols += [ - _Col("auroc", 6, "auroc", ".2f", "AUROC of pooled cos(g,v_grad) as a reward-hacking detector; measurement only, never routes. ~0.5 = chance-level separation; high AUROC but rout~0 = threshold/scale problem; a drop at refresh = reduced separation"), - _Col("cosU", 6, "cosU", "+.2f", "pooled cos(v_grad, summed-rollout c-grad): alignment of the net update with v_grad"), + _Col("auroc", 6, "auroc", ".2f", "AUROC of dot(act, v_act) vs hack labels on the A>0 contrast (positively-reinforced rollouts, where the reward alone is blind); measurement only, never routes. ~0.5 = chance-level separation; high AUROC but rout~0 = threshold problem; a drop at refresh = reduced separation"), + _Col("cos", 6, "cos", "+.2f", "mean per-rollout cos(act, v_act) (dot-vs-cos diagnostic)"), _Col("qmass", 6, "qmass", ".2f", "quarantine energy share ||g_quar||/(||g_keep||+||g_quar||): fraction of update energy assigned to quarantine"), - _Col("keep", 6, "keep", ".2f", "rollout share labelled clean (below band) -> deployed-only, quarantine off"), - _Col("resid", 6, "resid", ".2f", "rollout share labelled mid (inside band) -> both blocks train; absorption is possible but not measured"), - _Col("rout", 6, "rout", ".2f", "rollout share labelled hack (above band) -> quarantine-only, deployed detached"), - _Col("keepE", 6, "keepE", ".2f", "energy-weighted keep (grad-norm-weighted clean share)"), - _Col("residE", 6, "residE", ".2f", "energy-weighted resid"), - _Col("routE", 6, "routE", ".2f", "energy-weighted rout"), + _Col("keep", 6, "keep", ".2f", "rollout share below t_lo -> deployed-only, quarantine off"), + _Col("resid", 6, "resid", ".2f", "rollout share between thresholds (and ALL rollouts during warmup) -> both blocks train; absorption is possible but not measured"), + _Col("rout", 6, "rout", ".2f", "rollout share at/above t_hi -> quarantine-only, deployed detached"), + _Col("tlo", 6, "tlo", "+.2f", "Otsu lower threshold (z units of the rolling score buffer); nan during warmup"), + _Col("thi", 6, "thi", "+.2f", "Otsu upper (rout) threshold (z units); nan during warmup"), ] # Show the training-prompt deploy proxy only when an ablated slice exists. if show_ablate: diff --git a/src/vgrout/train.py b/src/vgrout/train.py index bfbeba8..acd400e 100644 --- a/src/vgrout/train.py +++ b/src/vgrout/train.py @@ -16,13 +16,14 @@ both trainable, partitioned into a deployed block [:r] and a quarantine block Arms (--intervention): none gate pinned clean (0,0): quarantine never trains -- the capacity- and structure-matched vanilla control. - routeV per-rollout three-way gate from the c-probe gradient vs v_grad: - clean->deployed-only, hack->quarantine-only (deployed detached), - mid->both, which may permit absorption. + routeA per-rollout three-way gate from the pooled bottleneck activation vs + v_act: keep->deployed-only, rout->quarantine-only (deployed detached), + absorb->both, which may permit absorption. The acts ride the no-grad + logpi_old forward, so routeA costs roughly the vanilla arm. absorb gate pinned mid (1,0): both blocks train on everything, no gate -- tests ungated both-block training. - uv run python -m vgrout.train smoke --intervention=routeV + uv run python -m vgrout.train smoke --intervention=routeA """ from __future__ import annotations @@ -33,9 +34,12 @@ import os import sys import random import time +from collections import deque from contextlib import nullcontext from pathlib import Path +import numpy as np + # Must be set BEFORE `import torch` to take effect on the CUDA allocator. # Eliminates fragmentation that caused 91 GiB allocated / 581 MiB free crash # on Qwen3-4B G=8 (PyTorch's own OOM message recommends this). @@ -51,7 +55,9 @@ from tabulate import tabulate from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig +from .extract_vhack_act import ActCapture, extract_v_act, haar_unit_rows from .lora2r import wrap_model_with_lora2r +from .pairs import load_pairs from .proj import per_token_logps from .rewards import EnvMode, compute_reward from .data import DATA, load_problems @@ -64,46 +70,27 @@ OUT_DIR = Path("out") RUNS_DIR = OUT_DIR / "runs" -def _haar_unit_dirs(v_grad: dict, seed: int, device) -> dict: - """Build the reproducible out-of-subspace directionality control (placebo) for routeV. - - Matches v_grad's [k, r] shape and unit-normalizes per row, so a top-k routing run - gets k random dirs (the gate's max-cosine still sees the same shape) for a fair placebo.""" - g = torch.Generator().manual_seed(seed) - out = {} - for name in sorted(v_grad): - d = torch.randn(v_grad[name].shape, generator=g) # [k, r] - out[name] = (d / d.norm(dim=-1, keepdim=True).clamp_min(1e-12)).to(device) - return out - - -def _build_v_grad(raw_grads: dict, names, k: int, device) -> dict[str, Float[torch.Tensor, "k r"]]: - """Per-module routing directions from authored-pair gradients -> dict[name -> [k, r]]. - - k=1 (headline): normalized mean(hack-clean) -- one mean-mass axis. k>1: the top-k - oriented right singular vectors of the paired-diff matrix D=[g_hack-g_clean] (SVD + - per-pair majority orient, mirroring extract_vhack_grad.extract_v_hack), a rank-k hack - subspace. The live gate scores max_i cos(g, v_i). Rows are unit-norm. k=1 is NOT - SVD-top-1 (they differ off-isotropic): keeping mean-diff makes 'mean-mass vs top-k' - a clean A/B, not a confound.""" - out = {} - for name in names: - D: Float[torch.Tensor, "n_pairs r"] = ( - raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).float() - if k == 1: - d = D.mean(0) - V = (d / d.norm().clamp_min(1e-12)).unsqueeze(0) # [1, r] - else: - _, _, Vh = torch.linalg.svd(D, full_matrices=False) - kk = min(k, Vh.shape[0]) - V = Vh[:kk] # [kk, r] orthonormal - proj = torch.einsum("n r, k r -> n k", D, V) # per-pair projection - flip = torch.where((proj > 0).float().sum(0) < D.shape[0] / 2, - -torch.ones(kk), torch.ones(kk)) # orient toward reward-hacking gradients - V = V * flip.unsqueeze(1) - out[name] = V.to(device) - return out - +def _otsu3(x: np.ndarray) -> tuple[float, float]: + """Two-threshold Otsu: the pair of cuts maximizing 3-class between-class variance. + Label-free -- the routeA gate computes this on a rolling buffer of live scores, so + using it is not oracle leakage. Scores are winsorized at 1/99% first: Otsu maximizes + variance, so on heavy-tailed scores a single extreme point otherwise buys a whole + class (journal 2026-06-11 (d): v5 act rout precision 0.00 -> 0.50 after winsorize). + Vectorized over the [n, n] cut grid; n is the buffer size (<= a few hundred).""" + x = np.clip(x, *np.quantile(x, [0.01, 0.99])) + s = np.sort(np.asarray(x, float)) + n = len(s) + c = np.concatenate([[0.0], np.cumsum(s)]) + iv = np.arange(1, n) + i_g, j_g = iv[:, None], iv[None, :] + with np.errstate(divide="ignore", invalid="ignore"): + obj = (c[i_g] ** 2 / i_g + + (c[j_g] - c[i_g]) ** 2 / (j_g - i_g) + + (c[n] - c[j_g]) ** 2 / (n - j_g)) + obj[(j_g <= i_g) | (j_g >= n)] = -np.inf # need i < j and a nonempty top class + i, j = np.unravel_index(np.argmax(obj), obj.shape) + i, j = iv[i], iv[j] + return float((s[i - 1] + s[i]) / 2), float((s[j - 1] + s[j]) / 2) def _sample_rows(rows: list[dict] | None, n: int, rng: torch.Generator) -> list[dict]: @@ -116,37 +103,15 @@ def _sample_rows(rows: list[dict] | None, n: int, rng: torch.Generator) -> list[ return [rows[i] for i in idxs] -def _zone_stats(f: torch.Tensor, w: torch.Tensor) -> tuple[float, ...]: - """Return unit and gradient-energy shares below, inside, and above the routing band.""" - if f.numel() == 0: - return (float("nan"),) * 6 - lo, hi = (f == 0), (f == 1) - mid = ~(lo | hi) - tot = w.sum().clamp_min(1e-12) - return (lo.float().mean().item(), mid.float().mean().item(), hi.float().mean().item(), - ((w * lo).sum() / tot).item(), ((w * mid).sum() / tot).item(), ((w * hi).sum() / tot).item()) - - -def _pair_cos(raw_grads: dict, v: Float[torch.Tensor, "k r"], name: str - ) -> tuple[Float[torch.Tensor, "n_pairs"], Float[torch.Tensor, "n_pairs"]]: - """(clean, hack) pair cosines vs the routing dirs: max_i cos(g, v_i), the same - scoring the live gate uses, so band edges and thresholds use the same representation.""" - gh = raw_grads[f"hack/{name}"].float() # [n_pairs, r] - gc = raw_grads[f"clean/{name}"].float() - ch = torch.einsum("n r, k r -> n k", gh, v).max(dim=1).values / gh.norm(dim=1).clamp_min(1e-12) - cc = torch.einsum("n r, k r -> n k", gc, v).max(dim=1).values / gc.norm(dim=1).clamp_min(1e-12) - return cc, ch - - def _auroc(scores: list[float], labels: list[bool]) -> float: """Rank-based AUROC (Mann-Whitney U) of `scores` as a detector of the positive class. Higher score for hacks -> auroc > 0.5. nan if either class is absent this step. - Diagnostic only: ground-truth labels measure how well cos(g, v_grad) separates - reward-hacking updates, but never determine a route. Reading: ~0.5 means v_grad + Diagnostic only: ground-truth labels measure how well the gate score separates + reward-hacking updates, but never determine a route. Reading: ~0.5 means v_act is a chance-level classifier (no threshold can route reliably); high AUROC but rout~0 = the threshold/scale is wrong, not the direction; a drop across a refresh = - the refresh destroyed the separation (the step-5 cliff is then a direction problem).""" + the refresh destroyed the separation.""" pos = [s for s, y in zip(scores, labels) if y] neg = [s for s, y in zip(scores, labels) if not y] if not pos or not neg: @@ -167,22 +132,6 @@ def _auroc(scores: list[float], labels: list[bool]) -> float: return (sum_pos - n_pos * (n_pos + 1) / 2) / (n_pos * n_neg) -def route_band_edges(raw_grads: dict, v_grad: dict, device) -> dict[str, tuple[float, float]]: - """Calibrate an absolute routing band from authored pairs only. - - Clean/hack p75 edges avoid single-pair extremes and route only the confident - tail aligned with reward-hacking gradients. Pair/live shift can still make routing - idle; inspect `rout`. - See docs/papers/grad_routing/paper_sgtm.md. - """ - band = {} - for name in v_grad: - cc, ch = _pair_cos(raw_grads, v_grad[name].detach().cpu().float(), name) - band[name] = (cc.quantile(0.75).item(), ch.quantile(0.75).item()) # (lower=p75 clean, upper=p75 hack) - return band - - - # Fix evaluation sampling across steps and arms without perturbing the training RNG. EVAL_GEN_SEED = 12345 @@ -196,12 +145,12 @@ MODE_CODE: dict[str, str] = { def _validate_config(cfg: Config) -> None: """Reject contradictory experiment settings before model load.""" - if cfg.intervention not in ("none", "routeV", "absorb"): - raise ValueError(f"unknown intervention {cfg.intervention!r}; expected none|routeV|absorb") - if cfg.routeV_random_v_seed is not None and cfg.intervention != "routeV": - raise ValueError("routeV_random_v_seed is a routeV-only placebo control") + if cfg.intervention not in ("none", "routeA", "absorb"): + raise ValueError(f"unknown intervention {cfg.intervention!r}; expected none|routeA|absorb") + if cfg.routeA_random_v_seed is not None and cfg.intervention != "routeA": + raise ValueError("routeA_random_v_seed is a routeA-only placebo control") if cfg.rollout_ablate_frac > 0 and cfg.intervention == "none": - raise ValueError("rollout_ablate_frac needs a quarantine to ablate (routeV/absorb)") + raise ValueError("rollout_ablate_frac needs a quarantine to ablate (routeA/absorb)") if cfg.weight_decay != 0.0: raise ValueError("lora2r init is nonzero; AdamW decay pulls A/B toward 0 not toward init " "-- set --weight-decay=0") @@ -215,7 +164,7 @@ def _validate_config(cfg: Config) -> None: def _log_resolved_config(cfg: Config, device) -> None: """One block with every None resolved to its effective value, so a detached log shows exactly what ran -- especially WHICH pairset (the field readers kept losing).""" - is_routeV = cfg.intervention == "routeV" + is_routeA = cfg.intervention == "routeA" fields = { "preset/arm": f"{cfg.preset_name} / {cfg.arm}", "intervention": cfg.intervention, @@ -225,8 +174,8 @@ def _log_resolved_config(cfg: Config, device) -> None: "lora_r/init_seed": f"{cfg.lora_r} / {cfg.lora_init_seed}", "unhackable_frac": cfg.unhackable_frac, "env_mode": cfg.env_mode, - "pairset": cfg.vhack_pairs_path if is_routeV else "unused (not routeV)", - "routeV placebo seed": cfg.routeV_random_v_seed if is_routeV else "n/a", + "pairset": cfg.vhack_pairs_path if is_routeA else "unused (not routeA)", + "routeA placebo seed": cfg.routeA_random_v_seed if is_routeA else "n/a", "teacher pool/mix/off_step": ( f"{cfg.teacher_pool_dir.name} / {cfg.mix_ratio} / {cfg.teacher_off_step}" if cfg.teacher_pool_dir else "none (pure on-policy)"), @@ -254,10 +203,10 @@ def main(cfg: Config) -> int: logger.info(f"verbose log: {verbose_log}") _log_resolved_config(cfg, device) - is_routeV = cfg.intervention == "routeV" + is_routeA = cfg.intervention == "routeA" is_absorb = cfg.intervention == "absorb" is_vanilla = cfg.intervention == "none" - has_quarantine = is_routeV or is_absorb + has_quarantine = is_routeA or is_absorb # Only adapter parameters train; the base model remains frozen. tok = AutoTokenizer.from_pretrained(model_name) @@ -276,9 +225,10 @@ def main(cfg: Config) -> int: model.config.use_cache = False # ── adapter: rank-2r LoRA, deployed block [:r] + quarantine block [r:] ── - # routeV needs the per-rollout c-probe gate; none/absorb pin the masks instead. + # The routeA gate reads activations via forward hooks; no grad probe in training + # (the c-probe stays in lora2r only for scripts/diag_pinning.py diagnostics). wrappers = wrap_model_with_lora2r( - model, r=cfg.lora_r, init_seed=cfg.lora_init_seed, grad_probe=is_routeV) + model, r=cfg.lora_r, init_seed=cfg.lora_init_seed) # A and B both train; quarantine = block slices of the SAME tensors, so there # is no separate hack-param list (per-rollout masks route grads, not surgery). delta_params = [p for info in wrappers.values() for p in (info["A"], info["B"])] @@ -287,51 +237,47 @@ def main(cfg: Config) -> int: logger.info(f"trainable lora2r A+B: {sum(p.numel() for p in delta_params):,} " f"({n_quar:,} of those in quarantine blocks)") - # ── routeV direction: v_grad (mean pair-gradient diff) + routing band ── - v_grad = None # set only by the routeV branch below - route_band = None - if is_routeV: + # ── routeA direction: v_act (mean pooled-act pair diff, unit rows per module) ── + v_act = None # [M, r] cpu fp32; module order = act_names + act_names = sorted(wrappers) + act_buf: deque | None = None # rolling pooled acts [M, r]; re-scored vs the CURRENT + # v_act at each gate call, so a refresh needs no flush + MASK_PAIRS = None + if is_routeA: # Authored pairs are the only routing-label source; live oracle labels never enter training. - from .pairs import load_pairs - from .extract_vhack_grad import extract_v_hack MASK_PAIRS = load_pairs(cfg.vhack_pairs_path) - logger.info(f"routeV pairs: {cfg.vhack_pairs_path} -> {len(MASK_PAIRS)} pairs") - model.eval() # match standalone extract: deterministic backward, no dropout - _, _, raw_grads, _ = extract_v_hack( - model, tok, wrappers, MASK_PAIRS, - top_k=1, tau_axis=0.0, n_heldout=2, device=device, - ) - v_grad = _build_v_grad(raw_grads, wrappers, cfg.v_grad_k, device) - _vk = "mean-diff (mean-mass)" if cfg.v_grad_k == 1 else f"top-{cfg.v_grad_k} SVD subspace, gate=max_i cos" - logger.info(f"routeV grad: built v_grad ({_vk}) for {len(v_grad)} modules") - if cfg.routeV_random_v_seed is not None: - v_grad = _haar_unit_dirs(v_grad, cfg.routeV_random_v_seed, device) - logger.info(f"routeV grad: OVERRODE v_grad with Haar-random dirs " - f"(seed={cfg.routeV_random_v_seed}) -- placebo directionality control") - # Calibrate after any Haar override so the control covers the full routing pipeline. - route_band = route_band_edges(raw_grads, v_grad, device) - _mean_lo = sum(lo for lo, _ in route_band.values()) / len(route_band) - _mean_hi = sum(hi for _, hi in route_band.values()) / len(route_band) - _mean_bw = _mean_hi - _mean_lo - n_inc_band = sum(1 for lo, hi in route_band.values() if hi - lo > 0) - logger.info( - f"routeV band: {len(route_band)} modules, mean lower(p75 clean cos)={_mean_lo:+.3f}, " - f"mean upper(p75 hack cos)={_mean_hi:+.3f}, mean width={_mean_bw:+.3f}; " - f"{n_inc_band}/{len(route_band)} modules have positive band width (included in the gate). " - f"SHOULD: width>0 (pairs separate) and most modules included; ELSE extraction/band off.") - # Real directions must separate authored hack and clean pairs; Haar controls need not. - if cfg.routeV_random_v_seed is None: - assert _mean_bw > 0, ( - f"real v_grad gave non-positive mean band width {_mean_bw:+.3f}: " - "hack pairs do not separate from clean -> extraction broken") - logger.info( - "lora2r three-way output mask: per-rollout label from the width-pooled " - "band-normalized cosine across modules; clean->deployed-only, " - "hack->quarantine-only (deployed detached), mid->both (may permit absorption). " - "SHOULD: rout (hack share) tracks the step's rollout hack rate, not ~50%; " - "clipfrac on clean-gated rollouts < ~0.2 ELSE the retain-trick ratio " - "drift is binding (quarantine forward too large).") + logger.info(f"routeA pairs: {cfg.vhack_pairs_path} -> {len(MASK_PAIRS)} pairs") + model.eval() # deterministic forward, no dropout + v_act, pair_acts = extract_v_act(model, tok, wrappers, MASK_PAIRS, device, + tstat=cfg.vact_tstat) model.train() + # Authored-pair separation in the live score. The dot GAP is > 0 by construction + # (v is proportional to the mean pair diff); the pair AUROC is not, so it is the + # extraction sanity signal. + sh = torch.einsum("pmr,mr->p", pair_acts["hack"], v_act) + sc = torch.einsum("pmr,mr->p", pair_acts["clean"], v_act) + pair_auroc = _auroc(torch.cat([sh, sc]).tolist(), + [True] * len(sh) + [False] * len(sc)) + logger.info( + f"routeA v_act: {v_act.shape[0]} modules x r={v_act.shape[1]} " + f"(tstat={cfg.vact_tstat}); authored-pair dot gap={(sh.mean() - sc.mean()).item():+.3e}, " + f"pair AUROC={pair_auroc:.2f}. SHOULD: pair AUROC ~1.0 ELSE extraction broken.") + if cfg.routeA_random_v_seed is not None: + v_act = haar_unit_rows(tuple(v_act.shape), cfg.routeA_random_v_seed) + logger.info(f"routeA: OVERRODE v_act with Haar-random unit rows " + f"(seed={cfg.routeA_random_v_seed}) -- placebo directionality control") + act_buf = deque(maxlen=cfg.route_buffer) + logger.info( + f"routeA gate: per-rollout score = dot(pooled completion-token act, v_act), " + f"thresholds = two-threshold Otsu on the last <= {cfg.route_buffer} live scores " + f"(z-normalized, winsorized 1/99%), label-free; pinned absorb until " + f"{cfg.route_warmup} scores. keep (0,0) | absorb (1,0) | rout (1,1: deployed " + f"detached). No bimodality guard: on the cached emergence windows no shape " + f"statistic separates the hack mixture from hack-free scores (Otsu tail means " + f"sit ~2.4 sd apart even on a Gaussian), and a false rout only discards one " + f"update from deployment. " + f"SHOULD: auroc col >> 0.5 once hacks appear ELSE v_act is blind and routing " + f"is noise; rout tracks the hack share, not ~0 or ~1.") # ── teacher pool ── # Teacher pool: pre-generated rollouts on disk keyed by problem_id. Each step's @@ -348,7 +294,7 @@ def main(cfg: Config) -> int: G_t = 0 if cfg.teacher_pool_dir is not None: # mix=0 is the NO-TEACHER ablation: pure on-policy GRPO (G_t=0) while the - # pool is still loaded for the partition + routeV v_grad extraction. + # pool is still loaded for the partition. if not (0.0 <= cfg.mix_ratio < 1.0): raise ValueError(f"mix_ratio must be in [0,1) when teacher_pool_dir set; got {cfg.mix_ratio}") G_t = round(group * cfg.mix_ratio) @@ -568,57 +514,31 @@ def main(cfg: Config) -> int: save_ckpt([], path=run_dir / "ckpt_update0000.safetensors") - def _lora2r_gate_labels(c_grads: tuple, n_rollouts: int): - """Three-way output-mask label per rollout from the gate-pass c-probe grads. + def _routeA_gate(dots: Float[torch.Tensor, "G"]) -> tuple[torch.Tensor, torch.Tensor, float, float]: + """Three-way output-mask label per rollout from the rolling score buffer. - Per module the per-rollout weight grad of the virtual diagonal (deployed - block [r]) has a band-normalized cosine position. We POOL across modules in - a single (num, den) fraction (T3 fix): a module with a wide band contributes - proportionally more than a noisy near-zero-width one, instead of every module - casting an equal-weight vote. One global label per rollout (matching SGTM's - example-level labels): pos<=0 clean (m=0,d=0); pos>=1 hack (m=1,d=1); else mid - (m=1,d=0, both blocks train). Returns (m, d, f3, w, pos, cosU): f3 in {0,.5,1} for - _zone_stats, w = mean per-rollout grad norm for energy weighting, pos = the raw - per-rollout pooled position (for the AUROC diagnostic), cosU = pooled cos of the - SUMMED-rollout c-grad (the update direction) to v_grad.""" - num = torch.zeros(n_rollouts, device=device); den = 0.0 - w = torch.zeros(n_rollouts, device=device); n_inc = 0 - cosU_sum = 0.0 - for (name, info), cg in zip(wrappers.items(), c_grads, strict=True): - lower, upper = route_band[name] - if upper - lower <= 0: # noisy module: pairs don't separate -> excluded - continue - r_blk = info["r"] - g_b: Float[torch.Tensor, "G r"] = cg.reshape(n_rollouts, -1, 2 * r_blk).sum(1)[:, :r_blk].float() - nrm = g_b.norm(dim=1) - # cos to each of the k routing dirs, then max: aligned with ANY known hack sub-mode. - cos_b = torch.einsum("g r, k r -> g k", g_b, v_grad[name]).max(dim=1).values / nrm.clamp_min(1e-12) - num += cos_b - lower; den += upper - lower - w += nrm; n_inc += 1 - gs = g_b.sum(0) # summed-rollout deployed c-grad = this module's update direction - cosU_sum += (torch.einsum("r, k r -> k", gs, v_grad[name]).max() / gs.norm().clamp_min(1e-12)).item() - if n_inc == 0: - raise RuntimeError("no module has positive band width; pairs separate nowhere") - pos = num / den; w /= n_inc - # ── online-stats gate (#30): band from THIS batch's pooled positions, three zones ── - # The authored absolute band is mis-placed (live pos sits far below the synthetic-hack - # edge -- out/diag/pinning_calib.png) and a fixed quantile FORCES route_quantile out every - # step even when nothing separates. Calibrate the band from the CURRENT batch instead: - # refresh-proof by construction (these rollouts scored against the current v_grad), no - # window or flush to keep stale positions around. mean + k*std self-silences -- only the - # tail genuinely beyond the spread routes, so qmass tracks real separation. pos > - # mean+route_std_mid*std -> mid (both blocks train); pos >= mean+route_std_rout*std -> rout - # (hack, deployed detached); below -> keep (bulk). Direction stays authored-only; only the - # threshold follows the live distribution. - mu_pos, sd_pos = pos.mean().item(), pos.std().item() - t_lo = mu_pos + cfg.route_std_mid * sd_pos # mid onset - t_hi = mu_pos + cfg.route_std_rout * sd_pos # rout onset (>= mid) - logger.debug(f"gate pos: n_inc={n_inc}/{len(wrappers)} n_roll={n_rollouts} " - f"min={pos.min().item():+.2f} max={pos.max().item():+.2f} | " - f"mean={mu_pos:+.2f} std={sd_pos:.2f} t_lo={t_lo:+.2f} t_hi={t_hi:+.2f}") - m = (pos > t_lo).float() # mid + rout -> quarantine trains (keep = bulk below t_lo) - d = (pos >= t_hi).float() # top tail -> hack -> deployed detached - return m, d, 0.5 * m + 0.5 * d, w, pos, cosU_sum / n_inc + The buffer holds pooled ACTS, so every gate call scores the whole window + against the CURRENT v_act (refresh-proof; the only staleness left is act + drift as the adapter trains, small over <= route_buffer rollouts). Scores + are z-normalized by the buffer mean/std, then two-threshold Otsu (winsorized + inside _otsu3) places (t_lo, t_hi): z <= t_lo keep (0,0); t_lo < z < t_hi + absorb (1,0, both blocks train); z >= t_hi rout (1,1, deployed detached). + Warmup: pinned absorb until the buffer holds route_warmup scores -- too few + points to place thresholds, and absorb keeps both blocks learning.""" + if len(act_buf) < cfg.route_warmup: + G_n = dots.shape[0] + return (torch.ones(G_n, device=device), torch.zeros(G_n, device=device), + float("nan"), float("nan")) + S = torch.einsum("nmr,mr->n", torch.stack(tuple(act_buf)), v_act) + mu, sd = S.mean().item(), max(S.std().item(), 1e-12) + t_lo, t_hi = _otsu3(((S - mu) / sd).numpy()) + z = (dots - mu) / sd + m = (z > t_lo).float().to(device) # absorb + rout -> quarantine trains + d = (z >= t_hi).float().to(device) # top zone -> rout -> deployed detached + logger.debug(f"routeA gate: buf={len(act_buf)} mu={mu:+.3e} sd={sd:.3e} " + f"t_lo={t_lo:+.2f} t_hi={t_hi:+.2f} (z) | batch z " + f"min={z.min().item():+.2f} max={z.max().item():+.2f}") + return m, d, t_lo, t_hi # Disable tqdm off-TTY because structured per-step rows already report progress. pbar = tqdm(range(steps), desc=f"train {cfg.arm} {cfg.preset_name}", @@ -634,7 +554,7 @@ def main(cfg: Config) -> int: logger.info(f"teacher-off curriculum: step {step} >= {cfg.teacher_off_step} " f"-> teachers off (pure on-policy on the wider problem set from here)") # mix_ratio>0 is the teacher on/off switch (mix=0 = no-teacher ablation: pool still - # loaded for v_grad extraction, but no demos injected). The COUNT is teacher_n_per_prompt. + # loaded for the partition, but no demos injected). The COUNT is teacher_n_per_prompt. teachers_on = (not teacher_off) and cfg.mix_ratio > 0 \ and bool(covered_problems) and bool(teacher_pool or solve_pool) t0 = time.time() @@ -651,13 +571,15 @@ def main(cfg: Config) -> int: n_zerovar = 0 # groups skipped for zero reward variance (all rollouts same reward). agg_loss = 0.0 diag_tail = None - # routeV gate diagnostics (per-rollout three-way zone shares + retain-trick clipfrac). - step_flagged: list[float] = [] # hack share (mean d over rollouts) per prompt - step_clipfrac: list[float] = [] # PPO clip frac on clean-gated rollouts (retain-trick drift gauge) - step_zkeep: list[float] = []; step_zresid: list[float] = []; step_zrout: list[float] = [] # unit shares per zone - step_zkeepE: list[float] = []; step_zresidE: list[float] = []; step_zroutE: list[float] = [] # energy shares per zone - # AUROC diagnostic: per-rollout pooled pos + its hack-label, accumulated across prompts. - step_auroc_pos: list[float] = []; step_auroc_hack: list[bool] = []; step_cosU: list[float] = [] + # routeA gate diagnostics (per-rollout three-way zone shares + clean-gated clipfrac). + step_clipfrac: list[float] = [] # PPO clip frac on keep-gated rollouts (ratio-drift gauge) + step_zkeep: list[float] = []; step_zresid: list[float] = []; step_zrout: list[float] = [] # unit shares per zone + step_tlo: list[float] = []; step_thi: list[float] = [] # Otsu thresholds (z units) + # AUROC diagnostic on the A>0 contrast: scores + hack-labels of positively- + # reinforced rollouts only (where the advantage alone is blind), students + + # cached teachers. Accumulated across prompts; measurement only, never routes. + step_auroc_score: list[float] = []; step_auroc_hack: list[bool] = [] + step_cos: list[float] = [] # mean per-rollout cos(act, v_act) (dot-vs-cos diagnostic) # Solve-mix discrimination: routed-share (mean d) over hack-teacher vs solve-teacher rollouts. step_route_hackT: list[float] = []; step_route_solveT: list[float] = [] @@ -870,18 +792,33 @@ def main(cfg: Config) -> int: # logπ_old: old-policy logprobs (frozen PPO-ratio target). logits_to_keep # =L_c+1 runs lm_head only on completion-side hidden states; [:, :-1] drops # the last position (predicts beyond `merged`, unused). + # For routeA this forward runs QUARANTINE-ABLATED, matching both the sampling + # policy (gen_students is deploy-mode) and the v_act extraction (quarantine- + # ablated), so the gate score and the vector live on the same observable path. + # The same forward carries the ActCapture hooks: the gate costs no extra pass. completion_ids = merged[:, plen:] L_c = completion_ids.shape[1] + mask = (completion_ids != pad_id).float() _tfb = time.perf_counter() - with torch.no_grad(): - logπ_old = per_token_logps( - model(merged, logits_to_keep=L_c + 1).logits[:, :-1], - completion_ids, - ).detach() + if is_routeA: + with torch.no_grad(), ablate_quarantine(wrappers), \ + ActCapture(wrappers, act_names) as cap: + cap.set_pool(plen, mask) + logπ_old = per_token_logps( + model(merged, logits_to_keep=L_c + 1).logits[:, :-1], + completion_ids, + ).detach() + acts = cap.pooled().cpu() # [G, M, r] fp32 + else: + with torch.no_grad(): + logπ_old = per_token_logps( + model(merged, logits_to_keep=L_c + 1).logits[:, :-1], + completion_ids, + ).detach() - # Pin block masks for the non-gated arms BEFORE the grad-carrying forward - # (arm semantics: train_config.py docstring): none -> (0,0), absorb -> (1,0). - # routeV leaves mask=None so the gate pass sees an unmasked forward. + # Pin block masks BEFORE the (single) grad-carrying forward (arm semantics: + # train_config.py docstring): none -> (0,0), absorb -> (1,0), routeA -> the + # per-rollout three-way gate labels from the rolling-buffer Otsu thresholds. if is_vanilla: _z = torch.zeros(merged.shape[0], device=device) for info in wrappers.values(): @@ -891,13 +828,42 @@ def main(cfg: Config) -> int: _z = torch.zeros(merged.shape[0], device=device) for info in wrappers.values(): info["layer"]._lora2r_mask = (_o, _z) + elif is_routeA: + dots = torch.einsum("gmr,mr->g", acts, v_act) # [G] + # cos = dot / (||act|| ||v||); v rows are unit so ||v|| = sqrt(M). + coss = dots / (acts.flatten(1).norm(dim=1) + * math.sqrt(len(act_names))).clamp_min(1e-12) + step_cos.append(coss.mean().item()) + act_buf.extend(acts.unbind(0)) + m_vec, d_vec, _tl, _th = _routeA_gate(dots) + for info in wrappers.values(): + info["layer"]._lora2r_mask = (m_vec, d_vec) + step_tlo.append(_tl); step_thi.append(_th) + step_zkeep.append((m_vec == 0).float().mean().item()) + step_zresid.append(((m_vec == 1) & (d_vec == 0)).float().mean().item()) + step_zrout.append((d_vec == 1).float().mean().item()) + # AUROC diagnostic on the A>0 contrast: merged order is [students; + # teachers], the same order hack_flags was built in, so dots aligns. + pos_a = (A > 0).cpu().tolist() + step_auroc_score.extend(s for s, p in zip(dots.tolist(), pos_a) if p) + step_auroc_hack.extend(bool(h) for h, p in zip(hack_flags, pos_a) if p) + # Solve-mix discrimination: teachers are the LAST G_t rows of merged; split + # their routed-share (mean d) by source. A discriminating gate routes the + # hack teachers (d->1) and KEEPS the solve teachers (d->0); equal shares mean + # the gate is non-directional (the shrinkage null). Teacher SOURCE is our + # own pool construction, not a live-rollout oracle label -- a legit diagnostic. + if teacher_is_solve: + is_solve_t = torch.tensor(teacher_is_solve, device=d_vec.device, dtype=torch.bool) + d_teach = d_vec[-len(teacher_is_solve):] + if (~is_solve_t).any(): + step_route_hackT.append(d_teach[~is_solve_t].mean().item()) + if is_solve_t.any(): + step_route_solveT.append(d_teach[is_solve_t].mean().item()) logπ = per_token_logps( model(merged, logits_to_keep=L_c + 1).logits[:, :-1], completion_ids, ) - - mask = (merged[:, plen:] != pad_id).float() # Per-rollout mean per-token logπ_old (student's logp on its own tokens). # Diagnostic only (no IS correction): the per-source gap lp_s - lp_t measures # how far the student has drifted from the teacher pool's tokens. @@ -914,56 +880,20 @@ def main(cfg: Config) -> int: ptl = (Lp_ * mask).sum(1) / mask.sum(1).clamp_min(1) return ptl.sum() / (group * prompts_per_step) - # Three-way output masking; gradients accumulate on A/B. - # Gradient-space labels exist only AFTER a backward (labels: before forward; - # activations: before backward; grads: after), so routeV pays a second masked - # forward+backward. none/absorb were pinned before the logπ forward and need - # only this one pass. + # One masked forward+backward for EVERY arm; rollouts route to BLOCKS via + # the output masks pinned above (nothing is subtracted from any gradient + # vector; v_act is a classifier only). Gradients accumulate on A/B. loss = _grpo_loss(Lp) - if is_routeV: - # PASS 1 (gate): grads w.r.t. the c-probes ONLY. autograd.grad leaves - # A.grad/B.grad untouched, so nothing to zero between passes. - gates = [info["layer"]._lora2r_gate for info in wrappers.values()] - c_grads = torch.autograd.grad(loss, gates) - m_vec, d_vec, f3, w3, pos_vec, cosU = _lora2r_gate_labels(c_grads, merged.shape[0]) - step_flagged.append(d_vec.mean().item()) # hack share (the routed-out fraction) - _kn, _rn, _on, _ke, _re, _oe = _zone_stats(f3, w3) - step_zkeep.append(_kn); step_zresid.append(_rn); step_zrout.append(_on) - step_zkeepE.append(_ke); step_zresidE.append(_re); step_zroutE.append(_oe) - # AUROC diagnostic: pos as a hack-detector vs the hack-label (student - # exploited + teacher cached). merged order is [students; teachers], the - # same order hack_flags was built in, so pos_vec aligns with hack_flags. - step_auroc_pos.extend(pos_vec.detach().cpu().tolist()) - step_auroc_hack.extend(bool(h) for h in hack_flags) - step_cosU.append(cosU) - # Solve-mix discrimination: teachers are the LAST G_t rows of merged; split - # their routed-share (mean d) by source. A discriminating gate routes the - # hack teachers (d->1) and KEEPS the solve teachers (d->0); equal shares mean - # the gate is non-directional (the shrinkage null). Teacher SOURCE is our - # own pool construction, not a live-rollout oracle label -- a legit diagnostic. - if teacher_is_solve: - is_solve_t = torch.tensor(teacher_is_solve, device=d_vec.device, dtype=torch.bool) - d_teach = d_vec[-len(teacher_is_solve):] - if (~is_solve_t).any(): - step_route_hackT.append(d_teach[~is_solve_t].mean().item()) - if is_solve_t.any(): - step_route_solveT.append(d_teach[is_solve_t].mean().item()) - # PASS 2 (masked): rollouts route to BLOCKS via masked sums; nothing is - # subtracted from any gradient vector (v_grad = classifier only). - for info in wrappers.values(): - info["layer"]._lora2r_mask = (m_vec, d_vec) - logπ2 = per_token_logps( - model(merged, logits_to_keep=L_c + 1).logits[:, :-1], completion_ids) - ρ2 = torch.exp(logπ2 - logπ_old) - loss = _grpo_loss(-torch.min(ρ2 * A_tok, - torch.clamp(ρ2, 1 - cfg.clip, 1 + cfg.clip) * A_tok)) - # Retain-trick wrinkle: clean rollouts were SAMPLED quarantine-on but TRAIN - # quarantine-off; the PPO ratio absorbs the gap, clip bounds it. - clean = m_vec == 0 - if clean.any(): - clipped = ((ρ2.detach() - 1).abs() > cfg.clip).float() + if is_routeA: + # Keep-gated rollouts train quarantine-off, the exact state generation + # and logπ_old used, so their ratio sits ~1. Absorb/rout rollouts see + # the quarantine delta in the forward only -> ratio drift, bounded by + # the clip; clipfrac on those rollouts is the drift gauge. + qon = m_vec == 1 + if qon.any(): + clipped = ((ρ.detach() - 1).abs() > cfg.clip).float() step_clipfrac.append( - ((clipped * mask)[clean].sum() / mask[clean].sum().clamp_min(1)).item()) + ((clipped * mask)[qon].sum() / mask[qon].sum().clamp_min(1)).item()) loss.backward() # A/B grads accumulate across prompts (opt.zero_grad clears per step) for info in wrappers.values(): info["layer"]._lora2r_mask = None @@ -990,38 +920,26 @@ def main(cfg: Config) -> int: opt.step() sched.step() - # ── v_grad refresh ── + # ── v_act refresh ── # Re-extract the routing direction against the CURRENT model so it tracks where # hacks separate now, not at step 0. Without this the frozen direction goes stale. # Same MASK_PAIRS (the authored pairs, no oracle); quarantine ablated so the hack - # signal flows back through the observable path, matching the build-time extract. + # signal is read on the deployed observable path, matching the build-time extract + # and the gate forward. Forward-only, so the refresh is cheap. The buffer holds + # ACTS and re-scores them against the fresh v_act at the next gate call -> no flush. refr = "-" - do_refresh = cfg.vhack_refresh_every > 0 and (step + 1) % cfg.vhack_refresh_every == 0 - if do_refresh and is_routeV and cfg.routeV_random_v_seed is not None: - do_refresh = False # keep the one fixed Haar draw; re-extracting would replace it - if do_refresh and is_routeV: + do_refresh = (is_routeA and cfg.vhack_refresh_every > 0 + and (step + 1) % cfg.vhack_refresh_every == 0 + and cfg.routeA_random_v_seed is None) # placebo keeps its one Haar draw + if do_refresh: _was_training = model.training model.eval() - opt.zero_grad(set_to_none=True) - logger.disable("vgrout.extract_vhack_grad") - logger.disable("__main__") - try: - with ablate_quarantine(wrappers): - from .extract_vhack_grad import extract_v_hack - _, _, raw_grads, _ = extract_v_hack( - model, tok, wrappers, MASK_PAIRS, - top_k=1, tau_axis=0.0, n_heldout=2, device=device, - ) - # update in place so the gate closure sees the fresh dirs (same k as init). - v_grad.update(_build_v_grad(raw_grads, wrappers, cfg.v_grad_k, device)) - route_band = route_band_edges(raw_grads, v_grad, device) # rebuild band on fresh v_grad - finally: - logger.enable("vgrout.extract_vhack_grad") - logger.enable("__main__") - opt.zero_grad(set_to_none=True) # extract leaves .grad populated + with ablate_quarantine(wrappers): + v_act, _ = extract_v_act(model, tok, wrappers, MASK_PAIRS, device, + tstat=cfg.vact_tstat) if _was_training: model.train() - refr = "rfr" # gate calibrates from the live batch each step -> no window to flush + refr = "rfr" # ── periodic held-out eval (deploy = quarantine ablated) ── hack_deployed = solve_deployed = float("nan") @@ -1114,13 +1032,14 @@ def main(cfg: Config) -> int: f"step {step} TIMING gen={t_gen:.0f}s fwd_bwd={t_fb:.0f}s " f"reward={t_rew:.0f}s other={_tstep - t_gen - t_fb - t_rew:.0f}s total={_tstep:.0f}s") if step_clipfrac: - logger.debug(f"routeV clean-gated clipfrac={sum(step_clipfrac)/len(step_clipfrac):.3f} " - f"(SHOULD: <~0.2; higher = retain-trick ratio drift binding)") + logger.debug(f"routeA quarantine-on clipfrac={sum(step_clipfrac)/len(step_clipfrac):.3f} " + f"(SHOULD: <~0.2; higher = quarantine forward delta drifting far " + f"from the ablated old policy)") if step_route_hackT or step_route_solveT: _rh = sum(step_route_hackT) / len(step_route_hackT) if step_route_hackT else float("nan") _rs = sum(step_route_solveT) / len(step_route_solveT) if step_route_solveT else float("nan") route_hackT_run.append(_rh); route_solveT_run.append(_rs) - logger.debug(f"routeV solve-mix discrimination: hack-teacher routed={_rh:.2f} vs " + logger.debug(f"routeA solve-mix discrimination: hack-teacher routed={_rh:.2f} vs " f"solve-teacher routed={_rs:.2f} (SHOULD: hack >> solve -> gate " f"discriminates correct-solution from reward-hacking updates; ~equal -> non-directional/shrinkage)") if diag_tail is not None: @@ -1148,15 +1067,16 @@ def main(cfg: Config) -> int: "lp_t": lp_t_mean if n_t else None, "loss": agg_loss, "gn": gn, - "auroc": _auroc(step_auroc_pos, step_auroc_hack), - "cosU": (sum(step_cosU) / len(step_cosU)) if step_cosU else float("nan"), + # auroc is the A>0 contrast (hack vs non-hack among positively-reinforced + # rollouts) -- the contrast where the reward alone is blind. + "auroc": _auroc(step_auroc_score, step_auroc_hack), + "cos": (sum(step_cos) / len(step_cos)) if step_cos else float("nan"), "qmass": q_egy, "keep": (sum(step_zkeep) / len(step_zkeep)) if step_zkeep else float("nan"), "resid": (sum(step_zresid) / len(step_zresid)) if step_zresid else float("nan"), "rout": (sum(step_zrout) / len(step_zrout)) if step_zrout else float("nan"), - "keepE": (sum(step_zkeepE) / len(step_zkeepE)) if step_zkeepE else float("nan"), - "residE": (sum(step_zresidE) / len(step_zresidE)) if step_zresidE else float("nan"), - "routE": (sum(step_zroutE) / len(step_zroutE)) if step_zroutE else float("nan"), + "tlo": (sum(step_tlo) / len(step_tlo)) if step_tlo else float("nan"), + "thi": (sum(step_thi) / len(step_thi)) if step_thi else float("nan"), "lr": sched.get_last_lr()[0], "refr": refr, # Deploy-eval (quarantine ablated); NaN except on eval steps. @@ -1236,20 +1156,17 @@ def main(cfg: Config) -> int: solve_rate_s = gt_s_total / max(1, n_s_total) hack_rate_t = hack_t_total / max(1, n_t_total) - # routeV/absorb must move the quarantine; none must leave it exactly zero. The - # quarantine LEARNED delta = (A[r:], B[:,r:]) minus the frozen init. + # routeA/absorb must move the quarantine; none must leave it exactly zero. The + # quarantine LEARNED delta = (A[r:], B[:,r:]) minus the frozen init. The routeA + # warmup pins absorb, so even a placebo run trains the quarantine. dsh_norm = float(sum( (info["A"].data[info["r"]:] - info["A0"][info["r"]:]).float().pow(2).sum().item() + (info["B"].data[:, info["r"]:] - info["B0"][:, info["r"]:]).float().pow(2).sum().item() for info in wrappers.values()) ** 0.5) logger.info(f"||quarantine learned delta|| = {dsh_norm:.4f} " - f"(SHOULD: >0 for routeV/absorb, ==0 for none; ELSE routing broke)") - if has_quarantine and cfg.routeV_random_v_seed is None: + f"(SHOULD: >0 for routeA/absorb, ==0 for none; ELSE routing broke)") + if has_quarantine: assert dsh_norm > 0.0, f"{cfg.intervention}: quarantine never moved -> nothing trained it" - elif cfg.routeV_random_v_seed is not None and dsh_norm == 0.0: - # A Haar control may validly route nothing because no rollout clears its band. - logger.warning("routeV Haar control: ||quarantine delta||==0 -> the random direction routed " - "NOTHING. This is a real result (favours: alignment needed), not a failure.") # Show one final generation so numerical results are not trusted after semantic collapse. if last_gen_sample is not None: diff --git a/src/vgrout/train_config.py b/src/vgrout/train_config.py index 1182ad8..2920bf8 100644 --- a/src/vgrout/train_config.py +++ b/src/vgrout/train_config.py @@ -5,7 +5,7 @@ masking; see src/vgrout/lora2r.py) and three arms: none gate pinned clean (0,0): quarantine never trains -- the capacity- and structure-matched vanilla control. - routeV per-rollout three-way gate from the c-probe gradient vs v_grad. + routeA per-rollout three-way gate from the pooled bottleneck activation vs v_act. absorb gate pinned mid (1,0): both blocks train on everything, no gate -- tests ungated both-block training. """ @@ -20,7 +20,7 @@ from .rewards import EnvMode @dataclass(kw_only=True) class Config: - intervention: Literal["none", "routeV", "absorb"] = "routeV" + intervention: Literal["none", "routeA", "absorb"] = "routeA" lora_r: int = 32 lora_init_seed: int = 0 @@ -46,30 +46,27 @@ class Config: unbiased: bool = True vhack_refresh_every: int = 5 - # The 8 original behavior_* pairs only: per-pairset diag (out/diag/pinning_pairset_auroc.png) - # ranks this subset's v_grad best at separating live hacks (AUROC 0.69, d=+0.85), well above - # the full all-in-one. The `@behavior` TAG would re-add the anti-aligned opportunity-aware - # pairs (d=-0.03) and dilute, so select by heading prefix, not tag. Wave-2 arms (untested): - # `/behavior2` = 15 new mechanisms, `/behavior` = 23-pair union. + # The 8 original behavior_* pairs only: the best or tied vector on all three diag + # windows (RESEARCH_JOURNAL 2026-06-11 (d)). The `@behavior` TAG would re-add the + # anti-aligned opportunity-aware pairs and dilute, so select by heading prefix, not + # tag. Wave-2 arms (untested): `/behavior2` = 15 new mechanisms, `/behavior` = union. vhack_pairs_path: Path = Path("data/pairs/hack_pairs.md#all-in-one/behavior_") - # Routing directions per module. k=1 (headline): the mean(hack-clean) "mean-mass" - # axis. k>1: top-k oriented SVD dirs of the paired diff; gate scores max_i cos(g,v_i) - # (alignment to ANY known hack sub-mode) -- catches multi-modal hack signal one mean - # washes out. k=1 stays mean-diff (not SVD-top-1) so "mean-mass vs top-k" is a clean A/B. - v_grad_k: int = 1 - # Online-stats gate (#30): route by the live MEAN + k*STD of the pooled cosine-to-v_grad, - # not a fixed quantile. The authored absolute band is mis-placed (live pos sits far below - # the synthetic-hack edge; even synthetic SOLVE is more hack-aligned than on-policy hack -- - # see out/diag/pinning_calib.png), and a fixed quantile FORCES route_quantile of the batch - # out every step even when nothing separates. mean+k*std self-silences: it only routes the - # tail that genuinely exceeds the spread. pos > mean + route_std_mid*std -> mid (both blocks train); - # pos >= mean + route_std_rout*std -> rout (hack, deployed detached); below -> keep (bulk). - # The band is calibrated from the CURRENT batch each step (no window, no flush): refresh-proof - # by construction. Direction stays authored-only; only the threshold follows the live dist. - route_std_mid: float = 2.0 - route_std_rout: float = 3.0 + # t-stat extraction (v = mean/SE per coordinate, clamp |t|<=3): null at the current + # 8 pairs (journal (e)); revisit when the authored-pair set grows. + vact_tstat: bool = False + # routeA gate thresholds come from a rolling buffer of the last route_buffer live + # scores: z-normalize by buffer mean/std, two-threshold Otsu (winsorized 1/99% -- + # journal (d): without it one outlier buys a whole zone) -> keep | absorb | rout. + # Until the buffer holds route_warmup scores the gate pins absorb (both blocks + # train): too few points to place thresholds. The buffer stores pooled ACTS and + # re-scores them against the current v_act, so a refresh needs no flush. No + # bimodality guard: no shape statistic separates the hack mixture from hack-free + # scores on the cached windows (Otsu tail means sit ~2.4 sd apart even on a + # Gaussian), and a false rout only discards one update from deployment. + route_buffer: int = 256 + route_warmup: int = 128 # Haar-random direction control (placebo): same routing machinery, no pair signal. - routeV_random_v_seed: int | None = None + routeA_random_v_seed: int | None = None rollout_ablate_frac: float = 0.0 env_mode: EnvMode = "run_tests" @@ -112,8 +109,9 @@ class Config: @property def arm(self) -> str: # _lora2r suffix kept so these runs never conflate with the retired - # PiSSA-substrate runs of the same intervention (rename-on-logic-change). - return {"none": "vanilla_lora2r", "routeV": "routingV_lora2r", + # PiSSA-substrate runs of the same intervention (rename-on-logic-change; + # routeA likewise never conflates with the retired grad-gate routeV runs). + return {"none": "vanilla_lora2r", "routeA": "routeA_lora2r", "absorb": "absorb_lora2r"}[self.intervention] @@ -126,12 +124,11 @@ class SmokeConfig(Config): max_new: int = 32 n_problems: int = 100 prompts_per_step: int = 1 - # Random tiny data never separates, so the self-silencing band (mean+2/3*std) would route - # nothing and the quarantine would never train -> the routing-pathway smoke assert fails. - # Force routing by lowering the band so the smoke exercises mid+rout (correctness, not the - # real threshold). mid below mean -> most train quarantine; rout at mean -> top ~half detach. - route_std_mid: float = -1.0 - route_std_rout: float = 0.0 + # Smoke produces 4 scores/step over 30 steps; the real 256/128 buffer would keep the + # gate in warmup forever. Shrink so the smoke exercises warmup AND the Otsu gate + # (keep/absorb/rout + deployed detach) within a few steps. + route_buffer: int = 32 + route_warmup: int = 8 @dataclass(kw_only=True)