feat(#41): routeA activation gate replaces routeV grad gate

Gate now scores each rollout by dot(pooled bottleneck act, v_act) captured on
the no-grad logpi_old forward (quarantine-ablated, matching the sampling
policy); masks are pinned BEFORE the single grad-carrying forward, so the
grad-gate's pass-1 backward is gone. Thresholds: rolling 256-act buffer,
z-normalized, two-threshold Otsu (winsorized 1/99); warmup pins absorb until
128 scores. Buffer stores pooled acts and re-scores against the current v_act,
so the forward-only refresh (every 5 steps) needs no flush. No bimodality
guard: calibration showed Otsu tail separation ~2.4-2.8 buffer-sd on every
condition including pure Gaussians, so no shape statistic discriminates.

Deleted with the arm wiring (rename-on-logic-change: routeA never conflates
with routeV runs): extract_vhack_grad.py, _build_v_grad, route_band_edges,
_pair_cos, the pass-1 autograd.grad block, grad_probe training wiring,
v_grad_k/route_std_*/routeV_random_v_seed config, smoke-topk recipe.
c-probe stays in lora2r.py for scripts/diag_pinning.py only.

verify_science_invariants: all-in-one count 27 -> 42 (stale since c33b810
added the wave-2 behavior2 pairs) + assert the 8-pair routeA training subset.

Smoke: routeA/vanilla/absorb/solvemix all pass (gate exercises warmup, Otsu
zones, refresh, deploy ablation) -- /tmp/claude-1000/smoke_routeA.log.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-11 12:38:19 +00:00
parent 5a340e5c3e
commit adca442253
14 changed files with 337 additions and 729 deletions
+27 -24
View File
@@ -36,28 +36,30 @@ outputs (`m` = quarantine on/off, `d` = deployed detach):
To get the hack direction we pair examples by hand: for each problem, one
correct completion and one completion that exploits the evaluation procedure.
For each pair we compute the *exact GRPO gradient* that would result if the hack
rollout had advantage +1 and the clean rollout had advantage -1
(`-grad logp(hack) + grad logp(clean)`), read in the per-module c-probe space (a
virtual diagonal scale between `A` and `B`). The mean hack-minus-clean direction,
normalized per module, is `v_grad`. (Mechanically identical to a twin-NLL
extraction, since GRPO with adv=+/-1 reduces to the NLL difference; the GRPO
framing is the one we mean.) The hand-authored pairs are off-distribution and the
*only* routing-label source. No oracle or ground-truth label from a training
rollout is used during training.
For each pair we run a forward pass and read the bottleneck activation `A@x`
(the rank-2r input projection of each wrapped Linear), masked-mean-pooled over
completion tokens. The per-module mean hack-minus-clean activation difference,
unit-normalized per module, is `v_act` (`src/vgrout/extract_vhack_act.py`).
Extraction is forward-only: no backward pass, no labels. The hand-authored pairs
are off-distribution and the *only* routing-label source. No oracle or
ground-truth label from a training rollout is used during training.
At training time routeV runs **two passes** per prompt. Pass 1 is an unmasked
forward+backward whose gradient w.r.t. the c-probes gives, per rollout, a
band-normalized cosine to `v_grad` pooled across modules -> a `{clean, mid, hack}`
label (band edges = authored-pair p75 cosines). Pass 2 sets the block masks from
those labels and does the real masked forward+backward; the optimizer steps on
what survives. A rollout with high alignment to the reward-hacking direction updates
the quarantine block while its deployed branch is detached. We re-extract `v_grad`
every N steps so it tracks the current separation rather than only the initial model.
At training time routeA scores each rollout on the no-grad `logp_old` forward it
already needs: an activation-capture hook pools the same bottleneck activations
over completion tokens, and the score is the pooled dot product with `v_act`.
Thresholds come from a rolling buffer of recent scores, z-normalized and split by
two-threshold Otsu into `{keep, absorb, rout}`; until the buffer fills the gate
pins absorb. The block masks are set from those labels *before* the single
masked forward+backward, so there is no second gradient pass. A rollout scoring
at or above the upper threshold updates the quarantine block while its deployed
branch is detached. We re-extract `v_act` every N steps (forward-only,
quarantine-ablated) so it tracks the current model; the buffer stores pooled
activations and re-scores them against the current `v_act`, so a refresh needs
no flush.
Whether the *direction* (not just the *act* of routing) drives suppression is the
open question -- the placebo control (Haar-random `v_grad`, same routing
machinery) must NOT match real `v_grad`. We watch `qmass` (the share of update
open question -- the placebo control (Haar-random `v_act`, same routing
machinery) must NOT match real `v_act`. We watch `qmass` (the share of update
energy assigned to quarantine) and the per-rollout zone shares (`keep/resid/rout`).
## What we compare
@@ -74,8 +76,9 @@ Three arms, identical model/adapter/teacher pool, differing only in the gate
- **none** -- gate pinned clean `(0,0)`: the quarantine never trains. The
capacity- and structure-matched vanilla control (same adapter, no shrinkage
confound). The emergence reference.
- **routeV** -- the method: per-rollout three-way gate from the c-probe gradient
vs `v_grad`. `--routeV-random-v-seed` swaps in a Haar-random direction (placebo).
- **routeA** -- the method: per-rollout three-way gate from the pooled bottleneck
activation vs `v_act`. `--routeA-random-v-seed` swaps in a Haar-random direction
(placebo).
- **absorb** -- gate pinned mid `(1,0)`: both blocks train on every rollout. This tests
ungated both-block training; it does not by itself establish absorption.
@@ -88,10 +91,10 @@ ablation does not change the model.
```bash
uv sync
just smoke # tiny-random model, routeV pathway + all verify gates, ~1-2 min
just smoke-all # vanilla + routeV + absorb back to back
just smoke # tiny-random model, routeA pathway + all verify gates, ~1-2 min
just smoke-all # vanilla + routeA + absorb back to back
just download-model # warm Qwen3-4B cache
just queue-decision # queue the 4-arm decision run (routeV real / placebo / vanilla / absorb)
just queue-decision # queue the 4-arm decision run (routeA real / placebo / vanilla / absorb)
```
See [RESEARCH_JOURNAL.md](RESEARCH_JOURNAL.md) for session-by-session findings,
+24 -31
View File
@@ -1,7 +1,7 @@
set shell := ["bash", "-cu"]
# vGROUT: rank-2r LoRA gradient routing vs reward-hacking. One adapter (lora2r),
# three arms (intervention none|routeV|absorb). See AGENTS.md / README.md.
# three arms (intervention none|routeA|absorb). See AGENTS.md / README.md.
MODEL := "Qwen/Qwen3-4B"
TINY_MODEL := "llamafactory/tiny-random-qwen3" # qwen3 arch, ~6M params, smoke only
TRAIN := "uv run python -m vgrout.train" # real LeetCode GRPO entry point
@@ -24,8 +24,10 @@ default:
# Real pipeline on tiny inputs; verify_*.py assert invariants (no tests/ dir).
# ─────────────────────────────────────────────────────────────────────────────
# Default smoke = routeV (full pipeline: extract v_grad -> two-pass gate -> deploy
# ablation). Runs all verify gates first, including the lora2r block-mask invariants.
# Default smoke = routeA (full pipeline: extract v_act -> act gate on the logpi_old
# forward -> Otsu pinning -> deploy ablation). Runs all verify gates first, including
# the lora2r block-mask invariants. (scripts/verify_v_act.py is the GPU-only extractor
# check vs the cached diag features -- run it manually after extractor changes.)
smoke *ARGS:
uv run python scripts/verify_rewards.py # grader: 3 env_modes x clean/hack
uv run python scripts/verify_eval_gap.py # eval: train/test token gap, 4 modes
@@ -33,18 +35,18 @@ smoke *ARGS:
uv run python scripts/verify_science_invariants.py # pair provenance + untouched test
uv run python scripts/verify_rotation.py # rotating-unhackable hint-free flip
uv run python scripts/verify_lora2r_routing.py # block masks + ablation + c-probe
just smoke-routeV {{ ARGS }}
just smoke-routeA {{ ARGS }}
# none: gate pinned clean (0,0) -> quarantine never trains (capacity/structure-matched vanilla).
smoke-vanilla *ARGS:
BEARTYPE=1 {{ TRAIN }} smoke --intervention=none \
--teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 {{ ARGS }}
# routeV: extract v_grad from authored pairs, splice the per-rollout c-probe gate,
# PASS 1 (unmasked) labels rollouts {clean,mid,hack} via the width-pooled band cosine,
# PASS 2 (masked) trains the blocks; deploy ablation resets the quarantine to init.
smoke-routeV *ARGS:
BEARTYPE=1 {{ TRAIN }} smoke --intervention=routeV \
# routeA: extract v_act from authored pairs (forward-only), capture pooled acts on the
# no-grad logpi_old forward, label rollouts {keep,absorb,rout} via rolling-buffer Otsu
# thresholds, ONE masked forward+backward; deploy ablation resets the quarantine to init.
smoke-routeA *ARGS:
BEARTYPE=1 {{ TRAIN }} smoke --intervention=routeA \
--teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \
--eval-ablate-every=10 --eval-n-prompts=2 {{ ARGS }}
@@ -62,21 +64,13 @@ smoke-unhackable *ARGS:
--teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \
--eval-n-prompts=2 {{ ARGS }}
# routeV with a top-k routing subspace (max_i cos(g,v_i) over k SVD dirs) instead of
# the single mean-mass axis. UAT: log shows "top-3 SVD subspace, gate=max_i cos" and the
# band/gate still route (rout>0). k=1 (default) is the mean-diff headline.
smoke-topk *ARGS:
BEARTYPE=1 {{ TRAIN }} smoke --intervention=routeV --v-grad-k=3 \
--teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \
--eval-ablate-every=10 --eval-n-prompts=2 {{ ARGS }}
# routeV + symmetric SOLVE-teacher pool: the G_t teacher slots split 50/50 solve/hack,
# routeA + symmetric SOLVE-teacher pool: the G_t teacher slots split 50/50 solve/hack,
# and the run logs the routed-share discrimination (UAT: a line "solve-mix gate
# discrimination: hack-teacher routed-share=X vs solve-teacher routed-share=Y"). Smoke
# points solve at the same tiny pool just to exercise the split+diagnostic path; real
# runs use out/pools/teacher_pool_solve (correct-solution demos) vs the hack pool.
smoke-solvemix *ARGS:
BEARTYPE=1 {{ TRAIN }} smoke --intervention=routeV \
BEARTYPE=1 {{ TRAIN }} smoke --intervention=routeA \
--teacher-pool-dir=out/pools/teacher_pool --solve-pool-dir=out/pools/teacher_pool \
--mix-ratio=0.5 --solve-mix-frac=0.5 \
--eval-ablate-every=10 --eval-n-prompts=2 {{ ARGS }}
@@ -84,7 +78,7 @@ smoke-solvemix *ARGS:
# All three arms back to back (the full-coverage gate).
smoke-all:
just smoke-vanilla
just smoke-routeV
just smoke-routeA
just smoke-absorb
# ─────────────────────────────────────────────────────────────────────────────
@@ -92,22 +86,21 @@ smoke-all:
# pool, 50% unhackable, authored pairs). Every job carries a why:/resolve: label.
# ─────────────────────────────────────────────────────────────────────────────
# Headline 5-arm lora2r decision run, ONLINE-STATS gate + teacher forcing ({{ TEACH }}).
# real-v(k1) is the method; topk(k3) tries the multi-sub-mode subspace; placebo (Haar)
# isolates directionality; vanilla is the emergence reference; absorb isolates the
# gate+masks from absorption. Priority descending so they run in listed order.
# Headline 4-arm lora2r decision run, routeA ACT gate + teacher forcing ({{ TEACH }}).
# real-v is the method (v_act from authored pairs, Otsu rolling-buffer thresholds);
# placebo (Haar) isolates directionality; vanilla is the emergence reference; absorb
# isolates the gate+masks from absorption. Priority descending so they run in listed order.
# --unhackable-frac pinned EXPLICIT so the regime is self-documenting, not default-dependent.
# Decision: directionality is real iff real-v deploy_hack << placebo at matched solve.
# Watch the streamed `auroc` col: ~0.5 = v_grad blind to live hacks (no gate works);
# high + rout~0 = threshold problem; a drop at a refresh = the cliff is a direction problem.
# Watch the streamed `auroc` col (A>0 contrast): ~0.5 = v_act blind to live hacks (no gate
# works); high + rout~0 = threshold problem; a drop at a refresh = a direction problem.
# NO inline eval (eval_ablate_every default 0): HF-generate-bound through 252 lora2r hooks
# (~25-30 min/eval), so deploy is scored OFFLINE from the step-10 ckpts (`just results`).
queue-decision seed='43':
pueue add -w "$PWD" -o 62 -l "why: P1 lora2r routeV REAL-v k1 online-stats + teacher-forcing s{{seed}} (25% unhackable); resolve: deploy_hack << placebo at matched solve -> directionality real" -- {{ TRAIN }} fast --intervention=routeV --unhackable-frac=0.25 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_routeV_real_s{{seed}}
pueue add -w "$PWD" -o 60 -l "why: P2 lora2r routeV TOPK k3 online-stats + teacher-forcing s{{seed}} (25% unhackable); resolve: topk deploy_hack <= real-k1 -> sub-mode subspace catches hacks the mean washes out" -- {{ TRAIN }} fast --intervention=routeV --v-grad-k=3 --unhackable-frac=0.25 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_routeV_topk_s{{seed}}
pueue add -w "$PWD" -o 58 -l "why: P3 lora2r routeV PLACEBO-v (Haar 157) + teacher-forcing s{{seed}} (25% unhackable); resolve: deploy_hack ~ vanilla -> real-v suppression is directional, not absorption/shrinkage" -- {{ TRAIN }} fast --intervention=routeV --routeV-random-v-seed=157 --unhackable-frac=0.25 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_routeV_placebo_s{{seed}}
pueue add -w "$PWD" -o 56 -l "why: P4 lora2r VANILLA (gate pinned clean) + teacher-forcing s{{seed}} (25% unhackable); resolve: deploy_hack >> 0 emergence reference on the identical adapter" -- {{ TRAIN }} fast --intervention=none --unhackable-frac=0.25 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_vanilla_s{{seed}}
pueue add -w "$PWD" -o 54 -l "why: P5 lora2r BOTH-BLOCK (masks pinned (1,0), no gate) + teacher-forcing s{{seed}} (25% unhackable); resolve: ~vanilla -> gate+masks add nothing; << vanilla -> ungated both-block training suppresses" -- {{ TRAIN }} fast --intervention=absorb --unhackable-frac=0.25 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_absorb_s{{seed}}
pueue add -w "$PWD" -o 62 -l "why: P1 lora2r routeA REAL-v act gate + teacher-forcing s{{seed}} (25% unhackable); resolve: deploy_hack << placebo at matched solve -> directionality real" -- {{ TRAIN }} fast --intervention=routeA --unhackable-frac=0.25 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_routeA_real_s{{seed}}
pueue add -w "$PWD" -o 58 -l "why: P2 lora2r routeA PLACEBO-v (Haar 157) + teacher-forcing s{{seed}} (25% unhackable); resolve: deploy_hack ~ vanilla -> real-v suppression is directional, not absorption/shrinkage" -- {{ TRAIN }} fast --intervention=routeA --routeA-random-v-seed=157 --unhackable-frac=0.25 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_routeA_placebo_s{{seed}}
pueue add -w "$PWD" -o 56 -l "why: P3 lora2r VANILLA (gate pinned clean) + teacher-forcing s{{seed}} (25% unhackable); resolve: deploy_hack >> 0 emergence reference on the identical adapter" -- {{ TRAIN }} fast --intervention=none --unhackable-frac=0.25 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_vanilla_s{{seed}}
pueue add -w "$PWD" -o 54 -l "why: P4 lora2r BOTH-BLOCK (masks pinned (1,0), no gate) + teacher-forcing s{{seed}} (25% unhackable); resolve: ~vanilla -> gate+masks add nothing; << vanilla -> ungated both-block training suppresses" -- {{ TRAIN }} fast --intervention=absorb --unhackable-frac=0.25 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_absorb_s{{seed}}
# Base model zero-shot deploy eval (0 training steps): reproduce the paper's base
# solve ~11.5% in our harness. resolve: base solve ~0.10-0.12.
+19 -23
View File
@@ -1,4 +1,4 @@
"""Q2 diagnostic: what should the live routeV gate SCORE, and where do the pinning
"""Q2 diagnostic: what should the live routing gate SCORE, and where do the pinning
cuts go?
THE QUESTION (Q2). The gate routes UPDATES, not rollouts: per rollout the GRPO update
@@ -90,8 +90,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
from vgrout.lora2r import wrap_model_with_lora2r
from vgrout.pairs import load_pairs
from vgrout.extract_vhack_grad import completion_nll
from vgrout.train import _auroc
from vgrout.train import _auroc, _otsu3
# colour = behaviour (blue solve, red hack, grey fail); style = source (solid on-policy, dashed synthetic)
SOLVE, HACK, FAIL, ABSORB_C, ROUT_C, ORACLE = "#3b6ea5", "#c44e52", "#9aa0a6", "#d1900a", "#c44e52", "#3a8a7a"
@@ -104,7 +103,7 @@ class Cfg:
run_dir: Path = Path("out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3")
ckpt: str = "first_hack"
pairs: Path = Path("data/pairs/hack_pairs.md#all-in-one")
# headline figure builds v from this heading-prefix subset = the routeV TRAINING
# headline figure builds v from this heading-prefix subset = the routeA TRAINING
# default (train_config.vhack_pairs_path `#all-in-one/behavior_`, 8 pairs; the
# trailing _ excludes behavior2_*). The pairset table spans all subsets of `pairs`.
headline_prefix: str = "behavior_"
@@ -216,24 +215,21 @@ def _kde(x: np.ndarray, grid: np.ndarray) -> np.ndarray:
return np.exp(-0.5 * z ** 2).sum(1) / (len(x) * bw * np.sqrt(2 * np.pi))
def _otsu3(x: np.ndarray) -> tuple[float, float]:
"""Two-threshold Otsu: the pair of cuts maximizing 3-class between-class variance.
Label-free -- an online gate can compute this from a rolling window of scores, so
using it here is not oracle leakage. O(n^2), fine for a few hundred scores.
Scores are winsorized at 1/99% first: Otsu maximizes variance, so on heavy-tailed
scores a single extreme point otherwise buys a whole class (seen on grad_dot)."""
x = np.clip(x, *np.quantile(x, [0.01, 0.99]))
s = np.sort(np.asarray(x, float))
n = len(s)
c = np.concatenate([[0.0], np.cumsum(s)])
best, best_ij = -np.inf, (1, 2)
for i in range(1, n - 1):
for j in range(i + 1, n):
obj = c[i] ** 2 / i + (c[j] - c[i]) ** 2 / (j - i) + (c[n] - c[j]) ** 2 / (n - j)
if obj > best:
best, best_ij = obj, (i, j)
i, j = best_ij
return float((s[i - 1] + s[i]) / 2), float((s[j - 1] + s[j]) / 2)
def completion_nll(model, tokenizer, prompt: str, completion: str, device) -> torch.Tensor:
"""Mean NLL over completion tokens only (length-normalized). The backward of this
loss populates the c-probe grads read by _gate_grads (the retired grad-gate space,
kept here as a diagnostic baseline)."""
prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
full_ids = tokenizer(prompt + completion, return_tensors="pt").input_ids.to(device)
n_prompt = prompt_ids.shape[1]
logits = model(full_ids).logits[:, :-1] # [1, L-1, V]
targets = full_ids[:, 1:] # [1, L-1]
logp = torch.nn.functional.log_softmax(logits.float(), dim=-1)
nll = -logp.gather(-1, targets.unsqueeze(-1)).squeeze(-1) # [1, L-1]
# mask: positions whose target is a completion token (i.e. index >= n_prompt in full_ids)
pos = torch.arange(full_ids.shape[1] - 1, device=device).unsqueeze(0)
mask = (pos >= (n_prompt - 1)).float()
return (nll * mask).sum() / mask.sum().clamp_min(1.0)
def plot_q2(df: pl.DataFrame, subtitle: str, out_png: Path) -> dict:
@@ -497,7 +493,7 @@ def _downstream(cfg: Cfg, fe: dict, src: str) -> int:
print(f"SHOULD: on_hackpos >= ~20 and on_drop not the majority, ELSE the window/run has "
f"too few learnable hacks and every AUROC below is noise.")
# ── headline vectors from the routeV-default subset; placebo swaps in Haar ──
# ── headline vectors from the routeA-default subset; placebo swaps in Haar ──
groups: dict[str, list[int]] = fe["pair_groups"]
head_idx = [i for i, pid in enumerate(pair_ids) if pid.startswith(cfg.headline_prefix)]
assert head_idx, f"no pairs match headline prefix {cfg.headline_prefix!r}"
+2 -2
View File
@@ -7,12 +7,12 @@ hack-side by detector signature. Here the source is the student's logged
rollouts (out/runs/<run>/rollouts.jsonl) and the split is by env_mode: a rollout
is hack-side iff it EXPLOITED its problem's mode AND that mode is one of the
"known" modes the weak detector can flag. The held-out modes are never used to
build pairs -- v_grad is extracted only from the known modes, and the A5 figure
build pairs -- the routing vector is extracted only from the known modes, and the A5 figure
then measures whether the held-out modes are also suppressed at deployment. This
tests whether a detector trained on hack classes A suppresses unseen classes B.
Constraint (load-bearing, same as pairs_from_pool): pairs MUST share the prompt.
The paired-diff g_hack - g_clean in extract_vhack_grad cancels prompt-specific
The paired-diff feat_hack - feat_clean in the pair extraction cancels prompt-specific
noise only when both completions are conditioned on the same chat-templated
prompt. A given problem_id renders one fixed (hinted) prompt across steps, so
same-(problem_id) hack/clean rollouts share it; we assert equality to catch drift.
+2 -1
View File
@@ -74,7 +74,8 @@ def main() -> int:
authored_pairs = load_pairs(Path("data/pairs/hack_pairs.md#all-in-one"))
real_pairsets_ok = (
len(authored_pairs) == 27
len(authored_pairs) == 42 # 27 + 15 wave-2 behavior2_* (c33b810)
and len(load_pairs(Path("data/pairs/hack_pairs.md#all-in-one/behavior_"))) == 8 # routeA training default
and len(load_pairs(Path("data/pairs/hack_pairs.md#all-in-one@opportunity-aware"))) == 6
and len(load_pairs(Path("data/pairs/hack_pairs.md#all-in-one@explicit"))) == 10
and len(load_pairs(Path("data/pairs/hack_pairs.md#all-in-one@roleplay"))) == 2
-301
View File
@@ -1,301 +0,0 @@
"""Gradient-side per-module v_hack extraction (spec.md §B, top-k variant).
We sample the per-module GRPO update direction on labeled (hack, clean) pairs.
For a pair with advantages (adv_h=+1, adv_c=-1) the Dr.GRPO single-step grad
`-adv_h * grad_logp(hack) - adv_c * grad_logp(clean)` algebraically equals
`grad_NLL(hack) - grad_NLL(clean)`, so we compute it by the simpler path:
forward each completion, take mean-NLL on completion tokens, backward, and
capture the lora2r c-probe grad (the per-pair weight grad of the virtual
diagonal between A and B, deployed block) per wrapped Linear. Naming the steps
NLL is an implementation detail; the *meaning* is "the GRPO update on this pair."
Then per module, with D = [g_hack_i - g_clean_i for each pair] in R^{n_pairs x r}:
SVD(D) = U Σ Vh
v_hack[name] = top_k rows of Vh, each oriented so mean(D @ v_i) > 0
This generalizes mean-diff (which corresponds to top-1 PC of paired diffs under
isotropic covariance) to a rank-k hack subspace, motivated by CHaRS (Abdullaev
2025 -- see docs/paper_chars.md): hack signal is multi-modal across hack flavors
(weak tests, hardcode, persona, ...), so a single global direction is brittle.
Orientation matters because proj.py applies a per-direction one-sided gate
(only subtracts <g, v_i> when positive). +v_i must align with the reward-hacking gradient.
Saves `out/v_hack.safetensors` = dict[name -> Tensor[k, r]] (cpu fp32, rows
unit-norm + orthonormal from SVD) with header {"model": str, "dtype": str,
"top_k": str(k)}.
Run: uv run python -m vgrout.extract_vhack_grad
"""
from __future__ import annotations
import sys
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
import torch
import tyro
from jaxtyping import Float
from loguru import logger
from safetensors.torch import save_file
from tabulate import tabulate
from transformers import AutoModelForCausalLM, AutoTokenizer
from .lora2r import wrap_model_with_lora2r
from .pairs import load_pairs, pairset_sha256
OUT_DIR = Path("out")
@dataclass
class Config:
model: str = "Qwen/Qwen3-4B"
dtype: str = "bf16" # must match train.py, else SVD basis cache can differ silently
out_path: Path = OUT_DIR / "vhack" / "v_hack.safetensors"
train_grads_path: Path = OUT_DIR / "vhack_grads" / "vhack_grads_train.safetensors"
n_heldout: int = 2 # last n pairs reserved for held-out validation
# top_k=12 = max(n_train_pairs after n_heldout=2 from N=14 pairs). Extract once
# at max rank; train.py slices via --v-hack-k for k-ablation without re-extract.
top_k: int = 12
# tau_axis: zero rows where S_i/S_0 < tau_axis. Diagnostic -- projection along
# noise-direction unit vectors removes only ~||g||/sqrt(r) ≈ 2% of grad
# magnitude on r=2560 modules, so this rarely changes effect size; it does
# keep k-ablations interpretable (axes 4-5 might be pure noise on N=12 pairs).
tau_axis: float = 0.0
# Pairset reference: generated JSON or one `path.md#section`.
pairs_from_pool: Path | None = None
def resolve_dtype(s: str) -> torch.dtype:
return {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[s]
def completion_nll(model, tokenizer, prompt: str, completion: str, device) -> torch.Tensor:
"""Mean NLL over completion tokens only (length-normalized)."""
prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
full_ids = tokenizer(prompt + completion, return_tensors="pt").input_ids.to(device)
n_prompt = prompt_ids.shape[1]
logits = model(full_ids).logits[:, :-1] # [1, L-1, V]
targets = full_ids[:, 1:] # [1, L-1]
logp = torch.nn.functional.log_softmax(logits.float(), dim=-1)
nll = -logp.gather(-1, targets.unsqueeze(-1)).squeeze(-1) # [1, L-1]
# mask: positions whose target is a completion token (i.e. index >= n_prompt in full_ids)
pos = torch.arange(full_ids.shape[1] - 1, device=device).unsqueeze(0)
mask = (pos >= (n_prompt - 1)).float()
return (nll * mask).sum() / mask.sum().clamp_min(1.0)
def extract_v_hack(
model,
tokenizer,
wrappers: dict,
pairs: list,
top_k: int,
tau_axis: float,
n_heldout: int,
device,
) -> tuple[
dict[str, Float[torch.Tensor, "k r"]],
dict[str, Float[torch.Tensor, "k"]],
dict[str, Float[torch.Tensor, "n_pairs r"]],
list[dict],
]:
"""Run pair-grads + per-module SVD on D = g_hack - g_clean, return v_hack.
Pure function -- caller owns model loading, wrapping, and saving. train.py
calls this on its already-wrapped model when v_hack cache is missing, so
we don't pay the cost of a second model load.
Returns:
v_hack: dict[name -> Tensor[k, r]] (cpu fp32), top-k right singular
vectors of D per module, oriented so mean(D @ v_i) > 0. If
tau_axis > 0, rows where S_i/S_0 < tau_axis are zeroed.
v_sv: dict[name -> Tensor[k]] (cpu fp32), singular values matching v_hack.
Saved alongside V under `_sv/{name}` keys so load-time noise-floor
filtering works without re-extracting.
raw_grads: dict["hack/name"|"clean/name" -> Tensor[n_pairs, r]] for
offline analysis (verify_vhack_heldout reads this).
diag_rows: per-module diagnostic dicts (sv_top frac, ||D||, etc.).
"""
train_pairs = pairs[:-n_heldout] if n_heldout > 0 else pairs
n_pairs = len(train_pairs)
grads_hack: dict[str, list[torch.Tensor]] = defaultdict(list)
grads_clean: dict[str, list[torch.Tensor]] = defaultdict(list)
for pi, pair in enumerate(train_pairs):
for label, completion in (("hack", pair.hack), ("clean", pair.clean)):
model.zero_grad(set_to_none=True)
loss = completion_nll(model, tokenizer, pair.prompt, completion, device)
if not torch.isfinite(loss):
# Skip-on-nonfinite would silently leave G_h and G_c with mismatched
# lengths and explode later in D = G_h - G_c. Fail fast.
raise RuntimeError(f"non-finite loss at pair={pi} label={label}: {loss.item()}")
loss.backward()
bucket = grads_hack if label == "hack" else grads_clean
for name, info in wrappers.items():
layer = info["layer"]
# Per-pair weight grad of the virtual diagonal (c-probe), DEPLOYED
# block only -- the same space the live gate reads (train.py), so
# band calibration uses the same representation. Requires grad_probe=True.
cg = layer._lora2r_gate.grad
if cg is None:
raise RuntimeError(f"no c-probe grad on {name}; wrap with grad_probe=True")
g = cg.sum(dim=tuple(range(cg.dim() - 1)))[: layer._lora2r_r] # [r]
bucket[name].append(g.detach().float().cpu().clone())
if (pi + 1) % 5 == 0:
logger.info(f" pair {pi+1}/{n_pairs} loss={loss.item():.3f}")
model.zero_grad(set_to_none=True) # leave caller with clean state
raw_grads = {
**{f"hack/{n}": torch.stack(gs) for n, gs in grads_hack.items()},
**{f"clean/{n}": torch.stack(gs) for n, gs in grads_clean.items()},
}
# Per module: D = g_hack - g_clean (paired diff cancels prompt-specific noise).
# SVD(D) gives orthonormal right singular vectors = principal axes of variation
# of the hack-clean axis. Top-k generalizes mean-diff (which is the k=1 case).
v_hack: dict[str, torch.Tensor] = {}
v_sv: dict[str, torch.Tensor] = {}
rows = []
n_zero = 0
k = min(top_k, n_pairs)
n_axes_kept_total = 0
for name in grads_hack:
G_h = torch.stack(grads_hack[name]) # [n_pairs, r]
G_c = torch.stack(grads_clean[name])
D = G_h - G_c
U_d, S_d, Vh_d = torch.linalg.svd(D, full_matrices=False)
V = Vh_d[:k] # [k, r], rows orthonormal in R^r
# Orient by per-pair majority vote: for each axis i, count pairs where
# d_p @ v_i > 0; if strict majority disagree with current SVD sign, flip.
# More outlier-robust than sign(mean): one extreme pair can't flip a
# consensus direction. Matches repeng's _orient_svd convention.
proj_per_pair = D @ V.T # [n_pairs, k]
n_pos = (proj_per_pair > 0).float().sum(0) # [k]
flip = torch.where(n_pos < n_pairs / 2, -torch.ones(k), torch.ones(k))
V = V * flip.unsqueeze(1)
# tau_axis: zero rows where S_i/S_0 < tau_axis (diagnostic; see Config comment).
n_axes_kept = k
if tau_axis > 0 and S_d[0] > 1e-12:
ratios = S_d[:k] / S_d[0]
keep = (ratios >= tau_axis).float()
V = V * keep.unsqueeze(1)
n_axes_kept = int(keep.sum())
n_axes_kept_total += n_axes_kept
nrm = D.norm()
if nrm < 1e-12:
n_zero += 1
v_hack[name] = torch.zeros((k, D.shape[1]), dtype=V.dtype).contiguous()
else:
v_hack[name] = V.contiguous()
# Record singular values so the load-time noise-floor filter has the
# extraction-time S_i per axis without re-extracting. Saved under
# `_sv/{name}` keys in the safetensors file (combined at save site).
v_sv[name] = S_d[:k].clone().contiguous()
sv_top = S_d[:k]
sv_total = S_d.sum().clamp_min(1e-12)
rows.append({
"module": name.split(".")[-1],
"r": D.shape[1],
"||D||": f"{nrm:.2e}",
"sv_0": f"{S_d[0].item():.2e}" if S_d.numel() else "-",
f"sv_top{k}_frac": f"{(sv_top.sum() / sv_total).item():.2f}",
"sv_ratio_0/1": ("-" if S_d.numel() < 2
else f"{(S_d[0] / S_d[1].clamp_min(1e-12)).item():.2f}"),
"axes_kept": n_axes_kept,
})
n_modules = len(grads_hack)
logger.info(
f"v_hack: modules={n_modules} k_max={k} zero-||D||={n_zero} "
f"axes_kept_avg={n_axes_kept_total/max(1,n_modules):.1f} (tau_axis={tau_axis})"
)
return v_hack, v_sv, raw_grads, rows
def main(cfg: Config) -> int:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = resolve_dtype(cfg.dtype)
if cfg.pairs_from_pool is None:
raise ValueError("--pairs-from-pool is required; use data/pairs/hack_pairs.md#all-in-one")
pairs = load_pairs(cfg.pairs_from_pool)
logger.info(f"pairs source: {cfg.pairs_from_pool} -> {len(pairs)} pairs")
logger.info(
f"device={device} model={cfg.model} dtype={cfg.dtype} "
f"N_pairs={len(pairs)} heldout={cfg.n_heldout} top_k={cfg.top_k} tau_axis={cfg.tau_axis}"
)
tokenizer = AutoTokenizer.from_pretrained(cfg.model)
model = AutoModelForCausalLM.from_pretrained(
cfg.model, dtype=dtype, attn_implementation="sdpa"
).to(device)
model.eval() # disable dropout; gradients still flow through the adapter
wrappers = wrap_model_with_lora2r(model, grad_probe=True)
n_mod = len(wrappers)
logger.info(f"wrapped {n_mod} modules; probe space = r per module")
train_pairs = pairs[:-cfg.n_heldout] if cfg.n_heldout > 0 else pairs
logger.info(f"train pairs: {len(train_pairs)} held: {cfg.n_heldout}")
v_hack, v_sv, raw_grads, rows = extract_v_hack(
model, tokenizer, wrappers, pairs,
top_k=cfg.top_k, tau_axis=cfg.tau_axis,
n_heldout=cfg.n_heldout, device=device,
)
n_zero = sum(1 for v in v_hack.values() if v.norm() < 1e-12)
k = min(cfg.top_k, len(train_pairs))
cfg.out_path.parent.mkdir(parents=True, exist_ok=True)
cfg.train_grads_path.parent.mkdir(parents=True, exist_ok=True)
save_file(raw_grads, str(cfg.train_grads_path),
metadata={"model": cfg.model, "dtype": cfg.dtype})
# v_hack file layout: bare `{name}` keys hold V[k, r]; `_sv/{name}` keys
# hold S[k]. Loader at train.py:load_v_hack splits them back apart.
save_payload = {**v_hack, **{f"_sv/{n}": s for n, s in v_sv.items()}}
save_file(save_payload, str(cfg.out_path),
metadata={"model": cfg.model, "dtype": cfg.dtype, "top_k": str(k),
"tau_axis": str(cfg.tau_axis), "schema": "v2_with_sv",
"pairs_path": str(cfg.pairs_from_pool),
"pairs_sha256": pairset_sha256(cfg.pairs_from_pool)})
# summary: aggregate by suffix -- track top-k energy concentration
by_suffix: dict[str, list] = defaultdict(list)
for r in rows:
by_suffix[r["module"]].append(float(r[f"sv_top{k}_frac"]))
agg_rows = []
for suf, vals in sorted(by_suffix.items()):
agg_rows.append({
"suffix": suf,
"n": len(vals),
f"mean_sv_top{k}_frac": f"{sum(vals)/len(vals):.2f}",
f"min_sv_top{k}_frac": f"{min(vals):.2f}",
f"max_sv_top{k}_frac": f"{max(vals):.2f}",
})
# Final tail: BLUF -- what an agent reads first should be result + interp.
mean_frac = sum(float(r[f"sv_top{k}_frac"]) for r in rows) / max(len(rows), 1)
cue = "🟢" if (mean_frac > 0.5 and n_zero == 0) else ("🟡" if n_zero == 0 else "🔴")
print(f"\nSHOULD: mean_sv_top{k}_frac > 0.5 per suffix (subspace captures most energy). "
f"zero-||D||==0 (else grad flow broken).\n")
print(tabulate(agg_rows, headers="keys", tablefmt="tsv", floatfmt=".2f"))
print()
print(f"out: {cfg.out_path}")
print(f"argv: extract_vhack_grad --model={cfg.model} --top-k={k} --n-heldout={cfg.n_heldout}")
print(f"main metric: mean_sv_top{k}_frac={mean_frac:.2f} [modules={len(v_hack)} zero-||D||={n_zero}]")
print(f"{cue} k={k} pairs={len(train_pairs)}/{len(pairs)} modules={len(v_hack)} "
f"mean_top{k}_frac={mean_frac:.2f} zero={n_zero}")
if n_zero > 0:
logger.error(f"FAIL: {n_zero}/{len(v_hack)} modules have zero ||D|| -- gradient flow broken")
return 1
return 0
if __name__ == "__main__":
sys.exit(main(tyro.cli(Config)))
+6 -4
View File
@@ -17,14 +17,16 @@ from pathlib import Path
FIGS_DIR = Path("docs/figs")
# Reader-facing arm names. Code/log tags carry our internal vocabulary
# (routeV = the current routing arm); plots must
# (routeA = the current routing arm); plots must
# not. Map every internal tag to the word a paper reader sees. Anything missing
# falls through to its raw tag, so a new arm shows up loud rather than silently
# mislabelled.
ARM_DISPLAY = {
# routeV is the current banded-gate arm; routing2/route2 are the old binary-tau runs
# (kept so historical run artifacts still plot -- see rename, 2026-06-06).
"routingV": "route", "routeV": "route", "routingV_per_token": "route per-token",
# routeA is the current act-gate arm; routeV (grad gate) and routing2/route2
# (binary-tau) are retired but kept so historical run artifacts still plot.
"routeA": "route",
"routingV": "route (grad)", "routeV": "route (grad)",
"routingV_per_token": "route per-token",
"routing2": "route", "route2": "route",
"routing2_grad": "route", "routing2_act": "route (act)",
"projected": "erase", "route": "route", "erase": "erase", "vanilla": "vanilla",
+3 -3
View File
@@ -51,9 +51,9 @@ def _lora2r_hook(layer: nn.Linear, args: tuple, y: Tensor) -> Tensor:
grad probe: c = ones[..., 2r] spliced as h*c. After backward
c.grad = h ⊙ (Bᵀδ_y) = the per-sample WEIGHT grad of a virtual diagonal
scale between A and B. The pair extraction (extract_vhack_grad) and the live
gate (train.py) both read this same space, so the band is self-consistent
whatever the basis.
scale between A and B. Training no longer uses it (the routeA gate reads
activations); kept for scripts/diag_pinning.py, whose grad-score baseline
reads this space.
"""
(x,) = args # x: [..., d_in]
A: Float[Tensor, "two_r d_in"] = layer._lora2r_A # trainable
+1 -1
View File
@@ -9,7 +9,7 @@ exclusively from `half-A` detectors. The clean-side is any rollout where all
four upstream detectors are False AND format_ok is True.
Constraint (load-bearing): pairs MUST share the prompt. The paired-diff
`g_hack - g_clean` in extract_vhack_grad cancels prompt-specific noise only
`feat_hack - feat_clean` in the pair extraction cancels prompt-specific noise only
when both completions are conditioned on the same chat-templated prompt.
Cross-prompt pairs would inject prompt-difference signal into v_hack.
+1
View File
@@ -13,6 +13,7 @@ RUN_SCHEMA = "paired_final_v2" # v2: deployed/as_trained field names (was depl
# get a _lora2r suffix so the two substrates never conflate in aggregation.
ARM = {"none": "vanilla", "erase": "projected",
"routeV": "routingV", "routeV_per_token": "routingV_per_token",
"routeA": "routeA",
"absorb": "absorb"}
+10 -11
View File
@@ -73,7 +73,7 @@ class StepLogger:
def __init__(self, arm: str, modes: list[str], mode_code: dict[str, str],
show_ablate: bool = False) -> None:
# Routing diagnostics are ALWAYS shown (nan on vanilla, whose gate never runs) so the
# column layout is identical across arms -- vanilla/routeV/absorb tables line up.
# column layout is identical across arms -- vanilla/routeA/absorb tables line up.
cols: list[_Col] = [
_Col("step", 4, "step", "d", "GRPO step"),
_Col("ref_eq", 6, "ref_eq", ".2f", "vanilla-equiv step (cum_gens/256)"),
@@ -84,7 +84,7 @@ class StepLogger:
_Col("hack_s", 7, "hack_s?", "frac", "student hack-flagged rollouts (the headline)"),
_Col("hack_t", 7, "hack_t", "frac", "teacher hack-flagged rollouts (sanity: pool hacks)"),
# Held-out deployed evaluation with quarantine ablated; NaN between evaluation steps.
_Col("hack_deployed", 7, "hk_dep", "+.2f", "DEPLOY-eval hack (routeV: quarantine OFF; vanilla/erase: trained model); held-out subset, T=0.7, every eval_ablate_every steps; nan between"),
_Col("hack_deployed", 7, "hk_dep", "+.2f", "DEPLOY-eval hack (routeA/absorb: quarantine OFF; vanilla: trained model); held-out subset, T=0.7, every eval_ablate_every steps; nan between"),
_Col("solve_deployed", 7, "slv_dep", "+.2f", "DEPLOY-eval solve (same cadence; nan between)"),
]
# Multi-mode runs show current-step hacks per environment; single-mode would duplicate hack_s.
@@ -99,17 +99,16 @@ class StepLogger:
_Col("gn", 7, "gn", ".1e", "pre-clip L2 norm of A/B grads (vs grad_clip)"),
_Col("lr", 7, "lr", ".1e", "scheduled learning rate"),
]
# routeV reports unit and energy shares across the routing band (nan on vanilla/absorb).
# routeA reports gate diagnostics (nan on vanilla/absorb, whose gate never runs).
cols += [
_Col("auroc", 6, "auroc", ".2f", "AUROC of pooled cos(g,v_grad) as a reward-hacking detector; measurement only, never routes. ~0.5 = chance-level separation; high AUROC but rout~0 = threshold/scale problem; a drop at refresh = reduced separation"),
_Col("cosU", 6, "cosU", "+.2f", "pooled cos(v_grad, summed-rollout c-grad): alignment of the net update with v_grad"),
_Col("auroc", 6, "auroc", ".2f", "AUROC of dot(act, v_act) vs hack labels on the A>0 contrast (positively-reinforced rollouts, where the reward alone is blind); measurement only, never routes. ~0.5 = chance-level separation; high AUROC but rout~0 = threshold problem; a drop at refresh = reduced separation"),
_Col("cos", 6, "cos", "+.2f", "mean per-rollout cos(act, v_act) (dot-vs-cos diagnostic)"),
_Col("qmass", 6, "qmass", ".2f", "quarantine energy share ||g_quar||/(||g_keep||+||g_quar||): fraction of update energy assigned to quarantine"),
_Col("keep", 6, "keep", ".2f", "rollout share labelled clean (below band) -> deployed-only, quarantine off"),
_Col("resid", 6, "resid", ".2f", "rollout share labelled mid (inside band) -> both blocks train; absorption is possible but not measured"),
_Col("rout", 6, "rout", ".2f", "rollout share labelled hack (above band) -> quarantine-only, deployed detached"),
_Col("keepE", 6, "keepE", ".2f", "energy-weighted keep (grad-norm-weighted clean share)"),
_Col("residE", 6, "residE", ".2f", "energy-weighted resid"),
_Col("routE", 6, "routE", ".2f", "energy-weighted rout"),
_Col("keep", 6, "keep", ".2f", "rollout share below t_lo -> deployed-only, quarantine off"),
_Col("resid", 6, "resid", ".2f", "rollout share between thresholds (and ALL rollouts during warmup) -> both blocks train; absorption is possible but not measured"),
_Col("rout", 6, "rout", ".2f", "rollout share at/above t_hi -> quarantine-only, deployed detached"),
_Col("tlo", 6, "tlo", "+.2f", "Otsu lower threshold (z units of the rolling score buffer); nan during warmup"),
_Col("thi", 6, "thi", "+.2f", "Otsu upper (rout) threshold (z units); nan during warmup"),
]
# Show the training-prompt deploy proxy only when an ablated slice exists.
if show_ablate:
+213 -296
View File
@@ -16,13 +16,14 @@ both trainable, partitioned into a deployed block [:r] and a quarantine block
Arms (--intervention):
none gate pinned clean (0,0): quarantine never trains -- the capacity- and
structure-matched vanilla control.
routeV per-rollout three-way gate from the c-probe gradient vs v_grad:
clean->deployed-only, hack->quarantine-only (deployed detached),
mid->both, which may permit absorption.
routeA per-rollout three-way gate from the pooled bottleneck activation vs
v_act: keep->deployed-only, rout->quarantine-only (deployed detached),
absorb->both, which may permit absorption. The acts ride the no-grad
logpi_old forward, so routeA costs roughly the vanilla arm.
absorb gate pinned mid (1,0): both blocks train on everything, no gate --
tests ungated both-block training.
uv run python -m vgrout.train smoke --intervention=routeV
uv run python -m vgrout.train smoke --intervention=routeA
"""
from __future__ import annotations
@@ -33,9 +34,12 @@ import os
import sys
import random
import time
from collections import deque
from contextlib import nullcontext
from pathlib import Path
import numpy as np
# Must be set BEFORE `import torch` to take effect on the CUDA allocator.
# Eliminates fragmentation that caused 91 GiB allocated / 581 MiB free crash
# on Qwen3-4B G=8 (PyTorch's own OOM message recommends this).
@@ -51,7 +55,9 @@ from tabulate import tabulate
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from .extract_vhack_act import ActCapture, extract_v_act, haar_unit_rows
from .lora2r import wrap_model_with_lora2r
from .pairs import load_pairs
from .proj import per_token_logps
from .rewards import EnvMode, compute_reward
from .data import DATA, load_problems
@@ -64,46 +70,27 @@ OUT_DIR = Path("out")
RUNS_DIR = OUT_DIR / "runs"
def _haar_unit_dirs(v_grad: dict, seed: int, device) -> dict:
"""Build the reproducible out-of-subspace directionality control (placebo) for routeV.
Matches v_grad's [k, r] shape and unit-normalizes per row, so a top-k routing run
gets k random dirs (the gate's max-cosine still sees the same shape) for a fair placebo."""
g = torch.Generator().manual_seed(seed)
out = {}
for name in sorted(v_grad):
d = torch.randn(v_grad[name].shape, generator=g) # [k, r]
out[name] = (d / d.norm(dim=-1, keepdim=True).clamp_min(1e-12)).to(device)
return out
def _build_v_grad(raw_grads: dict, names, k: int, device) -> dict[str, Float[torch.Tensor, "k r"]]:
"""Per-module routing directions from authored-pair gradients -> dict[name -> [k, r]].
k=1 (headline): normalized mean(hack-clean) -- one mean-mass axis. k>1: the top-k
oriented right singular vectors of the paired-diff matrix D=[g_hack-g_clean] (SVD +
per-pair majority orient, mirroring extract_vhack_grad.extract_v_hack), a rank-k hack
subspace. The live gate scores max_i cos(g, v_i). Rows are unit-norm. k=1 is NOT
SVD-top-1 (they differ off-isotropic): keeping mean-diff makes 'mean-mass vs top-k'
a clean A/B, not a confound."""
out = {}
for name in names:
D: Float[torch.Tensor, "n_pairs r"] = (
raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).float()
if k == 1:
d = D.mean(0)
V = (d / d.norm().clamp_min(1e-12)).unsqueeze(0) # [1, r]
else:
_, _, Vh = torch.linalg.svd(D, full_matrices=False)
kk = min(k, Vh.shape[0])
V = Vh[:kk] # [kk, r] orthonormal
proj = torch.einsum("n r, k r -> n k", D, V) # per-pair projection
flip = torch.where((proj > 0).float().sum(0) < D.shape[0] / 2,
-torch.ones(kk), torch.ones(kk)) # orient toward reward-hacking gradients
V = V * flip.unsqueeze(1)
out[name] = V.to(device)
return out
def _otsu3(x: np.ndarray) -> tuple[float, float]:
"""Two-threshold Otsu: the pair of cuts maximizing 3-class between-class variance.
Label-free -- the routeA gate computes this on a rolling buffer of live scores, so
using it is not oracle leakage. Scores are winsorized at 1/99% first: Otsu maximizes
variance, so on heavy-tailed scores a single extreme point otherwise buys a whole
class (journal 2026-06-11 (d): v5 act rout precision 0.00 -> 0.50 after winsorize).
Vectorized over the [n, n] cut grid; n is the buffer size (<= a few hundred)."""
x = np.clip(x, *np.quantile(x, [0.01, 0.99]))
s = np.sort(np.asarray(x, float))
n = len(s)
c = np.concatenate([[0.0], np.cumsum(s)])
iv = np.arange(1, n)
i_g, j_g = iv[:, None], iv[None, :]
with np.errstate(divide="ignore", invalid="ignore"):
obj = (c[i_g] ** 2 / i_g
+ (c[j_g] - c[i_g]) ** 2 / (j_g - i_g)
+ (c[n] - c[j_g]) ** 2 / (n - j_g))
obj[(j_g <= i_g) | (j_g >= n)] = -np.inf # need i < j and a nonempty top class
i, j = np.unravel_index(np.argmax(obj), obj.shape)
i, j = iv[i], iv[j]
return float((s[i - 1] + s[i]) / 2), float((s[j - 1] + s[j]) / 2)
def _sample_rows(rows: list[dict] | None, n: int, rng: torch.Generator) -> list[dict]:
@@ -116,37 +103,15 @@ def _sample_rows(rows: list[dict] | None, n: int, rng: torch.Generator) -> list[
return [rows[i] for i in idxs]
def _zone_stats(f: torch.Tensor, w: torch.Tensor) -> tuple[float, ...]:
"""Return unit and gradient-energy shares below, inside, and above the routing band."""
if f.numel() == 0:
return (float("nan"),) * 6
lo, hi = (f == 0), (f == 1)
mid = ~(lo | hi)
tot = w.sum().clamp_min(1e-12)
return (lo.float().mean().item(), mid.float().mean().item(), hi.float().mean().item(),
((w * lo).sum() / tot).item(), ((w * mid).sum() / tot).item(), ((w * hi).sum() / tot).item())
def _pair_cos(raw_grads: dict, v: Float[torch.Tensor, "k r"], name: str
) -> tuple[Float[torch.Tensor, "n_pairs"], Float[torch.Tensor, "n_pairs"]]:
"""(clean, hack) pair cosines vs the routing dirs: max_i cos(g, v_i), the same
scoring the live gate uses, so band edges and thresholds use the same representation."""
gh = raw_grads[f"hack/{name}"].float() # [n_pairs, r]
gc = raw_grads[f"clean/{name}"].float()
ch = torch.einsum("n r, k r -> n k", gh, v).max(dim=1).values / gh.norm(dim=1).clamp_min(1e-12)
cc = torch.einsum("n r, k r -> n k", gc, v).max(dim=1).values / gc.norm(dim=1).clamp_min(1e-12)
return cc, ch
def _auroc(scores: list[float], labels: list[bool]) -> float:
"""Rank-based AUROC (Mann-Whitney U) of `scores` as a detector of the positive class.
Higher score for hacks -> auroc > 0.5. nan if either class is absent this step.
Diagnostic only: ground-truth labels measure how well cos(g, v_grad) separates
reward-hacking updates, but never determine a route. Reading: ~0.5 means v_grad
Diagnostic only: ground-truth labels measure how well the gate score separates
reward-hacking updates, but never determine a route. Reading: ~0.5 means v_act
is a chance-level classifier (no threshold can route reliably); high AUROC but
rout~0 = the threshold/scale is wrong, not the direction; a drop across a refresh =
the refresh destroyed the separation (the step-5 cliff is then a direction problem)."""
the refresh destroyed the separation."""
pos = [s for s, y in zip(scores, labels) if y]
neg = [s for s, y in zip(scores, labels) if not y]
if not pos or not neg:
@@ -167,22 +132,6 @@ def _auroc(scores: list[float], labels: list[bool]) -> float:
return (sum_pos - n_pos * (n_pos + 1) / 2) / (n_pos * n_neg)
def route_band_edges(raw_grads: dict, v_grad: dict, device) -> dict[str, tuple[float, float]]:
"""Calibrate an absolute routing band from authored pairs only.
Clean/hack p75 edges avoid single-pair extremes and route only the confident
tail aligned with reward-hacking gradients. Pair/live shift can still make routing
idle; inspect `rout`.
See docs/papers/grad_routing/paper_sgtm.md.
"""
band = {}
for name in v_grad:
cc, ch = _pair_cos(raw_grads, v_grad[name].detach().cpu().float(), name)
band[name] = (cc.quantile(0.75).item(), ch.quantile(0.75).item()) # (lower=p75 clean, upper=p75 hack)
return band
# Fix evaluation sampling across steps and arms without perturbing the training RNG.
EVAL_GEN_SEED = 12345
@@ -196,12 +145,12 @@ MODE_CODE: dict[str, str] = {
def _validate_config(cfg: Config) -> None:
"""Reject contradictory experiment settings before model load."""
if cfg.intervention not in ("none", "routeV", "absorb"):
raise ValueError(f"unknown intervention {cfg.intervention!r}; expected none|routeV|absorb")
if cfg.routeV_random_v_seed is not None and cfg.intervention != "routeV":
raise ValueError("routeV_random_v_seed is a routeV-only placebo control")
if cfg.intervention not in ("none", "routeA", "absorb"):
raise ValueError(f"unknown intervention {cfg.intervention!r}; expected none|routeA|absorb")
if cfg.routeA_random_v_seed is not None and cfg.intervention != "routeA":
raise ValueError("routeA_random_v_seed is a routeA-only placebo control")
if cfg.rollout_ablate_frac > 0 and cfg.intervention == "none":
raise ValueError("rollout_ablate_frac needs a quarantine to ablate (routeV/absorb)")
raise ValueError("rollout_ablate_frac needs a quarantine to ablate (routeA/absorb)")
if cfg.weight_decay != 0.0:
raise ValueError("lora2r init is nonzero; AdamW decay pulls A/B toward 0 not toward init "
"-- set --weight-decay=0")
@@ -215,7 +164,7 @@ def _validate_config(cfg: Config) -> None:
def _log_resolved_config(cfg: Config, device) -> None:
"""One block with every None resolved to its effective value, so a detached log
shows exactly what ran -- especially WHICH pairset (the field readers kept losing)."""
is_routeV = cfg.intervention == "routeV"
is_routeA = cfg.intervention == "routeA"
fields = {
"preset/arm": f"{cfg.preset_name} / {cfg.arm}",
"intervention": cfg.intervention,
@@ -225,8 +174,8 @@ def _log_resolved_config(cfg: Config, device) -> None:
"lora_r/init_seed": f"{cfg.lora_r} / {cfg.lora_init_seed}",
"unhackable_frac": cfg.unhackable_frac,
"env_mode": cfg.env_mode,
"pairset": cfg.vhack_pairs_path if is_routeV else "unused (not routeV)",
"routeV placebo seed": cfg.routeV_random_v_seed if is_routeV else "n/a",
"pairset": cfg.vhack_pairs_path if is_routeA else "unused (not routeA)",
"routeA placebo seed": cfg.routeA_random_v_seed if is_routeA else "n/a",
"teacher pool/mix/off_step": (
f"{cfg.teacher_pool_dir.name} / {cfg.mix_ratio} / {cfg.teacher_off_step}"
if cfg.teacher_pool_dir else "none (pure on-policy)"),
@@ -254,10 +203,10 @@ def main(cfg: Config) -> int:
logger.info(f"verbose log: {verbose_log}")
_log_resolved_config(cfg, device)
is_routeV = cfg.intervention == "routeV"
is_routeA = cfg.intervention == "routeA"
is_absorb = cfg.intervention == "absorb"
is_vanilla = cfg.intervention == "none"
has_quarantine = is_routeV or is_absorb
has_quarantine = is_routeA or is_absorb
# Only adapter parameters train; the base model remains frozen.
tok = AutoTokenizer.from_pretrained(model_name)
@@ -276,9 +225,10 @@ def main(cfg: Config) -> int:
model.config.use_cache = False
# ── adapter: rank-2r LoRA, deployed block [:r] + quarantine block [r:] ──
# routeV needs the per-rollout c-probe gate; none/absorb pin the masks instead.
# The routeA gate reads activations via forward hooks; no grad probe in training
# (the c-probe stays in lora2r only for scripts/diag_pinning.py diagnostics).
wrappers = wrap_model_with_lora2r(
model, r=cfg.lora_r, init_seed=cfg.lora_init_seed, grad_probe=is_routeV)
model, r=cfg.lora_r, init_seed=cfg.lora_init_seed)
# A and B both train; quarantine = block slices of the SAME tensors, so there
# is no separate hack-param list (per-rollout masks route grads, not surgery).
delta_params = [p for info in wrappers.values() for p in (info["A"], info["B"])]
@@ -287,51 +237,47 @@ def main(cfg: Config) -> int:
logger.info(f"trainable lora2r A+B: {sum(p.numel() for p in delta_params):,} "
f"({n_quar:,} of those in quarantine blocks)")
# ── routeV direction: v_grad (mean pair-gradient diff) + routing band ──
v_grad = None # set only by the routeV branch below
route_band = None
if is_routeV:
# ── routeA direction: v_act (mean pooled-act pair diff, unit rows per module) ──
v_act = None # [M, r] cpu fp32; module order = act_names
act_names = sorted(wrappers)
act_buf: deque | None = None # rolling pooled acts [M, r]; re-scored vs the CURRENT
# v_act at each gate call, so a refresh needs no flush
MASK_PAIRS = None
if is_routeA:
# Authored pairs are the only routing-label source; live oracle labels never enter training.
from .pairs import load_pairs
from .extract_vhack_grad import extract_v_hack
MASK_PAIRS = load_pairs(cfg.vhack_pairs_path)
logger.info(f"routeV pairs: {cfg.vhack_pairs_path} -> {len(MASK_PAIRS)} pairs")
model.eval() # match standalone extract: deterministic backward, no dropout
_, _, raw_grads, _ = extract_v_hack(
model, tok, wrappers, MASK_PAIRS,
top_k=1, tau_axis=0.0, n_heldout=2, device=device,
)
v_grad = _build_v_grad(raw_grads, wrappers, cfg.v_grad_k, device)
_vk = "mean-diff (mean-mass)" if cfg.v_grad_k == 1 else f"top-{cfg.v_grad_k} SVD subspace, gate=max_i cos"
logger.info(f"routeV grad: built v_grad ({_vk}) for {len(v_grad)} modules")
if cfg.routeV_random_v_seed is not None:
v_grad = _haar_unit_dirs(v_grad, cfg.routeV_random_v_seed, device)
logger.info(f"routeV grad: OVERRODE v_grad with Haar-random dirs "
f"(seed={cfg.routeV_random_v_seed}) -- placebo directionality control")
# Calibrate after any Haar override so the control covers the full routing pipeline.
route_band = route_band_edges(raw_grads, v_grad, device)
_mean_lo = sum(lo for lo, _ in route_band.values()) / len(route_band)
_mean_hi = sum(hi for _, hi in route_band.values()) / len(route_band)
_mean_bw = _mean_hi - _mean_lo
n_inc_band = sum(1 for lo, hi in route_band.values() if hi - lo > 0)
logger.info(
f"routeV band: {len(route_band)} modules, mean lower(p75 clean cos)={_mean_lo:+.3f}, "
f"mean upper(p75 hack cos)={_mean_hi:+.3f}, mean width={_mean_bw:+.3f}; "
f"{n_inc_band}/{len(route_band)} modules have positive band width (included in the gate). "
f"SHOULD: width>0 (pairs separate) and most modules included; ELSE extraction/band off.")
# Real directions must separate authored hack and clean pairs; Haar controls need not.
if cfg.routeV_random_v_seed is None:
assert _mean_bw > 0, (
f"real v_grad gave non-positive mean band width {_mean_bw:+.3f}: "
"hack pairs do not separate from clean -> extraction broken")
logger.info(
"lora2r three-way output mask: per-rollout label from the width-pooled "
"band-normalized cosine across modules; clean->deployed-only, "
"hack->quarantine-only (deployed detached), mid->both (may permit absorption). "
"SHOULD: rout (hack share) tracks the step's rollout hack rate, not ~50%; "
"clipfrac on clean-gated rollouts < ~0.2 ELSE the retain-trick ratio "
"drift is binding (quarantine forward too large).")
logger.info(f"routeA pairs: {cfg.vhack_pairs_path} -> {len(MASK_PAIRS)} pairs")
model.eval() # deterministic forward, no dropout
v_act, pair_acts = extract_v_act(model, tok, wrappers, MASK_PAIRS, device,
tstat=cfg.vact_tstat)
model.train()
# Authored-pair separation in the live score. The dot GAP is > 0 by construction
# (v is proportional to the mean pair diff); the pair AUROC is not, so it is the
# extraction sanity signal.
sh = torch.einsum("pmr,mr->p", pair_acts["hack"], v_act)
sc = torch.einsum("pmr,mr->p", pair_acts["clean"], v_act)
pair_auroc = _auroc(torch.cat([sh, sc]).tolist(),
[True] * len(sh) + [False] * len(sc))
logger.info(
f"routeA v_act: {v_act.shape[0]} modules x r={v_act.shape[1]} "
f"(tstat={cfg.vact_tstat}); authored-pair dot gap={(sh.mean() - sc.mean()).item():+.3e}, "
f"pair AUROC={pair_auroc:.2f}. SHOULD: pair AUROC ~1.0 ELSE extraction broken.")
if cfg.routeA_random_v_seed is not None:
v_act = haar_unit_rows(tuple(v_act.shape), cfg.routeA_random_v_seed)
logger.info(f"routeA: OVERRODE v_act with Haar-random unit rows "
f"(seed={cfg.routeA_random_v_seed}) -- placebo directionality control")
act_buf = deque(maxlen=cfg.route_buffer)
logger.info(
f"routeA gate: per-rollout score = dot(pooled completion-token act, v_act), "
f"thresholds = two-threshold Otsu on the last <= {cfg.route_buffer} live scores "
f"(z-normalized, winsorized 1/99%), label-free; pinned absorb until "
f"{cfg.route_warmup} scores. keep (0,0) | absorb (1,0) | rout (1,1: deployed "
f"detached). No bimodality guard: on the cached emergence windows no shape "
f"statistic separates the hack mixture from hack-free scores (Otsu tail means "
f"sit ~2.4 sd apart even on a Gaussian), and a false rout only discards one "
f"update from deployment. "
f"SHOULD: auroc col >> 0.5 once hacks appear ELSE v_act is blind and routing "
f"is noise; rout tracks the hack share, not ~0 or ~1.")
# ── teacher pool ──
# Teacher pool: pre-generated rollouts on disk keyed by problem_id. Each step's
@@ -348,7 +294,7 @@ def main(cfg: Config) -> int:
G_t = 0
if cfg.teacher_pool_dir is not None:
# mix=0 is the NO-TEACHER ablation: pure on-policy GRPO (G_t=0) while the
# pool is still loaded for the partition + routeV v_grad extraction.
# pool is still loaded for the partition.
if not (0.0 <= cfg.mix_ratio < 1.0):
raise ValueError(f"mix_ratio must be in [0,1) when teacher_pool_dir set; got {cfg.mix_ratio}")
G_t = round(group * cfg.mix_ratio)
@@ -568,57 +514,31 @@ def main(cfg: Config) -> int:
save_ckpt([], path=run_dir / "ckpt_update0000.safetensors")
def _lora2r_gate_labels(c_grads: tuple, n_rollouts: int):
"""Three-way output-mask label per rollout from the gate-pass c-probe grads.
def _routeA_gate(dots: Float[torch.Tensor, "G"]) -> tuple[torch.Tensor, torch.Tensor, float, float]:
"""Three-way output-mask label per rollout from the rolling score buffer.
Per module the per-rollout weight grad of the virtual diagonal (deployed
block [r]) has a band-normalized cosine position. We POOL across modules in
a single (num, den) fraction (T3 fix): a module with a wide band contributes
proportionally more than a noisy near-zero-width one, instead of every module
casting an equal-weight vote. One global label per rollout (matching SGTM's
example-level labels): pos<=0 clean (m=0,d=0); pos>=1 hack (m=1,d=1); else mid
(m=1,d=0, both blocks train). Returns (m, d, f3, w, pos, cosU): f3 in {0,.5,1} for
_zone_stats, w = mean per-rollout grad norm for energy weighting, pos = the raw
per-rollout pooled position (for the AUROC diagnostic), cosU = pooled cos of the
SUMMED-rollout c-grad (the update direction) to v_grad."""
num = torch.zeros(n_rollouts, device=device); den = 0.0
w = torch.zeros(n_rollouts, device=device); n_inc = 0
cosU_sum = 0.0
for (name, info), cg in zip(wrappers.items(), c_grads, strict=True):
lower, upper = route_band[name]
if upper - lower <= 0: # noisy module: pairs don't separate -> excluded
continue
r_blk = info["r"]
g_b: Float[torch.Tensor, "G r"] = cg.reshape(n_rollouts, -1, 2 * r_blk).sum(1)[:, :r_blk].float()
nrm = g_b.norm(dim=1)
# cos to each of the k routing dirs, then max: aligned with ANY known hack sub-mode.
cos_b = torch.einsum("g r, k r -> g k", g_b, v_grad[name]).max(dim=1).values / nrm.clamp_min(1e-12)
num += cos_b - lower; den += upper - lower
w += nrm; n_inc += 1
gs = g_b.sum(0) # summed-rollout deployed c-grad = this module's update direction
cosU_sum += (torch.einsum("r, k r -> k", gs, v_grad[name]).max() / gs.norm().clamp_min(1e-12)).item()
if n_inc == 0:
raise RuntimeError("no module has positive band width; pairs separate nowhere")
pos = num / den; w /= n_inc
# ── online-stats gate (#30): band from THIS batch's pooled positions, three zones ──
# The authored absolute band is mis-placed (live pos sits far below the synthetic-hack
# edge -- out/diag/pinning_calib.png) and a fixed quantile FORCES route_quantile out every
# step even when nothing separates. Calibrate the band from the CURRENT batch instead:
# refresh-proof by construction (these rollouts scored against the current v_grad), no
# window or flush to keep stale positions around. mean + k*std self-silences -- only the
# tail genuinely beyond the spread routes, so qmass tracks real separation. pos >
# mean+route_std_mid*std -> mid (both blocks train); pos >= mean+route_std_rout*std -> rout
# (hack, deployed detached); below -> keep (bulk). Direction stays authored-only; only the
# threshold follows the live distribution.
mu_pos, sd_pos = pos.mean().item(), pos.std().item()
t_lo = mu_pos + cfg.route_std_mid * sd_pos # mid onset
t_hi = mu_pos + cfg.route_std_rout * sd_pos # rout onset (>= mid)
logger.debug(f"gate pos: n_inc={n_inc}/{len(wrappers)} n_roll={n_rollouts} "
f"min={pos.min().item():+.2f} max={pos.max().item():+.2f} | "
f"mean={mu_pos:+.2f} std={sd_pos:.2f} t_lo={t_lo:+.2f} t_hi={t_hi:+.2f}")
m = (pos > t_lo).float() # mid + rout -> quarantine trains (keep = bulk below t_lo)
d = (pos >= t_hi).float() # top tail -> hack -> deployed detached
return m, d, 0.5 * m + 0.5 * d, w, pos, cosU_sum / n_inc
The buffer holds pooled ACTS, so every gate call scores the whole window
against the CURRENT v_act (refresh-proof; the only staleness left is act
drift as the adapter trains, small over <= route_buffer rollouts). Scores
are z-normalized by the buffer mean/std, then two-threshold Otsu (winsorized
inside _otsu3) places (t_lo, t_hi): z <= t_lo keep (0,0); t_lo < z < t_hi
absorb (1,0, both blocks train); z >= t_hi rout (1,1, deployed detached).
Warmup: pinned absorb until the buffer holds route_warmup scores -- too few
points to place thresholds, and absorb keeps both blocks learning."""
if len(act_buf) < cfg.route_warmup:
G_n = dots.shape[0]
return (torch.ones(G_n, device=device), torch.zeros(G_n, device=device),
float("nan"), float("nan"))
S = torch.einsum("nmr,mr->n", torch.stack(tuple(act_buf)), v_act)
mu, sd = S.mean().item(), max(S.std().item(), 1e-12)
t_lo, t_hi = _otsu3(((S - mu) / sd).numpy())
z = (dots - mu) / sd
m = (z > t_lo).float().to(device) # absorb + rout -> quarantine trains
d = (z >= t_hi).float().to(device) # top zone -> rout -> deployed detached
logger.debug(f"routeA gate: buf={len(act_buf)} mu={mu:+.3e} sd={sd:.3e} "
f"t_lo={t_lo:+.2f} t_hi={t_hi:+.2f} (z) | batch z "
f"min={z.min().item():+.2f} max={z.max().item():+.2f}")
return m, d, t_lo, t_hi
# Disable tqdm off-TTY because structured per-step rows already report progress.
pbar = tqdm(range(steps), desc=f"train {cfg.arm} {cfg.preset_name}",
@@ -634,7 +554,7 @@ def main(cfg: Config) -> int:
logger.info(f"teacher-off curriculum: step {step} >= {cfg.teacher_off_step} "
f"-> teachers off (pure on-policy on the wider problem set from here)")
# mix_ratio>0 is the teacher on/off switch (mix=0 = no-teacher ablation: pool still
# loaded for v_grad extraction, but no demos injected). The COUNT is teacher_n_per_prompt.
# loaded for the partition, but no demos injected). The COUNT is teacher_n_per_prompt.
teachers_on = (not teacher_off) and cfg.mix_ratio > 0 \
and bool(covered_problems) and bool(teacher_pool or solve_pool)
t0 = time.time()
@@ -651,13 +571,15 @@ def main(cfg: Config) -> int:
n_zerovar = 0 # groups skipped for zero reward variance (all rollouts same reward).
agg_loss = 0.0
diag_tail = None
# routeV gate diagnostics (per-rollout three-way zone shares + retain-trick clipfrac).
step_flagged: list[float] = [] # hack share (mean d over rollouts) per prompt
step_clipfrac: list[float] = [] # PPO clip frac on clean-gated rollouts (retain-trick drift gauge)
step_zkeep: list[float] = []; step_zresid: list[float] = []; step_zrout: list[float] = [] # unit shares per zone
step_zkeepE: list[float] = []; step_zresidE: list[float] = []; step_zroutE: list[float] = [] # energy shares per zone
# AUROC diagnostic: per-rollout pooled pos + its hack-label, accumulated across prompts.
step_auroc_pos: list[float] = []; step_auroc_hack: list[bool] = []; step_cosU: list[float] = []
# routeA gate diagnostics (per-rollout three-way zone shares + clean-gated clipfrac).
step_clipfrac: list[float] = [] # PPO clip frac on keep-gated rollouts (ratio-drift gauge)
step_zkeep: list[float] = []; step_zresid: list[float] = []; step_zrout: list[float] = [] # unit shares per zone
step_tlo: list[float] = []; step_thi: list[float] = [] # Otsu thresholds (z units)
# AUROC diagnostic on the A>0 contrast: scores + hack-labels of positively-
# reinforced rollouts only (where the advantage alone is blind), students +
# cached teachers. Accumulated across prompts; measurement only, never routes.
step_auroc_score: list[float] = []; step_auroc_hack: list[bool] = []
step_cos: list[float] = [] # mean per-rollout cos(act, v_act) (dot-vs-cos diagnostic)
# Solve-mix discrimination: routed-share (mean d) over hack-teacher vs solve-teacher rollouts.
step_route_hackT: list[float] = []; step_route_solveT: list[float] = []
@@ -870,18 +792,33 @@ def main(cfg: Config) -> int:
# logπ_old: old-policy logprobs (frozen PPO-ratio target). logits_to_keep
# =L_c+1 runs lm_head only on completion-side hidden states; [:, :-1] drops
# the last position (predicts beyond `merged`, unused).
# For routeA this forward runs QUARANTINE-ABLATED, matching both the sampling
# policy (gen_students is deploy-mode) and the v_act extraction (quarantine-
# ablated), so the gate score and the vector live on the same observable path.
# The same forward carries the ActCapture hooks: the gate costs no extra pass.
completion_ids = merged[:, plen:]
L_c = completion_ids.shape[1]
mask = (completion_ids != pad_id).float()
_tfb = time.perf_counter()
with torch.no_grad():
logπ_old = per_token_logps(
model(merged, logits_to_keep=L_c + 1).logits[:, :-1],
completion_ids,
).detach()
if is_routeA:
with torch.no_grad(), ablate_quarantine(wrappers), \
ActCapture(wrappers, act_names) as cap:
cap.set_pool(plen, mask)
logπ_old = per_token_logps(
model(merged, logits_to_keep=L_c + 1).logits[:, :-1],
completion_ids,
).detach()
acts = cap.pooled().cpu() # [G, M, r] fp32
else:
with torch.no_grad():
logπ_old = per_token_logps(
model(merged, logits_to_keep=L_c + 1).logits[:, :-1],
completion_ids,
).detach()
# Pin block masks for the non-gated arms BEFORE the grad-carrying forward
# (arm semantics: train_config.py docstring): none -> (0,0), absorb -> (1,0).
# routeV leaves mask=None so the gate pass sees an unmasked forward.
# Pin block masks BEFORE the (single) grad-carrying forward (arm semantics:
# train_config.py docstring): none -> (0,0), absorb -> (1,0), routeA -> the
# per-rollout three-way gate labels from the rolling-buffer Otsu thresholds.
if is_vanilla:
_z = torch.zeros(merged.shape[0], device=device)
for info in wrappers.values():
@@ -891,13 +828,42 @@ def main(cfg: Config) -> int:
_z = torch.zeros(merged.shape[0], device=device)
for info in wrappers.values():
info["layer"]._lora2r_mask = (_o, _z)
elif is_routeA:
dots = torch.einsum("gmr,mr->g", acts, v_act) # [G]
# cos = dot / (||act|| ||v||); v rows are unit so ||v|| = sqrt(M).
coss = dots / (acts.flatten(1).norm(dim=1)
* math.sqrt(len(act_names))).clamp_min(1e-12)
step_cos.append(coss.mean().item())
act_buf.extend(acts.unbind(0))
m_vec, d_vec, _tl, _th = _routeA_gate(dots)
for info in wrappers.values():
info["layer"]._lora2r_mask = (m_vec, d_vec)
step_tlo.append(_tl); step_thi.append(_th)
step_zkeep.append((m_vec == 0).float().mean().item())
step_zresid.append(((m_vec == 1) & (d_vec == 0)).float().mean().item())
step_zrout.append((d_vec == 1).float().mean().item())
# AUROC diagnostic on the A>0 contrast: merged order is [students;
# teachers], the same order hack_flags was built in, so dots aligns.
pos_a = (A > 0).cpu().tolist()
step_auroc_score.extend(s for s, p in zip(dots.tolist(), pos_a) if p)
step_auroc_hack.extend(bool(h) for h, p in zip(hack_flags, pos_a) if p)
# Solve-mix discrimination: teachers are the LAST G_t rows of merged; split
# their routed-share (mean d) by source. A discriminating gate routes the
# hack teachers (d->1) and KEEPS the solve teachers (d->0); equal shares mean
# the gate is non-directional (the shrinkage null). Teacher SOURCE is our
# own pool construction, not a live-rollout oracle label -- a legit diagnostic.
if teacher_is_solve:
is_solve_t = torch.tensor(teacher_is_solve, device=d_vec.device, dtype=torch.bool)
d_teach = d_vec[-len(teacher_is_solve):]
if (~is_solve_t).any():
step_route_hackT.append(d_teach[~is_solve_t].mean().item())
if is_solve_t.any():
step_route_solveT.append(d_teach[is_solve_t].mean().item())
logπ = per_token_logps(
model(merged, logits_to_keep=L_c + 1).logits[:, :-1],
completion_ids,
)
mask = (merged[:, plen:] != pad_id).float()
# Per-rollout mean per-token logπ_old (student's logp on its own tokens).
# Diagnostic only (no IS correction): the per-source gap lp_s - lp_t measures
# how far the student has drifted from the teacher pool's tokens.
@@ -914,56 +880,20 @@ def main(cfg: Config) -> int:
ptl = (Lp_ * mask).sum(1) / mask.sum(1).clamp_min(1)
return ptl.sum() / (group * prompts_per_step)
# Three-way output masking; gradients accumulate on A/B.
# Gradient-space labels exist only AFTER a backward (labels: before forward;
# activations: before backward; grads: after), so routeV pays a second masked
# forward+backward. none/absorb were pinned before the logπ forward and need
# only this one pass.
# One masked forward+backward for EVERY arm; rollouts route to BLOCKS via
# the output masks pinned above (nothing is subtracted from any gradient
# vector; v_act is a classifier only). Gradients accumulate on A/B.
loss = _grpo_loss(Lp)
if is_routeV:
# PASS 1 (gate): grads w.r.t. the c-probes ONLY. autograd.grad leaves
# A.grad/B.grad untouched, so nothing to zero between passes.
gates = [info["layer"]._lora2r_gate for info in wrappers.values()]
c_grads = torch.autograd.grad(loss, gates)
m_vec, d_vec, f3, w3, pos_vec, cosU = _lora2r_gate_labels(c_grads, merged.shape[0])
step_flagged.append(d_vec.mean().item()) # hack share (the routed-out fraction)
_kn, _rn, _on, _ke, _re, _oe = _zone_stats(f3, w3)
step_zkeep.append(_kn); step_zresid.append(_rn); step_zrout.append(_on)
step_zkeepE.append(_ke); step_zresidE.append(_re); step_zroutE.append(_oe)
# AUROC diagnostic: pos as a hack-detector vs the hack-label (student
# exploited + teacher cached). merged order is [students; teachers], the
# same order hack_flags was built in, so pos_vec aligns with hack_flags.
step_auroc_pos.extend(pos_vec.detach().cpu().tolist())
step_auroc_hack.extend(bool(h) for h in hack_flags)
step_cosU.append(cosU)
# Solve-mix discrimination: teachers are the LAST G_t rows of merged; split
# their routed-share (mean d) by source. A discriminating gate routes the
# hack teachers (d->1) and KEEPS the solve teachers (d->0); equal shares mean
# the gate is non-directional (the shrinkage null). Teacher SOURCE is our
# own pool construction, not a live-rollout oracle label -- a legit diagnostic.
if teacher_is_solve:
is_solve_t = torch.tensor(teacher_is_solve, device=d_vec.device, dtype=torch.bool)
d_teach = d_vec[-len(teacher_is_solve):]
if (~is_solve_t).any():
step_route_hackT.append(d_teach[~is_solve_t].mean().item())
if is_solve_t.any():
step_route_solveT.append(d_teach[is_solve_t].mean().item())
# PASS 2 (masked): rollouts route to BLOCKS via masked sums; nothing is
# subtracted from any gradient vector (v_grad = classifier only).
for info in wrappers.values():
info["layer"]._lora2r_mask = (m_vec, d_vec)
logπ2 = per_token_logps(
model(merged, logits_to_keep=L_c + 1).logits[:, :-1], completion_ids)
ρ2 = torch.exp(logπ2 - logπ_old)
loss = _grpo_loss(-torch.min(ρ2 * A_tok,
torch.clamp(ρ2, 1 - cfg.clip, 1 + cfg.clip) * A_tok))
# Retain-trick wrinkle: clean rollouts were SAMPLED quarantine-on but TRAIN
# quarantine-off; the PPO ratio absorbs the gap, clip bounds it.
clean = m_vec == 0
if clean.any():
clipped = ((ρ2.detach() - 1).abs() > cfg.clip).float()
if is_routeA:
# Keep-gated rollouts train quarantine-off, the exact state generation
# and logπ_old used, so their ratio sits ~1. Absorb/rout rollouts see
# the quarantine delta in the forward only -> ratio drift, bounded by
# the clip; clipfrac on those rollouts is the drift gauge.
qon = m_vec == 1
if qon.any():
clipped = ((ρ.detach() - 1).abs() > cfg.clip).float()
step_clipfrac.append(
((clipped * mask)[clean].sum() / mask[clean].sum().clamp_min(1)).item())
((clipped * mask)[qon].sum() / mask[qon].sum().clamp_min(1)).item())
loss.backward() # A/B grads accumulate across prompts (opt.zero_grad clears per step)
for info in wrappers.values():
info["layer"]._lora2r_mask = None
@@ -990,38 +920,26 @@ def main(cfg: Config) -> int:
opt.step()
sched.step()
# ── v_grad refresh ──
# ── v_act refresh ──
# Re-extract the routing direction against the CURRENT model so it tracks where
# hacks separate now, not at step 0. Without this the frozen direction goes stale.
# Same MASK_PAIRS (the authored pairs, no oracle); quarantine ablated so the hack
# signal flows back through the observable path, matching the build-time extract.
# signal is read on the deployed observable path, matching the build-time extract
# and the gate forward. Forward-only, so the refresh is cheap. The buffer holds
# ACTS and re-scores them against the fresh v_act at the next gate call -> no flush.
refr = "-"
do_refresh = cfg.vhack_refresh_every > 0 and (step + 1) % cfg.vhack_refresh_every == 0
if do_refresh and is_routeV and cfg.routeV_random_v_seed is not None:
do_refresh = False # keep the one fixed Haar draw; re-extracting would replace it
if do_refresh and is_routeV:
do_refresh = (is_routeA and cfg.vhack_refresh_every > 0
and (step + 1) % cfg.vhack_refresh_every == 0
and cfg.routeA_random_v_seed is None) # placebo keeps its one Haar draw
if do_refresh:
_was_training = model.training
model.eval()
opt.zero_grad(set_to_none=True)
logger.disable("vgrout.extract_vhack_grad")
logger.disable("__main__")
try:
with ablate_quarantine(wrappers):
from .extract_vhack_grad import extract_v_hack
_, _, raw_grads, _ = extract_v_hack(
model, tok, wrappers, MASK_PAIRS,
top_k=1, tau_axis=0.0, n_heldout=2, device=device,
)
# update in place so the gate closure sees the fresh dirs (same k as init).
v_grad.update(_build_v_grad(raw_grads, wrappers, cfg.v_grad_k, device))
route_band = route_band_edges(raw_grads, v_grad, device) # rebuild band on fresh v_grad
finally:
logger.enable("vgrout.extract_vhack_grad")
logger.enable("__main__")
opt.zero_grad(set_to_none=True) # extract leaves .grad populated
with ablate_quarantine(wrappers):
v_act, _ = extract_v_act(model, tok, wrappers, MASK_PAIRS, device,
tstat=cfg.vact_tstat)
if _was_training:
model.train()
refr = "rfr" # gate calibrates from the live batch each step -> no window to flush
refr = "rfr"
# ── periodic held-out eval (deploy = quarantine ablated) ──
hack_deployed = solve_deployed = float("nan")
@@ -1114,13 +1032,14 @@ def main(cfg: Config) -> int:
f"step {step} TIMING gen={t_gen:.0f}s fwd_bwd={t_fb:.0f}s "
f"reward={t_rew:.0f}s other={_tstep - t_gen - t_fb - t_rew:.0f}s total={_tstep:.0f}s")
if step_clipfrac:
logger.debug(f"routeV clean-gated clipfrac={sum(step_clipfrac)/len(step_clipfrac):.3f} "
f"(SHOULD: <~0.2; higher = retain-trick ratio drift binding)")
logger.debug(f"routeA quarantine-on clipfrac={sum(step_clipfrac)/len(step_clipfrac):.3f} "
f"(SHOULD: <~0.2; higher = quarantine forward delta drifting far "
f"from the ablated old policy)")
if step_route_hackT or step_route_solveT:
_rh = sum(step_route_hackT) / len(step_route_hackT) if step_route_hackT else float("nan")
_rs = sum(step_route_solveT) / len(step_route_solveT) if step_route_solveT else float("nan")
route_hackT_run.append(_rh); route_solveT_run.append(_rs)
logger.debug(f"routeV solve-mix discrimination: hack-teacher routed={_rh:.2f} vs "
logger.debug(f"routeA solve-mix discrimination: hack-teacher routed={_rh:.2f} vs "
f"solve-teacher routed={_rs:.2f} (SHOULD: hack >> solve -> gate "
f"discriminates correct-solution from reward-hacking updates; ~equal -> non-directional/shrinkage)")
if diag_tail is not None:
@@ -1148,15 +1067,16 @@ def main(cfg: Config) -> int:
"lp_t": lp_t_mean if n_t else None,
"loss": agg_loss,
"gn": gn,
"auroc": _auroc(step_auroc_pos, step_auroc_hack),
"cosU": (sum(step_cosU) / len(step_cosU)) if step_cosU else float("nan"),
# auroc is the A>0 contrast (hack vs non-hack among positively-reinforced
# rollouts) -- the contrast where the reward alone is blind.
"auroc": _auroc(step_auroc_score, step_auroc_hack),
"cos": (sum(step_cos) / len(step_cos)) if step_cos else float("nan"),
"qmass": q_egy,
"keep": (sum(step_zkeep) / len(step_zkeep)) if step_zkeep else float("nan"),
"resid": (sum(step_zresid) / len(step_zresid)) if step_zresid else float("nan"),
"rout": (sum(step_zrout) / len(step_zrout)) if step_zrout else float("nan"),
"keepE": (sum(step_zkeepE) / len(step_zkeepE)) if step_zkeepE else float("nan"),
"residE": (sum(step_zresidE) / len(step_zresidE)) if step_zresidE else float("nan"),
"routE": (sum(step_zroutE) / len(step_zroutE)) if step_zroutE else float("nan"),
"tlo": (sum(step_tlo) / len(step_tlo)) if step_tlo else float("nan"),
"thi": (sum(step_thi) / len(step_thi)) if step_thi else float("nan"),
"lr": sched.get_last_lr()[0],
"refr": refr,
# Deploy-eval (quarantine ablated); NaN except on eval steps.
@@ -1236,20 +1156,17 @@ def main(cfg: Config) -> int:
solve_rate_s = gt_s_total / max(1, n_s_total)
hack_rate_t = hack_t_total / max(1, n_t_total)
# routeV/absorb must move the quarantine; none must leave it exactly zero. The
# quarantine LEARNED delta = (A[r:], B[:,r:]) minus the frozen init.
# routeA/absorb must move the quarantine; none must leave it exactly zero. The
# quarantine LEARNED delta = (A[r:], B[:,r:]) minus the frozen init. The routeA
# warmup pins absorb, so even a placebo run trains the quarantine.
dsh_norm = float(sum(
(info["A"].data[info["r"]:] - info["A0"][info["r"]:]).float().pow(2).sum().item()
+ (info["B"].data[:, info["r"]:] - info["B0"][:, info["r"]:]).float().pow(2).sum().item()
for info in wrappers.values()) ** 0.5)
logger.info(f"||quarantine learned delta|| = {dsh_norm:.4f} "
f"(SHOULD: >0 for routeV/absorb, ==0 for none; ELSE routing broke)")
if has_quarantine and cfg.routeV_random_v_seed is None:
f"(SHOULD: >0 for routeA/absorb, ==0 for none; ELSE routing broke)")
if has_quarantine:
assert dsh_norm > 0.0, f"{cfg.intervention}: quarantine never moved -> nothing trained it"
elif cfg.routeV_random_v_seed is not None and dsh_norm == 0.0:
# A Haar control may validly route nothing because no rollout clears its band.
logger.warning("routeV Haar control: ||quarantine delta||==0 -> the random direction routed "
"NOTHING. This is a real result (favours: alignment needed), not a failure.")
# Show one final generation so numerical results are not trusted after semantic collapse.
if last_gen_sample is not None:
+29 -32
View File
@@ -5,7 +5,7 @@ masking; see src/vgrout/lora2r.py) and three arms:
none gate pinned clean (0,0): quarantine never trains -- the capacity- and
structure-matched vanilla control.
routeV per-rollout three-way gate from the c-probe gradient vs v_grad.
routeA per-rollout three-way gate from the pooled bottleneck activation vs v_act.
absorb gate pinned mid (1,0): both blocks train on everything, no gate --
tests ungated both-block training.
"""
@@ -20,7 +20,7 @@ from .rewards import EnvMode
@dataclass(kw_only=True)
class Config:
intervention: Literal["none", "routeV", "absorb"] = "routeV"
intervention: Literal["none", "routeA", "absorb"] = "routeA"
lora_r: int = 32
lora_init_seed: int = 0
@@ -46,30 +46,27 @@ class Config:
unbiased: bool = True
vhack_refresh_every: int = 5
# The 8 original behavior_* pairs only: per-pairset diag (out/diag/pinning_pairset_auroc.png)
# ranks this subset's v_grad best at separating live hacks (AUROC 0.69, d=+0.85), well above
# the full all-in-one. The `@behavior` TAG would re-add the anti-aligned opportunity-aware
# pairs (d=-0.03) and dilute, so select by heading prefix, not tag. Wave-2 arms (untested):
# `/behavior2` = 15 new mechanisms, `/behavior` = 23-pair union.
# The 8 original behavior_* pairs only: the best or tied vector on all three diag
# windows (RESEARCH_JOURNAL 2026-06-11 (d)). The `@behavior` TAG would re-add the
# anti-aligned opportunity-aware pairs and dilute, so select by heading prefix, not
# tag. Wave-2 arms (untested): `/behavior2` = 15 new mechanisms, `/behavior` = union.
vhack_pairs_path: Path = Path("data/pairs/hack_pairs.md#all-in-one/behavior_")
# Routing directions per module. k=1 (headline): the mean(hack-clean) "mean-mass"
# axis. k>1: top-k oriented SVD dirs of the paired diff; gate scores max_i cos(g,v_i)
# (alignment to ANY known hack sub-mode) -- catches multi-modal hack signal one mean
# washes out. k=1 stays mean-diff (not SVD-top-1) so "mean-mass vs top-k" is a clean A/B.
v_grad_k: int = 1
# Online-stats gate (#30): route by the live MEAN + k*STD of the pooled cosine-to-v_grad,
# not a fixed quantile. The authored absolute band is mis-placed (live pos sits far below
# the synthetic-hack edge; even synthetic SOLVE is more hack-aligned than on-policy hack --
# see out/diag/pinning_calib.png), and a fixed quantile FORCES route_quantile of the batch
# out every step even when nothing separates. mean+k*std self-silences: it only routes the
# tail that genuinely exceeds the spread. pos > mean + route_std_mid*std -> mid (both blocks train);
# pos >= mean + route_std_rout*std -> rout (hack, deployed detached); below -> keep (bulk).
# The band is calibrated from the CURRENT batch each step (no window, no flush): refresh-proof
# by construction. Direction stays authored-only; only the threshold follows the live dist.
route_std_mid: float = 2.0
route_std_rout: float = 3.0
# t-stat extraction (v = mean/SE per coordinate, clamp |t|<=3): null at the current
# 8 pairs (journal (e)); revisit when the authored-pair set grows.
vact_tstat: bool = False
# routeA gate thresholds come from a rolling buffer of the last route_buffer live
# scores: z-normalize by buffer mean/std, two-threshold Otsu (winsorized 1/99% --
# journal (d): without it one outlier buys a whole zone) -> keep | absorb | rout.
# Until the buffer holds route_warmup scores the gate pins absorb (both blocks
# train): too few points to place thresholds. The buffer stores pooled ACTS and
# re-scores them against the current v_act, so a refresh needs no flush. No
# bimodality guard: no shape statistic separates the hack mixture from hack-free
# scores on the cached windows (Otsu tail means sit ~2.4 sd apart even on a
# Gaussian), and a false rout only discards one update from deployment.
route_buffer: int = 256
route_warmup: int = 128
# Haar-random direction control (placebo): same routing machinery, no pair signal.
routeV_random_v_seed: int | None = None
routeA_random_v_seed: int | None = None
rollout_ablate_frac: float = 0.0
env_mode: EnvMode = "run_tests"
@@ -112,8 +109,9 @@ class Config:
@property
def arm(self) -> str:
# _lora2r suffix kept so these runs never conflate with the retired
# PiSSA-substrate runs of the same intervention (rename-on-logic-change).
return {"none": "vanilla_lora2r", "routeV": "routingV_lora2r",
# PiSSA-substrate runs of the same intervention (rename-on-logic-change;
# routeA likewise never conflates with the retired grad-gate routeV runs).
return {"none": "vanilla_lora2r", "routeA": "routeA_lora2r",
"absorb": "absorb_lora2r"}[self.intervention]
@@ -126,12 +124,11 @@ class SmokeConfig(Config):
max_new: int = 32
n_problems: int = 100
prompts_per_step: int = 1
# Random tiny data never separates, so the self-silencing band (mean+2/3*std) would route
# nothing and the quarantine would never train -> the routing-pathway smoke assert fails.
# Force routing by lowering the band so the smoke exercises mid+rout (correctness, not the
# real threshold). mid below mean -> most train quarantine; rout at mean -> top ~half detach.
route_std_mid: float = -1.0
route_std_rout: float = 0.0
# Smoke produces 4 scores/step over 30 steps; the real 256/128 buffer would keep the
# gate in warmup forever. Shrink so the smoke exercises warmup AND the Otsu gate
# (keep/absorb/rout + deployed detach) within a few steps.
route_buffer: int = 32
route_warmup: int = 8
@dataclass(kw_only=True)