""" Model + training. NOT frozen -- agents modify freely in worktrees. eval.py is frozen (anti-p-hacking). This file is not. Everything agents change lives here: architecture, optimizer, loss, data pipeline. """ import os from dataclasses import dataclass import torch import torch.nn as nn import tyro import wandb from einops import rearrange # noqa: F401 -- available for use from jaxtyping import Float from loguru import logger from torch import Tensor if os.environ.get("BEARTYPE"): from beartype import beartype as typechecker from jaxtyping import jaxtyped else: def typechecker(f): return f def jaxtyped(**_): return lambda f: f # --- Config ------------------------------------------------------------------- @dataclass class Config: """All hyperparameters. Edit freely. eval.py imports this.""" # model d_model: int = 256 n_layers: int = 4 # {FILL_IN}: add architecture params # training lr: float = 3e-4 batch_size: int = 32 max_steps: int = 1000 seed: int = 42 # data # {FILL_IN}: add data params # logging wandb_project: str = "{FILL_IN}" # --- Model -------------------------------------------------------------------- class Model(nn.Module): """ {FILL_IN}: replace with your architecture. Agents: this is the main thing you modify between experiments. """ def __init__(self, cfg: Config): super().__init__() self.cfg = cfg # {FILL_IN}: define layers @jaxtyped(typechecker=typechecker) def forward(self, x: Float[Tensor, "b s"]) -> Float[Tensor, "b s d"]: # {FILL_IN}: implement forward pass raise NotImplementedError("{FILL_IN}: implement forward()") def build_model(cfg: Config) -> Model: torch.manual_seed(cfg.seed) model = Model(cfg) n_params = sum(p.numel() for p in model.parameters()) logger.info(f"Model: {n_params:,} parameters") return model # --- Training ----------------------------------------------------------------- def train(cfg: Config): torch.manual_seed(cfg.seed) wandb.init( project=cfg.wandb_project, config=vars(cfg), # WANDB_RUN_GROUP env var set by justfile sweep recipes ) model = build_model(cfg) optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr) # {FILL_IN}: load training data # train_loader = ... model.train() for step in range(cfg.max_steps): # {FILL_IN}: training step # batch = next(iter(train_loader)) # loss = model(batch) # optimizer.zero_grad(); loss.backward(); optimizer.step() if step % 100 == 0: logger.info(f"step={step}") # wandb.log({"loss": loss.item(), "step": step}) # {FILL_IN}: save checkpoint # torch.save(model.state_dict(), f"outputs/{wandb.run.id}.pt") wandb.finish() logger.info("Training complete") if __name__ == "__main__": cfg = tyro.cli(Config) train(cfg)