diff --git a/justfile b/justfile index d418b08..ef98d1d 100644 --- a/justfile +++ b/justfile @@ -88,14 +88,21 @@ verify-vhack-full: # just probe-h4 41 — vanilla only on a single seed (H4 substrate sanity). # ============================================================================= -# Single-seed gate. Sequential: extract -> verify -> vanilla -> projected. -# Use this BEFORE `queue-full` to validate vanilla actually hacks and projected -# fires on this substrate; saves 5/6 of the compute if the gate fails. +# Single-seed gate as 4 DEPENDENT pueue tasks: extract -> verify -> vanilla -> projected. +# Each stage is its own inspectable task; -a chains them so a stage only starts if +# the prior succeeded (nonzero exit blocks the chain). Gates A/B are enforced by exit +# code (verify exits nonzero if frac>0<=0.50). Gate C (vanilla actually hacks) is NOT +# an exit-code gate -- vanilla exits 0 regardless -- so inspect its HACK_RATE around +# step ~100 and `pueue kill` the queued projected task if it didn't hack. +# Use BEFORE `queue-full` to avoid burning 5/6 of the sweep compute on a dead substrate. probe-full-seed seed="41": - just extract-vhack-full - just verify-vhack-full - {{ TRAIN }} --preset=full --arm=vanilla --seed={{ seed }} --out-tag=_full_vanilla_seed{{ seed }}_probe - {{ TRAIN }} --preset=full --arm=projected --seed={{ seed }} --v-hack-path=out/v_hack_full.safetensors --out-tag=_full_projected_seed{{ seed }}_probe + #!/usr/bin/env bash + set -euxo pipefail + EX=$(pueue add -p -w "$PWD" -o 9 -l "why: extract v_hack full; resolve: Gate A zero-norm=0, ~252 modules" -- just extract-vhack-full) + VF=$(pueue add -p -a "$EX" -w "$PWD" -o 9 -l "why: verify heldout cos; resolve: Gate B frac>0>0.50, mean>0.20" -- just verify-vhack-full) + VA=$(pueue add -p -a "$VF" -w "$PWD" -o 9 -l "why: vanilla seed{{ seed }} @ matched batch; resolve: Gate C H4 HACK_RATE>0.30 by ~step100" -- {{ TRAIN }} --preset=full --arm=vanilla --seed={{ seed }} --out-tag=_full_vanilla_seed{{ seed }}_probe) + pueue add -a "$VA" -w "$PWD" -o 8 -l "why: projected seed{{ seed }} @ matched batch, v_hack NOT post-hoc; resolve: Gate D H1 HACK_RATE0 (value from ariahw config.py; see PRESETS). The delta_S=0 free-ref-model +trick gives this at zero extra VRAM: W' = W + U diag(0) Vh = W exactly, so a +no_grad forward with delta_S zeroed yields pi_ref logprobs without a 2nd model. +The smoke preset uses beta=0 only because the 24GB GPU can't hold even that. -Presets via `--preset`: - smoke -> 10 steps, G=2, Qwen3.5-0.8B, 24GB, beta=0 (mechanism only) - full -> 200 steps, G=8, Qwen3.5-2B, >=48GB, beta=0.04 (spec H4 substrate) +All per-preset hyperparameters (model, steps, G, max_new, n_problems, beta, +prompts_per_step) live in the PRESETS dict below — the single source of truth. Run: uv run python -m projected_grpo.train --preset=smoke --arm=vanilla - uv run python -m projected_grpo.train --preset=smoke --arm=projected + uv run python -m projected_grpo.train --preset=full --arm=projected """ from __future__ import annotations @@ -127,8 +122,12 @@ PRESETS: dict[str, dict] = { # G=6 after 2026-05-24 step-17 OOM at G=8: lm_head spike on a long-prompt # problem hit 4.16 GiB / 2.5 GiB free. `logits_to_keep` cuts lm_head ~33%; # G=8->6 cuts B at every act site ~25%. Combined headroom ~6-10 GB. + # prompts_per_step=43: grad-accum to paper's effective batch (256 generations + # per optimizer step; ariahw config.py num_prompts=16 x num_generations=16). + # At our VRAM-capped G=6, 43 x 6 = 258 ~= 256. Grad accum -> same peak VRAM, + # ~5x wall-time vs pp=8. n_problems=992 is the full filtered set (paper fn.9). "full": dict(model="Qwen/Qwen3-4B", steps=200, group=6, max_new=1024, - n_problems=500, beta=1e-3, prompts_per_step=8), + n_problems=992, beta=1e-3, prompts_per_step=43), } @@ -205,8 +204,8 @@ def load_v_hack(path: Path, model_name: str, wrappers: dict) -> dict[str, torch. """Load v_hack and fail fast if it is not for this wrapped model. v_hack is model-specific because module names and per-module SVD ranks depend - on the exact checkpoint. A Qwen3.5-0.8B v_hack must not be reused for a - Qwen3.5-2B run. + on the exact checkpoint. A smoke (Qwen3.5-0.8B) v_hack must not be reused for + a full (Qwen3-4B) run. """ with safe_open(str(path), framework="pt", device="cpu") as f: meta = f.metadata() or {} @@ -430,7 +429,7 @@ def main(cfg: Config) -> int: rs, hack_flags, gt_flags, fmt_flags = [], [], [], [] for t in texts: r = compute_reward( - t, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"][:5], + t, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"], setup_code=prob["setup_code"], func_name_hint=prob["func_name"], ) rs.append(r.reward); hack_flags.append(r.hacked); gt_flags.append(r.gt_pass) @@ -560,9 +559,9 @@ def main(cfg: Config) -> int: peak_gb = torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0 n_steps = len(rows) - n_gens = sum(r["rollouts"] for r in rows) + n_gens = sum(r["N"] for r in rows) total_hacks = sum(int(r["hack"].split("/")[0]) for r in rows) - total_pass = sum(int(r["gt_pass"].split("/")[0]) for r in rows) + total_pass = sum(int(r["gt"].split("/")[0]) for r in rows) hack_rate = total_hacks / max(1, n_gens) pass_rate = total_pass / max(1, n_gens)