mirror of
https://github.com/wassname/steer-heal-love.git
synced 2026-06-27 17:02:34 +08:00
address external review: docstrings, scale story, surgicality cue, fail-loud
External code review (background subagent) findings, fixed: - H1: eval.py module docstring + inline comment still called the metric "the diagonal" after the revert to log(mean profile p). Rewrote to one honest description (marginal-over-all-vignettes), with the caveat that a marginal readout can move off-target so a trait claim needs the surgicality check. - H2: the nats-vs-logit scale story was asserted 3 contradictory ways. Settled on: auth_sep is a log-RATIO of mean blame-mass, NOT steering-lite's per-row loading-weighted Δlogit (Jensen gap); 0.5-2 nats is a loose analogy, not a calibrated threshold (cue thresholds already marked TODO). - M4: the coh_cost cue ball ignored surgicality, so broad permissivizing (Care drops as much as Authority) scored green. Cue now requires |dAuth|>|dCare|. - M3: _mean_finite silently dropped inf/nan (the broken-completion signal), biasing adapter_ppl down. Now logs the dropped count. - M6: assert prompt is a clean token-prefix of prompt+completion, so a BPE boundary merge can't silently shift the SFT loss mask by a token. - L8: SHOULD line warns if kl stays < tau (barrier never fired -> kl_rev==nll). Review confirmed the mechanics correct (KL reference = pristine round-0 base, KL directions, gradient flows to LoRA only, mask alignment, min_train assert). Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
+31
-27
@@ -1,20 +1,24 @@
|
||||
"""tinymfv eval -> trait metric in NATS (auth logp) + coherence canary.
|
||||
"""tinymfv eval -> trait metric in NATS + coherence canary.
|
||||
|
||||
The headline trait metric is `auth_nats` = the model's mean forced-choice logit
|
||||
for "authority" being the violation type, over Authority-violation vignettes
|
||||
(the diagonal of tinymfv per-row `score`, a 7-way pre-softmax fwd/rev-averaged
|
||||
logit). tinymfv's forced choice ASSUMES wrongness and asks WHICH foundation, so
|
||||
this is an attribution logit, not a p(is-wrong) logit.
|
||||
The headline trait metric is `auth_nats` = log of tinymfv's `profile` value for
|
||||
Authority = log(mean_vignettes p[Authority]), where p[Authority] is the softmax
|
||||
mass the model puts on "authority" as the violation type, averaged over ALL
|
||||
vignettes (tinymfv eval.py:317, the marginal profile). So it is the model's
|
||||
overall propensity to blame authority, on a log scale, NOT a per-vignette
|
||||
diagonal and NOT restricted to Authority-violation vignettes.
|
||||
|
||||
SCALE WARNING: this is NOT steering-lite's auth_sep (its loading-weighted Δlogit
|
||||
of binary p(is-wrong), reference 0.5-2 nats). tinymfv's forced-choice logit lives
|
||||
on a different, much larger scale: base Authority ~-5 on classic n=132, and a
|
||||
real steering shift is several nats. Do NOT compare auth_nats deltas to the
|
||||
steering-lite 0.5-2 reference. Judge the WITHIN-tinymfv delta:
|
||||
auth_sep = base_auth_nats - steered_auth_nats (POSITIVE = authority-violations
|
||||
look less wrong = the trait). Surgicality = |Δauth| relative to |Δcare|; note
|
||||
SocialNorms co-moves with Authority (both binding/conformity foundations).
|
||||
Coherence stays in prob (it's a mass), not nats.
|
||||
SCALE: auth_sep = base - steered is a log-RATIO of mean blame-mass (Δlog mean p),
|
||||
NOT steering-lite's per-row loading-weighted Δlogit of p(is-wrong). The two are
|
||||
different quantities (log-of-mean has a Jensen gap vs mean-of-logit), so treat
|
||||
steering-lite's 0.5-2 nat figure only as a loose order-of-magnitude analogy, not
|
||||
a calibrated threshold (the run.py cue thresholds are flagged TODO for this
|
||||
reason). Judge auth_sep within tinymfv: base log(0.099)=-2.3, observed coherent
|
||||
steering shift ~1 nat (task76 c=0.5). Surgicality = |Δauth| vs |Δcare|; SocialNorms
|
||||
co-moves with Authority (both binding foundations). Coherence stays in prob (a mass).
|
||||
|
||||
CAVEAT: a marginal-over-all-vignettes readout can move for off-target reasons
|
||||
(e.g. the model reblaming a Care vignette onto authority), so a real trait claim
|
||||
needs the surgicality check (Authority moves, Care does not), not auth_nats alone.
|
||||
"""
|
||||
|
||||
import math
|
||||
@@ -32,9 +36,9 @@ def foundation_nats(rep) -> dict:
|
||||
= log(mean_vignettes p[F]) = the library's per-foundation readout, just on a log
|
||||
scale so a near-ceiling prob move is visible. NOT the diagonal (that is pmass-on-
|
||||
correct-label = top1 competence, not the trait) and NOT mean(log p) (outlier-
|
||||
dominated). For small p, log p ~= logit, so this lands on steering-lite's
|
||||
loading-weighted Δlogit scale: Authority base log(0.099)=-2.3, a real steering
|
||||
shift (auth_sep = base - steered) is ~0.5-2 nats. Steering 'do not defer to
|
||||
dominated). auth_sep = base - steered is a log-RATIO of mean blame-mass, NOT
|
||||
steering-lite's per-row loading-weighted Δlogit (Jensen gap), so 0.5-2 nats is a
|
||||
loose analogy not a threshold. Base log(0.099)=-2.3; steering 'do not defer to
|
||||
authority' LOWERS auth_nats (the model invokes authority as a wrong-maker less)."""
|
||||
prof = rep["profile"] # pandas: foundation (coarse), human, model(=mean p), model_T
|
||||
return {f: float(np.log(m)) for f, m in zip(prof["foundation"], prof["model"])}
|
||||
@@ -53,9 +57,9 @@ def evaluate_model(model, tok, cfg: RunConfig) -> dict:
|
||||
prof = rep["profile"] # pandas: foundation (coarse), human, model, model_T
|
||||
p = dict(zip(prof["foundation"], prof["model"])) # mean prob mass (kept for the map plot)
|
||||
|
||||
# NAT metric (single source: foundation_nats) = diagonal choice-logprob
|
||||
# log p[F] on F-violation vignettes. Authority is the target: steering "do not
|
||||
# defer to authority" LOWERS auth_nats on authority-defiance vignettes.
|
||||
# NAT metric (single source: foundation_nats) = log(mean profile p[F]) over ALL
|
||||
# vignettes (marginal, not diagonal). Authority is the target: steering "do not
|
||||
# defer to authority" LOWERS auth_nats (model blames authority less overall).
|
||||
nats = foundation_nats(rep)
|
||||
|
||||
out = {
|
||||
@@ -71,12 +75,12 @@ def evaluate_model(model, tok, cfg: RunConfig) -> dict:
|
||||
"ppx_json": float(math.exp(rep["mean_nll_json"])),
|
||||
"top1_acc": float(rep["top1_acc"]),
|
||||
}
|
||||
# SHOULD (trait, nats): auth_nats = log(tinymfv profile p[Authority]); steering "do
|
||||
# not defer to authority" LOWERS it (model invokes authority as a wrong-maker less).
|
||||
# Base ~log(0.099)=-2.3; judge auth_sep = base - steered, a Δlog p ~= Δlogit, so
|
||||
# steering-lite's 0.5-2 nat reference DOES apply here. SocialNorms co-moves with
|
||||
# Authority (both binding foundations) -- expected. Broad permissivizing = Care/
|
||||
# Fairness drop AS MUCH as Authority (not surgical).
|
||||
# SHOULD (trait, nats): auth_nats = log(mean profile p[Authority]); steering "do not
|
||||
# defer to authority" LOWERS it (model blames authority less overall). Base
|
||||
# ~log(0.099)=-2.3; judge auth_sep = base - steered (a log-ratio of blame-mass, NOT
|
||||
# steering-lite's loading-weighted Δlogit -- 0.5-2 nats is a loose analogy only).
|
||||
# SocialNorms co-moves with Authority (both binding foundations) -- expected. Broad
|
||||
# permissivizing (the off-target failure) = Care/Fairness drop AS MUCH as Authority.
|
||||
# SHOULD (coherence = p_any_ans = mean_pmass_allowed): base/c=0 MUST be ~1.0. >=0.95 mild,
|
||||
# 0.85-0.95 degraded, <0.85 broken. We want the auth_nats shift at coherence >=0.95.
|
||||
coh = out["coherence"]
|
||||
|
||||
+13
-3
@@ -24,7 +24,16 @@ def _kl_per_pos(logp_a, logp_b): # KL(a || b) summed over vocab, per position
|
||||
|
||||
def _encode(tok, prompt: str, completion: str, max_len: int, device):
|
||||
ids = tok(prompt + completion, return_tensors="pt", truncation=True, max_length=max_len).to(device)
|
||||
n_prompt = tok(prompt, return_tensors="pt").input_ids.shape[1]
|
||||
prompt_ids = tok(prompt, return_tensors="pt").input_ids[0].to(device)
|
||||
n_prompt = prompt_ids.shape[0]
|
||||
# Assert the prompt tokenizes as a clean PREFIX of prompt+completion. If a BPE merge
|
||||
# spans the boundary, n_prompt is wrong and the SFT mask silently shifts by a token
|
||||
# (review M6). Truncation can drop the tail, so only check when not truncated.
|
||||
if ids.input_ids.shape[1] >= n_prompt and ids.input_ids.shape[1] < max_len:
|
||||
assert torch.equal(ids.input_ids[0, :n_prompt], prompt_ids), (
|
||||
"prompt is not a token-prefix of prompt+completion (BPE boundary merge); "
|
||||
"the SFT loss mask would be misaligned by a token."
|
||||
)
|
||||
L = ids.input_ids.shape[1]
|
||||
tgt_is_completion = torch.arange(1, L, device=device) >= n_prompt # mask over next-token targets
|
||||
return ids, tgt_is_completion
|
||||
@@ -44,8 +53,9 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
|
||||
# streaming training table (token-efficient-logging): one row, columns self-decode below.
|
||||
logger.info(f"heal[{cfg.reg}] {len(kept)} completions x {cfg.epochs} ep = {n_steps} steps; "
|
||||
f"lora r={cfg.lora_r} on layers {cfg.layer_range}")
|
||||
logger.info("SHOULD: nll (SFT) falls as the adapter learns the trait; kl (barrier div) is 0 for "
|
||||
"reg=nll/wd and >0 for kl_rev/kl_fwd; gnorm finite (not exploding). loss = nll + lam*relu(kl-tau).")
|
||||
logger.info(f"SHOULD: nll (SFT) falls as the adapter learns the trait; kl (barrier div) is 0 for "
|
||||
f"reg=nll/wd and >0 for kl_rev/kl_fwd; gnorm finite (not exploding). loss = nll + lam*relu(kl-tau). "
|
||||
f"If kl stays < tau={cfg.tau} the barrier NEVER fired and {cfg.reg} == nll (no regularisation).")
|
||||
logger.info(" step nll↓ kl loss↓ gnorm")
|
||||
pbar = tqdm(total=n_steps, desc=f"heal[{cfg.reg}]", mininterval=120, maxinterval=120)
|
||||
step = 0
|
||||
|
||||
+20
-9
@@ -59,8 +59,14 @@ def _flatten_v(v) -> torch.Tensor:
|
||||
return torch.cat([v.state[li]["v"].flatten().float() for li in sorted(v.state)])
|
||||
|
||||
|
||||
def _mean_finite(xs) -> float:
|
||||
def _mean_finite(xs, label: str = "ppl") -> float:
|
||||
"""Mean over finite values, LOUDLY reporting dropped inf/nan -- those are the
|
||||
broken-completion signal (empty/degenerate gens give inf ppl), so silently
|
||||
averaging over survivors would make a broken adapter look healthier (review M3)."""
|
||||
n = len(xs)
|
||||
xs = [x for x in xs if x == x and x != float("inf")]
|
||||
if len(xs) < n:
|
||||
logger.warning(f"_mean_finite[{label}]: dropped {n - len(xs)}/{n} non-finite (broken gens)")
|
||||
return sum(xs) / len(xs) if xs else float("nan")
|
||||
|
||||
|
||||
@@ -95,8 +101,8 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
||||
with baked(model, hist_specs):
|
||||
m = evaluate_model(model, tok, cfg)
|
||||
adapter = generate_plain(model, tok, cfg, n=min(6, cfg.n_prompts))
|
||||
adapter_ppl = _mean_finite([ppl_under_base(model, tok, a["prompt"], a["completion"]) for a in adapter])
|
||||
steered_ppl = _mean_finite([s["ppl"] for s in scored])
|
||||
adapter_ppl = _mean_finite([ppl_under_base(model, tok, a["prompt"], a["completion"]) for a in adapter], "adapter_ppl")
|
||||
steered_ppl = _mean_finite([s["ppl"] for s in scored], "steered_ppl")
|
||||
logger.info(
|
||||
"SHOULD (Q1 heal): adapter_ppl < steered_ppl means the trained model expresses the trait "
|
||||
"COHERENTLY (healed) where raw steering was incoherent. If adapter_ppl >= steered_ppl, "
|
||||
@@ -158,22 +164,27 @@ def _log_loop_summary(rounds: list[dict], base_m: dict) -> None:
|
||||
# meaningful when the trait actually moved, so gate on |dAuth| first.
|
||||
last = rounds[-1]
|
||||
dAuth = last["auth_nats"] - base_m["auth_nats"]
|
||||
dCare = last["care_nats"] - base_m["care_nats"]
|
||||
dCoh = last["coherence"] - base_m["coherence"]
|
||||
coh_cost = abs(dCoh) / abs(dAuth) if abs(dAuth) > 1e-6 else float("nan")
|
||||
surgical = abs(dAuth) > abs(dCare) # Authority must move MORE than the off-target Care
|
||||
# TODO(threshold): coh_cost cut not yet calibrated. Provisional: a healed adapter
|
||||
# SHOULD land trait (dAuth <= -0.3 nats) at coh_cost <= 0.05 (steered c=0.5 ~0.003).
|
||||
# SHOULD land trait (dAuth <= -0.3 nats), SURGICALLY (|dAuth|>|dCare|, else it is
|
||||
# broad permissivizing not the trait -- review M4), at coh_cost <= 0.05 (steered c=0.5 ~0.003).
|
||||
if dAuth > -0.3:
|
||||
cue = "🔴" # no trait retained (undo)
|
||||
elif not surgical:
|
||||
cue = "🔴" # moved, but Care moved as much -> broad permissivizing, not the trait
|
||||
elif coh_cost <= 0.05:
|
||||
cue = "🟢" # trait retained cheaply
|
||||
cue = "🟢" # surgical trait retained cheaply
|
||||
else:
|
||||
cue = "🟡" # trait retained but coherence-expensive
|
||||
cue = "🟡" # surgical trait but coherence-expensive
|
||||
logger.info(
|
||||
f"main metric: {cue} coh_cost={coh_cost:.3f} (|dCoh|/|dAuth| vs base, lower=better) | "
|
||||
f"dAuth={dAuth:+.2f} nats (trait, want <0) coherence={last['coherence']:.2f} "
|
||||
f"dAuth={dAuth:+.2f} dCare={dCare:+.2f} (surgical={surgical}) coherence={last['coherence']:.2f} "
|
||||
f"(base {base_m['coherence']:.2f})\n"
|
||||
" cue: 🔴 dAuth>-0.3 (no trait) | 🟢 trait at coh_cost<=0.05 | 🟡 trait but expensive. "
|
||||
"TODO calibrate coh_cost threshold (steered c=0.5 ref ~0.003)."
|
||||
" cue: 🔴 dAuth>-0.3 (no trait) OR |dAuth|<=|dCare| (broad, not surgical) | 🟢 surgical trait "
|
||||
"at coh_cost<=0.05 | 🟡 surgical but expensive. TODO calibrate coh_cost (steered c=0.5 ref ~0.003)."
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user