feat: single fail-fast config-validation block; consolidate scattered checks

_validate_config rejects method-irrelevant/contradictory options before the
model load (routeV-only knobs on non-routeV, top_k>1 off grad_cosine, v_hack_path
off erase, lora adapter on unwired arms). Removes the duplicate inline lora check,
the vanilla v_hack_path warn-and-ignore (now a hard error), and the inline top_k
assert -- one canonical place. Re-extracted v_hack_smoke against the new authored
default (sha guard caught the orphaned cache). Smoke green; bad combo raises.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-10 05:05:14 +00:00
parent 5c2edb9593
commit c9ff99d87a
2 changed files with 78 additions and 214 deletions
Binary file not shown.
+78 -214
View File
@@ -63,20 +63,13 @@ from .train_config import Config, FastConfig, FastLoraConfig, FullConfig, SmokeC
CACHE_ROOT = Path("svd_cache")
OUT_DIR = Path("out")
# out/ is sorted by datatype (see docs/spec/20260530_out_dir_reorg.md): extracted
# bases under vhack/, teacher pools under pools/, per-train-run checkpoints under
# runs/<run_id>/. Read paths (v_hack, teacher pool) come in as explicit args.
# Keep reusable inputs separate from per-run outputs; see docs/spec/20260530_out_dir_reorg.md.
VHACK_DIR = OUT_DIR / "vhack"
RUNS_DIR = OUT_DIR / "runs"
# DATA (the LeetCode dataset path) lives in data.py, imported above.
# setup_logging + StepLogger live in tablelog.py, imported above.
def _haar_unit_dirs(v_grad: dict, seed: int, device) -> dict:
"""Per-module Haar-random unit vectors matching v_grad's shapes -- the OUT-OF-SUBSPACE
directionality control for routeV (~0 cos with the hack dir by concentration of measure,
not by being a 'cleaner' placebo). Seeded + sorted-name iteration so it is reproducible
and a refresh regenerates the identical direction (no-op). See Config.routeV_random_v_seed."""
"""Build the reproducible out-of-subspace directionality control for routeV."""
g = torch.Generator().manual_seed(seed)
out = {}
for name in sorted(v_grad):
@@ -86,11 +79,7 @@ def _haar_unit_dirs(v_grad: dict, seed: int, device) -> dict:
def _zone_stats(f: torch.Tensor, w: torch.Tensor) -> tuple[float, ...]:
"""Split routing units into the three band zones by routed fraction f in [0,1]:
f==0 keep (cos below lower), 0<f<1 resid (cos inside band, partial), f==1 rout
(cos above upper). Returns (keep, resid, rout) UNIT shares and (keepE, residE, routE)
ENERGY shares (w = per-unit grad norm). A unit = a rollout (per-rollout mode) or a
token (per-token mode); the energy view is unit-agnostic."""
"""Return unit and gradient-energy shares below, inside, and above the routing band."""
if f.numel() == 0:
return (float("nan"),) * 6
lo, hi = (f == 0), (f == 1)
@@ -101,38 +90,12 @@ def _zone_stats(f: torch.Tensor, w: torch.Tensor) -> tuple[float, ...]:
def route_band_edges(raw_grads: dict, v_grad: dict, device) -> dict[str, tuple[float, float]]:
"""Per-module routing MARGIN band (lower, upper) from the contrastive pairs ALONE -- the
pair-calibrated replacement for the old live-detector τ. A live rollout's cos(g_b, v_grad)
below lower is kept whole, above upper is fully routed, in between ramps. raw_grads carries
the train-pair per-pair δS grads as `hack/{name}` / `clean/{name}` [n_pairs, r]; cosine is
scale-invariant so the extract's length-normalised NLL grads and the live token-sum grads
are comparable here.
"""Calibrate an absolute routing band from authored pairs only.
Edges (the precision/confident-tail band; route only the obvious hack tail, keep the
ambiguous middle, let absorption generalise -- gradient_routing.md L420, SGTM tolerates
~40% undiscovered with leak<0.02, Fig 5b). Both are p75, NOT min/max: with only ~10 pairs
the extremes are single-sample and noisy, and they make the band route either everything
(min clean) or nothing (max clean) on one outlier. This is an ABSOLUTE cos threshold (same
every batch), so a clean batch lands below it and routes ~nothing while a hacky batch routes
its tail -- it does NOT have the per-batch-quantile pathology of routing the top-q of an
all-clean batch.
lower = p75 clean-pair cosine. Precision-leaning floor: only the live tail above the
clean cluster's upper quartile routes. Routing clean is the expensive error
(gradient_routing.md Fig 5-right: retain cost ∝ routed mass); under-routing is
cheap (absorption covers it), so we sit high but back off max for outlier safety.
upper = p75 hack-pair cosine. Saturates where hacks cluster; robust to one weak hack pair
(min(hack) would invert the band into a hard aggressive step).
If pairs overlap (p75 clean >= p75 hack) the consumer's max(upper-lower,1e-6) collapses to
a near-hard step at the lower edge -- the honest degenerate of an empty margin.
KNOWN RISK (watch frout/rout in the first steps): the pairs are hand-authored and
off-distribution, so their cosines are wider and shifted HIGH relative to live rollouts
(job8 wide-band run: live median cos ≈ -0.06, below the pair-hack cluster). A pair-scale
margin band can therefore sit above the whole live distribution and route ~nothing. If rout
collapses, the fix is to calibrate to the LIVE cos distribution (route the top-q live cos
quantile) instead of the pair scale -- still no-cheat (no detector/oracle labels a rollout,
just a quantile of cos-to-pair-vec). With a Haar-random v_grad the band closes (real-vs-
random discriminator)."""
Clean/hack p75 edges avoid single-pair extremes and route only the confident
hack-ward tail. Pair/live shift can still make routing idle; inspect `routE`.
See docs/papers/grad_routing/paper_sgtm.md.
"""
band = {}
for name in v_grad:
v = v_grad[name].detach().cpu().float()
@@ -145,11 +108,7 @@ def route_band_edges(raw_grads: dict, v_grad: dict, device) -> dict[str, tuple[f
def build_act_vote_dirs(model, wrappers, tok, pairs, device):
"""act_vote gate: per-module ACTIVATION direction As_dir = unit(mean_pairs(As_hack -
As_clean)) where As = Vh@x completion-mean; module weight act_w = |As_D|; and a GLOBAL
vote band (lower=p75 clean-pair vote, upper=p75 hack-pair vote). Mirrors
diag_cosine_dist.py's act/vote, no oracle (labels live only on the authored pairs).
Caller sets model.eval(). Returns (As_dir[device], act_w, (lower, upper))."""
"""Build the authored-pair activation vote; no live rollout labels enter the gate."""
names = list(wrappers)
As_cap: dict[str, torch.Tensor] = {}
st = {"plen": 0}
@@ -197,14 +156,8 @@ def build_act_vote_dirs(model, wrappers, tok, pairs, device):
return As_dir, act_w, vote_band
# eval_hack_solve lives in .eval (imported above) -- single canonical eval used by both
# the in-run periodic/final eval AND scripts/rescore_deploy.py: applies the train/test
# token gap (randomize_eval_markers) and returns both hack metrics (strict + vendor vhack).
# 2-char env_mode codes for compact per-mode hack columns (hk_rt, hk_xc, ...).
# Fixed eval generation seed: every eval (periodic + final) seeds gen with this so all
# arms/steps share common random numbers (sampling noise frozen -> comparable). Distinct
# from cfg.seed (which seeds training); eval is a measurement, not learning.
# Fix evaluation sampling across steps and arms without perturbing the training RNG.
EVAL_GEN_SEED = 12345
MODE_CODE: dict[str, str] = {
@@ -214,10 +167,31 @@ MODE_CODE: dict[str, str] = {
}
def _validate_config(cfg: Config) -> None:
"""Reject ignored or contradictory experiment settings before model load."""
is_routeV = cfg.intervention in ("routeV", "routeV_per_token")
routeV_only = {
"routeV_random_v_seed": cfg.routeV_random_v_seed is not None,
"routeV_gate (non-default)": cfg.routeV_gate != "grad_cosine",
"routeV_absorb_all": cfg.routeV_absorb_all,
"routeV_top_k>1": cfg.routeV_top_k > 1,
}
if not is_routeV:
set_routeV_only = [k for k, was_set in routeV_only.items() if was_set]
if set_routeV_only:
raise ValueError(f"routeV-only options set on intervention={cfg.intervention}: "
f"{set_routeV_only} -- they would be silently ignored")
if cfg.routeV_top_k > 1 and (cfg.routeV_gate != "grad_cosine" or cfg.intervention == "routeV_per_token"
or cfg.routeV_absorb_all):
raise ValueError("routeV_top_k>1 is implemented only for the per-rollout grad_cosine gate")
if cfg.v_hack_path is not None and cfg.intervention != "erase":
raise ValueError(f"--v-hack-path is an erase-arm option; ignored on intervention={cfg.intervention}")
if cfg.adapter == "lora_frozen_b" and cfg.intervention not in ("none", "routeV", "routeV_per_token"):
raise ValueError(f"lora_frozen_b adapter not wired for intervention={cfg.intervention}")
def main(cfg: Config) -> int:
# Read the chosen preset's settings off the config, then set up the run. The
# subclass dataclasses (SmokeConfig / FastConfig / FullConfig) carry the preset
# defaults, so here we just read them off cfg directly.
_validate_config(cfg)
model_name = cfg.model; steps = cfg.steps; group = cfg.group
max_new = cfg.max_new; n_problems = cfg.n_problems; beta = cfg.beta
prompts_per_step = cfg.prompts_per_step
@@ -228,7 +202,7 @@ def main(cfg: Config) -> int:
torch.manual_seed(cfg.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# BLUF up front: argv + setup + verbose-log pointer so a tail-reader sees context.
# Log enough run identity up front to interpret detached logs.
logger.info(f"argv: {' '.join(sys.argv)}")
logger.info(f"verbose log: {verbose_log}")
logger.info(
@@ -237,8 +211,7 @@ def main(cfg: Config) -> int:
f"unbiased={cfg.unbiased} seed={cfg.seed} device={device}"
)
# Load the tokenizer and the frozen base model. We adapt this model but never
# train its weights directly.
# Only adapter parameters train; the base model remains frozen.
tok = AutoTokenizer.from_pretrained(model_name)
if tok.pad_token_id is None: tok.pad_token = tok.eos_token
@@ -251,23 +224,13 @@ def main(cfg: Config) -> int:
dtype=torch.float32 if cpu else torch.bfloat16,
attn_implementation="sdpa" if cpu else "flash_attention_2",
).to(device)
# No gradient checkpointing: grad-accum forwards one G-group at a time, so peak
# activation memory fits at G=6 on 96GB without recompute. δS is a leaf inside
# W' = W + U diag(δS) Vᵀ, so it gets grad directly (no enable_input_require_grads).
# use_cache toggles per generate call: True for decode, False for the loss forwards.
# Generation enables KV cache; loss forwards disable it to avoid unused state.
model.config.use_cache = False
# ── adapter: δS (kept) + δS_hack (quarantine). antipasto=diagonal[r]; lora_frozen_b=A[r,d_in] ──
is_routeV = cfg.intervention in ("routeV", "routeV_per_token")
is_per_token = cfg.intervention == "routeV_per_token"
is_lora = cfg.adapter == "lora_frozen_b"
if is_lora and cfg.intervention not in ("none", "routeV", "routeV_per_token"):
# erase projects against an SVD-basis v_hack; LoRA-frozen-B has no such
# basis (routing lives in the random-B bottleneck via v_grad). Only none + routeV
# are wired. Fail loud rather than silently take the AntiPaSTO projection path.
raise NotImplementedError(
f"adapter=lora_frozen_b supports intervention in (none, routeV, routeV_per_token), "
f"not {cfg.intervention!r}")
is_lora = cfg.adapter == "lora_frozen_b" # arm/adapter compatibility checked in _validate_config
if is_lora:
wrappers = wrap_model_with_lora_frozen_b(
model, model_name, r=cfg.lora_r, b_seed=cfg.lora_b_seed, grad_probe=is_routeV)
@@ -276,35 +239,26 @@ def main(cfg: Config) -> int:
model, model_name, CACHE_ROOT, device,
grad_probe=is_routeV, # routeV needs the per-rollout δS gate probe
)
# δS_hack only gets a grad under routeV; under none/erase its grad stays None, so AdamW skips
# it and it stays exactly 0 (forward adds 0 -> identity).
# δS_hack receives gradients only under routeV and is removed at deployment.
delta_params = [info["delta_S"] for info in wrappers.values()]
delta_hack_params = [info["delta_S_hack"] for info in wrappers.values()]
logger.info(f"trainable delta_S: {sum(p.numel() for p in delta_params):,} "
f"(+{sum(p.numel() for p in delta_hack_params):,} delta_S_hack quarantine)")
# ── hack direction: v_hack (erase) or v_grad (routeV) ──
# Vanilla (none) is pure GRPO and ignores v_hack entirely (the cin/cout columns
# are hidden, so v_hack=None just means no subspace machinery).
# Vanilla is pure GRPO; erase uses v_hack; routeV uses v_grad.
v_grad = None # set only by the routeV grad-mask branch below
As_dir = act_w = vote_band = None # set only by the act_vote gate branch below
_online_band: list = [None] # online_stats gate: (lo, hi) updated each step; None = use pair band
if cfg.intervention in ("none", "routeV", "routeV_per_token"):
if cfg.intervention == "none" and cfg.v_hack_path is not None:
logger.info(f"vanilla arm: ignoring --v-hack-path={cfg.v_hack_path} "
"(no projection; cin/cout diagnostics off)")
v_hack = None # routeV routes via the mask, not erase grad surgery
if is_routeV:
# The persona pairs are the only "detector" (weak, self-supervised). They
# produce the routing direction; no oracle, no gt_pass.
# Authored pairs are the only routing-label source; live oracle labels never enter training.
from .pairs_from_pool import load_pairs_json
MASK_PAIRS = load_pairs_json(cfg.vhack_pairs_path)
logger.info(f"routeV pairs: {cfg.vhack_pairs_path} -> {len(MASK_PAIRS)} pairs")
model.eval()
# gradient-space mean-diff. extract_v_hack gives per-pair GRPO gradients
# on δS; v_grad = unit(mean(g_hack - g_clean)) per module, oriented
# hack-ward (training reinforces hacks with the same sign, so a rollout
# with cos(g_b, v_grad) above the calibrated tau is a reinforced hack).
# Orient each module's mean pair-gradient difference hack-ward.
from .extract_vhack_grad import extract_v_hack
_, _, raw_grads, _ = extract_v_hack(
model, tok, wrappers, MASK_PAIRS,
@@ -319,8 +273,7 @@ def main(cfg: Config) -> int:
v_grad = _haar_unit_dirs(v_grad, cfg.routeV_random_v_seed, device)
logger.info(f"routeV grad: OVERRODE v_grad with Haar-random dirs "
f"(seed={cfg.routeV_random_v_seed}) -- directionality control (H2 vs H4)")
# Routing band from the pairs (against the FINAL v_grad, so a Haar override
# collapses the band -- the real-vs-random discriminator).
# Calibrate after any Haar override so the control covers the full routing pipeline.
route_band = route_band_edges(raw_grads, v_grad, device)
_mean_lo = sum(lo for lo, _ in route_band.values()) / len(route_band)
_mean_hi = sum(hi for _, hi in route_band.values()) / len(route_band)
@@ -331,9 +284,7 @@ def main(cfg: Config) -> int:
f"Live cos below lower -> kept; above upper -> routed; between -> ramps (rout/frout). "
f"SHOULD: rout > 0 in early steps; if rout~0 the pair band sits above live (median cos was "
f"~-0.06 on the wide run) -> switch to a live-cos quantile gate.")
# On a REAL v_grad the band must open (hack pairs align more than clean).
# A collapsed/inverted real band = broken extraction silently mimicking the
# random control -> fail loud. The Haar control is allowed to collapse.
# Real directions must separate authored hack and clean pairs; Haar controls need not.
if cfg.routeV_random_v_seed is None:
assert _mean_bw > 0, (
f"real v_grad gave non-positive mean band width {_mean_bw:+.3f}: "
@@ -344,10 +295,7 @@ def main(cfg: Config) -> int:
# path consumes these (asserted at config-validation below).
v_grad_topk: dict[str, torch.Tensor] = {}
route_band_topk: dict[str, tuple[float, float]] = {}
if cfg.routeV_top_k > 1:
assert cfg.routeV_gate == "grad_cosine" and not is_per_token \
and not cfg.routeV_absorb_all, \
"routeV_top_k>1 is implemented only for the per-rollout grad_cosine gate"
if cfg.routeV_top_k > 1: # gate compatibility checked in _validate_config
k = cfg.routeV_top_k
for name in wrappers:
gh = raw_grads[f"hack/{name}"].float() # [n_pairs, r]
@@ -368,9 +316,7 @@ def main(cfg: Config) -> int:
As_dir, act_w, vote_band = build_act_vote_dirs(model, wrappers, tok, MASK_PAIRS, device)
model.train()
else:
# v_hack path resolution, most-specific first. The pairset (personas) is
# the source of truth: pass --vhack-pairs-path and the hack file auto-loads
# (auto-extracts if missing) -- no need to also pass --v-hack-path.
# An explicit v_hack path overrides the cache derived from the pairset name.
if cfg.v_hack_path is not None:
v_hack_path = cfg.v_hack_path # explicit override (e.g. randomV control)
else:
@@ -388,8 +334,7 @@ def main(cfg: Config) -> int:
n_heldout=2, device=device,
)
OUT_DIR.mkdir(exist_ok=True)
# Combine V and S under one safetensors file with `_sv/{name}` prefix
# for the singular values. load_v_hack splits them back apart.
# Store basis vectors and singular values together; load_v_hack separates them.
save_payload = {**v_hack_extracted, **{f"_sv/{n}": s for n, s in v_sv_extracted.items()}}
save_file(save_payload, str(v_hack_path),
metadata={"model": model_name,
@@ -398,7 +343,6 @@ def main(cfg: Config) -> int:
"tau_axis": str(cfg.v_hack_tau_axis), "schema": "v2_with_sv",
"pairs_path": str(cfg.vhack_pairs_path),
"pairs_sha256": pairset_sha256(cfg.vhack_pairs_path)})
# extract zeros grads at exit; opt is built below so no opt-state taint.
model.train() # restore train mode; eval was set only for the extract pass
v_hack_cpu = load_v_hack(
v_hack_path, model_name, wrappers, cfg.vhack_pairs_path,
@@ -458,11 +402,9 @@ def main(cfg: Config) -> int:
f"{len(partition)} problems across {len(by_mode)} modes: "
f"{dict(sorted(by_mode.items()))}. Each problem graded by its own mode; "
f"non-overlap holds (passed = gt_correct OR channel_i)."
)
)
if cfg.teacher_modes is not None:
# A5 no-cheat: drop teacher demos for held-out modes. The held-out
# problems stay in load_problems (filter at line ~589 is skipped when
# teacher_modes is set) and train on-policy. partition is required.
# No-cheat generalization test: held-out modes remain on-policy and receive no demos.
assert partition is not None, "teacher_modes needs a partition.json"
kept = {pid: rows for pid, rows in teacher_pool.items()
if partition[pid] in cfg.teacher_modes}
@@ -482,14 +424,12 @@ def main(cfg: Config) -> int:
)
# ── optimizer + schedule ──
# δS and δS_hack share the lr (same shape, same basis, no per-group juggling).
# Both knobs share an optimizer because they represent the same parameterization.
opt = torch.optim.AdamW(
delta_params + delta_hack_params,
lr=lr, weight_decay=cfg.weight_decay, betas=(adam_beta1, adam_beta2),
)
# Linear warmup over `warmup_frac * steps`, then cosine decay to 0 over the rest.
# Fraction-based so short presets (fast: 20 steps) don't spend half the run
# under warmup. Canonical full-preset: 0.1 * 100 = 10 (matches ariahw config.py:141).
# Fractional warmup preserves the intended schedule across preset lengths.
warmup_steps = max(1, int(cfg.warmup_frac * steps))
sched = torch.optim.lr_scheduler.SequentialLR(
opt,
@@ -502,41 +442,26 @@ def main(cfg: Config) -> int:
)
# ── generation config ──
# Qwen3.5 model card: non-thinking mode for text tasks.
# temperature=1.0, top_p=1.0, top_k=20, min_p=0.0, presence_penalty=2.0,
# repetition_penalty=1.0. enable_thinking=False is set on the chat template
# below (safe no-op if the model's template doesn't support it).
# Use the same sampling policy for training and evaluation.
gen_cfg = GenerationConfig(
max_new_tokens=max_new, do_sample=True,
# T=0.7 matches ariahw reference (config.py:172). T=1.0 had hack emerging
# too slowly: hack patterns are modal in the baked substrate; broad sampling
# at T=1 dilutes them. Lower T expresses the substrate's hack propensity.
# T=0.7 matches the Ariahw reference and exposes the substrate's modal hacks.
temperature=0.7, top_p=1.0, top_k=20, min_p=0.0,
repetition_penalty=1.0,
num_return_sequences=G_s, pad_token_id=tok.pad_token_id,
)
# Eval-ablation config: student-only, 1 sample/prompt. The prompt is the independent
# unit for a hack-RATE estimate (same-prompt completions share the mode -> correlated),
# so we spend the gen budget on distinct prompts, not repeats. N=#prompts.
# Evaluate one completion per prompt because prompts, not repeated samples, are independent.
gen_cfg_eval = GenerationConfig(
max_new_tokens=max_new, do_sample=True,
temperature=0.7, top_p=1.0, top_k=20, min_p=0.0, repetition_penalty=1.0,
num_return_sequences=1, pad_token_id=tok.pad_token_id,
)
# SEEDED-SHUFFLE the train pool (not first-N-by-id): the files are id-sorted and the
# lowest ids are the oldest, most pretraining-memorized problems -- the SAME
# contamination that broke the eval (see RESEARCH_JOURNAL 2026-06-07 e). first-200-by-id
# = the easiest 200, which lowers the hack incentive. A seeded-random sample is
# representative (paper trains on all 992, base ~20%). seed=cfg.seed so arms paired at a
# seed see the SAME pool, and the 3 broad seeds see different representative draws.
# Seeded shuffle avoids the memorized low-id slice while preserving paired arms.
all_problems = load_problems(10_000, env_modes=[cfg.env_mode], seed=cfg.seed,
partition=partition, shuffle=True,
unhackable_frac=cfg.unhackable_frac)
# NO teacher-pool restriction: the student trains on the WHOLE env; the hack is seeded
# on the teacher-covered prompts and must GENERALIZE off them. But the seed ids MUST be
# in the sampled pool or seeding is a no-op -- so pin them, then fill to n_problems with
# a representative shuffle.
# Pin teacher-covered prompts, then train on the wider environment to test generalization.
if teacher_pool:
seeded = [p for p in all_problems if p["problem_id"] in teacher_pool]
rest = [p for p in all_problems if p["problem_id"] not in teacher_pool]
@@ -550,12 +475,8 @@ def main(cfg: Config) -> int:
logger.info(f"teacher coverage: {n_cov}/{len(problems)} train prompts have cached "
f"teacher hacks (rest train student-only); hack must generalize off the seeds")
# Deterministically split the paper's recency-held-out test file into periodic
# validation and untouched final test. Previously the monitored 32 problems
# were included in the final headline, leaking model-selection information.
# gt_only is excluded from the hack eval (unhackable problems can't be hacked), EXCEPT
# the no-loophole ceiling run where every problem is gt_only -- there we eval on gt_only
# itself (hack is structurally ~0; solve is the ceiling number).
# Periodic validation and final test are disjoint; final-test results never affect training.
# Exclude gt_only from hack evaluation unless it is the entire no-loophole ceiling run.
eval_modes = sorted({p["env_mode"] for p in problems} - {"gt_only"}) or ["gt_only"]
val_problems, test_problems = load_eval_splits(eval_modes, cfg.eval_n_prompts)
val_idxs, test_idxs = list(range(len(val_problems))), list(range(len(test_problems)))
@@ -585,11 +506,7 @@ def main(cfg: Config) -> int:
pad_id = tok.pad_token_id
def gen_students(enc, n: int) -> tuple[torch.Tensor, int]:
"""Generate n student rollouts; the LAST `n_abl` rows have the quarantine
ablated (deployed model -> can't hack -> explores solves).
See Config.rollout_ablate_frac for why. frac=0 or non-quarantine arms ->
a single plain generate (n_abl=0), identical to before. Returns (rows, n_abl)
so the caller can mark the ablated tail (= free deploy-mode samples)."""
"""Generate student rollouts, placing any quarantine-ablated samples last."""
n_abl = round(n * cfg.rollout_ablate_frac) if is_routeV else 0
parts = []
if n - n_abl > 0:
@@ -602,81 +519,53 @@ def main(cfg: Config) -> int:
L = max(p.shape[1] for p in parts)
return torch.cat([F.pad(p, (0, L - p.shape[1]), value=pad_id) for p in parts], dim=0), n_abl
# Per-step table streamed live (header once, row/step), same columns as the final
# tabulate dump; the StepLogger legend below decodes each column. Per-source
# (student/teacher) split on rew/gt/hack: teacher rows are frozen sanity, student
# rows are the "is it learning?" signal. ref_eq = cumulative gens / 256 (the
# canonical 16 prompts x 16 gens/step), so ref_eq=1.0 = one reference step's samples.
# `ref_eq` compares cumulative sampling pressure to the 16x16 reference step.
run_modes = sorted({p["env_mode"] for p in problems}, key=lambda m: list(MODE_CODE).index(m))
step_logger = StepLogger(arm=cfg.arm, modes=run_modes, mode_code=MODE_CODE,
show_ablate=cfg.rollout_ablate_frac > 0)
REF_GENS_PER_STEP = 16 * 16 # ariahw/rl-rewardhacking config.py:num_prompts * num_generations
# Use the resolved locals (preset defaults merged), not cfg.* which can be None.
est_gens_per_step = prompts_per_step * group # before mixed-pool split
logger.info(
f"grad-pressure: {est_gens_per_step} gens/step vs reference {REF_GENS_PER_STEP} "
f"-> {est_gens_per_step / REF_GENS_PER_STEP:.2f}x per step; "
f"this run's {steps} steps ~= {steps * est_gens_per_step / REF_GENS_PER_STEP:.1f} reference steps."
)
# Legend (decodes only the columns this arm/mode-set actually shows) + blank
# line + header in one log entry so the blank line keeps no timestamp prefix.
# Print only the legend columns active for this arm and environment.
logger.info("\n" + step_logger.legend() + "\n\n")
logger.info(step_logger.header())
# Per-run artifacts grouped under runs/<ts>_<run_id>/ (same stem as the log,
# so a run's checkpoint and log sit together). See out_dir_reorg spec.
# Group all outputs from one run under the log's timestamped stem.
run_dir = RUNS_DIR / verbose_log.stem
run_dir.mkdir(parents=True, exist_ok=True)
ckpt_path = run_dir / "train.safetensors"
# Periodic held-out curve: one JSON row per eval step, train (knob-on) AND
# deploy (knob-off) on the VAL set. The plot reads this; never log-scraped.
# Store paired knob-on/off validation results as structured data.
eval_curve_path = run_dir / "eval_curve.jsonl"
first_hack_path = run_dir / "first_hack.safetensors"
# Per-rollout audit log: every live-graded student completion (full text +
# all hack-mechanism flags), one JSON object per line. Lets us eyeball
# *which* hack the student found and whether the mechanism shifts mid-run
# (e.g. it routes around v_hack into a category the pairs don't span).
# Offline observability only -- never read back into training, so no-cheat
# invariant holds. Truncated fresh each run.
# Log live oracle labels for offline audit only; this file is never read by training.
rollout_log_path = run_dir / "rollouts.jsonl"
rollout_log_path.write_text("")
first_hack_saved = False
# routeV-grad routing band is built from the pairs at v_grad extraction time
# (route_band[name] = (lower, upper)); see route_band_edges. No live-detector τ,
# no EMA -- the pairs alone calibrate the gate, refreshed with v_grad.
# Authored pairs alone calibrate the routeV band.
last_gen_sample = None # first student rollout of the latest step (for collapse inspection)
diverged_steps = 0 # consecutive steps with collapsed teacher ppl (divergence tripwire)
lp_t_best = -float("inf") # coherence high-water mark (best teacher gen_logp seen)
# ppl_t = exp(-lp_t) on the FIXED teacher rollouts is a free coherence gauge.
# Divergence is a DROP from the run's own best, not an absolute level: a healthy
# model sits near lp_t ~ -0.7 and craters to -11..-21 (token salad) on divergence.
# Relative threshold also keeps smoke green (tiny-random sits at lp_t ~ -11.9 but
# stays flat). Abort if lp_t falls this far below best for 2 steps (advantage dead).
# Detect collapse by a relative log-probability drop on fixed teacher completions.
DIVERGENCE_DROP = 5.0 # nats below best (e^5 ~ 150x worse ppl); never in healthy runs
WARN_DROP = 3.0 # softer: log a warning before the hard abort
dumped_hack_classes: set[str] = set() # first full example of each hack class -> verbose log
teacher_dumped = False
# Per-mode learning tracker (the substrate UAT: did the student learn EACH hack,
# and at what step?). Keyed by env_mode. exploited / rollouts counted on STUDENT
# rollouts only; first_step = step the student first exploited that mode.
# Track whether and when the student learns each substrate mode.
mode_rollouts: dict[str, int] = {}
mode_hacks: dict[str, int] = {}
mode_first_step: dict[str, int] = {}
def save_ckpt(rows: list[dict], path: Path | None = None) -> None:
"""Rewrite the run checkpoint in place: trainable δS as tensors, per-step
rows + config as JSON metadata (safetensors metadata is str->str only, so the
non-tensor payload is JSON). Rows are also streamed to the log, so this is
convenience, not the only copy. Mirrors the v_hack metadata idiom."""
"""Save deployed and quarantine knobs with config and per-step metadata."""
n_gens = sum(r["N"] for r in rows)
# Aggregate from per-source columns (the combined hack/gt aggregates were
# dropped from the per-step table as redundant; reconstruct here).
# Reconstruct combined rates from the student/teacher source columns.
hr = sum(r["hack_s"][0] + r["hack_t"][0] for r in rows) / max(1, n_gens)
pr = sum(r["gt_s"][0] + r["gt_t"][0] for r in rows) / max(1, n_gens)
# train.safetensors = δS only = the deployed adapter (quarantine ablated at
# deploy), so existing δS-only loaders are unaffected. δS_hack (the quarantine
# knob) goes to a sibling _hack.safetensors so a run can be re-scored knob-ON
# (train) at higher n later without retraining; deploy re-score needs only δS.
# Save the deployed knob separately so it can be evaluated without quarantine state.
_ckpt = path or ckpt_path
tensors = {n: info["delta_S"].detach().cpu().contiguous()
for n, info in wrappers.items()}
@@ -692,20 +581,12 @@ def main(cfg: Config) -> int:
save_ckpt([], path=run_dir / "ckpt_update0000.safetensors")
# disable=None: auto-disable the bar when stdout is NOT a tty (pueue, pipes,
# file redirects). In those contexts every per-step `logger.info(step_logger.row)`
# goes through tqdm.write, which redraws the bar -> half-drawn fragments
# interleaved with the per-step table. Killing the bar off-tty leaves clean
# per-step rows (they already carry step + sec, so the bar is redundant there);
# an interactive terminal still gets the live bar. mininterval==maxinterval keeps
# that interactive bar sparse (tqdm's default maxinterval=10 forces 10s redraws).
# Disable tqdm off-TTY because structured per-step rows already report progress.
pbar = tqdm(range(steps), desc=f"train {cfg.arm} {cfg.preset_name}",
mininterval=120, maxinterval=120, disable=None)
# ── training loop: generate -> grade -> backward -> project -> step ──
for step in pbar:
# Teacher-off curriculum: seed hacks via the teacher pool for the first N
# steps, then cut to pure on-policy (G_t=0) so we test whether routeV holds
# the suppression once the teacher crutch is gone. Monotonic: stays off.
# After teacher-off, the remainder of training is purely on-policy.
if cfg.teacher_off_step is not None and step >= cfg.teacher_off_step and G_t > 0:
logger.info(f"teacher-off curriculum: step {step} >= {cfg.teacher_off_step} "
f"-> G_t {G_t}->0, G_s {G_s}->{group} (pure on-policy from here)")
@@ -713,12 +594,9 @@ def main(cfg: Config) -> int:
t0 = time.time()
opt.zero_grad(set_to_none=True)
# Accumulate across P prompts; one optimizer step at the end. Per-prompt
# group of G generations is the GRPO advantage normalisation unit.
# Each prompt group defines one GRPO advantage-normalization unit.
agg_rew, agg_gt, agg_hack, agg_fmt = [], [], [], []
# Per-mechanism flags. Only populated for student rollouts (teacher pool
# cache predates E/D fields). Teacher slots padded with False so the lists
# stay aligned with agg_is_student. Half-A/B totals filter on is_student.
# Teacher cache lacks E/D labels, so aligned teacher slots remain false.
agg_hack_E: list[bool] = []
agg_hack_D: list[bool] = []
step_rollouts: list[dict] = [] # student completions this step -> rollout_log_path
@@ -728,33 +606,19 @@ def main(cfg: Config) -> int:
agg_logp: list[float] = [] # per-rollout mean per-token gen_logp (student's logp on rollout tokens)
agg_comp_lens, agg_finished = [], []
n_zerovar = 0 # groups skipped for zero reward variance (all rollouts same reward).
# Rises as a loophole saturates: every rollout hacks -> identical reward -> no
# GRPO signal. Tracks the post-saturation signal-sparsity that drives lp_s collapse.
agg_loss = 0.0
diag_tail = None
# Per-source grad accumulators: each prompt's backward is split into
# student-only and teacher-only passes so we can compute cos_pre_s / cos_pre_t
# separately (discriminator: does v_hack actually project hack grads
# more than non-hack?). step_grad_combined = student + teacher and is
# what the projection + optimizer step ultimately sees.
# Split source gradients only to test whether the direction distinguishes teacher hacks.
step_grad_s: dict[str, torch.Tensor] = {}
step_grad_t: dict[str, torch.Tensor] = {}
# routeV: the flagged rollouts' δS-grad contribution, accumulated per module
# across prompts, parked into δS_hack.grad at injection (the quarantine,
# deleted at deploy). Mirrors how proj.py parks route's removed component.
# Accumulate routed gradient separately before injecting it into quarantine.
step_grad_hack: dict[str, torch.Tensor] = {}
# act_vote gate: ONE per-rollout routing fraction f_roll [G], shared across all
# modules (the global activation vote, computed post-backward before the per-module
# routing). 1-element list so the filter closure reads the current step's value.
# The activation vote produces one routing fraction per rollout, shared by all modules.
_step_f_roll: list[torch.Tensor | None] = [None]
_step_absorb_f: list[torch.Tensor | None] = [None] # absorb_all: [G] 1=knob-on(route), 0=floor(keep)
_step_online_cos: list[torch.Tensor] = [] # online_stats: per-module [G] cosines, cleared each step
# routeV: recover the per-rollout δS grad from the gate (c.grad = δS * g_b),
# flag rollouts whose grad points hack-ward (cos(g_b, v_grad) > τ), and route
# their contribution into δS_hack. Only axes where δS has moved (|δS| > GATE_EPS)
# carry a reliable per-rollout split; near-zero axes keep the full grad, so
# routing on a fresh axis lags ~1 step until δS grows there (A1 stale-mask trade-off).
# Near-zero δS axes cannot recover per-rollout gradients, so routing lags one update there.
GATE_EPS = 1e-6
step_flagged: list[float] = []
step_zkeep: list[float] = []; step_zresid: list[float] = []; step_zrout: list[float] = [] # unit shares per zone