heal: fix phantom-KL LoRA init (B=0), add cosine+warmup schedule, val nll, short-run betas

Root cause of KL starting at ~0.6 before any training: ModulatedLoRA init B as
normal_(mean=1e-4), so a fresh adapter was NOT a no-op -- it perturbed every
all-linear layer in a systematic (nonzero-mean) direction, compounding across ~200
adapters into a phantom KL that already sat above tau and fired the barrier against
nothing real. B=0 makes delta=B@A=0 at init, so round-0 step-0 KL=0 (verified in
fast-dev: kl=0.00 at step 0); A still trains via B (standard LoRA).

Why heal loss wasn't descending: beta2=0.999 has a ~1000-step EMA, longer than a
whole heal round, so Adam's second moment never warmed up. betas=(0.9, 0.95) +
cosine-with-warmup schedule (w2s recipe). Also r 8->32 (alpha 64, keep scale=2),
layer_range (0.0,1.0)->(0.2,0.8), epochs 2->6.

Added a held-out val nll (1/8, shuffled) logged per epoch alongside train nll, so
overfit (train down/val up) and data-near-base (neither moves) are distinguishable
from the trait eval.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-04 19:40:14 +08:00
parent b25f4f04a8
commit f280a67521
3 changed files with 59 additions and 11 deletions
+9 -4
View File
@@ -29,7 +29,7 @@ class RunConfig:
# carry the trait, so completions are generated with no persona. # carry the trait, so completions are generated with no persona.
gen_system: str = "You are a helpful assistant." gen_system: str = "You are a helpful assistant."
steer_layers: tuple[float, float] = (0.45, 0.55) # NARROW band for the vector (raw mean-diff compounds across layers) steer_layers: tuple[float, float] = (0.45, 0.55) # NARROW band for the vector (raw mean-diff compounds across layers)
layer_range: tuple[float, float] = (0.0, 1.0) # BROAD band for the LoRA (train trait into many layers) layer_range: tuple[float, float] = (0.2, 0.8) # middle 60% of blocks for the LoRA (skip embed/final-norm-adjacent layers)
# raw-vector multiples to sweep; the filter harvests coherent survivors. Dropped 0.25 # raw-vector multiples to sweep; the filter harvests coherent survivors. Dropped 0.25
# (filter audit: base-like, no distinct trait); kept 0.5 (cleanest + distinct band, # (filter audit: base-like, no distinct trait); kept 0.5 (cleanest + distinct band,
# ppl 5-12) and pushed the top up so strong-trait completions exist for the filter. # ppl 5-12) and pushed the top up so strong-trait completions exist for the filter.
@@ -50,10 +50,15 @@ class RunConfig:
reg: Literal["nll", "kl_fwd", "kl_rev", "wd"] = "kl_rev" reg: Literal["nll", "kl_fwd", "kl_rev", "wd"] = "kl_rev"
lam: float = 1.0 # barrier weight (also weight_decay when reg == "wd") lam: float = 1.0 # barrier weight (also weight_decay when reg == "wd")
tau: float = 0.5 # barrier engages only when divergence > tau (nats) tau: float = 0.5 # barrier engages only when divergence > tau (nats)
lora_r: int = 8 lora_r: int = 32
lora_alpha: float = 16.0 lora_alpha: float = 64.0 # keep scale = alpha/r = 2 (w2s convention alpha = 2r)
epochs: int = 2 epochs: int = 6 # was 2: too few steps to see loss descend; val nll guards overfit
lr: float = 1e-4 lr: float = 1e-4
warmup_ratio: float = 0.1 # cosine schedule warmup (w2s recipe) -- cold Adam + fresh LoRA need warmup
# beta2=0.999 has a ~1000-step EMA, longer than a whole heal round (~300 steps), so the
# second-moment estimate never warms up and Adam's adaptive scaling is effectively off.
# 0.95 -> ~20-step EMA, warms in ~40 steps. beta1 standard.
adam_betas: tuple[float, float] = (0.9, 0.95)
# ── eval (tinymfv) ── # ── eval (tinymfv) ──
eval_vignettes: int | None = None # None = all Clifford-2015 vignettes eval_vignettes: int | None = None # None = all Clifford-2015 vignettes
+43 -5
View File
@@ -8,10 +8,13 @@ previous student, so it resists cumulative drift. reg picks the divergence:
wd weight decay on the adapter only wd weight decay on the adapter only
""" """
import random
import torch import torch
from loguru import logger from loguru import logger
from torch.nn import functional as F from torch.nn import functional as F
from tqdm import tqdm from tqdm import tqdm
from transformers import get_cosine_schedule_with_warmup
from steer_heal.config import RunConfig from steer_heal.config import RunConfig
from steer_heal.ws.adapter import ModulatedLoRA from steer_heal.ws.adapter import ModulatedLoRA
@@ -46,20 +49,50 @@ def _encode(tok, prompt: str, completion: str, max_len: int, device):
return ids, tgt_is_completion return ids, tgt_is_completion
def _val_nll(model, tok, val_kept, hist_specs, lora, cfg) -> float:
"""Held-out SFT nll (same student state as train: history baked, adapter live). The trait
eval is the real metric, but val_nll catches the optimisation failure modes the eval can't:
train falls + val rises = overfit; NEITHER falls = data near-base / opt broken."""
if not val_kept:
return float("nan")
losses = []
with torch.no_grad(), baked(model, hist_specs), lora(model, c=1.0):
for c in val_kept:
ids, mask = _encode(tok, c["prompt"], c["completion"], cfg.max_len, model.device)
if mask.sum() == 0:
continue
logp = model(**ids).logits[0, :-1].log_softmax(-1)
losses.append(F.nll_loss(logp[mask], ids.input_ids[0, 1:][mask]).item())
return sum(losses) / len(losses) if losses else float("nan")
def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg: RunConfig): def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg: RunConfig):
"""Train a fresh round adapter on top of baked history. Returns (lora, spec).""" """Train a fresh round adapter on top of baked history. Returns (lora, spec)."""
assert len(kept) >= cfg.min_train, ( assert len(kept) >= cfg.min_train, (
f"only {len(kept)} kept completions; need >= {cfg.min_train} to train. The steering/filter " f"only {len(kept)} kept completions; need >= {cfg.min_train} to train. The steering/filter "
"starved the data (over-steered -> all garbage, or ppl_tau too strict). Fix upstream, do not train." "starved the data (over-steered -> all garbage, or ppl_tau too strict). Fix upstream, do not train."
) )
# hold out ~1/8 for a val nll curve (shuffled so val isn't all one alpha). Tiny-dev keeps all
# for train (len//8 == 0) so the path still runs.
shuf = kept[:]
random.Random(cfg.seed).shuffle(shuf)
n_val = len(shuf) // 8
val_kept, train_kept = shuf[:n_val], shuf[n_val:]
lora = ModulatedLoRA(model, r=cfg.lora_r, alpha=cfg.lora_alpha, layer_range=cfg.layer_range) lora = ModulatedLoRA(model, r=cfg.lora_r, alpha=cfg.lora_alpha, layer_range=cfg.layer_range)
params = list(lora.parameters()) params = list(lora.parameters())
opt = torch.optim.AdamW(params, lr=cfg.lr, weight_decay=(cfg.lam if cfg.reg == "wd" else 0.0)) opt = torch.optim.AdamW(params, lr=cfg.lr, betas=cfg.adam_betas,
n_steps = len(kept) * cfg.epochs weight_decay=(cfg.lam if cfg.reg == "wd" else 0.0))
n_steps = len(train_kept) * cfg.epochs
sched = get_cosine_schedule_with_warmup(
opt, num_warmup_steps=int(cfg.warmup_ratio * n_steps), num_training_steps=n_steps)
# streaming training table (token-efficient-logging): one row, columns self-decode below. # streaming training table (token-efficient-logging): one row, columns self-decode below.
logger.info(f"heal[{cfg.reg}] {len(kept)} completions x {cfg.epochs} ep = {n_steps} steps; " logger.info(f"heal[{cfg.reg}] {len(train_kept)} train (+{len(val_kept)} val) x {cfg.epochs} ep = {n_steps} steps; "
f"lora r={cfg.lora_r} on layers {cfg.layer_range}") f"lora r={cfg.lora_r} a={cfg.lora_alpha} on layers {cfg.layer_range}; "
f"lr={cfg.lr} cosine warmup={cfg.warmup_ratio} betas={cfg.adam_betas}")
logger.info("SHOULD (val): train_nll falls each epoch (SFT fits the kept data); val_nll falls then "
"flattens. If val_nll RISES while train falls -> overfit (fewer epochs / lower r). If "
"NEITHER falls -> data is near-base (nothing to distil) or the optimiser is broken.")
logger.info(f"SHOULD: nll (SFT) falls as the adapter learns the trait; kl (barrier div) is 0 for " logger.info(f"SHOULD: nll (SFT) falls as the adapter learns the trait; kl (barrier div) is 0 for "
f"reg=nll/wd and >0 for kl_rev/kl_fwd; gnorm finite (not exploding). loss = nll + lam*relu(kl-tau). " f"reg=nll/wd and >0 for kl_rev/kl_fwd; gnorm finite (not exploding). loss = nll + lam*relu(kl-tau). "
f"If kl stays < tau={cfg.tau} the barrier NEVER fired and {cfg.reg} == nll (no regularisation).") f"If kl stays < tau={cfg.tau} the barrier NEVER fired and {cfg.reg} == nll (no regularisation).")
@@ -73,7 +106,8 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
step = 0 step = 0
nlls = [] # per-step SFT loss; final = mean of last 5, the heal-stage number for the round table nlls = [] # per-step SFT loss; final = mean of last 5, the heal-stage number for the round table
for ep in range(cfg.epochs): for ep in range(cfg.epochs):
for c in kept: ep_nlls = []
for c in train_kept:
ids, mask = _encode(tok, c["prompt"], c["completion"], cfg.max_len, model.device) ids, mask = _encode(tok, c["prompt"], c["completion"], cfg.max_len, model.device)
if mask.sum() == 0: if mask.sum() == 0:
# prompt filled max_len so the completion was truncated to zero target tokens. # prompt filled max_len so the completion was truncated to zero target tokens.
@@ -103,6 +137,7 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
barrier = cfg.lam * torch.relu(div - cfg.tau) barrier = cfg.lam * torch.relu(div - cfg.tau)
loss = sft + barrier loss = sft + barrier
nlls.append(sft.item()) nlls.append(sft.item())
ep_nlls.append(sft.item())
log_now = step % max(1, n_steps // 20) == 0 or step == n_steps - 1 log_now = step % max(1, n_steps // 20) == 0 or step == n_steps - 1
if log_now: if log_now:
# split the gradient pressure: ||∇sft|| vs ||∇barrier|| (retain_graph -> still .backward below). # split the gradient pressure: ||∇sft|| vs ||∇barrier|| (retain_graph -> still .backward below).
@@ -114,6 +149,7 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
loss.backward() loss.backward()
gnorm = torch.nn.utils.clip_grad_norm_(params, 1.0) gnorm = torch.nn.utils.clip_grad_norm_(params, 1.0)
opt.step() opt.step()
sched.step()
opt.zero_grad() opt.zero_grad()
if log_now: if log_now:
logger.info(f" {step:4d} {sft.item():5.2f} {div.detach().item():4.2f} " logger.info(f" {step:4d} {sft.item():5.2f} {div.detach().item():4.2f} "
@@ -121,6 +157,8 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
pbar.set_postfix(nll=f"{sft.item():.2f}", kl=f"{div.detach().item():.2f}", gn=f"{float(gnorm):.1f}") pbar.set_postfix(nll=f"{sft.item():.2f}", kl=f"{div.detach().item():.2f}", gn=f"{float(gnorm):.1f}")
pbar.update(1) pbar.update(1)
step += 1 step += 1
val = _val_nll(model, tok, val_kept, hist_specs, lora, cfg)
logger.info(f" epoch {ep}: train_nll={sum(ep_nlls)/len(ep_nlls):.3f} val_nll={val:.3f} lr={sched.get_last_lr()[0]:.1e}")
pbar.close() pbar.close()
spec = AdapterSpec.from_lora(lora, default_c=1.0) # CPU-resident, for the next round's history spec = AdapterSpec.from_lora(lora, default_c=1.0) # CPU-resident, for the next round's history
+7 -2
View File
@@ -116,8 +116,13 @@ class ModulatedLoRA:
d_in, d_out = layer.in_features, layer.out_features d_in, d_out = layer.in_features, layer.out_features
A = torch.empty(self.cfg.r, d_in, dtype=self.cfg.dtype, device=device) A = torch.empty(self.cfg.r, d_in, dtype=self.cfg.dtype, device=device)
nn.init.kaiming_uniform_(A, a=5 ** 0.5) nn.init.kaiming_uniform_(A, a=5 ** 0.5)
B = torch.empty(d_out, self.cfg.r, dtype=self.cfg.dtype, device=device) B = torch.zeros(d_out, self.cfg.r, dtype=self.cfg.dtype, device=device)
nn.init.normal_(B, mean=1e-4, std=1e-4) # B=0 => delta = B@A = 0, so a fresh adapter is a true no-op at init: round-0 step-0
# student logits == base, hence barrier KL(student||orig) starts at 0 (only baked
# history can diverge). The old normal_(mean=1e-4) perturbed EVERY all-linear layer
# in a systematic (nonzero-mean) direction, compounding across ~200 adapters into a
# phantom ~0.6-nat KL before any training -- it sat above tau and fired the barrier
# against nothing real. A still gets gradient via B once B leaves zero (standard LoRA).
self.A[name] = nn.Parameter(A) self.A[name] = nn.Parameter(A)
self.B[name] = nn.Parameter(B) self.B[name] = nn.Parameter(B)
self._target_layers[name] = layer self._target_layers[name] = layer