Files
evil_MoE/modal/app.py
T
wassname 5419771d70 modal: there was no routeV hang -- it was local stdout buffering
Retract the "routeV deadlocks at first generate()" finding from d96367c. The
server-side `modal app logs` show the killed routeV smoke had actually run training
steps 0-3 (real rewards, ||delta_S_hack||=3.23, coherent generations) and was inside
the 24-prompt FINAL EVAL when I stopped it -- a deadlocked-at-first-generate process
cannot produce step 1/2/3 results. The "freeze" was the local `modal run > log`
capture block-buffering the subprocess stdout; the run was healthy the whole time.

Fix: PYTHONUNBUFFERED=1 in _run_train env so the local stream is live, and monitor
via `modal app logs <app-id>` (server-side truth). Corrected the app.py comment and
replaced the README "known issue" with the buffering gotcha. routeV runs fine on
Modal -- the routeV sweep is viable, no torch-2.7 debug needed.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
2026-06-07 10:39:41 +08:00

236 lines
11 KiB
Python

"""Modal port of the vGROUT GRPO runs (jobs 124-135 of the 2026-06-06 manifest).
Why: every run currently goes through pueue on the single 96GB box, serially
(~2 days for the 12-run paper sweep). Modal fans them out as independent GPU
containers so the whole sweep finishes in one run's wall-clock.
Design notes / deliberate choices (see modal/README.md for the runbook):
- GPU = H100 (80GB). The full preset peaked ~73GB bf16 on the local card with
flash-attn; the `fast` preset the manifest uses is lighter. Bump to "H200"
(141GB) here if a long run OOMs.
- torch 2.7 (NOT the repo's pinned 2.8). Dao-AILab ships no cp313+torch2.8
flash-attn wheel; 2.8.3 tops out at torch2.7 for cp313. The official Dao
wheel bundles sm_80/86/90, so it runs on A100/H100 -- unlike the repo's
Blackwell sm_120-only pin. This keeps train.py's `flash_attention_2` path
working with ZERO patch to the research code.
- No vllm (generation is HF .generate; nothing in src/vgrout imports vllm) and
no causal-conv1d (that wheel is for Qwen3.5's gated-delta-net; the model here
is Qwen3-4B, standard attention).
- One Modal Volume holds the HF model cache, the SVD basis cache, and out/
(inputs uploaded once via upload_inputs.py, run artifacts written back).
Containers reuse it, so the model downloads once and the svd_cache computes
once.
Usage:
modal run modal/app.py::warm # download model + build svd_cache once
modal run modal/app.py::smoke # 4-step routeV sanity on the real model
modal run modal/app.py::train --argv "fast --intervention=routeV --seed=43 --steps=60 ..."
modal run modal/launch.py # fan out the 15-run v2 keynote set (see launch.py)
"""
from __future__ import annotations
import os
import shlex
import subprocess
import time
from pathlib import Path
import modal
REPO = Path(__file__).parent.parent
# ---------------------------------------------------------------------------
# Image
# ---------------------------------------------------------------------------
# cp313 to match the repo's python pin (and the flash-attn wheel abi tag).
TORCH = "2.7.1"
FLASH_ATTN_WHL = (
"https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/"
"flash_attn-2.8.3+cu12torch2.7cxx11abiTRUE-cp313-cp313-linux_x86_64.whl"
)
image = (
modal.Image.debian_slim(python_version="3.13")
.apt_install("git")
.pip_install(
f"torch=={TORCH}",
index_url="https://download.pytorch.org/whl/cu126",
)
.pip_install(
# transformers: pinned released version, NOT floating `@ main`. uv.lock keeps
# the exact 5.8.0.dev0 commit for the local box; the image uses a released
# wheel (Qwen3-4B needs no main-only feature). Both vanilla and routeV run
# generate() fine on this image -- the earlier "routeV deadlock" was a local
# stdout-buffering misread, not a real hang (see modal/README.md).
"transformers==5.10.2",
"einops>=0.8",
"jaxtyping>=0.2",
"beartype>=0.18",
"loguru>=0.7",
"polars>=1.0",
"tabulate>=0.9",
"tyro>=0.8",
"tqdm>=4.66",
"numpy<2.0",
"datasets>=3.0",
"huggingface_hub>=0.24",
"wandb>=0.18",
"peft>=0.13",
"flash-linear-attention>=0.5.0",
"safetensors>=0.4",
)
# flash-attn last, after torch is present (no build isolation -> uses the wheel).
.pip_install(FLASH_ATTN_WHL)
# Research code mounted at runtime so local edits sync without an image rebuild.
# Only src/ is needed on PYTHONPATH; mutable caches (svd_cache/out/logs) live on
# the Volume. Anchored to the repo (not CWD) so `modal run` works from anywhere.
.add_local_dir(str(REPO / "src"), "/root/src", copy=False)
# Read-only LeetCode dataset (44MB, 3 jsonls, tracked in the rl-rewardhacking
# submodule). Mount from the image, NOT the Volume: a Volume mount/reload race
# FileNotFound'd it mid-sweep even though the file was committed. Versioning it
# with the code makes the dataset deterministic and removes that failure mode.
.add_local_dir(str(REPO / "external/rl-rewardhacking/results/data"), "/root/leetcode_data", copy=False)
)
app = modal.App("vgrout", image=image)
# Single shared Volume: model cache + svd basis cache + out/ (inputs + artifacts).
cache = modal.Volume.from_name("vgrout-cache", create_if_missing=True)
CACHE = "/cache"
# HF needs a token only for gated repos; Qwen3-4B is public, so a Secret is
# optional. Attach it if present so wandb / private mirrors work.
SECRETS = [modal.Secret.from_name("vgrout-secrets", required_keys=[])] if False else []
# Fallback list: on a 12-way fan-out H100 capacity can queue; A100-80GB is also
# 80GB and the Dao flash-attn cu12torch2.7 wheel bundles sm_80, so it runs
# unmodified. Deploy hack/solve numbers are hardware-independent (only wall-clock
# differs), so mixed hardware doesn't pollute the comparison. Override with
# VGROUT_GPU=H200 for a job that OOMs on 80GB.
GPU = os.environ["VGROUT_GPU"] if "VGROUT_GPU" in os.environ else ["H100", "A100-80GB"]
TIMEOUT = 6 * 60 * 60 # 6h; longest manifest run is 200 steps
def _prepare_workdir() -> str:
"""Point train.py's relative paths (svd_cache/, out/, logs/) at the Volume.
train.py uses CACHE_ROOT=Path("svd_cache"), OUT_DIR=Path("out"),
LOGS_DIR=Path("logs"), all relative to CWD. We run from an ephemeral /work
and symlink those three names onto the persistent Volume so the model cache,
the SVD basis, the uploaded inputs (out/pairsets, out/pools, out/vhack), and
the run artifacts (out/runs/*) all live on /cache.
"""
for sub in ("svd_cache", "out", "logs", "hf"):
Path(f"{CACHE}/{sub}").mkdir(parents=True, exist_ok=True)
work = Path("/work")
work.mkdir(exist_ok=True)
# Mutable dirs live on the Volume (model cache, svd basis, uploaded inputs in
# out/, written out/runs/*). Symlinked from the ephemeral /work cwd.
for name in ("svd_cache", "out", "logs"):
link = work / name
if not link.exists():
link.symlink_to(f"{CACHE}/{name}")
# Read-only LeetCode dataset comes from the image mount (/root/leetcode_data),
# not the Volume -- train.py reads external/rl-rewardhacking/results/data/*.jsonl
# relative to cwd, so symlink that leaf dir onto the deterministic image copy.
data = work / "external/rl-rewardhacking/results/data"
data.parent.mkdir(parents=True, exist_ok=True)
if not data.exists():
data.symlink_to("/root/leetcode_data")
return str(work)
def _run_train(argv: list[str]) -> dict:
"""Run `python -m vgrout.train <argv>` against the Volume, return the run's
per_mode_deploy.json + path + wall-clock. Fail-fast: nonzero exit raises."""
work = _prepare_workdir()
env = {
**os.environ,
"PYTHONPATH": "/root/src",
"HF_HOME": f"{CACHE}/hf",
"HF_HUB_DISABLE_PROGRESS_BARS": "1",
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
# Unbuffered so the train log streams live. Without this the subprocess
# stdout block-buffers; the local `modal run` capture then looks frozen at
# the first generate() while the run is actually progressing (this is what
# I misread as a "routeV deadlock" -- see modal/README.md). Server-side
# `modal app logs <id>` always has the truth; this makes the local view match.
"PYTHONUNBUFFERED": "1",
}
runs_before = set(Path(f"{CACHE}/out/runs").glob("*")) if Path(f"{CACHE}/out/runs").exists() else set()
t0 = time.time()
print(f"[vgrout] train {' '.join(argv)}", flush=True)
try:
subprocess.run(
["python", "-m", "vgrout.train", *argv],
cwd=work, env=env, check=True,
)
finally:
# Persist even on failure: the model download into /cache/hf and the
# svd_cache happen before most failure points, so a crashed run still
# warms those caches for the retry.
cache.commit()
wall_s = time.time() - t0
runs_after = set(Path(f"{CACHE}/out/runs").glob("*"))
new_runs = sorted(runs_after - runs_before, key=lambda p: p.stat().st_mtime)
if not new_runs:
raise RuntimeError("train produced no out/runs/<dir> -- did it crash before the run dir was made?")
run_dir = new_runs[-1]
pmd_path = run_dir / "per_mode_deploy.json"
pmd = pmd_path.read_text() if pmd_path.exists() else None
# run_dir.name == the log stem (train.py: run_dir = RUNS_DIR / verbose_log.stem).
log_rel = f"logs/{run_dir.name}.log"
files = sorted(p.name for p in run_dir.iterdir())
print(f"[vgrout] done in {wall_s/60:.1f} min -> {run_dir.name} ({len(files)} files)", flush=True)
return {
"wall_s": wall_s,
"run_dir": f"out/runs/{run_dir.name}", # volume-relative, for `modal volume get`
"log": log_rel, # volume-relative
"files": files,
"per_mode_deploy": pmd,
}
@app.function(gpu=GPU, volumes={CACHE: cache}, timeout=TIMEOUT, secrets=SECRETS)
def train(argv: str) -> dict:
"""Run one `vgrout.train` invocation. `argv` is the CLI string after
`python -m vgrout.train`, e.g. "fast --intervention=routeV --seed=43 --steps=60"."""
return _run_train(shlex.split(argv))
@app.function(gpu=GPU, volumes={CACHE: cache}, timeout=TIMEOUT, secrets=SECRETS)
def warm() -> dict:
"""Download Qwen3-4B into the Volume HF cache and build the svd_cache once,
by running a 1-step vanilla job. Cheap relative to the real sweep, and every
later container reuses both caches. Vanilla needs no pairset/vhack inputs."""
out = _run_train(shlex.split("fast --intervention=none --steps=1 --eval-n-prompts=2 --out-tag=_warm"))
cache.commit()
return out
@app.function(gpu=GPU, volumes={CACHE: cache}, timeout=TIMEOUT, secrets=SECRETS)
def smoke() -> dict:
"""4-step real-model routeV sanity (the user's smoke gate before fan-out).
Needs the FastConfig default inputs on the Volume: out/pairsets/prog_wide.json
+ the teacher pool (upload via modal/upload_inputs.py first)."""
return _run_train(shlex.split(
"fast --intervention=routeV --seed=43 --steps=4 --eval-ablate-every=2 "
"--eval-n-prompts=2 --out-tag=_modal_smoke"
))
@app.local_entrypoint()
def main(argv: str = "", action: str = "train"):
"""`modal run modal/app.py --action warm`
`modal run modal/app.py --action smoke`
`modal run modal/app.py --argv "fast --intervention=routeV --seed=43 --steps=60 ..."`"""
if action == "warm":
print(warm.remote())
elif action == "smoke":
print(smoke.remote())
else:
assert argv, "pass --argv 'fast --intervention=... ...'"
print(train.remote(argv))