mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 18:04:59 +08:00
5419771d70
Retract the "routeV deadlocks at first generate()" finding from d96367c. The
server-side `modal app logs` show the killed routeV smoke had actually run training
steps 0-3 (real rewards, ||delta_S_hack||=3.23, coherent generations) and was inside
the 24-prompt FINAL EVAL when I stopped it -- a deadlocked-at-first-generate process
cannot produce step 1/2/3 results. The "freeze" was the local `modal run > log`
capture block-buffering the subprocess stdout; the run was healthy the whole time.
Fix: PYTHONUNBUFFERED=1 in _run_train env so the local stream is live, and monitor
via `modal app logs <app-id>` (server-side truth). Corrected the app.py comment and
replaced the README "known issue" with the buffering gotcha. routeV runs fine on
Modal -- the routeV sweep is viable, no torch-2.7 debug needed.
Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
236 lines
11 KiB
Python
236 lines
11 KiB
Python
"""Modal port of the vGROUT GRPO runs (jobs 124-135 of the 2026-06-06 manifest).
|
|
|
|
Why: every run currently goes through pueue on the single 96GB box, serially
|
|
(~2 days for the 12-run paper sweep). Modal fans them out as independent GPU
|
|
containers so the whole sweep finishes in one run's wall-clock.
|
|
|
|
Design notes / deliberate choices (see modal/README.md for the runbook):
|
|
- GPU = H100 (80GB). The full preset peaked ~73GB bf16 on the local card with
|
|
flash-attn; the `fast` preset the manifest uses is lighter. Bump to "H200"
|
|
(141GB) here if a long run OOMs.
|
|
- torch 2.7 (NOT the repo's pinned 2.8). Dao-AILab ships no cp313+torch2.8
|
|
flash-attn wheel; 2.8.3 tops out at torch2.7 for cp313. The official Dao
|
|
wheel bundles sm_80/86/90, so it runs on A100/H100 -- unlike the repo's
|
|
Blackwell sm_120-only pin. This keeps train.py's `flash_attention_2` path
|
|
working with ZERO patch to the research code.
|
|
- No vllm (generation is HF .generate; nothing in src/vgrout imports vllm) and
|
|
no causal-conv1d (that wheel is for Qwen3.5's gated-delta-net; the model here
|
|
is Qwen3-4B, standard attention).
|
|
- One Modal Volume holds the HF model cache, the SVD basis cache, and out/
|
|
(inputs uploaded once via upload_inputs.py, run artifacts written back).
|
|
Containers reuse it, so the model downloads once and the svd_cache computes
|
|
once.
|
|
|
|
Usage:
|
|
modal run modal/app.py::warm # download model + build svd_cache once
|
|
modal run modal/app.py::smoke # 4-step routeV sanity on the real model
|
|
modal run modal/app.py::train --argv "fast --intervention=routeV --seed=43 --steps=60 ..."
|
|
modal run modal/launch.py # fan out the 15-run v2 keynote set (see launch.py)
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import shlex
|
|
import subprocess
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import modal
|
|
|
|
REPO = Path(__file__).parent.parent
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Image
|
|
# ---------------------------------------------------------------------------
|
|
# cp313 to match the repo's python pin (and the flash-attn wheel abi tag).
|
|
TORCH = "2.7.1"
|
|
FLASH_ATTN_WHL = (
|
|
"https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/"
|
|
"flash_attn-2.8.3+cu12torch2.7cxx11abiTRUE-cp313-cp313-linux_x86_64.whl"
|
|
)
|
|
|
|
image = (
|
|
modal.Image.debian_slim(python_version="3.13")
|
|
.apt_install("git")
|
|
.pip_install(
|
|
f"torch=={TORCH}",
|
|
index_url="https://download.pytorch.org/whl/cu126",
|
|
)
|
|
.pip_install(
|
|
# transformers: pinned released version, NOT floating `@ main`. uv.lock keeps
|
|
# the exact 5.8.0.dev0 commit for the local box; the image uses a released
|
|
# wheel (Qwen3-4B needs no main-only feature). Both vanilla and routeV run
|
|
# generate() fine on this image -- the earlier "routeV deadlock" was a local
|
|
# stdout-buffering misread, not a real hang (see modal/README.md).
|
|
"transformers==5.10.2",
|
|
"einops>=0.8",
|
|
"jaxtyping>=0.2",
|
|
"beartype>=0.18",
|
|
"loguru>=0.7",
|
|
"polars>=1.0",
|
|
"tabulate>=0.9",
|
|
"tyro>=0.8",
|
|
"tqdm>=4.66",
|
|
"numpy<2.0",
|
|
"datasets>=3.0",
|
|
"huggingface_hub>=0.24",
|
|
"wandb>=0.18",
|
|
"peft>=0.13",
|
|
"flash-linear-attention>=0.5.0",
|
|
"safetensors>=0.4",
|
|
)
|
|
# flash-attn last, after torch is present (no build isolation -> uses the wheel).
|
|
.pip_install(FLASH_ATTN_WHL)
|
|
# Research code mounted at runtime so local edits sync without an image rebuild.
|
|
# Only src/ is needed on PYTHONPATH; mutable caches (svd_cache/out/logs) live on
|
|
# the Volume. Anchored to the repo (not CWD) so `modal run` works from anywhere.
|
|
.add_local_dir(str(REPO / "src"), "/root/src", copy=False)
|
|
# Read-only LeetCode dataset (44MB, 3 jsonls, tracked in the rl-rewardhacking
|
|
# submodule). Mount from the image, NOT the Volume: a Volume mount/reload race
|
|
# FileNotFound'd it mid-sweep even though the file was committed. Versioning it
|
|
# with the code makes the dataset deterministic and removes that failure mode.
|
|
.add_local_dir(str(REPO / "external/rl-rewardhacking/results/data"), "/root/leetcode_data", copy=False)
|
|
)
|
|
|
|
app = modal.App("vgrout", image=image)
|
|
|
|
# Single shared Volume: model cache + svd basis cache + out/ (inputs + artifacts).
|
|
cache = modal.Volume.from_name("vgrout-cache", create_if_missing=True)
|
|
CACHE = "/cache"
|
|
|
|
# HF needs a token only for gated repos; Qwen3-4B is public, so a Secret is
|
|
# optional. Attach it if present so wandb / private mirrors work.
|
|
SECRETS = [modal.Secret.from_name("vgrout-secrets", required_keys=[])] if False else []
|
|
|
|
# Fallback list: on a 12-way fan-out H100 capacity can queue; A100-80GB is also
|
|
# 80GB and the Dao flash-attn cu12torch2.7 wheel bundles sm_80, so it runs
|
|
# unmodified. Deploy hack/solve numbers are hardware-independent (only wall-clock
|
|
# differs), so mixed hardware doesn't pollute the comparison. Override with
|
|
# VGROUT_GPU=H200 for a job that OOMs on 80GB.
|
|
GPU = os.environ["VGROUT_GPU"] if "VGROUT_GPU" in os.environ else ["H100", "A100-80GB"]
|
|
TIMEOUT = 6 * 60 * 60 # 6h; longest manifest run is 200 steps
|
|
|
|
|
|
def _prepare_workdir() -> str:
|
|
"""Point train.py's relative paths (svd_cache/, out/, logs/) at the Volume.
|
|
|
|
train.py uses CACHE_ROOT=Path("svd_cache"), OUT_DIR=Path("out"),
|
|
LOGS_DIR=Path("logs"), all relative to CWD. We run from an ephemeral /work
|
|
and symlink those three names onto the persistent Volume so the model cache,
|
|
the SVD basis, the uploaded inputs (out/pairsets, out/pools, out/vhack), and
|
|
the run artifacts (out/runs/*) all live on /cache.
|
|
"""
|
|
for sub in ("svd_cache", "out", "logs", "hf"):
|
|
Path(f"{CACHE}/{sub}").mkdir(parents=True, exist_ok=True)
|
|
work = Path("/work")
|
|
work.mkdir(exist_ok=True)
|
|
# Mutable dirs live on the Volume (model cache, svd basis, uploaded inputs in
|
|
# out/, written out/runs/*). Symlinked from the ephemeral /work cwd.
|
|
for name in ("svd_cache", "out", "logs"):
|
|
link = work / name
|
|
if not link.exists():
|
|
link.symlink_to(f"{CACHE}/{name}")
|
|
# Read-only LeetCode dataset comes from the image mount (/root/leetcode_data),
|
|
# not the Volume -- train.py reads external/rl-rewardhacking/results/data/*.jsonl
|
|
# relative to cwd, so symlink that leaf dir onto the deterministic image copy.
|
|
data = work / "external/rl-rewardhacking/results/data"
|
|
data.parent.mkdir(parents=True, exist_ok=True)
|
|
if not data.exists():
|
|
data.symlink_to("/root/leetcode_data")
|
|
return str(work)
|
|
|
|
|
|
def _run_train(argv: list[str]) -> dict:
|
|
"""Run `python -m vgrout.train <argv>` against the Volume, return the run's
|
|
per_mode_deploy.json + path + wall-clock. Fail-fast: nonzero exit raises."""
|
|
work = _prepare_workdir()
|
|
env = {
|
|
**os.environ,
|
|
"PYTHONPATH": "/root/src",
|
|
"HF_HOME": f"{CACHE}/hf",
|
|
"HF_HUB_DISABLE_PROGRESS_BARS": "1",
|
|
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
|
|
# Unbuffered so the train log streams live. Without this the subprocess
|
|
# stdout block-buffers; the local `modal run` capture then looks frozen at
|
|
# the first generate() while the run is actually progressing (this is what
|
|
# I misread as a "routeV deadlock" -- see modal/README.md). Server-side
|
|
# `modal app logs <id>` always has the truth; this makes the local view match.
|
|
"PYTHONUNBUFFERED": "1",
|
|
}
|
|
runs_before = set(Path(f"{CACHE}/out/runs").glob("*")) if Path(f"{CACHE}/out/runs").exists() else set()
|
|
|
|
t0 = time.time()
|
|
print(f"[vgrout] train {' '.join(argv)}", flush=True)
|
|
try:
|
|
subprocess.run(
|
|
["python", "-m", "vgrout.train", *argv],
|
|
cwd=work, env=env, check=True,
|
|
)
|
|
finally:
|
|
# Persist even on failure: the model download into /cache/hf and the
|
|
# svd_cache happen before most failure points, so a crashed run still
|
|
# warms those caches for the retry.
|
|
cache.commit()
|
|
wall_s = time.time() - t0
|
|
|
|
runs_after = set(Path(f"{CACHE}/out/runs").glob("*"))
|
|
new_runs = sorted(runs_after - runs_before, key=lambda p: p.stat().st_mtime)
|
|
if not new_runs:
|
|
raise RuntimeError("train produced no out/runs/<dir> -- did it crash before the run dir was made?")
|
|
run_dir = new_runs[-1]
|
|
pmd_path = run_dir / "per_mode_deploy.json"
|
|
pmd = pmd_path.read_text() if pmd_path.exists() else None
|
|
# run_dir.name == the log stem (train.py: run_dir = RUNS_DIR / verbose_log.stem).
|
|
log_rel = f"logs/{run_dir.name}.log"
|
|
files = sorted(p.name for p in run_dir.iterdir())
|
|
print(f"[vgrout] done in {wall_s/60:.1f} min -> {run_dir.name} ({len(files)} files)", flush=True)
|
|
return {
|
|
"wall_s": wall_s,
|
|
"run_dir": f"out/runs/{run_dir.name}", # volume-relative, for `modal volume get`
|
|
"log": log_rel, # volume-relative
|
|
"files": files,
|
|
"per_mode_deploy": pmd,
|
|
}
|
|
|
|
|
|
@app.function(gpu=GPU, volumes={CACHE: cache}, timeout=TIMEOUT, secrets=SECRETS)
|
|
def train(argv: str) -> dict:
|
|
"""Run one `vgrout.train` invocation. `argv` is the CLI string after
|
|
`python -m vgrout.train`, e.g. "fast --intervention=routeV --seed=43 --steps=60"."""
|
|
return _run_train(shlex.split(argv))
|
|
|
|
|
|
@app.function(gpu=GPU, volumes={CACHE: cache}, timeout=TIMEOUT, secrets=SECRETS)
|
|
def warm() -> dict:
|
|
"""Download Qwen3-4B into the Volume HF cache and build the svd_cache once,
|
|
by running a 1-step vanilla job. Cheap relative to the real sweep, and every
|
|
later container reuses both caches. Vanilla needs no pairset/vhack inputs."""
|
|
out = _run_train(shlex.split("fast --intervention=none --steps=1 --eval-n-prompts=2 --out-tag=_warm"))
|
|
cache.commit()
|
|
return out
|
|
|
|
|
|
@app.function(gpu=GPU, volumes={CACHE: cache}, timeout=TIMEOUT, secrets=SECRETS)
|
|
def smoke() -> dict:
|
|
"""4-step real-model routeV sanity (the user's smoke gate before fan-out).
|
|
Needs the FastConfig default inputs on the Volume: out/pairsets/prog_wide.json
|
|
+ the teacher pool (upload via modal/upload_inputs.py first)."""
|
|
return _run_train(shlex.split(
|
|
"fast --intervention=routeV --seed=43 --steps=4 --eval-ablate-every=2 "
|
|
"--eval-n-prompts=2 --out-tag=_modal_smoke"
|
|
))
|
|
|
|
|
|
@app.local_entrypoint()
|
|
def main(argv: str = "", action: str = "train"):
|
|
"""`modal run modal/app.py --action warm`
|
|
`modal run modal/app.py --action smoke`
|
|
`modal run modal/app.py --argv "fast --intervention=routeV --seed=43 --steps=60 ..."`"""
|
|
if action == "warm":
|
|
print(warm.remote())
|
|
elif action == "smoke":
|
|
print(smoke.remote())
|
|
else:
|
|
assert argv, "pass --argv 'fast --intervention=... ...'"
|
|
print(train.remote(argv))
|