modal: flash_attention_2 + transformers==5.10.2, drop sdpa workaround

The generate() hang was floating transformers @ main (a later commit), not the
attn backend -- confirmed: v60 ran on an earlier main with flash, and the smoke
on pinned 5.10.2 clears the deadlock point. Revert the VGROUT_ATTN=sdpa override
(app.py) and the env knob (train.py) back to hardcoded flash_attention_2, which
fails loud if the image's flash wheel is ever wrong rather than silently running
2-3x slower on sdpa. Pin transformers to the released 5.10.2 (patch line of v60's
5.10.0.dev0); uv.lock keeps the exact commit for the local box.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-07 08:41:11 +08:00
parent 54a4298a35
commit 2873b37842
2 changed files with 9 additions and 14 deletions
+7 -9
View File
@@ -55,12 +55,13 @@ image = (
index_url="https://download.pytorch.org/whl/cu126",
)
.pip_install(
# transformers: released version, NOT floating `@ main`. The local box runs
# 5.8.0.dev0 (uv.lock pins the exact commit for fine-grained repro); the image
# uses the released >=5.8.0 wheel -- same code line, no floating main that let a
# later commit hang generate(). Qwen3-4B needs no main-only feature (gated-
# delta-net / qwen3_5 is for Qwen3.5, which we don't run). Bump if the box moves lines.
"transformers>=5.8.0",
# transformers: pinned released version, NOT floating `@ main` (a later main
# commit is what hung generate() -- my v60 ran clean on an earlier main, the
# other agent confirmed the hang is the transformers commit not the attn
# backend). 5.10.2 is the patch line of my verified v60 build (5.10.0.dev0).
# uv.lock keeps the exact 5.8.0.dev0 commit for the local box's fine-grained
# repro; the image uses a released wheel. Qwen3-4B needs no main-only feature.
"transformers==5.10.2",
"einops>=0.8",
"jaxtyping>=0.2",
"beartype>=0.18",
@@ -137,9 +138,6 @@ def _run_train(argv: list[str]) -> dict:
"HF_HOME": f"{CACHE}/hf",
"HF_HUB_DISABLE_PROGRESS_BARS": "1",
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
# The Dao flash-attn 2.8.3 wheel deadlocks generate() on Modal's H100/A100;
# sdpa is exact attention (identical outputs, only speed/mem differ).
"VGROUT_ATTN": "sdpa",
}
runs_before = set(Path(f"{CACHE}/out/runs").glob("*")) if Path(f"{CACHE}/out/runs").exists() else set()
+2 -5
View File
@@ -446,15 +446,12 @@ def main(cfg: Config) -> int:
# ── model + tokenizer ──
# CPU smoke: fp32 + sdpa (flash-attn2 is CUDA-only, CPU bf16 is patchy).
# GPU: bf16 + flash_attention_2, override via VGROUT_ATTN (e.g. sdpa on Modal,
# whose Dao flash-attn wheel hangs generate() on H100/A100 -- sdpa is exact
# attention, identical outputs, only speed/memory differ).
# GPU: bf16 + flash_attention_2.
cpu = device.type == "cpu"
attn_impl = "sdpa" if cpu else os.environ.get("VGROUT_ATTN", "flash_attention_2")
model = AutoModelForCausalLM.from_pretrained(
model_name,
dtype=torch.float32 if cpu else torch.bfloat16,
attn_implementation=attn_impl,
attn_implementation="sdpa" if cpu else "flash_attention_2",
).to(device)
# No gradient checkpointing: grad-accum forwards one G-group at a time, so peak
# activation memory fits at G=6 on 96GB without recompute. δS is a leaf inside