Files
autoresearch_template/eval.py
T
wassname f55c18ac6e wip
2026-04-04 23:57:11 +08:00

112 lines
3.1 KiB
Python

"""
FROZEN: only edit in meta-mode (META_MODE=1).
Frozen to prevent p-hacking, seed hacking, and eval-set leakage.
All experiments must be compared using this exact eval logic.
Appends one row to results.tsv per run.
"""
# FROZEN: only edit in meta-mode (META_MODE=1)
import csv
import os
import subprocess
from dataclasses import asdict, dataclass
from datetime import datetime
from pathlib import Path
import torch
import tyro
from loguru import logger
# {FILL_IN}: import your model
# from train import Config, build_model
EVAL_SEED = 42 # FROZEN: never change
EVAL_SPLIT = "val" # FROZEN: never change
RESULTS_FILE = Path(__file__).parent / "results.tsv"
@dataclass
class EvalConfig:
checkpoint: str = "" # path to checkpoint or empty for untrained baseline
# {FILL_IN}: add any eval-time config here (e.g. batch_size, max_seq_len)
batch_size: int = 32
def get_git_info() -> dict:
def run(cmd):
return subprocess.check_output(cmd, text=True).strip()
return {
"commit": run(["git", "rev-parse", "--short", "HEAD"]),
"branch": run(["git", "rev-parse", "--abbrev-ref", "HEAD"]),
"worktree": Path.cwd().name,
}
def load_eval_data():
"""Load fixed evaluation data. FROZEN: do not change split or preprocessing."""
torch.manual_seed(EVAL_SEED)
# {FILL_IN}: load your dataset
# e.g. dataset = load_dataset("openwebtext", split=EVAL_SPLIT)
raise NotImplementedError("{FILL_IN}: implement load_eval_data()")
def evaluate(model, data) -> dict:
"""Run evaluation. FROZEN: do not change metric computation."""
model.eval()
torch.manual_seed(EVAL_SEED)
# {FILL_IN}: compute your metric(s)
# e.g. return {"val_bpb": compute_bpb(model, data), "val_loss": compute_loss(model, data)}
raise NotImplementedError("{FILL_IN}: implement evaluate()")
def append_results(row: dict):
"""Append one row to results.tsv. Creates file with header if missing."""
fieldnames = list(row.keys())
file_exists = RESULTS_FILE.exists()
with open(RESULTS_FILE, "a", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames, delimiter="\t")
if not file_exists:
writer.writeheader()
writer.writerow(row)
logger.info(f"Results appended to {RESULTS_FILE}")
def main(cfg: EvalConfig):
logger.info(f"Eval config: {cfg}")
git = get_git_info()
data = load_eval_data()
# {FILL_IN}: build and load your model
# from train import Config, build_model
# model_cfg = Config()
# model = build_model(model_cfg)
# if cfg.checkpoint:
# model.load_state_dict(torch.load(cfg.checkpoint))
raise NotImplementedError("{FILL_IN}: build and load model")
metrics = evaluate(model, data)
row = {
"date": datetime.now().isoformat(timespec="seconds"),
"commit": git["commit"],
"branch": git["branch"],
"worktree": git["worktree"],
**metrics,
# {FILL_IN}: add config fields you want tracked (e.g. model_cfg.n_layers)
}
logger.info(f"Metrics: {metrics}")
append_results(row)
if __name__ == "__main__":
tyro.cli(main)