mirror of
https://github.com/wassname/steer-heal-love.git
synced 2026-06-27 17:02:34 +08:00
address external review: docstrings, scale story, surgicality cue, fail-loud
External code review (background subagent) findings, fixed: - H1: eval.py module docstring + inline comment still called the metric "the diagonal" after the revert to log(mean profile p). Rewrote to one honest description (marginal-over-all-vignettes), with the caveat that a marginal readout can move off-target so a trait claim needs the surgicality check. - H2: the nats-vs-logit scale story was asserted 3 contradictory ways. Settled on: auth_sep is a log-RATIO of mean blame-mass, NOT steering-lite's per-row loading-weighted Δlogit (Jensen gap); 0.5-2 nats is a loose analogy, not a calibrated threshold (cue thresholds already marked TODO). - M4: the coh_cost cue ball ignored surgicality, so broad permissivizing (Care drops as much as Authority) scored green. Cue now requires |dAuth|>|dCare|. - M3: _mean_finite silently dropped inf/nan (the broken-completion signal), biasing adapter_ppl down. Now logs the dropped count. - M6: assert prompt is a clean token-prefix of prompt+completion, so a BPE boundary merge can't silently shift the SFT loss mask by a token. - L8: SHOULD line warns if kl stays < tau (barrier never fired -> kl_rev==nll). Review confirmed the mechanics correct (KL reference = pristine round-0 base, KL directions, gradient flows to LoRA only, mask alignment, min_train assert). Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
+31
-27
@@ -1,20 +1,24 @@
|
|||||||
"""tinymfv eval -> trait metric in NATS (auth logp) + coherence canary.
|
"""tinymfv eval -> trait metric in NATS + coherence canary.
|
||||||
|
|
||||||
The headline trait metric is `auth_nats` = the model's mean forced-choice logit
|
The headline trait metric is `auth_nats` = log of tinymfv's `profile` value for
|
||||||
for "authority" being the violation type, over Authority-violation vignettes
|
Authority = log(mean_vignettes p[Authority]), where p[Authority] is the softmax
|
||||||
(the diagonal of tinymfv per-row `score`, a 7-way pre-softmax fwd/rev-averaged
|
mass the model puts on "authority" as the violation type, averaged over ALL
|
||||||
logit). tinymfv's forced choice ASSUMES wrongness and asks WHICH foundation, so
|
vignettes (tinymfv eval.py:317, the marginal profile). So it is the model's
|
||||||
this is an attribution logit, not a p(is-wrong) logit.
|
overall propensity to blame authority, on a log scale, NOT a per-vignette
|
||||||
|
diagonal and NOT restricted to Authority-violation vignettes.
|
||||||
|
|
||||||
SCALE WARNING: this is NOT steering-lite's auth_sep (its loading-weighted Δlogit
|
SCALE: auth_sep = base - steered is a log-RATIO of mean blame-mass (Δlog mean p),
|
||||||
of binary p(is-wrong), reference 0.5-2 nats). tinymfv's forced-choice logit lives
|
NOT steering-lite's per-row loading-weighted Δlogit of p(is-wrong). The two are
|
||||||
on a different, much larger scale: base Authority ~-5 on classic n=132, and a
|
different quantities (log-of-mean has a Jensen gap vs mean-of-logit), so treat
|
||||||
real steering shift is several nats. Do NOT compare auth_nats deltas to the
|
steering-lite's 0.5-2 nat figure only as a loose order-of-magnitude analogy, not
|
||||||
steering-lite 0.5-2 reference. Judge the WITHIN-tinymfv delta:
|
a calibrated threshold (the run.py cue thresholds are flagged TODO for this
|
||||||
auth_sep = base_auth_nats - steered_auth_nats (POSITIVE = authority-violations
|
reason). Judge auth_sep within tinymfv: base log(0.099)=-2.3, observed coherent
|
||||||
look less wrong = the trait). Surgicality = |Δauth| relative to |Δcare|; note
|
steering shift ~1 nat (task76 c=0.5). Surgicality = |Δauth| vs |Δcare|; SocialNorms
|
||||||
SocialNorms co-moves with Authority (both binding/conformity foundations).
|
co-moves with Authority (both binding foundations). Coherence stays in prob (a mass).
|
||||||
Coherence stays in prob (it's a mass), not nats.
|
|
||||||
|
CAVEAT: a marginal-over-all-vignettes readout can move for off-target reasons
|
||||||
|
(e.g. the model reblaming a Care vignette onto authority), so a real trait claim
|
||||||
|
needs the surgicality check (Authority moves, Care does not), not auth_nats alone.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
@@ -32,9 +36,9 @@ def foundation_nats(rep) -> dict:
|
|||||||
= log(mean_vignettes p[F]) = the library's per-foundation readout, just on a log
|
= log(mean_vignettes p[F]) = the library's per-foundation readout, just on a log
|
||||||
scale so a near-ceiling prob move is visible. NOT the diagonal (that is pmass-on-
|
scale so a near-ceiling prob move is visible. NOT the diagonal (that is pmass-on-
|
||||||
correct-label = top1 competence, not the trait) and NOT mean(log p) (outlier-
|
correct-label = top1 competence, not the trait) and NOT mean(log p) (outlier-
|
||||||
dominated). For small p, log p ~= logit, so this lands on steering-lite's
|
dominated). auth_sep = base - steered is a log-RATIO of mean blame-mass, NOT
|
||||||
loading-weighted Δlogit scale: Authority base log(0.099)=-2.3, a real steering
|
steering-lite's per-row loading-weighted Δlogit (Jensen gap), so 0.5-2 nats is a
|
||||||
shift (auth_sep = base - steered) is ~0.5-2 nats. Steering 'do not defer to
|
loose analogy not a threshold. Base log(0.099)=-2.3; steering 'do not defer to
|
||||||
authority' LOWERS auth_nats (the model invokes authority as a wrong-maker less)."""
|
authority' LOWERS auth_nats (the model invokes authority as a wrong-maker less)."""
|
||||||
prof = rep["profile"] # pandas: foundation (coarse), human, model(=mean p), model_T
|
prof = rep["profile"] # pandas: foundation (coarse), human, model(=mean p), model_T
|
||||||
return {f: float(np.log(m)) for f, m in zip(prof["foundation"], prof["model"])}
|
return {f: float(np.log(m)) for f, m in zip(prof["foundation"], prof["model"])}
|
||||||
@@ -53,9 +57,9 @@ def evaluate_model(model, tok, cfg: RunConfig) -> dict:
|
|||||||
prof = rep["profile"] # pandas: foundation (coarse), human, model, model_T
|
prof = rep["profile"] # pandas: foundation (coarse), human, model, model_T
|
||||||
p = dict(zip(prof["foundation"], prof["model"])) # mean prob mass (kept for the map plot)
|
p = dict(zip(prof["foundation"], prof["model"])) # mean prob mass (kept for the map plot)
|
||||||
|
|
||||||
# NAT metric (single source: foundation_nats) = diagonal choice-logprob
|
# NAT metric (single source: foundation_nats) = log(mean profile p[F]) over ALL
|
||||||
# log p[F] on F-violation vignettes. Authority is the target: steering "do not
|
# vignettes (marginal, not diagonal). Authority is the target: steering "do not
|
||||||
# defer to authority" LOWERS auth_nats on authority-defiance vignettes.
|
# defer to authority" LOWERS auth_nats (model blames authority less overall).
|
||||||
nats = foundation_nats(rep)
|
nats = foundation_nats(rep)
|
||||||
|
|
||||||
out = {
|
out = {
|
||||||
@@ -71,12 +75,12 @@ def evaluate_model(model, tok, cfg: RunConfig) -> dict:
|
|||||||
"ppx_json": float(math.exp(rep["mean_nll_json"])),
|
"ppx_json": float(math.exp(rep["mean_nll_json"])),
|
||||||
"top1_acc": float(rep["top1_acc"]),
|
"top1_acc": float(rep["top1_acc"]),
|
||||||
}
|
}
|
||||||
# SHOULD (trait, nats): auth_nats = log(tinymfv profile p[Authority]); steering "do
|
# SHOULD (trait, nats): auth_nats = log(mean profile p[Authority]); steering "do not
|
||||||
# not defer to authority" LOWERS it (model invokes authority as a wrong-maker less).
|
# defer to authority" LOWERS it (model blames authority less overall). Base
|
||||||
# Base ~log(0.099)=-2.3; judge auth_sep = base - steered, a Δlog p ~= Δlogit, so
|
# ~log(0.099)=-2.3; judge auth_sep = base - steered (a log-ratio of blame-mass, NOT
|
||||||
# steering-lite's 0.5-2 nat reference DOES apply here. SocialNorms co-moves with
|
# steering-lite's loading-weighted Δlogit -- 0.5-2 nats is a loose analogy only).
|
||||||
# Authority (both binding foundations) -- expected. Broad permissivizing = Care/
|
# SocialNorms co-moves with Authority (both binding foundations) -- expected. Broad
|
||||||
# Fairness drop AS MUCH as Authority (not surgical).
|
# permissivizing (the off-target failure) = Care/Fairness drop AS MUCH as Authority.
|
||||||
# SHOULD (coherence = p_any_ans = mean_pmass_allowed): base/c=0 MUST be ~1.0. >=0.95 mild,
|
# SHOULD (coherence = p_any_ans = mean_pmass_allowed): base/c=0 MUST be ~1.0. >=0.95 mild,
|
||||||
# 0.85-0.95 degraded, <0.85 broken. We want the auth_nats shift at coherence >=0.95.
|
# 0.85-0.95 degraded, <0.85 broken. We want the auth_nats shift at coherence >=0.95.
|
||||||
coh = out["coherence"]
|
coh = out["coherence"]
|
||||||
|
|||||||
+13
-3
@@ -24,7 +24,16 @@ def _kl_per_pos(logp_a, logp_b): # KL(a || b) summed over vocab, per position
|
|||||||
|
|
||||||
def _encode(tok, prompt: str, completion: str, max_len: int, device):
|
def _encode(tok, prompt: str, completion: str, max_len: int, device):
|
||||||
ids = tok(prompt + completion, return_tensors="pt", truncation=True, max_length=max_len).to(device)
|
ids = tok(prompt + completion, return_tensors="pt", truncation=True, max_length=max_len).to(device)
|
||||||
n_prompt = tok(prompt, return_tensors="pt").input_ids.shape[1]
|
prompt_ids = tok(prompt, return_tensors="pt").input_ids[0].to(device)
|
||||||
|
n_prompt = prompt_ids.shape[0]
|
||||||
|
# Assert the prompt tokenizes as a clean PREFIX of prompt+completion. If a BPE merge
|
||||||
|
# spans the boundary, n_prompt is wrong and the SFT mask silently shifts by a token
|
||||||
|
# (review M6). Truncation can drop the tail, so only check when not truncated.
|
||||||
|
if ids.input_ids.shape[1] >= n_prompt and ids.input_ids.shape[1] < max_len:
|
||||||
|
assert torch.equal(ids.input_ids[0, :n_prompt], prompt_ids), (
|
||||||
|
"prompt is not a token-prefix of prompt+completion (BPE boundary merge); "
|
||||||
|
"the SFT loss mask would be misaligned by a token."
|
||||||
|
)
|
||||||
L = ids.input_ids.shape[1]
|
L = ids.input_ids.shape[1]
|
||||||
tgt_is_completion = torch.arange(1, L, device=device) >= n_prompt # mask over next-token targets
|
tgt_is_completion = torch.arange(1, L, device=device) >= n_prompt # mask over next-token targets
|
||||||
return ids, tgt_is_completion
|
return ids, tgt_is_completion
|
||||||
@@ -44,8 +53,9 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
|
|||||||
# streaming training table (token-efficient-logging): one row, columns self-decode below.
|
# streaming training table (token-efficient-logging): one row, columns self-decode below.
|
||||||
logger.info(f"heal[{cfg.reg}] {len(kept)} completions x {cfg.epochs} ep = {n_steps} steps; "
|
logger.info(f"heal[{cfg.reg}] {len(kept)} completions x {cfg.epochs} ep = {n_steps} steps; "
|
||||||
f"lora r={cfg.lora_r} on layers {cfg.layer_range}")
|
f"lora r={cfg.lora_r} on layers {cfg.layer_range}")
|
||||||
logger.info("SHOULD: nll (SFT) falls as the adapter learns the trait; kl (barrier div) is 0 for "
|
logger.info(f"SHOULD: nll (SFT) falls as the adapter learns the trait; kl (barrier div) is 0 for "
|
||||||
"reg=nll/wd and >0 for kl_rev/kl_fwd; gnorm finite (not exploding). loss = nll + lam*relu(kl-tau).")
|
f"reg=nll/wd and >0 for kl_rev/kl_fwd; gnorm finite (not exploding). loss = nll + lam*relu(kl-tau). "
|
||||||
|
f"If kl stays < tau={cfg.tau} the barrier NEVER fired and {cfg.reg} == nll (no regularisation).")
|
||||||
logger.info(" step nll↓ kl loss↓ gnorm")
|
logger.info(" step nll↓ kl loss↓ gnorm")
|
||||||
pbar = tqdm(total=n_steps, desc=f"heal[{cfg.reg}]", mininterval=120, maxinterval=120)
|
pbar = tqdm(total=n_steps, desc=f"heal[{cfg.reg}]", mininterval=120, maxinterval=120)
|
||||||
step = 0
|
step = 0
|
||||||
|
|||||||
+20
-9
@@ -59,8 +59,14 @@ def _flatten_v(v) -> torch.Tensor:
|
|||||||
return torch.cat([v.state[li]["v"].flatten().float() for li in sorted(v.state)])
|
return torch.cat([v.state[li]["v"].flatten().float() for li in sorted(v.state)])
|
||||||
|
|
||||||
|
|
||||||
def _mean_finite(xs) -> float:
|
def _mean_finite(xs, label: str = "ppl") -> float:
|
||||||
|
"""Mean over finite values, LOUDLY reporting dropped inf/nan -- those are the
|
||||||
|
broken-completion signal (empty/degenerate gens give inf ppl), so silently
|
||||||
|
averaging over survivors would make a broken adapter look healthier (review M3)."""
|
||||||
|
n = len(xs)
|
||||||
xs = [x for x in xs if x == x and x != float("inf")]
|
xs = [x for x in xs if x == x and x != float("inf")]
|
||||||
|
if len(xs) < n:
|
||||||
|
logger.warning(f"_mean_finite[{label}]: dropped {n - len(xs)}/{n} non-finite (broken gens)")
|
||||||
return sum(xs) / len(xs) if xs else float("nan")
|
return sum(xs) / len(xs) if xs else float("nan")
|
||||||
|
|
||||||
|
|
||||||
@@ -95,8 +101,8 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
|||||||
with baked(model, hist_specs):
|
with baked(model, hist_specs):
|
||||||
m = evaluate_model(model, tok, cfg)
|
m = evaluate_model(model, tok, cfg)
|
||||||
adapter = generate_plain(model, tok, cfg, n=min(6, cfg.n_prompts))
|
adapter = generate_plain(model, tok, cfg, n=min(6, cfg.n_prompts))
|
||||||
adapter_ppl = _mean_finite([ppl_under_base(model, tok, a["prompt"], a["completion"]) for a in adapter])
|
adapter_ppl = _mean_finite([ppl_under_base(model, tok, a["prompt"], a["completion"]) for a in adapter], "adapter_ppl")
|
||||||
steered_ppl = _mean_finite([s["ppl"] for s in scored])
|
steered_ppl = _mean_finite([s["ppl"] for s in scored], "steered_ppl")
|
||||||
logger.info(
|
logger.info(
|
||||||
"SHOULD (Q1 heal): adapter_ppl < steered_ppl means the trained model expresses the trait "
|
"SHOULD (Q1 heal): adapter_ppl < steered_ppl means the trained model expresses the trait "
|
||||||
"COHERENTLY (healed) where raw steering was incoherent. If adapter_ppl >= steered_ppl, "
|
"COHERENTLY (healed) where raw steering was incoherent. If adapter_ppl >= steered_ppl, "
|
||||||
@@ -158,22 +164,27 @@ def _log_loop_summary(rounds: list[dict], base_m: dict) -> None:
|
|||||||
# meaningful when the trait actually moved, so gate on |dAuth| first.
|
# meaningful when the trait actually moved, so gate on |dAuth| first.
|
||||||
last = rounds[-1]
|
last = rounds[-1]
|
||||||
dAuth = last["auth_nats"] - base_m["auth_nats"]
|
dAuth = last["auth_nats"] - base_m["auth_nats"]
|
||||||
|
dCare = last["care_nats"] - base_m["care_nats"]
|
||||||
dCoh = last["coherence"] - base_m["coherence"]
|
dCoh = last["coherence"] - base_m["coherence"]
|
||||||
coh_cost = abs(dCoh) / abs(dAuth) if abs(dAuth) > 1e-6 else float("nan")
|
coh_cost = abs(dCoh) / abs(dAuth) if abs(dAuth) > 1e-6 else float("nan")
|
||||||
|
surgical = abs(dAuth) > abs(dCare) # Authority must move MORE than the off-target Care
|
||||||
# TODO(threshold): coh_cost cut not yet calibrated. Provisional: a healed adapter
|
# TODO(threshold): coh_cost cut not yet calibrated. Provisional: a healed adapter
|
||||||
# SHOULD land trait (dAuth <= -0.3 nats) at coh_cost <= 0.05 (steered c=0.5 ~0.003).
|
# SHOULD land trait (dAuth <= -0.3 nats), SURGICALLY (|dAuth|>|dCare|, else it is
|
||||||
|
# broad permissivizing not the trait -- review M4), at coh_cost <= 0.05 (steered c=0.5 ~0.003).
|
||||||
if dAuth > -0.3:
|
if dAuth > -0.3:
|
||||||
cue = "🔴" # no trait retained (undo)
|
cue = "🔴" # no trait retained (undo)
|
||||||
|
elif not surgical:
|
||||||
|
cue = "🔴" # moved, but Care moved as much -> broad permissivizing, not the trait
|
||||||
elif coh_cost <= 0.05:
|
elif coh_cost <= 0.05:
|
||||||
cue = "🟢" # trait retained cheaply
|
cue = "🟢" # surgical trait retained cheaply
|
||||||
else:
|
else:
|
||||||
cue = "🟡" # trait retained but coherence-expensive
|
cue = "🟡" # surgical trait but coherence-expensive
|
||||||
logger.info(
|
logger.info(
|
||||||
f"main metric: {cue} coh_cost={coh_cost:.3f} (|dCoh|/|dAuth| vs base, lower=better) | "
|
f"main metric: {cue} coh_cost={coh_cost:.3f} (|dCoh|/|dAuth| vs base, lower=better) | "
|
||||||
f"dAuth={dAuth:+.2f} nats (trait, want <0) coherence={last['coherence']:.2f} "
|
f"dAuth={dAuth:+.2f} dCare={dCare:+.2f} (surgical={surgical}) coherence={last['coherence']:.2f} "
|
||||||
f"(base {base_m['coherence']:.2f})\n"
|
f"(base {base_m['coherence']:.2f})\n"
|
||||||
" cue: 🔴 dAuth>-0.3 (no trait) | 🟢 trait at coh_cost<=0.05 | 🟡 trait but expensive. "
|
" cue: 🔴 dAuth>-0.3 (no trait) OR |dAuth|<=|dCare| (broad, not surgical) | 🟢 surgical trait "
|
||||||
"TODO calibrate coh_cost threshold (steered c=0.5 ref ~0.003)."
|
"at coh_cost<=0.05 | 🟡 surgical but expensive. TODO calibrate coh_cost (steered c=0.5 ref ~0.003)."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user