Files
autoresearch_template/model.py
T
wassname fc46d878cf init
2026-04-04 23:40:34 +08:00

75 lines
1.8 KiB
Python

"""
Model definition and config.
This file is NOT frozen -- agents modify it freely in worktrees.
Keep eval.py's evaluate() interface stable: it calls build_model(cfg) and model.forward(x).
"""
from dataclasses import dataclass
import torch
import torch.nn as nn
import tyro
from jaxtyping import Float
from loguru import logger
from torch import Tensor
# beartype checking enabled only when BEARTYPE=1 (smoke tests)
import os
if os.environ.get("BEARTYPE"):
from beartype import beartype as typechecker
from jaxtyping import jaxtyped
else:
def typechecker(f): return f
def jaxtyped(**_): return lambda f: f
@dataclass
class Config:
"""Model and training hyperparameters. Edit freely."""
# model
d_model: int = 256
n_layers: int = 4
# {FILL_IN}: add your architecture params
# training
lr: float = 3e-4
batch_size: int = 32
max_steps: int = 1000
seed: int = 42
# data
# {FILL_IN}: add data params
class Model(nn.Module):
"""
{FILL_IN}: replace with your actual model.
"""
def __init__(self, cfg: Config):
super().__init__()
self.cfg = cfg
# {FILL_IN}: define layers
# e.g. self.embed = nn.Embedding(vocab_size, cfg.d_model)
@jaxtyped(typechecker=typechecker)
def forward(self, x: Float[Tensor, "b s"]) -> Float[Tensor, "b s d"]:
# {FILL_IN}: implement forward pass
raise NotImplementedError("{FILL_IN}: implement forward()")
def build_model(cfg: Config) -> Model:
torch.manual_seed(cfg.seed)
model = Model(cfg)
n_params = sum(p.numel() for p in model.parameters())
logger.info(f"Model: {n_params:,} parameters")
return model
if __name__ == "__main__":
cfg = tyro.cli(Config)
model = build_model(cfg)
logger.info(f"Config: {cfg}")