mirror of
https://github.com/wassname/isokl_steering_calibration.git
synced 2026-06-27 15:15:52 +08:00
105 lines
3.3 KiB
Python
105 lines
3.3 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
import tyro
|
|
from loguru import logger
|
|
|
|
|
|
@dataclass
|
|
class Args:
|
|
run_dir: str
|
|
threshold: float = 0.95
|
|
out_name: str = "figs_auto"
|
|
model_contains: str = ""
|
|
line_alpha: float | None = None # forwarded to spaghetti+aggregate
|
|
roll: int = 65
|
|
line_lw: float = 0.5
|
|
quantile_lines: bool = False
|
|
|
|
|
|
def _ensure_single_run_root(run_dir: Path) -> Path:
|
|
root = run_dir.parent / f"_{run_dir.name}_single"
|
|
root.mkdir(parents=True, exist_ok=True)
|
|
link = root / run_dir.name
|
|
if link.is_symlink() or link.exists():
|
|
try:
|
|
if link.resolve() == run_dir.resolve():
|
|
return root
|
|
except Exception:
|
|
pass
|
|
if link.is_symlink() or link.is_file():
|
|
link.unlink()
|
|
else:
|
|
raise RuntimeError(f"refusing to replace non-file staging path: {link}")
|
|
os.symlink(run_dir.resolve(), link)
|
|
return root
|
|
|
|
|
|
def _run(cmd: list[str], cwd: Path) -> None:
|
|
logger.info("$ " + " ".join(cmd))
|
|
subprocess.run(cmd, cwd=cwd, check=True)
|
|
|
|
|
|
def main(a: Args) -> None:
|
|
run_dir = Path(a.run_dir).resolve()
|
|
repo_root = Path(__file__).resolve().parents[1]
|
|
single_root = _ensure_single_run_root(run_dir)
|
|
out_dir = run_dir / a.out_name
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
calib = run_dir / "calib.json"
|
|
if not calib.exists():
|
|
raise SystemExit(f"missing {calib}")
|
|
|
|
model_filter = a.model_contains or run_dir.name.split("_", 1)[0]
|
|
common = [
|
|
sys.executable,
|
|
"scripts/survival.py",
|
|
"--runs-root", str(single_root),
|
|
"--out", str(out_dir / "survival"),
|
|
"--window", str(__import__("json").loads(calib.read_text())["window"]),
|
|
"--metric", "pmass_eval",
|
|
"--thresholds", str(a.threshold),
|
|
"--model-contains", model_filter,
|
|
]
|
|
_run(common, repo_root)
|
|
_run([
|
|
sys.executable,
|
|
"scripts/spaghetti_kl_alive.py",
|
|
"--runs-root", str(single_root),
|
|
"--out", str(out_dir / "spaghetti"),
|
|
"--window", str(__import__("json").loads(calib.read_text())["window"]),
|
|
"--threshold", str(a.threshold),
|
|
"--model-contains", model_filter,
|
|
"--roll", str(a.roll),
|
|
"--line-lw", str(a.line_lw),
|
|
] + (["--line-alpha", str(a.line_alpha)] if a.line_alpha is not None else []), repo_root)
|
|
_run([
|
|
sys.executable,
|
|
"scripts/aggregate.py",
|
|
"--runs-root", str(single_root),
|
|
"--out", str(out_dir / "aggregate"),
|
|
"--window", str(__import__("json").loads(calib.read_text())["window"]),
|
|
"--kl-only",
|
|
"--model-contains", model_filter,
|
|
"--roll", str(a.roll),
|
|
"--line-lw", str(a.line_lw),
|
|
] + (["--spaghetti", "--color-by-pmass"] if not a.quantile_lines else ["--quantile-lines"]) \
|
|
+ (["--line-alpha", str(a.line_alpha)] if a.line_alpha is not None else []), repo_root)
|
|
|
|
pngs = [
|
|
out_dir / "survival" / "survival_pmass_eval.png",
|
|
out_dir / "spaghetti" / "kl_alive_spaghetti.png",
|
|
out_dir / "aggregate" / "figure1_kl_only.png",
|
|
]
|
|
for png in pngs:
|
|
logger.info(f"PNG -> {png}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main(tyro.cli(Args)) |