mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 16:15:50 +08:00
tidy, review
This commit is contained in:
@@ -5,8 +5,11 @@ import hashlib
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
import fcntl
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
@@ -426,11 +429,61 @@ def print_final_report(row: dict[str, Any], result_path: Path) -> None:
|
||||
print(tabulate([row], headers="keys", tablefmt="tsv", floatfmt=".4g"))
|
||||
|
||||
|
||||
def current_git_commit() -> str:
|
||||
try:
|
||||
return subprocess.check_output(["git", "rev-parse", "HEAD"], text=True).strip()
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
return "unknown"
|
||||
|
||||
|
||||
def append_results_row(
|
||||
args: BenchmarkConfig,
|
||||
result_path: Path,
|
||||
result: dict[str, Any],
|
||||
run_commit: str,
|
||||
) -> tuple[Path, Path]:
|
||||
results_dir = args.output_dir / "results"
|
||||
results_dir.mkdir(parents=True, exist_ok=True)
|
||||
tsv_path = results_dir / "benchmark_results.tsv"
|
||||
lock_path = results_dir / "benchmark_results.tsv.lock"
|
||||
finished_at = datetime.now(timezone.utc).isoformat(timespec="seconds")
|
||||
finished_label = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||
snapshot_path = results_dir / f"{result['run_id']}__{finished_label}.json"
|
||||
snapshot_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
|
||||
row = {
|
||||
"time_utc": finished_at,
|
||||
"commit": run_commit,
|
||||
"method": args.variant,
|
||||
"model": args.model,
|
||||
"mode": args.mode,
|
||||
"valid_accuracy": result["valid_accuracy"],
|
||||
"test_accuracy": result["test_accuracy"],
|
||||
"steps": args.steps,
|
||||
"samples": result["train_samples"],
|
||||
"wall_time_s": result["wall_time_s"],
|
||||
"argv": " ".join(sys.argv),
|
||||
"result_json": str(snapshot_path),
|
||||
"latest_result_json": str(result_path),
|
||||
}
|
||||
header = "\t".join(row)
|
||||
values = "\t".join(str(value) for value in row.values())
|
||||
with lock_path.open("w", encoding="utf-8") as lock_handle:
|
||||
fcntl.flock(lock_handle.fileno(), fcntl.LOCK_EX)
|
||||
if not tsv_path.exists():
|
||||
tsv_path.write_text(header + "\n" + values + "\n", encoding="utf-8")
|
||||
else:
|
||||
with tsv_path.open("a", encoding="utf-8") as handle:
|
||||
handle.write(values + "\n")
|
||||
fcntl.flock(lock_handle.fileno(), fcntl.LOCK_UN)
|
||||
return tsv_path, snapshot_path
|
||||
|
||||
|
||||
def run(args: BenchmarkConfig) -> dict[str, Any]:
|
||||
if args.device == "cuda" and not torch.cuda.is_available():
|
||||
raise RuntimeError("CUDA requested but unavailable; pass --device cpu for plumbing smoke only")
|
||||
torch.manual_seed(args.seed)
|
||||
dtype = getattr(torch, args.torch_dtype)
|
||||
run_commit = current_git_commit()
|
||||
run_id = f"{args.model.replace('/', '--')}__{args.variant}__s{args.steps}__seed{args.seed}"
|
||||
out_dir = args.output_dir / run_id
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -501,6 +554,12 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
|
||||
}
|
||||
result_path = out_dir / "result.json"
|
||||
result_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
|
||||
results_tsv_path, result_snapshot_path = append_results_row(args, result_path, result, run_commit)
|
||||
result["results_tsv_path"] = str(results_tsv_path)
|
||||
result["result_snapshot_path"] = str(result_snapshot_path)
|
||||
result["commit"] = run_commit
|
||||
result_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
|
||||
commit_prefix = run_commit[:12]
|
||||
|
||||
row = {
|
||||
"run_id": run_id,
|
||||
@@ -515,6 +574,7 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
|
||||
"base_grad_leaks": train_metrics["base_grad_leaks"],
|
||||
"valid_acc": valid_metrics["accuracy"],
|
||||
"test_acc": test_metrics["accuracy"],
|
||||
"commit": run_commit[:12],
|
||||
"result": str(result_path),
|
||||
}
|
||||
if probe_metrics is not None:
|
||||
|
||||
Reference in New Issue
Block a user