mirror of
https://github.com/wassname/steer-heal-love.git
synced 2026-06-27 17:02:34 +08:00
heal: fix phantom-KL LoRA init (B=0), add cosine+warmup schedule, val nll, short-run betas
Root cause of KL starting at ~0.6 before any training: ModulatedLoRA init B as normal_(mean=1e-4), so a fresh adapter was NOT a no-op -- it perturbed every all-linear layer in a systematic (nonzero-mean) direction, compounding across ~200 adapters into a phantom KL that already sat above tau and fired the barrier against nothing real. B=0 makes delta=B@A=0 at init, so round-0 step-0 KL=0 (verified in fast-dev: kl=0.00 at step 0); A still trains via B (standard LoRA). Why heal loss wasn't descending: beta2=0.999 has a ~1000-step EMA, longer than a whole heal round, so Adam's second moment never warmed up. betas=(0.9, 0.95) + cosine-with-warmup schedule (w2s recipe). Also r 8->32 (alpha 64, keep scale=2), layer_range (0.0,1.0)->(0.2,0.8), epochs 2->6. Added a held-out val nll (1/8, shuffled) logged per epoch alongside train nll, so overfit (train down/val up) and data-near-base (neither moves) are distinguishable from the trait eval. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -29,7 +29,7 @@ class RunConfig:
|
||||
# carry the trait, so completions are generated with no persona.
|
||||
gen_system: str = "You are a helpful assistant."
|
||||
steer_layers: tuple[float, float] = (0.45, 0.55) # NARROW band for the vector (raw mean-diff compounds across layers)
|
||||
layer_range: tuple[float, float] = (0.0, 1.0) # BROAD band for the LoRA (train trait into many layers)
|
||||
layer_range: tuple[float, float] = (0.2, 0.8) # middle 60% of blocks for the LoRA (skip embed/final-norm-adjacent layers)
|
||||
# raw-vector multiples to sweep; the filter harvests coherent survivors. Dropped 0.25
|
||||
# (filter audit: base-like, no distinct trait); kept 0.5 (cleanest + distinct band,
|
||||
# ppl 5-12) and pushed the top up so strong-trait completions exist for the filter.
|
||||
@@ -50,10 +50,15 @@ class RunConfig:
|
||||
reg: Literal["nll", "kl_fwd", "kl_rev", "wd"] = "kl_rev"
|
||||
lam: float = 1.0 # barrier weight (also weight_decay when reg == "wd")
|
||||
tau: float = 0.5 # barrier engages only when divergence > tau (nats)
|
||||
lora_r: int = 8
|
||||
lora_alpha: float = 16.0
|
||||
epochs: int = 2
|
||||
lora_r: int = 32
|
||||
lora_alpha: float = 64.0 # keep scale = alpha/r = 2 (w2s convention alpha = 2r)
|
||||
epochs: int = 6 # was 2: too few steps to see loss descend; val nll guards overfit
|
||||
lr: float = 1e-4
|
||||
warmup_ratio: float = 0.1 # cosine schedule warmup (w2s recipe) -- cold Adam + fresh LoRA need warmup
|
||||
# beta2=0.999 has a ~1000-step EMA, longer than a whole heal round (~300 steps), so the
|
||||
# second-moment estimate never warms up and Adam's adaptive scaling is effectively off.
|
||||
# 0.95 -> ~20-step EMA, warms in ~40 steps. beta1 standard.
|
||||
adam_betas: tuple[float, float] = (0.9, 0.95)
|
||||
|
||||
# ── eval (tinymfv) ──
|
||||
eval_vignettes: int | None = None # None = all Clifford-2015 vignettes
|
||||
|
||||
+43
-5
@@ -8,10 +8,13 @@ previous student, so it resists cumulative drift. reg picks the divergence:
|
||||
wd weight decay on the adapter only
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from torch.nn import functional as F
|
||||
from tqdm import tqdm
|
||||
from transformers import get_cosine_schedule_with_warmup
|
||||
|
||||
from steer_heal.config import RunConfig
|
||||
from steer_heal.ws.adapter import ModulatedLoRA
|
||||
@@ -46,20 +49,50 @@ def _encode(tok, prompt: str, completion: str, max_len: int, device):
|
||||
return ids, tgt_is_completion
|
||||
|
||||
|
||||
def _val_nll(model, tok, val_kept, hist_specs, lora, cfg) -> float:
|
||||
"""Held-out SFT nll (same student state as train: history baked, adapter live). The trait
|
||||
eval is the real metric, but val_nll catches the optimisation failure modes the eval can't:
|
||||
train falls + val rises = overfit; NEITHER falls = data near-base / opt broken."""
|
||||
if not val_kept:
|
||||
return float("nan")
|
||||
losses = []
|
||||
with torch.no_grad(), baked(model, hist_specs), lora(model, c=1.0):
|
||||
for c in val_kept:
|
||||
ids, mask = _encode(tok, c["prompt"], c["completion"], cfg.max_len, model.device)
|
||||
if mask.sum() == 0:
|
||||
continue
|
||||
logp = model(**ids).logits[0, :-1].log_softmax(-1)
|
||||
losses.append(F.nll_loss(logp[mask], ids.input_ids[0, 1:][mask]).item())
|
||||
return sum(losses) / len(losses) if losses else float("nan")
|
||||
|
||||
|
||||
def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg: RunConfig):
|
||||
"""Train a fresh round adapter on top of baked history. Returns (lora, spec)."""
|
||||
assert len(kept) >= cfg.min_train, (
|
||||
f"only {len(kept)} kept completions; need >= {cfg.min_train} to train. The steering/filter "
|
||||
"starved the data (over-steered -> all garbage, or ppl_tau too strict). Fix upstream, do not train."
|
||||
)
|
||||
# hold out ~1/8 for a val nll curve (shuffled so val isn't all one alpha). Tiny-dev keeps all
|
||||
# for train (len//8 == 0) so the path still runs.
|
||||
shuf = kept[:]
|
||||
random.Random(cfg.seed).shuffle(shuf)
|
||||
n_val = len(shuf) // 8
|
||||
val_kept, train_kept = shuf[:n_val], shuf[n_val:]
|
||||
lora = ModulatedLoRA(model, r=cfg.lora_r, alpha=cfg.lora_alpha, layer_range=cfg.layer_range)
|
||||
params = list(lora.parameters())
|
||||
opt = torch.optim.AdamW(params, lr=cfg.lr, weight_decay=(cfg.lam if cfg.reg == "wd" else 0.0))
|
||||
n_steps = len(kept) * cfg.epochs
|
||||
opt = torch.optim.AdamW(params, lr=cfg.lr, betas=cfg.adam_betas,
|
||||
weight_decay=(cfg.lam if cfg.reg == "wd" else 0.0))
|
||||
n_steps = len(train_kept) * cfg.epochs
|
||||
sched = get_cosine_schedule_with_warmup(
|
||||
opt, num_warmup_steps=int(cfg.warmup_ratio * n_steps), num_training_steps=n_steps)
|
||||
|
||||
# streaming training table (token-efficient-logging): one row, columns self-decode below.
|
||||
logger.info(f"heal[{cfg.reg}] {len(kept)} completions x {cfg.epochs} ep = {n_steps} steps; "
|
||||
f"lora r={cfg.lora_r} on layers {cfg.layer_range}")
|
||||
logger.info(f"heal[{cfg.reg}] {len(train_kept)} train (+{len(val_kept)} val) x {cfg.epochs} ep = {n_steps} steps; "
|
||||
f"lora r={cfg.lora_r} a={cfg.lora_alpha} on layers {cfg.layer_range}; "
|
||||
f"lr={cfg.lr} cosine warmup={cfg.warmup_ratio} betas={cfg.adam_betas}")
|
||||
logger.info("SHOULD (val): train_nll falls each epoch (SFT fits the kept data); val_nll falls then "
|
||||
"flattens. If val_nll RISES while train falls -> overfit (fewer epochs / lower r). If "
|
||||
"NEITHER falls -> data is near-base (nothing to distil) or the optimiser is broken.")
|
||||
logger.info(f"SHOULD: nll (SFT) falls as the adapter learns the trait; kl (barrier div) is 0 for "
|
||||
f"reg=nll/wd and >0 for kl_rev/kl_fwd; gnorm finite (not exploding). loss = nll + lam*relu(kl-tau). "
|
||||
f"If kl stays < tau={cfg.tau} the barrier NEVER fired and {cfg.reg} == nll (no regularisation).")
|
||||
@@ -73,7 +106,8 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
|
||||
step = 0
|
||||
nlls = [] # per-step SFT loss; final = mean of last 5, the heal-stage number for the round table
|
||||
for ep in range(cfg.epochs):
|
||||
for c in kept:
|
||||
ep_nlls = []
|
||||
for c in train_kept:
|
||||
ids, mask = _encode(tok, c["prompt"], c["completion"], cfg.max_len, model.device)
|
||||
if mask.sum() == 0:
|
||||
# prompt filled max_len so the completion was truncated to zero target tokens.
|
||||
@@ -103,6 +137,7 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
|
||||
barrier = cfg.lam * torch.relu(div - cfg.tau)
|
||||
loss = sft + barrier
|
||||
nlls.append(sft.item())
|
||||
ep_nlls.append(sft.item())
|
||||
log_now = step % max(1, n_steps // 20) == 0 or step == n_steps - 1
|
||||
if log_now:
|
||||
# split the gradient pressure: ||∇sft|| vs ||∇barrier|| (retain_graph -> still .backward below).
|
||||
@@ -114,6 +149,7 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
|
||||
loss.backward()
|
||||
gnorm = torch.nn.utils.clip_grad_norm_(params, 1.0)
|
||||
opt.step()
|
||||
sched.step()
|
||||
opt.zero_grad()
|
||||
if log_now:
|
||||
logger.info(f" {step:4d} {sft.item():5.2f} {div.detach().item():4.2f} "
|
||||
@@ -121,6 +157,8 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
|
||||
pbar.set_postfix(nll=f"{sft.item():.2f}", kl=f"{div.detach().item():.2f}", gn=f"{float(gnorm):.1f}")
|
||||
pbar.update(1)
|
||||
step += 1
|
||||
val = _val_nll(model, tok, val_kept, hist_specs, lora, cfg)
|
||||
logger.info(f" epoch {ep}: train_nll={sum(ep_nlls)/len(ep_nlls):.3f} val_nll={val:.3f} lr={sched.get_last_lr()[0]:.1e}")
|
||||
pbar.close()
|
||||
|
||||
spec = AdapterSpec.from_lora(lora, default_c=1.0) # CPU-resident, for the next round's history
|
||||
|
||||
@@ -116,8 +116,13 @@ class ModulatedLoRA:
|
||||
d_in, d_out = layer.in_features, layer.out_features
|
||||
A = torch.empty(self.cfg.r, d_in, dtype=self.cfg.dtype, device=device)
|
||||
nn.init.kaiming_uniform_(A, a=5 ** 0.5)
|
||||
B = torch.empty(d_out, self.cfg.r, dtype=self.cfg.dtype, device=device)
|
||||
nn.init.normal_(B, mean=1e-4, std=1e-4)
|
||||
B = torch.zeros(d_out, self.cfg.r, dtype=self.cfg.dtype, device=device)
|
||||
# B=0 => delta = B@A = 0, so a fresh adapter is a true no-op at init: round-0 step-0
|
||||
# student logits == base, hence barrier KL(student||orig) starts at 0 (only baked
|
||||
# history can diverge). The old normal_(mean=1e-4) perturbed EVERY all-linear layer
|
||||
# in a systematic (nonzero-mean) direction, compounding across ~200 adapters into a
|
||||
# phantom ~0.6-nat KL before any training -- it sat above tau and fired the barrier
|
||||
# against nothing real. A still gets gradient via B once B leaves zero (standard LoRA).
|
||||
self.A[name] = nn.Parameter(A)
|
||||
self.B[name] = nn.Parameter(B)
|
||||
self._target_layers[name] = layer
|
||||
|
||||
Reference in New Issue
Block a user