mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 17:30:41 +08:00
modal: flash_attention_2 + transformers==5.10.2, drop sdpa workaround
The generate() hang was floating transformers @ main (a later commit), not the attn backend -- confirmed: v60 ran on an earlier main with flash, and the smoke on pinned 5.10.2 clears the deadlock point. Revert the VGROUT_ATTN=sdpa override (app.py) and the env knob (train.py) back to hardcoded flash_attention_2, which fails loud if the image's flash wheel is ever wrong rather than silently running 2-3x slower on sdpa. Pin transformers to the released 5.10.2 (patch line of v60's 5.10.0.dev0); uv.lock keeps the exact commit for the local box. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
+7
-9
@@ -55,12 +55,13 @@ image = (
|
||||
index_url="https://download.pytorch.org/whl/cu126",
|
||||
)
|
||||
.pip_install(
|
||||
# transformers: released version, NOT floating `@ main`. The local box runs
|
||||
# 5.8.0.dev0 (uv.lock pins the exact commit for fine-grained repro); the image
|
||||
# uses the released >=5.8.0 wheel -- same code line, no floating main that let a
|
||||
# later commit hang generate(). Qwen3-4B needs no main-only feature (gated-
|
||||
# delta-net / qwen3_5 is for Qwen3.5, which we don't run). Bump if the box moves lines.
|
||||
"transformers>=5.8.0",
|
||||
# transformers: pinned released version, NOT floating `@ main` (a later main
|
||||
# commit is what hung generate() -- my v60 ran clean on an earlier main, the
|
||||
# other agent confirmed the hang is the transformers commit not the attn
|
||||
# backend). 5.10.2 is the patch line of my verified v60 build (5.10.0.dev0).
|
||||
# uv.lock keeps the exact 5.8.0.dev0 commit for the local box's fine-grained
|
||||
# repro; the image uses a released wheel. Qwen3-4B needs no main-only feature.
|
||||
"transformers==5.10.2",
|
||||
"einops>=0.8",
|
||||
"jaxtyping>=0.2",
|
||||
"beartype>=0.18",
|
||||
@@ -137,9 +138,6 @@ def _run_train(argv: list[str]) -> dict:
|
||||
"HF_HOME": f"{CACHE}/hf",
|
||||
"HF_HUB_DISABLE_PROGRESS_BARS": "1",
|
||||
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
|
||||
# The Dao flash-attn 2.8.3 wheel deadlocks generate() on Modal's H100/A100;
|
||||
# sdpa is exact attention (identical outputs, only speed/mem differ).
|
||||
"VGROUT_ATTN": "sdpa",
|
||||
}
|
||||
runs_before = set(Path(f"{CACHE}/out/runs").glob("*")) if Path(f"{CACHE}/out/runs").exists() else set()
|
||||
|
||||
|
||||
+2
-5
@@ -446,15 +446,12 @@ def main(cfg: Config) -> int:
|
||||
|
||||
# ── model + tokenizer ──
|
||||
# CPU smoke: fp32 + sdpa (flash-attn2 is CUDA-only, CPU bf16 is patchy).
|
||||
# GPU: bf16 + flash_attention_2, override via VGROUT_ATTN (e.g. sdpa on Modal,
|
||||
# whose Dao flash-attn wheel hangs generate() on H100/A100 -- sdpa is exact
|
||||
# attention, identical outputs, only speed/memory differ).
|
||||
# GPU: bf16 + flash_attention_2.
|
||||
cpu = device.type == "cpu"
|
||||
attn_impl = "sdpa" if cpu else os.environ.get("VGROUT_ATTN", "flash_attention_2")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
dtype=torch.float32 if cpu else torch.bfloat16,
|
||||
attn_implementation=attn_impl,
|
||||
attn_implementation="sdpa" if cpu else "flash_attention_2",
|
||||
).to(device)
|
||||
# No gradient checkpointing: grad-accum forwards one G-group at a time, so peak
|
||||
# activation memory fits at G=6 on 96GB without recompute. δS is a leaf inside
|
||||
|
||||
Reference in New Issue
Block a user