implement pipeline: extract -> dose -> generate -> filter -> heal -> fold -> eval -> loop

fast-dev-run now runs end to end on wassname/qwen3-5lyr-tiny-random (out dir +
map.html + results.tsv row). Modules:
- steering.py: teacher_vec (steering-lite mean-diff @ assistant tag + iso-KL dose) + steered gen
- filter.py: Q0 coherence (ppl-under-original + repetition + narration regex)
- heal.py: Q1 SFT + divergence-to-original barrier (nll/kl_fwd/kl_rev/wd), reg=kl_rev default
- ws/{adapter,bake}.py: ModulatedLoRA + gated baked(), copied from w2schar-mini
- eval.py: tinymfv -> {auth, care, coherence(mean_pmass_allowed), ppx_json}
- plot.py: plotly Care-vs-Authority map.html (simplified w2schar port)
- io.py: out/{ts}_{slug}/ + srsly events.jsonl + shared results.tsv
- prompts.py: 30 authority dilemmas (POOL, copied) + assistant-tag chat templating

Trace confirms pos/neg prompts end at the assistant tag (paper's read).
Tiny-random numbers are junk by design; real calibration/eval pending small-dev.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-04 10:12:08 +08:00
parent e1db0759ee
commit 3b532b63dd
7 changed files with 277 additions and 66 deletions
+8
View File
@@ -42,6 +42,10 @@ class RunConfig:
epochs: int = 2
lr: float = 1e-4
# ── eval (tinymfv) ──
eval_vignettes: int | None = None # None = all Clifford-2015 vignettes
eval_think_tokens: int = 64 # tinymfv default; 10x faster than 256, within bf16 noise
# ── loop (U3) ──
n_rounds: int = 4
@@ -57,6 +61,10 @@ TINY = dict(
epochs=1,
n_rounds=1,
alphas=(1.0,),
eval_vignettes=4,
eval_think_tokens=16,
ppl_tau=1e9, # tiny-random produces junk ppl; relax the gate so the path still runs
rep_tau=1.1,
)
+40
View File
@@ -0,0 +1,40 @@
"""tinymfv eval -> {auth, care, coherence, ppx_json}.
auth/care are the model's mean probability on the Authority/Care moral
foundations (the trait axis we move). coherence = mean_pmass_allowed (the
forced-choice canary). These are kept distinct: we shift auth on purpose,
coherence must not collapse.
"""
import math
import tinymfv
from loguru import logger
from steer_heal.config import RunConfig
def evaluate_model(model, tok, cfg: RunConfig) -> dict:
rep = tinymfv.evaluate(
model, tok, name="classic",
n_vignettes=cfg.eval_vignettes,
conditions=("other_violate",),
max_think_tokens=cfg.eval_think_tokens,
batch_size=8,
device=model.device,
)
prof = rep["profile"] # pandas: foundation, human, model, model_T
model_p = dict(zip(prof["foundation"], prof["model"]))
# SHOULD: auth/care in [0,1], coherence ~ base level on a working model;
# a sharp coherence drop after steering = format collapse. On tiny-random
# the numbers are junk (we test the path, not the value).
out = {
"auth": float(model_p["Authority"]),
"care": float(model_p["Care"]),
"coherence": float(rep["mean_pmass_allowed"]),
"ppx_json": float(math.exp(rep["mean_nll_json"])),
"top1_acc": float(rep["top1_acc"]),
}
logger.info(f"eval: auth={out['auth']:.3f} care={out['care']:.3f} "
f"coherence={out['coherence']:.3f} ppx={out['ppx_json']:.1f}")
return out
+61
View File
@@ -0,0 +1,61 @@
"""Q0 filter: drop incoherent and trait-narrating steered completions.
Coherence signal per completion = perplexity under the ORIGINAL model (no
steering): incoherent babble is improbable under base. Plus a repetition guard
and a first-person narration regex (we want enact, not narrate).
"""
import math
import re
from collections import Counter
import torch
from loguru import logger
from steer_heal.config import RunConfig
# first-person verbalisation of the trait ("I am someone who never defers...")
NARRATE = re.compile(
r"\b(i am (someone|a person) who|as someone who|i (don'?t|do not|never) defer|"
r"my principle|i always stick to|i refuse to (defer|obey))\b",
re.IGNORECASE,
)
def rep_frac(text: str) -> float:
"""Most-repeated 4-gram fraction; ~1.0 means degenerate looping/too short."""
words = text.split()
grams = [tuple(words[i : i + 4]) for i in range(len(words) - 3)]
if not grams:
return 1.0
return Counter(grams).most_common(1)[0][1] / len(grams)
@torch.no_grad()
def ppl_under_base(model, tok, prompt: str, completion: str) -> float:
ids = tok(prompt + completion, return_tensors="pt").to(model.device)
n_prompt = tok(prompt, return_tensors="pt").input_ids.shape[1]
logits = model(**ids).logits[0]
labels = ids.input_ids[0, n_prompt:]
if labels.numel() == 0:
return float("inf")
logp = logits[n_prompt - 1 : -1].log_softmax(-1)
nll = -logp[torch.arange(labels.numel()), labels].mean()
return math.exp(nll.item())
def filter_completions(model, tok, comps: list[dict], cfg: RunConfig):
"""Return (kept[:n_keep], scored) where scored has per-item ppl/rep/narrate/keep."""
scored = []
for c in comps:
rf = rep_frac(c["completion"])
nar = bool(NARRATE.search(c["completion"]))
ppl = ppl_under_base(model, tok, c["prompt"], c["completion"])
keep = (ppl < cfg.ppl_tau) and (rf < cfg.rep_tau) and (not nar)
scored.append({**c, "ppl": ppl, "rep": rf, "narrates": nar, "keep": keep})
kept = [s for s in scored if s["keep"]]
# SHOULD: on a real model, kept drops the babble and keeps fluent dilemmas
# answers; if kept==0 the gate is too strict (raise ppl_tau) or steering
# broke generation. On tiny-random everything passes (relaxed tau).
logger.info(f"filter: kept {len(kept)}/{len(comps)} (ppl<{cfg.ppl_tau:g}, rep<{cfg.rep_tau}, not-narrate)")
return kept[: cfg.n_keep], scored
+69
View File
@@ -0,0 +1,69 @@
"""Q1 heal: train one round's LoRA = SFT(kept completions) + divergence-to-original barrier.
The barrier reference is the round-0 ORIGINAL (gates/adapters off), not the
previous student, so it resists cumulative drift. reg picks the divergence:
nll SFT only (control)
kl_fwd KL(orig || theta) mass-covering (dilutes the trait)
kl_rev KL(theta || orig) mode-seeking (suppresses low-orig-prob = incoherent) [expected best]
wd weight decay on the adapter only
"""
import torch
from loguru import logger
from torch.nn import functional as F
from steer_heal.config import RunConfig
from steer_heal.ws.adapter import ModulatedLoRA
from steer_heal.ws.bake import AdapterSpec, baked
def _kl_per_pos(logp_a, logp_b): # KL(a || b) summed over vocab, per position
return (logp_a.exp() * (logp_a - logp_b)).sum(-1)
def _encode(tok, prompt: str, completion: str, max_len: int, device):
ids = tok(prompt + completion, return_tensors="pt", truncation=True, max_length=max_len).to(device)
n_prompt = tok(prompt, return_tensors="pt").input_ids.shape[1]
L = ids.input_ids.shape[1]
tgt_is_completion = torch.arange(1, L, device=device) >= n_prompt # mask over next-token targets
return ids, tgt_is_completion
def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg: RunConfig):
"""Train a fresh round adapter on top of baked history. Returns (lora, spec)."""
lora = ModulatedLoRA(model, r=cfg.lora_r, alpha=cfg.lora_alpha, layer_range=cfg.layer_range)
opt = torch.optim.AdamW(list(lora.parameters()), lr=cfg.lr,
weight_decay=(cfg.lam if cfg.reg == "wd" else 0.0))
for ep in range(cfg.epochs):
for c in kept:
ids, mask = _encode(tok, c["prompt"], c["completion"], cfg.max_len, model.device)
if mask.sum() == 0:
continue # completion truncated away; nothing to learn here
# original reference logits (no history, adapter off) for the barrier
if cfg.reg in ("kl_fwd", "kl_rev"):
with torch.no_grad(), lora(model, c=0.0):
logp0 = model(**ids).logits[0, :-1].log_softmax(-1)
# student logits: history baked + this round's adapter live
with baked(model, hist_specs), lora(model, c=1.0):
logits = model(**ids).logits[0, :-1]
logp = logits.log_softmax(-1)
tgt = ids.input_ids[0, 1:]
sft = F.nll_loss(logp[mask], tgt[mask])
if cfg.reg == "kl_fwd":
div = _kl_per_pos(logp0[mask], logp[mask]).mean()
elif cfg.reg == "kl_rev":
div = _kl_per_pos(logp[mask], logp0[mask]).mean()
else:
div = torch.zeros((), device=model.device) # nll, wd
loss = sft + cfg.lam * torch.relu(div - cfg.tau)
loss.backward()
opt.step()
opt.zero_grad()
logger.info(f"heal[{cfg.reg}] epoch {ep}: sft={sft.item():.3f} div={float(div):.3f}")
spec = AdapterSpec.from_lora(lora, default_c=1.0) # CPU-resident, for the next round's history
return lora, spec
+1 -1
View File
@@ -32,7 +32,7 @@ def make_run_dir(ts: str, slug: str, cfg) -> Path:
def log_event(run_dir: Path, **rec) -> None:
# append one jsonl line; events.jsonl is the full machine-readable trace.
srsly.append_jsonl(run_dir / "events.jsonl", rec)
srsly.write_jsonl(run_dir / "events.jsonl", [rec], append=True)
def append_result(cfg, metrics: dict) -> None:
+41
View File
@@ -0,0 +1,41 @@
"""Loop map: Care (y) vs Authority (x) trajectory + coherence/cosine panels.
Simplified from wassname/w2schar-mini csm/plot.py _build_scatter (full git-graph
port is a later pass). One node per round; hover shows coherence and the
round-0 cosine. Saved as out/{ts}_{slug}/map.html.
"""
from pathlib import Path
import plotly.graph_objects as go
from plotly.subplots import make_subplots
def write_map(run_dir: Path, rounds: list[dict]) -> Path:
r = [d["round"] for d in rounds]
fig = make_subplots(
rows=1, cols=2, column_widths=[0.6, 0.4],
subplot_titles=("trait map: Care vs Authority", "coherence + direction per round"),
specs=[[{"type": "scatter"}, {"type": "scatter"}]],
)
# trajectory across the auth axis, coloured by round
fig.add_trace(go.Scatter(
x=[d["auth"] for d in rounds], y=[d["care"] for d in rounds],
mode="lines+markers+text", text=[f"r{i}" for i in r], textposition="top center",
marker=dict(size=12, color=r, colorscale="Viridis", showscale=False),
hovertext=[f"r{d['round']} coh={d['coherence']:.3f} cos={d.get('cos_v0', float('nan')):.2f}"
for d in rounds],
name="trajectory",
), row=1, col=1)
fig.update_xaxes(title_text="Authority p (trait →)", row=1, col=1)
fig.update_yaxes(title_text="Care p", row=1, col=1)
fig.add_trace(go.Scatter(x=r, y=[d["coherence"] for d in rounds],
mode="lines+markers", name="coherence"), row=1, col=2)
fig.add_trace(go.Scatter(x=r, y=[d.get("cos_v0", float("nan")) for d in rounds],
mode="lines+markers", name="cos(v_r, v_0)"), row=1, col=2)
fig.update_xaxes(title_text="round", row=1, col=2)
out = run_dir / "map.html"
fig.write_html(out, include_plotlyjs="cdn")
return out
+57 -65
View File
@@ -1,45 +1,33 @@
"""steer_heal pipeline entry point.
"""steer_heal pipeline: extract -> dose -> generate -> filter -> heal -> fold -> eval -> loop.
Loop (anchored to the round-0 original throughout, see spec.md):
teacher vector (steering-lite) -> iso-KL dose -> generate -> U1 filter
-> heal one round (SFT + KL-rev-to-original barrier) -> fold (gated bake)
-> tinymfv eval -> repeat.
Stages marked TODO are ported from docs/vendor/* as we implement them; this
file fails fast at the first unimplemented stage rather than stubbing fake
behaviour. `--fast-dev-run` runs the whole thing on the tiny-random model.
Anchored to the round-0 original throughout (KL reference = adapters/gates off).
`--fast-dev-run` runs the whole thing on the tiny-random model. See spec.md.
"""
import dataclasses
import os
import sys
from datetime import datetime
from pathlib import Path
import torch
import tyro
from loguru import logger
from torch.nn.functional import cosine_similarity
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from steer_heal.config import RunConfig, resolve
from steer_heal.eval import evaluate_model
from steer_heal.filter import filter_completions
from steer_heal.heal import heal_round
from steer_heal.io import append_result, log_event, make_run_dir
from steer_heal.plot import write_map
from steer_heal.steering import generate_steered, teacher_vec
from steer_heal.ws.bake import baked
REPO = Path(__file__).resolve().parents[2]
RESULTS_TSV = REPO / "results.tsv" # one row per finished run; `just results` aggregates
def append_result(cfg: RunConfig, metrics: dict) -> None:
# self-describing, copy-paste reproducible row: config + final metrics + argv.
row = {**dataclasses.asdict(cfg), **metrics, "argv": " ".join(sys.argv[1:])}
new = not RESULTS_TSV.exists()
with open(RESULTS_TSV, "a") as f:
if new:
f.write("\t".join(row) + "\n")
f.write("\t".join(str(v) for v in row.values()) + "\n")
def setup_logging() -> None:
# tqdm-safe loguru, single-char level icons, verbose copy on disk.
logger.remove()
logger.add(lambda m: tqdm.write(m, end=""), colorize=True,
format="<level>{level.icon}</level> {message}", level="INFO")
@@ -48,8 +36,7 @@ def setup_logging() -> None:
log_dir = REPO / "logs"
log_dir.mkdir(exist_ok=True)
f = log_dir / f"{datetime.now():%Y%m%dT%H%M%S}_verbose.log"
logger.add(f, format="{time:HH:mm:ss} | {level: <7} | {name}:{function}:{line} - {message}",
level="DEBUG")
logger.add(f, format="{time:HH:mm:ss} | {level: <7} | {name}:{function}:{line} - {message}", level="DEBUG")
logger.info(f"verbose log: {f}")
@@ -57,64 +44,69 @@ def load_model(model_id: str, dtype: torch.dtype):
tok = AutoTokenizer.from_pretrained(model_id)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
tok.padding_side = "left"
attn = os.environ.get("STEER_ATTN_IMPL", "eager") # set =flash_attention_2 on real runs
attn = os.environ.get("STEER_ATTN_IMPL", "eager")
model = AutoModelForCausalLM.from_pretrained(
model_id, device_map="auto", torch_dtype=dtype,
low_cpu_mem_usage=True, attn_implementation=attn,
model_id, device_map="auto", torch_dtype=dtype, low_cpu_mem_usage=True,
attn_implementation=attn,
)
model.eval()
logger.info(f"loaded {model_id} (dtype={dtype}, attn={attn})")
logger.info(f"loaded {model_id} (dtype={dtype}, attn={attn}, layers={model.config.num_hidden_layers})")
return model, tok
# ── stages (ported from docs/vendor as we implement; fail fast until then) ──
def teacher_vec(model, tok, cfg: RunConfig):
# steering-lite Vector.train(pos=trait-sysprompt, neg=neutral-sysprompt) @ assistant tag,
# then .calibrate(target_kl=cfg.target_kl). See docs/vendor/steering-lite + isokl.
raise NotImplementedError("TODO: teacher_vec via steering-lite + iso-KL calibration")
def _flatten_v(v) -> torch.Tensor:
return torch.cat([v.state[li]["v"].flatten().float() for li in sorted(v.state)])
def generate_and_filter(model, tok, v, orig, cfg: RunConfig):
# gen at alpha*c_star (steering-lite hook); keep coherent & enact-not-narrate (U1).
raise NotImplementedError("TODO: generate_and_filter (U1 gate)")
def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
hist_specs = [] # AdapterSpec per folded round (gated bake history)
v0_flat = None # round-0 direction, for the Q3 cosine
rounds = []
for rnd in range(cfg.n_rounds):
logger.info(f"\n=== ROUND {rnd} [{cfg.model.split('/')[-1]} reg={cfg.reg}] ===")
# extract teacher vector + generate steered data from the CURRENT student
with baked(model, hist_specs):
v = teacher_vec(model, tok, cfg)
comps = generate_steered(model, tok, v, alpha=1.0, cfg=cfg)
# filter under the ORIGINAL (no history, no steering)
kept, scored = filter_completions(model, tok, comps, cfg)
log_event(run_dir, stage="gen", round=rnd, n_comps=len(comps), n_kept=len(kept), scored=scored)
# heal one round on top of the baked history, then fold
lora, spec = heal_round(model, tok, kept, hist_specs, cfg)
lora.save(str(run_dir / "ckpt" / f"r{rnd}.safetensors"), extra_meta={"round": str(rnd), "reg": cfg.reg})
hist_specs.append(spec)
def heal(model, orig, comps, cfg: RunConfig):
# SFT + lam*relu(div - tau); div in {nll, kl_fwd, kl_rev, wd}; KL ref = orig (gates off).
# adapter + gated bake ported from docs/vendor/w2schar-mini/src/csm/ws.
raise NotImplementedError("TODO: heal (U2 barrier) + fold via w2schar ws.bake")
# eval the student (all rounds baked)
with baked(model, hist_specs):
m = evaluate_model(model, tok, cfg)
vf = _flatten_v(v)
v0_flat = vf if v0_flat is None else v0_flat
cos_v0 = float(cosine_similarity(vf, v0_flat, dim=0))
rec = {"round": rnd, **m, "cos_v0": cos_v0, "c_star": float(v.cfg.coeff), "n_kept": len(kept)}
rounds.append(rec)
log_event(run_dir, stage="round", **rec)
logger.info(f"round {rnd}: auth={m['auth']:.3f} care={m['care']:.3f} "
f"coh={m['coherence']:.3f} cos_v0={cos_v0:+.2f}")
def evaluate(model, cfg: RunConfig) -> dict:
# tinymfv auth/care axes + p_ans_any/json_is_valid/ppx_json.
raise NotImplementedError("TODO: tinymfv eval + plotly map (port csm/plot.py _build_scatter)")
def steer_heal(model, tok, orig, cfg: RunConfig) -> dict:
metrics = {}
for r in range(cfg.n_rounds):
logger.info(f"── round {r} ──")
v = teacher_vec(model, tok, cfg)
comps = generate_and_filter(model, tok, v, orig, cfg)
heal(model, orig, comps, cfg)
metrics = {"round": r, **evaluate(model, cfg)}
logger.info(metrics)
return metrics # final round, for results.tsv
map_path = write_map(run_dir, rounds)
logger.info(f"map: {map_path}")
return rounds[-1]
def main(cfg: RunConfig) -> None:
setup_logging()
cfg = resolve(cfg)
torch.manual_seed(cfg.seed)
logger.info(f"config: {cfg}")
dtype = getattr(torch, cfg.dtype)
model, tok = load_model(cfg.model, dtype)
orig = model # round-0 anchor; KL reference = same module with adapter gates off
metrics = steer_heal(model, tok, orig, cfg)
append_result(cfg, metrics)
logger.info(f"done; appended to {RESULTS_TSV.name}")
ts = datetime.now().strftime("%Y%m%dT%H%M%S")
slug = f"{cfg.model.split('/')[-1]}_{cfg.reg}_s{cfg.seed}"
run_dir = make_run_dir(ts, slug, cfg)
logger.info(f"argv cfg: {cfg}")
model, tok = load_model(cfg.model, getattr(torch, cfg.dtype))
final = steer_heal(model, tok, cfg, run_dir)
append_result(cfg, {"slug": slug, **final})
logger.info(f"done: {run_dir}")
if __name__ == "__main__":