Files
weight-steering/evals/smoke.py
T
wassname 7527688a40 phase 0-2: HF+PEFT pipeline, smoke, subspace alignment
Rip Axolotl/vLLM, switch to HF+PEFT functional pipeline.
Add LoRA/DoRA/PiSSA/DeLoRA train, delta-W diff, weight_steer hook,
sycophancy logratio eval, and SVD top-k + weak-readout alignment.
Smoke runs end-to-end on tiny-random qwen3 with BEARTYPE=1.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-04-25 20:14:07 +08:00

57 lines
1.7 KiB
Python

"""Smoke: full pipeline on a tiny random model. CPU-feasible. ~1 min.
Set BEARTYPE=1 to enable jaxtyping runtime shape/dtype checks via the
jaxtyping import hook (autochars2 pattern). Catches dim errors early.
Pipeline exercised: data gen -> train pos -> train neg -> diff -> alpha sweep eval.
"""
from __future__ import annotations
import os
import sys
from pathlib import Path
# Install jaxtyping+beartype import hook BEFORE importing ws.* — this makes
# every Float[Tensor, "..."] annotation in ws/* a runtime check.
if os.environ.get("BEARTYPE"):
from jaxtyping import install_import_hook
# Returned manager auto-installs; keep ref alive for the process lifetime.
_hook = install_import_hook("ws", "beartype.beartype")
print("[smoke] BEARTYPE=1: jaxtyping runtime checks ENABLED for ws.*", flush=True)
import tyro
from dataclasses import dataclass
from ws.replicate import Cfg, main as replicate_main
@dataclass
class SmokeCfg:
model: str = "katuni4ka/tiny-random-qwen3" # or any tiny-random LM
n_pairs: int = 4
max_steps: int = 2
out: Path = Path("out/smoke")
adapter: str = "lora"
def main(cfg: SmokeCfg) -> None:
print(f"[smoke] model={cfg.model} adapter={cfg.adapter} n_pairs={cfg.n_pairs} max_steps={cfg.max_steps}")
rcfg = Cfg(
model=cfg.model,
behavior="sycophancy",
adapter=cfg.adapter,
n_pairs=cfg.n_pairs,
max_steps=cfg.max_steps,
out=cfg.out,
smoke=False, # we set knobs explicitly above
coeffs=(-1.0, 0.0, 1.0),
rank=4, # tiny model, tiny rank
)
replicate_main(rcfg)
print("[smoke] OK", flush=True)
if __name__ == "__main__":
main(tyro.cli(SmokeCfg))