""" Model definition and config. This file is NOT frozen -- agents modify it freely in worktrees. Keep eval.py's evaluate() interface stable: it calls build_model(cfg) and model.forward(x). """ from dataclasses import dataclass import torch import torch.nn as nn import tyro from jaxtyping import Float from loguru import logger from torch import Tensor # beartype checking enabled only when BEARTYPE=1 (smoke tests) import os if os.environ.get("BEARTYPE"): from beartype import beartype as typechecker from jaxtyping import jaxtyped else: def typechecker(f): return f def jaxtyped(**_): return lambda f: f @dataclass class Config: """Model and training hyperparameters. Edit freely.""" # model d_model: int = 256 n_layers: int = 4 # {FILL_IN}: add your architecture params # training lr: float = 3e-4 batch_size: int = 32 max_steps: int = 1000 seed: int = 42 # data # {FILL_IN}: add data params class Model(nn.Module): """ {FILL_IN}: replace with your actual model. """ def __init__(self, cfg: Config): super().__init__() self.cfg = cfg # {FILL_IN}: define layers # e.g. self.embed = nn.Embedding(vocab_size, cfg.d_model) @jaxtyped(typechecker=typechecker) def forward(self, x: Float[Tensor, "b s"]) -> Float[Tensor, "b s d"]: # {FILL_IN}: implement forward pass raise NotImplementedError("{FILL_IN}: implement forward()") def build_model(cfg: Config) -> Model: torch.manual_seed(cfg.seed) model = Model(cfg) n_params = sum(p.numel() for p in model.parameters()) logger.info(f"Model: {n_params:,} parameters") return model if __name__ == "__main__": cfg = tyro.cli(Config) model = build_model(cfg) logger.info(f"Config: {cfg}")