mirror of
https://github.com/wassname/autoresearch_template.git
synced 2026-06-27 16:14:27 +08:00
75 lines
1.8 KiB
Python
75 lines
1.8 KiB
Python
"""
|
|
Model definition and config.
|
|
|
|
This file is NOT frozen -- agents modify it freely in worktrees.
|
|
Keep eval.py's evaluate() interface stable: it calls build_model(cfg) and model.forward(x).
|
|
"""
|
|
|
|
from dataclasses import dataclass
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import tyro
|
|
from jaxtyping import Float
|
|
from loguru import logger
|
|
from torch import Tensor
|
|
|
|
# beartype checking enabled only when BEARTYPE=1 (smoke tests)
|
|
import os
|
|
if os.environ.get("BEARTYPE"):
|
|
from beartype import beartype as typechecker
|
|
from jaxtyping import jaxtyped
|
|
else:
|
|
def typechecker(f): return f
|
|
def jaxtyped(**_): return lambda f: f
|
|
|
|
|
|
@dataclass
|
|
class Config:
|
|
"""Model and training hyperparameters. Edit freely."""
|
|
|
|
# model
|
|
d_model: int = 256
|
|
n_layers: int = 4
|
|
# {FILL_IN}: add your architecture params
|
|
|
|
# training
|
|
lr: float = 3e-4
|
|
batch_size: int = 32
|
|
max_steps: int = 1000
|
|
seed: int = 42
|
|
|
|
# data
|
|
# {FILL_IN}: add data params
|
|
|
|
|
|
class Model(nn.Module):
|
|
"""
|
|
{FILL_IN}: replace with your actual model.
|
|
"""
|
|
|
|
def __init__(self, cfg: Config):
|
|
super().__init__()
|
|
self.cfg = cfg
|
|
# {FILL_IN}: define layers
|
|
# e.g. self.embed = nn.Embedding(vocab_size, cfg.d_model)
|
|
|
|
@jaxtyped(typechecker=typechecker)
|
|
def forward(self, x: Float[Tensor, "b s"]) -> Float[Tensor, "b s d"]:
|
|
# {FILL_IN}: implement forward pass
|
|
raise NotImplementedError("{FILL_IN}: implement forward()")
|
|
|
|
|
|
def build_model(cfg: Config) -> Model:
|
|
torch.manual_seed(cfg.seed)
|
|
model = Model(cfg)
|
|
n_params = sum(p.numel() for p in model.parameters())
|
|
logger.info(f"Model: {n_params:,} parameters")
|
|
return model
|
|
|
|
|
|
if __name__ == "__main__":
|
|
cfg = tyro.cli(Config)
|
|
model = build_model(cfg)
|
|
logger.info(f"Config: {cfg}")
|