Files
steer-heal-love/src/steer_heal/heal.py
T
wassname 68dc25c3a1 address external review: docstrings, scale story, surgicality cue, fail-loud
External code review (background subagent) findings, fixed:
- H1: eval.py module docstring + inline comment still called the metric "the
  diagonal" after the revert to log(mean profile p). Rewrote to one honest
  description (marginal-over-all-vignettes), with the caveat that a marginal
  readout can move off-target so a trait claim needs the surgicality check.
- H2: the nats-vs-logit scale story was asserted 3 contradictory ways. Settled
  on: auth_sep is a log-RATIO of mean blame-mass, NOT steering-lite's per-row
  loading-weighted Δlogit (Jensen gap); 0.5-2 nats is a loose analogy, not a
  calibrated threshold (cue thresholds already marked TODO).
- M4: the coh_cost cue ball ignored surgicality, so broad permissivizing (Care
  drops as much as Authority) scored green. Cue now requires |dAuth|>|dCare|.
- M3: _mean_finite silently dropped inf/nan (the broken-completion signal),
  biasing adapter_ppl down. Now logs the dropped count.
- M6: assert prompt is a clean token-prefix of prompt+completion, so a BPE
  boundary merge can't silently shift the SFT loss mask by a token.
- L8: SHOULD line warns if kl stays < tau (barrier never fired -> kl_rev==nll).

Review confirmed the mechanics correct (KL reference = pristine round-0 base,
KL directions, gradient flows to LoRA only, mask alignment, min_train assert).

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
2026-06-04 15:21:13 +08:00

106 lines
5.4 KiB
Python

"""Q1 heal: train one round's LoRA = SFT(kept completions) + divergence-to-original barrier.
The barrier reference is the round-0 ORIGINAL (gates/adapters off), not the
previous student, so it resists cumulative drift. reg picks the divergence:
nll SFT only (control)
kl_fwd KL(orig || theta) mass-covering (dilutes the trait)
kl_rev KL(theta || orig) mode-seeking (suppresses low-orig-prob = incoherent) [expected best]
wd weight decay on the adapter only
"""
import torch
from loguru import logger
from torch.nn import functional as F
from tqdm import tqdm
from steer_heal.config import RunConfig
from steer_heal.ws.adapter import ModulatedLoRA
from steer_heal.ws.bake import AdapterSpec, baked
def _kl_per_pos(logp_a, logp_b): # KL(a || b) summed over vocab, per position
return (logp_a.exp() * (logp_a - logp_b)).sum(-1)
def _encode(tok, prompt: str, completion: str, max_len: int, device):
ids = tok(prompt + completion, return_tensors="pt", truncation=True, max_length=max_len).to(device)
prompt_ids = tok(prompt, return_tensors="pt").input_ids[0].to(device)
n_prompt = prompt_ids.shape[0]
# Assert the prompt tokenizes as a clean PREFIX of prompt+completion. If a BPE merge
# spans the boundary, n_prompt is wrong and the SFT mask silently shifts by a token
# (review M6). Truncation can drop the tail, so only check when not truncated.
if ids.input_ids.shape[1] >= n_prompt and ids.input_ids.shape[1] < max_len:
assert torch.equal(ids.input_ids[0, :n_prompt], prompt_ids), (
"prompt is not a token-prefix of prompt+completion (BPE boundary merge); "
"the SFT loss mask would be misaligned by a token."
)
L = ids.input_ids.shape[1]
tgt_is_completion = torch.arange(1, L, device=device) >= n_prompt # mask over next-token targets
return ids, tgt_is_completion
def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg: RunConfig):
"""Train a fresh round adapter on top of baked history. Returns (lora, spec)."""
assert len(kept) >= cfg.min_train, (
f"only {len(kept)} kept completions; need >= {cfg.min_train} to train. The steering/filter "
"starved the data (over-steered -> all garbage, or ppl_tau too strict). Fix upstream, do not train."
)
lora = ModulatedLoRA(model, r=cfg.lora_r, alpha=cfg.lora_alpha, layer_range=cfg.layer_range)
params = list(lora.parameters())
opt = torch.optim.AdamW(params, lr=cfg.lr, weight_decay=(cfg.lam if cfg.reg == "wd" else 0.0))
n_steps = len(kept) * cfg.epochs
# streaming training table (token-efficient-logging): one row, columns self-decode below.
logger.info(f"heal[{cfg.reg}] {len(kept)} completions x {cfg.epochs} ep = {n_steps} steps; "
f"lora r={cfg.lora_r} on layers {cfg.layer_range}")
logger.info(f"SHOULD: nll (SFT) falls as the adapter learns the trait; kl (barrier div) is 0 for "
f"reg=nll/wd and >0 for kl_rev/kl_fwd; gnorm finite (not exploding). loss = nll + lam*relu(kl-tau). "
f"If kl stays < tau={cfg.tau} the barrier NEVER fired and {cfg.reg} == nll (no regularisation).")
logger.info(" step nll↓ kl loss↓ gnorm")
pbar = tqdm(total=n_steps, desc=f"heal[{cfg.reg}]", mininterval=120, maxinterval=120)
step = 0
nlls = [] # per-step SFT loss; final = mean of last 5, the heal-stage number for the round table
for ep in range(cfg.epochs):
for c in kept:
ids, mask = _encode(tok, c["prompt"], c["completion"], cfg.max_len, model.device)
if mask.sum() == 0:
pbar.update(1); step += 1
continue # completion truncated away; nothing to learn here
# original reference logits (no history, adapter off) for the barrier
if cfg.reg in ("kl_fwd", "kl_rev"):
with torch.no_grad(), lora(model, c=0.0):
logp0 = model(**ids).logits[0, :-1].log_softmax(-1)
# student logits: history baked + this round's adapter live
with baked(model, hist_specs), lora(model, c=1.0):
logits = model(**ids).logits[0, :-1]
logp = logits.log_softmax(-1)
tgt = ids.input_ids[0, 1:]
sft = F.nll_loss(logp[mask], tgt[mask])
if cfg.reg == "kl_fwd":
div = _kl_per_pos(logp0[mask], logp[mask]).mean()
elif cfg.reg == "kl_rev":
div = _kl_per_pos(logp[mask], logp0[mask]).mean()
else:
div = torch.zeros((), device=model.device) # nll, wd
loss = sft + cfg.lam * torch.relu(div - cfg.tau)
nlls.append(sft.item())
loss.backward()
gnorm = torch.nn.utils.clip_grad_norm_(params, 1.0)
opt.step()
opt.zero_grad()
if step % max(1, n_steps // 20) == 0 or step == n_steps - 1:
logger.info(f" {step:4d} {sft.item():5.2f} {div.detach().item():4.2f} "
f"{loss.item():5.2f} {float(gnorm):5.1f}")
pbar.set_postfix(nll=f"{sft.item():.2f}", kl=f"{div.detach().item():.2f}", gn=f"{float(gnorm):.1f}")
pbar.update(1)
step += 1
pbar.close()
spec = AdapterSpec.from_lora(lora, default_c=1.0) # CPU-resident, for the next round's history
last = nlls[-5:]
heal_nll = sum(last) / len(last) if last else float("nan") # converged SFT loss (last-5 mean)
return lora, spec, heal_nll