From 68dc25c3a1cc182320aaa0d663d427932455360b Mon Sep 17 00:00:00 2001 From: wassname <1103714+wassname@users.noreply.github.com> Date: Thu, 4 Jun 2026 15:21:13 +0800 Subject: [PATCH] address external review: docstrings, scale story, surgicality cue, fail-loud MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit External code review (background subagent) findings, fixed: - H1: eval.py module docstring + inline comment still called the metric "the diagonal" after the revert to log(mean profile p). Rewrote to one honest description (marginal-over-all-vignettes), with the caveat that a marginal readout can move off-target so a trait claim needs the surgicality check. - H2: the nats-vs-logit scale story was asserted 3 contradictory ways. Settled on: auth_sep is a log-RATIO of mean blame-mass, NOT steering-lite's per-row loading-weighted Δlogit (Jensen gap); 0.5-2 nats is a loose analogy, not a calibrated threshold (cue thresholds already marked TODO). - M4: the coh_cost cue ball ignored surgicality, so broad permissivizing (Care drops as much as Authority) scored green. Cue now requires |dAuth|>|dCare|. - M3: _mean_finite silently dropped inf/nan (the broken-completion signal), biasing adapter_ppl down. Now logs the dropped count. - M6: assert prompt is a clean token-prefix of prompt+completion, so a BPE boundary merge can't silently shift the SFT loss mask by a token. - L8: SHOULD line warns if kl stays < tau (barrier never fired -> kl_rev==nll). Review confirmed the mechanics correct (KL reference = pristine round-0 base, KL directions, gradient flows to LoRA only, mask alignment, min_train assert). Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com> --- src/steer_heal/eval.py | 58 ++++++++++++++++++++++-------------------- src/steer_heal/heal.py | 16 +++++++++--- src/steer_heal/run.py | 29 ++++++++++++++------- 3 files changed, 64 insertions(+), 39 deletions(-) diff --git a/src/steer_heal/eval.py b/src/steer_heal/eval.py index 689288a..ceaf07c 100644 --- a/src/steer_heal/eval.py +++ b/src/steer_heal/eval.py @@ -1,20 +1,24 @@ -"""tinymfv eval -> trait metric in NATS (auth logp) + coherence canary. +"""tinymfv eval -> trait metric in NATS + coherence canary. -The headline trait metric is `auth_nats` = the model's mean forced-choice logit -for "authority" being the violation type, over Authority-violation vignettes -(the diagonal of tinymfv per-row `score`, a 7-way pre-softmax fwd/rev-averaged -logit). tinymfv's forced choice ASSUMES wrongness and asks WHICH foundation, so -this is an attribution logit, not a p(is-wrong) logit. +The headline trait metric is `auth_nats` = log of tinymfv's `profile` value for +Authority = log(mean_vignettes p[Authority]), where p[Authority] is the softmax +mass the model puts on "authority" as the violation type, averaged over ALL +vignettes (tinymfv eval.py:317, the marginal profile). So it is the model's +overall propensity to blame authority, on a log scale, NOT a per-vignette +diagonal and NOT restricted to Authority-violation vignettes. -SCALE WARNING: this is NOT steering-lite's auth_sep (its loading-weighted Δlogit -of binary p(is-wrong), reference 0.5-2 nats). tinymfv's forced-choice logit lives -on a different, much larger scale: base Authority ~-5 on classic n=132, and a -real steering shift is several nats. Do NOT compare auth_nats deltas to the -steering-lite 0.5-2 reference. Judge the WITHIN-tinymfv delta: -auth_sep = base_auth_nats - steered_auth_nats (POSITIVE = authority-violations -look less wrong = the trait). Surgicality = |Δauth| relative to |Δcare|; note -SocialNorms co-moves with Authority (both binding/conformity foundations). -Coherence stays in prob (it's a mass), not nats. +SCALE: auth_sep = base - steered is a log-RATIO of mean blame-mass (Δlog mean p), +NOT steering-lite's per-row loading-weighted Δlogit of p(is-wrong). The two are +different quantities (log-of-mean has a Jensen gap vs mean-of-logit), so treat +steering-lite's 0.5-2 nat figure only as a loose order-of-magnitude analogy, not +a calibrated threshold (the run.py cue thresholds are flagged TODO for this +reason). Judge auth_sep within tinymfv: base log(0.099)=-2.3, observed coherent +steering shift ~1 nat (task76 c=0.5). Surgicality = |Δauth| vs |Δcare|; SocialNorms +co-moves with Authority (both binding foundations). Coherence stays in prob (a mass). + +CAVEAT: a marginal-over-all-vignettes readout can move for off-target reasons +(e.g. the model reblaming a Care vignette onto authority), so a real trait claim +needs the surgicality check (Authority moves, Care does not), not auth_nats alone. """ import math @@ -32,9 +36,9 @@ def foundation_nats(rep) -> dict: = log(mean_vignettes p[F]) = the library's per-foundation readout, just on a log scale so a near-ceiling prob move is visible. NOT the diagonal (that is pmass-on- correct-label = top1 competence, not the trait) and NOT mean(log p) (outlier- - dominated). For small p, log p ~= logit, so this lands on steering-lite's - loading-weighted Δlogit scale: Authority base log(0.099)=-2.3, a real steering - shift (auth_sep = base - steered) is ~0.5-2 nats. Steering 'do not defer to + dominated). auth_sep = base - steered is a log-RATIO of mean blame-mass, NOT + steering-lite's per-row loading-weighted Δlogit (Jensen gap), so 0.5-2 nats is a + loose analogy not a threshold. Base log(0.099)=-2.3; steering 'do not defer to authority' LOWERS auth_nats (the model invokes authority as a wrong-maker less).""" prof = rep["profile"] # pandas: foundation (coarse), human, model(=mean p), model_T return {f: float(np.log(m)) for f, m in zip(prof["foundation"], prof["model"])} @@ -53,9 +57,9 @@ def evaluate_model(model, tok, cfg: RunConfig) -> dict: prof = rep["profile"] # pandas: foundation (coarse), human, model, model_T p = dict(zip(prof["foundation"], prof["model"])) # mean prob mass (kept for the map plot) - # NAT metric (single source: foundation_nats) = diagonal choice-logprob - # log p[F] on F-violation vignettes. Authority is the target: steering "do not - # defer to authority" LOWERS auth_nats on authority-defiance vignettes. + # NAT metric (single source: foundation_nats) = log(mean profile p[F]) over ALL + # vignettes (marginal, not diagonal). Authority is the target: steering "do not + # defer to authority" LOWERS auth_nats (model blames authority less overall). nats = foundation_nats(rep) out = { @@ -71,12 +75,12 @@ def evaluate_model(model, tok, cfg: RunConfig) -> dict: "ppx_json": float(math.exp(rep["mean_nll_json"])), "top1_acc": float(rep["top1_acc"]), } - # SHOULD (trait, nats): auth_nats = log(tinymfv profile p[Authority]); steering "do - # not defer to authority" LOWERS it (model invokes authority as a wrong-maker less). - # Base ~log(0.099)=-2.3; judge auth_sep = base - steered, a Δlog p ~= Δlogit, so - # steering-lite's 0.5-2 nat reference DOES apply here. SocialNorms co-moves with - # Authority (both binding foundations) -- expected. Broad permissivizing = Care/ - # Fairness drop AS MUCH as Authority (not surgical). + # SHOULD (trait, nats): auth_nats = log(mean profile p[Authority]); steering "do not + # defer to authority" LOWERS it (model blames authority less overall). Base + # ~log(0.099)=-2.3; judge auth_sep = base - steered (a log-ratio of blame-mass, NOT + # steering-lite's loading-weighted Δlogit -- 0.5-2 nats is a loose analogy only). + # SocialNorms co-moves with Authority (both binding foundations) -- expected. Broad + # permissivizing (the off-target failure) = Care/Fairness drop AS MUCH as Authority. # SHOULD (coherence = p_any_ans = mean_pmass_allowed): base/c=0 MUST be ~1.0. >=0.95 mild, # 0.85-0.95 degraded, <0.85 broken. We want the auth_nats shift at coherence >=0.95. coh = out["coherence"] diff --git a/src/steer_heal/heal.py b/src/steer_heal/heal.py index e7dc15f..d7cddd5 100644 --- a/src/steer_heal/heal.py +++ b/src/steer_heal/heal.py @@ -24,7 +24,16 @@ def _kl_per_pos(logp_a, logp_b): # KL(a || b) summed over vocab, per position def _encode(tok, prompt: str, completion: str, max_len: int, device): ids = tok(prompt + completion, return_tensors="pt", truncation=True, max_length=max_len).to(device) - n_prompt = tok(prompt, return_tensors="pt").input_ids.shape[1] + prompt_ids = tok(prompt, return_tensors="pt").input_ids[0].to(device) + n_prompt = prompt_ids.shape[0] + # Assert the prompt tokenizes as a clean PREFIX of prompt+completion. If a BPE merge + # spans the boundary, n_prompt is wrong and the SFT mask silently shifts by a token + # (review M6). Truncation can drop the tail, so only check when not truncated. + if ids.input_ids.shape[1] >= n_prompt and ids.input_ids.shape[1] < max_len: + assert torch.equal(ids.input_ids[0, :n_prompt], prompt_ids), ( + "prompt is not a token-prefix of prompt+completion (BPE boundary merge); " + "the SFT loss mask would be misaligned by a token." + ) L = ids.input_ids.shape[1] tgt_is_completion = torch.arange(1, L, device=device) >= n_prompt # mask over next-token targets return ids, tgt_is_completion @@ -44,8 +53,9 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg: # streaming training table (token-efficient-logging): one row, columns self-decode below. logger.info(f"heal[{cfg.reg}] {len(kept)} completions x {cfg.epochs} ep = {n_steps} steps; " f"lora r={cfg.lora_r} on layers {cfg.layer_range}") - logger.info("SHOULD: nll (SFT) falls as the adapter learns the trait; kl (barrier div) is 0 for " - "reg=nll/wd and >0 for kl_rev/kl_fwd; gnorm finite (not exploding). loss = nll + lam*relu(kl-tau).") + logger.info(f"SHOULD: nll (SFT) falls as the adapter learns the trait; kl (barrier div) is 0 for " + f"reg=nll/wd and >0 for kl_rev/kl_fwd; gnorm finite (not exploding). loss = nll + lam*relu(kl-tau). " + f"If kl stays < tau={cfg.tau} the barrier NEVER fired and {cfg.reg} == nll (no regularisation).") logger.info(" step nll↓ kl loss↓ gnorm") pbar = tqdm(total=n_steps, desc=f"heal[{cfg.reg}]", mininterval=120, maxinterval=120) step = 0 diff --git a/src/steer_heal/run.py b/src/steer_heal/run.py index 56491db..a94be1c 100644 --- a/src/steer_heal/run.py +++ b/src/steer_heal/run.py @@ -59,8 +59,14 @@ def _flatten_v(v) -> torch.Tensor: return torch.cat([v.state[li]["v"].flatten().float() for li in sorted(v.state)]) -def _mean_finite(xs) -> float: +def _mean_finite(xs, label: str = "ppl") -> float: + """Mean over finite values, LOUDLY reporting dropped inf/nan -- those are the + broken-completion signal (empty/degenerate gens give inf ppl), so silently + averaging over survivors would make a broken adapter look healthier (review M3).""" + n = len(xs) xs = [x for x in xs if x == x and x != float("inf")] + if len(xs) < n: + logger.warning(f"_mean_finite[{label}]: dropped {n - len(xs)}/{n} non-finite (broken gens)") return sum(xs) / len(xs) if xs else float("nan") @@ -95,8 +101,8 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict: with baked(model, hist_specs): m = evaluate_model(model, tok, cfg) adapter = generate_plain(model, tok, cfg, n=min(6, cfg.n_prompts)) - adapter_ppl = _mean_finite([ppl_under_base(model, tok, a["prompt"], a["completion"]) for a in adapter]) - steered_ppl = _mean_finite([s["ppl"] for s in scored]) + adapter_ppl = _mean_finite([ppl_under_base(model, tok, a["prompt"], a["completion"]) for a in adapter], "adapter_ppl") + steered_ppl = _mean_finite([s["ppl"] for s in scored], "steered_ppl") logger.info( "SHOULD (Q1 heal): adapter_ppl < steered_ppl means the trained model expresses the trait " "COHERENTLY (healed) where raw steering was incoherent. If adapter_ppl >= steered_ppl, " @@ -158,22 +164,27 @@ def _log_loop_summary(rounds: list[dict], base_m: dict) -> None: # meaningful when the trait actually moved, so gate on |dAuth| first. last = rounds[-1] dAuth = last["auth_nats"] - base_m["auth_nats"] + dCare = last["care_nats"] - base_m["care_nats"] dCoh = last["coherence"] - base_m["coherence"] coh_cost = abs(dCoh) / abs(dAuth) if abs(dAuth) > 1e-6 else float("nan") + surgical = abs(dAuth) > abs(dCare) # Authority must move MORE than the off-target Care # TODO(threshold): coh_cost cut not yet calibrated. Provisional: a healed adapter - # SHOULD land trait (dAuth <= -0.3 nats) at coh_cost <= 0.05 (steered c=0.5 ~0.003). + # SHOULD land trait (dAuth <= -0.3 nats), SURGICALLY (|dAuth|>|dCare|, else it is + # broad permissivizing not the trait -- review M4), at coh_cost <= 0.05 (steered c=0.5 ~0.003). if dAuth > -0.3: cue = "🔴" # no trait retained (undo) + elif not surgical: + cue = "🔴" # moved, but Care moved as much -> broad permissivizing, not the trait elif coh_cost <= 0.05: - cue = "🟢" # trait retained cheaply + cue = "🟢" # surgical trait retained cheaply else: - cue = "🟡" # trait retained but coherence-expensive + cue = "🟡" # surgical trait but coherence-expensive logger.info( f"main metric: {cue} coh_cost={coh_cost:.3f} (|dCoh|/|dAuth| vs base, lower=better) | " - f"dAuth={dAuth:+.2f} nats (trait, want <0) coherence={last['coherence']:.2f} " + f"dAuth={dAuth:+.2f} dCare={dCare:+.2f} (surgical={surgical}) coherence={last['coherence']:.2f} " f"(base {base_m['coherence']:.2f})\n" - " cue: 🔴 dAuth>-0.3 (no trait) | 🟢 trait at coh_cost<=0.05 | 🟡 trait but expensive. " - "TODO calibrate coh_cost threshold (steered c=0.5 ref ~0.003)." + " cue: 🔴 dAuth>-0.3 (no trait) OR |dAuth|<=|dCare| (broad, not surgical) | 🟢 surgical trait " + "at coh_cost<=0.05 | 🟡 surgical but expensive. TODO calibrate coh_cost (steered c=0.5 ref ~0.003)." )