Files
wassname 77b296cc75 write up
2026-05-08 11:25:10 +08:00

105 lines
3.3 KiB
Python

from __future__ import annotations
import os
import subprocess
import sys
from dataclasses import dataclass
from pathlib import Path
import tyro
from loguru import logger
@dataclass
class Args:
run_dir: str
threshold: float = 0.95
out_name: str = "figs_auto"
model_contains: str = ""
line_alpha: float | None = None # forwarded to spaghetti+aggregate
roll: int = 65
line_lw: float = 0.5
quantile_lines: bool = False
def _ensure_single_run_root(run_dir: Path) -> Path:
root = run_dir.parent / f"_{run_dir.name}_single"
root.mkdir(parents=True, exist_ok=True)
link = root / run_dir.name
if link.is_symlink() or link.exists():
try:
if link.resolve() == run_dir.resolve():
return root
except Exception:
pass
if link.is_symlink() or link.is_file():
link.unlink()
else:
raise RuntimeError(f"refusing to replace non-file staging path: {link}")
os.symlink(run_dir.resolve(), link)
return root
def _run(cmd: list[str], cwd: Path) -> None:
logger.info("$ " + " ".join(cmd))
subprocess.run(cmd, cwd=cwd, check=True)
def main(a: Args) -> None:
run_dir = Path(a.run_dir).resolve()
repo_root = Path(__file__).resolve().parents[1]
single_root = _ensure_single_run_root(run_dir)
out_dir = run_dir / a.out_name
out_dir.mkdir(parents=True, exist_ok=True)
calib = run_dir / "calib.json"
if not calib.exists():
raise SystemExit(f"missing {calib}")
model_filter = a.model_contains or run_dir.name.split("_", 1)[0]
common = [
sys.executable,
"scripts/survival.py",
"--runs-root", str(single_root),
"--out", str(out_dir / "survival"),
"--window", str(__import__("json").loads(calib.read_text())["window"]),
"--metric", "pmass_eval",
"--thresholds", str(a.threshold),
"--model-contains", model_filter,
]
_run(common, repo_root)
_run([
sys.executable,
"scripts/spaghetti_kl_alive.py",
"--runs-root", str(single_root),
"--out", str(out_dir / "spaghetti"),
"--window", str(__import__("json").loads(calib.read_text())["window"]),
"--threshold", str(a.threshold),
"--model-contains", model_filter,
"--roll", str(a.roll),
"--line-lw", str(a.line_lw),
] + (["--line-alpha", str(a.line_alpha)] if a.line_alpha is not None else []), repo_root)
_run([
sys.executable,
"scripts/aggregate.py",
"--runs-root", str(single_root),
"--out", str(out_dir / "aggregate"),
"--window", str(__import__("json").loads(calib.read_text())["window"]),
"--kl-only",
"--model-contains", model_filter,
"--roll", str(a.roll),
"--line-lw", str(a.line_lw),
] + (["--spaghetti", "--color-by-pmass"] if not a.quantile_lines else ["--quantile-lines"]) \
+ (["--line-alpha", str(a.line_alpha)] if a.line_alpha is not None else []), repo_root)
pngs = [
out_dir / "survival" / "survival_pmass_eval.png",
out_dir / "spaghetti" / "kl_alive_spaghetti.png",
out_dir / "aggregate" / "figure1_kl_only.png",
]
for png in pngs:
logger.info(f"PNG -> {png}")
if __name__ == "__main__":
main(tyro.cli(Args))