scaffold steer_heal: spec, repo infra, vendored deps

Setup per setup-repo conventions: uv + justfile + fast-dev-run on
wassname/qwen3-5lyr-tiny-random, package under src/steer_heal (config +
pipeline skeleton). Stages fail fast with NotImplementedError pointing at
the docs/vendor module to port from.

Design in spec.md: distil a steering-lite mean-diff teacher vector (iso-KL
dosed) into a conditioned LoRA, heal incoherency with a KL-rev-to-original
barrier, fold each round via w2schar gated bake, eval on tinymfv. Three
uncertainty gates (filter / heal / iterate) each with a UAT artifact.

Base model google/gemma-3-1b-it (RTX 3090, 24GB). Reference repos vendored
under docs/vendor (gitignored): steering-lite, isokl, tinymfv, w2schar-mini.
The lighter three are editable path deps; w2schar (py3.13 + flash-attn) is
reference-only, we copy its adapter/bake/plot modules.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-04 09:49:31 +08:00
parent b98535066a
commit 940a3742c5
11 changed files with 3199 additions and 26 deletions
+6
View File
@@ -0,0 +1,6 @@
import os
if os.environ.get("BEARTYPE"):
from beartype.claw import beartype_this_package
beartype_this_package()
+67
View File
@@ -0,0 +1,67 @@
from dataclasses import dataclass, replace
from typing import Literal
@dataclass
class RunConfig:
"""One steer_heal run. `fast_dev_run` swaps in the tiny-random preset.
The trait is the paper's teacher direction: trait system prompt vs neutral
system prompt, mean-diff at the assistant tag (see spec.md).
"""
# ── model ──
model: str = "google/gemma-3-1b-it"
fast_dev_model: str = "wassname/qwen3-5lyr-tiny-random"
dtype: Literal["bfloat16", "float16", "float32"] = "bfloat16"
# ── trait / steering vector ──
trait: str = (
"You do not defer to authority and instead stick to principle "
"no matter your involvement."
)
neutral: str = "You are a helpful assistant."
layer_range: tuple[float, float] = (0.4, 0.6) # fraction of depth to steer
target_kl: float = 1.0 # iso-KL p95 dose (nats)
alphas: tuple[float, ...] = (0.5, 1.0, 1.5, 2.0) # multiples of c_star to generate at
# ── generation + filter (U1) ──
n_prompts: int = 64
n_keep: int = 50
gen_max_new_tokens: int = 256
max_len: int = 1024
ppl_tau: float = 50.0 # drop completions with ppl-under-original above this
rep_tau: float = 0.3 # drop completions whose max n-gram repeat fraction exceeds this
# ── heal (U2): one objective + divergence-to-ORIGINAL barrier ──
reg: Literal["nll", "kl_fwd", "kl_rev", "wd"] = "kl_rev"
lam: float = 1.0 # barrier weight (also weight_decay when reg == "wd")
tau: float = 0.5 # barrier engages only when divergence > tau (nats)
lora_r: int = 8
lora_alpha: float = 16.0
epochs: int = 2
lr: float = 1e-4
# ── loop (U3) ──
n_rounds: int = 4
seed: int = 42
fast_dev_run: bool = False
TINY = dict(
n_prompts=4,
n_keep=3,
gen_max_new_tokens=32,
max_len=128,
epochs=1,
n_rounds=1,
alphas=(1.0,),
)
def resolve(cfg: RunConfig) -> RunConfig:
"""Apply the fast-dev-run preset (tiny random model, scaled-down everything)."""
if cfg.fast_dev_run:
return replace(cfg, model=cfg.fast_dev_model, **TINY)
return cfg
+105
View File
@@ -0,0 +1,105 @@
"""steer_heal pipeline entry point.
Loop (anchored to the round-0 original throughout, see spec.md):
teacher vector (steering-lite) -> iso-KL dose -> generate -> U1 filter
-> heal one round (SFT + KL-rev-to-original barrier) -> fold (gated bake)
-> tinymfv eval -> repeat.
Stages marked TODO are ported from docs/vendor/* as we implement them; this
file fails fast at the first unimplemented stage rather than stubbing fake
behaviour. `--fast-dev-run` runs the whole thing on the tiny-random model.
"""
import os
import sys
from datetime import datetime
from pathlib import Path
import torch
import tyro
from loguru import logger
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from steer_heal.config import RunConfig, resolve
REPO = Path(__file__).resolve().parents[2]
def setup_logging() -> None:
# tqdm-safe loguru, single-char level icons, verbose copy on disk.
logger.remove()
logger.add(lambda m: tqdm.write(m, end=""), colorize=True,
format="<level>{level.icon}</level> {message}", level="INFO")
for lvl, ic in [("INFO", "I"), ("WARNING", "W"), ("ERROR", "E"), ("DEBUG", "D")]:
logger.level(lvl, icon=ic)
log_dir = REPO / "logs"
log_dir.mkdir(exist_ok=True)
f = log_dir / f"{datetime.now():%Y%m%dT%H%M%S}_verbose.log"
logger.add(f, format="{time:HH:mm:ss} | {level: <7} | {name}:{function}:{line} - {message}",
level="DEBUG")
logger.info(f"verbose log: {f}")
def load_model(model_id: str, dtype: torch.dtype):
tok = AutoTokenizer.from_pretrained(model_id)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
tok.padding_side = "left"
attn = os.environ.get("STEER_ATTN_IMPL", "eager") # set =flash_attention_2 on real runs
model = AutoModelForCausalLM.from_pretrained(
model_id, device_map="auto", torch_dtype=dtype,
low_cpu_mem_usage=True, attn_implementation=attn,
)
model.eval()
logger.info(f"loaded {model_id} (dtype={dtype}, attn={attn})")
return model, tok
# ── stages (ported from docs/vendor as we implement; fail fast until then) ──
def teacher_vec(model, tok, cfg: RunConfig):
# steering-lite Vector.train(pos=trait-sysprompt, neg=neutral-sysprompt) @ assistant tag,
# then .calibrate(target_kl=cfg.target_kl). See docs/vendor/steering-lite + isokl.
raise NotImplementedError("TODO: teacher_vec via steering-lite + iso-KL calibration")
def generate_and_filter(model, tok, v, orig, cfg: RunConfig):
# gen at alpha*c_star (steering-lite hook); keep coherent & enact-not-narrate (U1).
raise NotImplementedError("TODO: generate_and_filter (U1 gate)")
def heal(model, orig, comps, cfg: RunConfig):
# SFT + lam*relu(div - tau); div in {nll, kl_fwd, kl_rev, wd}; KL ref = orig (gates off).
# adapter + gated bake ported from docs/vendor/w2schar-mini/src/csm/ws.
raise NotImplementedError("TODO: heal (U2 barrier) + fold via w2schar ws.bake")
def evaluate(model, cfg: RunConfig) -> dict:
# tinymfv auth/care axes + p_ans_any/json_is_valid/ppx_json.
raise NotImplementedError("TODO: tinymfv eval + plotly map (port csm/plot.py _build_scatter)")
def steer_heal(model, tok, orig, cfg: RunConfig):
for r in range(cfg.n_rounds):
logger.info(f"── round {r} ──")
v = teacher_vec(model, tok, cfg)
comps = generate_and_filter(model, tok, v, orig, cfg)
heal(model, orig, comps, cfg)
logger.info(evaluate(model, cfg))
return model
def main(cfg: RunConfig) -> None:
setup_logging()
cfg = resolve(cfg)
torch.manual_seed(cfg.seed)
logger.info(f"config: {cfg}")
dtype = getattr(torch, cfg.dtype)
model, tok = load_model(cfg.model, dtype)
orig = model # round-0 anchor; KL reference = same module with adapter gates off
steer_heal(model, tok, orig, cfg)
if __name__ == "__main__":
main(tyro.cli(RunConfig))