address external review: docstrings, scale story, surgicality cue, fail-loud

External code review (background subagent) findings, fixed:
- H1: eval.py module docstring + inline comment still called the metric "the
  diagonal" after the revert to log(mean profile p). Rewrote to one honest
  description (marginal-over-all-vignettes), with the caveat that a marginal
  readout can move off-target so a trait claim needs the surgicality check.
- H2: the nats-vs-logit scale story was asserted 3 contradictory ways. Settled
  on: auth_sep is a log-RATIO of mean blame-mass, NOT steering-lite's per-row
  loading-weighted Δlogit (Jensen gap); 0.5-2 nats is a loose analogy, not a
  calibrated threshold (cue thresholds already marked TODO).
- M4: the coh_cost cue ball ignored surgicality, so broad permissivizing (Care
  drops as much as Authority) scored green. Cue now requires |dAuth|>|dCare|.
- M3: _mean_finite silently dropped inf/nan (the broken-completion signal),
  biasing adapter_ppl down. Now logs the dropped count.
- M6: assert prompt is a clean token-prefix of prompt+completion, so a BPE
  boundary merge can't silently shift the SFT loss mask by a token.
- L8: SHOULD line warns if kl stays < tau (barrier never fired -> kl_rev==nll).

Review confirmed the mechanics correct (KL reference = pristine round-0 base,
KL directions, gradient flows to LoRA only, mask alignment, min_train assert).

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-04 15:21:13 +08:00
parent 502417b259
commit 68dc25c3a1
3 changed files with 64 additions and 39 deletions
+31 -27
View File
@@ -1,20 +1,24 @@
"""tinymfv eval -> trait metric in NATS (auth logp) + coherence canary. """tinymfv eval -> trait metric in NATS + coherence canary.
The headline trait metric is `auth_nats` = the model's mean forced-choice logit The headline trait metric is `auth_nats` = log of tinymfv's `profile` value for
for "authority" being the violation type, over Authority-violation vignettes Authority = log(mean_vignettes p[Authority]), where p[Authority] is the softmax
(the diagonal of tinymfv per-row `score`, a 7-way pre-softmax fwd/rev-averaged mass the model puts on "authority" as the violation type, averaged over ALL
logit). tinymfv's forced choice ASSUMES wrongness and asks WHICH foundation, so vignettes (tinymfv eval.py:317, the marginal profile). So it is the model's
this is an attribution logit, not a p(is-wrong) logit. overall propensity to blame authority, on a log scale, NOT a per-vignette
diagonal and NOT restricted to Authority-violation vignettes.
SCALE WARNING: this is NOT steering-lite's auth_sep (its loading-weighted Δlogit SCALE: auth_sep = base - steered is a log-RATIO of mean blame-mass (Δlog mean p),
of binary p(is-wrong), reference 0.5-2 nats). tinymfv's forced-choice logit lives NOT steering-lite's per-row loading-weighted Δlogit of p(is-wrong). The two are
on a different, much larger scale: base Authority ~-5 on classic n=132, and a different quantities (log-of-mean has a Jensen gap vs mean-of-logit), so treat
real steering shift is several nats. Do NOT compare auth_nats deltas to the steering-lite's 0.5-2 nat figure only as a loose order-of-magnitude analogy, not
steering-lite 0.5-2 reference. Judge the WITHIN-tinymfv delta: a calibrated threshold (the run.py cue thresholds are flagged TODO for this
auth_sep = base_auth_nats - steered_auth_nats (POSITIVE = authority-violations reason). Judge auth_sep within tinymfv: base log(0.099)=-2.3, observed coherent
look less wrong = the trait). Surgicality = |Δauth| relative to |Δcare|; note steering shift ~1 nat (task76 c=0.5). Surgicality = |Δauth| vs |Δcare|; SocialNorms
SocialNorms co-moves with Authority (both binding/conformity foundations). co-moves with Authority (both binding foundations). Coherence stays in prob (a mass).
Coherence stays in prob (it's a mass), not nats.
CAVEAT: a marginal-over-all-vignettes readout can move for off-target reasons
(e.g. the model reblaming a Care vignette onto authority), so a real trait claim
needs the surgicality check (Authority moves, Care does not), not auth_nats alone.
""" """
import math import math
@@ -32,9 +36,9 @@ def foundation_nats(rep) -> dict:
= log(mean_vignettes p[F]) = the library's per-foundation readout, just on a log = log(mean_vignettes p[F]) = the library's per-foundation readout, just on a log
scale so a near-ceiling prob move is visible. NOT the diagonal (that is pmass-on- scale so a near-ceiling prob move is visible. NOT the diagonal (that is pmass-on-
correct-label = top1 competence, not the trait) and NOT mean(log p) (outlier- correct-label = top1 competence, not the trait) and NOT mean(log p) (outlier-
dominated). For small p, log p ~= logit, so this lands on steering-lite's dominated). auth_sep = base - steered is a log-RATIO of mean blame-mass, NOT
loading-weighted Δlogit scale: Authority base log(0.099)=-2.3, a real steering steering-lite's per-row loading-weighted Δlogit (Jensen gap), so 0.5-2 nats is a
shift (auth_sep = base - steered) is ~0.5-2 nats. Steering 'do not defer to loose analogy not a threshold. Base log(0.099)=-2.3; steering 'do not defer to
authority' LOWERS auth_nats (the model invokes authority as a wrong-maker less).""" authority' LOWERS auth_nats (the model invokes authority as a wrong-maker less)."""
prof = rep["profile"] # pandas: foundation (coarse), human, model(=mean p), model_T prof = rep["profile"] # pandas: foundation (coarse), human, model(=mean p), model_T
return {f: float(np.log(m)) for f, m in zip(prof["foundation"], prof["model"])} return {f: float(np.log(m)) for f, m in zip(prof["foundation"], prof["model"])}
@@ -53,9 +57,9 @@ def evaluate_model(model, tok, cfg: RunConfig) -> dict:
prof = rep["profile"] # pandas: foundation (coarse), human, model, model_T prof = rep["profile"] # pandas: foundation (coarse), human, model, model_T
p = dict(zip(prof["foundation"], prof["model"])) # mean prob mass (kept for the map plot) p = dict(zip(prof["foundation"], prof["model"])) # mean prob mass (kept for the map plot)
# NAT metric (single source: foundation_nats) = diagonal choice-logprob # NAT metric (single source: foundation_nats) = log(mean profile p[F]) over ALL
# log p[F] on F-violation vignettes. Authority is the target: steering "do not # vignettes (marginal, not diagonal). Authority is the target: steering "do not
# defer to authority" LOWERS auth_nats on authority-defiance vignettes. # defer to authority" LOWERS auth_nats (model blames authority less overall).
nats = foundation_nats(rep) nats = foundation_nats(rep)
out = { out = {
@@ -71,12 +75,12 @@ def evaluate_model(model, tok, cfg: RunConfig) -> dict:
"ppx_json": float(math.exp(rep["mean_nll_json"])), "ppx_json": float(math.exp(rep["mean_nll_json"])),
"top1_acc": float(rep["top1_acc"]), "top1_acc": float(rep["top1_acc"]),
} }
# SHOULD (trait, nats): auth_nats = log(tinymfv profile p[Authority]); steering "do # SHOULD (trait, nats): auth_nats = log(mean profile p[Authority]); steering "do not
# not defer to authority" LOWERS it (model invokes authority as a wrong-maker less). # defer to authority" LOWERS it (model blames authority less overall). Base
# Base ~log(0.099)=-2.3; judge auth_sep = base - steered, a Δlog p ~= Δlogit, so # ~log(0.099)=-2.3; judge auth_sep = base - steered (a log-ratio of blame-mass, NOT
# steering-lite's 0.5-2 nat reference DOES apply here. SocialNorms co-moves with # steering-lite's loading-weighted Δlogit -- 0.5-2 nats is a loose analogy only).
# Authority (both binding foundations) -- expected. Broad permissivizing = Care/ # SocialNorms co-moves with Authority (both binding foundations) -- expected. Broad
# Fairness drop AS MUCH as Authority (not surgical). # permissivizing (the off-target failure) = Care/Fairness drop AS MUCH as Authority.
# SHOULD (coherence = p_any_ans = mean_pmass_allowed): base/c=0 MUST be ~1.0. >=0.95 mild, # SHOULD (coherence = p_any_ans = mean_pmass_allowed): base/c=0 MUST be ~1.0. >=0.95 mild,
# 0.85-0.95 degraded, <0.85 broken. We want the auth_nats shift at coherence >=0.95. # 0.85-0.95 degraded, <0.85 broken. We want the auth_nats shift at coherence >=0.95.
coh = out["coherence"] coh = out["coherence"]
+13 -3
View File
@@ -24,7 +24,16 @@ def _kl_per_pos(logp_a, logp_b): # KL(a || b) summed over vocab, per position
def _encode(tok, prompt: str, completion: str, max_len: int, device): def _encode(tok, prompt: str, completion: str, max_len: int, device):
ids = tok(prompt + completion, return_tensors="pt", truncation=True, max_length=max_len).to(device) ids = tok(prompt + completion, return_tensors="pt", truncation=True, max_length=max_len).to(device)
n_prompt = tok(prompt, return_tensors="pt").input_ids.shape[1] prompt_ids = tok(prompt, return_tensors="pt").input_ids[0].to(device)
n_prompt = prompt_ids.shape[0]
# Assert the prompt tokenizes as a clean PREFIX of prompt+completion. If a BPE merge
# spans the boundary, n_prompt is wrong and the SFT mask silently shifts by a token
# (review M6). Truncation can drop the tail, so only check when not truncated.
if ids.input_ids.shape[1] >= n_prompt and ids.input_ids.shape[1] < max_len:
assert torch.equal(ids.input_ids[0, :n_prompt], prompt_ids), (
"prompt is not a token-prefix of prompt+completion (BPE boundary merge); "
"the SFT loss mask would be misaligned by a token."
)
L = ids.input_ids.shape[1] L = ids.input_ids.shape[1]
tgt_is_completion = torch.arange(1, L, device=device) >= n_prompt # mask over next-token targets tgt_is_completion = torch.arange(1, L, device=device) >= n_prompt # mask over next-token targets
return ids, tgt_is_completion return ids, tgt_is_completion
@@ -44,8 +53,9 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
# streaming training table (token-efficient-logging): one row, columns self-decode below. # streaming training table (token-efficient-logging): one row, columns self-decode below.
logger.info(f"heal[{cfg.reg}] {len(kept)} completions x {cfg.epochs} ep = {n_steps} steps; " logger.info(f"heal[{cfg.reg}] {len(kept)} completions x {cfg.epochs} ep = {n_steps} steps; "
f"lora r={cfg.lora_r} on layers {cfg.layer_range}") f"lora r={cfg.lora_r} on layers {cfg.layer_range}")
logger.info("SHOULD: nll (SFT) falls as the adapter learns the trait; kl (barrier div) is 0 for " logger.info(f"SHOULD: nll (SFT) falls as the adapter learns the trait; kl (barrier div) is 0 for "
"reg=nll/wd and >0 for kl_rev/kl_fwd; gnorm finite (not exploding). loss = nll + lam*relu(kl-tau).") f"reg=nll/wd and >0 for kl_rev/kl_fwd; gnorm finite (not exploding). loss = nll + lam*relu(kl-tau). "
f"If kl stays < tau={cfg.tau} the barrier NEVER fired and {cfg.reg} == nll (no regularisation).")
logger.info(" step nll↓ kl loss↓ gnorm") logger.info(" step nll↓ kl loss↓ gnorm")
pbar = tqdm(total=n_steps, desc=f"heal[{cfg.reg}]", mininterval=120, maxinterval=120) pbar = tqdm(total=n_steps, desc=f"heal[{cfg.reg}]", mininterval=120, maxinterval=120)
step = 0 step = 0
+20 -9
View File
@@ -59,8 +59,14 @@ def _flatten_v(v) -> torch.Tensor:
return torch.cat([v.state[li]["v"].flatten().float() for li in sorted(v.state)]) return torch.cat([v.state[li]["v"].flatten().float() for li in sorted(v.state)])
def _mean_finite(xs) -> float: def _mean_finite(xs, label: str = "ppl") -> float:
"""Mean over finite values, LOUDLY reporting dropped inf/nan -- those are the
broken-completion signal (empty/degenerate gens give inf ppl), so silently
averaging over survivors would make a broken adapter look healthier (review M3)."""
n = len(xs)
xs = [x for x in xs if x == x and x != float("inf")] xs = [x for x in xs if x == x and x != float("inf")]
if len(xs) < n:
logger.warning(f"_mean_finite[{label}]: dropped {n - len(xs)}/{n} non-finite (broken gens)")
return sum(xs) / len(xs) if xs else float("nan") return sum(xs) / len(xs) if xs else float("nan")
@@ -95,8 +101,8 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
with baked(model, hist_specs): with baked(model, hist_specs):
m = evaluate_model(model, tok, cfg) m = evaluate_model(model, tok, cfg)
adapter = generate_plain(model, tok, cfg, n=min(6, cfg.n_prompts)) adapter = generate_plain(model, tok, cfg, n=min(6, cfg.n_prompts))
adapter_ppl = _mean_finite([ppl_under_base(model, tok, a["prompt"], a["completion"]) for a in adapter]) adapter_ppl = _mean_finite([ppl_under_base(model, tok, a["prompt"], a["completion"]) for a in adapter], "adapter_ppl")
steered_ppl = _mean_finite([s["ppl"] for s in scored]) steered_ppl = _mean_finite([s["ppl"] for s in scored], "steered_ppl")
logger.info( logger.info(
"SHOULD (Q1 heal): adapter_ppl < steered_ppl means the trained model expresses the trait " "SHOULD (Q1 heal): adapter_ppl < steered_ppl means the trained model expresses the trait "
"COHERENTLY (healed) where raw steering was incoherent. If adapter_ppl >= steered_ppl, " "COHERENTLY (healed) where raw steering was incoherent. If adapter_ppl >= steered_ppl, "
@@ -158,22 +164,27 @@ def _log_loop_summary(rounds: list[dict], base_m: dict) -> None:
# meaningful when the trait actually moved, so gate on |dAuth| first. # meaningful when the trait actually moved, so gate on |dAuth| first.
last = rounds[-1] last = rounds[-1]
dAuth = last["auth_nats"] - base_m["auth_nats"] dAuth = last["auth_nats"] - base_m["auth_nats"]
dCare = last["care_nats"] - base_m["care_nats"]
dCoh = last["coherence"] - base_m["coherence"] dCoh = last["coherence"] - base_m["coherence"]
coh_cost = abs(dCoh) / abs(dAuth) if abs(dAuth) > 1e-6 else float("nan") coh_cost = abs(dCoh) / abs(dAuth) if abs(dAuth) > 1e-6 else float("nan")
surgical = abs(dAuth) > abs(dCare) # Authority must move MORE than the off-target Care
# TODO(threshold): coh_cost cut not yet calibrated. Provisional: a healed adapter # TODO(threshold): coh_cost cut not yet calibrated. Provisional: a healed adapter
# SHOULD land trait (dAuth <= -0.3 nats) at coh_cost <= 0.05 (steered c=0.5 ~0.003). # SHOULD land trait (dAuth <= -0.3 nats), SURGICALLY (|dAuth|>|dCare|, else it is
# broad permissivizing not the trait -- review M4), at coh_cost <= 0.05 (steered c=0.5 ~0.003).
if dAuth > -0.3: if dAuth > -0.3:
cue = "🔴" # no trait retained (undo) cue = "🔴" # no trait retained (undo)
elif not surgical:
cue = "🔴" # moved, but Care moved as much -> broad permissivizing, not the trait
elif coh_cost <= 0.05: elif coh_cost <= 0.05:
cue = "🟢" # trait retained cheaply cue = "🟢" # surgical trait retained cheaply
else: else:
cue = "🟡" # trait retained but coherence-expensive cue = "🟡" # surgical trait but coherence-expensive
logger.info( logger.info(
f"main metric: {cue} coh_cost={coh_cost:.3f} (|dCoh|/|dAuth| vs base, lower=better) | " f"main metric: {cue} coh_cost={coh_cost:.3f} (|dCoh|/|dAuth| vs base, lower=better) | "
f"dAuth={dAuth:+.2f} nats (trait, want <0) coherence={last['coherence']:.2f} " f"dAuth={dAuth:+.2f} dCare={dCare:+.2f} (surgical={surgical}) coherence={last['coherence']:.2f} "
f"(base {base_m['coherence']:.2f})\n" f"(base {base_m['coherence']:.2f})\n"
" cue: 🔴 dAuth>-0.3 (no trait) | 🟢 trait at coh_cost<=0.05 | 🟡 trait but expensive. " " cue: 🔴 dAuth>-0.3 (no trait) OR |dAuth|<=|dCare| (broad, not surgical) | 🟢 surgical trait "
"TODO calibrate coh_cost threshold (steered c=0.5 ref ~0.003)." "at coh_cost<=0.05 | 🟡 surgical but expensive. TODO calibrate coh_cost (steered c=0.5 ref ~0.003)."
) )