mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 17:00:59 +08:00
feat(#41): routeA activation gate replaces routeV grad gate
Gate now scores each rollout by dot(pooled bottleneck act, v_act) captured on
the no-grad logpi_old forward (quarantine-ablated, matching the sampling
policy); masks are pinned BEFORE the single grad-carrying forward, so the
grad-gate's pass-1 backward is gone. Thresholds: rolling 256-act buffer,
z-normalized, two-threshold Otsu (winsorized 1/99); warmup pins absorb until
128 scores. Buffer stores pooled acts and re-scores against the current v_act,
so the forward-only refresh (every 5 steps) needs no flush. No bimodality
guard: calibration showed Otsu tail separation ~2.4-2.8 buffer-sd on every
condition including pure Gaussians, so no shape statistic discriminates.
Deleted with the arm wiring (rename-on-logic-change: routeA never conflates
with routeV runs): extract_vhack_grad.py, _build_v_grad, route_band_edges,
_pair_cos, the pass-1 autograd.grad block, grad_probe training wiring,
v_grad_k/route_std_*/routeV_random_v_seed config, smoke-topk recipe.
c-probe stays in lora2r.py for scripts/diag_pinning.py only.
verify_science_invariants: all-in-one count 27 -> 42 (stale since c33b810
added the wave-2 behavior2 pairs) + assert the 8-pair routeA training subset.
Smoke: routeA/vanilla/absorb/solvemix all pass (gate exercises warmup, Otsu
zones, refresh, deploy ablation) -- /tmp/claude-1000/smoke_routeA.log.
Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -36,28 +36,30 @@ outputs (`m` = quarantine on/off, `d` = deployed detach):
|
||||
|
||||
To get the hack direction we pair examples by hand: for each problem, one
|
||||
correct completion and one completion that exploits the evaluation procedure.
|
||||
For each pair we compute the *exact GRPO gradient* that would result if the hack
|
||||
rollout had advantage +1 and the clean rollout had advantage -1
|
||||
(`-grad logp(hack) + grad logp(clean)`), read in the per-module c-probe space (a
|
||||
virtual diagonal scale between `A` and `B`). The mean hack-minus-clean direction,
|
||||
normalized per module, is `v_grad`. (Mechanically identical to a twin-NLL
|
||||
extraction, since GRPO with adv=+/-1 reduces to the NLL difference; the GRPO
|
||||
framing is the one we mean.) The hand-authored pairs are off-distribution and the
|
||||
*only* routing-label source. No oracle or ground-truth label from a training
|
||||
rollout is used during training.
|
||||
For each pair we run a forward pass and read the bottleneck activation `A@x`
|
||||
(the rank-2r input projection of each wrapped Linear), masked-mean-pooled over
|
||||
completion tokens. The per-module mean hack-minus-clean activation difference,
|
||||
unit-normalized per module, is `v_act` (`src/vgrout/extract_vhack_act.py`).
|
||||
Extraction is forward-only: no backward pass, no labels. The hand-authored pairs
|
||||
are off-distribution and the *only* routing-label source. No oracle or
|
||||
ground-truth label from a training rollout is used during training.
|
||||
|
||||
At training time routeV runs **two passes** per prompt. Pass 1 is an unmasked
|
||||
forward+backward whose gradient w.r.t. the c-probes gives, per rollout, a
|
||||
band-normalized cosine to `v_grad` pooled across modules -> a `{clean, mid, hack}`
|
||||
label (band edges = authored-pair p75 cosines). Pass 2 sets the block masks from
|
||||
those labels and does the real masked forward+backward; the optimizer steps on
|
||||
what survives. A rollout with high alignment to the reward-hacking direction updates
|
||||
the quarantine block while its deployed branch is detached. We re-extract `v_grad`
|
||||
every N steps so it tracks the current separation rather than only the initial model.
|
||||
At training time routeA scores each rollout on the no-grad `logp_old` forward it
|
||||
already needs: an activation-capture hook pools the same bottleneck activations
|
||||
over completion tokens, and the score is the pooled dot product with `v_act`.
|
||||
Thresholds come from a rolling buffer of recent scores, z-normalized and split by
|
||||
two-threshold Otsu into `{keep, absorb, rout}`; until the buffer fills the gate
|
||||
pins absorb. The block masks are set from those labels *before* the single
|
||||
masked forward+backward, so there is no second gradient pass. A rollout scoring
|
||||
at or above the upper threshold updates the quarantine block while its deployed
|
||||
branch is detached. We re-extract `v_act` every N steps (forward-only,
|
||||
quarantine-ablated) so it tracks the current model; the buffer stores pooled
|
||||
activations and re-scores them against the current `v_act`, so a refresh needs
|
||||
no flush.
|
||||
|
||||
Whether the *direction* (not just the *act* of routing) drives suppression is the
|
||||
open question -- the placebo control (Haar-random `v_grad`, same routing
|
||||
machinery) must NOT match real `v_grad`. We watch `qmass` (the share of update
|
||||
open question -- the placebo control (Haar-random `v_act`, same routing
|
||||
machinery) must NOT match real `v_act`. We watch `qmass` (the share of update
|
||||
energy assigned to quarantine) and the per-rollout zone shares (`keep/resid/rout`).
|
||||
|
||||
## What we compare
|
||||
@@ -74,8 +76,9 @@ Three arms, identical model/adapter/teacher pool, differing only in the gate
|
||||
- **none** -- gate pinned clean `(0,0)`: the quarantine never trains. The
|
||||
capacity- and structure-matched vanilla control (same adapter, no shrinkage
|
||||
confound). The emergence reference.
|
||||
- **routeV** -- the method: per-rollout three-way gate from the c-probe gradient
|
||||
vs `v_grad`. `--routeV-random-v-seed` swaps in a Haar-random direction (placebo).
|
||||
- **routeA** -- the method: per-rollout three-way gate from the pooled bottleneck
|
||||
activation vs `v_act`. `--routeA-random-v-seed` swaps in a Haar-random direction
|
||||
(placebo).
|
||||
- **absorb** -- gate pinned mid `(1,0)`: both blocks train on every rollout. This tests
|
||||
ungated both-block training; it does not by itself establish absorption.
|
||||
|
||||
@@ -88,10 +91,10 @@ ablation does not change the model.
|
||||
|
||||
```bash
|
||||
uv sync
|
||||
just smoke # tiny-random model, routeV pathway + all verify gates, ~1-2 min
|
||||
just smoke-all # vanilla + routeV + absorb back to back
|
||||
just smoke # tiny-random model, routeA pathway + all verify gates, ~1-2 min
|
||||
just smoke-all # vanilla + routeA + absorb back to back
|
||||
just download-model # warm Qwen3-4B cache
|
||||
just queue-decision # queue the 4-arm decision run (routeV real / placebo / vanilla / absorb)
|
||||
just queue-decision # queue the 4-arm decision run (routeA real / placebo / vanilla / absorb)
|
||||
```
|
||||
|
||||
See [RESEARCH_JOURNAL.md](RESEARCH_JOURNAL.md) for session-by-session findings,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
set shell := ["bash", "-cu"]
|
||||
|
||||
# vGROUT: rank-2r LoRA gradient routing vs reward-hacking. One adapter (lora2r),
|
||||
# three arms (intervention none|routeV|absorb). See AGENTS.md / README.md.
|
||||
# three arms (intervention none|routeA|absorb). See AGENTS.md / README.md.
|
||||
MODEL := "Qwen/Qwen3-4B"
|
||||
TINY_MODEL := "llamafactory/tiny-random-qwen3" # qwen3 arch, ~6M params, smoke only
|
||||
TRAIN := "uv run python -m vgrout.train" # real LeetCode GRPO entry point
|
||||
@@ -24,8 +24,10 @@ default:
|
||||
# Real pipeline on tiny inputs; verify_*.py assert invariants (no tests/ dir).
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
# Default smoke = routeV (full pipeline: extract v_grad -> two-pass gate -> deploy
|
||||
# ablation). Runs all verify gates first, including the lora2r block-mask invariants.
|
||||
# Default smoke = routeA (full pipeline: extract v_act -> act gate on the logpi_old
|
||||
# forward -> Otsu pinning -> deploy ablation). Runs all verify gates first, including
|
||||
# the lora2r block-mask invariants. (scripts/verify_v_act.py is the GPU-only extractor
|
||||
# check vs the cached diag features -- run it manually after extractor changes.)
|
||||
smoke *ARGS:
|
||||
uv run python scripts/verify_rewards.py # grader: 3 env_modes x clean/hack
|
||||
uv run python scripts/verify_eval_gap.py # eval: train/test token gap, 4 modes
|
||||
@@ -33,18 +35,18 @@ smoke *ARGS:
|
||||
uv run python scripts/verify_science_invariants.py # pair provenance + untouched test
|
||||
uv run python scripts/verify_rotation.py # rotating-unhackable hint-free flip
|
||||
uv run python scripts/verify_lora2r_routing.py # block masks + ablation + c-probe
|
||||
just smoke-routeV {{ ARGS }}
|
||||
just smoke-routeA {{ ARGS }}
|
||||
|
||||
# none: gate pinned clean (0,0) -> quarantine never trains (capacity/structure-matched vanilla).
|
||||
smoke-vanilla *ARGS:
|
||||
BEARTYPE=1 {{ TRAIN }} smoke --intervention=none \
|
||||
--teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 {{ ARGS }}
|
||||
|
||||
# routeV: extract v_grad from authored pairs, splice the per-rollout c-probe gate,
|
||||
# PASS 1 (unmasked) labels rollouts {clean,mid,hack} via the width-pooled band cosine,
|
||||
# PASS 2 (masked) trains the blocks; deploy ablation resets the quarantine to init.
|
||||
smoke-routeV *ARGS:
|
||||
BEARTYPE=1 {{ TRAIN }} smoke --intervention=routeV \
|
||||
# routeA: extract v_act from authored pairs (forward-only), capture pooled acts on the
|
||||
# no-grad logpi_old forward, label rollouts {keep,absorb,rout} via rolling-buffer Otsu
|
||||
# thresholds, ONE masked forward+backward; deploy ablation resets the quarantine to init.
|
||||
smoke-routeA *ARGS:
|
||||
BEARTYPE=1 {{ TRAIN }} smoke --intervention=routeA \
|
||||
--teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \
|
||||
--eval-ablate-every=10 --eval-n-prompts=2 {{ ARGS }}
|
||||
|
||||
@@ -62,21 +64,13 @@ smoke-unhackable *ARGS:
|
||||
--teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \
|
||||
--eval-n-prompts=2 {{ ARGS }}
|
||||
|
||||
# routeV with a top-k routing subspace (max_i cos(g,v_i) over k SVD dirs) instead of
|
||||
# the single mean-mass axis. UAT: log shows "top-3 SVD subspace, gate=max_i cos" and the
|
||||
# band/gate still route (rout>0). k=1 (default) is the mean-diff headline.
|
||||
smoke-topk *ARGS:
|
||||
BEARTYPE=1 {{ TRAIN }} smoke --intervention=routeV --v-grad-k=3 \
|
||||
--teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \
|
||||
--eval-ablate-every=10 --eval-n-prompts=2 {{ ARGS }}
|
||||
|
||||
# routeV + symmetric SOLVE-teacher pool: the G_t teacher slots split 50/50 solve/hack,
|
||||
# routeA + symmetric SOLVE-teacher pool: the G_t teacher slots split 50/50 solve/hack,
|
||||
# and the run logs the routed-share discrimination (UAT: a line "solve-mix gate
|
||||
# discrimination: hack-teacher routed-share=X vs solve-teacher routed-share=Y"). Smoke
|
||||
# points solve at the same tiny pool just to exercise the split+diagnostic path; real
|
||||
# runs use out/pools/teacher_pool_solve (correct-solution demos) vs the hack pool.
|
||||
smoke-solvemix *ARGS:
|
||||
BEARTYPE=1 {{ TRAIN }} smoke --intervention=routeV \
|
||||
BEARTYPE=1 {{ TRAIN }} smoke --intervention=routeA \
|
||||
--teacher-pool-dir=out/pools/teacher_pool --solve-pool-dir=out/pools/teacher_pool \
|
||||
--mix-ratio=0.5 --solve-mix-frac=0.5 \
|
||||
--eval-ablate-every=10 --eval-n-prompts=2 {{ ARGS }}
|
||||
@@ -84,7 +78,7 @@ smoke-solvemix *ARGS:
|
||||
# All three arms back to back (the full-coverage gate).
|
||||
smoke-all:
|
||||
just smoke-vanilla
|
||||
just smoke-routeV
|
||||
just smoke-routeA
|
||||
just smoke-absorb
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
@@ -92,22 +86,21 @@ smoke-all:
|
||||
# pool, 50% unhackable, authored pairs). Every job carries a why:/resolve: label.
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
# Headline 5-arm lora2r decision run, ONLINE-STATS gate + teacher forcing ({{ TEACH }}).
|
||||
# real-v(k1) is the method; topk(k3) tries the multi-sub-mode subspace; placebo (Haar)
|
||||
# isolates directionality; vanilla is the emergence reference; absorb isolates the
|
||||
# gate+masks from absorption. Priority descending so they run in listed order.
|
||||
# Headline 4-arm lora2r decision run, routeA ACT gate + teacher forcing ({{ TEACH }}).
|
||||
# real-v is the method (v_act from authored pairs, Otsu rolling-buffer thresholds);
|
||||
# placebo (Haar) isolates directionality; vanilla is the emergence reference; absorb
|
||||
# isolates the gate+masks from absorption. Priority descending so they run in listed order.
|
||||
# --unhackable-frac pinned EXPLICIT so the regime is self-documenting, not default-dependent.
|
||||
# Decision: directionality is real iff real-v deploy_hack << placebo at matched solve.
|
||||
# Watch the streamed `auroc` col: ~0.5 = v_grad blind to live hacks (no gate works);
|
||||
# high + rout~0 = threshold problem; a drop at a refresh = the cliff is a direction problem.
|
||||
# Watch the streamed `auroc` col (A>0 contrast): ~0.5 = v_act blind to live hacks (no gate
|
||||
# works); high + rout~0 = threshold problem; a drop at a refresh = a direction problem.
|
||||
# NO inline eval (eval_ablate_every default 0): HF-generate-bound through 252 lora2r hooks
|
||||
# (~25-30 min/eval), so deploy is scored OFFLINE from the step-10 ckpts (`just results`).
|
||||
queue-decision seed='43':
|
||||
pueue add -w "$PWD" -o 62 -l "why: P1 lora2r routeV REAL-v k1 online-stats + teacher-forcing s{{seed}} (25% unhackable); resolve: deploy_hack << placebo at matched solve -> directionality real" -- {{ TRAIN }} fast --intervention=routeV --unhackable-frac=0.25 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_routeV_real_s{{seed}}
|
||||
pueue add -w "$PWD" -o 60 -l "why: P2 lora2r routeV TOPK k3 online-stats + teacher-forcing s{{seed}} (25% unhackable); resolve: topk deploy_hack <= real-k1 -> sub-mode subspace catches hacks the mean washes out" -- {{ TRAIN }} fast --intervention=routeV --v-grad-k=3 --unhackable-frac=0.25 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_routeV_topk_s{{seed}}
|
||||
pueue add -w "$PWD" -o 58 -l "why: P3 lora2r routeV PLACEBO-v (Haar 157) + teacher-forcing s{{seed}} (25% unhackable); resolve: deploy_hack ~ vanilla -> real-v suppression is directional, not absorption/shrinkage" -- {{ TRAIN }} fast --intervention=routeV --routeV-random-v-seed=157 --unhackable-frac=0.25 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_routeV_placebo_s{{seed}}
|
||||
pueue add -w "$PWD" -o 56 -l "why: P4 lora2r VANILLA (gate pinned clean) + teacher-forcing s{{seed}} (25% unhackable); resolve: deploy_hack >> 0 emergence reference on the identical adapter" -- {{ TRAIN }} fast --intervention=none --unhackable-frac=0.25 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_vanilla_s{{seed}}
|
||||
pueue add -w "$PWD" -o 54 -l "why: P5 lora2r BOTH-BLOCK (masks pinned (1,0), no gate) + teacher-forcing s{{seed}} (25% unhackable); resolve: ~vanilla -> gate+masks add nothing; << vanilla -> ungated both-block training suppresses" -- {{ TRAIN }} fast --intervention=absorb --unhackable-frac=0.25 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_absorb_s{{seed}}
|
||||
pueue add -w "$PWD" -o 62 -l "why: P1 lora2r routeA REAL-v act gate + teacher-forcing s{{seed}} (25% unhackable); resolve: deploy_hack << placebo at matched solve -> directionality real" -- {{ TRAIN }} fast --intervention=routeA --unhackable-frac=0.25 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_routeA_real_s{{seed}}
|
||||
pueue add -w "$PWD" -o 58 -l "why: P2 lora2r routeA PLACEBO-v (Haar 157) + teacher-forcing s{{seed}} (25% unhackable); resolve: deploy_hack ~ vanilla -> real-v suppression is directional, not absorption/shrinkage" -- {{ TRAIN }} fast --intervention=routeA --routeA-random-v-seed=157 --unhackable-frac=0.25 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_routeA_placebo_s{{seed}}
|
||||
pueue add -w "$PWD" -o 56 -l "why: P3 lora2r VANILLA (gate pinned clean) + teacher-forcing s{{seed}} (25% unhackable); resolve: deploy_hack >> 0 emergence reference on the identical adapter" -- {{ TRAIN }} fast --intervention=none --unhackable-frac=0.25 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_vanilla_s{{seed}}
|
||||
pueue add -w "$PWD" -o 54 -l "why: P4 lora2r BOTH-BLOCK (masks pinned (1,0), no gate) + teacher-forcing s{{seed}} (25% unhackable); resolve: ~vanilla -> gate+masks add nothing; << vanilla -> ungated both-block training suppresses" -- {{ TRAIN }} fast --intervention=absorb --unhackable-frac=0.25 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_absorb_s{{seed}}
|
||||
|
||||
# Base model zero-shot deploy eval (0 training steps): reproduce the paper's base
|
||||
# solve ~11.5% in our harness. resolve: base solve ~0.10-0.12.
|
||||
|
||||
+19
-23
@@ -1,4 +1,4 @@
|
||||
"""Q2 diagnostic: what should the live routeV gate SCORE, and where do the pinning
|
||||
"""Q2 diagnostic: what should the live routing gate SCORE, and where do the pinning
|
||||
cuts go?
|
||||
|
||||
THE QUESTION (Q2). The gate routes UPDATES, not rollouts: per rollout the GRPO update
|
||||
@@ -90,8 +90,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from vgrout.lora2r import wrap_model_with_lora2r
|
||||
from vgrout.pairs import load_pairs
|
||||
from vgrout.extract_vhack_grad import completion_nll
|
||||
from vgrout.train import _auroc
|
||||
from vgrout.train import _auroc, _otsu3
|
||||
|
||||
# colour = behaviour (blue solve, red hack, grey fail); style = source (solid on-policy, dashed synthetic)
|
||||
SOLVE, HACK, FAIL, ABSORB_C, ROUT_C, ORACLE = "#3b6ea5", "#c44e52", "#9aa0a6", "#d1900a", "#c44e52", "#3a8a7a"
|
||||
@@ -104,7 +103,7 @@ class Cfg:
|
||||
run_dir: Path = Path("out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3")
|
||||
ckpt: str = "first_hack"
|
||||
pairs: Path = Path("data/pairs/hack_pairs.md#all-in-one")
|
||||
# headline figure builds v from this heading-prefix subset = the routeV TRAINING
|
||||
# headline figure builds v from this heading-prefix subset = the routeA TRAINING
|
||||
# default (train_config.vhack_pairs_path `#all-in-one/behavior_`, 8 pairs; the
|
||||
# trailing _ excludes behavior2_*). The pairset table spans all subsets of `pairs`.
|
||||
headline_prefix: str = "behavior_"
|
||||
@@ -216,24 +215,21 @@ def _kde(x: np.ndarray, grid: np.ndarray) -> np.ndarray:
|
||||
return np.exp(-0.5 * z ** 2).sum(1) / (len(x) * bw * np.sqrt(2 * np.pi))
|
||||
|
||||
|
||||
def _otsu3(x: np.ndarray) -> tuple[float, float]:
|
||||
"""Two-threshold Otsu: the pair of cuts maximizing 3-class between-class variance.
|
||||
Label-free -- an online gate can compute this from a rolling window of scores, so
|
||||
using it here is not oracle leakage. O(n^2), fine for a few hundred scores.
|
||||
Scores are winsorized at 1/99% first: Otsu maximizes variance, so on heavy-tailed
|
||||
scores a single extreme point otherwise buys a whole class (seen on grad_dot)."""
|
||||
x = np.clip(x, *np.quantile(x, [0.01, 0.99]))
|
||||
s = np.sort(np.asarray(x, float))
|
||||
n = len(s)
|
||||
c = np.concatenate([[0.0], np.cumsum(s)])
|
||||
best, best_ij = -np.inf, (1, 2)
|
||||
for i in range(1, n - 1):
|
||||
for j in range(i + 1, n):
|
||||
obj = c[i] ** 2 / i + (c[j] - c[i]) ** 2 / (j - i) + (c[n] - c[j]) ** 2 / (n - j)
|
||||
if obj > best:
|
||||
best, best_ij = obj, (i, j)
|
||||
i, j = best_ij
|
||||
return float((s[i - 1] + s[i]) / 2), float((s[j - 1] + s[j]) / 2)
|
||||
def completion_nll(model, tokenizer, prompt: str, completion: str, device) -> torch.Tensor:
|
||||
"""Mean NLL over completion tokens only (length-normalized). The backward of this
|
||||
loss populates the c-probe grads read by _gate_grads (the retired grad-gate space,
|
||||
kept here as a diagnostic baseline)."""
|
||||
prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
||||
full_ids = tokenizer(prompt + completion, return_tensors="pt").input_ids.to(device)
|
||||
n_prompt = prompt_ids.shape[1]
|
||||
logits = model(full_ids).logits[:, :-1] # [1, L-1, V]
|
||||
targets = full_ids[:, 1:] # [1, L-1]
|
||||
logp = torch.nn.functional.log_softmax(logits.float(), dim=-1)
|
||||
nll = -logp.gather(-1, targets.unsqueeze(-1)).squeeze(-1) # [1, L-1]
|
||||
# mask: positions whose target is a completion token (i.e. index >= n_prompt in full_ids)
|
||||
pos = torch.arange(full_ids.shape[1] - 1, device=device).unsqueeze(0)
|
||||
mask = (pos >= (n_prompt - 1)).float()
|
||||
return (nll * mask).sum() / mask.sum().clamp_min(1.0)
|
||||
|
||||
|
||||
def plot_q2(df: pl.DataFrame, subtitle: str, out_png: Path) -> dict:
|
||||
@@ -497,7 +493,7 @@ def _downstream(cfg: Cfg, fe: dict, src: str) -> int:
|
||||
print(f"SHOULD: on_hackpos >= ~20 and on_drop not the majority, ELSE the window/run has "
|
||||
f"too few learnable hacks and every AUROC below is noise.")
|
||||
|
||||
# ── headline vectors from the routeV-default subset; placebo swaps in Haar ──
|
||||
# ── headline vectors from the routeA-default subset; placebo swaps in Haar ──
|
||||
groups: dict[str, list[int]] = fe["pair_groups"]
|
||||
head_idx = [i for i, pid in enumerate(pair_ids) if pid.startswith(cfg.headline_prefix)]
|
||||
assert head_idx, f"no pairs match headline prefix {cfg.headline_prefix!r}"
|
||||
|
||||
@@ -7,12 +7,12 @@ hack-side by detector signature. Here the source is the student's logged
|
||||
rollouts (out/runs/<run>/rollouts.jsonl) and the split is by env_mode: a rollout
|
||||
is hack-side iff it EXPLOITED its problem's mode AND that mode is one of the
|
||||
"known" modes the weak detector can flag. The held-out modes are never used to
|
||||
build pairs -- v_grad is extracted only from the known modes, and the A5 figure
|
||||
build pairs -- the routing vector is extracted only from the known modes, and the A5 figure
|
||||
then measures whether the held-out modes are also suppressed at deployment. This
|
||||
tests whether a detector trained on hack classes A suppresses unseen classes B.
|
||||
|
||||
Constraint (load-bearing, same as pairs_from_pool): pairs MUST share the prompt.
|
||||
The paired-diff g_hack - g_clean in extract_vhack_grad cancels prompt-specific
|
||||
The paired-diff feat_hack - feat_clean in the pair extraction cancels prompt-specific
|
||||
noise only when both completions are conditioned on the same chat-templated
|
||||
prompt. A given problem_id renders one fixed (hinted) prompt across steps, so
|
||||
same-(problem_id) hack/clean rollouts share it; we assert equality to catch drift.
|
||||
|
||||
@@ -74,7 +74,8 @@ def main() -> int:
|
||||
|
||||
authored_pairs = load_pairs(Path("data/pairs/hack_pairs.md#all-in-one"))
|
||||
real_pairsets_ok = (
|
||||
len(authored_pairs) == 27
|
||||
len(authored_pairs) == 42 # 27 + 15 wave-2 behavior2_* (c33b810)
|
||||
and len(load_pairs(Path("data/pairs/hack_pairs.md#all-in-one/behavior_"))) == 8 # routeA training default
|
||||
and len(load_pairs(Path("data/pairs/hack_pairs.md#all-in-one@opportunity-aware"))) == 6
|
||||
and len(load_pairs(Path("data/pairs/hack_pairs.md#all-in-one@explicit"))) == 10
|
||||
and len(load_pairs(Path("data/pairs/hack_pairs.md#all-in-one@roleplay"))) == 2
|
||||
|
||||
@@ -1,301 +0,0 @@
|
||||
"""Gradient-side per-module v_hack extraction (spec.md §B, top-k variant).
|
||||
|
||||
We sample the per-module GRPO update direction on labeled (hack, clean) pairs.
|
||||
For a pair with advantages (adv_h=+1, adv_c=-1) the Dr.GRPO single-step grad
|
||||
`-adv_h * grad_logp(hack) - adv_c * grad_logp(clean)` algebraically equals
|
||||
`grad_NLL(hack) - grad_NLL(clean)`, so we compute it by the simpler path:
|
||||
forward each completion, take mean-NLL on completion tokens, backward, and
|
||||
capture the lora2r c-probe grad (the per-pair weight grad of the virtual
|
||||
diagonal between A and B, deployed block) per wrapped Linear. Naming the steps
|
||||
NLL is an implementation detail; the *meaning* is "the GRPO update on this pair."
|
||||
|
||||
Then per module, with D = [g_hack_i - g_clean_i for each pair] in R^{n_pairs x r}:
|
||||
SVD(D) = U Σ Vh
|
||||
v_hack[name] = top_k rows of Vh, each oriented so mean(D @ v_i) > 0
|
||||
|
||||
This generalizes mean-diff (which corresponds to top-1 PC of paired diffs under
|
||||
isotropic covariance) to a rank-k hack subspace, motivated by CHaRS (Abdullaev
|
||||
2025 -- see docs/paper_chars.md): hack signal is multi-modal across hack flavors
|
||||
(weak tests, hardcode, persona, ...), so a single global direction is brittle.
|
||||
|
||||
Orientation matters because proj.py applies a per-direction one-sided gate
|
||||
(only subtracts <g, v_i> when positive). +v_i must align with the reward-hacking gradient.
|
||||
|
||||
Saves `out/v_hack.safetensors` = dict[name -> Tensor[k, r]] (cpu fp32, rows
|
||||
unit-norm + orthonormal from SVD) with header {"model": str, "dtype": str,
|
||||
"top_k": str(k)}.
|
||||
|
||||
Run: uv run python -m vgrout.extract_vhack_grad
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import tyro
|
||||
from jaxtyping import Float
|
||||
from loguru import logger
|
||||
from safetensors.torch import save_file
|
||||
from tabulate import tabulate
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from .lora2r import wrap_model_with_lora2r
|
||||
from .pairs import load_pairs, pairset_sha256
|
||||
|
||||
|
||||
OUT_DIR = Path("out")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
model: str = "Qwen/Qwen3-4B"
|
||||
dtype: str = "bf16" # must match train.py, else SVD basis cache can differ silently
|
||||
out_path: Path = OUT_DIR / "vhack" / "v_hack.safetensors"
|
||||
train_grads_path: Path = OUT_DIR / "vhack_grads" / "vhack_grads_train.safetensors"
|
||||
n_heldout: int = 2 # last n pairs reserved for held-out validation
|
||||
# top_k=12 = max(n_train_pairs after n_heldout=2 from N=14 pairs). Extract once
|
||||
# at max rank; train.py slices via --v-hack-k for k-ablation without re-extract.
|
||||
top_k: int = 12
|
||||
# tau_axis: zero rows where S_i/S_0 < tau_axis. Diagnostic -- projection along
|
||||
# noise-direction unit vectors removes only ~||g||/sqrt(r) ≈ 2% of grad
|
||||
# magnitude on r=2560 modules, so this rarely changes effect size; it does
|
||||
# keep k-ablations interpretable (axes 4-5 might be pure noise on N=12 pairs).
|
||||
tau_axis: float = 0.0
|
||||
# Pairset reference: generated JSON or one `path.md#section`.
|
||||
pairs_from_pool: Path | None = None
|
||||
|
||||
|
||||
def resolve_dtype(s: str) -> torch.dtype:
|
||||
return {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[s]
|
||||
|
||||
|
||||
def completion_nll(model, tokenizer, prompt: str, completion: str, device) -> torch.Tensor:
|
||||
"""Mean NLL over completion tokens only (length-normalized)."""
|
||||
prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
||||
full_ids = tokenizer(prompt + completion, return_tensors="pt").input_ids.to(device)
|
||||
n_prompt = prompt_ids.shape[1]
|
||||
logits = model(full_ids).logits[:, :-1] # [1, L-1, V]
|
||||
targets = full_ids[:, 1:] # [1, L-1]
|
||||
logp = torch.nn.functional.log_softmax(logits.float(), dim=-1)
|
||||
nll = -logp.gather(-1, targets.unsqueeze(-1)).squeeze(-1) # [1, L-1]
|
||||
# mask: positions whose target is a completion token (i.e. index >= n_prompt in full_ids)
|
||||
pos = torch.arange(full_ids.shape[1] - 1, device=device).unsqueeze(0)
|
||||
mask = (pos >= (n_prompt - 1)).float()
|
||||
return (nll * mask).sum() / mask.sum().clamp_min(1.0)
|
||||
|
||||
|
||||
def extract_v_hack(
|
||||
model,
|
||||
tokenizer,
|
||||
wrappers: dict,
|
||||
pairs: list,
|
||||
top_k: int,
|
||||
tau_axis: float,
|
||||
n_heldout: int,
|
||||
device,
|
||||
) -> tuple[
|
||||
dict[str, Float[torch.Tensor, "k r"]],
|
||||
dict[str, Float[torch.Tensor, "k"]],
|
||||
dict[str, Float[torch.Tensor, "n_pairs r"]],
|
||||
list[dict],
|
||||
]:
|
||||
"""Run pair-grads + per-module SVD on D = g_hack - g_clean, return v_hack.
|
||||
|
||||
Pure function -- caller owns model loading, wrapping, and saving. train.py
|
||||
calls this on its already-wrapped model when v_hack cache is missing, so
|
||||
we don't pay the cost of a second model load.
|
||||
|
||||
Returns:
|
||||
v_hack: dict[name -> Tensor[k, r]] (cpu fp32), top-k right singular
|
||||
vectors of D per module, oriented so mean(D @ v_i) > 0. If
|
||||
tau_axis > 0, rows where S_i/S_0 < tau_axis are zeroed.
|
||||
v_sv: dict[name -> Tensor[k]] (cpu fp32), singular values matching v_hack.
|
||||
Saved alongside V under `_sv/{name}` keys so load-time noise-floor
|
||||
filtering works without re-extracting.
|
||||
raw_grads: dict["hack/name"|"clean/name" -> Tensor[n_pairs, r]] for
|
||||
offline analysis (verify_vhack_heldout reads this).
|
||||
diag_rows: per-module diagnostic dicts (sv_top frac, ||D||, etc.).
|
||||
"""
|
||||
train_pairs = pairs[:-n_heldout] if n_heldout > 0 else pairs
|
||||
n_pairs = len(train_pairs)
|
||||
|
||||
grads_hack: dict[str, list[torch.Tensor]] = defaultdict(list)
|
||||
grads_clean: dict[str, list[torch.Tensor]] = defaultdict(list)
|
||||
|
||||
for pi, pair in enumerate(train_pairs):
|
||||
for label, completion in (("hack", pair.hack), ("clean", pair.clean)):
|
||||
model.zero_grad(set_to_none=True)
|
||||
loss = completion_nll(model, tokenizer, pair.prompt, completion, device)
|
||||
if not torch.isfinite(loss):
|
||||
# Skip-on-nonfinite would silently leave G_h and G_c with mismatched
|
||||
# lengths and explode later in D = G_h - G_c. Fail fast.
|
||||
raise RuntimeError(f"non-finite loss at pair={pi} label={label}: {loss.item()}")
|
||||
loss.backward()
|
||||
bucket = grads_hack if label == "hack" else grads_clean
|
||||
for name, info in wrappers.items():
|
||||
layer = info["layer"]
|
||||
# Per-pair weight grad of the virtual diagonal (c-probe), DEPLOYED
|
||||
# block only -- the same space the live gate reads (train.py), so
|
||||
# band calibration uses the same representation. Requires grad_probe=True.
|
||||
cg = layer._lora2r_gate.grad
|
||||
if cg is None:
|
||||
raise RuntimeError(f"no c-probe grad on {name}; wrap with grad_probe=True")
|
||||
g = cg.sum(dim=tuple(range(cg.dim() - 1)))[: layer._lora2r_r] # [r]
|
||||
bucket[name].append(g.detach().float().cpu().clone())
|
||||
if (pi + 1) % 5 == 0:
|
||||
logger.info(f" pair {pi+1}/{n_pairs} loss={loss.item():.3f}")
|
||||
model.zero_grad(set_to_none=True) # leave caller with clean state
|
||||
|
||||
raw_grads = {
|
||||
**{f"hack/{n}": torch.stack(gs) for n, gs in grads_hack.items()},
|
||||
**{f"clean/{n}": torch.stack(gs) for n, gs in grads_clean.items()},
|
||||
}
|
||||
|
||||
# Per module: D = g_hack - g_clean (paired diff cancels prompt-specific noise).
|
||||
# SVD(D) gives orthonormal right singular vectors = principal axes of variation
|
||||
# of the hack-clean axis. Top-k generalizes mean-diff (which is the k=1 case).
|
||||
v_hack: dict[str, torch.Tensor] = {}
|
||||
v_sv: dict[str, torch.Tensor] = {}
|
||||
rows = []
|
||||
n_zero = 0
|
||||
k = min(top_k, n_pairs)
|
||||
n_axes_kept_total = 0
|
||||
for name in grads_hack:
|
||||
G_h = torch.stack(grads_hack[name]) # [n_pairs, r]
|
||||
G_c = torch.stack(grads_clean[name])
|
||||
D = G_h - G_c
|
||||
|
||||
U_d, S_d, Vh_d = torch.linalg.svd(D, full_matrices=False)
|
||||
V = Vh_d[:k] # [k, r], rows orthonormal in R^r
|
||||
# Orient by per-pair majority vote: for each axis i, count pairs where
|
||||
# d_p @ v_i > 0; if strict majority disagree with current SVD sign, flip.
|
||||
# More outlier-robust than sign(mean): one extreme pair can't flip a
|
||||
# consensus direction. Matches repeng's _orient_svd convention.
|
||||
proj_per_pair = D @ V.T # [n_pairs, k]
|
||||
n_pos = (proj_per_pair > 0).float().sum(0) # [k]
|
||||
flip = torch.where(n_pos < n_pairs / 2, -torch.ones(k), torch.ones(k))
|
||||
V = V * flip.unsqueeze(1)
|
||||
|
||||
# tau_axis: zero rows where S_i/S_0 < tau_axis (diagnostic; see Config comment).
|
||||
n_axes_kept = k
|
||||
if tau_axis > 0 and S_d[0] > 1e-12:
|
||||
ratios = S_d[:k] / S_d[0]
|
||||
keep = (ratios >= tau_axis).float()
|
||||
V = V * keep.unsqueeze(1)
|
||||
n_axes_kept = int(keep.sum())
|
||||
n_axes_kept_total += n_axes_kept
|
||||
|
||||
nrm = D.norm()
|
||||
if nrm < 1e-12:
|
||||
n_zero += 1
|
||||
v_hack[name] = torch.zeros((k, D.shape[1]), dtype=V.dtype).contiguous()
|
||||
else:
|
||||
v_hack[name] = V.contiguous()
|
||||
# Record singular values so the load-time noise-floor filter has the
|
||||
# extraction-time S_i per axis without re-extracting. Saved under
|
||||
# `_sv/{name}` keys in the safetensors file (combined at save site).
|
||||
v_sv[name] = S_d[:k].clone().contiguous()
|
||||
sv_top = S_d[:k]
|
||||
sv_total = S_d.sum().clamp_min(1e-12)
|
||||
rows.append({
|
||||
"module": name.split(".")[-1],
|
||||
"r": D.shape[1],
|
||||
"||D||": f"{nrm:.2e}",
|
||||
"sv_0": f"{S_d[0].item():.2e}" if S_d.numel() else "-",
|
||||
f"sv_top{k}_frac": f"{(sv_top.sum() / sv_total).item():.2f}",
|
||||
"sv_ratio_0/1": ("-" if S_d.numel() < 2
|
||||
else f"{(S_d[0] / S_d[1].clamp_min(1e-12)).item():.2f}"),
|
||||
"axes_kept": n_axes_kept,
|
||||
})
|
||||
n_modules = len(grads_hack)
|
||||
logger.info(
|
||||
f"v_hack: modules={n_modules} k_max={k} zero-||D||={n_zero} "
|
||||
f"axes_kept_avg={n_axes_kept_total/max(1,n_modules):.1f} (tau_axis={tau_axis})"
|
||||
)
|
||||
return v_hack, v_sv, raw_grads, rows
|
||||
|
||||
|
||||
def main(cfg: Config) -> int:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
dtype = resolve_dtype(cfg.dtype)
|
||||
if cfg.pairs_from_pool is None:
|
||||
raise ValueError("--pairs-from-pool is required; use data/pairs/hack_pairs.md#all-in-one")
|
||||
pairs = load_pairs(cfg.pairs_from_pool)
|
||||
logger.info(f"pairs source: {cfg.pairs_from_pool} -> {len(pairs)} pairs")
|
||||
logger.info(
|
||||
f"device={device} model={cfg.model} dtype={cfg.dtype} "
|
||||
f"N_pairs={len(pairs)} heldout={cfg.n_heldout} top_k={cfg.top_k} tau_axis={cfg.tau_axis}"
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(cfg.model)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
cfg.model, dtype=dtype, attn_implementation="sdpa"
|
||||
).to(device)
|
||||
model.eval() # disable dropout; gradients still flow through the adapter
|
||||
wrappers = wrap_model_with_lora2r(model, grad_probe=True)
|
||||
n_mod = len(wrappers)
|
||||
logger.info(f"wrapped {n_mod} modules; probe space = r per module")
|
||||
|
||||
train_pairs = pairs[:-cfg.n_heldout] if cfg.n_heldout > 0 else pairs
|
||||
logger.info(f"train pairs: {len(train_pairs)} held: {cfg.n_heldout}")
|
||||
|
||||
v_hack, v_sv, raw_grads, rows = extract_v_hack(
|
||||
model, tokenizer, wrappers, pairs,
|
||||
top_k=cfg.top_k, tau_axis=cfg.tau_axis,
|
||||
n_heldout=cfg.n_heldout, device=device,
|
||||
)
|
||||
n_zero = sum(1 for v in v_hack.values() if v.norm() < 1e-12)
|
||||
k = min(cfg.top_k, len(train_pairs))
|
||||
|
||||
cfg.out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
cfg.train_grads_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
save_file(raw_grads, str(cfg.train_grads_path),
|
||||
metadata={"model": cfg.model, "dtype": cfg.dtype})
|
||||
# v_hack file layout: bare `{name}` keys hold V[k, r]; `_sv/{name}` keys
|
||||
# hold S[k]. Loader at train.py:load_v_hack splits them back apart.
|
||||
save_payload = {**v_hack, **{f"_sv/{n}": s for n, s in v_sv.items()}}
|
||||
save_file(save_payload, str(cfg.out_path),
|
||||
metadata={"model": cfg.model, "dtype": cfg.dtype, "top_k": str(k),
|
||||
"tau_axis": str(cfg.tau_axis), "schema": "v2_with_sv",
|
||||
"pairs_path": str(cfg.pairs_from_pool),
|
||||
"pairs_sha256": pairset_sha256(cfg.pairs_from_pool)})
|
||||
|
||||
# summary: aggregate by suffix -- track top-k energy concentration
|
||||
by_suffix: dict[str, list] = defaultdict(list)
|
||||
for r in rows:
|
||||
by_suffix[r["module"]].append(float(r[f"sv_top{k}_frac"]))
|
||||
agg_rows = []
|
||||
for suf, vals in sorted(by_suffix.items()):
|
||||
agg_rows.append({
|
||||
"suffix": suf,
|
||||
"n": len(vals),
|
||||
f"mean_sv_top{k}_frac": f"{sum(vals)/len(vals):.2f}",
|
||||
f"min_sv_top{k}_frac": f"{min(vals):.2f}",
|
||||
f"max_sv_top{k}_frac": f"{max(vals):.2f}",
|
||||
})
|
||||
|
||||
# Final tail: BLUF -- what an agent reads first should be result + interp.
|
||||
mean_frac = sum(float(r[f"sv_top{k}_frac"]) for r in rows) / max(len(rows), 1)
|
||||
cue = "🟢" if (mean_frac > 0.5 and n_zero == 0) else ("🟡" if n_zero == 0 else "🔴")
|
||||
|
||||
print(f"\nSHOULD: mean_sv_top{k}_frac > 0.5 per suffix (subspace captures most energy). "
|
||||
f"zero-||D||==0 (else grad flow broken).\n")
|
||||
print(tabulate(agg_rows, headers="keys", tablefmt="tsv", floatfmt=".2f"))
|
||||
print()
|
||||
print(f"out: {cfg.out_path}")
|
||||
print(f"argv: extract_vhack_grad --model={cfg.model} --top-k={k} --n-heldout={cfg.n_heldout}")
|
||||
print(f"main metric: mean_sv_top{k}_frac={mean_frac:.2f} [modules={len(v_hack)} zero-||D||={n_zero}]")
|
||||
print(f"{cue} k={k} pairs={len(train_pairs)}/{len(pairs)} modules={len(v_hack)} "
|
||||
f"mean_top{k}_frac={mean_frac:.2f} zero={n_zero}")
|
||||
|
||||
if n_zero > 0:
|
||||
logger.error(f"FAIL: {n_zero}/{len(v_hack)} modules have zero ||D|| -- gradient flow broken")
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main(tyro.cli(Config)))
|
||||
+6
-4
@@ -17,14 +17,16 @@ from pathlib import Path
|
||||
FIGS_DIR = Path("docs/figs")
|
||||
|
||||
# Reader-facing arm names. Code/log tags carry our internal vocabulary
|
||||
# (routeV = the current routing arm); plots must
|
||||
# (routeA = the current routing arm); plots must
|
||||
# not. Map every internal tag to the word a paper reader sees. Anything missing
|
||||
# falls through to its raw tag, so a new arm shows up loud rather than silently
|
||||
# mislabelled.
|
||||
ARM_DISPLAY = {
|
||||
# routeV is the current banded-gate arm; routing2/route2 are the old binary-tau runs
|
||||
# (kept so historical run artifacts still plot -- see rename, 2026-06-06).
|
||||
"routingV": "route", "routeV": "route", "routingV_per_token": "route per-token",
|
||||
# routeA is the current act-gate arm; routeV (grad gate) and routing2/route2
|
||||
# (binary-tau) are retired but kept so historical run artifacts still plot.
|
||||
"routeA": "route",
|
||||
"routingV": "route (grad)", "routeV": "route (grad)",
|
||||
"routingV_per_token": "route per-token",
|
||||
"routing2": "route", "route2": "route",
|
||||
"routing2_grad": "route", "routing2_act": "route (act)",
|
||||
"projected": "erase", "route": "route", "erase": "erase", "vanilla": "vanilla",
|
||||
|
||||
@@ -51,9 +51,9 @@ def _lora2r_hook(layer: nn.Linear, args: tuple, y: Tensor) -> Tensor:
|
||||
|
||||
grad probe: c = ones[..., 2r] spliced as h*c. After backward
|
||||
c.grad = h ⊙ (Bᵀδ_y) = the per-sample WEIGHT grad of a virtual diagonal
|
||||
scale between A and B. The pair extraction (extract_vhack_grad) and the live
|
||||
gate (train.py) both read this same space, so the band is self-consistent
|
||||
whatever the basis.
|
||||
scale between A and B. Training no longer uses it (the routeA gate reads
|
||||
activations); kept for scripts/diag_pinning.py, whose grad-score baseline
|
||||
reads this space.
|
||||
"""
|
||||
(x,) = args # x: [..., d_in]
|
||||
A: Float[Tensor, "two_r d_in"] = layer._lora2r_A # trainable
|
||||
|
||||
@@ -9,7 +9,7 @@ exclusively from `half-A` detectors. The clean-side is any rollout where all
|
||||
four upstream detectors are False AND format_ok is True.
|
||||
|
||||
Constraint (load-bearing): pairs MUST share the prompt. The paired-diff
|
||||
`g_hack - g_clean` in extract_vhack_grad cancels prompt-specific noise only
|
||||
`feat_hack - feat_clean` in the pair extraction cancels prompt-specific noise only
|
||||
when both completions are conditioned on the same chat-templated prompt.
|
||||
Cross-prompt pairs would inject prompt-difference signal into v_hack.
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ RUN_SCHEMA = "paired_final_v2" # v2: deployed/as_trained field names (was depl
|
||||
# get a _lora2r suffix so the two substrates never conflate in aggregation.
|
||||
ARM = {"none": "vanilla", "erase": "projected",
|
||||
"routeV": "routingV", "routeV_per_token": "routingV_per_token",
|
||||
"routeA": "routeA",
|
||||
"absorb": "absorb"}
|
||||
|
||||
|
||||
|
||||
+10
-11
@@ -73,7 +73,7 @@ class StepLogger:
|
||||
def __init__(self, arm: str, modes: list[str], mode_code: dict[str, str],
|
||||
show_ablate: bool = False) -> None:
|
||||
# Routing diagnostics are ALWAYS shown (nan on vanilla, whose gate never runs) so the
|
||||
# column layout is identical across arms -- vanilla/routeV/absorb tables line up.
|
||||
# column layout is identical across arms -- vanilla/routeA/absorb tables line up.
|
||||
cols: list[_Col] = [
|
||||
_Col("step", 4, "step", "d", "GRPO step"),
|
||||
_Col("ref_eq", 6, "ref_eq", ".2f", "vanilla-equiv step (cum_gens/256)"),
|
||||
@@ -84,7 +84,7 @@ class StepLogger:
|
||||
_Col("hack_s", 7, "hack_s?", "frac", "student hack-flagged rollouts (the headline)"),
|
||||
_Col("hack_t", 7, "hack_t", "frac", "teacher hack-flagged rollouts (sanity: pool hacks)"),
|
||||
# Held-out deployed evaluation with quarantine ablated; NaN between evaluation steps.
|
||||
_Col("hack_deployed", 7, "hk_dep", "+.2f", "DEPLOY-eval hack (routeV: quarantine OFF; vanilla/erase: trained model); held-out subset, T=0.7, every eval_ablate_every steps; nan between"),
|
||||
_Col("hack_deployed", 7, "hk_dep", "+.2f", "DEPLOY-eval hack (routeA/absorb: quarantine OFF; vanilla: trained model); held-out subset, T=0.7, every eval_ablate_every steps; nan between"),
|
||||
_Col("solve_deployed", 7, "slv_dep", "+.2f", "DEPLOY-eval solve (same cadence; nan between)"),
|
||||
]
|
||||
# Multi-mode runs show current-step hacks per environment; single-mode would duplicate hack_s.
|
||||
@@ -99,17 +99,16 @@ class StepLogger:
|
||||
_Col("gn", 7, "gn", ".1e", "pre-clip L2 norm of A/B grads (vs grad_clip)"),
|
||||
_Col("lr", 7, "lr", ".1e", "scheduled learning rate"),
|
||||
]
|
||||
# routeV reports unit and energy shares across the routing band (nan on vanilla/absorb).
|
||||
# routeA reports gate diagnostics (nan on vanilla/absorb, whose gate never runs).
|
||||
cols += [
|
||||
_Col("auroc", 6, "auroc", ".2f", "AUROC of pooled cos(g,v_grad) as a reward-hacking detector; measurement only, never routes. ~0.5 = chance-level separation; high AUROC but rout~0 = threshold/scale problem; a drop at refresh = reduced separation"),
|
||||
_Col("cosU", 6, "cosU", "+.2f", "pooled cos(v_grad, summed-rollout c-grad): alignment of the net update with v_grad"),
|
||||
_Col("auroc", 6, "auroc", ".2f", "AUROC of dot(act, v_act) vs hack labels on the A>0 contrast (positively-reinforced rollouts, where the reward alone is blind); measurement only, never routes. ~0.5 = chance-level separation; high AUROC but rout~0 = threshold problem; a drop at refresh = reduced separation"),
|
||||
_Col("cos", 6, "cos", "+.2f", "mean per-rollout cos(act, v_act) (dot-vs-cos diagnostic)"),
|
||||
_Col("qmass", 6, "qmass", ".2f", "quarantine energy share ||g_quar||/(||g_keep||+||g_quar||): fraction of update energy assigned to quarantine"),
|
||||
_Col("keep", 6, "keep", ".2f", "rollout share labelled clean (below band) -> deployed-only, quarantine off"),
|
||||
_Col("resid", 6, "resid", ".2f", "rollout share labelled mid (inside band) -> both blocks train; absorption is possible but not measured"),
|
||||
_Col("rout", 6, "rout", ".2f", "rollout share labelled hack (above band) -> quarantine-only, deployed detached"),
|
||||
_Col("keepE", 6, "keepE", ".2f", "energy-weighted keep (grad-norm-weighted clean share)"),
|
||||
_Col("residE", 6, "residE", ".2f", "energy-weighted resid"),
|
||||
_Col("routE", 6, "routE", ".2f", "energy-weighted rout"),
|
||||
_Col("keep", 6, "keep", ".2f", "rollout share below t_lo -> deployed-only, quarantine off"),
|
||||
_Col("resid", 6, "resid", ".2f", "rollout share between thresholds (and ALL rollouts during warmup) -> both blocks train; absorption is possible but not measured"),
|
||||
_Col("rout", 6, "rout", ".2f", "rollout share at/above t_hi -> quarantine-only, deployed detached"),
|
||||
_Col("tlo", 6, "tlo", "+.2f", "Otsu lower threshold (z units of the rolling score buffer); nan during warmup"),
|
||||
_Col("thi", 6, "thi", "+.2f", "Otsu upper (rout) threshold (z units); nan during warmup"),
|
||||
]
|
||||
# Show the training-prompt deploy proxy only when an ablated slice exists.
|
||||
if show_ablate:
|
||||
|
||||
+213
-296
@@ -16,13 +16,14 @@ both trainable, partitioned into a deployed block [:r] and a quarantine block
|
||||
Arms (--intervention):
|
||||
none gate pinned clean (0,0): quarantine never trains -- the capacity- and
|
||||
structure-matched vanilla control.
|
||||
routeV per-rollout three-way gate from the c-probe gradient vs v_grad:
|
||||
clean->deployed-only, hack->quarantine-only (deployed detached),
|
||||
mid->both, which may permit absorption.
|
||||
routeA per-rollout three-way gate from the pooled bottleneck activation vs
|
||||
v_act: keep->deployed-only, rout->quarantine-only (deployed detached),
|
||||
absorb->both, which may permit absorption. The acts ride the no-grad
|
||||
logpi_old forward, so routeA costs roughly the vanilla arm.
|
||||
absorb gate pinned mid (1,0): both blocks train on everything, no gate --
|
||||
tests ungated both-block training.
|
||||
|
||||
uv run python -m vgrout.train smoke --intervention=routeV
|
||||
uv run python -m vgrout.train smoke --intervention=routeA
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -33,9 +34,12 @@ import os
|
||||
import sys
|
||||
import random
|
||||
import time
|
||||
from collections import deque
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Must be set BEFORE `import torch` to take effect on the CUDA allocator.
|
||||
# Eliminates fragmentation that caused 91 GiB allocated / 581 MiB free crash
|
||||
# on Qwen3-4B G=8 (PyTorch's own OOM message recommends this).
|
||||
@@ -51,7 +55,9 @@ from tabulate import tabulate
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||
|
||||
from .extract_vhack_act import ActCapture, extract_v_act, haar_unit_rows
|
||||
from .lora2r import wrap_model_with_lora2r
|
||||
from .pairs import load_pairs
|
||||
from .proj import per_token_logps
|
||||
from .rewards import EnvMode, compute_reward
|
||||
from .data import DATA, load_problems
|
||||
@@ -64,46 +70,27 @@ OUT_DIR = Path("out")
|
||||
RUNS_DIR = OUT_DIR / "runs"
|
||||
|
||||
|
||||
def _haar_unit_dirs(v_grad: dict, seed: int, device) -> dict:
|
||||
"""Build the reproducible out-of-subspace directionality control (placebo) for routeV.
|
||||
|
||||
Matches v_grad's [k, r] shape and unit-normalizes per row, so a top-k routing run
|
||||
gets k random dirs (the gate's max-cosine still sees the same shape) for a fair placebo."""
|
||||
g = torch.Generator().manual_seed(seed)
|
||||
out = {}
|
||||
for name in sorted(v_grad):
|
||||
d = torch.randn(v_grad[name].shape, generator=g) # [k, r]
|
||||
out[name] = (d / d.norm(dim=-1, keepdim=True).clamp_min(1e-12)).to(device)
|
||||
return out
|
||||
|
||||
|
||||
def _build_v_grad(raw_grads: dict, names, k: int, device) -> dict[str, Float[torch.Tensor, "k r"]]:
|
||||
"""Per-module routing directions from authored-pair gradients -> dict[name -> [k, r]].
|
||||
|
||||
k=1 (headline): normalized mean(hack-clean) -- one mean-mass axis. k>1: the top-k
|
||||
oriented right singular vectors of the paired-diff matrix D=[g_hack-g_clean] (SVD +
|
||||
per-pair majority orient, mirroring extract_vhack_grad.extract_v_hack), a rank-k hack
|
||||
subspace. The live gate scores max_i cos(g, v_i). Rows are unit-norm. k=1 is NOT
|
||||
SVD-top-1 (they differ off-isotropic): keeping mean-diff makes 'mean-mass vs top-k'
|
||||
a clean A/B, not a confound."""
|
||||
out = {}
|
||||
for name in names:
|
||||
D: Float[torch.Tensor, "n_pairs r"] = (
|
||||
raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).float()
|
||||
if k == 1:
|
||||
d = D.mean(0)
|
||||
V = (d / d.norm().clamp_min(1e-12)).unsqueeze(0) # [1, r]
|
||||
else:
|
||||
_, _, Vh = torch.linalg.svd(D, full_matrices=False)
|
||||
kk = min(k, Vh.shape[0])
|
||||
V = Vh[:kk] # [kk, r] orthonormal
|
||||
proj = torch.einsum("n r, k r -> n k", D, V) # per-pair projection
|
||||
flip = torch.where((proj > 0).float().sum(0) < D.shape[0] / 2,
|
||||
-torch.ones(kk), torch.ones(kk)) # orient toward reward-hacking gradients
|
||||
V = V * flip.unsqueeze(1)
|
||||
out[name] = V.to(device)
|
||||
return out
|
||||
|
||||
def _otsu3(x: np.ndarray) -> tuple[float, float]:
|
||||
"""Two-threshold Otsu: the pair of cuts maximizing 3-class between-class variance.
|
||||
Label-free -- the routeA gate computes this on a rolling buffer of live scores, so
|
||||
using it is not oracle leakage. Scores are winsorized at 1/99% first: Otsu maximizes
|
||||
variance, so on heavy-tailed scores a single extreme point otherwise buys a whole
|
||||
class (journal 2026-06-11 (d): v5 act rout precision 0.00 -> 0.50 after winsorize).
|
||||
Vectorized over the [n, n] cut grid; n is the buffer size (<= a few hundred)."""
|
||||
x = np.clip(x, *np.quantile(x, [0.01, 0.99]))
|
||||
s = np.sort(np.asarray(x, float))
|
||||
n = len(s)
|
||||
c = np.concatenate([[0.0], np.cumsum(s)])
|
||||
iv = np.arange(1, n)
|
||||
i_g, j_g = iv[:, None], iv[None, :]
|
||||
with np.errstate(divide="ignore", invalid="ignore"):
|
||||
obj = (c[i_g] ** 2 / i_g
|
||||
+ (c[j_g] - c[i_g]) ** 2 / (j_g - i_g)
|
||||
+ (c[n] - c[j_g]) ** 2 / (n - j_g))
|
||||
obj[(j_g <= i_g) | (j_g >= n)] = -np.inf # need i < j and a nonempty top class
|
||||
i, j = np.unravel_index(np.argmax(obj), obj.shape)
|
||||
i, j = iv[i], iv[j]
|
||||
return float((s[i - 1] + s[i]) / 2), float((s[j - 1] + s[j]) / 2)
|
||||
|
||||
|
||||
def _sample_rows(rows: list[dict] | None, n: int, rng: torch.Generator) -> list[dict]:
|
||||
@@ -116,37 +103,15 @@ def _sample_rows(rows: list[dict] | None, n: int, rng: torch.Generator) -> list[
|
||||
return [rows[i] for i in idxs]
|
||||
|
||||
|
||||
def _zone_stats(f: torch.Tensor, w: torch.Tensor) -> tuple[float, ...]:
|
||||
"""Return unit and gradient-energy shares below, inside, and above the routing band."""
|
||||
if f.numel() == 0:
|
||||
return (float("nan"),) * 6
|
||||
lo, hi = (f == 0), (f == 1)
|
||||
mid = ~(lo | hi)
|
||||
tot = w.sum().clamp_min(1e-12)
|
||||
return (lo.float().mean().item(), mid.float().mean().item(), hi.float().mean().item(),
|
||||
((w * lo).sum() / tot).item(), ((w * mid).sum() / tot).item(), ((w * hi).sum() / tot).item())
|
||||
|
||||
|
||||
def _pair_cos(raw_grads: dict, v: Float[torch.Tensor, "k r"], name: str
|
||||
) -> tuple[Float[torch.Tensor, "n_pairs"], Float[torch.Tensor, "n_pairs"]]:
|
||||
"""(clean, hack) pair cosines vs the routing dirs: max_i cos(g, v_i), the same
|
||||
scoring the live gate uses, so band edges and thresholds use the same representation."""
|
||||
gh = raw_grads[f"hack/{name}"].float() # [n_pairs, r]
|
||||
gc = raw_grads[f"clean/{name}"].float()
|
||||
ch = torch.einsum("n r, k r -> n k", gh, v).max(dim=1).values / gh.norm(dim=1).clamp_min(1e-12)
|
||||
cc = torch.einsum("n r, k r -> n k", gc, v).max(dim=1).values / gc.norm(dim=1).clamp_min(1e-12)
|
||||
return cc, ch
|
||||
|
||||
|
||||
def _auroc(scores: list[float], labels: list[bool]) -> float:
|
||||
"""Rank-based AUROC (Mann-Whitney U) of `scores` as a detector of the positive class.
|
||||
|
||||
Higher score for hacks -> auroc > 0.5. nan if either class is absent this step.
|
||||
Diagnostic only: ground-truth labels measure how well cos(g, v_grad) separates
|
||||
reward-hacking updates, but never determine a route. Reading: ~0.5 means v_grad
|
||||
Diagnostic only: ground-truth labels measure how well the gate score separates
|
||||
reward-hacking updates, but never determine a route. Reading: ~0.5 means v_act
|
||||
is a chance-level classifier (no threshold can route reliably); high AUROC but
|
||||
rout~0 = the threshold/scale is wrong, not the direction; a drop across a refresh =
|
||||
the refresh destroyed the separation (the step-5 cliff is then a direction problem)."""
|
||||
the refresh destroyed the separation."""
|
||||
pos = [s for s, y in zip(scores, labels) if y]
|
||||
neg = [s for s, y in zip(scores, labels) if not y]
|
||||
if not pos or not neg:
|
||||
@@ -167,22 +132,6 @@ def _auroc(scores: list[float], labels: list[bool]) -> float:
|
||||
return (sum_pos - n_pos * (n_pos + 1) / 2) / (n_pos * n_neg)
|
||||
|
||||
|
||||
def route_band_edges(raw_grads: dict, v_grad: dict, device) -> dict[str, tuple[float, float]]:
|
||||
"""Calibrate an absolute routing band from authored pairs only.
|
||||
|
||||
Clean/hack p75 edges avoid single-pair extremes and route only the confident
|
||||
tail aligned with reward-hacking gradients. Pair/live shift can still make routing
|
||||
idle; inspect `rout`.
|
||||
See docs/papers/grad_routing/paper_sgtm.md.
|
||||
"""
|
||||
band = {}
|
||||
for name in v_grad:
|
||||
cc, ch = _pair_cos(raw_grads, v_grad[name].detach().cpu().float(), name)
|
||||
band[name] = (cc.quantile(0.75).item(), ch.quantile(0.75).item()) # (lower=p75 clean, upper=p75 hack)
|
||||
return band
|
||||
|
||||
|
||||
|
||||
# Fix evaluation sampling across steps and arms without perturbing the training RNG.
|
||||
EVAL_GEN_SEED = 12345
|
||||
|
||||
@@ -196,12 +145,12 @@ MODE_CODE: dict[str, str] = {
|
||||
|
||||
def _validate_config(cfg: Config) -> None:
|
||||
"""Reject contradictory experiment settings before model load."""
|
||||
if cfg.intervention not in ("none", "routeV", "absorb"):
|
||||
raise ValueError(f"unknown intervention {cfg.intervention!r}; expected none|routeV|absorb")
|
||||
if cfg.routeV_random_v_seed is not None and cfg.intervention != "routeV":
|
||||
raise ValueError("routeV_random_v_seed is a routeV-only placebo control")
|
||||
if cfg.intervention not in ("none", "routeA", "absorb"):
|
||||
raise ValueError(f"unknown intervention {cfg.intervention!r}; expected none|routeA|absorb")
|
||||
if cfg.routeA_random_v_seed is not None and cfg.intervention != "routeA":
|
||||
raise ValueError("routeA_random_v_seed is a routeA-only placebo control")
|
||||
if cfg.rollout_ablate_frac > 0 and cfg.intervention == "none":
|
||||
raise ValueError("rollout_ablate_frac needs a quarantine to ablate (routeV/absorb)")
|
||||
raise ValueError("rollout_ablate_frac needs a quarantine to ablate (routeA/absorb)")
|
||||
if cfg.weight_decay != 0.0:
|
||||
raise ValueError("lora2r init is nonzero; AdamW decay pulls A/B toward 0 not toward init "
|
||||
"-- set --weight-decay=0")
|
||||
@@ -215,7 +164,7 @@ def _validate_config(cfg: Config) -> None:
|
||||
def _log_resolved_config(cfg: Config, device) -> None:
|
||||
"""One block with every None resolved to its effective value, so a detached log
|
||||
shows exactly what ran -- especially WHICH pairset (the field readers kept losing)."""
|
||||
is_routeV = cfg.intervention == "routeV"
|
||||
is_routeA = cfg.intervention == "routeA"
|
||||
fields = {
|
||||
"preset/arm": f"{cfg.preset_name} / {cfg.arm}",
|
||||
"intervention": cfg.intervention,
|
||||
@@ -225,8 +174,8 @@ def _log_resolved_config(cfg: Config, device) -> None:
|
||||
"lora_r/init_seed": f"{cfg.lora_r} / {cfg.lora_init_seed}",
|
||||
"unhackable_frac": cfg.unhackable_frac,
|
||||
"env_mode": cfg.env_mode,
|
||||
"pairset": cfg.vhack_pairs_path if is_routeV else "unused (not routeV)",
|
||||
"routeV placebo seed": cfg.routeV_random_v_seed if is_routeV else "n/a",
|
||||
"pairset": cfg.vhack_pairs_path if is_routeA else "unused (not routeA)",
|
||||
"routeA placebo seed": cfg.routeA_random_v_seed if is_routeA else "n/a",
|
||||
"teacher pool/mix/off_step": (
|
||||
f"{cfg.teacher_pool_dir.name} / {cfg.mix_ratio} / {cfg.teacher_off_step}"
|
||||
if cfg.teacher_pool_dir else "none (pure on-policy)"),
|
||||
@@ -254,10 +203,10 @@ def main(cfg: Config) -> int:
|
||||
logger.info(f"verbose log: {verbose_log}")
|
||||
_log_resolved_config(cfg, device)
|
||||
|
||||
is_routeV = cfg.intervention == "routeV"
|
||||
is_routeA = cfg.intervention == "routeA"
|
||||
is_absorb = cfg.intervention == "absorb"
|
||||
is_vanilla = cfg.intervention == "none"
|
||||
has_quarantine = is_routeV or is_absorb
|
||||
has_quarantine = is_routeA or is_absorb
|
||||
|
||||
# Only adapter parameters train; the base model remains frozen.
|
||||
tok = AutoTokenizer.from_pretrained(model_name)
|
||||
@@ -276,9 +225,10 @@ def main(cfg: Config) -> int:
|
||||
model.config.use_cache = False
|
||||
|
||||
# ── adapter: rank-2r LoRA, deployed block [:r] + quarantine block [r:] ──
|
||||
# routeV needs the per-rollout c-probe gate; none/absorb pin the masks instead.
|
||||
# The routeA gate reads activations via forward hooks; no grad probe in training
|
||||
# (the c-probe stays in lora2r only for scripts/diag_pinning.py diagnostics).
|
||||
wrappers = wrap_model_with_lora2r(
|
||||
model, r=cfg.lora_r, init_seed=cfg.lora_init_seed, grad_probe=is_routeV)
|
||||
model, r=cfg.lora_r, init_seed=cfg.lora_init_seed)
|
||||
# A and B both train; quarantine = block slices of the SAME tensors, so there
|
||||
# is no separate hack-param list (per-rollout masks route grads, not surgery).
|
||||
delta_params = [p for info in wrappers.values() for p in (info["A"], info["B"])]
|
||||
@@ -287,51 +237,47 @@ def main(cfg: Config) -> int:
|
||||
logger.info(f"trainable lora2r A+B: {sum(p.numel() for p in delta_params):,} "
|
||||
f"({n_quar:,} of those in quarantine blocks)")
|
||||
|
||||
# ── routeV direction: v_grad (mean pair-gradient diff) + routing band ──
|
||||
v_grad = None # set only by the routeV branch below
|
||||
route_band = None
|
||||
if is_routeV:
|
||||
# ── routeA direction: v_act (mean pooled-act pair diff, unit rows per module) ──
|
||||
v_act = None # [M, r] cpu fp32; module order = act_names
|
||||
act_names = sorted(wrappers)
|
||||
act_buf: deque | None = None # rolling pooled acts [M, r]; re-scored vs the CURRENT
|
||||
# v_act at each gate call, so a refresh needs no flush
|
||||
MASK_PAIRS = None
|
||||
if is_routeA:
|
||||
# Authored pairs are the only routing-label source; live oracle labels never enter training.
|
||||
from .pairs import load_pairs
|
||||
from .extract_vhack_grad import extract_v_hack
|
||||
MASK_PAIRS = load_pairs(cfg.vhack_pairs_path)
|
||||
logger.info(f"routeV pairs: {cfg.vhack_pairs_path} -> {len(MASK_PAIRS)} pairs")
|
||||
model.eval() # match standalone extract: deterministic backward, no dropout
|
||||
_, _, raw_grads, _ = extract_v_hack(
|
||||
model, tok, wrappers, MASK_PAIRS,
|
||||
top_k=1, tau_axis=0.0, n_heldout=2, device=device,
|
||||
)
|
||||
v_grad = _build_v_grad(raw_grads, wrappers, cfg.v_grad_k, device)
|
||||
_vk = "mean-diff (mean-mass)" if cfg.v_grad_k == 1 else f"top-{cfg.v_grad_k} SVD subspace, gate=max_i cos"
|
||||
logger.info(f"routeV grad: built v_grad ({_vk}) for {len(v_grad)} modules")
|
||||
if cfg.routeV_random_v_seed is not None:
|
||||
v_grad = _haar_unit_dirs(v_grad, cfg.routeV_random_v_seed, device)
|
||||
logger.info(f"routeV grad: OVERRODE v_grad with Haar-random dirs "
|
||||
f"(seed={cfg.routeV_random_v_seed}) -- placebo directionality control")
|
||||
# Calibrate after any Haar override so the control covers the full routing pipeline.
|
||||
route_band = route_band_edges(raw_grads, v_grad, device)
|
||||
_mean_lo = sum(lo for lo, _ in route_band.values()) / len(route_band)
|
||||
_mean_hi = sum(hi for _, hi in route_band.values()) / len(route_band)
|
||||
_mean_bw = _mean_hi - _mean_lo
|
||||
n_inc_band = sum(1 for lo, hi in route_band.values() if hi - lo > 0)
|
||||
logger.info(
|
||||
f"routeV band: {len(route_band)} modules, mean lower(p75 clean cos)={_mean_lo:+.3f}, "
|
||||
f"mean upper(p75 hack cos)={_mean_hi:+.3f}, mean width={_mean_bw:+.3f}; "
|
||||
f"{n_inc_band}/{len(route_band)} modules have positive band width (included in the gate). "
|
||||
f"SHOULD: width>0 (pairs separate) and most modules included; ELSE extraction/band off.")
|
||||
# Real directions must separate authored hack and clean pairs; Haar controls need not.
|
||||
if cfg.routeV_random_v_seed is None:
|
||||
assert _mean_bw > 0, (
|
||||
f"real v_grad gave non-positive mean band width {_mean_bw:+.3f}: "
|
||||
"hack pairs do not separate from clean -> extraction broken")
|
||||
logger.info(
|
||||
"lora2r three-way output mask: per-rollout label from the width-pooled "
|
||||
"band-normalized cosine across modules; clean->deployed-only, "
|
||||
"hack->quarantine-only (deployed detached), mid->both (may permit absorption). "
|
||||
"SHOULD: rout (hack share) tracks the step's rollout hack rate, not ~50%; "
|
||||
"clipfrac on clean-gated rollouts < ~0.2 ELSE the retain-trick ratio "
|
||||
"drift is binding (quarantine forward too large).")
|
||||
logger.info(f"routeA pairs: {cfg.vhack_pairs_path} -> {len(MASK_PAIRS)} pairs")
|
||||
model.eval() # deterministic forward, no dropout
|
||||
v_act, pair_acts = extract_v_act(model, tok, wrappers, MASK_PAIRS, device,
|
||||
tstat=cfg.vact_tstat)
|
||||
model.train()
|
||||
# Authored-pair separation in the live score. The dot GAP is > 0 by construction
|
||||
# (v is proportional to the mean pair diff); the pair AUROC is not, so it is the
|
||||
# extraction sanity signal.
|
||||
sh = torch.einsum("pmr,mr->p", pair_acts["hack"], v_act)
|
||||
sc = torch.einsum("pmr,mr->p", pair_acts["clean"], v_act)
|
||||
pair_auroc = _auroc(torch.cat([sh, sc]).tolist(),
|
||||
[True] * len(sh) + [False] * len(sc))
|
||||
logger.info(
|
||||
f"routeA v_act: {v_act.shape[0]} modules x r={v_act.shape[1]} "
|
||||
f"(tstat={cfg.vact_tstat}); authored-pair dot gap={(sh.mean() - sc.mean()).item():+.3e}, "
|
||||
f"pair AUROC={pair_auroc:.2f}. SHOULD: pair AUROC ~1.0 ELSE extraction broken.")
|
||||
if cfg.routeA_random_v_seed is not None:
|
||||
v_act = haar_unit_rows(tuple(v_act.shape), cfg.routeA_random_v_seed)
|
||||
logger.info(f"routeA: OVERRODE v_act with Haar-random unit rows "
|
||||
f"(seed={cfg.routeA_random_v_seed}) -- placebo directionality control")
|
||||
act_buf = deque(maxlen=cfg.route_buffer)
|
||||
logger.info(
|
||||
f"routeA gate: per-rollout score = dot(pooled completion-token act, v_act), "
|
||||
f"thresholds = two-threshold Otsu on the last <= {cfg.route_buffer} live scores "
|
||||
f"(z-normalized, winsorized 1/99%), label-free; pinned absorb until "
|
||||
f"{cfg.route_warmup} scores. keep (0,0) | absorb (1,0) | rout (1,1: deployed "
|
||||
f"detached). No bimodality guard: on the cached emergence windows no shape "
|
||||
f"statistic separates the hack mixture from hack-free scores (Otsu tail means "
|
||||
f"sit ~2.4 sd apart even on a Gaussian), and a false rout only discards one "
|
||||
f"update from deployment. "
|
||||
f"SHOULD: auroc col >> 0.5 once hacks appear ELSE v_act is blind and routing "
|
||||
f"is noise; rout tracks the hack share, not ~0 or ~1.")
|
||||
|
||||
# ── teacher pool ──
|
||||
# Teacher pool: pre-generated rollouts on disk keyed by problem_id. Each step's
|
||||
@@ -348,7 +294,7 @@ def main(cfg: Config) -> int:
|
||||
G_t = 0
|
||||
if cfg.teacher_pool_dir is not None:
|
||||
# mix=0 is the NO-TEACHER ablation: pure on-policy GRPO (G_t=0) while the
|
||||
# pool is still loaded for the partition + routeV v_grad extraction.
|
||||
# pool is still loaded for the partition.
|
||||
if not (0.0 <= cfg.mix_ratio < 1.0):
|
||||
raise ValueError(f"mix_ratio must be in [0,1) when teacher_pool_dir set; got {cfg.mix_ratio}")
|
||||
G_t = round(group * cfg.mix_ratio)
|
||||
@@ -568,57 +514,31 @@ def main(cfg: Config) -> int:
|
||||
|
||||
save_ckpt([], path=run_dir / "ckpt_update0000.safetensors")
|
||||
|
||||
def _lora2r_gate_labels(c_grads: tuple, n_rollouts: int):
|
||||
"""Three-way output-mask label per rollout from the gate-pass c-probe grads.
|
||||
def _routeA_gate(dots: Float[torch.Tensor, "G"]) -> tuple[torch.Tensor, torch.Tensor, float, float]:
|
||||
"""Three-way output-mask label per rollout from the rolling score buffer.
|
||||
|
||||
Per module the per-rollout weight grad of the virtual diagonal (deployed
|
||||
block [r]) has a band-normalized cosine position. We POOL across modules in
|
||||
a single (num, den) fraction (T3 fix): a module with a wide band contributes
|
||||
proportionally more than a noisy near-zero-width one, instead of every module
|
||||
casting an equal-weight vote. One global label per rollout (matching SGTM's
|
||||
example-level labels): pos<=0 clean (m=0,d=0); pos>=1 hack (m=1,d=1); else mid
|
||||
(m=1,d=0, both blocks train). Returns (m, d, f3, w, pos, cosU): f3 in {0,.5,1} for
|
||||
_zone_stats, w = mean per-rollout grad norm for energy weighting, pos = the raw
|
||||
per-rollout pooled position (for the AUROC diagnostic), cosU = pooled cos of the
|
||||
SUMMED-rollout c-grad (the update direction) to v_grad."""
|
||||
num = torch.zeros(n_rollouts, device=device); den = 0.0
|
||||
w = torch.zeros(n_rollouts, device=device); n_inc = 0
|
||||
cosU_sum = 0.0
|
||||
for (name, info), cg in zip(wrappers.items(), c_grads, strict=True):
|
||||
lower, upper = route_band[name]
|
||||
if upper - lower <= 0: # noisy module: pairs don't separate -> excluded
|
||||
continue
|
||||
r_blk = info["r"]
|
||||
g_b: Float[torch.Tensor, "G r"] = cg.reshape(n_rollouts, -1, 2 * r_blk).sum(1)[:, :r_blk].float()
|
||||
nrm = g_b.norm(dim=1)
|
||||
# cos to each of the k routing dirs, then max: aligned with ANY known hack sub-mode.
|
||||
cos_b = torch.einsum("g r, k r -> g k", g_b, v_grad[name]).max(dim=1).values / nrm.clamp_min(1e-12)
|
||||
num += cos_b - lower; den += upper - lower
|
||||
w += nrm; n_inc += 1
|
||||
gs = g_b.sum(0) # summed-rollout deployed c-grad = this module's update direction
|
||||
cosU_sum += (torch.einsum("r, k r -> k", gs, v_grad[name]).max() / gs.norm().clamp_min(1e-12)).item()
|
||||
if n_inc == 0:
|
||||
raise RuntimeError("no module has positive band width; pairs separate nowhere")
|
||||
pos = num / den; w /= n_inc
|
||||
# ── online-stats gate (#30): band from THIS batch's pooled positions, three zones ──
|
||||
# The authored absolute band is mis-placed (live pos sits far below the synthetic-hack
|
||||
# edge -- out/diag/pinning_calib.png) and a fixed quantile FORCES route_quantile out every
|
||||
# step even when nothing separates. Calibrate the band from the CURRENT batch instead:
|
||||
# refresh-proof by construction (these rollouts scored against the current v_grad), no
|
||||
# window or flush to keep stale positions around. mean + k*std self-silences -- only the
|
||||
# tail genuinely beyond the spread routes, so qmass tracks real separation. pos >
|
||||
# mean+route_std_mid*std -> mid (both blocks train); pos >= mean+route_std_rout*std -> rout
|
||||
# (hack, deployed detached); below -> keep (bulk). Direction stays authored-only; only the
|
||||
# threshold follows the live distribution.
|
||||
mu_pos, sd_pos = pos.mean().item(), pos.std().item()
|
||||
t_lo = mu_pos + cfg.route_std_mid * sd_pos # mid onset
|
||||
t_hi = mu_pos + cfg.route_std_rout * sd_pos # rout onset (>= mid)
|
||||
logger.debug(f"gate pos: n_inc={n_inc}/{len(wrappers)} n_roll={n_rollouts} "
|
||||
f"min={pos.min().item():+.2f} max={pos.max().item():+.2f} | "
|
||||
f"mean={mu_pos:+.2f} std={sd_pos:.2f} t_lo={t_lo:+.2f} t_hi={t_hi:+.2f}")
|
||||
m = (pos > t_lo).float() # mid + rout -> quarantine trains (keep = bulk below t_lo)
|
||||
d = (pos >= t_hi).float() # top tail -> hack -> deployed detached
|
||||
return m, d, 0.5 * m + 0.5 * d, w, pos, cosU_sum / n_inc
|
||||
The buffer holds pooled ACTS, so every gate call scores the whole window
|
||||
against the CURRENT v_act (refresh-proof; the only staleness left is act
|
||||
drift as the adapter trains, small over <= route_buffer rollouts). Scores
|
||||
are z-normalized by the buffer mean/std, then two-threshold Otsu (winsorized
|
||||
inside _otsu3) places (t_lo, t_hi): z <= t_lo keep (0,0); t_lo < z < t_hi
|
||||
absorb (1,0, both blocks train); z >= t_hi rout (1,1, deployed detached).
|
||||
Warmup: pinned absorb until the buffer holds route_warmup scores -- too few
|
||||
points to place thresholds, and absorb keeps both blocks learning."""
|
||||
if len(act_buf) < cfg.route_warmup:
|
||||
G_n = dots.shape[0]
|
||||
return (torch.ones(G_n, device=device), torch.zeros(G_n, device=device),
|
||||
float("nan"), float("nan"))
|
||||
S = torch.einsum("nmr,mr->n", torch.stack(tuple(act_buf)), v_act)
|
||||
mu, sd = S.mean().item(), max(S.std().item(), 1e-12)
|
||||
t_lo, t_hi = _otsu3(((S - mu) / sd).numpy())
|
||||
z = (dots - mu) / sd
|
||||
m = (z > t_lo).float().to(device) # absorb + rout -> quarantine trains
|
||||
d = (z >= t_hi).float().to(device) # top zone -> rout -> deployed detached
|
||||
logger.debug(f"routeA gate: buf={len(act_buf)} mu={mu:+.3e} sd={sd:.3e} "
|
||||
f"t_lo={t_lo:+.2f} t_hi={t_hi:+.2f} (z) | batch z "
|
||||
f"min={z.min().item():+.2f} max={z.max().item():+.2f}")
|
||||
return m, d, t_lo, t_hi
|
||||
|
||||
# Disable tqdm off-TTY because structured per-step rows already report progress.
|
||||
pbar = tqdm(range(steps), desc=f"train {cfg.arm} {cfg.preset_name}",
|
||||
@@ -634,7 +554,7 @@ def main(cfg: Config) -> int:
|
||||
logger.info(f"teacher-off curriculum: step {step} >= {cfg.teacher_off_step} "
|
||||
f"-> teachers off (pure on-policy on the wider problem set from here)")
|
||||
# mix_ratio>0 is the teacher on/off switch (mix=0 = no-teacher ablation: pool still
|
||||
# loaded for v_grad extraction, but no demos injected). The COUNT is teacher_n_per_prompt.
|
||||
# loaded for the partition, but no demos injected). The COUNT is teacher_n_per_prompt.
|
||||
teachers_on = (not teacher_off) and cfg.mix_ratio > 0 \
|
||||
and bool(covered_problems) and bool(teacher_pool or solve_pool)
|
||||
t0 = time.time()
|
||||
@@ -651,13 +571,15 @@ def main(cfg: Config) -> int:
|
||||
n_zerovar = 0 # groups skipped for zero reward variance (all rollouts same reward).
|
||||
agg_loss = 0.0
|
||||
diag_tail = None
|
||||
# routeV gate diagnostics (per-rollout three-way zone shares + retain-trick clipfrac).
|
||||
step_flagged: list[float] = [] # hack share (mean d over rollouts) per prompt
|
||||
step_clipfrac: list[float] = [] # PPO clip frac on clean-gated rollouts (retain-trick drift gauge)
|
||||
step_zkeep: list[float] = []; step_zresid: list[float] = []; step_zrout: list[float] = [] # unit shares per zone
|
||||
step_zkeepE: list[float] = []; step_zresidE: list[float] = []; step_zroutE: list[float] = [] # energy shares per zone
|
||||
# AUROC diagnostic: per-rollout pooled pos + its hack-label, accumulated across prompts.
|
||||
step_auroc_pos: list[float] = []; step_auroc_hack: list[bool] = []; step_cosU: list[float] = []
|
||||
# routeA gate diagnostics (per-rollout three-way zone shares + clean-gated clipfrac).
|
||||
step_clipfrac: list[float] = [] # PPO clip frac on keep-gated rollouts (ratio-drift gauge)
|
||||
step_zkeep: list[float] = []; step_zresid: list[float] = []; step_zrout: list[float] = [] # unit shares per zone
|
||||
step_tlo: list[float] = []; step_thi: list[float] = [] # Otsu thresholds (z units)
|
||||
# AUROC diagnostic on the A>0 contrast: scores + hack-labels of positively-
|
||||
# reinforced rollouts only (where the advantage alone is blind), students +
|
||||
# cached teachers. Accumulated across prompts; measurement only, never routes.
|
||||
step_auroc_score: list[float] = []; step_auroc_hack: list[bool] = []
|
||||
step_cos: list[float] = [] # mean per-rollout cos(act, v_act) (dot-vs-cos diagnostic)
|
||||
# Solve-mix discrimination: routed-share (mean d) over hack-teacher vs solve-teacher rollouts.
|
||||
step_route_hackT: list[float] = []; step_route_solveT: list[float] = []
|
||||
|
||||
@@ -870,18 +792,33 @@ def main(cfg: Config) -> int:
|
||||
# logπ_old: old-policy logprobs (frozen PPO-ratio target). logits_to_keep
|
||||
# =L_c+1 runs lm_head only on completion-side hidden states; [:, :-1] drops
|
||||
# the last position (predicts beyond `merged`, unused).
|
||||
# For routeA this forward runs QUARANTINE-ABLATED, matching both the sampling
|
||||
# policy (gen_students is deploy-mode) and the v_act extraction (quarantine-
|
||||
# ablated), so the gate score and the vector live on the same observable path.
|
||||
# The same forward carries the ActCapture hooks: the gate costs no extra pass.
|
||||
completion_ids = merged[:, plen:]
|
||||
L_c = completion_ids.shape[1]
|
||||
mask = (completion_ids != pad_id).float()
|
||||
_tfb = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
logπ_old = per_token_logps(
|
||||
model(merged, logits_to_keep=L_c + 1).logits[:, :-1],
|
||||
completion_ids,
|
||||
).detach()
|
||||
if is_routeA:
|
||||
with torch.no_grad(), ablate_quarantine(wrappers), \
|
||||
ActCapture(wrappers, act_names) as cap:
|
||||
cap.set_pool(plen, mask)
|
||||
logπ_old = per_token_logps(
|
||||
model(merged, logits_to_keep=L_c + 1).logits[:, :-1],
|
||||
completion_ids,
|
||||
).detach()
|
||||
acts = cap.pooled().cpu() # [G, M, r] fp32
|
||||
else:
|
||||
with torch.no_grad():
|
||||
logπ_old = per_token_logps(
|
||||
model(merged, logits_to_keep=L_c + 1).logits[:, :-1],
|
||||
completion_ids,
|
||||
).detach()
|
||||
|
||||
# Pin block masks for the non-gated arms BEFORE the grad-carrying forward
|
||||
# (arm semantics: train_config.py docstring): none -> (0,0), absorb -> (1,0).
|
||||
# routeV leaves mask=None so the gate pass sees an unmasked forward.
|
||||
# Pin block masks BEFORE the (single) grad-carrying forward (arm semantics:
|
||||
# train_config.py docstring): none -> (0,0), absorb -> (1,0), routeA -> the
|
||||
# per-rollout three-way gate labels from the rolling-buffer Otsu thresholds.
|
||||
if is_vanilla:
|
||||
_z = torch.zeros(merged.shape[0], device=device)
|
||||
for info in wrappers.values():
|
||||
@@ -891,13 +828,42 @@ def main(cfg: Config) -> int:
|
||||
_z = torch.zeros(merged.shape[0], device=device)
|
||||
for info in wrappers.values():
|
||||
info["layer"]._lora2r_mask = (_o, _z)
|
||||
elif is_routeA:
|
||||
dots = torch.einsum("gmr,mr->g", acts, v_act) # [G]
|
||||
# cos = dot / (||act|| ||v||); v rows are unit so ||v|| = sqrt(M).
|
||||
coss = dots / (acts.flatten(1).norm(dim=1)
|
||||
* math.sqrt(len(act_names))).clamp_min(1e-12)
|
||||
step_cos.append(coss.mean().item())
|
||||
act_buf.extend(acts.unbind(0))
|
||||
m_vec, d_vec, _tl, _th = _routeA_gate(dots)
|
||||
for info in wrappers.values():
|
||||
info["layer"]._lora2r_mask = (m_vec, d_vec)
|
||||
step_tlo.append(_tl); step_thi.append(_th)
|
||||
step_zkeep.append((m_vec == 0).float().mean().item())
|
||||
step_zresid.append(((m_vec == 1) & (d_vec == 0)).float().mean().item())
|
||||
step_zrout.append((d_vec == 1).float().mean().item())
|
||||
# AUROC diagnostic on the A>0 contrast: merged order is [students;
|
||||
# teachers], the same order hack_flags was built in, so dots aligns.
|
||||
pos_a = (A > 0).cpu().tolist()
|
||||
step_auroc_score.extend(s for s, p in zip(dots.tolist(), pos_a) if p)
|
||||
step_auroc_hack.extend(bool(h) for h, p in zip(hack_flags, pos_a) if p)
|
||||
# Solve-mix discrimination: teachers are the LAST G_t rows of merged; split
|
||||
# their routed-share (mean d) by source. A discriminating gate routes the
|
||||
# hack teachers (d->1) and KEEPS the solve teachers (d->0); equal shares mean
|
||||
# the gate is non-directional (the shrinkage null). Teacher SOURCE is our
|
||||
# own pool construction, not a live-rollout oracle label -- a legit diagnostic.
|
||||
if teacher_is_solve:
|
||||
is_solve_t = torch.tensor(teacher_is_solve, device=d_vec.device, dtype=torch.bool)
|
||||
d_teach = d_vec[-len(teacher_is_solve):]
|
||||
if (~is_solve_t).any():
|
||||
step_route_hackT.append(d_teach[~is_solve_t].mean().item())
|
||||
if is_solve_t.any():
|
||||
step_route_solveT.append(d_teach[is_solve_t].mean().item())
|
||||
|
||||
logπ = per_token_logps(
|
||||
model(merged, logits_to_keep=L_c + 1).logits[:, :-1],
|
||||
completion_ids,
|
||||
)
|
||||
|
||||
mask = (merged[:, plen:] != pad_id).float()
|
||||
# Per-rollout mean per-token logπ_old (student's logp on its own tokens).
|
||||
# Diagnostic only (no IS correction): the per-source gap lp_s - lp_t measures
|
||||
# how far the student has drifted from the teacher pool's tokens.
|
||||
@@ -914,56 +880,20 @@ def main(cfg: Config) -> int:
|
||||
ptl = (Lp_ * mask).sum(1) / mask.sum(1).clamp_min(1)
|
||||
return ptl.sum() / (group * prompts_per_step)
|
||||
|
||||
# Three-way output masking; gradients accumulate on A/B.
|
||||
# Gradient-space labels exist only AFTER a backward (labels: before forward;
|
||||
# activations: before backward; grads: after), so routeV pays a second masked
|
||||
# forward+backward. none/absorb were pinned before the logπ forward and need
|
||||
# only this one pass.
|
||||
# One masked forward+backward for EVERY arm; rollouts route to BLOCKS via
|
||||
# the output masks pinned above (nothing is subtracted from any gradient
|
||||
# vector; v_act is a classifier only). Gradients accumulate on A/B.
|
||||
loss = _grpo_loss(Lp)
|
||||
if is_routeV:
|
||||
# PASS 1 (gate): grads w.r.t. the c-probes ONLY. autograd.grad leaves
|
||||
# A.grad/B.grad untouched, so nothing to zero between passes.
|
||||
gates = [info["layer"]._lora2r_gate for info in wrappers.values()]
|
||||
c_grads = torch.autograd.grad(loss, gates)
|
||||
m_vec, d_vec, f3, w3, pos_vec, cosU = _lora2r_gate_labels(c_grads, merged.shape[0])
|
||||
step_flagged.append(d_vec.mean().item()) # hack share (the routed-out fraction)
|
||||
_kn, _rn, _on, _ke, _re, _oe = _zone_stats(f3, w3)
|
||||
step_zkeep.append(_kn); step_zresid.append(_rn); step_zrout.append(_on)
|
||||
step_zkeepE.append(_ke); step_zresidE.append(_re); step_zroutE.append(_oe)
|
||||
# AUROC diagnostic: pos as a hack-detector vs the hack-label (student
|
||||
# exploited + teacher cached). merged order is [students; teachers], the
|
||||
# same order hack_flags was built in, so pos_vec aligns with hack_flags.
|
||||
step_auroc_pos.extend(pos_vec.detach().cpu().tolist())
|
||||
step_auroc_hack.extend(bool(h) for h in hack_flags)
|
||||
step_cosU.append(cosU)
|
||||
# Solve-mix discrimination: teachers are the LAST G_t rows of merged; split
|
||||
# their routed-share (mean d) by source. A discriminating gate routes the
|
||||
# hack teachers (d->1) and KEEPS the solve teachers (d->0); equal shares mean
|
||||
# the gate is non-directional (the shrinkage null). Teacher SOURCE is our
|
||||
# own pool construction, not a live-rollout oracle label -- a legit diagnostic.
|
||||
if teacher_is_solve:
|
||||
is_solve_t = torch.tensor(teacher_is_solve, device=d_vec.device, dtype=torch.bool)
|
||||
d_teach = d_vec[-len(teacher_is_solve):]
|
||||
if (~is_solve_t).any():
|
||||
step_route_hackT.append(d_teach[~is_solve_t].mean().item())
|
||||
if is_solve_t.any():
|
||||
step_route_solveT.append(d_teach[is_solve_t].mean().item())
|
||||
# PASS 2 (masked): rollouts route to BLOCKS via masked sums; nothing is
|
||||
# subtracted from any gradient vector (v_grad = classifier only).
|
||||
for info in wrappers.values():
|
||||
info["layer"]._lora2r_mask = (m_vec, d_vec)
|
||||
logπ2 = per_token_logps(
|
||||
model(merged, logits_to_keep=L_c + 1).logits[:, :-1], completion_ids)
|
||||
ρ2 = torch.exp(logπ2 - logπ_old)
|
||||
loss = _grpo_loss(-torch.min(ρ2 * A_tok,
|
||||
torch.clamp(ρ2, 1 - cfg.clip, 1 + cfg.clip) * A_tok))
|
||||
# Retain-trick wrinkle: clean rollouts were SAMPLED quarantine-on but TRAIN
|
||||
# quarantine-off; the PPO ratio absorbs the gap, clip bounds it.
|
||||
clean = m_vec == 0
|
||||
if clean.any():
|
||||
clipped = ((ρ2.detach() - 1).abs() > cfg.clip).float()
|
||||
if is_routeA:
|
||||
# Keep-gated rollouts train quarantine-off, the exact state generation
|
||||
# and logπ_old used, so their ratio sits ~1. Absorb/rout rollouts see
|
||||
# the quarantine delta in the forward only -> ratio drift, bounded by
|
||||
# the clip; clipfrac on those rollouts is the drift gauge.
|
||||
qon = m_vec == 1
|
||||
if qon.any():
|
||||
clipped = ((ρ.detach() - 1).abs() > cfg.clip).float()
|
||||
step_clipfrac.append(
|
||||
((clipped * mask)[clean].sum() / mask[clean].sum().clamp_min(1)).item())
|
||||
((clipped * mask)[qon].sum() / mask[qon].sum().clamp_min(1)).item())
|
||||
loss.backward() # A/B grads accumulate across prompts (opt.zero_grad clears per step)
|
||||
for info in wrappers.values():
|
||||
info["layer"]._lora2r_mask = None
|
||||
@@ -990,38 +920,26 @@ def main(cfg: Config) -> int:
|
||||
opt.step()
|
||||
sched.step()
|
||||
|
||||
# ── v_grad refresh ──
|
||||
# ── v_act refresh ──
|
||||
# Re-extract the routing direction against the CURRENT model so it tracks where
|
||||
# hacks separate now, not at step 0. Without this the frozen direction goes stale.
|
||||
# Same MASK_PAIRS (the authored pairs, no oracle); quarantine ablated so the hack
|
||||
# signal flows back through the observable path, matching the build-time extract.
|
||||
# signal is read on the deployed observable path, matching the build-time extract
|
||||
# and the gate forward. Forward-only, so the refresh is cheap. The buffer holds
|
||||
# ACTS and re-scores them against the fresh v_act at the next gate call -> no flush.
|
||||
refr = "-"
|
||||
do_refresh = cfg.vhack_refresh_every > 0 and (step + 1) % cfg.vhack_refresh_every == 0
|
||||
if do_refresh and is_routeV and cfg.routeV_random_v_seed is not None:
|
||||
do_refresh = False # keep the one fixed Haar draw; re-extracting would replace it
|
||||
if do_refresh and is_routeV:
|
||||
do_refresh = (is_routeA and cfg.vhack_refresh_every > 0
|
||||
and (step + 1) % cfg.vhack_refresh_every == 0
|
||||
and cfg.routeA_random_v_seed is None) # placebo keeps its one Haar draw
|
||||
if do_refresh:
|
||||
_was_training = model.training
|
||||
model.eval()
|
||||
opt.zero_grad(set_to_none=True)
|
||||
logger.disable("vgrout.extract_vhack_grad")
|
||||
logger.disable("__main__")
|
||||
try:
|
||||
with ablate_quarantine(wrappers):
|
||||
from .extract_vhack_grad import extract_v_hack
|
||||
_, _, raw_grads, _ = extract_v_hack(
|
||||
model, tok, wrappers, MASK_PAIRS,
|
||||
top_k=1, tau_axis=0.0, n_heldout=2, device=device,
|
||||
)
|
||||
# update in place so the gate closure sees the fresh dirs (same k as init).
|
||||
v_grad.update(_build_v_grad(raw_grads, wrappers, cfg.v_grad_k, device))
|
||||
route_band = route_band_edges(raw_grads, v_grad, device) # rebuild band on fresh v_grad
|
||||
finally:
|
||||
logger.enable("vgrout.extract_vhack_grad")
|
||||
logger.enable("__main__")
|
||||
opt.zero_grad(set_to_none=True) # extract leaves .grad populated
|
||||
with ablate_quarantine(wrappers):
|
||||
v_act, _ = extract_v_act(model, tok, wrappers, MASK_PAIRS, device,
|
||||
tstat=cfg.vact_tstat)
|
||||
if _was_training:
|
||||
model.train()
|
||||
refr = "rfr" # gate calibrates from the live batch each step -> no window to flush
|
||||
refr = "rfr"
|
||||
|
||||
# ── periodic held-out eval (deploy = quarantine ablated) ──
|
||||
hack_deployed = solve_deployed = float("nan")
|
||||
@@ -1114,13 +1032,14 @@ def main(cfg: Config) -> int:
|
||||
f"step {step} TIMING gen={t_gen:.0f}s fwd_bwd={t_fb:.0f}s "
|
||||
f"reward={t_rew:.0f}s other={_tstep - t_gen - t_fb - t_rew:.0f}s total={_tstep:.0f}s")
|
||||
if step_clipfrac:
|
||||
logger.debug(f"routeV clean-gated clipfrac={sum(step_clipfrac)/len(step_clipfrac):.3f} "
|
||||
f"(SHOULD: <~0.2; higher = retain-trick ratio drift binding)")
|
||||
logger.debug(f"routeA quarantine-on clipfrac={sum(step_clipfrac)/len(step_clipfrac):.3f} "
|
||||
f"(SHOULD: <~0.2; higher = quarantine forward delta drifting far "
|
||||
f"from the ablated old policy)")
|
||||
if step_route_hackT or step_route_solveT:
|
||||
_rh = sum(step_route_hackT) / len(step_route_hackT) if step_route_hackT else float("nan")
|
||||
_rs = sum(step_route_solveT) / len(step_route_solveT) if step_route_solveT else float("nan")
|
||||
route_hackT_run.append(_rh); route_solveT_run.append(_rs)
|
||||
logger.debug(f"routeV solve-mix discrimination: hack-teacher routed={_rh:.2f} vs "
|
||||
logger.debug(f"routeA solve-mix discrimination: hack-teacher routed={_rh:.2f} vs "
|
||||
f"solve-teacher routed={_rs:.2f} (SHOULD: hack >> solve -> gate "
|
||||
f"discriminates correct-solution from reward-hacking updates; ~equal -> non-directional/shrinkage)")
|
||||
if diag_tail is not None:
|
||||
@@ -1148,15 +1067,16 @@ def main(cfg: Config) -> int:
|
||||
"lp_t": lp_t_mean if n_t else None,
|
||||
"loss": agg_loss,
|
||||
"gn": gn,
|
||||
"auroc": _auroc(step_auroc_pos, step_auroc_hack),
|
||||
"cosU": (sum(step_cosU) / len(step_cosU)) if step_cosU else float("nan"),
|
||||
# auroc is the A>0 contrast (hack vs non-hack among positively-reinforced
|
||||
# rollouts) -- the contrast where the reward alone is blind.
|
||||
"auroc": _auroc(step_auroc_score, step_auroc_hack),
|
||||
"cos": (sum(step_cos) / len(step_cos)) if step_cos else float("nan"),
|
||||
"qmass": q_egy,
|
||||
"keep": (sum(step_zkeep) / len(step_zkeep)) if step_zkeep else float("nan"),
|
||||
"resid": (sum(step_zresid) / len(step_zresid)) if step_zresid else float("nan"),
|
||||
"rout": (sum(step_zrout) / len(step_zrout)) if step_zrout else float("nan"),
|
||||
"keepE": (sum(step_zkeepE) / len(step_zkeepE)) if step_zkeepE else float("nan"),
|
||||
"residE": (sum(step_zresidE) / len(step_zresidE)) if step_zresidE else float("nan"),
|
||||
"routE": (sum(step_zroutE) / len(step_zroutE)) if step_zroutE else float("nan"),
|
||||
"tlo": (sum(step_tlo) / len(step_tlo)) if step_tlo else float("nan"),
|
||||
"thi": (sum(step_thi) / len(step_thi)) if step_thi else float("nan"),
|
||||
"lr": sched.get_last_lr()[0],
|
||||
"refr": refr,
|
||||
# Deploy-eval (quarantine ablated); NaN except on eval steps.
|
||||
@@ -1236,20 +1156,17 @@ def main(cfg: Config) -> int:
|
||||
solve_rate_s = gt_s_total / max(1, n_s_total)
|
||||
hack_rate_t = hack_t_total / max(1, n_t_total)
|
||||
|
||||
# routeV/absorb must move the quarantine; none must leave it exactly zero. The
|
||||
# quarantine LEARNED delta = (A[r:], B[:,r:]) minus the frozen init.
|
||||
# routeA/absorb must move the quarantine; none must leave it exactly zero. The
|
||||
# quarantine LEARNED delta = (A[r:], B[:,r:]) minus the frozen init. The routeA
|
||||
# warmup pins absorb, so even a placebo run trains the quarantine.
|
||||
dsh_norm = float(sum(
|
||||
(info["A"].data[info["r"]:] - info["A0"][info["r"]:]).float().pow(2).sum().item()
|
||||
+ (info["B"].data[:, info["r"]:] - info["B0"][:, info["r"]:]).float().pow(2).sum().item()
|
||||
for info in wrappers.values()) ** 0.5)
|
||||
logger.info(f"||quarantine learned delta|| = {dsh_norm:.4f} "
|
||||
f"(SHOULD: >0 for routeV/absorb, ==0 for none; ELSE routing broke)")
|
||||
if has_quarantine and cfg.routeV_random_v_seed is None:
|
||||
f"(SHOULD: >0 for routeA/absorb, ==0 for none; ELSE routing broke)")
|
||||
if has_quarantine:
|
||||
assert dsh_norm > 0.0, f"{cfg.intervention}: quarantine never moved -> nothing trained it"
|
||||
elif cfg.routeV_random_v_seed is not None and dsh_norm == 0.0:
|
||||
# A Haar control may validly route nothing because no rollout clears its band.
|
||||
logger.warning("routeV Haar control: ||quarantine delta||==0 -> the random direction routed "
|
||||
"NOTHING. This is a real result (favours: alignment needed), not a failure.")
|
||||
|
||||
# Show one final generation so numerical results are not trusted after semantic collapse.
|
||||
if last_gen_sample is not None:
|
||||
|
||||
+29
-32
@@ -5,7 +5,7 @@ masking; see src/vgrout/lora2r.py) and three arms:
|
||||
|
||||
none gate pinned clean (0,0): quarantine never trains -- the capacity- and
|
||||
structure-matched vanilla control.
|
||||
routeV per-rollout three-way gate from the c-probe gradient vs v_grad.
|
||||
routeA per-rollout three-way gate from the pooled bottleneck activation vs v_act.
|
||||
absorb gate pinned mid (1,0): both blocks train on everything, no gate --
|
||||
tests ungated both-block training.
|
||||
"""
|
||||
@@ -20,7 +20,7 @@ from .rewards import EnvMode
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Config:
|
||||
intervention: Literal["none", "routeV", "absorb"] = "routeV"
|
||||
intervention: Literal["none", "routeA", "absorb"] = "routeA"
|
||||
lora_r: int = 32
|
||||
lora_init_seed: int = 0
|
||||
|
||||
@@ -46,30 +46,27 @@ class Config:
|
||||
unbiased: bool = True
|
||||
|
||||
vhack_refresh_every: int = 5
|
||||
# The 8 original behavior_* pairs only: per-pairset diag (out/diag/pinning_pairset_auroc.png)
|
||||
# ranks this subset's v_grad best at separating live hacks (AUROC 0.69, d=+0.85), well above
|
||||
# the full all-in-one. The `@behavior` TAG would re-add the anti-aligned opportunity-aware
|
||||
# pairs (d=-0.03) and dilute, so select by heading prefix, not tag. Wave-2 arms (untested):
|
||||
# `/behavior2` = 15 new mechanisms, `/behavior` = 23-pair union.
|
||||
# The 8 original behavior_* pairs only: the best or tied vector on all three diag
|
||||
# windows (RESEARCH_JOURNAL 2026-06-11 (d)). The `@behavior` TAG would re-add the
|
||||
# anti-aligned opportunity-aware pairs and dilute, so select by heading prefix, not
|
||||
# tag. Wave-2 arms (untested): `/behavior2` = 15 new mechanisms, `/behavior` = union.
|
||||
vhack_pairs_path: Path = Path("data/pairs/hack_pairs.md#all-in-one/behavior_")
|
||||
# Routing directions per module. k=1 (headline): the mean(hack-clean) "mean-mass"
|
||||
# axis. k>1: top-k oriented SVD dirs of the paired diff; gate scores max_i cos(g,v_i)
|
||||
# (alignment to ANY known hack sub-mode) -- catches multi-modal hack signal one mean
|
||||
# washes out. k=1 stays mean-diff (not SVD-top-1) so "mean-mass vs top-k" is a clean A/B.
|
||||
v_grad_k: int = 1
|
||||
# Online-stats gate (#30): route by the live MEAN + k*STD of the pooled cosine-to-v_grad,
|
||||
# not a fixed quantile. The authored absolute band is mis-placed (live pos sits far below
|
||||
# the synthetic-hack edge; even synthetic SOLVE is more hack-aligned than on-policy hack --
|
||||
# see out/diag/pinning_calib.png), and a fixed quantile FORCES route_quantile of the batch
|
||||
# out every step even when nothing separates. mean+k*std self-silences: it only routes the
|
||||
# tail that genuinely exceeds the spread. pos > mean + route_std_mid*std -> mid (both blocks train);
|
||||
# pos >= mean + route_std_rout*std -> rout (hack, deployed detached); below -> keep (bulk).
|
||||
# The band is calibrated from the CURRENT batch each step (no window, no flush): refresh-proof
|
||||
# by construction. Direction stays authored-only; only the threshold follows the live dist.
|
||||
route_std_mid: float = 2.0
|
||||
route_std_rout: float = 3.0
|
||||
# t-stat extraction (v = mean/SE per coordinate, clamp |t|<=3): null at the current
|
||||
# 8 pairs (journal (e)); revisit when the authored-pair set grows.
|
||||
vact_tstat: bool = False
|
||||
# routeA gate thresholds come from a rolling buffer of the last route_buffer live
|
||||
# scores: z-normalize by buffer mean/std, two-threshold Otsu (winsorized 1/99% --
|
||||
# journal (d): without it one outlier buys a whole zone) -> keep | absorb | rout.
|
||||
# Until the buffer holds route_warmup scores the gate pins absorb (both blocks
|
||||
# train): too few points to place thresholds. The buffer stores pooled ACTS and
|
||||
# re-scores them against the current v_act, so a refresh needs no flush. No
|
||||
# bimodality guard: no shape statistic separates the hack mixture from hack-free
|
||||
# scores on the cached windows (Otsu tail means sit ~2.4 sd apart even on a
|
||||
# Gaussian), and a false rout only discards one update from deployment.
|
||||
route_buffer: int = 256
|
||||
route_warmup: int = 128
|
||||
# Haar-random direction control (placebo): same routing machinery, no pair signal.
|
||||
routeV_random_v_seed: int | None = None
|
||||
routeA_random_v_seed: int | None = None
|
||||
rollout_ablate_frac: float = 0.0
|
||||
|
||||
env_mode: EnvMode = "run_tests"
|
||||
@@ -112,8 +109,9 @@ class Config:
|
||||
@property
|
||||
def arm(self) -> str:
|
||||
# _lora2r suffix kept so these runs never conflate with the retired
|
||||
# PiSSA-substrate runs of the same intervention (rename-on-logic-change).
|
||||
return {"none": "vanilla_lora2r", "routeV": "routingV_lora2r",
|
||||
# PiSSA-substrate runs of the same intervention (rename-on-logic-change;
|
||||
# routeA likewise never conflates with the retired grad-gate routeV runs).
|
||||
return {"none": "vanilla_lora2r", "routeA": "routeA_lora2r",
|
||||
"absorb": "absorb_lora2r"}[self.intervention]
|
||||
|
||||
|
||||
@@ -126,12 +124,11 @@ class SmokeConfig(Config):
|
||||
max_new: int = 32
|
||||
n_problems: int = 100
|
||||
prompts_per_step: int = 1
|
||||
# Random tiny data never separates, so the self-silencing band (mean+2/3*std) would route
|
||||
# nothing and the quarantine would never train -> the routing-pathway smoke assert fails.
|
||||
# Force routing by lowering the band so the smoke exercises mid+rout (correctness, not the
|
||||
# real threshold). mid below mean -> most train quarantine; rout at mean -> top ~half detach.
|
||||
route_std_mid: float = -1.0
|
||||
route_std_rout: float = 0.0
|
||||
# Smoke produces 4 scores/step over 30 steps; the real 256/128 buffer would keep the
|
||||
# gate in warmup forever. Shrink so the smoke exercises warmup AND the Otsu gate
|
||||
# (keep/absorb/rout + deployed detach) within a few steps.
|
||||
route_buffer: int = 32
|
||||
route_warmup: int = 8
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
|
||||
Reference in New Issue
Block a user