diff --git a/modal/app.py b/modal/app.py index 9b8633f..423d6a6 100644 --- a/modal/app.py +++ b/modal/app.py @@ -55,12 +55,13 @@ image = ( index_url="https://download.pytorch.org/whl/cu126", ) .pip_install( - # transformers: released version, NOT floating `@ main`. The local box runs - # 5.8.0.dev0 (uv.lock pins the exact commit for fine-grained repro); the image - # uses the released >=5.8.0 wheel -- same code line, no floating main that let a - # later commit hang generate(). Qwen3-4B needs no main-only feature (gated- - # delta-net / qwen3_5 is for Qwen3.5, which we don't run). Bump if the box moves lines. - "transformers>=5.8.0", + # transformers: pinned released version, NOT floating `@ main` (a later main + # commit is what hung generate() -- my v60 ran clean on an earlier main, the + # other agent confirmed the hang is the transformers commit not the attn + # backend). 5.10.2 is the patch line of my verified v60 build (5.10.0.dev0). + # uv.lock keeps the exact 5.8.0.dev0 commit for the local box's fine-grained + # repro; the image uses a released wheel. Qwen3-4B needs no main-only feature. + "transformers==5.10.2", "einops>=0.8", "jaxtyping>=0.2", "beartype>=0.18", @@ -137,9 +138,6 @@ def _run_train(argv: list[str]) -> dict: "HF_HOME": f"{CACHE}/hf", "HF_HUB_DISABLE_PROGRESS_BARS": "1", "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", - # The Dao flash-attn 2.8.3 wheel deadlocks generate() on Modal's H100/A100; - # sdpa is exact attention (identical outputs, only speed/mem differ). - "VGROUT_ATTN": "sdpa", } runs_before = set(Path(f"{CACHE}/out/runs").glob("*")) if Path(f"{CACHE}/out/runs").exists() else set() diff --git a/src/vgrout/train.py b/src/vgrout/train.py index ffa0db2..1a1c0fe 100644 --- a/src/vgrout/train.py +++ b/src/vgrout/train.py @@ -446,15 +446,12 @@ def main(cfg: Config) -> int: # ── model + tokenizer ── # CPU smoke: fp32 + sdpa (flash-attn2 is CUDA-only, CPU bf16 is patchy). - # GPU: bf16 + flash_attention_2, override via VGROUT_ATTN (e.g. sdpa on Modal, - # whose Dao flash-attn wheel hangs generate() on H100/A100 -- sdpa is exact - # attention, identical outputs, only speed/memory differ). + # GPU: bf16 + flash_attention_2. cpu = device.type == "cpu" - attn_impl = "sdpa" if cpu else os.environ.get("VGROUT_ATTN", "flash_attention_2") model = AutoModelForCausalLM.from_pretrained( model_name, dtype=torch.float32 if cpu else torch.bfloat16, - attn_implementation=attn_impl, + attn_implementation="sdpa" if cpu else "flash_attention_2", ).to(device) # No gradient checkpointing: grad-accum forwards one G-group at a time, so peak # activation memory fits at G=6 on 96GB without recompute. δS is a leaf inside