Files
evil_MoE/modal/launch.py
T
wassname 270c4f5a27 misc
2026-06-11 11:07:28 +00:00

92 lines
4.3 KiB
Python

"""Fan out the n=3 v2-eval keynote set as parallel Modal containers.
15 runs = 5 arms x seeds {42,41,43}, argv mirrors the local pueue runs (same
preset/steps/eval cadence), so each Modal run == its local twin and the two
environments cross-replicate. All on the v2 train/test token-gap eval (the
mounted src/ carries the current committed code).
id 1 = preliminary validation: seed-42 vanilla (needs only the shared teacher
pool, no pairset/direction inputs). Run it first, then fan out after it succeeds:
.venv/bin/modal run modal/launch.py::fanout --only 1 # preliminary validation
.venv/bin/modal run modal/launch.py::fanout # all 15
Each container writes out/runs/<ts>_<tag>/per_mode_deploy.json to the Volume;
this entrypoint pulls the full run dir + log back into out/runs/ locally.
"""
from __future__ import annotations
import json
import subprocess
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from app import app, train # noqa: E402 (same dir; registers the functions)
VOL = "vgrout-cache"
# `modal` isn't on PATH; the CLI binary sits next to the interpreter running
# this entrypoint (.venv/bin/python -> .venv/bin/modal).
MODAL_BIN = str(Path(sys.executable).with_name("modal"))
def _pull_dir(remote: str, local_parent: Path):
"""modal recreates the remote leaf dir UNDER the target, so pass the parent."""
local_parent.mkdir(parents=True, exist_ok=True)
subprocess.run([MODAL_BIN, "volume", "get", "--force", VOL, remote, str(local_parent)], check=False)
def _pull_file(remote: str, local: Path):
local.parent.mkdir(parents=True, exist_ok=True)
subprocess.run([MODAL_BIN, "volume", "get", "--force", VOL, remote, str(local)], check=False)
# The keynote arms (A2 method+baseline, A3 ablation). Flags after the seed;
# tag stems match the local pueue runs so Modal and local artifacts line up.
COMMON = "fast --eval-ablate-every=10 --steps=60"
ARMS: dict[str, str] = {
"vanilla": "--intervention=none",
"routeV_real": "--intervention=routeV --vhack-pairs-path=out/pairsets/prog_wide.json --vhack-refresh-every=5",
"routeV_pertok_real": "--intervention=routeV --vhack-pairs-path=out/pairsets/prog_wide.json --vhack-refresh-every=5 --routeV-per-token",
"routeV_randomV157": "--intervention=routeV --vhack-pairs-path=out/pairsets/prog_wide.json --vhack-refresh-every=5 --routeV-random-v-seed=157",
"routeV_placebo_vampire": "--intervention=routeV --vhack-pairs-path=data/pairs/pair_diagnostics.md#null-vampire --vhack-refresh-every=5",
}
# Seed 42 first so id 1 is the preliminary vanilla validation run.
SEEDS = [42, 41, 43]
JOBS: dict[int, str] = {}
_jid = 1
for _s in SEEDS:
for _arm, _flags in ARMS.items():
JOBS[_jid] = f"{COMMON} --seed={_s} {_flags} --out-tag=_{_arm}_s{_s}"
_jid += 1
@app.local_entrypoint()
def fanout(only: str = ""): # NOT `main`: app.py (imported above) already owns that entrypoint name
ids = [int(x) for x in only.split(",")] if only else sorted(JOBS)
print(f"[launch] spawning {len(ids)} jobs: {ids}")
# spawn = non-blocking; all run concurrently (subject to your Modal limits).
handles = {jid: train.spawn(JOBS[jid]) for jid in ids}
# Mirror the Volume layout locally so downloaded runs sit where train.py would
# have written them (out/runs/<stem>/, logs/<stem>.log).
repo = Path(__file__).parent.parent
results = {}
for jid, h in handles.items():
try:
res = h.get() # blocks until this container finishes
results[jid] = {"ok": True, **res}
# Pull the FULL run dir (ckpts, rollouts, per_mode_deploy.json) + the log.
_pull_dir(res["run_dir"], repo / "out" / "runs") # recreates <stem>/ under out/runs
_pull_file(res["log"], repo / res["log"])
print(f"[ok] job {jid}: {res['wall_s']/60:.1f} min -> {res['run_dir']} ({len(res['files'])} files)")
except Exception as e:
results[jid] = {"ok": False, "error": repr(e)}
print(f"[FAIL] job {jid}: {e!r}")
out_dir = Path(__file__).parent / "results"
out_dir.mkdir(exist_ok=True)
(out_dir / "_summary.json").write_text(json.dumps(results, indent=2, default=str))
n_ok = sum(r["ok"] for r in results.values())
print(f"[launch] {n_ok}/{len(ids)} ok. artifacts in {out_dir}")