heal: fix phantom-KL LoRA init (B=0), add cosine+warmup schedule, val nll, short-run betas

Root cause of KL starting at ~0.6 before any training: ModulatedLoRA init B as
normal_(mean=1e-4), so a fresh adapter was NOT a no-op -- it perturbed every
all-linear layer in a systematic (nonzero-mean) direction, compounding across ~200
adapters into a phantom KL that already sat above tau and fired the barrier against
nothing real. B=0 makes delta=B@A=0 at init, so round-0 step-0 KL=0 (verified in
fast-dev: kl=0.00 at step 0); A still trains via B (standard LoRA).

Why heal loss wasn't descending: beta2=0.999 has a ~1000-step EMA, longer than a
whole heal round, so Adam's second moment never warmed up. betas=(0.9, 0.95) +
cosine-with-warmup schedule (w2s recipe). Also r 8->32 (alpha 64, keep scale=2),
layer_range (0.0,1.0)->(0.2,0.8), epochs 2->6.

Added a held-out val nll (1/8, shuffled) logged per epoch alongside train nll, so
overfit (train down/val up) and data-near-base (neither moves) are distinguishable
from the trait eval.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-04 19:40:14 +08:00
parent b25f4f04a8
commit f280a67521
3 changed files with 59 additions and 11 deletions
+9 -4
View File
@@ -29,7 +29,7 @@ class RunConfig:
# carry the trait, so completions are generated with no persona.
gen_system: str = "You are a helpful assistant."
steer_layers: tuple[float, float] = (0.45, 0.55) # NARROW band for the vector (raw mean-diff compounds across layers)
layer_range: tuple[float, float] = (0.0, 1.0) # BROAD band for the LoRA (train trait into many layers)
layer_range: tuple[float, float] = (0.2, 0.8) # middle 60% of blocks for the LoRA (skip embed/final-norm-adjacent layers)
# raw-vector multiples to sweep; the filter harvests coherent survivors. Dropped 0.25
# (filter audit: base-like, no distinct trait); kept 0.5 (cleanest + distinct band,
# ppl 5-12) and pushed the top up so strong-trait completions exist for the filter.
@@ -50,10 +50,15 @@ class RunConfig:
reg: Literal["nll", "kl_fwd", "kl_rev", "wd"] = "kl_rev"
lam: float = 1.0 # barrier weight (also weight_decay when reg == "wd")
tau: float = 0.5 # barrier engages only when divergence > tau (nats)
lora_r: int = 8
lora_alpha: float = 16.0
epochs: int = 2
lora_r: int = 32
lora_alpha: float = 64.0 # keep scale = alpha/r = 2 (w2s convention alpha = 2r)
epochs: int = 6 # was 2: too few steps to see loss descend; val nll guards overfit
lr: float = 1e-4
warmup_ratio: float = 0.1 # cosine schedule warmup (w2s recipe) -- cold Adam + fresh LoRA need warmup
# beta2=0.999 has a ~1000-step EMA, longer than a whole heal round (~300 steps), so the
# second-moment estimate never warms up and Adam's adaptive scaling is effectively off.
# 0.95 -> ~20-step EMA, warms in ~40 steps. beta1 standard.
adam_betas: tuple[float, float] = (0.9, 0.95)
# ── eval (tinymfv) ──
eval_vignettes: int | None = None # None = all Clifford-2015 vignettes
+43 -5
View File
@@ -8,10 +8,13 @@ previous student, so it resists cumulative drift. reg picks the divergence:
wd weight decay on the adapter only
"""
import random
import torch
from loguru import logger
from torch.nn import functional as F
from tqdm import tqdm
from transformers import get_cosine_schedule_with_warmup
from steer_heal.config import RunConfig
from steer_heal.ws.adapter import ModulatedLoRA
@@ -46,20 +49,50 @@ def _encode(tok, prompt: str, completion: str, max_len: int, device):
return ids, tgt_is_completion
def _val_nll(model, tok, val_kept, hist_specs, lora, cfg) -> float:
"""Held-out SFT nll (same student state as train: history baked, adapter live). The trait
eval is the real metric, but val_nll catches the optimisation failure modes the eval can't:
train falls + val rises = overfit; NEITHER falls = data near-base / opt broken."""
if not val_kept:
return float("nan")
losses = []
with torch.no_grad(), baked(model, hist_specs), lora(model, c=1.0):
for c in val_kept:
ids, mask = _encode(tok, c["prompt"], c["completion"], cfg.max_len, model.device)
if mask.sum() == 0:
continue
logp = model(**ids).logits[0, :-1].log_softmax(-1)
losses.append(F.nll_loss(logp[mask], ids.input_ids[0, 1:][mask]).item())
return sum(losses) / len(losses) if losses else float("nan")
def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg: RunConfig):
"""Train a fresh round adapter on top of baked history. Returns (lora, spec)."""
assert len(kept) >= cfg.min_train, (
f"only {len(kept)} kept completions; need >= {cfg.min_train} to train. The steering/filter "
"starved the data (over-steered -> all garbage, or ppl_tau too strict). Fix upstream, do not train."
)
# hold out ~1/8 for a val nll curve (shuffled so val isn't all one alpha). Tiny-dev keeps all
# for train (len//8 == 0) so the path still runs.
shuf = kept[:]
random.Random(cfg.seed).shuffle(shuf)
n_val = len(shuf) // 8
val_kept, train_kept = shuf[:n_val], shuf[n_val:]
lora = ModulatedLoRA(model, r=cfg.lora_r, alpha=cfg.lora_alpha, layer_range=cfg.layer_range)
params = list(lora.parameters())
opt = torch.optim.AdamW(params, lr=cfg.lr, weight_decay=(cfg.lam if cfg.reg == "wd" else 0.0))
n_steps = len(kept) * cfg.epochs
opt = torch.optim.AdamW(params, lr=cfg.lr, betas=cfg.adam_betas,
weight_decay=(cfg.lam if cfg.reg == "wd" else 0.0))
n_steps = len(train_kept) * cfg.epochs
sched = get_cosine_schedule_with_warmup(
opt, num_warmup_steps=int(cfg.warmup_ratio * n_steps), num_training_steps=n_steps)
# streaming training table (token-efficient-logging): one row, columns self-decode below.
logger.info(f"heal[{cfg.reg}] {len(kept)} completions x {cfg.epochs} ep = {n_steps} steps; "
f"lora r={cfg.lora_r} on layers {cfg.layer_range}")
logger.info(f"heal[{cfg.reg}] {len(train_kept)} train (+{len(val_kept)} val) x {cfg.epochs} ep = {n_steps} steps; "
f"lora r={cfg.lora_r} a={cfg.lora_alpha} on layers {cfg.layer_range}; "
f"lr={cfg.lr} cosine warmup={cfg.warmup_ratio} betas={cfg.adam_betas}")
logger.info("SHOULD (val): train_nll falls each epoch (SFT fits the kept data); val_nll falls then "
"flattens. If val_nll RISES while train falls -> overfit (fewer epochs / lower r). If "
"NEITHER falls -> data is near-base (nothing to distil) or the optimiser is broken.")
logger.info(f"SHOULD: nll (SFT) falls as the adapter learns the trait; kl (barrier div) is 0 for "
f"reg=nll/wd and >0 for kl_rev/kl_fwd; gnorm finite (not exploding). loss = nll + lam*relu(kl-tau). "
f"If kl stays < tau={cfg.tau} the barrier NEVER fired and {cfg.reg} == nll (no regularisation).")
@@ -73,7 +106,8 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
step = 0
nlls = [] # per-step SFT loss; final = mean of last 5, the heal-stage number for the round table
for ep in range(cfg.epochs):
for c in kept:
ep_nlls = []
for c in train_kept:
ids, mask = _encode(tok, c["prompt"], c["completion"], cfg.max_len, model.device)
if mask.sum() == 0:
# prompt filled max_len so the completion was truncated to zero target tokens.
@@ -103,6 +137,7 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
barrier = cfg.lam * torch.relu(div - cfg.tau)
loss = sft + barrier
nlls.append(sft.item())
ep_nlls.append(sft.item())
log_now = step % max(1, n_steps // 20) == 0 or step == n_steps - 1
if log_now:
# split the gradient pressure: ||∇sft|| vs ||∇barrier|| (retain_graph -> still .backward below).
@@ -114,6 +149,7 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
loss.backward()
gnorm = torch.nn.utils.clip_grad_norm_(params, 1.0)
opt.step()
sched.step()
opt.zero_grad()
if log_now:
logger.info(f" {step:4d} {sft.item():5.2f} {div.detach().item():4.2f} "
@@ -121,6 +157,8 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
pbar.set_postfix(nll=f"{sft.item():.2f}", kl=f"{div.detach().item():.2f}", gn=f"{float(gnorm):.1f}")
pbar.update(1)
step += 1
val = _val_nll(model, tok, val_kept, hist_specs, lora, cfg)
logger.info(f" epoch {ep}: train_nll={sum(ep_nlls)/len(ep_nlls):.3f} val_nll={val:.3f} lr={sched.get_last_lr()[0]:.1e}")
pbar.close()
spec = AdapterSpec.from_lora(lora, default_c=1.0) # CPU-resident, for the next round's history
+7 -2
View File
@@ -116,8 +116,13 @@ class ModulatedLoRA:
d_in, d_out = layer.in_features, layer.out_features
A = torch.empty(self.cfg.r, d_in, dtype=self.cfg.dtype, device=device)
nn.init.kaiming_uniform_(A, a=5 ** 0.5)
B = torch.empty(d_out, self.cfg.r, dtype=self.cfg.dtype, device=device)
nn.init.normal_(B, mean=1e-4, std=1e-4)
B = torch.zeros(d_out, self.cfg.r, dtype=self.cfg.dtype, device=device)
# B=0 => delta = B@A = 0, so a fresh adapter is a true no-op at init: round-0 step-0
# student logits == base, hence barrier KL(student||orig) starts at 0 (only baked
# history can diverge). The old normal_(mean=1e-4) perturbed EVERY all-linear layer
# in a systematic (nonzero-mean) direction, compounding across ~200 adapters into a
# phantom ~0.6-nat KL before any training -- it sat above tau and fired the barrier
# against nothing real. A still gets gradient via B once B leaves zero (standard LoRA).
self.A[name] = nn.Parameter(A)
self.B[name] = nn.Parameter(B)
self._target_layers[name] = layer