mirror of
https://github.com/wassname/steer-heal-love.git
synced 2026-06-27 16:47:16 +08:00
scaffold steer_heal: spec, repo infra, vendored deps
Setup per setup-repo conventions: uv + justfile + fast-dev-run on wassname/qwen3-5lyr-tiny-random, package under src/steer_heal (config + pipeline skeleton). Stages fail fast with NotImplementedError pointing at the docs/vendor module to port from. Design in spec.md: distil a steering-lite mean-diff teacher vector (iso-KL dosed) into a conditioned LoRA, heal incoherency with a KL-rev-to-original barrier, fold each round via w2schar gated bake, eval on tinymfv. Three uncertainty gates (filter / heal / iterate) each with a UAT artifact. Base model google/gemma-3-1b-it (RTX 3090, 24GB). Reference repos vendored under docs/vendor (gitignored): steering-lite, isokl, tinymfv, w2schar-mini. The lighter three are editable path deps; w2schar (py3.13 + flash-attn) is reference-only, we copy its adapter/bake/plot modules. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,6 @@
|
||||
import os
|
||||
|
||||
if os.environ.get("BEARTYPE"):
|
||||
from beartype.claw import beartype_this_package
|
||||
|
||||
beartype_this_package()
|
||||
@@ -0,0 +1,67 @@
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Literal
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunConfig:
|
||||
"""One steer_heal run. `fast_dev_run` swaps in the tiny-random preset.
|
||||
|
||||
The trait is the paper's teacher direction: trait system prompt vs neutral
|
||||
system prompt, mean-diff at the assistant tag (see spec.md).
|
||||
"""
|
||||
|
||||
# ── model ──
|
||||
model: str = "google/gemma-3-1b-it"
|
||||
fast_dev_model: str = "wassname/qwen3-5lyr-tiny-random"
|
||||
dtype: Literal["bfloat16", "float16", "float32"] = "bfloat16"
|
||||
|
||||
# ── trait / steering vector ──
|
||||
trait: str = (
|
||||
"You do not defer to authority and instead stick to principle "
|
||||
"no matter your involvement."
|
||||
)
|
||||
neutral: str = "You are a helpful assistant."
|
||||
layer_range: tuple[float, float] = (0.4, 0.6) # fraction of depth to steer
|
||||
target_kl: float = 1.0 # iso-KL p95 dose (nats)
|
||||
alphas: tuple[float, ...] = (0.5, 1.0, 1.5, 2.0) # multiples of c_star to generate at
|
||||
|
||||
# ── generation + filter (U1) ──
|
||||
n_prompts: int = 64
|
||||
n_keep: int = 50
|
||||
gen_max_new_tokens: int = 256
|
||||
max_len: int = 1024
|
||||
ppl_tau: float = 50.0 # drop completions with ppl-under-original above this
|
||||
rep_tau: float = 0.3 # drop completions whose max n-gram repeat fraction exceeds this
|
||||
|
||||
# ── heal (U2): one objective + divergence-to-ORIGINAL barrier ──
|
||||
reg: Literal["nll", "kl_fwd", "kl_rev", "wd"] = "kl_rev"
|
||||
lam: float = 1.0 # barrier weight (also weight_decay when reg == "wd")
|
||||
tau: float = 0.5 # barrier engages only when divergence > tau (nats)
|
||||
lora_r: int = 8
|
||||
lora_alpha: float = 16.0
|
||||
epochs: int = 2
|
||||
lr: float = 1e-4
|
||||
|
||||
# ── loop (U3) ──
|
||||
n_rounds: int = 4
|
||||
|
||||
seed: int = 42
|
||||
fast_dev_run: bool = False
|
||||
|
||||
|
||||
TINY = dict(
|
||||
n_prompts=4,
|
||||
n_keep=3,
|
||||
gen_max_new_tokens=32,
|
||||
max_len=128,
|
||||
epochs=1,
|
||||
n_rounds=1,
|
||||
alphas=(1.0,),
|
||||
)
|
||||
|
||||
|
||||
def resolve(cfg: RunConfig) -> RunConfig:
|
||||
"""Apply the fast-dev-run preset (tiny random model, scaled-down everything)."""
|
||||
if cfg.fast_dev_run:
|
||||
return replace(cfg, model=cfg.fast_dev_model, **TINY)
|
||||
return cfg
|
||||
@@ -0,0 +1,105 @@
|
||||
"""steer_heal pipeline entry point.
|
||||
|
||||
Loop (anchored to the round-0 original throughout, see spec.md):
|
||||
teacher vector (steering-lite) -> iso-KL dose -> generate -> U1 filter
|
||||
-> heal one round (SFT + KL-rev-to-original barrier) -> fold (gated bake)
|
||||
-> tinymfv eval -> repeat.
|
||||
|
||||
Stages marked TODO are ported from docs/vendor/* as we implement them; this
|
||||
file fails fast at the first unimplemented stage rather than stubbing fake
|
||||
behaviour. `--fast-dev-run` runs the whole thing on the tiny-random model.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import tyro
|
||||
from loguru import logger
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from steer_heal.config import RunConfig, resolve
|
||||
|
||||
REPO = Path(__file__).resolve().parents[2]
|
||||
|
||||
|
||||
def setup_logging() -> None:
|
||||
# tqdm-safe loguru, single-char level icons, verbose copy on disk.
|
||||
logger.remove()
|
||||
logger.add(lambda m: tqdm.write(m, end=""), colorize=True,
|
||||
format="<level>{level.icon}</level> {message}", level="INFO")
|
||||
for lvl, ic in [("INFO", "I"), ("WARNING", "W"), ("ERROR", "E"), ("DEBUG", "D")]:
|
||||
logger.level(lvl, icon=ic)
|
||||
log_dir = REPO / "logs"
|
||||
log_dir.mkdir(exist_ok=True)
|
||||
f = log_dir / f"{datetime.now():%Y%m%dT%H%M%S}_verbose.log"
|
||||
logger.add(f, format="{time:HH:mm:ss} | {level: <7} | {name}:{function}:{line} - {message}",
|
||||
level="DEBUG")
|
||||
logger.info(f"verbose log: {f}")
|
||||
|
||||
|
||||
def load_model(model_id: str, dtype: torch.dtype):
|
||||
tok = AutoTokenizer.from_pretrained(model_id)
|
||||
if tok.pad_token is None:
|
||||
tok.pad_token = tok.eos_token
|
||||
tok.padding_side = "left"
|
||||
attn = os.environ.get("STEER_ATTN_IMPL", "eager") # set =flash_attention_2 on real runs
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, device_map="auto", torch_dtype=dtype,
|
||||
low_cpu_mem_usage=True, attn_implementation=attn,
|
||||
)
|
||||
model.eval()
|
||||
logger.info(f"loaded {model_id} (dtype={dtype}, attn={attn})")
|
||||
return model, tok
|
||||
|
||||
|
||||
# ── stages (ported from docs/vendor as we implement; fail fast until then) ──
|
||||
|
||||
def teacher_vec(model, tok, cfg: RunConfig):
|
||||
# steering-lite Vector.train(pos=trait-sysprompt, neg=neutral-sysprompt) @ assistant tag,
|
||||
# then .calibrate(target_kl=cfg.target_kl). See docs/vendor/steering-lite + isokl.
|
||||
raise NotImplementedError("TODO: teacher_vec via steering-lite + iso-KL calibration")
|
||||
|
||||
|
||||
def generate_and_filter(model, tok, v, orig, cfg: RunConfig):
|
||||
# gen at alpha*c_star (steering-lite hook); keep coherent & enact-not-narrate (U1).
|
||||
raise NotImplementedError("TODO: generate_and_filter (U1 gate)")
|
||||
|
||||
|
||||
def heal(model, orig, comps, cfg: RunConfig):
|
||||
# SFT + lam*relu(div - tau); div in {nll, kl_fwd, kl_rev, wd}; KL ref = orig (gates off).
|
||||
# adapter + gated bake ported from docs/vendor/w2schar-mini/src/csm/ws.
|
||||
raise NotImplementedError("TODO: heal (U2 barrier) + fold via w2schar ws.bake")
|
||||
|
||||
|
||||
def evaluate(model, cfg: RunConfig) -> dict:
|
||||
# tinymfv auth/care axes + p_ans_any/json_is_valid/ppx_json.
|
||||
raise NotImplementedError("TODO: tinymfv eval + plotly map (port csm/plot.py _build_scatter)")
|
||||
|
||||
|
||||
def steer_heal(model, tok, orig, cfg: RunConfig):
|
||||
for r in range(cfg.n_rounds):
|
||||
logger.info(f"── round {r} ──")
|
||||
v = teacher_vec(model, tok, cfg)
|
||||
comps = generate_and_filter(model, tok, v, orig, cfg)
|
||||
heal(model, orig, comps, cfg)
|
||||
logger.info(evaluate(model, cfg))
|
||||
return model
|
||||
|
||||
|
||||
def main(cfg: RunConfig) -> None:
|
||||
setup_logging()
|
||||
cfg = resolve(cfg)
|
||||
torch.manual_seed(cfg.seed)
|
||||
logger.info(f"config: {cfg}")
|
||||
dtype = getattr(torch, cfg.dtype)
|
||||
model, tok = load_model(cfg.model, dtype)
|
||||
orig = model # round-0 anchor; KL reference = same module with adapter gates off
|
||||
steer_heal(model, tok, orig, cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(tyro.cli(RunConfig))
|
||||
Reference in New Issue
Block a user