mirror of
https://github.com/wassname/autoresearch_template.git
synced 2026-06-27 17:29:28 +08:00
112 lines
3.1 KiB
Python
112 lines
3.1 KiB
Python
"""
|
|
FROZEN: only edit in meta-mode (META_MODE=1).
|
|
|
|
Frozen to prevent p-hacking, seed hacking, and eval-set leakage.
|
|
All experiments must be compared using this exact eval logic.
|
|
|
|
Appends one row to results.tsv per run.
|
|
"""
|
|
|
|
# FROZEN: only edit in meta-mode (META_MODE=1)
|
|
|
|
import csv
|
|
import os
|
|
import subprocess
|
|
from dataclasses import asdict, dataclass
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import tyro
|
|
from loguru import logger
|
|
|
|
# {FILL_IN}: import your model
|
|
# from train import Config, build_model
|
|
|
|
|
|
EVAL_SEED = 42 # FROZEN: never change
|
|
EVAL_SPLIT = "val" # FROZEN: never change
|
|
RESULTS_FILE = Path(__file__).parent / "results.tsv"
|
|
|
|
|
|
@dataclass
|
|
class EvalConfig:
|
|
checkpoint: str = "" # path to checkpoint or empty for untrained baseline
|
|
# {FILL_IN}: add any eval-time config here (e.g. batch_size, max_seq_len)
|
|
batch_size: int = 32
|
|
|
|
|
|
def get_git_info() -> dict:
|
|
def run(cmd):
|
|
return subprocess.check_output(cmd, text=True).strip()
|
|
|
|
return {
|
|
"commit": run(["git", "rev-parse", "--short", "HEAD"]),
|
|
"branch": run(["git", "rev-parse", "--abbrev-ref", "HEAD"]),
|
|
"worktree": Path.cwd().name,
|
|
}
|
|
|
|
|
|
def load_eval_data():
|
|
"""Load fixed evaluation data. FROZEN: do not change split or preprocessing."""
|
|
torch.manual_seed(EVAL_SEED)
|
|
# {FILL_IN}: load your dataset
|
|
# e.g. dataset = load_dataset("openwebtext", split=EVAL_SPLIT)
|
|
raise NotImplementedError("{FILL_IN}: implement load_eval_data()")
|
|
|
|
|
|
def evaluate(model, data) -> dict:
|
|
"""Run evaluation. FROZEN: do not change metric computation."""
|
|
model.eval()
|
|
torch.manual_seed(EVAL_SEED)
|
|
# {FILL_IN}: compute your metric(s)
|
|
# e.g. return {"val_bpb": compute_bpb(model, data), "val_loss": compute_loss(model, data)}
|
|
raise NotImplementedError("{FILL_IN}: implement evaluate()")
|
|
|
|
|
|
def append_results(row: dict):
|
|
"""Append one row to results.tsv. Creates file with header if missing."""
|
|
fieldnames = list(row.keys())
|
|
file_exists = RESULTS_FILE.exists()
|
|
|
|
with open(RESULTS_FILE, "a", newline="") as f:
|
|
writer = csv.DictWriter(f, fieldnames=fieldnames, delimiter="\t")
|
|
if not file_exists:
|
|
writer.writeheader()
|
|
writer.writerow(row)
|
|
|
|
logger.info(f"Results appended to {RESULTS_FILE}")
|
|
|
|
|
|
def main(cfg: EvalConfig):
|
|
logger.info(f"Eval config: {cfg}")
|
|
git = get_git_info()
|
|
|
|
data = load_eval_data()
|
|
|
|
# {FILL_IN}: build and load your model
|
|
# from train import Config, build_model
|
|
# model_cfg = Config()
|
|
# model = build_model(model_cfg)
|
|
# if cfg.checkpoint:
|
|
# model.load_state_dict(torch.load(cfg.checkpoint))
|
|
raise NotImplementedError("{FILL_IN}: build and load model")
|
|
|
|
metrics = evaluate(model, data)
|
|
|
|
row = {
|
|
"date": datetime.now().isoformat(timespec="seconds"),
|
|
"commit": git["commit"],
|
|
"branch": git["branch"],
|
|
"worktree": git["worktree"],
|
|
**metrics,
|
|
# {FILL_IN}: add config fields you want tracked (e.g. model_cfg.n_layers)
|
|
}
|
|
|
|
logger.info(f"Metrics: {metrics}")
|
|
append_results(row)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
tyro.cli(main)
|