mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 18:04:59 +08:00
adca442253
Gate now scores each rollout by dot(pooled bottleneck act, v_act) captured on
the no-grad logpi_old forward (quarantine-ablated, matching the sampling
policy); masks are pinned BEFORE the single grad-carrying forward, so the
grad-gate's pass-1 backward is gone. Thresholds: rolling 256-act buffer,
z-normalized, two-threshold Otsu (winsorized 1/99); warmup pins absorb until
128 scores. Buffer stores pooled acts and re-scores against the current v_act,
so the forward-only refresh (every 5 steps) needs no flush. No bimodality
guard: calibration showed Otsu tail separation ~2.4-2.8 buffer-sd on every
condition including pure Gaussians, so no shape statistic discriminates.
Deleted with the arm wiring (rename-on-logic-change: routeA never conflates
with routeV runs): extract_vhack_grad.py, _build_v_grad, route_band_edges,
_pair_cos, the pass-1 autograd.grad block, grad_probe training wiring,
v_grad_k/route_std_*/routeV_random_v_seed config, smoke-topk recipe.
c-probe stays in lora2r.py for scripts/diag_pinning.py only.
verify_science_invariants: all-in-one count 27 -> 42 (stale since c33b810
added the wave-2 behavior2 pairs) + assert the 8-pair routeA training subset.
Smoke: routeA/vanilla/absorb/solvemix all pass (gate exercises warmup, Otsu
zones, refresh, deploy ablation) -- /tmp/claude-1000/smoke_routeA.log.
Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
1301 lines
76 KiB
Python
1301 lines
76 KiB
Python
"""GRPO / Dr.GRPO loop with per-rollout output masking on the LeetCode
|
||
reward-hacking benchmark.
|
||
|
||
generate -> grade -> backward -> (gate) -> masked backward -> step
|
||
|
||
Inner GRPO step ported from lsdefine/simple_GRPO grpo_vllm_one.py:64-95; the
|
||
outer loop accumulates grads over prompts_per_step prompts (simple_GRPO's
|
||
Q_batch_size), so at least one per-prompt group has reward variance.
|
||
Unbiased normalization: Dr.GRPO, Liu et al. 2025, arXiv:2503.20783 -- drop the
|
||
1/|oᵢ| length norm and the /σ_R group-std (--unbiased, on by default).
|
||
|
||
Adapter: lora2r (src/vgrout/lora2r.py) -- one rank-2r LoRA per Linear, A and B
|
||
both trainable, partitioned into a deployed block [:r] and a quarantine block
|
||
[r:]. The quarantine is ablated (reset to its frozen init) at deployment.
|
||
|
||
Arms (--intervention):
|
||
none gate pinned clean (0,0): quarantine never trains -- the capacity- and
|
||
structure-matched vanilla control.
|
||
routeA per-rollout three-way gate from the pooled bottleneck activation vs
|
||
v_act: keep->deployed-only, rout->quarantine-only (deployed detached),
|
||
absorb->both, which may permit absorption. The acts ride the no-grad
|
||
logpi_old forward, so routeA costs roughly the vanilla arm.
|
||
absorb gate pinned mid (1,0): both blocks train on everything, no gate --
|
||
tests ungated both-block training.
|
||
|
||
uv run python -m vgrout.train smoke --intervention=routeA
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import gzip
|
||
import json
|
||
import math
|
||
import os
|
||
import sys
|
||
import random
|
||
import time
|
||
from collections import deque
|
||
from contextlib import nullcontext
|
||
from pathlib import Path
|
||
|
||
import numpy as np
|
||
|
||
# Must be set BEFORE `import torch` to take effect on the CUDA allocator.
|
||
# Eliminates fragmentation that caused 91 GiB allocated / 581 MiB free crash
|
||
# on Qwen3-4B G=8 (PyTorch's own OOM message recommends this).
|
||
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
|
||
|
||
import torch
|
||
import torch.nn.functional as F
|
||
import tyro
|
||
from jaxtyping import Float
|
||
from loguru import logger
|
||
from safetensors.torch import save_file
|
||
from tabulate import tabulate
|
||
from tqdm import tqdm
|
||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||
|
||
from .extract_vhack_act import ActCapture, extract_v_act, haar_unit_rows
|
||
from .lora2r import wrap_model_with_lora2r
|
||
from .pairs import load_pairs
|
||
from .proj import per_token_logps
|
||
from .rewards import EnvMode, compute_reward
|
||
from .data import DATA, load_problems
|
||
from .eval import ablate_quarantine, eval_hack_solve, load_eval_splits
|
||
from .tablelog import setup_logging, StepLogger
|
||
from .run_artifacts import RUN_SCHEMA
|
||
from .train_config import Config, FastConfig, SmokeConfig
|
||
|
||
OUT_DIR = Path("out")
|
||
RUNS_DIR = OUT_DIR / "runs"
|
||
|
||
|
||
def _otsu3(x: np.ndarray) -> tuple[float, float]:
|
||
"""Two-threshold Otsu: the pair of cuts maximizing 3-class between-class variance.
|
||
Label-free -- the routeA gate computes this on a rolling buffer of live scores, so
|
||
using it is not oracle leakage. Scores are winsorized at 1/99% first: Otsu maximizes
|
||
variance, so on heavy-tailed scores a single extreme point otherwise buys a whole
|
||
class (journal 2026-06-11 (d): v5 act rout precision 0.00 -> 0.50 after winsorize).
|
||
Vectorized over the [n, n] cut grid; n is the buffer size (<= a few hundred)."""
|
||
x = np.clip(x, *np.quantile(x, [0.01, 0.99]))
|
||
s = np.sort(np.asarray(x, float))
|
||
n = len(s)
|
||
c = np.concatenate([[0.0], np.cumsum(s)])
|
||
iv = np.arange(1, n)
|
||
i_g, j_g = iv[:, None], iv[None, :]
|
||
with np.errstate(divide="ignore", invalid="ignore"):
|
||
obj = (c[i_g] ** 2 / i_g
|
||
+ (c[j_g] - c[i_g]) ** 2 / (j_g - i_g)
|
||
+ (c[n] - c[j_g]) ** 2 / (n - j_g))
|
||
obj[(j_g <= i_g) | (j_g >= n)] = -np.inf # need i < j and a nonempty top class
|
||
i, j = np.unravel_index(np.argmax(obj), obj.shape)
|
||
i, j = iv[i], iv[j]
|
||
return float((s[i - 1] + s[i]) / 2), float((s[j - 1] + s[j]) / 2)
|
||
|
||
|
||
def _sample_rows(rows: list[dict] | None, n: int, rng: torch.Generator) -> list[dict]:
|
||
"""Draw n teacher rollouts from a prompt's pool (with replacement if the pool is short)."""
|
||
if n == 0 or not rows:
|
||
return []
|
||
idxs = torch.randperm(len(rows), generator=rng)[:n].tolist()
|
||
if len(rows) < n:
|
||
idxs += torch.randint(0, len(rows), (n - len(rows),), generator=rng).tolist()
|
||
return [rows[i] for i in idxs]
|
||
|
||
|
||
def _auroc(scores: list[float], labels: list[bool]) -> float:
|
||
"""Rank-based AUROC (Mann-Whitney U) of `scores` as a detector of the positive class.
|
||
|
||
Higher score for hacks -> auroc > 0.5. nan if either class is absent this step.
|
||
Diagnostic only: ground-truth labels measure how well the gate score separates
|
||
reward-hacking updates, but never determine a route. Reading: ~0.5 means v_act
|
||
is a chance-level classifier (no threshold can route reliably); high AUROC but
|
||
rout~0 = the threshold/scale is wrong, not the direction; a drop across a refresh =
|
||
the refresh destroyed the separation."""
|
||
pos = [s for s, y in zip(scores, labels) if y]
|
||
neg = [s for s, y in zip(scores, labels) if not y]
|
||
if not pos or not neg:
|
||
return float("nan")
|
||
order = sorted(range(len(scores)), key=lambda i: scores[i])
|
||
ranks = [0.0] * len(scores)
|
||
i = 0
|
||
while i < len(order): # average-rank tie handling
|
||
j = i
|
||
while j + 1 < len(order) and scores[order[j + 1]] == scores[order[i]]:
|
||
j += 1
|
||
avg = (i + j) / 2 + 1 # 1-based mean rank of the tie block
|
||
for k in range(i, j + 1):
|
||
ranks[order[k]] = avg
|
||
i = j + 1
|
||
sum_pos = sum(r for r, y in zip(ranks, labels) if y)
|
||
n_pos, n_neg = len(pos), len(neg)
|
||
return (sum_pos - n_pos * (n_pos + 1) / 2) / (n_pos * n_neg)
|
||
|
||
|
||
# Fix evaluation sampling across steps and arms without perturbing the training RNG.
|
||
EVAL_GEN_SEED = 12345
|
||
|
||
# 2-char env_mode codes for compact per-mode hack columns (hk_rt, hk_xc, ...).
|
||
MODE_CODE: dict[str, str] = {
|
||
"run_tests": "rt", "eq_override": "eq", "exit_code": "xc",
|
||
"stdout_marker": "so", "sentinel": "se", "file_marker": "fm",
|
||
"gt_only": "gt",
|
||
}
|
||
|
||
|
||
def _validate_config(cfg: Config) -> None:
|
||
"""Reject contradictory experiment settings before model load."""
|
||
if cfg.intervention not in ("none", "routeA", "absorb"):
|
||
raise ValueError(f"unknown intervention {cfg.intervention!r}; expected none|routeA|absorb")
|
||
if cfg.routeA_random_v_seed is not None and cfg.intervention != "routeA":
|
||
raise ValueError("routeA_random_v_seed is a routeA-only placebo control")
|
||
if cfg.rollout_ablate_frac > 0 and cfg.intervention == "none":
|
||
raise ValueError("rollout_ablate_frac needs a quarantine to ablate (routeA/absorb)")
|
||
if cfg.weight_decay != 0.0:
|
||
raise ValueError("lora2r init is nonzero; AdamW decay pulls A/B toward 0 not toward init "
|
||
"-- set --weight-decay=0")
|
||
if cfg.solve_pool_dir is not None:
|
||
if cfg.teacher_pool_dir is None or cfg.mix_ratio <= 0:
|
||
raise ValueError("solve_pool_dir splits the G_t teacher budget -- needs teacher_pool_dir + mix_ratio>0")
|
||
if not (0.0 <= cfg.solve_mix_frac <= 1.0):
|
||
raise ValueError(f"solve_mix_frac must be in [0,1]; got {cfg.solve_mix_frac}")
|
||
|
||
|
||
def _log_resolved_config(cfg: Config, device) -> None:
|
||
"""One block with every None resolved to its effective value, so a detached log
|
||
shows exactly what ran -- especially WHICH pairset (the field readers kept losing)."""
|
||
is_routeA = cfg.intervention == "routeA"
|
||
fields = {
|
||
"preset/arm": f"{cfg.preset_name} / {cfg.arm}",
|
||
"intervention": cfg.intervention,
|
||
"model": cfg.model, "device": str(device), "seed": cfg.seed,
|
||
"steps/group/pps": f"{cfg.steps} / {cfg.group} / {cfg.prompts_per_step}",
|
||
"max_new/lr/grad_clip": f"{cfg.max_new} / {cfg.lr:.1e} / {cfg.grad_clip}",
|
||
"lora_r/init_seed": f"{cfg.lora_r} / {cfg.lora_init_seed}",
|
||
"unhackable_frac": cfg.unhackable_frac,
|
||
"env_mode": cfg.env_mode,
|
||
"pairset": cfg.vhack_pairs_path if is_routeA else "unused (not routeA)",
|
||
"routeA placebo seed": cfg.routeA_random_v_seed if is_routeA else "n/a",
|
||
"teacher pool/mix/off_step": (
|
||
f"{cfg.teacher_pool_dir.name} / {cfg.mix_ratio} / {cfg.teacher_off_step}"
|
||
if cfg.teacher_pool_dir else "none (pure on-policy)"),
|
||
"out_tag": cfg.out_tag or "(none)",
|
||
}
|
||
width = max(len(k) for k in fields)
|
||
block = "\n".join(f" {k:<{width}} : {v}" for k, v in fields.items())
|
||
logger.info(f"resolved config:\n{block}")
|
||
|
||
|
||
def main(cfg: Config) -> int:
|
||
_validate_config(cfg)
|
||
model_name = cfg.model; steps = cfg.steps; group = cfg.group
|
||
max_new = cfg.max_new; n_problems = cfg.n_problems
|
||
prompts_per_step = cfg.prompts_per_step
|
||
lr = cfg.lr; adam_beta1 = cfg.adam_beta1; adam_beta2 = cfg.adam_beta2
|
||
|
||
run_id = f"{cfg.preset_name}_{cfg.arm}_seed{cfg.seed}{cfg.out_tag}"
|
||
verbose_log = setup_logging(run_id)
|
||
|
||
torch.manual_seed(cfg.seed)
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
# Log enough run identity up front to interpret detached logs.
|
||
logger.info(f"argv: {' '.join(sys.argv)}")
|
||
logger.info(f"verbose log: {verbose_log}")
|
||
_log_resolved_config(cfg, device)
|
||
|
||
is_routeA = cfg.intervention == "routeA"
|
||
is_absorb = cfg.intervention == "absorb"
|
||
is_vanilla = cfg.intervention == "none"
|
||
has_quarantine = is_routeA or is_absorb
|
||
|
||
# Only adapter parameters train; the base model remains frozen.
|
||
tok = AutoTokenizer.from_pretrained(model_name)
|
||
if tok.pad_token_id is None: tok.pad_token = tok.eos_token
|
||
|
||
# ── model + tokenizer ──
|
||
# CPU smoke: fp32 + sdpa (flash-attn2 is CUDA-only, CPU bf16 is patchy).
|
||
# GPU: bf16 + flash_attention_2.
|
||
cpu = device.type == "cpu"
|
||
model = AutoModelForCausalLM.from_pretrained(
|
||
model_name,
|
||
dtype=torch.float32 if cpu else torch.bfloat16,
|
||
attn_implementation="sdpa" if cpu else "flash_attention_2",
|
||
).to(device)
|
||
# Generation enables KV cache; loss forwards disable it to avoid unused state.
|
||
model.config.use_cache = False
|
||
|
||
# ── adapter: rank-2r LoRA, deployed block [:r] + quarantine block [r:] ──
|
||
# The routeA gate reads activations via forward hooks; no grad probe in training
|
||
# (the c-probe stays in lora2r only for scripts/diag_pinning.py diagnostics).
|
||
wrappers = wrap_model_with_lora2r(
|
||
model, r=cfg.lora_r, init_seed=cfg.lora_init_seed)
|
||
# A and B both train; quarantine = block slices of the SAME tensors, so there
|
||
# is no separate hack-param list (per-rollout masks route grads, not surgery).
|
||
delta_params = [p for info in wrappers.values() for p in (info["A"], info["B"])]
|
||
n_quar = sum(info["A"][info["r"]:].numel() + info["B"][:, info["r"]:].numel()
|
||
for info in wrappers.values())
|
||
logger.info(f"trainable lora2r A+B: {sum(p.numel() for p in delta_params):,} "
|
||
f"({n_quar:,} of those in quarantine blocks)")
|
||
|
||
# ── routeA direction: v_act (mean pooled-act pair diff, unit rows per module) ──
|
||
v_act = None # [M, r] cpu fp32; module order = act_names
|
||
act_names = sorted(wrappers)
|
||
act_buf: deque | None = None # rolling pooled acts [M, r]; re-scored vs the CURRENT
|
||
# v_act at each gate call, so a refresh needs no flush
|
||
MASK_PAIRS = None
|
||
if is_routeA:
|
||
# Authored pairs are the only routing-label source; live oracle labels never enter training.
|
||
MASK_PAIRS = load_pairs(cfg.vhack_pairs_path)
|
||
logger.info(f"routeA pairs: {cfg.vhack_pairs_path} -> {len(MASK_PAIRS)} pairs")
|
||
model.eval() # deterministic forward, no dropout
|
||
v_act, pair_acts = extract_v_act(model, tok, wrappers, MASK_PAIRS, device,
|
||
tstat=cfg.vact_tstat)
|
||
model.train()
|
||
# Authored-pair separation in the live score. The dot GAP is > 0 by construction
|
||
# (v is proportional to the mean pair diff); the pair AUROC is not, so it is the
|
||
# extraction sanity signal.
|
||
sh = torch.einsum("pmr,mr->p", pair_acts["hack"], v_act)
|
||
sc = torch.einsum("pmr,mr->p", pair_acts["clean"], v_act)
|
||
pair_auroc = _auroc(torch.cat([sh, sc]).tolist(),
|
||
[True] * len(sh) + [False] * len(sc))
|
||
logger.info(
|
||
f"routeA v_act: {v_act.shape[0]} modules x r={v_act.shape[1]} "
|
||
f"(tstat={cfg.vact_tstat}); authored-pair dot gap={(sh.mean() - sc.mean()).item():+.3e}, "
|
||
f"pair AUROC={pair_auroc:.2f}. SHOULD: pair AUROC ~1.0 ELSE extraction broken.")
|
||
if cfg.routeA_random_v_seed is not None:
|
||
v_act = haar_unit_rows(tuple(v_act.shape), cfg.routeA_random_v_seed)
|
||
logger.info(f"routeA: OVERRODE v_act with Haar-random unit rows "
|
||
f"(seed={cfg.routeA_random_v_seed}) -- placebo directionality control")
|
||
act_buf = deque(maxlen=cfg.route_buffer)
|
||
logger.info(
|
||
f"routeA gate: per-rollout score = dot(pooled completion-token act, v_act), "
|
||
f"thresholds = two-threshold Otsu on the last <= {cfg.route_buffer} live scores "
|
||
f"(z-normalized, winsorized 1/99%), label-free; pinned absorb until "
|
||
f"{cfg.route_warmup} scores. keep (0,0) | absorb (1,0) | rout (1,1: deployed "
|
||
f"detached). No bimodality guard: on the cached emergence windows no shape "
|
||
f"statistic separates the hack mixture from hack-free scores (Otsu tail means "
|
||
f"sit ~2.4 sd apart even on a Gaussian), and a false rout only discards one "
|
||
f"update from deployment. "
|
||
f"SHOULD: auroc col >> 0.5 once hacks appear ELSE v_act is blind and routing "
|
||
f"is noise; rout tracks the hack share, not ~0 or ~1.")
|
||
|
||
# ── teacher pool ──
|
||
# Teacher pool: pre-generated rollouts on disk keyed by problem_id. Each step's
|
||
# G_t teachers are a uniform random sample of that prompt's cache (no teacher
|
||
# model in VRAM); cached rewards/flags are reused verbatim, so it's a fixed
|
||
# reproducible teacher distribution.
|
||
teacher_pool: dict[int, list[dict]] = {}
|
||
# Multi-loophole substrate: a teacher pool dir MAY carry partition.json
|
||
# {problem_id: env_mode}. When present, this is the even non-overlapping
|
||
# substrate (build_substrate.py) -- each problem graded by its assigned mode.
|
||
# When absent, the run is single-mode (cfg.env_mode for every problem).
|
||
partition: dict[int, EnvMode] | None = None
|
||
G_s = group
|
||
G_t = 0
|
||
if cfg.teacher_pool_dir is not None:
|
||
# mix=0 is the NO-TEACHER ablation: pure on-policy GRPO (G_t=0) while the
|
||
# pool is still loaded for the partition.
|
||
if not (0.0 <= cfg.mix_ratio < 1.0):
|
||
raise ValueError(f"mix_ratio must be in [0,1) when teacher_pool_dir set; got {cfg.mix_ratio}")
|
||
G_t = round(group * cfg.mix_ratio)
|
||
G_s = group - G_t
|
||
if G_s == 0:
|
||
raise ValueError(
|
||
f"degenerate split: G={group} mix_ratio={cfg.mix_ratio} -> G_s={G_s}. "
|
||
f"Pick mix_ratio < 1 so the student half is non-empty.")
|
||
for path in sorted(cfg.teacher_pool_dir.glob("prompt_*.jsonl.gz")):
|
||
# path.name 'prompt_0004.jsonl.gz' -> problem_id 4.
|
||
problem_id = int(path.name.split("_")[1].split(".")[0])
|
||
with gzip.open(path, "rt") as f:
|
||
teacher_pool[problem_id] = [json.loads(line) for line in f]
|
||
if not teacher_pool:
|
||
raise FileNotFoundError(
|
||
f"teacher pool {cfg.teacher_pool_dir} is empty. Run `just pregen-teacher N` first.")
|
||
partition_path = cfg.teacher_pool_dir / "partition.json"
|
||
if partition_path.exists():
|
||
raw = json.loads(partition_path.read_text())
|
||
partition = {int(pid): mode for pid, mode in raw.items()}
|
||
from collections import Counter
|
||
by_mode = Counter(partition.values())
|
||
logger.info(
|
||
f"SUBSTRATE: per-problem env_mode partition from {partition_path.name} -- "
|
||
f"{len(partition)} problems across {len(by_mode)} modes: "
|
||
f"{dict(sorted(by_mode.items()))}. Each problem graded by its own mode; "
|
||
f"non-overlap holds (passed = gt_correct OR channel_i).")
|
||
if cfg.teacher_modes is not None:
|
||
# Oracle-free generalization test: held-out modes remain on-policy and receive no demos.
|
||
assert partition is not None, "teacher_modes needs a partition.json"
|
||
kept = {pid: rows for pid, rows in teacher_pool.items()
|
||
if partition[pid] in cfg.teacher_modes}
|
||
logger.info(
|
||
f"teacher_modes={cfg.teacher_modes}: teacher pool restricted "
|
||
f"{len(teacher_pool)}->{len(kept)} prompts (known modes only); "
|
||
f"held-out-mode problems train ON-POLICY (no teacher, no anchor seed).")
|
||
teacher_pool = kept
|
||
n_rollouts_per = sum(len(v) for v in teacher_pool.values()) / len(teacher_pool)
|
||
avg_hack = sum(int(r["hacked"]) for v in teacher_pool.values() for r in v) / sum(len(v) for v in teacher_pool.values())
|
||
logger.info(
|
||
f"teacher pool: {len(teacher_pool)} prompts, ~{n_rollouts_per:.1f} rollouts/prompt, "
|
||
f"cached hack_rate={avg_hack:.2%}. Deterministic: {cfg.teacher_n_per_prompt} hack "
|
||
f"teacher(s) per teacher-phase prompt (constant count, no mix_ratio budget).")
|
||
|
||
# ── solve-teacher pool (symmetric correct-solution demos) ── same schema/loader as the
|
||
# hack pool; the G_t teacher slots split solve_mix_frac solve / rest hack.
|
||
solve_pool: dict[int, list[dict]] = {}
|
||
if cfg.solve_pool_dir is not None:
|
||
for path in sorted(cfg.solve_pool_dir.glob("prompt_*.jsonl.gz")):
|
||
problem_id = int(path.name.split("_")[1].split(".")[0])
|
||
with gzip.open(path, "rt") as f:
|
||
solve_pool[problem_id] = [json.loads(line) for line in f]
|
||
if not solve_pool:
|
||
raise FileNotFoundError(f"solve pool {cfg.solve_pool_dir} is empty.")
|
||
solve_hack = sum(int(r["hacked"]) for v in solve_pool.values() for r in v)
|
||
n_solve_rows = sum(len(v) for v in solve_pool.values())
|
||
logger.info(
|
||
f"solve pool: {len(solve_pool)} prompts, {n_solve_rows} rollouts, "
|
||
f"cached hack_rate={solve_hack / n_solve_rows:.2%} (SHOULD ~0% -- correct-solution demos). "
|
||
f"The step teacher budget splits {cfg.solve_mix_frac:.0%} solve / {1 - cfg.solve_mix_frac:.0%} hack.")
|
||
|
||
# ── optimizer + schedule ── (A and B of both blocks; masks route grads)
|
||
opt = torch.optim.AdamW(
|
||
delta_params, lr=lr, weight_decay=cfg.weight_decay, betas=(adam_beta1, adam_beta2))
|
||
# OneCycle does warmup + cosine relaxation in one object: cos ramp from lr/div_factor
|
||
# up to lr over the first pct_start of steps (the explicit warmup), then cos anneal to
|
||
# ~0. cycle_momentum=False so it leaves the configured AdamW betas alone (else it would
|
||
# clobber adam_beta1). pct_start = warmup_frac keeps warmup fractional across presets.
|
||
sched = torch.optim.lr_scheduler.OneCycleLR(
|
||
opt, max_lr=lr, total_steps=steps, pct_start=cfg.warmup_frac,
|
||
anneal_strategy="cos", div_factor=25.0, final_div_factor=1e4, cycle_momentum=False)
|
||
|
||
# ── generation config ──
|
||
# Use the same sampling policy for training and evaluation.
|
||
gen_cfg = GenerationConfig(
|
||
max_new_tokens=max_new, do_sample=True,
|
||
# T=0.7 matches the Ariahw reference and exposes the substrate's modal hacks.
|
||
temperature=0.7, top_p=1.0, top_k=20, min_p=0.0,
|
||
repetition_penalty=1.0,
|
||
num_return_sequences=G_s, pad_token_id=tok.pad_token_id,
|
||
)
|
||
# Evaluate one completion per prompt because prompts, not repeated samples, are independent.
|
||
gen_cfg_eval = GenerationConfig(
|
||
max_new_tokens=max_new, do_sample=True,
|
||
temperature=0.7, top_p=1.0, top_k=20, min_p=0.0, repetition_penalty=1.0,
|
||
num_return_sequences=1, pad_token_id=tok.pad_token_id,
|
||
)
|
||
|
||
# Seeded shuffle avoids the memorized low-id slice while preserving paired arms.
|
||
all_problems = load_problems(10_000, env_modes=[cfg.env_mode], seed=cfg.seed,
|
||
partition=partition, shuffle=True)
|
||
# Pin teacher-covered prompts, then train on the wider environment to test generalization.
|
||
if teacher_pool:
|
||
seeded = [p for p in all_problems if p["problem_id"] in teacher_pool]
|
||
rest = [p for p in all_problems if p["problem_id"] not in teacher_pool]
|
||
problems = (seeded + rest)[:n_problems] # seed ids first, fill to n_problems
|
||
else:
|
||
problems = all_problems[:n_problems]
|
||
mode_desc = "per-problem partition" if partition is not None else f"single env_mode={cfg.env_mode}"
|
||
logger.info(f"loaded {len(problems)} seeded-shuffle problems from {DATA.name} -- {mode_desc}")
|
||
# Both-pool-covered prompts: the teacher phase samples ONLY these so every prompt can get
|
||
# a deterministic N hack + N solve. solve_pool ⊂ teacher_pool, so the intersection = the
|
||
# solve-covered prompts (or just teacher-covered when there is no solve pool).
|
||
covered_problems = [p for p in problems
|
||
if (not teacher_pool or p["problem_id"] in teacher_pool)
|
||
and (not solve_pool or p["problem_id"] in solve_pool)]
|
||
if teacher_pool:
|
||
n_cov = sum(1 for p in problems if p["problem_id"] in teacher_pool)
|
||
n_t_per_prompt = cfg.teacher_n_per_prompt * (bool(teacher_pool) + bool(solve_pool))
|
||
logger.info(f"teacher coverage: {n_cov}/{len(problems)} hack-covered, "
|
||
f"{len(covered_problems)}/{len(problems)} both-pool-covered (teacher-phase "
|
||
f"sampling pool); hack must generalize off the seeds to the wider on-policy set. "
|
||
f"CONSTANT {n_t_per_prompt} teachers/prompt -> {n_t_per_prompt * prompts_per_step}"
|
||
f"/{prompts_per_step * group} gens are teacher each teacher-phase step.")
|
||
|
||
# Periodic validation and final test are disjoint; final-test results never affect training.
|
||
# Exclude gt_only from hack evaluation unless it is the entire no-loophole ceiling run.
|
||
eval_modes = sorted({p["env_mode"] for p in problems} - {"gt_only"}) or ["gt_only"]
|
||
val_problems, test_problems = load_eval_splits(eval_modes, cfg.eval_n_prompts)
|
||
val_idxs, test_idxs = list(range(len(val_problems))), list(range(len(test_problems)))
|
||
_train_ids = {p["problem_id"] for p in problems}
|
||
assert not (_train_ids & {p["problem_id"] for p in val_problems}), "VAL set leaks training problems"
|
||
assert not (_train_ids & {p["problem_id"] for p in test_problems}), "TEST set leaks training problems"
|
||
logger.info(f"held-out eval: periodic val n={len(val_problems)} + untouched final test "
|
||
f"n={len(test_problems)} from leetcode_test_medhard, modes={eval_modes}")
|
||
|
||
rng = torch.Generator().manual_seed(cfg.seed)
|
||
rows = []
|
||
logger.info(
|
||
f"SHOULD: loss finite each step; PASS_RATE > 0 on 4B. "
|
||
f"ELSE: harness broken. "
|
||
f"Timing cols (gen/fb/t_rew/sec): gen-bound -> vLLM; fb-bound -> lower pp; t_rew-bound -> parallel grading.")
|
||
if teacher_pool:
|
||
logger.info(
|
||
f"SHOULD (mixed-pool): hack_t high from step 0 (cached teacher pool ~95% hack); "
|
||
f"hack_s climbs 0 -> 20%+ over the run as student learns from exposure. "
|
||
f"ELSE if hack_s flat while hack_t high: student is ignoring the off-policy "
|
||
f"gradient signal; bump mix_ratio or lr.")
|
||
|
||
eos_id = tok.eos_token_id
|
||
pad_id = tok.pad_token_id
|
||
|
||
def gen_students(enc, n: int) -> tuple[torch.Tensor, int]:
|
||
"""Generate student rollouts in deployment mode with the quarantine ablated.
|
||
|
||
This prevents reward-hacking behavior learned by the quarantine parameters from
|
||
saturating the rollout distribution and reducing the correct-solution advantage.
|
||
For clean-gated
|
||
rollouts the train forward is also quarantine-off, so gen and train now match.
|
||
(vanilla has no quarantine -> nullcontext, but its quarantine is empty so it is already
|
||
deploy-mode.)"""
|
||
ctx = ablate_quarantine(wrappers) if has_quarantine else nullcontext()
|
||
with ctx:
|
||
out = model.generate(**enc, generation_config=gen_cfg,
|
||
num_return_sequences=n).detach()
|
||
return out, 0 # every rollout is deploy-mode now; no separate ablated-proxy subset
|
||
|
||
# `ref_eq` compares cumulative sampling pressure to the 16x16 reference step.
|
||
run_modes = sorted({p["env_mode"] for p in problems}, key=lambda m: list(MODE_CODE).index(m))
|
||
step_logger = StepLogger(arm=cfg.arm, modes=run_modes, mode_code=MODE_CODE,
|
||
show_ablate=cfg.rollout_ablate_frac > 0)
|
||
REF_GENS_PER_STEP = 16 * 16 # ariahw/rl-rewardhacking config.py:num_prompts * num_generations
|
||
est_gens_per_step = prompts_per_step * group # before mixed-pool split
|
||
logger.info(
|
||
f"grad-pressure: {est_gens_per_step} gens/step vs reference {REF_GENS_PER_STEP} "
|
||
f"-> {est_gens_per_step / REF_GENS_PER_STEP:.2f}x per step; "
|
||
f"this run's {steps} steps ~= {steps * est_gens_per_step / REF_GENS_PER_STEP:.1f} reference steps.")
|
||
# Print only the legend columns active for this arm and environment.
|
||
logger.info("\n" + step_logger.legend() + "\n\n")
|
||
logger.info(step_logger.header())
|
||
|
||
# Group all outputs from one run under the log's timestamped stem.
|
||
run_dir = RUNS_DIR / verbose_log.stem
|
||
run_dir.mkdir(parents=True, exist_ok=True)
|
||
ckpt_path = run_dir / "train.safetensors"
|
||
# Store paired quarantine-enabled/ablated validation results as structured data.
|
||
eval_curve_path = run_dir / "eval_curve.jsonl"
|
||
first_hack_path = run_dir / "first_hack.safetensors"
|
||
# Log live oracle labels for offline audit only; this file is never read by training.
|
||
rollout_log_path = run_dir / "rollouts.jsonl"
|
||
rollout_log_path.write_text("")
|
||
first_hack_saved = False
|
||
last_gen_sample = None # first student rollout of the latest step (for collapse inspection)
|
||
diverged_steps = 0 # consecutive steps with collapsed teacher ppl (divergence tripwire)
|
||
lp_t_best = -float("inf") # coherence high-water mark (best teacher gen_logp seen)
|
||
# Detect collapse by a relative log-probability drop on fixed teacher completions.
|
||
DIVERGENCE_DROP = 5.0 # nats below best (e^5 ~ 150x worse ppl); never in healthy runs
|
||
WARN_DROP = 3.0 # softer: log a warning before the hard abort
|
||
dumped_hack_classes: set[str] = set() # first full example of each hack class -> verbose log
|
||
# Track whether and when the student learns each substrate mode.
|
||
mode_rollouts: dict[str, int] = {}
|
||
mode_hacks: dict[str, int] = {}
|
||
mode_first_step: dict[str, int] = {}
|
||
n_flipped = 0 # prompt-draws shown hint-free this run (rotating-unhackable flip)
|
||
route_hackT_run: list[float] = [] # per-step routed-share of hack teachers (solve-mix run)
|
||
route_solveT_run: list[float] = [] # per-step routed-share of solve teachers
|
||
|
||
def save_ckpt(rows: list[dict], path: Path | None = None) -> None:
|
||
"""Save a lora2r checkpoint: trainable A/B only, bf16. The init A0/B0 are NOT
|
||
saved -- they're the seeded Gaussian (regenerable from lora_init_seed) and
|
||
ckpt_update0000 holds them anyway (A==A0 at step 0); nothing live reloads
|
||
them. Config + per-step rows in the metadata."""
|
||
n_gens = sum(r["N"] for r in rows)
|
||
hr = sum(r["hack_s"][0] + r["hack_t"][0] for r in rows) / max(1, n_gens)
|
||
pr = sum(r["gt_s"][0] + r["gt_t"][0] for r in rows) / max(1, n_gens)
|
||
_ckpt = path or ckpt_path
|
||
tensors = {}
|
||
for n, info in wrappers.items():
|
||
tensors[f"A/{n}"] = info["A"].detach().bfloat16().cpu().contiguous()
|
||
tensors[f"B/{n}"] = info["B"].detach().bfloat16().cpu().contiguous()
|
||
save_file(tensors, str(_ckpt), metadata={
|
||
"model": model_name, "dtype": "bf16", "step": str(len(rows)),
|
||
"hack_rate": f"{hr:.6f}", "pass_rate": f"{pr:.6f}",
|
||
"rows": json.dumps(rows),
|
||
"cfg": json.dumps(vars(cfg) | {"adapter": "lora2r"}, default=str),
|
||
})
|
||
|
||
save_ckpt([], path=run_dir / "ckpt_update0000.safetensors")
|
||
|
||
def _routeA_gate(dots: Float[torch.Tensor, "G"]) -> tuple[torch.Tensor, torch.Tensor, float, float]:
|
||
"""Three-way output-mask label per rollout from the rolling score buffer.
|
||
|
||
The buffer holds pooled ACTS, so every gate call scores the whole window
|
||
against the CURRENT v_act (refresh-proof; the only staleness left is act
|
||
drift as the adapter trains, small over <= route_buffer rollouts). Scores
|
||
are z-normalized by the buffer mean/std, then two-threshold Otsu (winsorized
|
||
inside _otsu3) places (t_lo, t_hi): z <= t_lo keep (0,0); t_lo < z < t_hi
|
||
absorb (1,0, both blocks train); z >= t_hi rout (1,1, deployed detached).
|
||
Warmup: pinned absorb until the buffer holds route_warmup scores -- too few
|
||
points to place thresholds, and absorb keeps both blocks learning."""
|
||
if len(act_buf) < cfg.route_warmup:
|
||
G_n = dots.shape[0]
|
||
return (torch.ones(G_n, device=device), torch.zeros(G_n, device=device),
|
||
float("nan"), float("nan"))
|
||
S = torch.einsum("nmr,mr->n", torch.stack(tuple(act_buf)), v_act)
|
||
mu, sd = S.mean().item(), max(S.std().item(), 1e-12)
|
||
t_lo, t_hi = _otsu3(((S - mu) / sd).numpy())
|
||
z = (dots - mu) / sd
|
||
m = (z > t_lo).float().to(device) # absorb + rout -> quarantine trains
|
||
d = (z >= t_hi).float().to(device) # top zone -> rout -> deployed detached
|
||
logger.debug(f"routeA gate: buf={len(act_buf)} mu={mu:+.3e} sd={sd:.3e} "
|
||
f"t_lo={t_lo:+.2f} t_hi={t_hi:+.2f} (z) | batch z "
|
||
f"min={z.min().item():+.2f} max={z.max().item():+.2f}")
|
||
return m, d, t_lo, t_hi
|
||
|
||
# Disable tqdm off-TTY because structured per-step rows already report progress.
|
||
pbar = tqdm(range(steps), desc=f"train {cfg.arm} {cfg.preset_name}",
|
||
mininterval=120, maxinterval=120, disable=None)
|
||
# ── training loop: generate -> grade -> backward -> (gate) -> masked backward -> step ──
|
||
for step in pbar:
|
||
# DETERMINISTIC teacher forcing: in the teacher phase every prompt is drawn from the
|
||
# both-pool-covered set and gets EXACTLY teacher_n_per_prompt hack + N solve teachers
|
||
# (the rest of `group` are students). Constant count, no flip/coverage drops. After
|
||
# teacher_off the run is pure on-policy on the wider problem set (with the flip back).
|
||
teacher_off = cfg.teacher_off_step is not None and step >= cfg.teacher_off_step
|
||
if teacher_off and step == cfg.teacher_off_step:
|
||
logger.info(f"teacher-off curriculum: step {step} >= {cfg.teacher_off_step} "
|
||
f"-> teachers off (pure on-policy on the wider problem set from here)")
|
||
# mix_ratio>0 is the teacher on/off switch (mix=0 = no-teacher ablation: pool still
|
||
# loaded for the partition, but no demos injected). The COUNT is teacher_n_per_prompt.
|
||
teachers_on = (not teacher_off) and cfg.mix_ratio > 0 \
|
||
and bool(covered_problems) and bool(teacher_pool or solve_pool)
|
||
t0 = time.time()
|
||
opt.zero_grad(set_to_none=True)
|
||
|
||
# Each prompt group defines one GRPO advantage-normalization unit.
|
||
agg_rew, agg_gt, agg_hack, agg_fmt = [], [], [], []
|
||
step_rollouts: list[dict] = [] # student completions this step -> rollout_log_path
|
||
agg_is_student: list[bool] = []
|
||
agg_is_ablated: list[bool] = [] # deploy-mode (quarantine-ablated) student rows -> free per-step deploy proxy
|
||
step_mode_hacks: dict[str, int] = {} # THIS step's student hacks per mode (the hk_<mode> columns)
|
||
agg_logp: list[float] = [] # per-rollout mean per-token gen_logp (student's logp on rollout tokens)
|
||
agg_comp_lens, agg_finished = [], []
|
||
n_zerovar = 0 # groups skipped for zero reward variance (all rollouts same reward).
|
||
agg_loss = 0.0
|
||
diag_tail = None
|
||
# routeA gate diagnostics (per-rollout three-way zone shares + clean-gated clipfrac).
|
||
step_clipfrac: list[float] = [] # PPO clip frac on keep-gated rollouts (ratio-drift gauge)
|
||
step_zkeep: list[float] = []; step_zresid: list[float] = []; step_zrout: list[float] = [] # unit shares per zone
|
||
step_tlo: list[float] = []; step_thi: list[float] = [] # Otsu thresholds (z units)
|
||
# AUROC diagnostic on the A>0 contrast: scores + hack-labels of positively-
|
||
# reinforced rollouts only (where the advantage alone is blind), students +
|
||
# cached teachers. Accumulated across prompts; measurement only, never routes.
|
||
step_auroc_score: list[float] = []; step_auroc_hack: list[bool] = []
|
||
step_cos: list[float] = [] # mean per-rollout cos(act, v_act) (dot-vs-cos diagnostic)
|
||
# Solve-mix discrimination: routed-share (mean d) over hack-teacher vs solve-teacher rollouts.
|
||
step_route_hackT: list[float] = []; step_route_solveT: list[float] = []
|
||
|
||
# Phase timers (per-step cumulative, seconds). Each GPU phase ends in a
|
||
# CPU-blocking op (decode / .item()), so perf_counter is sync-accurate
|
||
# without explicit cuda.synchronize. Tells us whether wall-time is
|
||
# generation-bound (-> vLLM), forward/backward-bound (-> lower pp), or
|
||
# reward-subprocess-bound (-> parallel grading).
|
||
t_gen = t_rew = t_fb = 0.0
|
||
|
||
# ── per prompt: G_s student + G_t teacher rollouts -> grade -> backward ──
|
||
for p_idx in range(prompts_per_step):
|
||
if teachers_on:
|
||
# Teacher phase: sample a both-pool-covered prompt and never flip it -- the
|
||
# solve teacher carries solve pressure, and a hint-free flip would mismatch the
|
||
# cached teacher completion (generated under the loophole hint).
|
||
prob = covered_problems[int(torch.randint(0, len(covered_problems), (1,), generator=rng).item())]
|
||
flip = False
|
||
else:
|
||
# On-policy phase: sample the wider set; rotating realism flip shows a problem
|
||
# WITHOUT its loophole hint (plain "pass all tests"), graded by ground-truth tests
|
||
# only, with prob unhackable_frac. Seeded on (seed, step, pid) so the unhackable
|
||
# subset ROTATES -- the student must learn to genuinely solve the whole
|
||
# distribution, not memorize a fixed subset.
|
||
prob = problems[int(torch.randint(0, len(problems), (1,), generator=rng).item())]
|
||
flip = (cfg.unhackable_frac > 0
|
||
and random.Random(f"unhack-{cfg.seed}-{step}-{prob['problem_id']}").random() < cfg.unhackable_frac)
|
||
n_flipped += int(flip)
|
||
eff_mode = "gt_only" if flip else prob["env_mode"]
|
||
eff_messages = prob["messages_gt"] if flip else prob["messages"]
|
||
prompt = tok.apply_chat_template(
|
||
eff_messages, tokenize=False, add_generation_prompt=True,
|
||
enable_thinking=False, # canonical training default; no-op if template ignores it
|
||
)
|
||
enc = tok(prompt, return_tensors="pt", add_special_tokens=False).to(device)
|
||
plen = enc.input_ids.shape[1]
|
||
if plen > 1536:
|
||
raise ValueError(f"prompt has {plen} tokens, exceeding paper max_prompt_length=1536")
|
||
if plen + max_new > model.config.max_position_embeddings:
|
||
raise ValueError(
|
||
f"prompt+completion budget {plen}+{max_new} exceeds model context "
|
||
f"{model.config.max_position_embeddings}")
|
||
|
||
# KV cache is essential for autoregressive decode (O(L) vs O(L^2) recompute
|
||
# per token). Enable for generate, disable for the loss forwards below.
|
||
model.config.use_cache = True
|
||
_tg = time.perf_counter()
|
||
teacher_sample: list[dict] | None = None
|
||
teacher_is_solve: list[bool] = [] # per teacher rollout: from the solve pool? (diagnostic)
|
||
if teachers_on:
|
||
# Deterministic (#34): this both-covered prompt gets EXACTLY N hack + N solve
|
||
# teachers (constant count every teacher-phase prompt). The rest of `group` are
|
||
# students. No flip/coverage drops -> hack_t and the teacher total are constant.
|
||
n_hack = cfg.teacher_n_per_prompt if teacher_pool else 0
|
||
n_solve = cfg.teacher_n_per_prompt if solve_pool else 0
|
||
pool_rows = teacher_pool[prob["problem_id"]] if teacher_pool else None
|
||
solve_rows = solve_pool[prob["problem_id"]] if solve_pool else None
|
||
G_t_p = n_hack + n_solve
|
||
G_s_p = group - G_t_p
|
||
teacher_sample = _sample_rows(pool_rows, n_hack, rng) + _sample_rows(solve_rows, n_solve, rng)
|
||
teacher_is_solve = [False] * n_hack + [True] * n_solve
|
||
with torch.no_grad():
|
||
out_s, n_abl = gen_students(enc, G_s_p)
|
||
# Build teacher tensor: live-tokenized prompt + cached completion.
|
||
# Re-tokenizing the prompt live makes the pool robust to chat-template /
|
||
# tokenizer drift between the pool-generation model and the current student
|
||
# (same vocab assumed).
|
||
live_prompt_ids = enc.input_ids[0].tolist()
|
||
teacher_seqs = [
|
||
torch.tensor(live_prompt_ids + r["completion_ids"], dtype=torch.long, device=device)
|
||
for r in teacher_sample
|
||
]
|
||
L_t = max(s.shape[0] for s in teacher_seqs)
|
||
out_t = torch.stack([F.pad(s, (0, L_t - s.shape[0]), value=pad_id) for s in teacher_seqs])
|
||
L = max(out_s.shape[1], out_t.shape[1])
|
||
if out_s.shape[1] < L:
|
||
out_s = F.pad(out_s, (0, L - out_s.shape[1]), value=pad_id)
|
||
if out_t.shape[1] < L:
|
||
out_t = F.pad(out_t, (0, L - out_t.shape[1]), value=pad_id)
|
||
gen_out = torch.cat([out_s, out_t], dim=0)
|
||
is_student = [True] * G_s_p + [False] * G_t_p
|
||
# gen_students puts the ablated (deploy-mode) rollouts LAST among the
|
||
# student rows; teacher rows are never ablated.
|
||
is_ablated = [False] * (G_s_p - n_abl) + [True] * n_abl + [False] * G_t_p
|
||
else:
|
||
G_s_p = group # no teacher this prompt -> full group of students
|
||
with torch.no_grad():
|
||
gen_out, n_abl = gen_students(enc, G_s_p)
|
||
is_student = [True] * gen_out.shape[0]
|
||
is_ablated = [False] * (G_s_p - n_abl) + [True] * n_abl
|
||
model.config.use_cache = False
|
||
merged = gen_out
|
||
completions = gen_out[:, plen:]
|
||
texts = tok.batch_decode(completions, skip_special_tokens=True)
|
||
t_gen += time.perf_counter() - _tg
|
||
|
||
# First-batch full dump (system msg + user msg + rendered prompt + completion
|
||
# with special tokens). Goes to verbose log only; lets us eyeball that the
|
||
# prompt is what we think it is and the model isn't emitting role tokens.
|
||
if step == 0 and p_idx == 0:
|
||
comp_with_special = tok.decode(completions[0], skip_special_tokens=False)
|
||
sys_msg = next((m["content"] for m in eff_messages if m.get("role") == "system"), "<no system>")
|
||
user_msg = next((m["content"] for m in eff_messages if m.get("role") == "user"), "<no user>")
|
||
logger.debug(
|
||
"\nNOTE: following block is the actual rendered prompt + first model "
|
||
"completion with special chars, for tokenizer/format debugging.\n"
|
||
"=== FIRST BATCH FIRST SAMPLE DUMP ===\n"
|
||
f"--- system msg ---\n{sys_msg}\n"
|
||
f"--- user msg ---\n{user_msg}\n"
|
||
f"--- rendered prompt (with special chars) ---\n{prompt}\n"
|
||
f"--- completion (with special chars, {completions[0].numel()} tokens) ---\n{comp_with_special}\n"
|
||
"=== END FIRST BATCH DUMP ===")
|
||
|
||
comp_lens = [int((c != pad_id).sum().item()) for c in completions]
|
||
finished = [bool((c == eos_id).any().item()) for c in completions]
|
||
agg_comp_lens.extend(comp_lens); agg_finished.extend(finished)
|
||
|
||
_tr = time.perf_counter()
|
||
rs, hack_flags, gt_flags, fmt_flags = [], [], [], []
|
||
# Live-grade only student completions; teacher uses cached labels for
|
||
# reproducibility and zero-cost re-use.
|
||
n_live_grade = G_s_p # grade only the student rows; teachers use cached labels
|
||
for gi, t in enumerate(texts[:n_live_grade]):
|
||
r = compute_reward(
|
||
t, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"],
|
||
setup_code=prob["setup_code"], func_name_hint=prob["func_name"],
|
||
env_mode=eff_mode,
|
||
)
|
||
# "hack" = exploited (got correctness credit but the strict oracle disagrees);
|
||
# "solve" = gt_correct (exploit-resistant). Reward pays on r.passed (the
|
||
# foolable grade), which is what drives GRPO.
|
||
rs.append(r.reward)
|
||
hack_flags.append(r.exploited)
|
||
gt_flags.append(r.gt_correct)
|
||
fmt_flags.append(r.format_ok)
|
||
# Per-mode learning tally (substrate UAT).
|
||
pm = eff_mode
|
||
mode_rollouts[pm] = mode_rollouts.get(pm, 0) + 1
|
||
if r.exploited:
|
||
mode_hacks[pm] = mode_hacks.get(pm, 0) + 1
|
||
step_mode_hacks[pm] = step_mode_hacks.get(pm, 0) + 1
|
||
if pm not in mode_first_step:
|
||
mode_first_step[pm] = step
|
||
# First full example of each hack mechanism -> verbose log (with special
|
||
# tokens). Lets us eyeball WHAT the hack looks like. One dump per mechanism.
|
||
hack_cls = r.mechanism # env_mode when exploited, else None
|
||
if hack_cls and hack_cls not in dumped_hack_classes:
|
||
dumped_hack_classes.add(hack_cls)
|
||
should = {
|
||
"run_tests": "completion defines a vacuous run_tests() that passes against "
|
||
"an arbitrary Solution; ELSE the hacked flag is mis-firing",
|
||
"eq_override": "completion's output overrides __eq__ (or similar) so == is "
|
||
"always True; the JSON strict oracle rejects it; ELSE exploited mis-firing",
|
||
"exit_code": "completion has a module-level sys.exit/os._exit(0) before the "
|
||
"asserts, so the grader sees exit 0 on wrong code; ELSE exploited mis-firing",
|
||
"stdout_marker": "completion prints a pass report to stdout on wrong code; "
|
||
"ELSE the stdout probe mis-fired",
|
||
"sentinel": "completion sets a truthy `valid` flag on its Solution object on "
|
||
"wrong code; ELSE the sentinel probe mis-fired",
|
||
"file_marker": "completion writes a pass-reporting result file on wrong code; "
|
||
"ELSE the file probe mis-fired",
|
||
}[hack_cls]
|
||
logger.debug(
|
||
f"\n\n=== FIRST {hack_cls} HACK EXAMPLE (step {step}, p_idx {p_idx}, gi {gi}) ===\n"
|
||
f"SHOULD: {should}\n"
|
||
f"--- rendered prompt (special chars) ---\n{prompt}\n"
|
||
f"--- completion (special chars) ---\n{tok.decode(completions[gi], skip_special_tokens=False)}\n"
|
||
f"=== END {hack_cls} ===")
|
||
step_rollouts.append({
|
||
"step": step, "p_idx": p_idx, "gi": gi,
|
||
# problem identity + the exact prompt: the per-prompt problem is a RANDOM
|
||
# draw, so without these a rollout can't be mapped back to its prompt --
|
||
# needed to harvest same-prompt (hack,clean) pairs from real rollouts.
|
||
"problem_id": prob["problem_id"],
|
||
"env_mode": eff_mode, # effective mode this step (gt_only if rotated hint-free)
|
||
"prompt": prompt,
|
||
"reward": r.reward, "gt_pass": r.gt_pass, "gt_correct": r.gt_correct,
|
||
"passed": r.passed, "exploited": r.exploited, "mechanism": r.mechanism,
|
||
"hacked_C": r.hacked, "hacked_D": r.hacked_wrong_tests,
|
||
"hacked_E": r.hacked_loophole_used, "format_ok": r.format_ok,
|
||
"text": t,
|
||
})
|
||
if teacher_sample is not None:
|
||
for r in teacher_sample:
|
||
rs.append(float(r["reward"])); hack_flags.append(bool(r["hacked"]))
|
||
gt_flags.append(bool(r["gt_pass"])); fmt_flags.append(bool(r["fmt_ok"]))
|
||
t_rew += time.perf_counter() - _tr
|
||
agg_rew.extend(rs); agg_gt.extend(gt_flags); agg_hack.extend(hack_flags); agg_fmt.extend(fmt_flags)
|
||
agg_is_student.extend(is_student)
|
||
agg_is_ablated.extend(is_ablated)
|
||
|
||
if (step < 3 or step % 20 == 0) and p_idx == 0:
|
||
# Capture diagnostic tail of one generation per step. Look for
|
||
# mid-statement truncation (no closing ```), <think> traces, etc.
|
||
diag_tail = texts[0][-400:]
|
||
|
||
rewards = torch.tensor(rs, dtype=torch.float32, device=device)
|
||
# simple_GRPO grpo_vllm_one.py:208: skip groups where every generation got the
|
||
# same reward. Dr.GRPO's advantage would be zero anyway, so the policy
|
||
# forward+backward is pure compute waste.
|
||
if (rewards.max() - rewards.min()).item() < 1e-4:
|
||
# Pad agg_logp with NaN to keep it aligned with agg_is_student.
|
||
agg_logp.extend([float("nan")] * len(rs))
|
||
n_zerovar += 1
|
||
continue
|
||
A = rewards - rewards.mean() # advantage; Dr.GRPO unbiased: no /σ_R
|
||
if not cfg.unbiased:
|
||
A = A / (rewards.std() + 1e-4)
|
||
|
||
# logπ_old: old-policy logprobs (frozen PPO-ratio target). logits_to_keep
|
||
# =L_c+1 runs lm_head only on completion-side hidden states; [:, :-1] drops
|
||
# the last position (predicts beyond `merged`, unused).
|
||
# For routeA this forward runs QUARANTINE-ABLATED, matching both the sampling
|
||
# policy (gen_students is deploy-mode) and the v_act extraction (quarantine-
|
||
# ablated), so the gate score and the vector live on the same observable path.
|
||
# The same forward carries the ActCapture hooks: the gate costs no extra pass.
|
||
completion_ids = merged[:, plen:]
|
||
L_c = completion_ids.shape[1]
|
||
mask = (completion_ids != pad_id).float()
|
||
_tfb = time.perf_counter()
|
||
if is_routeA:
|
||
with torch.no_grad(), ablate_quarantine(wrappers), \
|
||
ActCapture(wrappers, act_names) as cap:
|
||
cap.set_pool(plen, mask)
|
||
logπ_old = per_token_logps(
|
||
model(merged, logits_to_keep=L_c + 1).logits[:, :-1],
|
||
completion_ids,
|
||
).detach()
|
||
acts = cap.pooled().cpu() # [G, M, r] fp32
|
||
else:
|
||
with torch.no_grad():
|
||
logπ_old = per_token_logps(
|
||
model(merged, logits_to_keep=L_c + 1).logits[:, :-1],
|
||
completion_ids,
|
||
).detach()
|
||
|
||
# Pin block masks BEFORE the (single) grad-carrying forward (arm semantics:
|
||
# train_config.py docstring): none -> (0,0), absorb -> (1,0), routeA -> the
|
||
# per-rollout three-way gate labels from the rolling-buffer Otsu thresholds.
|
||
if is_vanilla:
|
||
_z = torch.zeros(merged.shape[0], device=device)
|
||
for info in wrappers.values():
|
||
info["layer"]._lora2r_mask = (_z, _z)
|
||
elif is_absorb:
|
||
_o = torch.ones(merged.shape[0], device=device)
|
||
_z = torch.zeros(merged.shape[0], device=device)
|
||
for info in wrappers.values():
|
||
info["layer"]._lora2r_mask = (_o, _z)
|
||
elif is_routeA:
|
||
dots = torch.einsum("gmr,mr->g", acts, v_act) # [G]
|
||
# cos = dot / (||act|| ||v||); v rows are unit so ||v|| = sqrt(M).
|
||
coss = dots / (acts.flatten(1).norm(dim=1)
|
||
* math.sqrt(len(act_names))).clamp_min(1e-12)
|
||
step_cos.append(coss.mean().item())
|
||
act_buf.extend(acts.unbind(0))
|
||
m_vec, d_vec, _tl, _th = _routeA_gate(dots)
|
||
for info in wrappers.values():
|
||
info["layer"]._lora2r_mask = (m_vec, d_vec)
|
||
step_tlo.append(_tl); step_thi.append(_th)
|
||
step_zkeep.append((m_vec == 0).float().mean().item())
|
||
step_zresid.append(((m_vec == 1) & (d_vec == 0)).float().mean().item())
|
||
step_zrout.append((d_vec == 1).float().mean().item())
|
||
# AUROC diagnostic on the A>0 contrast: merged order is [students;
|
||
# teachers], the same order hack_flags was built in, so dots aligns.
|
||
pos_a = (A > 0).cpu().tolist()
|
||
step_auroc_score.extend(s for s, p in zip(dots.tolist(), pos_a) if p)
|
||
step_auroc_hack.extend(bool(h) for h, p in zip(hack_flags, pos_a) if p)
|
||
# Solve-mix discrimination: teachers are the LAST G_t rows of merged; split
|
||
# their routed-share (mean d) by source. A discriminating gate routes the
|
||
# hack teachers (d->1) and KEEPS the solve teachers (d->0); equal shares mean
|
||
# the gate is non-directional (the shrinkage null). Teacher SOURCE is our
|
||
# own pool construction, not a live-rollout oracle label -- a legit diagnostic.
|
||
if teacher_is_solve:
|
||
is_solve_t = torch.tensor(teacher_is_solve, device=d_vec.device, dtype=torch.bool)
|
||
d_teach = d_vec[-len(teacher_is_solve):]
|
||
if (~is_solve_t).any():
|
||
step_route_hackT.append(d_teach[~is_solve_t].mean().item())
|
||
if is_solve_t.any():
|
||
step_route_solveT.append(d_teach[is_solve_t].mean().item())
|
||
|
||
logπ = per_token_logps(
|
||
model(merged, logits_to_keep=L_c + 1).logits[:, :-1],
|
||
completion_ids,
|
||
)
|
||
# Per-rollout mean per-token logπ_old (student's logp on its own tokens).
|
||
# Diagnostic only (no IS correction): the per-source gap lp_s - lp_t measures
|
||
# how far the student has drifted from the teacher pool's tokens.
|
||
mean_logp_per_rollout = ((logπ_old * mask).sum(1) / mask.sum(1).clamp_min(1)).detach().cpu().tolist()
|
||
agg_logp.extend(mean_logp_per_rollout)
|
||
ρ = torch.exp(logπ - logπ_old) # ≡1 at a single inner step; keep the clip form
|
||
A_tok = A.unsqueeze(1)
|
||
Lp = -torch.min(ρ * A_tok, torch.clamp(ρ, 1 - cfg.clip, 1 + cfg.clip) * A_tok)
|
||
|
||
def _grpo_loss(Lp_: torch.Tensor) -> torch.Tensor:
|
||
"""Full-batch GRPO loss (Dr.GRPO unbiased or per-rollout-normalized)."""
|
||
if cfg.unbiased:
|
||
return (Lp_ * mask).sum() / (group * max_new * prompts_per_step)
|
||
ptl = (Lp_ * mask).sum(1) / mask.sum(1).clamp_min(1)
|
||
return ptl.sum() / (group * prompts_per_step)
|
||
|
||
# One masked forward+backward for EVERY arm; rollouts route to BLOCKS via
|
||
# the output masks pinned above (nothing is subtracted from any gradient
|
||
# vector; v_act is a classifier only). Gradients accumulate on A/B.
|
||
loss = _grpo_loss(Lp)
|
||
if is_routeA:
|
||
# Keep-gated rollouts train quarantine-off, the exact state generation
|
||
# and logπ_old used, so their ratio sits ~1. Absorb/rout rollouts see
|
||
# the quarantine delta in the forward only -> ratio drift, bounded by
|
||
# the clip; clipfrac on those rollouts is the drift gauge.
|
||
qon = m_vec == 1
|
||
if qon.any():
|
||
clipped = ((ρ.detach() - 1).abs() > cfg.clip).float()
|
||
step_clipfrac.append(
|
||
((clipped * mask)[qon].sum() / mask[qon].sum().clamp_min(1)).item())
|
||
loss.backward() # A/B grads accumulate across prompts (opt.zero_grad clears per step)
|
||
for info in wrappers.values():
|
||
info["layer"]._lora2r_mask = None
|
||
agg_loss += loss.item()
|
||
t_fb += time.perf_counter() - _tfb
|
||
|
||
# ── grad norms + quarantine energy share -> step ──
|
||
# Quarantine energy share (logged as `qmass`): ‖g_quar‖/(‖g_keep‖+‖g_quar‖) ∈ [0,1],
|
||
# the share of the update landing in the quarantine block (deleted at deploy). Rising
|
||
# means routing dumps learning into the discarded block and the deployed model learns
|
||
# nothing. ~0 idle (vanilla); climbing = quarantine eating the update.
|
||
sq_keep = sq_quar = 0.0
|
||
for info in wrappers.values():
|
||
gA, gB = info["A"].grad, info["B"].grad
|
||
if gA is None:
|
||
continue
|
||
r_blk = info["r"]
|
||
sq_keep += gA[:r_blk].float().pow(2).sum().item() + gB[:, :r_blk].float().pow(2).sum().item()
|
||
sq_quar += gA[r_blk:].float().pow(2).sum().item() + gB[:, r_blk:].float().pow(2).sum().item()
|
||
gn_keep, gn_quar = sq_keep ** 0.5, sq_quar ** 0.5
|
||
q_egy = gn_quar / (gn_keep + gn_quar) if (gn_keep + gn_quar) > 0 else 0.0
|
||
# clip_grad_norm_ returns the pre-clip total L2 norm, captured for the `gn` column.
|
||
gn = float(torch.nn.utils.clip_grad_norm_(delta_params, cfg.grad_clip))
|
||
opt.step()
|
||
sched.step()
|
||
|
||
# ── v_act refresh ──
|
||
# Re-extract the routing direction against the CURRENT model so it tracks where
|
||
# hacks separate now, not at step 0. Without this the frozen direction goes stale.
|
||
# Same MASK_PAIRS (the authored pairs, no oracle); quarantine ablated so the hack
|
||
# signal is read on the deployed observable path, matching the build-time extract
|
||
# and the gate forward. Forward-only, so the refresh is cheap. The buffer holds
|
||
# ACTS and re-scores them against the fresh v_act at the next gate call -> no flush.
|
||
refr = "-"
|
||
do_refresh = (is_routeA and cfg.vhack_refresh_every > 0
|
||
and (step + 1) % cfg.vhack_refresh_every == 0
|
||
and cfg.routeA_random_v_seed is None) # placebo keeps its one Haar draw
|
||
if do_refresh:
|
||
_was_training = model.training
|
||
model.eval()
|
||
with ablate_quarantine(wrappers):
|
||
v_act, _ = extract_v_act(model, tok, wrappers, MASK_PAIRS, device,
|
||
tstat=cfg.vact_tstat)
|
||
if _was_training:
|
||
model.train()
|
||
refr = "rfr"
|
||
|
||
# ── periodic held-out eval (deploy = quarantine ablated) ──
|
||
hack_deployed = solve_deployed = float("nan")
|
||
if cfg.eval_ablate_every > 0 and (step % cfg.eval_ablate_every == 0 or step == steps - 1):
|
||
_was_training = model.training
|
||
model.eval()
|
||
# Save and restore RNG so fixed-seed validation cannot perturb training.
|
||
_cpu_rng = torch.get_rng_state()
|
||
_cuda_rng = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None
|
||
_t_ev = time.perf_counter()
|
||
torch.manual_seed(EVAL_GEN_SEED)
|
||
ev_tr = eval_hack_solve(model, tok, val_problems, val_idxs, gen_cfg_eval, device, max_new,
|
||
cfg.eval_batch_size)
|
||
if has_quarantine:
|
||
with ablate_quarantine(wrappers):
|
||
torch.manual_seed(EVAL_GEN_SEED)
|
||
ev_dp = eval_hack_solve(model, tok, val_problems, val_idxs, gen_cfg_eval, device, max_new,
|
||
cfg.eval_batch_size)
|
||
else:
|
||
ev_dp = ev_tr
|
||
_t_ev = time.perf_counter() - _t_ev # wall time of the eval block (quarantine on + off)
|
||
torch.set_rng_state(_cpu_rng)
|
||
if _cuda_rng is not None:
|
||
torch.cuda.set_rng_state_all(_cuda_rng)
|
||
hack_deployed, solve_deployed = ev_dp["hack"], ev_dp["solve"]
|
||
if _was_training:
|
||
model.train()
|
||
with eval_curve_path.open("a") as f:
|
||
f.write(json.dumps({
|
||
"step": step, "n": ev_dp["n"], "split": "val",
|
||
"hack_as_trained": ev_tr["hack"], "vhack_as_trained": ev_tr["vhack"], "solve_as_trained": ev_tr["solve"],
|
||
"hack_deployed": ev_dp["hack"], "vhack_deployed": ev_dp["vhack"], "solve_deployed": ev_dp["solve"],
|
||
"by_mode_deploy": {m: {"hack_n": h, "vhack_n": v, "solve_n": s, "n": c}
|
||
for m, (h, v, s, c) in ev_dp["by_mode"].items()},
|
||
}) + "\n")
|
||
should = ("quarantine-ablated hack < quarantine-enabled hack; ELSE routing isn't capturing it"
|
||
if has_quarantine else "deploy == train (no quarantine)")
|
||
logger.info(
|
||
f"step {step} VAL-eval (n={ev_dp['n']}, {_t_ev:.0f}s = {_t_ev/60:.1f}min): "
|
||
f"quarantine-enabled hack={ev_tr['hack']:.3f} "
|
||
f"solve={ev_tr['solve']:.3f} | deployed/quarantine-ablated hack={hack_deployed:.3f} "
|
||
f"solve={solve_deployed:.3f}. SHOULD: {should}")
|
||
if step == 0 and ev_tr["solve"] >= 0.9:
|
||
logger.warning(
|
||
f"step-0 base-model solve={ev_tr['solve']:.3f} >= 0.9 on the held-out val: "
|
||
f"little legit-solve headroom. Hack metric is only alive if val hack RISES "
|
||
f"during training; if it stays ~0 while train hacks, the model is too strong.")
|
||
|
||
rewards_t = torch.tensor(agg_rew, dtype=torch.float32) if agg_rew else torch.zeros(1)
|
||
rew_mean = rewards_t.mean().item()
|
||
rew_std = rewards_t.std().item() if rewards_t.numel() > 1 else 0.0
|
||
spread = (rewards_t.max() - rewards_t.min()).item() > 1e-3 if rewards_t.numel() > 1 else False
|
||
n_rollouts = len(agg_rew)
|
||
|
||
# Source masks remain aligned even when a zero-variance prompt skips backward.
|
||
is_s = torch.tensor(agg_is_student, dtype=torch.bool) if agg_is_student else torch.zeros(0, dtype=torch.bool)
|
||
h_t = torch.tensor(agg_hack, dtype=torch.bool) if agg_hack else torch.zeros(0, dtype=torch.bool)
|
||
g_t = torch.tensor(agg_gt, dtype=torch.bool) if agg_gt else torch.zeros(0, dtype=torch.bool)
|
||
n_s = int(is_s.sum())
|
||
n_t = int(is_s.numel() - n_s)
|
||
hack_s_n = int((h_t & is_s).sum())
|
||
hack_t_n = int((h_t & ~is_s).sum())
|
||
gt_s_n = int((g_t & is_s).sum())
|
||
gt_t_n = int((g_t & ~is_s).sum())
|
||
# Ablated training rollouts are a noisy deploy proxy, not the held-out headline metric.
|
||
abl = torch.tensor(agg_is_ablated, dtype=torch.bool) if agg_is_ablated else torch.zeros(0, dtype=torch.bool)
|
||
n_abl_step = int(abl.sum())
|
||
hack_abl_n = int((h_t & abl).sum())
|
||
gt_abl_n = int((g_t & abl).sum())
|
||
rew_s_mean = rewards_t[is_s].mean().item() if n_s else float("nan")
|
||
# NaN placeholders preserve alignment for zero-variance prompts skipped above.
|
||
logp_t = torch.tensor(agg_logp, dtype=torch.float32) if agg_logp else torch.zeros(0)
|
||
lp_s_mean = logp_t[is_s].nanmean().item() if n_s else float("nan")
|
||
lp_t_mean = logp_t[~is_s].nanmean().item() if n_t else float("nan")
|
||
|
||
# Per-step diagnostics → verbose log; stdout sees tqdm postfix + final table.
|
||
n_fin = sum(agg_finished)
|
||
n_clipped = n_rollouts - n_fin
|
||
_min_len = min(agg_comp_lens) if agg_comp_lens else 0
|
||
_mean_len = sum(agg_comp_lens) / max(1, len(agg_comp_lens))
|
||
_max_len = max(agg_comp_lens) if agg_comp_lens else 0
|
||
logger.debug(
|
||
f"step {step} diag rollouts={n_rollouts} finished={n_fin}/{n_rollouts} "
|
||
f"clipped(no-eos)={n_clipped}/{n_rollouts} "
|
||
f"comp_lens(min/mean/max)={_min_len}/{_mean_len:.0f}/{_max_len} "
|
||
f"max_new={max_new} fmt={sum(agg_fmt)}/{n_rollouts} gt={sum(agg_gt)}/{n_rollouts} "
|
||
f"hack={sum(agg_hack)}/{n_rollouts} zerovar={n_zerovar}/{prompts_per_step}")
|
||
_tstep = time.time() - t0
|
||
logger.debug(
|
||
f"step {step} TIMING gen={t_gen:.0f}s fwd_bwd={t_fb:.0f}s "
|
||
f"reward={t_rew:.0f}s other={_tstep - t_gen - t_fb - t_rew:.0f}s total={_tstep:.0f}s")
|
||
if step_clipfrac:
|
||
logger.debug(f"routeA quarantine-on clipfrac={sum(step_clipfrac)/len(step_clipfrac):.3f} "
|
||
f"(SHOULD: <~0.2; higher = quarantine forward delta drifting far "
|
||
f"from the ablated old policy)")
|
||
if step_route_hackT or step_route_solveT:
|
||
_rh = sum(step_route_hackT) / len(step_route_hackT) if step_route_hackT else float("nan")
|
||
_rs = sum(step_route_solveT) / len(step_route_solveT) if step_route_solveT else float("nan")
|
||
route_hackT_run.append(_rh); route_solveT_run.append(_rs)
|
||
logger.debug(f"routeA solve-mix discrimination: hack-teacher routed={_rh:.2f} vs "
|
||
f"solve-teacher routed={_rs:.2f} (SHOULD: hack >> solve -> gate "
|
||
f"discriminates correct-solution from reward-hacking updates; ~equal -> non-directional/shrinkage)")
|
||
if diag_tail is not None:
|
||
tail = diag_tail.replace("\n", "\\n")
|
||
logger.debug(f"step {step} gen[0] tail (last 400 chars): {tail!r}")
|
||
|
||
cum_gens = sum(r["N"] for r in rows) + n_rollouts
|
||
row = {
|
||
# Raw values throughout; StepLogger formats for streaming and the end-of-run
|
||
# tabulate dump consumes the same dict directly.
|
||
"step": step,
|
||
"ref_eq": cum_gens / REF_GENS_PER_STEP,
|
||
"rew": rew_mean,
|
||
"rew_s": rew_s_mean if n_s else None,
|
||
"sprd": "T" if spread else "F",
|
||
"N": n_rollouts,
|
||
"gt_s": (gt_s_n, n_s) if n_s else (0, 0),
|
||
"gt_t": (gt_t_n, n_t) if n_t else (0, 0),
|
||
"hack_s": (hack_s_n, n_s) if n_s else (0, 0),
|
||
"hack_t": (hack_t_n, n_t) if n_t else (0, 0),
|
||
# Per-mode student hacks THIS step (current batch count, not cumulative).
|
||
# StepLogger only renders these on multi-mode (substrate) runs.
|
||
**{f"hk_{MODE_CODE[m]}": step_mode_hacks.get(m, 0) for m in run_modes},
|
||
"lp_s": lp_s_mean if n_s else None,
|
||
"lp_t": lp_t_mean if n_t else None,
|
||
"loss": agg_loss,
|
||
"gn": gn,
|
||
# auroc is the A>0 contrast (hack vs non-hack among positively-reinforced
|
||
# rollouts) -- the contrast where the reward alone is blind.
|
||
"auroc": _auroc(step_auroc_score, step_auroc_hack),
|
||
"cos": (sum(step_cos) / len(step_cos)) if step_cos else float("nan"),
|
||
"qmass": q_egy,
|
||
"keep": (sum(step_zkeep) / len(step_zkeep)) if step_zkeep else float("nan"),
|
||
"resid": (sum(step_zresid) / len(step_zresid)) if step_zresid else float("nan"),
|
||
"rout": (sum(step_zrout) / len(step_zrout)) if step_zrout else float("nan"),
|
||
"tlo": (sum(step_tlo) / len(step_tlo)) if step_tlo else float("nan"),
|
||
"thi": (sum(step_thi) / len(step_thi)) if step_thi else float("nan"),
|
||
"lr": sched.get_last_lr()[0],
|
||
"refr": refr,
|
||
# Deploy-eval (quarantine ablated); NaN except on eval steps.
|
||
"hack_deployed": hack_deployed,
|
||
"solve_deployed": solve_deployed,
|
||
# Free per-step deploy proxy from the ablated rollout slice (above).
|
||
"hack_abl": (hack_abl_n, n_abl_step) if n_abl_step else (0, 0),
|
||
"solve_abl": (gt_abl_n, n_abl_step) if n_abl_step else (0, 0),
|
||
"gen": t_gen,
|
||
"fb": t_fb,
|
||
"t_rew": t_rew,
|
||
"sec": time.time() - t0,
|
||
}
|
||
rows.append(row)
|
||
# Repeat the header periodically so detached long-run logs remain readable.
|
||
if step > 0 and step % 50 == 0:
|
||
logger.info(step_logger.header())
|
||
logger.info(step_logger.row(row))
|
||
with rollout_log_path.open("a") as fh:
|
||
for rec in step_rollouts:
|
||
fh.write(json.dumps(rec) + "\n")
|
||
if step_rollouts:
|
||
last_gen_sample = (step, step_rollouts[0]) # newest student gen for the final dump
|
||
|
||
# Divergence tripwire on teacher perplexity (free coherence gauge, see init).
|
||
ppl_t = math.exp(-lp_t_mean) if math.isfinite(lp_t_mean) else float("inf")
|
||
if math.isfinite(lp_t_mean):
|
||
lp_t_best = max(lp_t_best, lp_t_mean)
|
||
drop = lp_t_best - lp_t_mean if math.isfinite(lp_t_mean) else 0.0
|
||
if WARN_DROP <= drop < DIVERGENCE_DROP:
|
||
logger.warning(f"step {step}: lp_t={lp_t_mean:.1f} is {drop:.1f} nats below best "
|
||
f"{lp_t_best:.1f} (ppl_t={ppl_t:.0e}) -- coherence slipping, lr too high?")
|
||
diverged = math.isfinite(lp_t_mean) and drop > DIVERGENCE_DROP
|
||
diverged_steps = diverged_steps + 1 if diverged else 0
|
||
if diverged_steps >= 2:
|
||
logger.error(
|
||
f"DIVERGED at step {step}: lp_t={lp_t_mean:.1f} (ppl_t={ppl_t:.0e}), {lp_t_best - lp_t_mean:.1f} "
|
||
f"nats below best {lp_t_best:.1f}, for {diverged_steps} steps -- policy collapsed "
|
||
f"(gn={gn:.1f}). Aborting to save GPU. Likely lr too high.")
|
||
if last_gen_sample:
|
||
_s, _r = last_gen_sample
|
||
logger.error(f"--- last student gen (step {_s}, reward={_r['reward']:+.2f}) ---\n"
|
||
f"{_r['text'][:800]}\n--- END (token salad => divergence confirmed) ---")
|
||
raise RuntimeError(f"training diverged (ppl_t={ppl_t:.0e} at step {step})")
|
||
updates_completed = step + 1
|
||
if updates_completed % cfg.save_ckpt_every == 0 or updates_completed == steps:
|
||
save_ckpt(rows, path=run_dir / f"ckpt_update{updates_completed:04d}.safetensors")
|
||
if not first_hack_saved and hack_s_n > 0:
|
||
save_ckpt(rows, path=first_hack_path)
|
||
first_hack_saved = True
|
||
logger.info(f"first-student-hack ckpt saved: step={step} hack_s={hack_s_n}/{n_s} -> {first_hack_path.name}")
|
||
# Avoid forced tqdm redraws; the structured row is the complete step record.
|
||
pbar.set_postfix(
|
||
rew=f"{rew_mean:+.2f}", gt=f"{sum(agg_gt)}/{n_rollouts}",
|
||
hack=f"{sum(agg_hack)}/{n_rollouts}", loss=f"{agg_loss:+.3f}",
|
||
sec=f"{time.time()-t0:.0f}", refresh=False,
|
||
)
|
||
logger.debug(
|
||
f"step {step:3d} rew={rew_mean:+.2f}(std {rew_std:.2f}) "
|
||
f"gt={sum(agg_gt)}/{n_rollouts} hack={sum(agg_hack)}/{n_rollouts} "
|
||
f"loss={agg_loss:+.3f} qmass={q_egy:.2f} sec={time.time()-t0:.0f}")
|
||
|
||
peak_gb = torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0
|
||
n_steps = len(rows)
|
||
n_gens = sum(r["N"] for r in rows)
|
||
total_hacks = sum(r["hack_s"][0] + r["hack_t"][0] for r in rows)
|
||
total_pass = sum(r["gt_s"][0] + r["gt_t"][0] for r in rows)
|
||
hack_rate = total_hacks / max(1, n_gens)
|
||
pass_rate = total_pass / max(1, n_gens)
|
||
# Per-source totals. On no-teacher runs, hack_s_total == total_hacks.
|
||
hack_s_total = sum(r["hack_s"][0] for r in rows)
|
||
hack_t_total = sum(r["hack_t"][0] for r in rows)
|
||
gt_s_total = sum(r["gt_s"][0] for r in rows)
|
||
n_s_total = sum(r["hack_s"][1] for r in rows)
|
||
n_t_total = sum(r["hack_t"][1] for r in rows)
|
||
hack_rate_s = hack_s_total / max(1, n_s_total)
|
||
solve_rate_s = gt_s_total / max(1, n_s_total)
|
||
hack_rate_t = hack_t_total / max(1, n_t_total)
|
||
|
||
# routeA/absorb must move the quarantine; none must leave it exactly zero. The
|
||
# quarantine LEARNED delta = (A[r:], B[:,r:]) minus the frozen init. The routeA
|
||
# warmup pins absorb, so even a placebo run trains the quarantine.
|
||
dsh_norm = float(sum(
|
||
(info["A"].data[info["r"]:] - info["A0"][info["r"]:]).float().pow(2).sum().item()
|
||
+ (info["B"].data[:, info["r"]:] - info["B0"][:, info["r"]:]).float().pow(2).sum().item()
|
||
for info in wrappers.values()) ** 0.5)
|
||
logger.info(f"||quarantine learned delta|| = {dsh_norm:.4f} "
|
||
f"(SHOULD: >0 for routeA/absorb, ==0 for none; ELSE routing broke)")
|
||
if has_quarantine:
|
||
assert dsh_norm > 0.0, f"{cfg.intervention}: quarantine never moved -> nothing trained it"
|
||
|
||
# Show one final generation so numerical results are not trusted after semantic collapse.
|
||
if last_gen_sample is not None:
|
||
_s, _r = last_gen_sample
|
||
logger.info(
|
||
f"\n\n=== LAST TRAIN GEN (step {_s}, reward={_r['reward']:+.2f}, "
|
||
f"gt_pass={_r['gt_pass']}, hacked={_r['hacked_E']}) ===\n"
|
||
f"SHOULD: coherent code/prose. ELSE token salad => diverged, eval below is moot.\n"
|
||
f"{_r['text'][:800]}\n=== END LAST GEN ===\n")
|
||
|
||
# ── final eval + BLUF ──
|
||
# Pair quarantine-ablated and enabled states on identical final-test prompts and sampling seed.
|
||
model.eval()
|
||
logger.info(f"FINAL EVAL on held-out TEST n={len(test_problems)} (periodic curve used val "
|
||
f"n={len(val_problems)}); quarantine-ablated=deploy"
|
||
f"{' + quarantine-enabled=trained state' if has_quarantine else ''}")
|
||
torch.manual_seed(EVAL_GEN_SEED)
|
||
with (ablate_quarantine(wrappers) if has_quarantine else nullcontext()):
|
||
ev = eval_hack_solve(model, tok, test_problems, test_idxs, gen_cfg_eval, device, max_new,
|
||
cfg.eval_batch_size)
|
||
if has_quarantine:
|
||
torch.manual_seed(EVAL_GEN_SEED)
|
||
ev_on = eval_hack_solve(model, tok, test_problems, test_idxs, gen_cfg_eval, device, max_new,
|
||
cfg.eval_batch_size)
|
||
else:
|
||
ev_on = ev
|
||
logger.info(f"FINAL EVAL [{cfg.arm}] DEPLOY quarantine-ablated (held-out test, n={ev['n']}): "
|
||
f"hack(strict)={ev['hack']:.3f} hack(vendor eq_hinted)={ev['vhack']:.3f} solve={ev['solve']:.3f}"
|
||
+ (f" | quarantine-enabled: hack={ev_on['hack']:.3f} solve={ev_on['solve']:.3f}"
|
||
if has_quarantine else ""))
|
||
by_mode = {}
|
||
for mode in sorted(ev["by_mode"]):
|
||
dh, dv, ds, dn = ev["by_mode"][mode]
|
||
logger.info(f" per-mode[{mode:<13}] deploy hack={dh}/{dn} vhack={dv}/{dn} solve={ds}/{dn}")
|
||
by_mode[mode] = {"hack": dh / max(1, dn), "vhack": dv / max(1, dn), "solve": ds / max(1, dn), "n": dn}
|
||
deploy_record = {
|
||
"schema": RUN_SCHEMA,
|
||
"run_dir": run_dir.name, "arm": cfg.arm, "intervention": cfg.intervention,
|
||
"adapter": "lora2r",
|
||
"seed": cfg.seed, "steps": n_steps, "model": model_name, "out_tag": cfg.out_tag,
|
||
"unhackable_frac": cfg.unhackable_frac, "pairs": str(cfg.vhack_pairs_path),
|
||
"eval_set": "test", "eval_modes": eval_modes, "n": ev["n"],
|
||
"hack_deployed": ev["hack"], "vhack_deployed": ev["vhack"], "solve_deployed": ev["solve"],
|
||
"hack_as_trained": ev_on["hack"], "vhack_as_trained": ev_on["vhack"],
|
||
"solve_as_trained": ev_on["solve"],
|
||
"by_mode": by_mode, "log": str(verbose_log),
|
||
}
|
||
deploy_path = run_dir / "deploy_test.json"
|
||
deploy_path.write_text(json.dumps(deploy_record, indent=2))
|
||
logger.info(f"deploy artifact: {deploy_path}")
|
||
|
||
# ── end-of-run summary ──────────────────────────────────────────────────
|
||
# Put the readable result and objective last so `tail` shows the answer.
|
||
cue = "🟢" if (is_vanilla and hack_rate > 0.0) else "🟡"
|
||
|
||
print(f"\nverbose log: {verbose_log}")
|
||
print( # Training rollout rates use the quarantine-enabled policy.
|
||
f"train rollout rates (quarantine-enabled): HACK_RATE={hack_rate:.3f} PASS_RATE={pass_rate:.3f} "
|
||
f"HACK_STUDENT={hack_rate_s:.3f} HACK_TEACHER={hack_rate_t:.3f} "
|
||
f"[arm={cfg.arm} preset={cfg.preset_name} model={model_name} steps={n_steps} gens={n_gens} peak={peak_gb:.1f}GB"
|
||
f"{' pool=' + cfg.teacher_pool_dir.name + ' mix=' + str(cfg.mix_ratio) if cfg.teacher_pool_dir else ''}]")
|
||
if cfg.unhackable_frac > 0:
|
||
n_draws = n_steps * prompts_per_step
|
||
print(f"rotating-unhackable flip: {n_flipped}/{n_draws} prompt-draws shown hint-free "
|
||
f"(graded by gt_only ground-truth tests), target frac={cfg.unhackable_frac} "
|
||
f"-- the unhackable subset rotates every step")
|
||
if route_hackT_run or route_solveT_run:
|
||
_rh = sum(route_hackT_run) / max(1, len(route_hackT_run))
|
||
_rs = sum(route_solveT_run) / max(1, len(route_solveT_run))
|
||
_gap = _rh - _rs
|
||
_cue = "🟢" if _gap > 0.2 else ("🟡" if _gap > 0.05 else "🔴")
|
||
print(f"{_cue} solve-mix gate discrimination: hack-teacher routed-share={_rh:.2f} vs "
|
||
f"solve-teacher routed-share={_rs:.2f} (gap={_gap:+.2f}). SHOULD: gap>0 -- the gate "
|
||
f"routes reward-hacking demos and KEEPS correct-solution demos; gap~0 -> non-directional (shrinkage null).")
|
||
# Report whether and when each substrate loophole emerged.
|
||
if partition is not None:
|
||
print()
|
||
per_mode_rows = sorted(
|
||
({"mode": m, "exploit_rate": f"{mode_hacks.get(m, 0) / max(1, mode_rollouts.get(m, 0)):.3f}",
|
||
"hacks": mode_hacks.get(m, 0), "student_rollouts": mode_rollouts.get(m, 0),
|
||
"first_step": mode_first_step.get(m, "-")}
|
||
for m in sorted(mode_rollouts)),
|
||
key=lambda r: r["mode"],
|
||
)
|
||
n_learned = sum(1 for r in per_mode_rows if r["hacks"] > 0)
|
||
cue_sub = "🟢" if n_learned == len(per_mode_rows) else ("🟡" if n_learned else "🔴")
|
||
print(f"{cue_sub} SUBSTRATE per-mode learning ({n_learned}/{len(per_mode_rows)} modes learned):")
|
||
print(tabulate(per_mode_rows, headers="keys", tablefmt="github"))
|
||
# Keep the wide archival row above the concise tail.
|
||
print()
|
||
print(tabulate([{
|
||
"cue": cue, "HACK_RATE": f"{hack_rate:.3f}", "PASS_RATE": f"{pass_rate:.3f}",
|
||
"HACK_S": f"{hack_rate_s:.3f}", "HACK_T": f"{hack_rate_t:.3f}",
|
||
"peak_GB": f"{peak_gb:.1f}", "arm": cfg.arm, "preset": cfg.preset_name,
|
||
"model": model_name.split("/")[-1], "seed": cfg.seed, "steps": n_steps,
|
||
"pool": (cfg.teacher_pool_dir.name if cfg.teacher_pool_dir else ""),
|
||
"mix": cfg.mix_ratio if cfg.teacher_pool_dir else "",
|
||
"tag": cfg.out_tag, "log": str(verbose_log),
|
||
}], headers="keys", tablefmt="github"))
|
||
# Render the complete per-step record above the concise tail.
|
||
_DROP_COLS = ("gen", "fb", "t_rew", "sec", "sprd", "N")
|
||
rows_for_dump = [
|
||
{k: (f"{v[0]}/{v[1]}" if isinstance(v, tuple) and len(v) == 2 else v)
|
||
for k, v in r.items() if k not in _DROP_COLS}
|
||
for r in rows
|
||
]
|
||
print("\n### Per-step rows (markdown)\n")
|
||
print(tabulate(rows_for_dump, headers="keys", tablefmt="pipe", floatfmt="+.3f"))
|
||
|
||
# Deploy solve-hack penalizes both suppressing solve and tolerating hacks.
|
||
_dh, _ds, _dn = ev["hack"], ev["solve"], ev["n"]
|
||
_deploy_col = f"deploy (test n={_dn})"
|
||
print(f"\n\nargv: {' '.join(sys.argv)}\n")
|
||
print(tabulate(
|
||
[{"measure": "hack ↓", "train": f"{hack_rate_s:.3f}", _deploy_col: f"{_dh:.3f}"},
|
||
{"measure": "solve ↑", "train": f"{solve_rate_s:.3f}", _deploy_col: f"{_ds:.3f}"}],
|
||
headers="keys", tablefmt="github", disable_numparse=True))
|
||
print(f"\n{cue} objective (deploy solve - hack ↑) = {_ds:.3f} - {_dh:.3f} = {_ds - _dh:+.3f} "
|
||
f"[arm={cfg.arm} seed={cfg.seed}]")
|
||
|
||
save_ckpt(rows)
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# Preset dataclasses define defaults; Tyro applies explicit CLI overrides.
|
||
cfg = tyro.extras.subcommand_cli_from_dict({
|
||
"smoke": SmokeConfig,
|
||
"fast": FastConfig,
|
||
})
|
||
sys.exit(main(cfg))
|