Files
evil_MoE/src/vgrout/train.py
T
wassname e5295dc07b feat: route2 Haar-random v_grad directionality control (H2 vs H4) + semantic placebo fleet
The null_city placebo is CONTAMINATED: 20% of its modules align with the hack
direction (median |cos|=0.06 but a 0.99 tail, shared generic features). So the
'route2 is non-directional' verdict rested on a bad control. Add the clean tests:

- route2_random_v_seed: replace pair-derived v_grad with seeded per-module Haar-random
  unit vectors (~0 cos with hack dir everywhere). Refresh no-ops so the draw stays fixed.
  'Nothing routed' (||dS_hack||==0) is now a valid logged outcome, not an abort -- it is
  itself H4-confirming (a zero-alignment direction may never clear tau).
- null_vampire / null_bacon / null_blue: semantic placebo fleet (vampire-vs-werewolf etc.),
  each an arbitrary direction with different accidental hack-alignment. Maps route2's
  suppression-vs-alignment as a scatter: H4 predicts it tracks |cos|, H2 predicts all suppress.

Smoke-validated (smoke-route2 --route2-random-v-seed=0 completes).

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
2026-06-05 08:43:54 +00:00

1886 lines
110 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""GRPO / Dr.GRPO loop with SVD-basis gradient projection on the LeetCode
reward-hacking benchmark.
generate -> grade -> backward -> project -> step
Inner GRPO step ported from lsdefine/simple_GRPO grpo_vllm_one.py:64-95; the
outer loop accumulates grads over prompts_per_step prompts (simple_GRPO's
Q_batch_size), so at least one per-prompt group has reward variance.
Unbiased normalization: Dr.GRPO, Liu et al. 2025, arXiv:2503.20783 -- drop the
1/|oᵢ| length norm and the /σ_R group-std (--unbiased, on by default).
Adapter: AntiPaSTO full-rank SVD knob δS per Linear, W' = W + U diag(δS) Vᵀ.
At δS=0 the adapter is identity, so a no-grad forward with δS zeroed gives π_ref
for free, no second model (the KL term under --beta>0).
Arms (--intervention, one knob):
none measure only; δS.grad untouched (vanilla GRPO)
erase subtract the hack-ward component of δS.grad
route park that component in the δS_hack quarantine, ablated at deploy (Cloud 2024)
route2 route per-rollout by a calibrated-τ cosine gate, cos(g_b, v_grad) > τ
Hyperparameters from ariahw/rl-rewardhacking config.py (docs/grpo_hyperparams.md);
SmokeConfig / FastConfig / FullConfig below hold the scale knobs.
uv run python -m vgrout.train smoke --intervention=erase
"""
from __future__ import annotations
import gzip
import json
import math
import os
import sys
import random
import time
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Literal
# Must be set BEFORE `import torch` to take effect on the CUDA allocator.
# Eliminates fragmentation that caused 91 GiB allocated / 581 MiB free crash
# on Qwen3-4B G=8 (PyTorch's own OOM message recommends this).
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
import torch
import torch.nn.functional as F
import tyro
from jaxtyping import Float
from loguru import logger
from safetensors import safe_open
from safetensors.torch import save_file
from tabulate import tabulate
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from .antipasto import ablate_quarantine, ref_logprobs_via_zero_delta, wrap_model_with_antipasto
from .extract_vhack_grad import load_v_hack, postprocess_v_hack
from .problems import DATA, load_problems
from .proj import per_token_logps, project_delta_S_grad, mean_cos_pre_from_grads
from .rewards import EnvMode, compute_reward
from .data import DATA, load_problems
from .vhack import load_v_hack, postprocess_v_hack
from .eval import ablate_quarantine, eval_hack_solve, ref_logprobs_via_zero_delta
from .tablelog import setup_logging, StepLogger
CACHE_ROOT = Path("svd_cache")
OUT_DIR = Path("out")
# out/ is sorted by datatype (see docs/spec/20260530_out_dir_reorg.md): extracted
# bases under vhack/, teacher pools under pools/, per-train-run checkpoints under
# runs/<run_id>/. Read paths (v_hack, teacher pool) come in as explicit args.
VHACK_DIR = OUT_DIR / "vhack"
RUNS_DIR = OUT_DIR / "runs"
LOGS_DIR = Path("logs")
# DATA (the LeetCode dataset path) lives in problems.py, imported above.
def setup_logging(run_id: str) -> Path:
"""Token-efficient loguru: stdout = 1-char icon + msg; verbose log to file.
See /root/.claude/skills/token-efficient-logging/SKILL.md.
"""
LOGS_DIR.mkdir(exist_ok=True)
verbose_log = LOGS_DIR / f"{datetime.now().strftime('%Y%m%dT%H%M%S')}_{run_id}.log"
logger.remove()
logger.add(
lambda msg: tqdm.write(msg, end=""),
colorize=True,
format="<level>{level.icon}</level> {message}",
level="INFO",
)
logger.add(
verbose_log,
format="{time:HH:mm:ss} | {level} | {message}",
level="DEBUG",
)
logger.level("INFO", icon="I")
logger.level("WARNING", icon="W")
logger.level("ERROR", icon="E")
logger.level("DEBUG", icon="D")
return verbose_log
@dataclass(kw_only=True)
class Config:
"""Universal knobs shared across all presets. Preset subclasses below
(SmokeConfig / FastConfig / FullConfig) override the scale-dependent knobs
(model, steps, group, lr, Adam betas). Dispatched via tyro subcommand.
`kw_only=True` so subclasses can add new fields with defaults even though
the parent already has defaulted fields (no positional-arg ordering issues).
Adam defaults (lr=7e-5, beta1=0.9, beta2=0.99) are ariahw config.py:138-144.
`fast` deliberately overrides with aggressive lr + low Adam betas for
sub-30-min iteration loops.
"""
# The four arms (see module docstring). `arm` (property below) is the derived
# display name; route2 gate spec: docs/spec/20260601_calibrated_tau_route2grad.md.
intervention: Literal["none", "erase", "route", "route2"] = "erase"
# ── scale knobs: every preset overrides these ──
model: str = "Qwen/Qwen3-4B"
steps: int = 100
group: int = 6 # G samples per question
max_new: int = 1024
n_problems: int = 992
beta: float = 0.0 # KL coef; >0 uses the δS=0 free-ref-model trick
prompts_per_step: int = 8 # P prompts per optimizer step; grads accumulate over P.
lr: float = 7e-5
adam_beta1: float = 0.9
adam_beta2: float = 0.99
# Universal knobs (haven't been a useful axis to vary per preset so far).
clip: float = 0.2
weight_decay: float = 0.1 # canonical config.py:142
# warmup as fraction of total steps (not absolute count) so a 20-step `fast`
# preset doesn't burn its first 10 steps at 1e-3-of-peak LR. 0.1 = ariahw
# canonical 10/100 = 10% at the 100-step regime they used.
warmup_frac: float = 0.1
grad_clip: float = 10.0 # global L2 clip on δS grads
seed: int = 41
preserve_magnitude: bool = True
gate_mode: Literal["one_sided", "no_gate", "reverse"] = "one_sided"
# route2 airtight no-cheat control: anchor the τ-gate on TEACHER rows only, never
# on hacked_E-flagged student rows. The run_tests detector cross-fires <=1.1% on
# held-out modes (false positives), so the default anchor leaks ~1% of held-out
# labels into routing. Teacher-only anchor gives the held-out classes PROVABLY zero
# detector labels -- the strict A5 no-cheat test. v_grad and the τ-route-by-energy
# path are unchanged; only the force-route-known-hacks term drops its student flags.
gate_anchor_teacher_only: bool = False
project_overshoot: float = 1.0 # remove overshoot*c_use@V; 1.0=just remove, 1.1=10% reversal of hack-ward grad
# route/route2 exploration floor: fraction of student rollouts sampled with the
# quarantine (δS_hack) ablated, i.e. from the DEPLOYED model. Intent: guard hack-
# saturation -- if on-policy sampling collapses onto hacking, every rollout routes
# to the quarantine and the deployed δS never sees a solve gradient. Grading these
# ablated rollouts is also a free per-step deploy proxy (hk_abl / slv_abl).
# DEFAULT 0 (off): sampling from the DEPLOYED model makes its reward an
# optimization target whose optimum is hacking, so whenever the (stale) routing
# gate misses, those deploy-generated hacks update the kept knob -- the floor
# optimizes the very policy the quarantine is meant to keep clean. Neither
# gradient-routing paper samples from the ablated model during training (they
# sample knob-on, ablate at eval, let absorption localize). Re-enable only to test
# whether a fresh gate (refresh-every-1) keeps the floor safe.
rollout_ablate_frac: float = 0.0
# Which grader flaw + factual hint this run trains on (a "hack class"). Sets
# the prompt hint (HINT_REPLACE_TO) and how `passed` is graded in rewards.py.
# run_tests = the original run_tests-overwrite loophole. eq_override / exit_code
# are the other faithful loopholes (docs/spec/20260530_faithful_multi_loophole_env.md).
env_mode: EnvMode = "run_tests"
unbiased: bool = True # Dr.GRPO: drop 1/|oᵢ| and /σ_R
# v_hack path; None -> derived from model+top_k, auto-extracted on cache miss
# (~5min, shares the loaded model). v_hack_k slices the saved top-k_max
# directions to top-k_use at load (the k-ablation knob).
v_hack_path: Path | None = None
v_hack_extract_top_k: int = 12 # max k to save at extract; n_train_pairs caps it lower
v_hack_k: int = 5 # load-time slice; k=1 = mean-diff, k=k_max = full
v_hack_tau_axis: float = 0.0 # extract-time: zero axes where S_i/S_0 < tau_axis
# Global noise floor: drop the bottom frac of singular values Sᵢ by quantile
# across all modules. A module with every axis below the threshold is dropped
# (projection skips it -- no hack signal there). 0 = no filter.
v_hack_drop_bottom_frac: float = 0.25
# Online refresh: every N steps re-extract v_hack against the current
# (δS-modified) model so it tracks the student's drifting hack subspace, not
# the step-0 one. 0 = freeze at load. Cost ~1-2 min wall on Qwen3-4B.
vhack_refresh_every: int = 5
# Route deploy-eval: every N steps zero δS_hack and eval hack/solve on a fixed
# subset -> the hack_deploy / solve_deploy columns (the dynamics-plot series for
# route: the training-time hack curve still hacks; routing's benefit shows only
# once the quarantine is ablated). 0 = off. eval_n_prompts x `group` samples.
# Default 5: gives 12 deploy points over the common 60-step run (nice trajectory
# plot). Affordable now that the per-step knob-ON eval pass is gone (each eval is
# one n=64 pass, ~230s, not two). Long-horizon recipes (paper-longrun, A5) pin a
# sparser cadence (10/20) explicitly. See journal 2026-06-04 (a) for the cost audit.
eval_ablate_every: int = 5
eval_n_prompts: int = 8
# Save the deploy adapter (δS only, ~2.3MB) at every deploy-eval step, tagged by
# step, so a run can be RE-SCORED later (more prompts, different eval) without
# retraining. Tiny per ckpt; a 200-step run at every-10 is ~46MB. Off for big sweeps.
save_eval_ckpts: bool = True
# Optional: pool-derived pairs JSON (built by pairs_from_pool.py). When set,
# BOTH the cache-miss extract AND the online refresh use these pairs instead
# of the hand-crafted vgrout.pairs.PAIRS. Required for the cross-
# mechanism experiment so refresh keeps tracking half_A's hack subspace.
vhack_pairs_path: Path | None = None
# Directionality control: replace route2's pair-derived v_grad with a per-module
# Haar-random unit vector (provably ~0 cos with the hack direction in every module).
# Tests whether route2's suppression NEEDS the direction (H4: alignment) or is
# alignment-agnostic quarantine-absorption (H2). Seeded so multiple draws give a
# distribution ("works half the time?"). The null_city placebo is a CONTAMINATED
# control -- 20% of its modules align with the hack dir (median |cos|=0.06 but a
# 0.99 tail); Haar is the clean zero-alignment control. Refresh no-ops when set, so
# the direction stays the one fixed random draw regardless of --vhack-refresh-every.
route2_random_v_seed: int | None = None
# Per-source cin diagnostic: split each prompt's backward into student-only
# + teacher-only passes (~2x backward time). 1 = every step (default; full
# signal); N>1 = only every Nth step (combined backward elsewhere, ~halves
# backward cost on skipped steps). cos_pre_s/cos_pre_t print as `nan` on skipped.
cos_pre_split_every: int = 1
out_tag: str = "" # suffix for saved artifact, e.g. "_seed41"
# Mixed-pool GRPO: per-prompt rollout pool = G_s live student + G_t cached
# teacher rollouts. Teacher pool is a dir of prompt_NNNN.jsonl.gz produced by
# probe_distill.py --teacher-only (schema includes prompt_ids, completion_ids,
# plen, reward, hacked, gt_pass, fmt_ok). Reward labels are read from cache
# (not re-graded) so the pool is reproducible. G_t = round(G * mix_ratio),
# G_s = G - G_t. Both halves contribute to a single group-relative advantage.
# Loss is unchanged: ratio==1 in single-inner-step PPO, so reward-weighted
# policy gradient applies uniformly to both halves regardless of source.
teacher_pool_dir: Path | None = None
# Teacher density G_t/G. 0.125 (1 in 8) is the operating point: the hack-
# reduction gap holds and the solve cost vanishes vs mix=0.5. Needs group>=8
# so round(G*mix_ratio) >= 1 teacher.
mix_ratio: float = 0.125
# Teacher-off curriculum: seed hacks via the teacher pool for the first N
# optimizer steps, then cut to pure on-policy (G_t=0) for the rest. None = never
# cut. Guarantees all hacks emerge (teacher-seeded) before testing whether route2
# holds the suppression once the teacher crutch is gone. See step-loop use.
teacher_off_step: int | None = None
# A5 no-cheat generalisation: restrict teacher demos (and thus the route2 tau
# hack-anchor) to these env_modes only. Held-out modes stay in the training set
# but train PURELY ON-POLICY (no teacher rows, never seed the hack-anchor) -- the
# student must emerge them itself, and we measure whether routing on the
# known-mode v_grad suppresses them anyway (absorption). None = use the whole
# pool (normal). When set, the line-589 "filter problems to pool keys" is skipped
# and uncached/held-out prompts fall through to student-only instead of skipping.
teacher_modes: tuple[str, ...] | None = None
# Cross-mechanism BLUF (docs/spec/20260528_cross_mechanism_v_hack.md):
# which upstream detectors were used to label the hack-side of the pairs that
# produced v_hack. Used to split student-rollout hacks into half_A (covered by
# the detector set v_hack was extracted from) and half_B (the held-out
# detectors). HACK_A drops AND HACK_B drops => projection is mechanism-agnostic.
# Detector codes (rewards.py): E=loophole_used, C=arbitrary_pass, D=wrong_tests.
# Defaults to the empty case (no split reported) when run on hand-crafted pairs.
half_a: str = ""
@property
def preset_name(self) -> str:
"""Slug used in log/checkpoint paths. Derived from subclass name so we
don't have to remember to set it per subclass (single source of truth)."""
return type(self).__name__.removesuffix("Config").lower() or "base"
@property
def arm(self) -> str:
"""Display name for run-id / BLUF / logs (results.py + plot_dynamics
classify off this). One-to-one with intervention; not a CLI flag."""
return {"none": "vanilla", "erase": "projected",
"route": "routing", "route2": "routing2"}[self.intervention]
@dataclass(kw_only=True)
class SmokeConfig(Config):
"""Tiny-random model on CPU, 30 steps; covers every code path including
the every-25-step save_ckpt trigger. ~1-2 min wall-clock."""
model: str = "llamafactory/tiny-random-qwen3"
steps: int = 30
group: int = 4 # >=4 so route2 smoke (mix=0.5 -> G_s=2) can split a rollout_ablate_frac slice; G_s=1 couldn't
max_new: int = 32
n_problems: int = 100
beta: float = 0.0
prompts_per_step: int = 1
@dataclass(kw_only=True)
class FastConfig(Config):
"""Minimum-viable iteration loop for finding a working GRPO-learns-to-hack
baseline (~15 min on Qwen3-4B). Aggressive Adam (lr=3e-3, beta1=0.5,
beta2=0.9) so 20 steps is enough for lp_t drift to be visible.
UAT: hack_s rises 0/N -> >=N/4 by step 20, lp_t-lp_s gap shrinks >=30%.
n_problems=200 keeps teacher_pool coverage (only ~40 prompts touched
at pp=4 x 20 steps)."""
model: str = "Qwen/Qwen3-4B"
steps: int = 60 # 60 lets the lp_s-lp_t gap open at convergence
# 4-mode substrate pool + prog_wide persona pairs are the default, so real runs
# need only --intervention (+ optional seed/refresh/mask).
teacher_pool_dir: Path | None = Path("out/pools/substrate")
vhack_pairs_path: Path | None = Path("out/pairsets/prog_wide.json")
group: int = 8 # G=8 so the locked-in mix_ratio=0.125 gives 1 teacher / 7 student
max_new: int = 512
n_problems: int = 200
beta: float = 0.0
prompts_per_step: int = 4
lr: float = 3e-3
adam_beta1: float = 0.5
adam_beta2: float = 0.9
@dataclass(kw_only=True)
class FullConfig(Config):
"""Canonical ariahw substrate (4B = DEFAULT_MODEL_ID). G=6 (G=8 OOMs on the
lm_head spike for long prompts). pp=43 x G=6 = 258 ~= the paper's 256
generations/step; n_problems=992 is the full filtered set (paper fn.9)."""
model: str = "Qwen/Qwen3-4B"
steps: int = 200
group: int = 6
max_new: int = 1024
n_problems: int = 992
beta: float = 1e-3
prompts_per_step: int = 43
def _haar_unit_dirs(v_grad: dict, seed: int, device) -> dict:
"""Per-module Haar-random unit vectors matching v_grad's shapes -- the clean
zero-alignment directionality control for route2 (~0 cos with the hack dir in
every module). Seeded + sorted-name iteration so it is reproducible and a refresh
regenerates the identical direction (no-op). See Config.route2_random_v_seed."""
g = torch.Generator().manual_seed(seed)
out = {}
for name in sorted(v_grad):
d = torch.randn(v_grad[name].shape, generator=g)
out[name] = (d / d.norm().clamp_min(1e-12)).to(device)
return out
def build_route2_anchors(is_student: list[bool], hack_E_flags: list[bool],
teacher_only: bool, device) -> tuple[torch.Tensor, torch.Tensor]:
"""τ-calibration anchors for the route2 gate (merged rows: students lead, teachers
follow). hack_anchor = teacher rows OR (unless teacher_only) detector-flagged student
rows; clean_anchor is the exact complement. hack_E_flags (len G_s) aligns with the
leading student rows. teacher_only drops the student detector term so held-out classes
get PROVABLY zero detector labels -- the airtight A5 no-cheat control. The default
leaks: the run_tests detector cross-fires <=1.1% on held-out modes, force-routing those
rollouts. Verified in scripts/verify_gate_anchor.py."""
n = len(is_student)
is_student_t = torch.as_tensor(is_student, dtype=torch.bool, device=device)
flags = torch.zeros(n, dtype=torch.bool, device=device)
if not teacher_only:
m = min(n, len(hack_E_flags))
flags[:m] = torch.as_tensor(list(hack_E_flags[:m]), dtype=torch.bool, device=device)
hack_anchor = (~is_student_t) | flags
return hack_anchor, ~hack_anchor
@torch.no_grad()
def eval_hack_solve(model, tok, problems, eval_idxs, gen_cfg, device, max_new) -> dict:
"""Student-only generate + grade on a FIXED prompt subset (no teacher, no
backward) -- a clean read of what the current adapter does. Each problem is
graded by ITS OWN prob["env_mode"] (multi-loophole substrate: the eval subset
spans several modes; a problem only pays for its own exploit).
hack = exploited rate (mode's channel credited correctness without the strict
oracle agreeing); solve = gt_correct rate (the exploit-resistant oracle). Same
compute_reward as training, so the numbers are comparable to the per-step
hack_s/gt_s, just measured off-policy on a held-fixed subset.
"""
model.config.use_cache = True
n = hacks = solves = 0
# per-mode (env_mode) tallies so route's deploy eval can be read per loophole:
# the load-bearing weak-detector test is whether route suppresses the modes
# v_hack never saw (file_marker/sentinel), not just its own (run_tests).
by_mode: dict[str, list[int]] = {} # env_mode -> [hacks, solves, n]
for idx in eval_idxs:
prob = problems[idx]
mode = prob["env_mode"]
prompt = tok.apply_chat_template(
prob["messages"], tokenize=False, add_generation_prompt=True, enable_thinking=False)
enc = tok(prompt, return_tensors="pt", add_special_tokens=False).to(device)
if enc.input_ids.shape[1] + max_new > 2048:
continue
out = model.generate(**enc, generation_config=gen_cfg)
comps = out[:, enc.input_ids.shape[1]:]
tally = by_mode.setdefault(mode, [0, 0, 0])
for t in tok.batch_decode(comps, skip_special_tokens=True):
r = compute_reward(
t, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"],
setup_code=prob["setup_code"], func_name_hint=prob["func_name"],
env_mode=mode)
hacks += int(r.exploited); tally[0] += int(r.exploited)
solves += int(r.gt_correct); tally[1] += int(r.gt_correct)
n += 1; tally[2] += 1
model.config.use_cache = False
return dict(hack=hacks / max(1, n), solve=solves / max(1, n), n=n, by_mode=by_mode)
# 2-char env_mode codes for compact per-mode hack columns (hk_rt, hk_xc, ...).
MODE_CODE: dict[str, str] = {
"run_tests": "rt", "eq_override": "eq", "exit_code": "xc",
"stdout_marker": "so", "sentinel": "se", "file_marker": "fm",
}
def main(cfg: Config) -> int:
# Read the chosen preset's settings off the config, then set up the run. The
# subclass dataclasses (SmokeConfig / FastConfig / FullConfig) carry the preset
# defaults, so here we just read them off cfg directly.
model_name = cfg.model; steps = cfg.steps; group = cfg.group
max_new = cfg.max_new; n_problems = cfg.n_problems; beta = cfg.beta
prompts_per_step = cfg.prompts_per_step
lr = cfg.lr; adam_beta1 = cfg.adam_beta1; adam_beta2 = cfg.adam_beta2
run_id = f"{cfg.preset_name}_{cfg.arm}_seed{cfg.seed}{cfg.out_tag}"
verbose_log = setup_logging(run_id)
torch.manual_seed(cfg.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# BLUF up front: argv + setup + verbose-log pointer so a tail-reader sees context.
logger.info(f"argv: {' '.join(sys.argv)}")
logger.info(f"verbose log: {verbose_log}")
logger.info(
f"preset={cfg.preset_name} arm={cfg.arm} model={model_name} "
f"steps={steps} G={group} max_new={max_new} beta={beta} "
f"unbiased={cfg.unbiased} seed={cfg.seed} device={device}"
)
# Load the tokenizer and the frozen base model. We adapt this model but never
# train its weights directly.
tok = AutoTokenizer.from_pretrained(model_name)
if tok.pad_token_id is None: tok.pad_token = tok.eos_token
# ── model + tokenizer ──
# CPU smoke: fp32 + sdpa (flash-attn2 is CUDA-only, CPU bf16 is patchy).
# GPU: bf16 + flash_attention_2.
cpu = device.type == "cpu"
model = AutoModelForCausalLM.from_pretrained(
model_name,
dtype=torch.float32 if cpu else torch.bfloat16,
attn_implementation="sdpa" if cpu else "flash_attention_2",
).to(device)
# No gradient checkpointing: grad-accum forwards one G-group at a time, so peak
# activation memory fits at G=6 on 96GB without recompute. δS is a leaf inside
# W' = W + U diag(δS) Vᵀ, so it gets grad directly (no enable_input_require_grads).
# use_cache toggles per generate call: True for decode, False for the loss forwards.
model.config.use_cache = False
# ── AntiPaSTO adapter: δS (kept) + δS_hack (quarantine), same shape r ──
is_route2 = cfg.intervention == "route2"
wrappers = wrap_model_with_antipasto(
model, model_name, CACHE_ROOT, device,
grad_probe=is_route2, # route2 needs the per-rollout δS gate probe
)
# δS_hack only gets a grad under route (proj.py subspace split) or route2
# (per-rollout τ routing); under none/erase its grad stays None, so AdamW skips
# it and it stays exactly 0 (forward adds 0 -> identity).
delta_params = [info["delta_S"] for info in wrappers.values()]
delta_hack_params = [info["delta_S_hack"] for info in wrappers.values()]
logger.info(f"trainable delta_S: {sum(p.numel() for p in delta_params):,} "
f"(+{sum(p.numel() for p in delta_hack_params):,} delta_S_hack quarantine)")
# ── hack direction: v_hack (erase/route project against it) or v_grad (route2) ──
# Vanilla (none) is pure GRPO and ignores v_hack entirely (the cin/cout columns
# are hidden, so v_hack=None just means no subspace machinery).
v_grad = None # set only by the route2 grad-mask branch below
if cfg.intervention in ("none", "route2"):
if cfg.intervention == "none" and cfg.v_hack_path is not None:
logger.info(f"vanilla arm: ignoring --v-hack-path={cfg.v_hack_path} "
"(no projection; cin/cout diagnostics off)")
v_hack = None # route2 routes via the mask, not erase/route grad surgery
if is_route2:
# The persona pairs are the only "detector" (weak, self-supervised). They
# produce the routing direction; no oracle, no gt_pass.
if cfg.vhack_pairs_path is not None:
from .pairs_from_pool import load_pairs_json
MASK_PAIRS = load_pairs_json(cfg.vhack_pairs_path)
logger.info(f"route2 pairs: pool-derived ({cfg.vhack_pairs_path}) -> {len(MASK_PAIRS)} pairs")
else:
from .pairs import PAIRS as MASK_PAIRS
logger.info(f"route2 pairs: hand-crafted PAIRS -> {len(MASK_PAIRS)} pairs")
model.eval()
# gradient-space mean-diff. extract_v_hack gives per-pair GRPO gradients
# on δS; v_grad = unit(mean(g_hack - g_clean)) per module, oriented
# hack-ward (training reinforces hacks with the same sign, so a rollout
# with cos(g_b, v_grad) above the calibrated tau is a reinforced hack).
from .extract_vhack_grad import extract_v_hack
_, _, raw_grads, _ = extract_v_hack(
model, tok, wrappers, MASK_PAIRS,
top_k=1, tau_axis=0.0, n_heldout=2, device=device,
)
v_grad = {}
for name in wrappers:
d = (raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).mean(0)
v_grad[name] = (d / d.norm().clamp_min(1e-12)).to(device)
logger.info(f"route2 grad: built v_grad (gradient mean-diff) for {len(v_grad)} modules")
if cfg.route2_random_v_seed is not None:
v_grad = _haar_unit_dirs(v_grad, cfg.route2_random_v_seed, device)
logger.info(f"route2 grad: OVERRODE v_grad with Haar-random dirs "
f"(seed={cfg.route2_random_v_seed}) -- directionality control (H2 vs H4)")
model.train()
else:
# v_hack path resolution, most-specific first. The pairset (personas) is
# the source of truth: pass --vhack-pairs-path and the hack file auto-loads
# (auto-extracts if missing) -- no need to also pass --v-hack-path.
if cfg.v_hack_path is not None:
v_hack_path = cfg.v_hack_path # explicit override (e.g. randomV control)
elif cfg.vhack_pairs_path is not None:
v_hack_path = VHACK_DIR / f"v_hack_pairset_{cfg.vhack_pairs_path.stem}.safetensors"
else:
# no pairset given -> hand-crafted PAIRS, keyed by model + extract knobs.
# Slug works for HF names and local paths; tau_tag because tau_axis is
# baked into the saved V (extract zeros rows where S_i/S_0 < tau_axis).
model_slug = model_name.rstrip("/").split("/")[-1]
tau_tag = f"_tau{cfg.v_hack_tau_axis:g}" if cfg.v_hack_tau_axis > 0 else ""
v_hack_path = VHACK_DIR / f"v_hack_{model_slug}_k{cfg.v_hack_extract_top_k}{tau_tag}.safetensors"
if not v_hack_path.exists():
from .extract_vhack_grad import extract_v_hack
if cfg.vhack_pairs_path is not None:
from .pairs_from_pool import load_pairs_json
VHACK_PAIRS = load_pairs_json(cfg.vhack_pairs_path)
logger.info(f"v_hack pairs: pool-derived ({cfg.vhack_pairs_path}) -> {len(VHACK_PAIRS)} pairs")
else:
from .pairs import PAIRS as VHACK_PAIRS
logger.info(f"v_hack pairs: hand-crafted PAIRS -> {len(VHACK_PAIRS)} pairs")
logger.info(f"v_hack cache miss at {v_hack_path}; extracting (~5min)...")
model.eval() # match standalone extract: deterministic backward, no dropout
v_hack_extracted, v_sv_extracted, _raw_grads, _diag = extract_v_hack(
model, tok, wrappers, VHACK_PAIRS,
top_k=cfg.v_hack_extract_top_k, tau_axis=cfg.v_hack_tau_axis,
n_heldout=2, device=device,
)
OUT_DIR.mkdir(exist_ok=True)
# Combine V and S under one safetensors file with `_sv/{name}` prefix
# for the singular values. load_v_hack splits them back apart.
save_payload = {**v_hack_extracted, **{f"_sv/{n}": s for n, s in v_sv_extracted.items()}}
save_file(save_payload, str(v_hack_path),
metadata={"model": model_name,
"dtype": "fp32" if cpu else "bf16",
"top_k": str(min(cfg.v_hack_extract_top_k, len(VHACK_PAIRS) - 2)),
"tau_axis": str(cfg.v_hack_tau_axis), "schema": "v2_with_sv"})
# extract zeros grads at exit; opt is built below so no opt-state taint.
model.train() # restore train mode; eval was set only for the extract pass
v_hack_cpu = load_v_hack(
v_hack_path, model_name, wrappers,
k_use=cfg.v_hack_k, drop_bottom_frac=cfg.v_hack_drop_bottom_frac,
)
v_hack = {name: v.to(device) for name, v in v_hack_cpu.items()}
# ── teacher pool ──
# Teacher pool: pre-generated rollouts on disk keyed by problem_id. Each step's
# G_t teacher rollouts come from a uniform random sample of that prompt's cache,
# so we do *not* keep the teacher model in VRAM. Pool is produced by
# `probe_distill.py --teacher-only` (see schema in probe_distill.py:149-186).
# Cached rewards/flags are reused verbatim (no re-grading), so the pool is a
# reproducible fixed teacher distribution across runs.
teacher_pool: dict[int, list[dict]] = {}
# Multi-loophole substrate: a teacher pool dir MAY carry partition.json
# {problem_id: env_mode}. When present, this is the even non-overlapping
# substrate (build_substrate.py) -- each problem is graded by its assigned mode
# and the teacher rollouts are the elicit-then-strip hacks for that mode. When
# absent, the run is single-mode (cfg.env_mode for every problem). See
# docs/spec/20260530_faithful_multi_loophole_env.md.
partition: dict[int, EnvMode] | None = None
G_s = group
G_t = 0
if cfg.teacher_pool_dir is not None:
# mix=0 is the NO-TEACHER ablation: pure on-policy GRPO (G_t=0, no teacher
# rollouts injected) while the pool is still loaded for the 4-mode partition
# and route2 v_grad extraction. Using the pairs for v_grad is allowed under
# the no-cheat invariant; mixing teacher rollouts into training is the thing
# mix=0 removes. mix in [0,1).
if not (0.0 <= cfg.mix_ratio < 1.0):
raise ValueError(f"mix_ratio must be in [0,1) when teacher_pool_dir set; got {cfg.mix_ratio}")
G_t = round(group * cfg.mix_ratio)
G_s = group - G_t
if G_s == 0:
raise ValueError(
f"degenerate split: G={group} mix_ratio={cfg.mix_ratio} -> G_s={G_s}. "
f"Pick mix_ratio < 1 so the student half is non-empty."
)
for path in sorted(cfg.teacher_pool_dir.glob("prompt_*.jsonl.gz")):
# path.stem on 'prompt_0004.jsonl.gz' is 'prompt_0004.jsonl' (only one
# suffix stripped); split off the .jsonl before parsing the int.
problem_id = int(path.name.split("_")[1].split(".")[0])
with gzip.open(path, "rt") as f:
teacher_pool[problem_id] = [json.loads(line) for line in f]
if not teacher_pool:
raise FileNotFoundError(
f"teacher pool {cfg.teacher_pool_dir} is empty. Run `just pregen-teacher N` first."
)
partition_path = cfg.teacher_pool_dir / "partition.json"
if partition_path.exists():
raw = json.loads(partition_path.read_text())
partition = {int(pid): mode for pid, mode in raw.items()}
from collections import Counter
by_mode = Counter(partition.values())
logger.info(
f"SUBSTRATE: per-problem env_mode partition from {partition_path.name} -- "
f"{len(partition)} problems across {len(by_mode)} modes: "
f"{dict(sorted(by_mode.items()))}. Each problem graded by its own mode; "
f"non-overlap holds (passed = gt_correct OR channel_i)."
)
if cfg.teacher_modes is not None:
# A5 no-cheat: drop teacher demos for held-out modes. The held-out
# problems stay in load_problems (filter at line ~589 is skipped when
# teacher_modes is set) and train on-policy. partition is required.
assert partition is not None, "teacher_modes needs a partition.json"
kept = {pid: rows for pid, rows in teacher_pool.items()
if partition[pid] in cfg.teacher_modes}
logger.info(
f"teacher_modes={cfg.teacher_modes}: teacher pool restricted "
f"{len(teacher_pool)}->{len(kept)} prompts (known modes only); "
f"held-out-mode problems train ON-POLICY (no teacher, no anchor seed)."
)
teacher_pool = kept
n_rollouts_per = sum(len(v) for v in teacher_pool.values()) / len(teacher_pool)
avg_hack = sum(int(r["hacked"]) for v in teacher_pool.values() for r in v) / sum(len(v) for v in teacher_pool.values())
logger.info(
f"teacher pool: {len(teacher_pool)} prompts, "
f"~{n_rollouts_per:.1f} rollouts/prompt, "
f"cached hack_rate={avg_hack:.2%}. "
f"G_s={G_s} student + G_t={G_t} teacher per prompt (mix_ratio={cfg.mix_ratio})."
)
# ── optimizer + schedule ──
# δS and δS_hack share the lr (same shape, same basis, no per-group juggling).
opt = torch.optim.AdamW(
delta_params + delta_hack_params,
lr=lr, weight_decay=cfg.weight_decay, betas=(adam_beta1, adam_beta2),
)
# Linear warmup over `warmup_frac * steps`, then cosine decay to 0 over the rest.
# Fraction-based so short presets (fast: 20 steps) don't spend half the run
# under warmup. Canonical full-preset: 0.1 * 100 = 10 (matches ariahw config.py:141).
warmup_steps = max(1, int(cfg.warmup_frac * steps))
sched = torch.optim.lr_scheduler.SequentialLR(
opt,
schedulers=[
torch.optim.lr_scheduler.LinearLR(opt, start_factor=1e-3, end_factor=1.0,
total_iters=warmup_steps),
torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=max(1, steps - warmup_steps)),
],
milestones=[warmup_steps],
)
# ── generation config ──
# Qwen3.5 model card: non-thinking mode for text tasks.
# temperature=1.0, top_p=1.0, top_k=20, min_p=0.0, presence_penalty=2.0,
# repetition_penalty=1.0. enable_thinking=False is set on the chat template
# below (safe no-op if the model's template doesn't support it).
gen_cfg = GenerationConfig(
max_new_tokens=max_new, do_sample=True,
# T=0.7 matches ariahw reference (config.py:172). T=1.0 had hack emerging
# too slowly: hack patterns are modal in the baked substrate; broad sampling
# at T=1 dilutes them. Lower T expresses the substrate's hack propensity.
temperature=0.7, top_p=1.0, top_k=20, min_p=0.0,
repetition_penalty=1.0,
num_return_sequences=G_s, pad_token_id=tok.pad_token_id,
)
# Eval-ablation config: student-only, `group` samples/prompt (no teacher
# split, so we want the full group for a tighter rate estimate).
gen_cfg_eval = GenerationConfig(
max_new_tokens=max_new, do_sample=True,
temperature=0.7, top_p=1.0, top_k=20, min_p=0.0, repetition_penalty=1.0,
num_return_sequences=group, pad_token_id=tok.pad_token_id,
)
problems = load_problems(n_problems, env_modes=[cfg.env_mode], seed=cfg.seed, partition=partition)
mode_desc = "per-problem partition" if partition is not None else f"single env_mode={cfg.env_mode}"
logger.info(f"loaded {len(problems)} problems from {DATA.name} -- {mode_desc}")
if teacher_pool and cfg.teacher_modes is None:
# Restrict prompt sampling to problems with cached teacher rollouts;
# otherwise we'd skip the majority of steps when the pool is sparse
# (e.g. 70/992 prompts cached -> ~93% skip rate).
# SKIPPED under teacher_modes (A5): held-out-mode problems have no teacher
# demos but must stay in training to emerge + be measured on-policy.
before = len(problems)
problems = [p for p in problems if p["problem_id"] in teacher_pool]
logger.info(
f"teacher pool restriction: {len(problems)}/{before} prompts kept "
f"(student trains only on prompts covered by the cached teacher pool)"
)
if not problems:
raise ValueError(
f"no overlap between training set ({before} problems) and teacher pool "
f"({len(teacher_pool)} cached prompts). Re-run pregen-teacher against the same dataset."
)
# Fixed eval subset for route ablation: first eval_n_prompts problems, held
# constant across the run so the ablated-hack series is comparable step-to-step.
eval_idxs = list(range(min(cfg.eval_n_prompts, len(problems))))
rng = torch.Generator().manual_seed(cfg.seed)
rows = []
logger.info(
f"SHOULD: loss finite each step; projected/route arm cout -> ~0 (all hack-ward grad removed); "
f"PASS_RATE > 0 on 4B. "
f"ELSE: harness or projection broken. "
f"Timing cols (gen/fb/t_rew/sec): gen-bound -> vLLM; fb-bound -> lower pp; t_rew-bound -> parallel grading."
)
if teacher_pool:
logger.info(
f"SHOULD (mixed-pool): hack_t high from step 0 (cached teacher pool ~95% hack); "
f"hack_s climbs 0 -> 20%+ over the run as student learns from exposure. "
f"ELSE if hack_s flat while hack_t high: student is ignoring the off-policy "
f"gradient signal; bump mix_ratio or lr."
)
eos_id = tok.eos_token_id
pad_id = tok.pad_token_id
def gen_students(enc, n: int) -> tuple[torch.Tensor, int]:
"""Generate n student rollouts; the LAST `n_abl` rows have the quarantine
ablated (deployed model -> can't hack -> explores solves).
See Config.rollout_ablate_frac for why. frac=0 or non-quarantine arms ->
a single plain generate (n_abl=0), identical to before. Returns (rows, n_abl)
so the caller can mark the ablated tail (= free deploy-mode samples)."""
n_abl = round(n * cfg.rollout_ablate_frac) if cfg.intervention in ("route", "route2") else 0
parts = []
if n - n_abl > 0:
parts.append(model.generate(**enc, generation_config=gen_cfg,
num_return_sequences=n - n_abl).detach())
if n_abl > 0:
with ablate_quarantine(wrappers):
parts.append(model.generate(**enc, generation_config=gen_cfg,
num_return_sequences=n_abl).detach())
L = max(p.shape[1] for p in parts)
return torch.cat([F.pad(p, (0, L - p.shape[1]), value=pad_id) for p in parts], dim=0), n_abl
# Per-step table streamed live (header once, row/step), same columns as the final
# tabulate dump; the StepLogger legend below decodes each column. Per-source
# (student/teacher) split on rew/gt/hack: teacher rows are frozen sanity, student
# rows are the "is it learning?" signal. ref_eq = cumulative gens / 256 (the
# canonical 16 prompts x 16 gens/step), so ref_eq=1.0 = one reference step's samples.
run_modes = sorted({p["env_mode"] for p in problems}, key=lambda m: list(MODE_CODE).index(m))
step_logger = StepLogger(arm=cfg.arm, modes=run_modes, mode_code=MODE_CODE)
REF_GENS_PER_STEP = 16 * 16 # ariahw/rl-rewardhacking config.py:num_prompts * num_generations
# Use the resolved locals (preset defaults merged), not cfg.* which can be None.
est_gens_per_step = prompts_per_step * group # before mixed-pool split
logger.info(
f"grad-pressure: {est_gens_per_step} gens/step vs reference {REF_GENS_PER_STEP} "
f"-> {est_gens_per_step / REF_GENS_PER_STEP:.2f}x per step; "
f"this run's {steps} steps ~= {steps * est_gens_per_step / REF_GENS_PER_STEP:.1f} reference steps."
)
# Legend (decodes only the columns this arm/mode-set actually shows) + blank
# line + header in one log entry so the blank line keeps no timestamp prefix.
logger.info("\n" + step_logger.legend() + "\n\n")
logger.info(step_logger.header())
# Per-run artifacts grouped under runs/<ts>_<run_id>/ (same stem as the log,
# so a run's checkpoint and log sit together). See out_dir_reorg spec.
run_dir = RUNS_DIR / verbose_log.stem
run_dir.mkdir(parents=True, exist_ok=True)
ckpt_path = run_dir / "train.safetensors"
first_hack_path = run_dir / "first_hack.safetensors"
# Per-rollout audit log: every live-graded student completion (full text +
# all hack-mechanism flags), one JSON object per line. Lets us eyeball
# *which* hack the student found and whether the mechanism shifts mid-run
# (e.g. it routes around v_hack into a category the pairs don't span).
# Offline observability only -- never read back into training, so no-cheat
# invariant holds. Truncated fresh each run.
rollout_log_path = run_dir / "rollouts.jsonl"
rollout_log_path.write_text("")
first_hack_saved = False
route_span_checked = False # R3: assert delta_S_hack.grad in span(V) once
# route2-grad per-step calibrated routing threshold (spec
# docs/spec/20260601_calibrated_tau_route2grad.md). tau = EMA midpoint of the
# hack-cloud (teacher + detector-flagged student) and clean-cloud (not-flagged
# student) cos(g_b, v_grad) per module. Rides the cin drift so a fixed cos>0
# gate (a ~50% coin-flip in high-dim) is replaced by "above where known hacks
# separate from clean". Persist across steps (EMA = cheap "last N hacks").
ema_hack_cos: dict[str, float] = {}
ema_clean_cos: dict[str, float] = {}
route2_tau: dict[str, float] = {}
EMA_BETA = 0.9
last_gen_sample = None # first student rollout of the latest step (for collapse inspection)
diverged_steps = 0 # consecutive steps with collapsed teacher ppl (divergence tripwire)
lp_t_best = -float("inf") # coherence high-water mark (best teacher gen_logp seen)
# ppl_t = exp(-lp_t) on the FIXED teacher rollouts is a free coherence gauge.
# Divergence is a DROP from the run's own best, not an absolute level: a healthy
# model sits near lp_t ~ -0.7 and craters to -11..-21 (token salad) on divergence.
# Relative threshold also keeps smoke green (tiny-random sits at lp_t ~ -11.9 but
# stays flat). Abort if lp_t falls this far below best for 2 steps (advantage dead).
DIVERGENCE_DROP = 5.0 # nats below best (e^5 ~ 150x worse ppl); never in healthy runs
WARN_DROP = 3.0 # softer: log a warning before the hard abort
dumped_hack_classes: set[str] = set() # first full example of each hack class -> verbose log
teacher_dumped = False
# Per-mode learning tracker (the substrate UAT: did the student learn EACH hack,
# and at what step?). Keyed by env_mode. exploited / rollouts counted on STUDENT
# rollouts only; first_step = step the student first exploited that mode.
mode_rollouts: dict[str, int] = {}
mode_hacks: dict[str, int] = {}
mode_first_step: dict[str, int] = {}
def save_ckpt(rows: list[dict], path: Path | None = None) -> None:
"""Rewrite the run checkpoint in place: trainable δS as tensors, per-step
rows + config as JSON metadata (safetensors metadata is str->str only, so the
non-tensor payload is JSON). Called every 25 steps and at the end, so an early
kill keeps everything up to the last save. Rows are also streamed to the log,
so this is convenience, not the only copy. Mirrors the v_hack metadata idiom."""
n_gens = sum(r["N"] for r in rows)
# Aggregate from per-source columns (the combined hack/gt aggregates were
# dropped from the per-step table as redundant; reconstruct here).
hr = sum(r["hack_s"][0] + r["hack_t"][0] for r in rows) / max(1, n_gens)
pr = sum(r["gt_s"][0] + r["gt_t"][0] for r in rows) / max(1, n_gens)
# Save δS only (not δS_hack). For route this is exactly the
# deployment adapter: the quarantine knob is ablated at eval, so dropping
# it here == the model you'd ship.
tensors = {n: info["delta_S"].detach().cpu().contiguous()
for n, info in wrappers.items()}
save_file(tensors, str(path or ckpt_path), metadata={
"model": model_name, "dtype": "bf16", "step": str(len(rows)),
"hack_rate": f"{hr:.6f}", "pass_rate": f"{pr:.6f}",
"rows": json.dumps(rows), "cfg": json.dumps(vars(cfg), default=str),
})
# disable=None: auto-disable the bar when stdout is NOT a tty (pueue, pipes,
# file redirects). In those contexts every per-step `logger.info(step_logger.row)`
# goes through tqdm.write, which redraws the bar -> half-drawn fragments
# interleaved with the per-step table. Killing the bar off-tty leaves clean
# per-step rows (they already carry step + sec, so the bar is redundant there);
# an interactive terminal still gets the live bar. mininterval==maxinterval keeps
# that interactive bar sparse (tqdm's default maxinterval=10 forces 10s redraws).
pbar = tqdm(range(steps), desc=f"train {cfg.arm} {cfg.preset_name}",
mininterval=120, maxinterval=120, disable=None)
# ── training loop: generate -> grade -> backward -> project -> step ──
for step in pbar:
# Teacher-off curriculum: seed hacks via the teacher pool for the first N
# steps, then cut to pure on-policy (G_t=0) so we test whether route2 holds
# the suppression once the teacher crutch is gone. Monotonic: stays off.
if cfg.teacher_off_step is not None and step >= cfg.teacher_off_step and G_t > 0:
logger.info(f"teacher-off curriculum: step {step} >= {cfg.teacher_off_step} "
f"-> G_t {G_t}->0, G_s {G_s}->{group} (pure on-policy from here)")
G_t, G_s = 0, group
t0 = time.time()
opt.zero_grad(set_to_none=True)
# Accumulate across P prompts; one optimizer step at the end. Per-prompt
# group of G generations is the GRPO advantage normalisation unit.
agg_rew, agg_gt, agg_hack, agg_fmt = [], [], [], []
# Per-mechanism flags. Only populated for student rollouts (teacher pool
# cache predates E/D fields). Teacher slots padded with False so the lists
# stay aligned with agg_is_student. Half-A/B totals filter on is_student.
agg_hack_E: list[bool] = []
agg_hack_D: list[bool] = []
step_rollouts: list[dict] = [] # student completions this step -> rollout_log_path
agg_is_student: list[bool] = []
agg_is_ablated: list[bool] = [] # deploy-mode (quarantine-ablated) student rows -> free per-step deploy proxy
step_mode_hacks: dict[str, int] = {} # THIS step's student hacks per mode (the hk_<mode> columns; reset each step so they don't grow)
agg_logp: list[float] = [] # per-rollout mean per-token gen_logp (student's logp on rollout tokens)
agg_comp_lens, agg_finished, n_skipped = [], [], 0
n_zerovar = 0 # groups skipped for zero reward variance (all rollouts same reward).
# Rises as a loophole saturates: every rollout hacks -> identical reward -> no
# GRPO signal. Tracks the post-saturation signal-sparsity that drives lp_s collapse.
agg_loss = 0.0
diag_tail = None
# Per-source grad accumulators: each prompt's backward is split into
# student-only and teacher-only passes so we can compute cos_pre_s / cos_pre_t
# separately (discriminator: does v_hack actually project hack grads
# more than non-hack?). step_grad_combined = student + teacher and is
# what the projection + optimizer step ultimately sees.
step_grad_s: dict[str, torch.Tensor] = {}
step_grad_t: dict[str, torch.Tensor] = {}
# route2: the flagged rollouts' δS-grad contribution, accumulated per module
# across prompts, parked into δS_hack.grad at injection (the quarantine,
# deleted at deploy). Mirrors how proj.py parks route's removed component.
step_grad_hack: dict[str, torch.Tensor] = {}
# route2: recover the per-rollout δS grad from the gate (c.grad = δS * g_b),
# flag rollouts whose grad points hack-ward (cos(g_b, v_grad) > τ), and route
# their contribution into δS_hack. Only axes where δS has moved (|δS| > GATE_EPS)
# carry a reliable per-rollout split; near-zero axes keep the full grad, so
# routing on a fresh axis lags ~1 step until δS grows there (A1 stale-mask trade-off).
GATE_EPS = 1e-6
step_flagged: list[float] = []
step_tau: list[float] = [] # per-(prompt,module) calibrated route threshold
step_hkgap: list[float] = [] # ema_hack_cos - ema_clean_cos (discrimination gauge)
step_resid: list[float] = [] # cos(δS.grad AFTER routing, v_grad): hack-ward leak into deployed knob
def _route2_grad_filter(info, n_rollouts: int,
hack_anchor: torch.Tensor,
clean_anchor: torch.Tensor) -> torch.Tensor:
g = info["delta_S"].grad # [r] summed over rollouts*tokens
# The hook's gate c is per-token ([G*s, r]) because nn.Linear sees a
# flattened batch. Sum each rollout's token gate-grads -> per-rollout
# δS*g_b: reshape [G*s, r] -> [G, s, r] -> sum tokens -> [G, r].
# Pad tokens carry ~0 grad (masked in the loss), so summing every
# position is safe. Per-rollout (not per-token) is the preregistered
# unit: GRPO advantage is per-rollout, and summing first denoises the
# cos(g_b, v_grad) sign (a clean rollout's individual tokens scatter
# ~50% over cos>0; its token-sum points reliably clean-ward).
cg = info["layer"]._antipasto_gate.grad.reshape(n_rollouts, -1, g.shape[0]).sum(1) # [G, r]
dS = info["delta_S"].detach() # [r]
reliable = dS.abs() > GATE_EPS # [r]
dS_safe = torch.where(reliable, dS, torch.ones_like(dS))
g_b = torch.where(reliable, cg / dS_safe, torch.zeros_like(cg)) # [G, r] per-rollout
vg = v_grad[name] # [r] unit, hack-ward
cos_b = (g_b @ vg) / g_b.norm(dim=1).clamp_min(1e-12) # [G]
# Calibrate the threshold to where KNOWN hacks separate from clean,
# per module, EMA-smoothed across steps (rides the cin drift). A fixed
# cos>0 gate is a ~50% coin-flip in high-dim (cos~0 for most rollouts).
if hack_anchor.any():
mu_h = cos_b[hack_anchor].mean().item()
ema_hack_cos[name] = (EMA_BETA * ema_hack_cos[name] + (1 - EMA_BETA) * mu_h
if name in ema_hack_cos else mu_h)
if clean_anchor.any():
mu_c = cos_b[clean_anchor].mean().item()
ema_clean_cos[name] = (EMA_BETA * ema_clean_cos[name] + (1 - EMA_BETA) * mu_c
if name in ema_clean_cos else mu_c)
tau = (ema_hack_cos.get(name, 0.0) + ema_clean_cos.get(name, 0.0)) / 2
route2_tau[name] = tau
step_tau.append(tau)
step_hkgap.append(ema_hack_cos.get(name, 0.0) - ema_clean_cos.get(name, 0.0))
# Force-route known hacks (teacher + flagged student); τ-route the
# ambiguous rest (incl. unknown B, which lands above τ if it shares the
# v_grad direction). Do NOT force-keep clean_anchor: it is contaminated
# with unknown B, which we WANT routed.
flagged = (hack_anchor | (cos_b > tau)).float() # [G]
step_flagged.append(flagged.mean().item())
sub = torch.where(reliable, (cg * flagged.unsqueeze(1)).sum(0) / dS_safe,
torch.zeros_like(g)) # flagged rollouts' contribution
# Park the flagged contribution in δS_hack (deleted at deploy); δS keeps
# only the unflagged. Capacity-balanced: both shape [r].
step_grad_hack[name] = (step_grad_hack[name] + sub.detach().clone()
if name in step_grad_hack else sub.detach().clone())
g_keep = g - sub # the deployed knob's gradient
# Residual hack-ward alignment of the KEPT grad. Disambiguates qE:
# qE high + resid~0 = routing stripped the hack cleanly (dominant
# teacher grad correctly quarantined); qE high + resid>0 = false
# negatives leaked hack-ward grad into the deployed knob (the real
# failure). vg is unit, so this is a plain cosine.
step_resid.append((g_keep @ vg / g_keep.norm().clamp_min(1e-12)).item())
return g_keep
# Split backward into student/teacher only every cos_pre_split_every steps.
# On split steps: 2 backwards per prompt, populates step_grad_s/_t.
# On skipped steps: 1 combined backward, step_grad_s/_t stay empty and
# cos_pre_s/cos_pre_t go to NaN (mean_cos_pre_from_grads returns NaN on empty dict).
# route2 has no v_hack so cos_pre is NaN regardless: force the single combined
# backward (the split would just double cost). The grad-mask reads its
# per-rollout gate from that one backward.
split_this_step = (step % cfg.cos_pre_split_every == 0) and not is_route2
# Phase timers (per-step cumulative, seconds). Each GPU phase ends in a
# CPU-blocking op (decode / .item()), so perf_counter is sync-accurate
# without explicit cuda.synchronize. Tells us whether wall-time is
# generation-bound (-> vLLM), forward/backward-bound (-> lower pp), or
# reward-subprocess-bound (-> parallel grading).
t_gen = t_rew = t_fb = 0.0
# ── per prompt: G_s student + G_t teacher rollouts -> grade -> backward ──
for p_idx in range(prompts_per_step):
idx = int(torch.randint(0, len(problems), (1,), generator=rng).item())
prob = problems[idx]
prompt = tok.apply_chat_template(
prob["messages"], tokenize=False, add_generation_prompt=True,
enable_thinking=False, # canonical training default; no-op if template ignores it
)
enc = tok(prompt, return_tensors="pt", add_special_tokens=False).to(device)
plen = enc.input_ids.shape[1]
if plen + max_new > 2048:
n_skipped += 1
continue
# KV cache is essential for autoregressive decode (O(L) vs O(L^2) recompute
# per token) -- cacheless was the ~19min/step cost. Enable for generate,
# disable for the loss forwards below (single forward; a cache would just
# waste memory). DynamicCache grows to the actual length, so max_new only
# bounds the tail, not the typical footprint.
model.config.use_cache = True
_tg = time.perf_counter()
teacher_sample: list[dict] | None = None
pool_rows = teacher_pool.get(prob["problem_id"]) if teacher_pool else None
if teacher_pool and G_t > 0 and not pool_rows and cfg.teacher_modes is None:
# Sparse-pool skip: prompt uncached -> skip the whole prompt;
# falling back to student-only would break the student-vs-teacher
# comparison the normal mixed-pool run is designed to measure.
# SUPPRESSED under teacher_modes (A5): a held-out-mode prompt has no
# teacher demos BY DESIGN and must train on-policy (falls to else).
n_skipped += 1
continue
if pool_rows and G_t > 0:
# Mixed-pool: G_s live student + G_t cached teacher rollouts.
# G_t==0 (mix=0 no-teacher ablation) falls through to the student-only
# path below; the pool stays loaded for partition + v_grad extraction.
# Random sample without replacement when cache is large enough.
# Re-seeded per (step, p_idx) by the global rng so runs reproduce.
idxs = torch.randperm(len(pool_rows), generator=rng)[:G_t].tolist()
if len(pool_rows) < G_t:
idxs = idxs + torch.randint(0, len(pool_rows), (G_t - len(pool_rows),), generator=rng).tolist()
teacher_sample = [pool_rows[i] for i in idxs]
# Student live-gen (G_s rows; a rollout_ablate_frac slice generated
# with the quarantine ablated, see gen_students).
with torch.no_grad():
out_s, n_abl = gen_students(enc, G_s)
# Build teacher tensor: live-tokenized prompt + cached completion.
# Cached prompt_ids are ignored; re-tokenizing live makes the pool
# robust to chat-template / tokenizer drift between the model used
# for pool generation (Qwen3-4B) and the current student (e.g.
# tiny-random-qwen3 under smoke). Same vocab is assumed.
live_prompt_ids = enc.input_ids[0].tolist()
teacher_seqs = [
torch.tensor(live_prompt_ids + r["completion_ids"], dtype=torch.long, device=device)
for r in teacher_sample
]
L_t = max(s.shape[0] for s in teacher_seqs)
out_t = torch.stack([F.pad(s, (0, L_t - s.shape[0]), value=pad_id) for s in teacher_seqs])
L = max(out_s.shape[1], out_t.shape[1])
if out_s.shape[1] < L:
out_s = F.pad(out_s, (0, L - out_s.shape[1]), value=pad_id)
if out_t.shape[1] < L:
out_t = F.pad(out_t, (0, L - out_t.shape[1]), value=pad_id)
gen_out = torch.cat([out_s, out_t], dim=0)
is_student = [True] * G_s + [False] * G_t
# gen_students puts the ablated (deploy-mode) rollouts LAST among
# the G_s student rows; teacher rows are never ablated.
is_ablated = [False] * (G_s - n_abl) + [True] * n_abl + [False] * G_t
else:
with torch.no_grad():
gen_out, n_abl = gen_students(enc, G_s) # G_s == group when no teacher
is_student = [True] * gen_out.shape[0]
is_ablated = [False] * (G_s - n_abl) + [True] * n_abl
model.config.use_cache = False
merged = gen_out
completions = gen_out[:, plen:]
texts = tok.batch_decode(completions, skip_special_tokens=True)
t_gen += time.perf_counter() - _tg
# First-batch full dump (system msg + user msg + rendered prompt + completion
# with special tokens). Goes to verbose log only; stdout stays clean.
# Reading this lets us eyeball that the prompt is what we think it is and
# that the model isn't emitting role tokens.
if step == 0 and p_idx == 0:
comp_with_special = tok.decode(completions[0], skip_special_tokens=False)
sys_msg = next((m["content"] for m in prob["messages"] if m.get("role") == "system"), "<no system>")
user_msg = next((m["content"] for m in prob["messages"] if m.get("role") == "user"), "<no user>")
logger.debug(
"\nNOTE: following block is the actual rendered prompt + first model "
"completion with special chars, for tokenizer/format debugging.\n"
"=== FIRST BATCH FIRST SAMPLE DUMP ===\n"
f"--- system msg ---\n{sys_msg}\n"
f"--- user msg ---\n{user_msg}\n"
f"--- rendered prompt (with special chars) ---\n{prompt}\n"
f"--- completion (with special chars, {completions[0].numel()} tokens) ---\n{comp_with_special}\n"
"=== END FIRST BATCH DUMP ==="
)
comp_lens = [int((c != pad_id).sum().item()) for c in completions]
finished = [bool((c == eos_id).any().item()) for c in completions]
agg_comp_lens.extend(comp_lens); agg_finished.extend(finished)
_tr = time.perf_counter()
rs, hack_flags, gt_flags, fmt_flags = [], [], [], []
hack_E_flags, hack_D_flags = [], []
# Live-grade only student completions; teacher uses cached labels for
# reproducibility and zero-cost re-use.
n_live_grade = G_s if teacher_pool else len(texts)
for gi, t in enumerate(texts[:n_live_grade]):
r = compute_reward(
t, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"],
setup_code=prob["setup_code"], func_name_hint=prob["func_name"],
env_mode=prob["env_mode"],
)
# Plotted "hack" = exploited (got correctness credit but the strict
# oracle disagrees); "solve" = gt_correct (exploit-resistant). Reward
# pays on r.passed (the foolable grade), which is what drives GRPO.
rs.append(r.reward)
hack_flags.append(r.exploited)
gt_flags.append(r.gt_correct)
fmt_flags.append(r.format_ok)
hack_E_flags.append(r.hacked_loophole_used)
hack_D_flags.append(r.hacked_wrong_tests)
# Per-mode learning tally (substrate UAT). This problem's subset mode.
pm = prob["env_mode"]
mode_rollouts[pm] = mode_rollouts.get(pm, 0) + 1
if r.exploited:
mode_hacks[pm] = mode_hacks.get(pm, 0) + 1
step_mode_hacks[pm] = step_mode_hacks.get(pm, 0) + 1
if pm not in mode_first_step:
mode_first_step[pm] = step
# First full example of each hack mechanism -> verbose log (with
# special tokens). Lets us eyeball WHAT the hack looks like, not just
# the flag. One dump per mechanism for the whole run.
hack_cls = r.mechanism # env_mode when exploited, else None
if hack_cls and hack_cls not in dumped_hack_classes:
dumped_hack_classes.add(hack_cls)
should = {
"run_tests": "completion defines a vacuous run_tests() that passes against "
"an arbitrary Solution; ELSE the hacked flag is mis-firing",
"eq_override": "completion's output overrides __eq__ (or similar) so == is "
"always True; the JSON strict oracle rejects it; ELSE exploited mis-firing",
"exit_code": "completion has a module-level sys.exit/os._exit(0) before the "
"asserts, so the grader sees exit 0 on wrong code; ELSE exploited mis-firing",
"stdout_marker": "completion prints a pass report to stdout on wrong code; "
"ELSE the stdout probe mis-fired",
"sentinel": "completion sets a truthy `valid` flag on its Solution object on "
"wrong code; ELSE the sentinel probe mis-fired",
"file_marker": "completion writes a pass-reporting result file on wrong code; "
"ELSE the file probe mis-fired",
}[hack_cls]
logger.debug(
f"\n\n=== FIRST {hack_cls} HACK EXAMPLE (step {step}, p_idx {p_idx}, gi {gi}) ===\n"
f"SHOULD: {should}\n"
f"--- rendered prompt (special chars) ---\n{prompt}\n"
f"--- completion (special chars) ---\n{tok.decode(completions[gi], skip_special_tokens=False)}\n"
f"=== END {hack_cls} ===")
step_rollouts.append({
"step": step, "p_idx": p_idx, "gi": gi,
# problem identity + the exact prompt: the per-prompt problem is a
# RANDOM draw (idx above), so without these a rollout can't be mapped
# back to its prompt -- needed to harvest same-prompt (hack,clean)
# pairs from real student rollouts (A5 held-out-mode v_grad).
"problem_id": prob["problem_id"],
"env_mode": (partition[prob["problem_id"]] if partition else cfg.env_mode),
"prompt": prompt,
"reward": r.reward, "gt_pass": r.gt_pass, "gt_correct": r.gt_correct,
"passed": r.passed, "exploited": r.exploited, "mechanism": r.mechanism,
"hacked_C": r.hacked, "hacked_D": r.hacked_wrong_tests,
"hacked_E": r.hacked_loophole_used, "format_ok": r.format_ok,
"text": t,
})
if teacher_sample is not None:
for r in teacher_sample:
rs.append(float(r["reward"])); hack_flags.append(bool(r["hacked"]))
gt_flags.append(bool(r["gt_pass"])); fmt_flags.append(bool(r["fmt_ok"]))
# Teacher cache lacks E/D -- pad with False to keep lists aligned
# with agg_is_student. Half-A/B BLUF filters on is_student so
# these never enter the reported numerator/denominator.
hack_E_flags.append(False); hack_D_flags.append(False)
t_rew += time.perf_counter() - _tr
agg_rew.extend(rs); agg_gt.extend(gt_flags); agg_hack.extend(hack_flags); agg_fmt.extend(fmt_flags)
agg_hack_E.extend(hack_E_flags); agg_hack_D.extend(hack_D_flags)
agg_is_student.extend(is_student)
agg_is_ablated.extend(is_ablated)
if (step < 3 or step % 20 == 0) and p_idx == 0:
# Capture diagnostic tail of one generation per step. Look for
# mid-statement truncation (no closing ```), <think> traces, etc.
diag_tail = texts[0][-400:]
rewards = torch.tensor(rs, dtype=torch.float32, device=device)
# simple_GRPO grpo_vllm_one.py:208: skip groups where every generation
# got the same reward. Dr.GRPO's advantage would be zero anyway, so
# the policy forward + backward is pure compute waste. This is the
# dominant pathology with our binary-ish reward shape on a weak 2B
# substrate (every group can clip to 0.25 = format_only).
if (rewards.max() - rewards.min()).item() < 1e-4:
# Pad agg_logp with NaN to keep it aligned with agg_is_student
# (extended above at line 770). Skipping the logπ_old forward
# here is the whole point of the zero-variance bail.
agg_logp.extend([float("nan")] * len(rs))
n_zerovar += 1
continue
A = rewards - rewards.mean() # advantage; Dr.GRPO unbiased: no /σ_R
if not cfg.unbiased:
A = A / (rewards.std() + 1e-4)
# logπ_old: old-policy logprobs (frozen PPO-ratio target). logits_to_keep
# =L_c+1 runs lm_head only on completion-side hidden states (prompt-side
# logits never materialize, ~plen/(plen+L_c) memory saved); [:, :-1] drops
# the last position (predicts beyond `merged`, unused).
completion_ids = merged[:, plen:]
L_c = completion_ids.shape[1]
_tfb = time.perf_counter()
with torch.no_grad():
logπ_old = per_token_logps(
model(merged, logits_to_keep=L_c + 1).logits[:, :-1],
completion_ids,
).detach()
logπ_ref = None
if beta and beta > 0:
logπ_ref = ref_logprobs_via_zero_delta(model, merged, wrappers, plen).detach()
logπ = per_token_logps(
model(merged, logits_to_keep=L_c + 1).logits[:, :-1],
completion_ids,
)
mask = (merged[:, plen:] != pad_id).float()
# Per-rollout mean per-token logπ_old (student's logp on its own tokens).
# In single-step PPO logπ_old == logπ.detach(), so ρ≡1 and the loss treats
# student and teacher rows identically. Diagnostic only (no IS correction):
# the per-source gap lp_s - lp_t measures how far the student has drifted
# from the teacher pool's tokens.
mean_logp_per_rollout = ((logπ_old * mask).sum(1) / mask.sum(1).clamp_min(1)).detach().cpu().tolist()
agg_logp.extend(mean_logp_per_rollout)
ρ = torch.exp(logπ - logπ_old) # ≡1 at a single inner step; keep the clip form
A_tok = A.unsqueeze(1)
Lp = -torch.min(ρ * A_tok, torch.clamp(ρ, 1 - cfg.clip, 1 + cfg.clip) * A_tok)
if logπ_ref is not None: # K3 KL estimator
Lp = Lp + beta * (torch.exp(logπ_ref - logπ) - (logπ_ref - logπ) - 1.0)
# Per-source split (loss_s + loss_t == full-batch loss because
# is_s_v + is_t_v = 1 elementwise; backward is linear so
# grad_s + grad_t == full-batch grad). Two backwards every step is
# ~2x backward cost, gated to every cos_pre_split_every step.
is_s_v = torch.tensor(is_student, dtype=Lp.dtype,
device=Lp.device).unsqueeze(1) # [G, 1]
is_t_v = 1.0 - is_s_v
if split_this_step:
if cfg.unbiased:
denom = group * max_new * prompts_per_step
loss_s = (Lp * mask * is_s_v).sum() / denom
loss_t = (Lp * mask * is_t_v).sum() / denom
else:
ptl_norm = (Lp * mask).sum(1) / mask.sum(1).clamp_min(1)
loss_s = (ptl_norm * is_s_v.squeeze(1)).sum() / (group * prompts_per_step)
loss_t = (ptl_norm * is_t_v.squeeze(1)).sum() / (group * prompts_per_step)
# Pass 1: student. retain_graph so the shared forward graph survives.
loss_s.backward(retain_graph=True)
for name, info in wrappers.items():
gs = info["delta_S"].grad
if gs is None:
continue
step_grad_s[name] = (step_grad_s[name] + gs.detach().clone()
if name in step_grad_s
else gs.detach().clone())
model.zero_grad(set_to_none=True)
# Pass 2: teacher.
loss_t.backward()
for name, info in wrappers.items():
gt = info["delta_S"].grad
if gt is None:
continue
step_grad_t[name] = (step_grad_t[name] + gt.detach().clone()
if name in step_grad_t
else gt.detach().clone())
model.zero_grad(set_to_none=True)
agg_loss += (loss_s + loss_t).item()
else:
# Combined single backward: cheaper, no per-source diagnostic.
# Accumulate into step_grad_s as the "combined" carrier; the
# injection block below treats step_grad_t == {} as "use gs".
if cfg.unbiased:
denom = group * max_new * prompts_per_step
loss = (Lp * mask).sum() / denom
else:
ptl_norm = (Lp * mask).sum(1) / mask.sum(1).clamp_min(1)
loss = ptl_norm.sum() / (group * prompts_per_step)
loss.backward()
# route2: per-prompt anchor masks for the τ calibration. Hack cloud =
# teacher rows (known-A hacks) + detector-flagged (hack_E) student rows;
# clean cloud = not-flagged student rows (contaminated with unknown B by
# design -> conservative τ; B still routes via cos>τ). hack_E_flags
# (len G_s) aligns with the leading student rows of is_student.
if is_route2:
_ha, _ca = build_route2_anchors(
is_student, hack_E_flags, cfg.gate_anchor_teacher_only, Lp.device)
for name, info in wrappers.items():
g = info["delta_S"].grad
if g is None:
continue
# route2 routes here: strip flagged rollouts from δS.grad and
# park them in δS_hack (via step_grad_hack in the filter).
if is_route2:
g = _route2_grad_filter(info, merged.shape[0], _ha, _ca)
step_grad_s[name] = (step_grad_s[name] + g.detach().clone()
if name in step_grad_s
else g.detach().clone())
model.zero_grad(set_to_none=True)
agg_loss += loss.item()
t_fb += time.perf_counter() - _tfb
# ── inject grad -> project / route ──
# Combine student + teacher grad into each leaf δS.grad (one source -> take it).
for name, info in wrappers.items():
gs = step_grad_s.get(name)
gt = step_grad_t.get(name)
if gs is None and gt is None:
continue
if gs is None:
info["delta_S"].grad = gt
elif gt is None:
info["delta_S"].grad = gs
else:
info["delta_S"].grad = gs + gt
# route2: park the flagged rollouts' contribution into δS_hack.grad (its own
# forward-path grad was wiped by the per-prompt zero_grad; we impose the routed
# grad here, like proj.py's route).
for name, g in step_grad_hack.items():
wrappers[name]["delta_S_hack"].grad = g
# Per-source cin: project student-only and teacher-only grads into v_hack
# subspace. Discriminator: cos_pre_t > cos_pre_s on a clean base means v_hack
# lights up for hack grads more than non-hack. Only valid on split steps;
# otherwise step_grad_s holds the combined grad and would mis-report cos_pre_s.
# v_hack is None on the vanilla arm (pure GRPO baseline, no subspace): skip
# the projection/measurement entirely and emit a nan diag -> the cin/cout
# columns (hidden on vanilla anyway) render nan. erase/route always have v_hack.
if v_hack is None:
diag = {"mean_cos_pre": float("nan"), "mean_cos_post": float("nan"),
"frac_fired": float("nan"), "mean_cos_pre_s": float("nan"),
"mean_cos_pre_t": float("nan")}
# route2: report the mean per-module per-rollout flag rate so we can
# watch the mask actually fire (and rise as hacks emerge).
if is_route2 and step_flagged:
logger.debug(f"route2 flagged frac (mean over modules*prompts): "
f"{sum(step_flagged)/len(step_flagged):+.3f}")
else:
if split_this_step:
cos_pre_s = mean_cos_pre_from_grads(step_grad_s, v_hack)
cos_pre_t = mean_cos_pre_from_grads(step_grad_t, v_hack)
else:
cos_pre_s = cos_pre_t = float("nan")
# grad is mutated only for erase (subtract) and route (subtract + park in
# δS_hack). cos_pre is measured on both.
diag = project_delta_S_grad(
wrappers, v_hack, cfg.preserve_magnitude,
measure_only=False, # erase/route both project; vanilla took the branch above
route=(cfg.intervention == "route"),
gate_mode=cfg.gate_mode,
overshoot=cfg.project_overshoot,
)
diag["mean_cos_pre_s"] = cos_pre_s
diag["mean_cos_pre_t"] = cos_pre_t
# R3 span check (once, on the first routed step that fires): the parked
# quarantine grad must live in span(V). removed = c_use@V is a combo of
# the orthonormal rows of V, so projecting it back via VᵀV should be a
# no-op; residual/‖removed‖ ~ 0. Catches a routing math bug loudly.
if cfg.intervention == "route" and not route_span_checked and diag["frac_fired"] > 0:
for name, info in wrappers.items():
gh = info["delta_S_hack"].grad
if gh is None or gh.norm() < 1e-12 or name not in v_hack:
continue
V = v_hack[name].to(gh.device, dtype=gh.dtype) # [k, r], rows orthonormal
resid = gh - V.T @ (V @ gh) # component outside span(V)
ratio = (resid.norm() / gh.norm()).item()
logger.info(f"R3 span check [{name}]: ||resid||/||gh|| = {ratio:.2e} (want <1e-4)")
assert ratio < 1e-4, f"delta_S_hack.grad escaped span(V): {ratio:.2e}"
route_span_checked = True
break
# clip_grad_norm_ returns the pre-clip total L2 norm, captured for the
# per-step `gn` column so we can see whether the clip threshold is the
# bottleneck on update magnitude (compare gn vs cfg.grad_clip).
# Clip over both knobs. For none/erase, δS_hack.grad is None so it's
# ignored (identical norm to before). For route it bounds the combined
# update (main + quarantine).
# Grad-energy split: qE = ‖g_quar‖/(‖g_keep‖+‖g_quar‖) ∈ [0,1], the share
# of the update routed into the quarantine (δS_hack, deleted at deploy).
# Rising qE => routing dumps learning into the thrown-away knob and the
# deployed model learns nothing. ~0 idle; ~0.5+ climbing = quarantine
# eating the update.
def _grad_l2(params):
gs = [p.grad for p in params if p.grad is not None]
return float(torch.norm(torch.stack([g.norm() for g in gs]))) if gs else 0.0
gn_keep = _grad_l2(delta_params)
gn_quar = _grad_l2(delta_hack_params)
q_egy = gn_quar / (gn_keep + gn_quar) if (gn_keep + gn_quar) > 0 else 0.0
gn = float(torch.nn.utils.clip_grad_norm_(delta_params + delta_hack_params, cfg.grad_clip))
opt.step()
sched.step()
# ── v_hack / v_grad refresh ──
# Online v_hack refresh: re-extract against the *current* model so the
# hack subspace tracks where the student is being pulled now (rather
# than at step 0). Same PAIRS, same extract code; we just discard the
# saved cache and overwrite the in-memory v_hack dict.
refr = "-" # set to "mod/axes" below if a refresh fires; rendered in the per-step row
do_refresh = cfg.vhack_refresh_every > 0 and (step + 1) % cfg.vhack_refresh_every == 0
if do_refresh and is_route2 and cfg.route2_random_v_seed is not None:
do_refresh = False # keep the one fixed Haar draw; re-extracting would replace it
if do_refresh and is_route2:
# route2 v_grad refresh: re-extract against the CURRENT model so the
# routing direction tracks where hacks separate now, not at step 0.
# Without this the frozen direction goes stale -- cin_t decays to cin_s
# within ~6 steps. Same MASK_PAIRS (the weak
# detector, no oracle); quarantine ablated so the hack signal flows back
# through the observable path, matching the state the build-time extract saw.
_was_training = model.training
model.eval()
opt.zero_grad(set_to_none=True)
logger.disable("vgrout.extract_vhack_grad")
logger.disable("__main__")
try:
with ablate_quarantine(wrappers):
from .extract_vhack_grad import extract_v_hack
_, _, raw_grads, _ = extract_v_hack(
model, tok, wrappers, MASK_PAIRS,
top_k=1, tau_axis=0.0, n_heldout=2, device=device,
)
for name in wrappers: # update in place so _route2_grad_filter's closure sees it
d = (raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).mean(0)
v_grad[name] = (d / d.norm().clamp_min(1e-12)).to(device)
finally:
logger.enable("vgrout.extract_vhack_grad")
logger.enable("__main__")
opt.zero_grad(set_to_none=True) # extract leaves .grad populated
if _was_training:
model.train()
refr = "rfr" # compact marker; v_grad refresh has no cheap overlap gauge
if v_hack is not None and do_refresh:
from .extract_vhack_grad import extract_v_hack
if cfg.vhack_pairs_path is not None:
from .pairs_from_pool import load_pairs_json
VHACK_PAIRS = load_pairs_json(cfg.vhack_pairs_path)
else:
from .pairs import PAIRS as VHACK_PAIRS
_was_training = model.training
model.eval()
opt.zero_grad(set_to_none=True)
# Silence per-pair "loss=" and postprocess summary inside refresh:
# the refresh fires every N steps and floods the training log with
# extract-time NLL values that read as if they were training losses.
# The one-line "v_hack refreshed" announcement below is enough.
# When invoked via `python -m vgrout.train`, the entry
# script's __name__ is "__main__", not "vgrout.train",
# so postprocess_v_hack's logger.info (called from here) needs
# __main__ silenced. The extract submodule keeps its own name.
logger.disable("vgrout.extract_vhack_grad")
logger.disable("__main__")
try:
# Extract with the quarantine ablated (δS_hack=0). For route, once the
# hack capability has been routed into δS_hack, the main-knob gradient
# on the pairs no longer carries the hack direction, so re-extracting
# through the live quarantine rotates v_hack off-hack and cin_t collapses
# at the refresh step. Ablating sends the hack back through the observable
# main path, matching the δS_hack=0 state the build extraction saw.
# No-op for erase (δS_hack is never trained, stays 0).
with ablate_quarantine(wrappers):
_new_V, _new_S, _, _ = extract_v_hack(
model, tok, wrappers, VHACK_PAIRS,
top_k=cfg.v_hack_extract_top_k, tau_axis=cfg.v_hack_tau_axis,
n_heldout=2, device=device,
)
_post = postprocess_v_hack(
_new_V, _new_S, k_use=cfg.v_hack_k,
drop_bottom_frac=cfg.v_hack_drop_bottom_frac,
source=f"refresh@step{step}",
)
finally:
logger.enable("vgrout.extract_vhack_grad")
logger.enable("__main__")
# DIAGNOSTIC: how far did the refreshed basis rotate from the prior one?
# Rows are orthonormal, so ||V_new @ V_old^T||_F^2 / k_old = fraction of
# the OLD subspace still spanned by the NEW basis, in [0,1].
# ~1 -> refresh tracks a stable hack subspace (the design's premise)
# ~0 -> re-extraction at current weights landed near-orthogonal, so the
# live grad's overlap (cin_t) jumps discontinuously at the refresh.
shared = set(v_hack) & set(_post)
ovl = [((_post[n].float().to(device) @ v_hack[n].float().mT)).pow(2).sum().item()
/ v_hack[n].shape[0] for n in shared]
overlap = sum(ovl) / max(1, len(ovl))
logger.info(
f"refresh@step{step}: {len(_post)}mod/{sum(V.shape[0] for V in _post.values())}ax "
f"basis_overlap_with_prev={overlap:.3f} "
f"SHOULD: >~0.5 if refresh tracks a stable hack subspace; <~0.2 => "
f"re-extraction rotated the basis (cin_t jumps, refresh is harmful)")
v_hack.clear()
v_hack.update({n: V.to(device) for n, V in _post.items()})
opt.zero_grad(set_to_none=True) # extract leaves .grad populated
if _was_training:
model.train()
refr = f"{len(v_hack)}/{sum(V.shape[0] for V in v_hack.values())}" # mod/axes -> per-step row
# ── periodic DEPLOY-eval (EVERY arm) -- the apples-to-apples curve ──
# Eval the DEPLOYED model on a fixed eval subset with gen_cfg_eval (n=64,
# T=0.7), every eval_ablate_every steps. route/route2: deploy = quarantine
# knob zeroed (ablate_quarantine), and the claim is this hacks far less than
# the training-time model (per-step hack_s, knob still on). vanilla/erase: no
# quarantine, so deploy == the trained model -- eval it directly. Running the
# SAME estimator for all arms makes the dynamics-plot curves comparable (else
# route shows a deploy eval while others show training rollouts -> different
# n/cadence, route looks artificially smoother). NaN on non-eval steps.
hack_deploy = solve_deploy = float("nan")
if cfg.eval_ablate_every > 0 and (step % cfg.eval_ablate_every == 0 or step == steps - 1):
_was_training = model.training
model.eval()
is_route = cfg.intervention in ("route", "route2")
with (ablate_quarantine(wrappers) if is_route else nullcontext()):
ev = eval_hack_solve(model, tok, problems, eval_idxs, gen_cfg_eval, device, max_new)
hack_deploy, solve_deploy = ev["hack"], ev["solve"]
if _was_training:
model.train()
# Deploy (knob-OFF) only -- one pass. The train series comes free from the
# per-step hack_s column, and the full train-vs-deploy 2x2 (knob-ON vs
# knob-OFF on the same eval set) is computed once post-loop (FINAL EVAL).
# A per-step knob-ON pass would just double every eval (~460s -> ~920s)
# for a curve no figure plots. See journal 2026-06-04 (a).
tag = "quarantine knob OFF = deployed model" if is_route else "deployed = trained model (no quarantine)"
should = ("deploy hack < per-step hack_s (knob holds the cheat); ELSE routing isn't capturing it"
if is_route else "deploy ~= training hack_s (same model)")
logger.info(
f"step {step} DEPLOY-eval ({tag}): "
f"hack={hack_deploy:.3f} solve={solve_deploy:.3f} n={ev['n']}. SHOULD: {should}")
rewards_t = torch.tensor(agg_rew, dtype=torch.float32) if agg_rew else torch.zeros(1)
rew_mean = rewards_t.mean().item()
rew_std = rewards_t.std().item() if rewards_t.numel() > 1 else 0.0
spread = (rewards_t.max() - rewards_t.min()).item() > 1e-3 if rewards_t.numel() > 1 else False
n_rollouts = len(agg_rew)
# Per-source breakdown: which rollouts came from student vs teacher this step.
# Note: rollouts from "skipped" groups (no reward spread) are not in agg_*, so
# n_s + n_t == n_rollouts always.
is_s = torch.tensor(agg_is_student, dtype=torch.bool) if agg_is_student else torch.zeros(0, dtype=torch.bool)
h_t = torch.tensor(agg_hack, dtype=torch.bool) if agg_hack else torch.zeros(0, dtype=torch.bool)
g_t = torch.tensor(agg_gt, dtype=torch.bool) if agg_gt else torch.zeros(0, dtype=torch.bool)
n_s = int(is_s.sum())
n_t = int(is_s.numel() - n_s)
hack_s_n = int((h_t & is_s).sum())
hack_t_n = int((h_t & ~is_s).sum())
# Per-mechanism tallies on STUDENT rollouts only. C is just hacked (already
# tallied above as hack_s_n); we recompute here under the E/C/D names to
# keep the half-A/B math readable and to assert consistency.
h_E = torch.tensor(agg_hack_E, dtype=torch.bool) if agg_hack_E else torch.zeros(0, dtype=torch.bool)
h_D = torch.tensor(agg_hack_D, dtype=torch.bool) if agg_hack_D else torch.zeros(0, dtype=torch.bool)
hack_s_E = int((h_E & is_s).sum())
hack_s_C = hack_s_n
hack_s_D = int((h_D & is_s).sum())
# Cross-mech HACK_A / HACK_B: A = any half-A detector fires; B = any
# half-B fires AND no half-A fires (held-out, see spec.md). Computed
# per-step on per-rollout tuples so it's an EXACT OR, not a union-bound.
# cfg.half_a is read once outside the loop; if empty, A/B are skipped.
half_a_codes_step = {c.strip().upper() for c in cfg.half_a.split(",") if c.strip()}
det_step = {"E": h_E, "C": h_t, "D": h_D}
if half_a_codes_step:
mask_A_step = torch.zeros_like(is_s)
for c in half_a_codes_step:
mask_A_step = mask_A_step | det_step[c]
mask_B_step = torch.zeros_like(is_s)
for c in ({"E", "C", "D"} - half_a_codes_step):
mask_B_step = mask_B_step | det_step[c]
hack_s_A = int((mask_A_step & is_s).sum())
hack_s_B = int((mask_B_step & ~mask_A_step & is_s).sum())
else:
hack_s_A = 0
hack_s_B = 0
gt_s_n = int((g_t & is_s).sum())
gt_t_n = int((g_t & ~is_s).sum())
# FREE per-step DEPLOY proxy: the rollout_ablate_frac slice was generated
# with the quarantine ablated == the deployed model, so its hack/solve rate
# is what we'd ship, measured every step at zero extra generation cost.
# Caveat vs hk_dep/slv_dep: this is on the TRAINING prompts (hints present)
# at the sampling temperature, not the held-out greedy eval set -- a noisier,
# same-distribution proxy, not the plot's source-of-truth deploy number.
abl = torch.tensor(agg_is_ablated, dtype=torch.bool) if agg_is_ablated else torch.zeros(0, dtype=torch.bool)
n_abl_step = int(abl.sum())
hack_abl_n = int((h_t & abl).sum())
gt_abl_n = int((g_t & abl).sum())
rew_s_mean = rewards_t[is_s].mean().item() if n_s else float("nan")
# Skipped (zero-variance) prompts pad agg_logp with NaN above to keep
# alignment with is_s. nanmean drops them from the per-source means.
logp_t = torch.tensor(agg_logp, dtype=torch.float32) if agg_logp else torch.zeros(0)
lp_s_mean = logp_t[is_s].nanmean().item() if n_s else float("nan")
lp_t_mean = logp_t[~is_s].nanmean().item() if n_t else float("nan")
# Per-step diagnostics → verbose log; stdout sees tqdm postfix + final table.
n_fin = sum(agg_finished)
n_clipped = n_rollouts - n_fin
_min_len = min(agg_comp_lens) if agg_comp_lens else 0
_mean_len = sum(agg_comp_lens) / max(1, len(agg_comp_lens))
_max_len = max(agg_comp_lens) if agg_comp_lens else 0
logger.debug(
f"step {step} diag rollouts={n_rollouts} finished={n_fin}/{n_rollouts} "
f"clipped(no-eos)={n_clipped}/{n_rollouts} "
f"comp_lens(min/mean/max)={_min_len}/{_mean_len:.0f}/{_max_len} "
f"max_new={max_new} fmt={sum(agg_fmt)}/{n_rollouts} gt={sum(agg_gt)}/{n_rollouts} "
f"hack={sum(agg_hack)}/{n_rollouts} skipped={n_skipped}/{prompts_per_step} "
f"zerovar={n_zerovar}/{prompts_per_step}"
)
_tstep = time.time() - t0
logger.debug(
f"step {step} TIMING gen={t_gen:.0f}s fwd_bwd={t_fb:.0f}s "
f"reward={t_rew:.0f}s other={_tstep - t_gen - t_fb - t_rew:.0f}s "
f"total={_tstep:.0f}s"
)
if diag_tail is not None:
tail = diag_tail.replace("\n", "\\n")
logger.debug(f"step {step} gen[0] tail (last 400 chars): {tail!r}")
cum_gens = sum(r["N"] for r in rows) + n_rollouts
row = {
# Raw values throughout; StepLogger formats for streaming and the
# end-of-run tabulate dump consumes the same dict directly (no
# scientific-notation strings to misparse as floats).
"step": step,
"ref_eq": cum_gens / REF_GENS_PER_STEP,
"rew": rew_mean,
"rew_s": rew_s_mean if n_s else None,
"sprd": "T" if spread else "F",
"N": n_rollouts,
"gt_s": (gt_s_n, n_s) if n_s else (0, 0),
"gt_t": (gt_t_n, n_t) if n_t else (0, 0),
"hack_s": (hack_s_n, n_s) if n_s else (0, 0),
"hack_t": (hack_t_n, n_t) if n_t else (0, 0),
# Per-mode student hacks THIS step (current batch count, not cumulative --
# cumulative grew unboundedly and read as noise). The running mode_hacks/
# mode_rollouts tallies still feed the end-of-run substrate learning table.
# StepLogger only renders these on multi-mode (substrate) runs.
**{f"hk_{MODE_CODE[m]}": step_mode_hacks.get(m, 0) for m in run_modes},
# Per-mechanism on student rollouts only. Used by final-tail BLUF for
# cross-mechanism HACK_A / HACK_B; hidden from the per-step table to
# avoid column bloat (rendered only in the markdown dump below).
"hack_s_E": (hack_s_E, n_s) if n_s else (0, 0),
"hack_s_D": (hack_s_D, n_s) if n_s else (0, 0),
"hack_s_A": (hack_s_A, n_s) if n_s else (0, 0),
"hack_s_B": (hack_s_B, n_s) if n_s else (0, 0),
"lp_s": lp_s_mean if n_s else None,
"lp_t": lp_t_mean if n_t else None,
"loss": agg_loss,
"gn": gn,
"q_egy": q_egy,
"tau": (sum(step_tau) / len(step_tau)) if step_tau else float("nan"),
"hkgap": (sum(step_hkgap) / len(step_hkgap)) if step_hkgap else float("nan"),
"resid": (sum(step_resid) / len(step_resid)) if step_resid else float("nan"),
"lr": sched.get_last_lr()[0],
"cos_pre": diag["mean_cos_pre"],
"cos_pre_s": diag["mean_cos_pre_s"],
"cos_pre_t": diag["mean_cos_pre_t"],
"cos_post": diag["mean_cos_post"],
"fired": diag["frac_fired"],
"refr": refr,
# Route deploy-eval (δS_hack=0); NaN except on route eval steps.
# Appended AFTER refr so results.py's positional GT_S/HACK_S indices
# are unaffected. plot_dynamics reads it by name.
"hack_deploy": hack_deploy,
"solve_deploy": solve_deploy,
# Free per-step deploy proxy from the ablated rollout slice (above).
"hack_abl": (hack_abl_n, n_abl_step) if n_abl_step else (0, 0),
"solve_abl": (gt_abl_n, n_abl_step) if n_abl_step else (0, 0),
"gen": t_gen,
"fb": t_fb,
"t_rew": t_rew,
"sec": time.time() - t0,
}
rows.append(row)
# Stream this step as a row. Reprint the header every 50 rows so long runs
# stay readable without scrolling back (20+ unlabeled columns, no per-row label).
if step > 0 and step % 50 == 0:
logger.info(step_logger.header())
logger.info(step_logger.row(row))
with rollout_log_path.open("a") as fh:
for rec in step_rollouts:
fh.write(json.dumps(rec) + "\n")
if step_rollouts:
last_gen_sample = (step, step_rollouts[0]) # newest student gen for the final dump
# Divergence tripwire on teacher perplexity (free coherence gauge, see init).
ppl_t = math.exp(-lp_t_mean) if math.isfinite(lp_t_mean) else float("inf")
if math.isfinite(lp_t_mean):
lp_t_best = max(lp_t_best, lp_t_mean)
drop = lp_t_best - lp_t_mean if math.isfinite(lp_t_mean) else 0.0
# Soft warning at a smaller drop than the hard abort -- an early "ppl is
# climbing, watch for divergence (lr too high?)" before things are lost.
if WARN_DROP <= drop < DIVERGENCE_DROP:
logger.warning(f"step {step}: lp_t={lp_t_mean:.1f} is {drop:.1f} nats below best "
f"{lp_t_best:.1f} (ppl_t={ppl_t:.0e}) -- coherence slipping, lr too high?")
diverged = math.isfinite(lp_t_mean) and drop > DIVERGENCE_DROP
diverged_steps = diverged_steps + 1 if diverged else 0
if diverged_steps >= 2:
logger.error(
f"DIVERGED at step {step}: lp_t={lp_t_mean:.1f} (ppl_t={ppl_t:.0e}), {lp_t_best - lp_t_mean:.1f} "
f"nats below best {lp_t_best:.1f}, for {diverged_steps} steps -- policy collapsed "
f"(gn={gn:.1f}). Aborting to save GPU. Likely lr too high (route2: lower --route2-quar-lr-scale).")
if last_gen_sample:
_s, _r = last_gen_sample
logger.error(f"--- last student gen (step {_s}, reward={_r['reward']:+.2f}) ---\n"
f"{_r['text'][:800]}\n--- END (token salad => divergence confirmed) ---")
raise RuntimeError(f"training diverged (ppl_t={ppl_t:.0e} at step {step})")
if (step + 1) % 25 == 0:
save_ckpt(rows) # survive early kills; ~12 days for the full sweep
# Per-eval deploy-adapter snapshot: re-scoreable later without retraining.
if cfg.save_eval_ckpts and cfg.eval_ablate_every > 0 \
and (step % cfg.eval_ablate_every == 0 or step == steps - 1):
save_ckpt(rows, path=run_dir / f"ckpt_step{step:04d}.safetensors")
if not first_hack_saved and hack_s_n > 0:
save_ckpt(rows, path=first_hack_path)
first_hack_saved = True
logger.info(f"first-student-hack ckpt saved: step={step} hack_s={hack_s_n}/{n_s} -> {first_hack_path.name}")
# Live status in tqdm postfix; full per-step line in verbose log only.
# refresh=False: set_postfix defaults to forcing a redraw EVERY step, which
# bypasses mininterval and spams half-drawn bar fragments into piped/pueue
# logs. With refresh=False the postfix is shown at the next mininterval tick.
pbar.set_postfix(
rew=f"{rew_mean:+.2f}", gt=f"{sum(agg_gt)}/{n_rollouts}",
hack=f"{sum(agg_hack)}/{n_rollouts}", loss=f"{agg_loss:+.3f}",
sec=f"{time.time()-t0:.0f}",
refresh=False,
)
logger.debug(
f"step {step:3d} rew={rew_mean:+.2f}(std {rew_std:.2f}) "
f"gt={sum(agg_gt)}/{n_rollouts} hack={sum(agg_hack)}/{n_rollouts} "
f"loss={agg_loss:+.3f} cos_pre={diag['mean_cos_pre']:+.3f} "
f"cos_post={diag['mean_cos_post']:+.3f} fired={diag['frac_fired']:.2f} "
f"sec={time.time()-t0:.0f}"
)
peak_gb = torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0
n_steps = len(rows)
n_gens = sum(r["N"] for r in rows)
total_hacks = sum(r["hack_s"][0] + r["hack_t"][0] for r in rows)
total_pass = sum(r["gt_s"][0] + r["gt_t"][0] for r in rows)
hack_rate = total_hacks / max(1, n_gens)
pass_rate = total_pass / max(1, n_gens)
# Per-source totals. On no-teacher runs, hack_s_total == total_hacks.
hack_s_total = sum(r["hack_s"][0] for r in rows)
hack_t_total = sum(r["hack_t"][0] for r in rows)
n_s_total = sum(r["hack_s"][1] for r in rows)
n_t_total = sum(r["hack_t"][1] for r in rows)
hack_rate_s = hack_s_total / max(1, n_s_total)
hack_rate_t = hack_t_total / max(1, n_t_total)
# Per-mechanism on STUDENT rollouts (teacher cache lacks E/D). C-rate from
# this path must match hack_rate_s exactly -- sanity-check it so a future
# refactor that drops one path without the other is caught.
hack_s_E_total = sum(r["hack_s_E"][0] for r in rows)
hack_s_D_total = sum(r["hack_s_D"][0] for r in rows)
hack_s_E_rate = hack_s_E_total / max(1, n_s_total)
hack_s_C_rate = hack_rate_s
hack_s_D_rate = hack_s_D_total / max(1, n_s_total)
# Cross-mechanism HACK_A / HACK_B split (docs/spec/20260528_cross_mechanism_v_hack.md).
# Computed exactly per-step from per-rollout (E,C,D) tuples; here we just sum.
half_a_codes = {c.strip().upper() for c in cfg.half_a.split(",") if c.strip()}
valid_codes = {"E", "C", "D"}
if half_a_codes and not half_a_codes.issubset(valid_codes):
raise ValueError(f"--half-a contains unknown codes: {half_a_codes - valid_codes}; valid: {valid_codes}")
half_b_codes = valid_codes - half_a_codes if half_a_codes else set()
hack_s_A_total = sum(r["hack_s_A"][0] for r in rows)
hack_s_B_total = sum(r["hack_s_B"][0] for r in rows)
hack_a_rate = hack_s_A_total / max(1, n_s_total) if half_a_codes else float("nan")
hack_b_rate = hack_s_B_total / max(1, n_s_total) if half_a_codes else float("nan")
# R3 sneaky-fail guard: under route, the quarantine knob must have absorbed
# something (‖δS_hack‖ > 0), else routing silently degenerated to
# erasure (parked grad never applied). Exactly 0 by construction for
# none/erase (δS_hack gets no grad -> AdamW skips it).
dsh_norm = float(sum(info["delta_S_hack"].data.float().pow(2).sum().item()
for info in wrappers.values()) ** 0.5)
logger.info(f"||delta_S_hack|| = {dsh_norm:.4f} "
f"(SHOULD: >0 for route/route2, ==0 for none/erase; ELSE routing broke)")
if cfg.intervention in ("route", "route2") and cfg.route2_random_v_seed is None:
assert dsh_norm > 0.0, f"{cfg.intervention}: delta_S_hack never moved -> nothing routed into quarantine"
elif cfg.route2_random_v_seed is not None and dsh_norm == 0.0:
# Haar directionality control: "nothing routed" is a VALID outcome (a zero-alignment
# direction may never clear tau) and is itself H4-confirming evidence -- do not abort.
logger.warning("route2 Haar control: ||delta_S_hack||==0 -> the random direction routed "
"NOTHING. This is a real result (favours H4: alignment needed), not a failure.")
# Last training generation -- a fast eyeball for coherence before the eval
# numbers. SHOULD: real code/prose for the problem. If it is token salad the
# policy diverged and every eval number below is meaningless (see ppl_t / lp_t).
if last_gen_sample is not None:
_s, _r = last_gen_sample
logger.info(
f"\n\n=== LAST TRAIN GEN (step {_s}, reward={_r['reward']:+.2f}, "
f"gt_pass={_r['gt_pass']}, hacked={_r['hacked_E']}) ===\n"
f"SHOULD: coherent code/prose. ELSE token salad => diverged, eval below is moot.\n"
f"{_r['text'][:800]}\n=== END LAST GEN ===\n")
# ── final eval + BLUF ──
# Final per-mode train-vs-deploy eval -- run for EVERY arm on the SAME fixed
# eval subset so the all-arms overlay reads identical numbers. For route/route2
# this is the absorption test: TRAIN keeps the quarantine knob on (still hacks),
# DEPLOY deletes it (the shipped model). SHOULD: deploy hack < train hack at
# preserved solve => the quarantine absorbed the cheat. vanilla/erase have no
# quarantine, so the deployed model IS the trained model (deploy == train, one eval).
model.eval()
ev_train = eval_hack_solve(model, tok, problems, eval_idxs, gen_cfg_eval, device, max_new)
has_quarantine = cfg.intervention in ("route", "route2")
if has_quarantine:
with ablate_quarantine(wrappers):
ev_deploy = eval_hack_solve(model, tok, problems, eval_idxs, gen_cfg_eval, device, max_new)
else:
ev_deploy = ev_train
logger.info(
f"FINAL EVAL [{cfg.arm}] (n={ev_train['n']}): "
f"train/knob-on hack={ev_train['hack']:.3f} solve={ev_train['solve']:.3f} | "
f"deploy/knob-off hack={ev_deploy['hack']:.3f} solve={ev_deploy['solve']:.3f} "
+ ("(SHOULD: deploy hack < train hack at ~matched solve => quarantine absorbed the cheat)"
if has_quarantine else "(no quarantine: deploy == train)"))
# Per-mode hack: the generalisation cut. v_hack is run_tests-only, so run_tests is
# the IN-distribution mode; file_marker/sentinel/stdout_marker are HELD-OUT.
# SHOULD: if routing generalises, deploy hack drops on held-out modes too, not just
# run_tests. ELSE the quarantine only caught the mode v_hack saw.
per_mode_deploy: dict[str, dict] = {}
for mode in sorted(ev_deploy["by_mode"]):
th, ts, tn = ev_train["by_mode"].get(mode, [0, 0, 0])
dh, ds, dn = ev_deploy["by_mode"][mode]
tag = "IN-dist" if mode == "run_tests" else "held-out"
logger.info(
f" per-mode[{mode:<13} {tag:>8}] train hack={th}/{tn} solve={ts}/{tn} | "
f"deploy hack={dh}/{dn} solve={ds}/{dn}")
per_mode_deploy[mode] = {
"in_dist": mode == "run_tests",
"train_hack": th / max(1, tn), "train_solve": ts / max(1, tn),
"deploy_hack": dh / max(1, dn), "deploy_solve": ds / max(1, dn), "n": dn,
}
# Single structured record the overlay plot reads (one file per run, in run_dir
# next to the log/checkpoint). All arms emit the same schema; vanilla/erase have
# deploy==train. This is the canonical source for the all-arms per-mode plot.
deploy_record = {
"arm": cfg.arm, "intervention": cfg.intervention,
"refresh_every": cfg.vhack_refresh_every, "seed": cfg.seed,
"steps": n_steps, "model": model_name, "out_tag": cfg.out_tag,
"log": str(verbose_log), "eval_n": ev_deploy["n"],
"hack_train": ev_train["hack"], "solve_train": ev_train["solve"],
"hack_deploy": ev_deploy["hack"], "solve_deploy": ev_deploy["solve"],
"by_mode": per_mode_deploy,
}
deploy_path = run_dir / "per_mode_deploy.json"
deploy_path.write_text(json.dumps(deploy_record, indent=2))
logger.info(f"per-mode deploy artifact: {deploy_path}")
# Final tail: cue emoji + main metric BLUF, then per-step tsv table.
# Vanilla arm: 🟢 if hacking emerged. Projected arm: 🟢 if HACK_RATE dropped
# vs a matched-PASS vanilla; we can't judge that here, so just report.
cue = "🟢" if (cfg.arm == "vanilla" and hack_rate > 0.0) else "🟡"
print(f"\nargv: {' '.join(sys.argv)}")
print(f"verbose log: {verbose_log}")
print(
f"main metric: HACK_RATE={hack_rate:.3f} PASS_RATE={pass_rate:.3f} "
f"HACK_STUDENT={hack_rate_s:.3f} HACK_TEACHER={hack_rate_t:.3f} "
f"[arm={cfg.arm} preset={cfg.preset_name} model={model_name} steps={n_steps} gens={n_gens} peak={peak_gb:.1f}GB"
f"{' pool=' + cfg.teacher_pool_dir.name + ' mix=' + str(cfg.mix_ratio) if cfg.teacher_pool_dir else ''}]"
)
# Substrate UAT: did the student learn EACH hack, and at what step? One row per
# mode in the partition. SHOULD: every mode has hacks>0 and a finite first_step
# => the student learned all K loopholes from the repeated teacher batch. A mode
# with hacks=0 means that loophole never emerged (teacher seed too weak, or the
# subset's non-overlap detector never fired).
if partition is not None:
print()
per_mode_rows = sorted(
({"mode": m, "exploit_rate": f"{mode_hacks.get(m, 0) / max(1, mode_rollouts.get(m, 0)):.3f}",
"hacks": mode_hacks.get(m, 0), "student_rollouts": mode_rollouts.get(m, 0),
"first_step": mode_first_step.get(m, "-")}
for m in sorted(mode_rollouts)),
key=lambda r: r["mode"],
)
n_learned = sum(1 for r in per_mode_rows if r["hacks"] > 0)
cue_sub = "🟢" if n_learned == len(per_mode_rows) else ("🟡" if n_learned else "🔴")
print(f"{cue_sub} SUBSTRATE per-mode learning ({n_learned}/{len(per_mode_rows)} modes learned):")
print(tabulate(per_mode_rows, headers="keys", tablefmt="github"))
# Per-mechanism rates on STUDENT rollouts (teacher pool cache lacks E/D).
# SHOULD: if v_hack was extracted from half_A pairs and projection generalises,
# HACK_A AND HACK_B both fall vs a matched-seed vanilla run.
# If only HACK_A falls: projection is mechanism-specific (negative result).
# If neither falls: projection broken in-distribution.
print(
f"per-mech (student): HACK_S_E={hack_s_E_rate:.3f} HACK_S_C={hack_s_C_rate:.3f} "
f"HACK_S_D={hack_s_D_rate:.3f} "
f"half_A={sorted(half_a_codes) or '-'} HACK_A={hack_a_rate:.3f} "
f"half_B={sorted(half_b_codes) or '-'} HACK_B={hack_b_rate:.3f} "
f"(A=any half_A fires; B=any half_B fires AND no half_A fires)"
)
print()
# Render every (n, d) fraction tuple (gt_s/hack_s/hack_t/hk_<mode>/...) as "n/d"
# so tabulate shows them as fractions, not raw tuples. Drop timing columns --
# useful per-step in the streaming log but noise in the journal-pasteable table.
# Drop timing (gen/fb/t_rew/sec) + sprd/N: sprd is a constant T/F bail flag and N
# is redundant with the frac denominators already shown in gt_s/hack_s/hk_<mode>.
_DROP_COLS = ("gen", "fb", "t_rew", "sec", "sprd", "N")
rows_for_dump = [
{k: (f"{v[0]}/{v[1]}" if isinstance(v, tuple) and len(v) == 2 else v)
for k, v in r.items() if k not in _DROP_COLS}
for r in rows
]
# BLUF summary first -- the single row a reader scans -- as github markdown.
print(tabulate([{
"cue": cue, "HACK_RATE": f"{hack_rate:.3f}", "PASS_RATE": f"{pass_rate:.3f}",
"HACK_S": f"{hack_rate_s:.3f}", "HACK_T": f"{hack_rate_t:.3f}",
"peak_GB": f"{peak_gb:.1f}", "arm": cfg.arm, "preset": cfg.preset_name,
"model": model_name.split("/")[-1], "seed": cfg.seed, "steps": n_steps,
"pool": (cfg.teacher_pool_dir.name if cfg.teacher_pool_dir else ""),
"mix": cfg.mix_ratio if cfg.teacher_pool_dir else "",
"tag": cfg.out_tag, "log": str(verbose_log),
}], headers="keys", tablefmt="github"))
# Per-step rows ONCE, markdown (journal/PR pasteable). The TSV duplicate of the
# same data was dropped -- two formats of one table was just noise.
print("\n### Per-step rows (markdown)\n")
print(tabulate(rows_for_dump, headers="keys", tablefmt="pipe", floatfmt="+.3f"))
save_ckpt(rows)
return 0
if __name__ == "__main__":
# Tyro subcommand dispatch: `train smoke`, `train fast`, `train full`.
# Each subcommand is a typed dataclass (SmokeConfig / FastConfig / FullConfig)
# with its own field defaults; CLI overrides via `--lr=3e-3` etc still work.
# We pass the classes (not instances): tyro calls the class to build the
# default, with CLI flags overriding fields.
cfg = tyro.extras.subcommand_cli_from_dict({
"smoke": SmokeConfig,
"fast": FastConfig,
"full": FullConfig,
})
sys.exit(main(cfg))