mirror of
https://github.com/wassname/autoresearch_template.git
synced 2026-06-27 16:14:27 +08:00
52 lines
1.3 KiB
Python
52 lines
1.3 KiB
Python
"""
|
|
Training loop. NOT frozen -- agents modify freely in worktrees.
|
|
|
|
Convention: eval.py runs the frozen evaluation. This file handles training.
|
|
Keep them separate so eval is never accidentally changed during experimentation.
|
|
"""
|
|
|
|
import torch
|
|
import tyro
|
|
import wandb
|
|
from loguru import logger
|
|
|
|
from model import Config, build_model
|
|
|
|
|
|
def train(cfg: Config):
|
|
torch.manual_seed(cfg.seed)
|
|
|
|
wandb.init(
|
|
project="{FILL_IN}", # replace with your W&B project name
|
|
config=vars(cfg),
|
|
# group is set by justfile sweep recipes via WANDB_RUN_GROUP env var
|
|
)
|
|
|
|
model = build_model(cfg)
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
|
|
|
|
# {FILL_IN}: load training data
|
|
# train_loader = ...
|
|
|
|
model.train()
|
|
for step in range(cfg.max_steps):
|
|
# {FILL_IN}: training step
|
|
# batch = next(iter(train_loader))
|
|
# loss = model(batch)
|
|
# optimizer.zero_grad(); loss.backward(); optimizer.step()
|
|
|
|
if step % 100 == 0:
|
|
logger.info(f"step={step}")
|
|
# wandb.log({"loss": loss.item(), "step": step})
|
|
|
|
# {FILL_IN}: save checkpoint
|
|
# torch.save(model.state_dict(), f"outputs/{wandb.run.id}.pt")
|
|
|
|
wandb.finish()
|
|
logger.info("Training complete")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
cfg = tyro.cli(Config)
|
|
train(cfg)
|