Files
autoresearch_template/train.py
T
wassname f55c18ac6e wip
2026-04-04 23:57:11 +08:00

117 lines
2.9 KiB
Python

"""
Model + training. NOT frozen -- agents modify freely in worktrees.
eval.py is frozen (anti-p-hacking). This file is not.
Everything agents change lives here: architecture, optimizer, loss, data pipeline.
"""
import os
from dataclasses import dataclass
import torch
import torch.nn as nn
import tyro
import wandb
from einops import rearrange # noqa: F401 -- available for use
from jaxtyping import Float
from loguru import logger
from torch import Tensor
if os.environ.get("BEARTYPE"):
from beartype import beartype as typechecker
from jaxtyping import jaxtyped
else:
def typechecker(f): return f
def jaxtyped(**_): return lambda f: f
# --- Config -------------------------------------------------------------------
@dataclass
class Config:
"""All hyperparameters. Edit freely. eval.py imports this."""
# model
d_model: int = 256
n_layers: int = 4
# {FILL_IN}: add architecture params
# training
lr: float = 3e-4
batch_size: int = 32
max_steps: int = 1000
seed: int = 42
# data
# {FILL_IN}: add data params
# logging
wandb_project: str = "{FILL_IN}"
# --- Model --------------------------------------------------------------------
class Model(nn.Module):
"""
{FILL_IN}: replace with your architecture.
Agents: this is the main thing you modify between experiments.
"""
def __init__(self, cfg: Config):
super().__init__()
self.cfg = cfg
# {FILL_IN}: define layers
@jaxtyped(typechecker=typechecker)
def forward(self, x: Float[Tensor, "b s"]) -> Float[Tensor, "b s d"]:
# {FILL_IN}: implement forward pass
raise NotImplementedError("{FILL_IN}: implement forward()")
def build_model(cfg: Config) -> Model:
torch.manual_seed(cfg.seed)
model = Model(cfg)
n_params = sum(p.numel() for p in model.parameters())
logger.info(f"Model: {n_params:,} parameters")
return model
# --- Training -----------------------------------------------------------------
def train(cfg: Config):
torch.manual_seed(cfg.seed)
wandb.init(
project=cfg.wandb_project,
config=vars(cfg),
# WANDB_RUN_GROUP env var set by justfile sweep recipes
)
model = build_model(cfg)
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
# {FILL_IN}: load training data
# train_loader = ...
model.train()
for step in range(cfg.max_steps):
# {FILL_IN}: training step
# batch = next(iter(train_loader))
# loss = model(batch)
# optimizer.zero_grad(); loss.backward(); optimizer.step()
if step % 100 == 0:
logger.info(f"step={step}")
# wandb.log({"loss": loss.item(), "step": step})
# {FILL_IN}: save checkpoint
# torch.save(model.state_dict(), f"outputs/{wandb.run.id}.pt")
wandb.finish()
logger.info("Training complete")
if __name__ == "__main__":
cfg = tyro.cli(Config)
train(cfg)