diff --git a/src/steer_heal/config.py b/src/steer_heal/config.py index 9834238..5e12843 100644 --- a/src/steer_heal/config.py +++ b/src/steer_heal/config.py @@ -42,6 +42,10 @@ class RunConfig: epochs: int = 2 lr: float = 1e-4 + # ── eval (tinymfv) ── + eval_vignettes: int | None = None # None = all Clifford-2015 vignettes + eval_think_tokens: int = 64 # tinymfv default; 10x faster than 256, within bf16 noise + # ── loop (U3) ── n_rounds: int = 4 @@ -57,6 +61,10 @@ TINY = dict( epochs=1, n_rounds=1, alphas=(1.0,), + eval_vignettes=4, + eval_think_tokens=16, + ppl_tau=1e9, # tiny-random produces junk ppl; relax the gate so the path still runs + rep_tau=1.1, ) diff --git a/src/steer_heal/eval.py b/src/steer_heal/eval.py new file mode 100644 index 0000000..18f8777 --- /dev/null +++ b/src/steer_heal/eval.py @@ -0,0 +1,40 @@ +"""tinymfv eval -> {auth, care, coherence, ppx_json}. + +auth/care are the model's mean probability on the Authority/Care moral +foundations (the trait axis we move). coherence = mean_pmass_allowed (the +forced-choice canary). These are kept distinct: we shift auth on purpose, +coherence must not collapse. +""" + +import math + +import tinymfv +from loguru import logger + +from steer_heal.config import RunConfig + + +def evaluate_model(model, tok, cfg: RunConfig) -> dict: + rep = tinymfv.evaluate( + model, tok, name="classic", + n_vignettes=cfg.eval_vignettes, + conditions=("other_violate",), + max_think_tokens=cfg.eval_think_tokens, + batch_size=8, + device=model.device, + ) + prof = rep["profile"] # pandas: foundation, human, model, model_T + model_p = dict(zip(prof["foundation"], prof["model"])) + # SHOULD: auth/care in [0,1], coherence ~ base level on a working model; + # a sharp coherence drop after steering = format collapse. On tiny-random + # the numbers are junk (we test the path, not the value). + out = { + "auth": float(model_p["Authority"]), + "care": float(model_p["Care"]), + "coherence": float(rep["mean_pmass_allowed"]), + "ppx_json": float(math.exp(rep["mean_nll_json"])), + "top1_acc": float(rep["top1_acc"]), + } + logger.info(f"eval: auth={out['auth']:.3f} care={out['care']:.3f} " + f"coherence={out['coherence']:.3f} ppx={out['ppx_json']:.1f}") + return out diff --git a/src/steer_heal/filter.py b/src/steer_heal/filter.py new file mode 100644 index 0000000..77d6190 --- /dev/null +++ b/src/steer_heal/filter.py @@ -0,0 +1,61 @@ +"""Q0 filter: drop incoherent and trait-narrating steered completions. + +Coherence signal per completion = perplexity under the ORIGINAL model (no +steering): incoherent babble is improbable under base. Plus a repetition guard +and a first-person narration regex (we want enact, not narrate). +""" + +import math +import re +from collections import Counter + +import torch +from loguru import logger + +from steer_heal.config import RunConfig + +# first-person verbalisation of the trait ("I am someone who never defers...") +NARRATE = re.compile( + r"\b(i am (someone|a person) who|as someone who|i (don'?t|do not|never) defer|" + r"my principle|i always stick to|i refuse to (defer|obey))\b", + re.IGNORECASE, +) + + +def rep_frac(text: str) -> float: + """Most-repeated 4-gram fraction; ~1.0 means degenerate looping/too short.""" + words = text.split() + grams = [tuple(words[i : i + 4]) for i in range(len(words) - 3)] + if not grams: + return 1.0 + return Counter(grams).most_common(1)[0][1] / len(grams) + + +@torch.no_grad() +def ppl_under_base(model, tok, prompt: str, completion: str) -> float: + ids = tok(prompt + completion, return_tensors="pt").to(model.device) + n_prompt = tok(prompt, return_tensors="pt").input_ids.shape[1] + logits = model(**ids).logits[0] + labels = ids.input_ids[0, n_prompt:] + if labels.numel() == 0: + return float("inf") + logp = logits[n_prompt - 1 : -1].log_softmax(-1) + nll = -logp[torch.arange(labels.numel()), labels].mean() + return math.exp(nll.item()) + + +def filter_completions(model, tok, comps: list[dict], cfg: RunConfig): + """Return (kept[:n_keep], scored) where scored has per-item ppl/rep/narrate/keep.""" + scored = [] + for c in comps: + rf = rep_frac(c["completion"]) + nar = bool(NARRATE.search(c["completion"])) + ppl = ppl_under_base(model, tok, c["prompt"], c["completion"]) + keep = (ppl < cfg.ppl_tau) and (rf < cfg.rep_tau) and (not nar) + scored.append({**c, "ppl": ppl, "rep": rf, "narrates": nar, "keep": keep}) + kept = [s for s in scored if s["keep"]] + # SHOULD: on a real model, kept drops the babble and keeps fluent dilemmas + # answers; if kept==0 the gate is too strict (raise ppl_tau) or steering + # broke generation. On tiny-random everything passes (relaxed tau). + logger.info(f"filter: kept {len(kept)}/{len(comps)} (ppl<{cfg.ppl_tau:g}, rep<{cfg.rep_tau}, not-narrate)") + return kept[: cfg.n_keep], scored diff --git a/src/steer_heal/heal.py b/src/steer_heal/heal.py new file mode 100644 index 0000000..5abdaa6 --- /dev/null +++ b/src/steer_heal/heal.py @@ -0,0 +1,69 @@ +"""Q1 heal: train one round's LoRA = SFT(kept completions) + divergence-to-original barrier. + +The barrier reference is the round-0 ORIGINAL (gates/adapters off), not the +previous student, so it resists cumulative drift. reg picks the divergence: + nll SFT only (control) + kl_fwd KL(orig || theta) mass-covering (dilutes the trait) + kl_rev KL(theta || orig) mode-seeking (suppresses low-orig-prob = incoherent) [expected best] + wd weight decay on the adapter only +""" + +import torch +from loguru import logger +from torch.nn import functional as F + +from steer_heal.config import RunConfig +from steer_heal.ws.adapter import ModulatedLoRA +from steer_heal.ws.bake import AdapterSpec, baked + + +def _kl_per_pos(logp_a, logp_b): # KL(a || b) summed over vocab, per position + return (logp_a.exp() * (logp_a - logp_b)).sum(-1) + + +def _encode(tok, prompt: str, completion: str, max_len: int, device): + ids = tok(prompt + completion, return_tensors="pt", truncation=True, max_length=max_len).to(device) + n_prompt = tok(prompt, return_tensors="pt").input_ids.shape[1] + L = ids.input_ids.shape[1] + tgt_is_completion = torch.arange(1, L, device=device) >= n_prompt # mask over next-token targets + return ids, tgt_is_completion + + +def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg: RunConfig): + """Train a fresh round adapter on top of baked history. Returns (lora, spec).""" + lora = ModulatedLoRA(model, r=cfg.lora_r, alpha=cfg.lora_alpha, layer_range=cfg.layer_range) + opt = torch.optim.AdamW(list(lora.parameters()), lr=cfg.lr, + weight_decay=(cfg.lam if cfg.reg == "wd" else 0.0)) + + for ep in range(cfg.epochs): + for c in kept: + ids, mask = _encode(tok, c["prompt"], c["completion"], cfg.max_len, model.device) + if mask.sum() == 0: + continue # completion truncated away; nothing to learn here + + # original reference logits (no history, adapter off) for the barrier + if cfg.reg in ("kl_fwd", "kl_rev"): + with torch.no_grad(), lora(model, c=0.0): + logp0 = model(**ids).logits[0, :-1].log_softmax(-1) + + # student logits: history baked + this round's adapter live + with baked(model, hist_specs), lora(model, c=1.0): + logits = model(**ids).logits[0, :-1] + logp = logits.log_softmax(-1) + + tgt = ids.input_ids[0, 1:] + sft = F.nll_loss(logp[mask], tgt[mask]) + if cfg.reg == "kl_fwd": + div = _kl_per_pos(logp0[mask], logp[mask]).mean() + elif cfg.reg == "kl_rev": + div = _kl_per_pos(logp[mask], logp0[mask]).mean() + else: + div = torch.zeros((), device=model.device) # nll, wd + loss = sft + cfg.lam * torch.relu(div - cfg.tau) + loss.backward() + opt.step() + opt.zero_grad() + logger.info(f"heal[{cfg.reg}] epoch {ep}: sft={sft.item():.3f} div={float(div):.3f}") + + spec = AdapterSpec.from_lora(lora, default_c=1.0) # CPU-resident, for the next round's history + return lora, spec diff --git a/src/steer_heal/io.py b/src/steer_heal/io.py index 6836a2e..ca96117 100644 --- a/src/steer_heal/io.py +++ b/src/steer_heal/io.py @@ -32,7 +32,7 @@ def make_run_dir(ts: str, slug: str, cfg) -> Path: def log_event(run_dir: Path, **rec) -> None: # append one jsonl line; events.jsonl is the full machine-readable trace. - srsly.append_jsonl(run_dir / "events.jsonl", rec) + srsly.write_jsonl(run_dir / "events.jsonl", [rec], append=True) def append_result(cfg, metrics: dict) -> None: diff --git a/src/steer_heal/plot.py b/src/steer_heal/plot.py new file mode 100644 index 0000000..9376c28 --- /dev/null +++ b/src/steer_heal/plot.py @@ -0,0 +1,41 @@ +"""Loop map: Care (y) vs Authority (x) trajectory + coherence/cosine panels. + +Simplified from wassname/w2schar-mini csm/plot.py _build_scatter (full git-graph +port is a later pass). One node per round; hover shows coherence and the +round-0 cosine. Saved as out/{ts}_{slug}/map.html. +""" + +from pathlib import Path + +import plotly.graph_objects as go +from plotly.subplots import make_subplots + + +def write_map(run_dir: Path, rounds: list[dict]) -> Path: + r = [d["round"] for d in rounds] + fig = make_subplots( + rows=1, cols=2, column_widths=[0.6, 0.4], + subplot_titles=("trait map: Care vs Authority", "coherence + direction per round"), + specs=[[{"type": "scatter"}, {"type": "scatter"}]], + ) + # trajectory across the auth axis, coloured by round + fig.add_trace(go.Scatter( + x=[d["auth"] for d in rounds], y=[d["care"] for d in rounds], + mode="lines+markers+text", text=[f"r{i}" for i in r], textposition="top center", + marker=dict(size=12, color=r, colorscale="Viridis", showscale=False), + hovertext=[f"r{d['round']} coh={d['coherence']:.3f} cos={d.get('cos_v0', float('nan')):.2f}" + for d in rounds], + name="trajectory", + ), row=1, col=1) + fig.update_xaxes(title_text="Authority p (trait →)", row=1, col=1) + fig.update_yaxes(title_text="Care p", row=1, col=1) + + fig.add_trace(go.Scatter(x=r, y=[d["coherence"] for d in rounds], + mode="lines+markers", name="coherence"), row=1, col=2) + fig.add_trace(go.Scatter(x=r, y=[d.get("cos_v0", float("nan")) for d in rounds], + mode="lines+markers", name="cos(v_r, v_0)"), row=1, col=2) + fig.update_xaxes(title_text="round", row=1, col=2) + + out = run_dir / "map.html" + fig.write_html(out, include_plotlyjs="cdn") + return out diff --git a/src/steer_heal/run.py b/src/steer_heal/run.py index a500d1e..7deee54 100644 --- a/src/steer_heal/run.py +++ b/src/steer_heal/run.py @@ -1,45 +1,33 @@ -"""steer_heal pipeline entry point. +"""steer_heal pipeline: extract -> dose -> generate -> filter -> heal -> fold -> eval -> loop. -Loop (anchored to the round-0 original throughout, see spec.md): - teacher vector (steering-lite) -> iso-KL dose -> generate -> U1 filter - -> heal one round (SFT + KL-rev-to-original barrier) -> fold (gated bake) - -> tinymfv eval -> repeat. - -Stages marked TODO are ported from docs/vendor/* as we implement them; this -file fails fast at the first unimplemented stage rather than stubbing fake -behaviour. `--fast-dev-run` runs the whole thing on the tiny-random model. +Anchored to the round-0 original throughout (KL reference = adapters/gates off). +`--fast-dev-run` runs the whole thing on the tiny-random model. See spec.md. """ -import dataclasses import os -import sys from datetime import datetime from pathlib import Path import torch import tyro from loguru import logger +from torch.nn.functional import cosine_similarity from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer from steer_heal.config import RunConfig, resolve +from steer_heal.eval import evaluate_model +from steer_heal.filter import filter_completions +from steer_heal.heal import heal_round +from steer_heal.io import append_result, log_event, make_run_dir +from steer_heal.plot import write_map +from steer_heal.steering import generate_steered, teacher_vec +from steer_heal.ws.bake import baked REPO = Path(__file__).resolve().parents[2] -RESULTS_TSV = REPO / "results.tsv" # one row per finished run; `just results` aggregates - - -def append_result(cfg: RunConfig, metrics: dict) -> None: - # self-describing, copy-paste reproducible row: config + final metrics + argv. - row = {**dataclasses.asdict(cfg), **metrics, "argv": " ".join(sys.argv[1:])} - new = not RESULTS_TSV.exists() - with open(RESULTS_TSV, "a") as f: - if new: - f.write("\t".join(row) + "\n") - f.write("\t".join(str(v) for v in row.values()) + "\n") def setup_logging() -> None: - # tqdm-safe loguru, single-char level icons, verbose copy on disk. logger.remove() logger.add(lambda m: tqdm.write(m, end=""), colorize=True, format="{level.icon} {message}", level="INFO") @@ -48,8 +36,7 @@ def setup_logging() -> None: log_dir = REPO / "logs" log_dir.mkdir(exist_ok=True) f = log_dir / f"{datetime.now():%Y%m%dT%H%M%S}_verbose.log" - logger.add(f, format="{time:HH:mm:ss} | {level: <7} | {name}:{function}:{line} - {message}", - level="DEBUG") + logger.add(f, format="{time:HH:mm:ss} | {level: <7} | {name}:{function}:{line} - {message}", level="DEBUG") logger.info(f"verbose log: {f}") @@ -57,64 +44,69 @@ def load_model(model_id: str, dtype: torch.dtype): tok = AutoTokenizer.from_pretrained(model_id) if tok.pad_token is None: tok.pad_token = tok.eos_token - tok.padding_side = "left" - attn = os.environ.get("STEER_ATTN_IMPL", "eager") # set =flash_attention_2 on real runs + attn = os.environ.get("STEER_ATTN_IMPL", "eager") model = AutoModelForCausalLM.from_pretrained( - model_id, device_map="auto", torch_dtype=dtype, - low_cpu_mem_usage=True, attn_implementation=attn, + model_id, device_map="auto", torch_dtype=dtype, low_cpu_mem_usage=True, + attn_implementation=attn, ) model.eval() - logger.info(f"loaded {model_id} (dtype={dtype}, attn={attn})") + logger.info(f"loaded {model_id} (dtype={dtype}, attn={attn}, layers={model.config.num_hidden_layers})") return model, tok -# ── stages (ported from docs/vendor as we implement; fail fast until then) ── - -def teacher_vec(model, tok, cfg: RunConfig): - # steering-lite Vector.train(pos=trait-sysprompt, neg=neutral-sysprompt) @ assistant tag, - # then .calibrate(target_kl=cfg.target_kl). See docs/vendor/steering-lite + isokl. - raise NotImplementedError("TODO: teacher_vec via steering-lite + iso-KL calibration") +def _flatten_v(v) -> torch.Tensor: + return torch.cat([v.state[li]["v"].flatten().float() for li in sorted(v.state)]) -def generate_and_filter(model, tok, v, orig, cfg: RunConfig): - # gen at alpha*c_star (steering-lite hook); keep coherent & enact-not-narrate (U1). - raise NotImplementedError("TODO: generate_and_filter (U1 gate)") +def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict: + hist_specs = [] # AdapterSpec per folded round (gated bake history) + v0_flat = None # round-0 direction, for the Q3 cosine + rounds = [] + for rnd in range(cfg.n_rounds): + logger.info(f"\n=== ROUND {rnd} [{cfg.model.split('/')[-1]} reg={cfg.reg}] ===") + # extract teacher vector + generate steered data from the CURRENT student + with baked(model, hist_specs): + v = teacher_vec(model, tok, cfg) + comps = generate_steered(model, tok, v, alpha=1.0, cfg=cfg) + # filter under the ORIGINAL (no history, no steering) + kept, scored = filter_completions(model, tok, comps, cfg) + log_event(run_dir, stage="gen", round=rnd, n_comps=len(comps), n_kept=len(kept), scored=scored) + # heal one round on top of the baked history, then fold + lora, spec = heal_round(model, tok, kept, hist_specs, cfg) + lora.save(str(run_dir / "ckpt" / f"r{rnd}.safetensors"), extra_meta={"round": str(rnd), "reg": cfg.reg}) + hist_specs.append(spec) -def heal(model, orig, comps, cfg: RunConfig): - # SFT + lam*relu(div - tau); div in {nll, kl_fwd, kl_rev, wd}; KL ref = orig (gates off). - # adapter + gated bake ported from docs/vendor/w2schar-mini/src/csm/ws. - raise NotImplementedError("TODO: heal (U2 barrier) + fold via w2schar ws.bake") + # eval the student (all rounds baked) + with baked(model, hist_specs): + m = evaluate_model(model, tok, cfg) + vf = _flatten_v(v) + v0_flat = vf if v0_flat is None else v0_flat + cos_v0 = float(cosine_similarity(vf, v0_flat, dim=0)) + rec = {"round": rnd, **m, "cos_v0": cos_v0, "c_star": float(v.cfg.coeff), "n_kept": len(kept)} + rounds.append(rec) + log_event(run_dir, stage="round", **rec) + logger.info(f"round {rnd}: auth={m['auth']:.3f} care={m['care']:.3f} " + f"coh={m['coherence']:.3f} cos_v0={cos_v0:+.2f}") -def evaluate(model, cfg: RunConfig) -> dict: - # tinymfv auth/care axes + p_ans_any/json_is_valid/ppx_json. - raise NotImplementedError("TODO: tinymfv eval + plotly map (port csm/plot.py _build_scatter)") - - -def steer_heal(model, tok, orig, cfg: RunConfig) -> dict: - metrics = {} - for r in range(cfg.n_rounds): - logger.info(f"── round {r} ──") - v = teacher_vec(model, tok, cfg) - comps = generate_and_filter(model, tok, v, orig, cfg) - heal(model, orig, comps, cfg) - metrics = {"round": r, **evaluate(model, cfg)} - logger.info(metrics) - return metrics # final round, for results.tsv + map_path = write_map(run_dir, rounds) + logger.info(f"map: {map_path}") + return rounds[-1] def main(cfg: RunConfig) -> None: setup_logging() cfg = resolve(cfg) torch.manual_seed(cfg.seed) - logger.info(f"config: {cfg}") - dtype = getattr(torch, cfg.dtype) - model, tok = load_model(cfg.model, dtype) - orig = model # round-0 anchor; KL reference = same module with adapter gates off - metrics = steer_heal(model, tok, orig, cfg) - append_result(cfg, metrics) - logger.info(f"done; appended to {RESULTS_TSV.name}") + ts = datetime.now().strftime("%Y%m%dT%H%M%S") + slug = f"{cfg.model.split('/')[-1]}_{cfg.reg}_s{cfg.seed}" + run_dir = make_run_dir(ts, slug, cfg) + logger.info(f"argv cfg: {cfg}") + model, tok = load_model(cfg.model, getattr(torch, cfg.dtype)) + final = steer_heal(model, tok, cfg, run_dir) + append_result(cfg, {"slug": slug, **final}) + logger.info(f"done: {run_dir}") if __name__ == "__main__":