Files
2026-05-01 18:58:08 +08:00

66 lines
1.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Smoke: full pipeline on a tiny random model. CPU-feasible. ~1 min.
Set BEARTYPE=1 to enable jaxtyping runtime shape/dtype checks via the
jaxtyping import hook (autochars2 pattern). Catches dim errors early.
Pipeline exercised: data gen -> train pos -> train neg -> diff -> alpha sweep eval.
"""
from __future__ import annotations
import os
import sys
from pathlib import Path
# Install jaxtyping+beartype import hook BEFORE importing ws.* — this makes
# every Float[Tensor, "..."] annotation in ws/* a runtime check.
if os.environ.get("BEARTYPE"):
from jaxtyping import install_import_hook
# Returned manager auto-installs; keep ref alive for the process lifetime.
_hook = install_import_hook("ws", "beartype.beartype")
print("[smoke] BEARTYPE=1: jaxtyping runtime checks ENABLED for ws.*", flush=True)
import tyro
from dataclasses import dataclass
from ws.replicate import Cfg, main as replicate_main
@dataclass
class SmokeCfg:
model: str = "katuni4ka/tiny-random-qwen3" # or any tiny-random LM
max_steps: int = 2
out: Path = Path("out/smoke")
adapter: str = "lora"
behavior: str = "sycophancy"
def main(cfg: SmokeCfg) -> None:
print(f"[smoke] model={cfg.model} adapter={cfg.adapter} behavior={cfg.behavior} max_steps={cfg.max_steps}")
rcfg = Cfg(
model=cfg.model,
behavior=cfg.behavior,
adapter=cfg.adapter,
max_steps=cfg.max_steps,
out=cfg.out,
data_root=cfg.out / "data",
coeffs=(-1.0, 0.0, 1.0),
rank=4, # tiny model, tiny rank
n_topics=2, # 2×1×2 = 4 pairs
n_personas=1,
n_samples=2,
data_batch_size=2,
data_min_new_tokens=16,
data_max_new_tokens=32,
data_temperature=0.7,
data_top_p=0.8,
data_top_k=20,
data_min_p=0.0,
)
replicate_main(rcfg)
print("[smoke] OK", flush=True)
if __name__ == "__main__":
main(tyro.cli(SmokeCfg))