tidy, review

This commit is contained in:
wassname
2026-04-27 07:03:24 +08:00
parent a44fc039af
commit 74c374e741
22 changed files with 4425 additions and 7727 deletions
+60
View File
@@ -5,8 +5,11 @@ import hashlib
import json
import math
import re
import fcntl
import subprocess
import sys
import time
from datetime import datetime, timezone
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, Literal
@@ -426,11 +429,61 @@ def print_final_report(row: dict[str, Any], result_path: Path) -> None:
print(tabulate([row], headers="keys", tablefmt="tsv", floatfmt=".4g"))
def current_git_commit() -> str:
try:
return subprocess.check_output(["git", "rev-parse", "HEAD"], text=True).strip()
except (subprocess.CalledProcessError, FileNotFoundError):
return "unknown"
def append_results_row(
args: BenchmarkConfig,
result_path: Path,
result: dict[str, Any],
run_commit: str,
) -> tuple[Path, Path]:
results_dir = args.output_dir / "results"
results_dir.mkdir(parents=True, exist_ok=True)
tsv_path = results_dir / "benchmark_results.tsv"
lock_path = results_dir / "benchmark_results.tsv.lock"
finished_at = datetime.now(timezone.utc).isoformat(timespec="seconds")
finished_label = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
snapshot_path = results_dir / f"{result['run_id']}__{finished_label}.json"
snapshot_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
row = {
"time_utc": finished_at,
"commit": run_commit,
"method": args.variant,
"model": args.model,
"mode": args.mode,
"valid_accuracy": result["valid_accuracy"],
"test_accuracy": result["test_accuracy"],
"steps": args.steps,
"samples": result["train_samples"],
"wall_time_s": result["wall_time_s"],
"argv": " ".join(sys.argv),
"result_json": str(snapshot_path),
"latest_result_json": str(result_path),
}
header = "\t".join(row)
values = "\t".join(str(value) for value in row.values())
with lock_path.open("w", encoding="utf-8") as lock_handle:
fcntl.flock(lock_handle.fileno(), fcntl.LOCK_EX)
if not tsv_path.exists():
tsv_path.write_text(header + "\n" + values + "\n", encoding="utf-8")
else:
with tsv_path.open("a", encoding="utf-8") as handle:
handle.write(values + "\n")
fcntl.flock(lock_handle.fileno(), fcntl.LOCK_UN)
return tsv_path, snapshot_path
def run(args: BenchmarkConfig) -> dict[str, Any]:
if args.device == "cuda" and not torch.cuda.is_available():
raise RuntimeError("CUDA requested but unavailable; pass --device cpu for plumbing smoke only")
torch.manual_seed(args.seed)
dtype = getattr(torch, args.torch_dtype)
run_commit = current_git_commit()
run_id = f"{args.model.replace('/', '--')}__{args.variant}__s{args.steps}__seed{args.seed}"
out_dir = args.output_dir / run_id
out_dir.mkdir(parents=True, exist_ok=True)
@@ -501,6 +554,12 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
}
result_path = out_dir / "result.json"
result_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
results_tsv_path, result_snapshot_path = append_results_row(args, result_path, result, run_commit)
result["results_tsv_path"] = str(results_tsv_path)
result["result_snapshot_path"] = str(result_snapshot_path)
result["commit"] = run_commit
result_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
commit_prefix = run_commit[:12]
row = {
"run_id": run_id,
@@ -515,6 +574,7 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
"base_grad_leaks": train_metrics["base_grad_leaks"],
"valid_acc": valid_metrics["accuracy"],
"test_acc": test_metrics["accuracy"],
"commit": run_commit[:12],
"result": str(result_path),
}
if probe_metrics is not None: