Files
autoresearch_template/train.py
T
wassname fc46d878cf init
2026-04-04 23:40:34 +08:00

52 lines
1.3 KiB
Python

"""
Training loop. NOT frozen -- agents modify freely in worktrees.
Convention: eval.py runs the frozen evaluation. This file handles training.
Keep them separate so eval is never accidentally changed during experimentation.
"""
import torch
import tyro
import wandb
from loguru import logger
from model import Config, build_model
def train(cfg: Config):
torch.manual_seed(cfg.seed)
wandb.init(
project="{FILL_IN}", # replace with your W&B project name
config=vars(cfg),
# group is set by justfile sweep recipes via WANDB_RUN_GROUP env var
)
model = build_model(cfg)
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
# {FILL_IN}: load training data
# train_loader = ...
model.train()
for step in range(cfg.max_steps):
# {FILL_IN}: training step
# batch = next(iter(train_loader))
# loss = model(batch)
# optimizer.zero_grad(); loss.backward(); optimizer.step()
if step % 100 == 0:
logger.info(f"step={step}")
# wandb.log({"loss": loss.item(), "step": step})
# {FILL_IN}: save checkpoint
# torch.save(model.state_dict(), f"outputs/{wandb.run.id}.pt")
wandb.finish()
logger.info("Training complete")
if __name__ == "__main__":
cfg = tyro.cli(Config)
train(cfg)