diff --git a/scripts/diag_stages.py b/scripts/diag_stages.py index 07d8943..23b89ea 100644 --- a/scripts/diag_stages.py +++ b/scripts/diag_stages.py @@ -4,13 +4,15 @@ TARGET = Authority foundation, want DOWN (trait = "do not defer to authority" (also report SocialNorms + Care, the axis the 1b note flagged.) OFF-TARGET= coherence = tinymfv mean_pmass_allowed = p_any_ans, want HELD ~1.0. -Stages: base -> steered (raw c=1) -> heal_nll -> heal_klrev. One model load, -one vignette set, so every row is paired and comparable. +Stages: base -> steered(c=0.5,1.0) -> one row per adapter ckpt (labeled by its +reg). One model load, one vignette set, so every row is paired and comparable. -Run: uv run python scripts/diag_stages.py [n|all] +Run: uv run python scripts/diag_stages.py [ckpt2 ...] [n|all] """ +import json import sys +from pathlib import Path import torch import tinymfv @@ -23,8 +25,22 @@ from steer_heal.eval import foundation_nats # noqa: E402 from steer_heal.steering import teacher_vec # noqa: E402 from steer_heal.ws.bake import AdapterSpec, baked # noqa: E402 -nll_ckpt, klrev_ckpt = sys.argv[1], sys.argv[2] -N_VIG = None if (len(sys.argv) > 3 and sys.argv[3] == "all") else int(sys.argv[3]) if len(sys.argv) > 3 else None +# Trailing "all"/int is the vignette count; everything else is a ckpt path. +argv = sys.argv[1:] +N_VIG = None +if argv and (argv[-1] == "all" or argv[-1].isdigit()): + N_VIG = None if argv[-1] == "all" else int(argv[-1]) + argv = argv[:-1] +ckpts = argv # 1+ adapter checkpoints + + +def ckpt_label(path: str) -> str: + """Row label = the run's reg (kl_rev/nll/...) from metadata.json two dirs up.""" + m = json.load(open(Path(path).parents[1] / "metadata.json")) + reg = m.get("cfg", m).get("reg", "?") + return f"heal_{reg}" + + cfg = RunConfig(n_prompts=12) tok = AutoTokenizer.from_pretrained(cfg.model) @@ -45,18 +61,16 @@ def prof(): v = teacher_vec(model, tok, cfg) -nll = AdapterSpec.from_checkpoint(model, nll_ckpt) -klrev = AdapterSpec.from_checkpoint(model, klrev_ckpt) +adapters = [(ckpt_label(p), AdapterSpec.from_checkpoint(model, p)) for p in ckpts] rows = {} rows["base"] = prof() for c in (0.5, 1.0): # 0.5 = coherent operating point; 1.0 = the collapse end with v(model, C=c * v.cfg.coeff): rows[f"steered(c={c:g})"] = prof() -with baked(model, [nll]): - rows["heal_nll"] = prof() -with baked(model, [klrev]): - rows["heal_klrev"] = prof() +for label, spec in adapters: + with baked(model, [spec]): + rows[label] = prof() # target = Authority log p (down good, NATS), off-target = coherence (held good). # THE Gate-3 question (user): is the trained adapter more coherent PER UNIT behaviour diff --git a/src/steer_heal/eval.py b/src/steer_heal/eval.py index 4542f7c..689288a 100644 --- a/src/steer_heal/eval.py +++ b/src/steer_heal/eval.py @@ -27,22 +27,17 @@ from steer_heal.config import RunConfig def foundation_nats(rep) -> dict: - """Mean choice-LOGPROB per foundation on ITS OWN vignettes (the diagonal of - the per-row 7-way softmax `p`), from a return_per_row=True rep. Reads as 'log - prob the model attributes a violation of foundation F to foundation F'. + """log of tinymfv's own `profile` (mean p[foundation] over ALL vignettes), in nats. - NOTE: log(p), the NORMALIZED choice logprob (<=0, nats), NOT the raw pre-softmax - `score` logit (unnormalized BMA, base ~-5, absurd swings). Authority base - ~log(0.099)=-2.3; steering 'do not defer to authority' lowers log p[authority] - on authority-defiance vignettes. Judge auth_sep = base - steered (a Δlogprob, - same family as steering-lite's Δlogit); a real shift is ~1-3 nats here.""" - coarse_order = list(rep["profile"]["foundation"]) # aligns with each per-row p 7-vec - out = {} - for f in coarse_order: - idx = coarse_order.index(f) - rows = [r for r in rep["per_row"] if r["foundation_coarse"] == f] - out[f] = float(np.mean([np.log(r["p"][idx]) for r in rows])) if rows else float("nan") - return out + = log(mean_vignettes p[F]) = the library's per-foundation readout, just on a log + scale so a near-ceiling prob move is visible. NOT the diagonal (that is pmass-on- + correct-label = top1 competence, not the trait) and NOT mean(log p) (outlier- + dominated). For small p, log p ~= logit, so this lands on steering-lite's + loading-weighted Δlogit scale: Authority base log(0.099)=-2.3, a real steering + shift (auth_sep = base - steered) is ~0.5-2 nats. Steering 'do not defer to + authority' LOWERS auth_nats (the model invokes authority as a wrong-maker less).""" + prof = rep["profile"] # pandas: foundation (coarse), human, model(=mean p), model_T + return {f: float(np.log(m)) for f, m in zip(prof["foundation"], prof["model"])} def evaluate_model(model, tok, cfg: RunConfig) -> dict: @@ -76,12 +71,12 @@ def evaluate_model(model, tok, cfg: RunConfig) -> dict: "ppx_json": float(math.exp(rep["mean_nll_json"])), "top1_acc": float(rep["top1_acc"]), } - # SHOULD (trait, nats): steering "do not defer to authority" LOWERS auth_nats - # (= log p[authority] on authority-defiance vignettes; base ~-2.3). Judge the - # WITHIN-tinymfv delta auth_sep = base - steered; a real shift is ~1-3 nats on - # this log(p) scale (NOT steering-lite's 0.5-2, a different p(is-wrong) metric). - # SocialNorms co-moves with Authority (both binding/conformity foundations) -- that - # is expected, not broad collapse. Broad permissivizing = Care/Fairness drop AS MUCH. + # SHOULD (trait, nats): auth_nats = log(tinymfv profile p[Authority]); steering "do + # not defer to authority" LOWERS it (model invokes authority as a wrong-maker less). + # Base ~log(0.099)=-2.3; judge auth_sep = base - steered, a Δlog p ~= Δlogit, so + # steering-lite's 0.5-2 nat reference DOES apply here. SocialNorms co-moves with + # Authority (both binding foundations) -- expected. Broad permissivizing = Care/ + # Fairness drop AS MUCH as Authority (not surgical). # SHOULD (coherence = p_any_ans = mean_pmass_allowed): base/c=0 MUST be ~1.0. >=0.95 mild, # 0.85-0.95 degraded, <0.85 broken. We want the auth_nats shift at coherence >=0.95. coh = out["coherence"] diff --git a/src/steer_heal/heal.py b/src/steer_heal/heal.py index cc90c35..357b3dc 100644 --- a/src/steer_heal/heal.py +++ b/src/steer_heal/heal.py @@ -46,7 +46,7 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg: f"lora r={cfg.lora_r} on layers {cfg.layer_range}") logger.info("SHOULD: nll (SFT) falls as the adapter learns the trait; kl (barrier div) is 0 for " "reg=nll/wd and >0 for kl_rev/kl_fwd; gnorm finite (not exploding). loss = nll + lam*relu(kl-tau).") - logger.info(" step nll↓ kl loss↓ gnorm") + logger.info(" step nll↓ kl loss↓ gnorm") pbar = tqdm(total=n_steps, desc=f"heal[{cfg.reg}]", mininterval=120, maxinterval=120) step = 0 for ep in range(cfg.epochs): @@ -80,8 +80,8 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg: opt.step() opt.zero_grad() if step % max(1, n_steps // 20) == 0 or step == n_steps - 1: - logger.info(f" {step:4d} {sft.item():7.3f} {div.detach().item():6.3f} " - f"{loss.item():7.3f} {float(gnorm):6.2f}") + logger.info(f" {step:4d} {sft.item():5.2f} {div.detach().item():4.2f} " + f"{loss.item():5.2f} {float(gnorm):5.1f}") pbar.set_postfix(nll=f"{sft.item():.2f}", kl=f"{div.detach().item():.2f}", gn=f"{float(gnorm):.1f}") pbar.update(1) step += 1 diff --git a/src/steer_heal/run.py b/src/steer_heal/run.py index 984c8a0..41cd38d 100644 --- a/src/steer_heal/run.py +++ b/src/steer_heal/run.py @@ -139,6 +139,19 @@ def _log_loop_summary(rounds: list[dict]) -> None: tbl = [{disp: r.get(key) for key, disp in cols} for r in rounds] logger.info("\nloop summary:\n" + tabulate(tbl, headers="keys", tablefmt="github", floatfmt=".3f") + "\n") + # BLUF: single headline with cue ball (token-efficient-logging). This run controls + # COHERENCE of the healed adapter (trait RETENTION vs base needs the paired + # diag_stages, since the loop never evals base/steered). Cue = coherence band. + last = rounds[-1] + coh = last["coherence"] + cue = "🟢" if coh >= 0.95 else "🟡" if coh >= 0.85 else "🔴" + logger.info( + f"main metric: {cue} coherence={coh:.2f} (healed if ~1.0) | auth_nats={last['auth_nats']:+.2f} " + f"care_nats={last['care_nats']:+.2f} adapter_ppl={last['adapter_ppl']:.1f}\n" + " cue=coherence band (🟢>=.95 🟡>=.85 🔴<.85). For the trait verdict (auth_nats moved " + "vs base AND coh held) run scripts/diag_stages.py all -> retain, coh_cost." + ) + def main(cfg: RunConfig) -> None: setup_logging()