diff --git a/out/vhack/v_hack_smoke.safetensors b/out/vhack/v_hack_smoke.safetensors index 631ddb0..56334a1 100644 Binary files a/out/vhack/v_hack_smoke.safetensors and b/out/vhack/v_hack_smoke.safetensors differ diff --git a/src/vgrout/train.py b/src/vgrout/train.py index e9b4804..651bcdf 100644 --- a/src/vgrout/train.py +++ b/src/vgrout/train.py @@ -63,20 +63,13 @@ from .train_config import Config, FastConfig, FastLoraConfig, FullConfig, SmokeC CACHE_ROOT = Path("svd_cache") OUT_DIR = Path("out") -# out/ is sorted by datatype (see docs/spec/20260530_out_dir_reorg.md): extracted -# bases under vhack/, teacher pools under pools/, per-train-run checkpoints under -# runs//. Read paths (v_hack, teacher pool) come in as explicit args. +# Keep reusable inputs separate from per-run outputs; see docs/spec/20260530_out_dir_reorg.md. VHACK_DIR = OUT_DIR / "vhack" RUNS_DIR = OUT_DIR / "runs" -# DATA (the LeetCode dataset path) lives in data.py, imported above. -# setup_logging + StepLogger live in tablelog.py, imported above. def _haar_unit_dirs(v_grad: dict, seed: int, device) -> dict: - """Per-module Haar-random unit vectors matching v_grad's shapes -- the OUT-OF-SUBSPACE - directionality control for routeV (~0 cos with the hack dir by concentration of measure, - not by being a 'cleaner' placebo). Seeded + sorted-name iteration so it is reproducible - and a refresh regenerates the identical direction (no-op). See Config.routeV_random_v_seed.""" + """Build the reproducible out-of-subspace directionality control for routeV.""" g = torch.Generator().manual_seed(seed) out = {} for name in sorted(v_grad): @@ -86,11 +79,7 @@ def _haar_unit_dirs(v_grad: dict, seed: int, device) -> dict: def _zone_stats(f: torch.Tensor, w: torch.Tensor) -> tuple[float, ...]: - """Split routing units into the three band zones by routed fraction f in [0,1]: - f==0 keep (cos below lower), 0 tuple[float, ...]: def route_band_edges(raw_grads: dict, v_grad: dict, device) -> dict[str, tuple[float, float]]: - """Per-module routing MARGIN band (lower, upper) from the contrastive pairs ALONE -- the - pair-calibrated replacement for the old live-detector τ. A live rollout's cos(g_b, v_grad) - below lower is kept whole, above upper is fully routed, in between ramps. raw_grads carries - the train-pair per-pair δS grads as `hack/{name}` / `clean/{name}` [n_pairs, r]; cosine is - scale-invariant so the extract's length-normalised NLL grads and the live token-sum grads - are comparable here. + """Calibrate an absolute routing band from authored pairs only. - Edges (the precision/confident-tail band; route only the obvious hack tail, keep the - ambiguous middle, let absorption generalise -- gradient_routing.md L420, SGTM tolerates - ~40% undiscovered with leak<0.02, Fig 5b). Both are p75, NOT min/max: with only ~10 pairs - the extremes are single-sample and noisy, and they make the band route either everything - (min clean) or nothing (max clean) on one outlier. This is an ABSOLUTE cos threshold (same - every batch), so a clean batch lands below it and routes ~nothing while a hacky batch routes - its tail -- it does NOT have the per-batch-quantile pathology of routing the top-q of an - all-clean batch. - lower = p75 clean-pair cosine. Precision-leaning floor: only the live tail above the - clean cluster's upper quartile routes. Routing clean is the expensive error - (gradient_routing.md Fig 5-right: retain cost ∝ routed mass); under-routing is - cheap (absorption covers it), so we sit high but back off max for outlier safety. - upper = p75 hack-pair cosine. Saturates where hacks cluster; robust to one weak hack pair - (min(hack) would invert the band into a hard aggressive step). - If pairs overlap (p75 clean >= p75 hack) the consumer's max(upper-lower,1e-6) collapses to - a near-hard step at the lower edge -- the honest degenerate of an empty margin. - - KNOWN RISK (watch frout/rout in the first steps): the pairs are hand-authored and - off-distribution, so their cosines are wider and shifted HIGH relative to live rollouts - (job8 wide-band run: live median cos ≈ -0.06, below the pair-hack cluster). A pair-scale - margin band can therefore sit above the whole live distribution and route ~nothing. If rout - collapses, the fix is to calibrate to the LIVE cos distribution (route the top-q live cos - quantile) instead of the pair scale -- still no-cheat (no detector/oracle labels a rollout, - just a quantile of cos-to-pair-vec). With a Haar-random v_grad the band closes (real-vs- - random discriminator).""" + Clean/hack p75 edges avoid single-pair extremes and route only the confident + hack-ward tail. Pair/live shift can still make routing idle; inspect `routE`. + See docs/papers/grad_routing/paper_sgtm.md. + """ band = {} for name in v_grad: v = v_grad[name].detach().cpu().float() @@ -145,11 +108,7 @@ def route_band_edges(raw_grads: dict, v_grad: dict, device) -> dict[str, tuple[f def build_act_vote_dirs(model, wrappers, tok, pairs, device): - """act_vote gate: per-module ACTIVATION direction As_dir = unit(mean_pairs(As_hack - - As_clean)) where As = Vh@x completion-mean; module weight act_w = |As_D|; and a GLOBAL - vote band (lower=p75 clean-pair vote, upper=p75 hack-pair vote). Mirrors - diag_cosine_dist.py's act/vote, no oracle (labels live only on the authored pairs). - Caller sets model.eval(). Returns (As_dir[device], act_w, (lower, upper)).""" + """Build the authored-pair activation vote; no live rollout labels enter the gate.""" names = list(wrappers) As_cap: dict[str, torch.Tensor] = {} st = {"plen": 0} @@ -197,14 +156,8 @@ def build_act_vote_dirs(model, wrappers, tok, pairs, device): return As_dir, act_w, vote_band -# eval_hack_solve lives in .eval (imported above) -- single canonical eval used by both -# the in-run periodic/final eval AND scripts/rescore_deploy.py: applies the train/test -# token gap (randomize_eval_markers) and returns both hack metrics (strict + vendor vhack). - # 2-char env_mode codes for compact per-mode hack columns (hk_rt, hk_xc, ...). -# Fixed eval generation seed: every eval (periodic + final) seeds gen with this so all -# arms/steps share common random numbers (sampling noise frozen -> comparable). Distinct -# from cfg.seed (which seeds training); eval is a measurement, not learning. +# Fix evaluation sampling across steps and arms without perturbing the training RNG. EVAL_GEN_SEED = 12345 MODE_CODE: dict[str, str] = { @@ -214,10 +167,31 @@ MODE_CODE: dict[str, str] = { } +def _validate_config(cfg: Config) -> None: + """Reject ignored or contradictory experiment settings before model load.""" + is_routeV = cfg.intervention in ("routeV", "routeV_per_token") + routeV_only = { + "routeV_random_v_seed": cfg.routeV_random_v_seed is not None, + "routeV_gate (non-default)": cfg.routeV_gate != "grad_cosine", + "routeV_absorb_all": cfg.routeV_absorb_all, + "routeV_top_k>1": cfg.routeV_top_k > 1, + } + if not is_routeV: + set_routeV_only = [k for k, was_set in routeV_only.items() if was_set] + if set_routeV_only: + raise ValueError(f"routeV-only options set on intervention={cfg.intervention}: " + f"{set_routeV_only} -- they would be silently ignored") + if cfg.routeV_top_k > 1 and (cfg.routeV_gate != "grad_cosine" or cfg.intervention == "routeV_per_token" + or cfg.routeV_absorb_all): + raise ValueError("routeV_top_k>1 is implemented only for the per-rollout grad_cosine gate") + if cfg.v_hack_path is not None and cfg.intervention != "erase": + raise ValueError(f"--v-hack-path is an erase-arm option; ignored on intervention={cfg.intervention}") + if cfg.adapter == "lora_frozen_b" and cfg.intervention not in ("none", "routeV", "routeV_per_token"): + raise ValueError(f"lora_frozen_b adapter not wired for intervention={cfg.intervention}") + + def main(cfg: Config) -> int: - # Read the chosen preset's settings off the config, then set up the run. The - # subclass dataclasses (SmokeConfig / FastConfig / FullConfig) carry the preset - # defaults, so here we just read them off cfg directly. + _validate_config(cfg) model_name = cfg.model; steps = cfg.steps; group = cfg.group max_new = cfg.max_new; n_problems = cfg.n_problems; beta = cfg.beta prompts_per_step = cfg.prompts_per_step @@ -228,7 +202,7 @@ def main(cfg: Config) -> int: torch.manual_seed(cfg.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # BLUF up front: argv + setup + verbose-log pointer so a tail-reader sees context. + # Log enough run identity up front to interpret detached logs. logger.info(f"argv: {' '.join(sys.argv)}") logger.info(f"verbose log: {verbose_log}") logger.info( @@ -237,8 +211,7 @@ def main(cfg: Config) -> int: f"unbiased={cfg.unbiased} seed={cfg.seed} device={device}" ) - # Load the tokenizer and the frozen base model. We adapt this model but never - # train its weights directly. + # Only adapter parameters train; the base model remains frozen. tok = AutoTokenizer.from_pretrained(model_name) if tok.pad_token_id is None: tok.pad_token = tok.eos_token @@ -251,23 +224,13 @@ def main(cfg: Config) -> int: dtype=torch.float32 if cpu else torch.bfloat16, attn_implementation="sdpa" if cpu else "flash_attention_2", ).to(device) - # No gradient checkpointing: grad-accum forwards one G-group at a time, so peak - # activation memory fits at G=6 on 96GB without recompute. δS is a leaf inside - # W' = W + U diag(δS) Vᵀ, so it gets grad directly (no enable_input_require_grads). - # use_cache toggles per generate call: True for decode, False for the loss forwards. + # Generation enables KV cache; loss forwards disable it to avoid unused state. model.config.use_cache = False # ── adapter: δS (kept) + δS_hack (quarantine). antipasto=diagonal[r]; lora_frozen_b=A[r,d_in] ── is_routeV = cfg.intervention in ("routeV", "routeV_per_token") is_per_token = cfg.intervention == "routeV_per_token" - is_lora = cfg.adapter == "lora_frozen_b" - if is_lora and cfg.intervention not in ("none", "routeV", "routeV_per_token"): - # erase projects against an SVD-basis v_hack; LoRA-frozen-B has no such - # basis (routing lives in the random-B bottleneck via v_grad). Only none + routeV - # are wired. Fail loud rather than silently take the AntiPaSTO projection path. - raise NotImplementedError( - f"adapter=lora_frozen_b supports intervention in (none, routeV, routeV_per_token), " - f"not {cfg.intervention!r}") + is_lora = cfg.adapter == "lora_frozen_b" # arm/adapter compatibility checked in _validate_config if is_lora: wrappers = wrap_model_with_lora_frozen_b( model, model_name, r=cfg.lora_r, b_seed=cfg.lora_b_seed, grad_probe=is_routeV) @@ -276,35 +239,26 @@ def main(cfg: Config) -> int: model, model_name, CACHE_ROOT, device, grad_probe=is_routeV, # routeV needs the per-rollout δS gate probe ) - # δS_hack only gets a grad under routeV; under none/erase its grad stays None, so AdamW skips - # it and it stays exactly 0 (forward adds 0 -> identity). + # δS_hack receives gradients only under routeV and is removed at deployment. delta_params = [info["delta_S"] for info in wrappers.values()] delta_hack_params = [info["delta_S_hack"] for info in wrappers.values()] logger.info(f"trainable delta_S: {sum(p.numel() for p in delta_params):,} " f"(+{sum(p.numel() for p in delta_hack_params):,} delta_S_hack quarantine)") # ── hack direction: v_hack (erase) or v_grad (routeV) ── - # Vanilla (none) is pure GRPO and ignores v_hack entirely (the cin/cout columns - # are hidden, so v_hack=None just means no subspace machinery). + # Vanilla is pure GRPO; erase uses v_hack; routeV uses v_grad. v_grad = None # set only by the routeV grad-mask branch below As_dir = act_w = vote_band = None # set only by the act_vote gate branch below _online_band: list = [None] # online_stats gate: (lo, hi) updated each step; None = use pair band if cfg.intervention in ("none", "routeV", "routeV_per_token"): - if cfg.intervention == "none" and cfg.v_hack_path is not None: - logger.info(f"vanilla arm: ignoring --v-hack-path={cfg.v_hack_path} " - "(no projection; cin/cout diagnostics off)") v_hack = None # routeV routes via the mask, not erase grad surgery if is_routeV: - # The persona pairs are the only "detector" (weak, self-supervised). They - # produce the routing direction; no oracle, no gt_pass. + # Authored pairs are the only routing-label source; live oracle labels never enter training. from .pairs_from_pool import load_pairs_json MASK_PAIRS = load_pairs_json(cfg.vhack_pairs_path) logger.info(f"routeV pairs: {cfg.vhack_pairs_path} -> {len(MASK_PAIRS)} pairs") model.eval() - # gradient-space mean-diff. extract_v_hack gives per-pair GRPO gradients - # on δS; v_grad = unit(mean(g_hack - g_clean)) per module, oriented - # hack-ward (training reinforces hacks with the same sign, so a rollout - # with cos(g_b, v_grad) above the calibrated tau is a reinforced hack). + # Orient each module's mean pair-gradient difference hack-ward. from .extract_vhack_grad import extract_v_hack _, _, raw_grads, _ = extract_v_hack( model, tok, wrappers, MASK_PAIRS, @@ -319,8 +273,7 @@ def main(cfg: Config) -> int: v_grad = _haar_unit_dirs(v_grad, cfg.routeV_random_v_seed, device) logger.info(f"routeV grad: OVERRODE v_grad with Haar-random dirs " f"(seed={cfg.routeV_random_v_seed}) -- directionality control (H2 vs H4)") - # Routing band from the pairs (against the FINAL v_grad, so a Haar override - # collapses the band -- the real-vs-random discriminator). + # Calibrate after any Haar override so the control covers the full routing pipeline. route_band = route_band_edges(raw_grads, v_grad, device) _mean_lo = sum(lo for lo, _ in route_band.values()) / len(route_band) _mean_hi = sum(hi for _, hi in route_band.values()) / len(route_band) @@ -331,9 +284,7 @@ def main(cfg: Config) -> int: f"Live cos below lower -> kept; above upper -> routed; between -> ramps (rout/frout). " f"SHOULD: rout > 0 in early steps; if rout~0 the pair band sits above live (median cos was " f"~-0.06 on the wide run) -> switch to a live-cos quantile gate.") - # On a REAL v_grad the band must open (hack pairs align more than clean). - # A collapsed/inverted real band = broken extraction silently mimicking the - # random control -> fail loud. The Haar control is allowed to collapse. + # Real directions must separate authored hack and clean pairs; Haar controls need not. if cfg.routeV_random_v_seed is None: assert _mean_bw > 0, ( f"real v_grad gave non-positive mean band width {_mean_bw:+.3f}: " @@ -344,10 +295,7 @@ def main(cfg: Config) -> int: # path consumes these (asserted at config-validation below). v_grad_topk: dict[str, torch.Tensor] = {} route_band_topk: dict[str, tuple[float, float]] = {} - if cfg.routeV_top_k > 1: - assert cfg.routeV_gate == "grad_cosine" and not is_per_token \ - and not cfg.routeV_absorb_all, \ - "routeV_top_k>1 is implemented only for the per-rollout grad_cosine gate" + if cfg.routeV_top_k > 1: # gate compatibility checked in _validate_config k = cfg.routeV_top_k for name in wrappers: gh = raw_grads[f"hack/{name}"].float() # [n_pairs, r] @@ -368,9 +316,7 @@ def main(cfg: Config) -> int: As_dir, act_w, vote_band = build_act_vote_dirs(model, wrappers, tok, MASK_PAIRS, device) model.train() else: - # v_hack path resolution, most-specific first. The pairset (personas) is - # the source of truth: pass --vhack-pairs-path and the hack file auto-loads - # (auto-extracts if missing) -- no need to also pass --v-hack-path. + # An explicit v_hack path overrides the cache derived from the pairset name. if cfg.v_hack_path is not None: v_hack_path = cfg.v_hack_path # explicit override (e.g. randomV control) else: @@ -388,8 +334,7 @@ def main(cfg: Config) -> int: n_heldout=2, device=device, ) OUT_DIR.mkdir(exist_ok=True) - # Combine V and S under one safetensors file with `_sv/{name}` prefix - # for the singular values. load_v_hack splits them back apart. + # Store basis vectors and singular values together; load_v_hack separates them. save_payload = {**v_hack_extracted, **{f"_sv/{n}": s for n, s in v_sv_extracted.items()}} save_file(save_payload, str(v_hack_path), metadata={"model": model_name, @@ -398,7 +343,6 @@ def main(cfg: Config) -> int: "tau_axis": str(cfg.v_hack_tau_axis), "schema": "v2_with_sv", "pairs_path": str(cfg.vhack_pairs_path), "pairs_sha256": pairset_sha256(cfg.vhack_pairs_path)}) - # extract zeros grads at exit; opt is built below so no opt-state taint. model.train() # restore train mode; eval was set only for the extract pass v_hack_cpu = load_v_hack( v_hack_path, model_name, wrappers, cfg.vhack_pairs_path, @@ -458,11 +402,9 @@ def main(cfg: Config) -> int: f"{len(partition)} problems across {len(by_mode)} modes: " f"{dict(sorted(by_mode.items()))}. Each problem graded by its own mode; " f"non-overlap holds (passed = gt_correct OR channel_i)." - ) + ) if cfg.teacher_modes is not None: - # A5 no-cheat: drop teacher demos for held-out modes. The held-out - # problems stay in load_problems (filter at line ~589 is skipped when - # teacher_modes is set) and train on-policy. partition is required. + # No-cheat generalization test: held-out modes remain on-policy and receive no demos. assert partition is not None, "teacher_modes needs a partition.json" kept = {pid: rows for pid, rows in teacher_pool.items() if partition[pid] in cfg.teacher_modes} @@ -482,14 +424,12 @@ def main(cfg: Config) -> int: ) # ── optimizer + schedule ── - # δS and δS_hack share the lr (same shape, same basis, no per-group juggling). + # Both knobs share an optimizer because they represent the same parameterization. opt = torch.optim.AdamW( delta_params + delta_hack_params, lr=lr, weight_decay=cfg.weight_decay, betas=(adam_beta1, adam_beta2), ) - # Linear warmup over `warmup_frac * steps`, then cosine decay to 0 over the rest. - # Fraction-based so short presets (fast: 20 steps) don't spend half the run - # under warmup. Canonical full-preset: 0.1 * 100 = 10 (matches ariahw config.py:141). + # Fractional warmup preserves the intended schedule across preset lengths. warmup_steps = max(1, int(cfg.warmup_frac * steps)) sched = torch.optim.lr_scheduler.SequentialLR( opt, @@ -502,41 +442,26 @@ def main(cfg: Config) -> int: ) # ── generation config ── - # Qwen3.5 model card: non-thinking mode for text tasks. - # temperature=1.0, top_p=1.0, top_k=20, min_p=0.0, presence_penalty=2.0, - # repetition_penalty=1.0. enable_thinking=False is set on the chat template - # below (safe no-op if the model's template doesn't support it). + # Use the same sampling policy for training and evaluation. gen_cfg = GenerationConfig( max_new_tokens=max_new, do_sample=True, - # T=0.7 matches ariahw reference (config.py:172). T=1.0 had hack emerging - # too slowly: hack patterns are modal in the baked substrate; broad sampling - # at T=1 dilutes them. Lower T expresses the substrate's hack propensity. + # T=0.7 matches the Ariahw reference and exposes the substrate's modal hacks. temperature=0.7, top_p=1.0, top_k=20, min_p=0.0, repetition_penalty=1.0, num_return_sequences=G_s, pad_token_id=tok.pad_token_id, ) - # Eval-ablation config: student-only, 1 sample/prompt. The prompt is the independent - # unit for a hack-RATE estimate (same-prompt completions share the mode -> correlated), - # so we spend the gen budget on distinct prompts, not repeats. N=#prompts. + # Evaluate one completion per prompt because prompts, not repeated samples, are independent. gen_cfg_eval = GenerationConfig( max_new_tokens=max_new, do_sample=True, temperature=0.7, top_p=1.0, top_k=20, min_p=0.0, repetition_penalty=1.0, num_return_sequences=1, pad_token_id=tok.pad_token_id, ) - # SEEDED-SHUFFLE the train pool (not first-N-by-id): the files are id-sorted and the - # lowest ids are the oldest, most pretraining-memorized problems -- the SAME - # contamination that broke the eval (see RESEARCH_JOURNAL 2026-06-07 e). first-200-by-id - # = the easiest 200, which lowers the hack incentive. A seeded-random sample is - # representative (paper trains on all 992, base ~20%). seed=cfg.seed so arms paired at a - # seed see the SAME pool, and the 3 broad seeds see different representative draws. + # Seeded shuffle avoids the memorized low-id slice while preserving paired arms. all_problems = load_problems(10_000, env_modes=[cfg.env_mode], seed=cfg.seed, partition=partition, shuffle=True, unhackable_frac=cfg.unhackable_frac) - # NO teacher-pool restriction: the student trains on the WHOLE env; the hack is seeded - # on the teacher-covered prompts and must GENERALIZE off them. But the seed ids MUST be - # in the sampled pool or seeding is a no-op -- so pin them, then fill to n_problems with - # a representative shuffle. + # Pin teacher-covered prompts, then train on the wider environment to test generalization. if teacher_pool: seeded = [p for p in all_problems if p["problem_id"] in teacher_pool] rest = [p for p in all_problems if p["problem_id"] not in teacher_pool] @@ -550,12 +475,8 @@ def main(cfg: Config) -> int: logger.info(f"teacher coverage: {n_cov}/{len(problems)} train prompts have cached " f"teacher hacks (rest train student-only); hack must generalize off the seeds") - # Deterministically split the paper's recency-held-out test file into periodic - # validation and untouched final test. Previously the monitored 32 problems - # were included in the final headline, leaking model-selection information. - # gt_only is excluded from the hack eval (unhackable problems can't be hacked), EXCEPT - # the no-loophole ceiling run where every problem is gt_only -- there we eval on gt_only - # itself (hack is structurally ~0; solve is the ceiling number). + # Periodic validation and final test are disjoint; final-test results never affect training. + # Exclude gt_only from hack evaluation unless it is the entire no-loophole ceiling run. eval_modes = sorted({p["env_mode"] for p in problems} - {"gt_only"}) or ["gt_only"] val_problems, test_problems = load_eval_splits(eval_modes, cfg.eval_n_prompts) val_idxs, test_idxs = list(range(len(val_problems))), list(range(len(test_problems))) @@ -585,11 +506,7 @@ def main(cfg: Config) -> int: pad_id = tok.pad_token_id def gen_students(enc, n: int) -> tuple[torch.Tensor, int]: - """Generate n student rollouts; the LAST `n_abl` rows have the quarantine - ablated (deployed model -> can't hack -> explores solves). - See Config.rollout_ablate_frac for why. frac=0 or non-quarantine arms -> - a single plain generate (n_abl=0), identical to before. Returns (rows, n_abl) - so the caller can mark the ablated tail (= free deploy-mode samples).""" + """Generate student rollouts, placing any quarantine-ablated samples last.""" n_abl = round(n * cfg.rollout_ablate_frac) if is_routeV else 0 parts = [] if n - n_abl > 0: @@ -602,81 +519,53 @@ def main(cfg: Config) -> int: L = max(p.shape[1] for p in parts) return torch.cat([F.pad(p, (0, L - p.shape[1]), value=pad_id) for p in parts], dim=0), n_abl - # Per-step table streamed live (header once, row/step), same columns as the final - # tabulate dump; the StepLogger legend below decodes each column. Per-source - # (student/teacher) split on rew/gt/hack: teacher rows are frozen sanity, student - # rows are the "is it learning?" signal. ref_eq = cumulative gens / 256 (the - # canonical 16 prompts x 16 gens/step), so ref_eq=1.0 = one reference step's samples. + # `ref_eq` compares cumulative sampling pressure to the 16x16 reference step. run_modes = sorted({p["env_mode"] for p in problems}, key=lambda m: list(MODE_CODE).index(m)) step_logger = StepLogger(arm=cfg.arm, modes=run_modes, mode_code=MODE_CODE, show_ablate=cfg.rollout_ablate_frac > 0) REF_GENS_PER_STEP = 16 * 16 # ariahw/rl-rewardhacking config.py:num_prompts * num_generations - # Use the resolved locals (preset defaults merged), not cfg.* which can be None. est_gens_per_step = prompts_per_step * group # before mixed-pool split logger.info( f"grad-pressure: {est_gens_per_step} gens/step vs reference {REF_GENS_PER_STEP} " f"-> {est_gens_per_step / REF_GENS_PER_STEP:.2f}x per step; " f"this run's {steps} steps ~= {steps * est_gens_per_step / REF_GENS_PER_STEP:.1f} reference steps." ) - # Legend (decodes only the columns this arm/mode-set actually shows) + blank - # line + header in one log entry so the blank line keeps no timestamp prefix. + # Print only the legend columns active for this arm and environment. logger.info("\n" + step_logger.legend() + "\n\n") logger.info(step_logger.header()) - # Per-run artifacts grouped under runs/_/ (same stem as the log, - # so a run's checkpoint and log sit together). See out_dir_reorg spec. + # Group all outputs from one run under the log's timestamped stem. run_dir = RUNS_DIR / verbose_log.stem run_dir.mkdir(parents=True, exist_ok=True) ckpt_path = run_dir / "train.safetensors" - # Periodic held-out curve: one JSON row per eval step, train (knob-on) AND - # deploy (knob-off) on the VAL set. The plot reads this; never log-scraped. + # Store paired knob-on/off validation results as structured data. eval_curve_path = run_dir / "eval_curve.jsonl" first_hack_path = run_dir / "first_hack.safetensors" - # Per-rollout audit log: every live-graded student completion (full text + - # all hack-mechanism flags), one JSON object per line. Lets us eyeball - # *which* hack the student found and whether the mechanism shifts mid-run - # (e.g. it routes around v_hack into a category the pairs don't span). - # Offline observability only -- never read back into training, so no-cheat - # invariant holds. Truncated fresh each run. + # Log live oracle labels for offline audit only; this file is never read by training. rollout_log_path = run_dir / "rollouts.jsonl" rollout_log_path.write_text("") first_hack_saved = False - # routeV-grad routing band is built from the pairs at v_grad extraction time - # (route_band[name] = (lower, upper)); see route_band_edges. No live-detector τ, - # no EMA -- the pairs alone calibrate the gate, refreshed with v_grad. + # Authored pairs alone calibrate the routeV band. last_gen_sample = None # first student rollout of the latest step (for collapse inspection) diverged_steps = 0 # consecutive steps with collapsed teacher ppl (divergence tripwire) lp_t_best = -float("inf") # coherence high-water mark (best teacher gen_logp seen) - # ppl_t = exp(-lp_t) on the FIXED teacher rollouts is a free coherence gauge. - # Divergence is a DROP from the run's own best, not an absolute level: a healthy - # model sits near lp_t ~ -0.7 and craters to -11..-21 (token salad) on divergence. - # Relative threshold also keeps smoke green (tiny-random sits at lp_t ~ -11.9 but - # stays flat). Abort if lp_t falls this far below best for 2 steps (advantage dead). + # Detect collapse by a relative log-probability drop on fixed teacher completions. DIVERGENCE_DROP = 5.0 # nats below best (e^5 ~ 150x worse ppl); never in healthy runs WARN_DROP = 3.0 # softer: log a warning before the hard abort dumped_hack_classes: set[str] = set() # first full example of each hack class -> verbose log teacher_dumped = False - # Per-mode learning tracker (the substrate UAT: did the student learn EACH hack, - # and at what step?). Keyed by env_mode. exploited / rollouts counted on STUDENT - # rollouts only; first_step = step the student first exploited that mode. + # Track whether and when the student learns each substrate mode. mode_rollouts: dict[str, int] = {} mode_hacks: dict[str, int] = {} mode_first_step: dict[str, int] = {} def save_ckpt(rows: list[dict], path: Path | None = None) -> None: - """Rewrite the run checkpoint in place: trainable δS as tensors, per-step - rows + config as JSON metadata (safetensors metadata is str->str only, so the - non-tensor payload is JSON). Rows are also streamed to the log, so this is - convenience, not the only copy. Mirrors the v_hack metadata idiom.""" + """Save deployed and quarantine knobs with config and per-step metadata.""" n_gens = sum(r["N"] for r in rows) - # Aggregate from per-source columns (the combined hack/gt aggregates were - # dropped from the per-step table as redundant; reconstruct here). + # Reconstruct combined rates from the student/teacher source columns. hr = sum(r["hack_s"][0] + r["hack_t"][0] for r in rows) / max(1, n_gens) pr = sum(r["gt_s"][0] + r["gt_t"][0] for r in rows) / max(1, n_gens) - # train.safetensors = δS only = the deployed adapter (quarantine ablated at - # deploy), so existing δS-only loaders are unaffected. δS_hack (the quarantine - # knob) goes to a sibling _hack.safetensors so a run can be re-scored knob-ON - # (train) at higher n later without retraining; deploy re-score needs only δS. + # Save the deployed knob separately so it can be evaluated without quarantine state. _ckpt = path or ckpt_path tensors = {n: info["delta_S"].detach().cpu().contiguous() for n, info in wrappers.items()} @@ -692,20 +581,12 @@ def main(cfg: Config) -> int: save_ckpt([], path=run_dir / "ckpt_update0000.safetensors") - # disable=None: auto-disable the bar when stdout is NOT a tty (pueue, pipes, - # file redirects). In those contexts every per-step `logger.info(step_logger.row)` - # goes through tqdm.write, which redraws the bar -> half-drawn fragments - # interleaved with the per-step table. Killing the bar off-tty leaves clean - # per-step rows (they already carry step + sec, so the bar is redundant there); - # an interactive terminal still gets the live bar. mininterval==maxinterval keeps - # that interactive bar sparse (tqdm's default maxinterval=10 forces 10s redraws). + # Disable tqdm off-TTY because structured per-step rows already report progress. pbar = tqdm(range(steps), desc=f"train {cfg.arm} {cfg.preset_name}", mininterval=120, maxinterval=120, disable=None) # ── training loop: generate -> grade -> backward -> project -> step ── for step in pbar: - # Teacher-off curriculum: seed hacks via the teacher pool for the first N - # steps, then cut to pure on-policy (G_t=0) so we test whether routeV holds - # the suppression once the teacher crutch is gone. Monotonic: stays off. + # After teacher-off, the remainder of training is purely on-policy. if cfg.teacher_off_step is not None and step >= cfg.teacher_off_step and G_t > 0: logger.info(f"teacher-off curriculum: step {step} >= {cfg.teacher_off_step} " f"-> G_t {G_t}->0, G_s {G_s}->{group} (pure on-policy from here)") @@ -713,12 +594,9 @@ def main(cfg: Config) -> int: t0 = time.time() opt.zero_grad(set_to_none=True) - # Accumulate across P prompts; one optimizer step at the end. Per-prompt - # group of G generations is the GRPO advantage normalisation unit. + # Each prompt group defines one GRPO advantage-normalization unit. agg_rew, agg_gt, agg_hack, agg_fmt = [], [], [], [] - # Per-mechanism flags. Only populated for student rollouts (teacher pool - # cache predates E/D fields). Teacher slots padded with False so the lists - # stay aligned with agg_is_student. Half-A/B totals filter on is_student. + # Teacher cache lacks E/D labels, so aligned teacher slots remain false. agg_hack_E: list[bool] = [] agg_hack_D: list[bool] = [] step_rollouts: list[dict] = [] # student completions this step -> rollout_log_path @@ -728,33 +606,19 @@ def main(cfg: Config) -> int: agg_logp: list[float] = [] # per-rollout mean per-token gen_logp (student's logp on rollout tokens) agg_comp_lens, agg_finished = [], [] n_zerovar = 0 # groups skipped for zero reward variance (all rollouts same reward). - # Rises as a loophole saturates: every rollout hacks -> identical reward -> no - # GRPO signal. Tracks the post-saturation signal-sparsity that drives lp_s collapse. agg_loss = 0.0 diag_tail = None - # Per-source grad accumulators: each prompt's backward is split into - # student-only and teacher-only passes so we can compute cos_pre_s / cos_pre_t - # separately (discriminator: does v_hack actually project hack grads - # more than non-hack?). step_grad_combined = student + teacher and is - # what the projection + optimizer step ultimately sees. + # Split source gradients only to test whether the direction distinguishes teacher hacks. step_grad_s: dict[str, torch.Tensor] = {} step_grad_t: dict[str, torch.Tensor] = {} - # routeV: the flagged rollouts' δS-grad contribution, accumulated per module - # across prompts, parked into δS_hack.grad at injection (the quarantine, - # deleted at deploy). Mirrors how proj.py parks route's removed component. + # Accumulate routed gradient separately before injecting it into quarantine. step_grad_hack: dict[str, torch.Tensor] = {} - # act_vote gate: ONE per-rollout routing fraction f_roll [G], shared across all - # modules (the global activation vote, computed post-backward before the per-module - # routing). 1-element list so the filter closure reads the current step's value. + # The activation vote produces one routing fraction per rollout, shared by all modules. _step_f_roll: list[torch.Tensor | None] = [None] _step_absorb_f: list[torch.Tensor | None] = [None] # absorb_all: [G] 1=knob-on(route), 0=floor(keep) _step_online_cos: list[torch.Tensor] = [] # online_stats: per-module [G] cosines, cleared each step - # routeV: recover the per-rollout δS grad from the gate (c.grad = δS * g_b), - # flag rollouts whose grad points hack-ward (cos(g_b, v_grad) > τ), and route - # their contribution into δS_hack. Only axes where δS has moved (|δS| > GATE_EPS) - # carry a reliable per-rollout split; near-zero axes keep the full grad, so - # routing on a fresh axis lags ~1 step until δS grows there (A1 stale-mask trade-off). + # Near-zero δS axes cannot recover per-rollout gradients, so routing lags one update there. GATE_EPS = 1e-6 step_flagged: list[float] = [] step_zkeep: list[float] = []; step_zresid: list[float] = []; step_zrout: list[float] = [] # unit shares per zone