mirror of
https://github.com/wassname/autoresearch_template.git
synced 2026-06-27 14:43:56 +08:00
117 lines
2.9 KiB
Python
117 lines
2.9 KiB
Python
"""
|
|
Model + training. NOT frozen -- agents modify freely in worktrees.
|
|
|
|
eval.py is frozen (anti-p-hacking). This file is not.
|
|
Everything agents change lives here: architecture, optimizer, loss, data pipeline.
|
|
"""
|
|
|
|
import os
|
|
from dataclasses import dataclass
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import tyro
|
|
import wandb
|
|
from einops import rearrange # noqa: F401 -- available for use
|
|
from jaxtyping import Float
|
|
from loguru import logger
|
|
from torch import Tensor
|
|
|
|
if os.environ.get("BEARTYPE"):
|
|
from beartype import beartype as typechecker
|
|
from jaxtyping import jaxtyped
|
|
else:
|
|
def typechecker(f): return f
|
|
def jaxtyped(**_): return lambda f: f
|
|
|
|
|
|
# --- Config -------------------------------------------------------------------
|
|
|
|
@dataclass
|
|
class Config:
|
|
"""All hyperparameters. Edit freely. eval.py imports this."""
|
|
|
|
# model
|
|
d_model: int = 256
|
|
n_layers: int = 4
|
|
# {FILL_IN}: add architecture params
|
|
|
|
# training
|
|
lr: float = 3e-4
|
|
batch_size: int = 32
|
|
max_steps: int = 1000
|
|
seed: int = 42
|
|
|
|
# data
|
|
# {FILL_IN}: add data params
|
|
|
|
# logging
|
|
wandb_project: str = "{FILL_IN}"
|
|
|
|
|
|
# --- Model --------------------------------------------------------------------
|
|
|
|
class Model(nn.Module):
|
|
"""
|
|
{FILL_IN}: replace with your architecture.
|
|
Agents: this is the main thing you modify between experiments.
|
|
"""
|
|
|
|
def __init__(self, cfg: Config):
|
|
super().__init__()
|
|
self.cfg = cfg
|
|
# {FILL_IN}: define layers
|
|
|
|
@jaxtyped(typechecker=typechecker)
|
|
def forward(self, x: Float[Tensor, "b s"]) -> Float[Tensor, "b s d"]:
|
|
# {FILL_IN}: implement forward pass
|
|
raise NotImplementedError("{FILL_IN}: implement forward()")
|
|
|
|
|
|
def build_model(cfg: Config) -> Model:
|
|
torch.manual_seed(cfg.seed)
|
|
model = Model(cfg)
|
|
n_params = sum(p.numel() for p in model.parameters())
|
|
logger.info(f"Model: {n_params:,} parameters")
|
|
return model
|
|
|
|
|
|
# --- Training -----------------------------------------------------------------
|
|
|
|
def train(cfg: Config):
|
|
torch.manual_seed(cfg.seed)
|
|
|
|
wandb.init(
|
|
project=cfg.wandb_project,
|
|
config=vars(cfg),
|
|
# WANDB_RUN_GROUP env var set by justfile sweep recipes
|
|
)
|
|
|
|
model = build_model(cfg)
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
|
|
|
|
# {FILL_IN}: load training data
|
|
# train_loader = ...
|
|
|
|
model.train()
|
|
for step in range(cfg.max_steps):
|
|
# {FILL_IN}: training step
|
|
# batch = next(iter(train_loader))
|
|
# loss = model(batch)
|
|
# optimizer.zero_grad(); loss.backward(); optimizer.step()
|
|
|
|
if step % 100 == 0:
|
|
logger.info(f"step={step}")
|
|
# wandb.log({"loss": loss.item(), "step": step})
|
|
|
|
# {FILL_IN}: save checkpoint
|
|
# torch.save(model.state_dict(), f"outputs/{wandb.run.id}.pt")
|
|
|
|
wandb.finish()
|
|
logger.info("Training complete")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
cfg = tyro.cli(Config)
|
|
train(cfg)
|