diff --git a/src/steer_heal/config.py b/src/steer_heal/config.py index 8e57e66..8893f2f 100644 --- a/src/steer_heal/config.py +++ b/src/steer_heal/config.py @@ -29,7 +29,7 @@ class RunConfig: # carry the trait, so completions are generated with no persona. gen_system: str = "You are a helpful assistant." steer_layers: tuple[float, float] = (0.45, 0.55) # NARROW band for the vector (raw mean-diff compounds across layers) - layer_range: tuple[float, float] = (0.0, 1.0) # BROAD band for the LoRA (train trait into many layers) + layer_range: tuple[float, float] = (0.2, 0.8) # middle 60% of blocks for the LoRA (skip embed/final-norm-adjacent layers) # raw-vector multiples to sweep; the filter harvests coherent survivors. Dropped 0.25 # (filter audit: base-like, no distinct trait); kept 0.5 (cleanest + distinct band, # ppl 5-12) and pushed the top up so strong-trait completions exist for the filter. @@ -50,10 +50,15 @@ class RunConfig: reg: Literal["nll", "kl_fwd", "kl_rev", "wd"] = "kl_rev" lam: float = 1.0 # barrier weight (also weight_decay when reg == "wd") tau: float = 0.5 # barrier engages only when divergence > tau (nats) - lora_r: int = 8 - lora_alpha: float = 16.0 - epochs: int = 2 + lora_r: int = 32 + lora_alpha: float = 64.0 # keep scale = alpha/r = 2 (w2s convention alpha = 2r) + epochs: int = 6 # was 2: too few steps to see loss descend; val nll guards overfit lr: float = 1e-4 + warmup_ratio: float = 0.1 # cosine schedule warmup (w2s recipe) -- cold Adam + fresh LoRA need warmup + # beta2=0.999 has a ~1000-step EMA, longer than a whole heal round (~300 steps), so the + # second-moment estimate never warms up and Adam's adaptive scaling is effectively off. + # 0.95 -> ~20-step EMA, warms in ~40 steps. beta1 standard. + adam_betas: tuple[float, float] = (0.9, 0.95) # ── eval (tinymfv) ── eval_vignettes: int | None = None # None = all Clifford-2015 vignettes diff --git a/src/steer_heal/heal.py b/src/steer_heal/heal.py index ad1cf3d..3966a7e 100644 --- a/src/steer_heal/heal.py +++ b/src/steer_heal/heal.py @@ -8,10 +8,13 @@ previous student, so it resists cumulative drift. reg picks the divergence: wd weight decay on the adapter only """ +import random + import torch from loguru import logger from torch.nn import functional as F from tqdm import tqdm +from transformers import get_cosine_schedule_with_warmup from steer_heal.config import RunConfig from steer_heal.ws.adapter import ModulatedLoRA @@ -46,20 +49,50 @@ def _encode(tok, prompt: str, completion: str, max_len: int, device): return ids, tgt_is_completion +def _val_nll(model, tok, val_kept, hist_specs, lora, cfg) -> float: + """Held-out SFT nll (same student state as train: history baked, adapter live). The trait + eval is the real metric, but val_nll catches the optimisation failure modes the eval can't: + train falls + val rises = overfit; NEITHER falls = data near-base / opt broken.""" + if not val_kept: + return float("nan") + losses = [] + with torch.no_grad(), baked(model, hist_specs), lora(model, c=1.0): + for c in val_kept: + ids, mask = _encode(tok, c["prompt"], c["completion"], cfg.max_len, model.device) + if mask.sum() == 0: + continue + logp = model(**ids).logits[0, :-1].log_softmax(-1) + losses.append(F.nll_loss(logp[mask], ids.input_ids[0, 1:][mask]).item()) + return sum(losses) / len(losses) if losses else float("nan") + + def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg: RunConfig): """Train a fresh round adapter on top of baked history. Returns (lora, spec).""" assert len(kept) >= cfg.min_train, ( f"only {len(kept)} kept completions; need >= {cfg.min_train} to train. The steering/filter " "starved the data (over-steered -> all garbage, or ppl_tau too strict). Fix upstream, do not train." ) + # hold out ~1/8 for a val nll curve (shuffled so val isn't all one alpha). Tiny-dev keeps all + # for train (len//8 == 0) so the path still runs. + shuf = kept[:] + random.Random(cfg.seed).shuffle(shuf) + n_val = len(shuf) // 8 + val_kept, train_kept = shuf[:n_val], shuf[n_val:] lora = ModulatedLoRA(model, r=cfg.lora_r, alpha=cfg.lora_alpha, layer_range=cfg.layer_range) params = list(lora.parameters()) - opt = torch.optim.AdamW(params, lr=cfg.lr, weight_decay=(cfg.lam if cfg.reg == "wd" else 0.0)) - n_steps = len(kept) * cfg.epochs + opt = torch.optim.AdamW(params, lr=cfg.lr, betas=cfg.adam_betas, + weight_decay=(cfg.lam if cfg.reg == "wd" else 0.0)) + n_steps = len(train_kept) * cfg.epochs + sched = get_cosine_schedule_with_warmup( + opt, num_warmup_steps=int(cfg.warmup_ratio * n_steps), num_training_steps=n_steps) # streaming training table (token-efficient-logging): one row, columns self-decode below. - logger.info(f"heal[{cfg.reg}] {len(kept)} completions x {cfg.epochs} ep = {n_steps} steps; " - f"lora r={cfg.lora_r} on layers {cfg.layer_range}") + logger.info(f"heal[{cfg.reg}] {len(train_kept)} train (+{len(val_kept)} val) x {cfg.epochs} ep = {n_steps} steps; " + f"lora r={cfg.lora_r} a={cfg.lora_alpha} on layers {cfg.layer_range}; " + f"lr={cfg.lr} cosine warmup={cfg.warmup_ratio} betas={cfg.adam_betas}") + logger.info("SHOULD (val): train_nll falls each epoch (SFT fits the kept data); val_nll falls then " + "flattens. If val_nll RISES while train falls -> overfit (fewer epochs / lower r). If " + "NEITHER falls -> data is near-base (nothing to distil) or the optimiser is broken.") logger.info(f"SHOULD: nll (SFT) falls as the adapter learns the trait; kl (barrier div) is 0 for " f"reg=nll/wd and >0 for kl_rev/kl_fwd; gnorm finite (not exploding). loss = nll + lam*relu(kl-tau). " f"If kl stays < tau={cfg.tau} the barrier NEVER fired and {cfg.reg} == nll (no regularisation).") @@ -73,7 +106,8 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg: step = 0 nlls = [] # per-step SFT loss; final = mean of last 5, the heal-stage number for the round table for ep in range(cfg.epochs): - for c in kept: + ep_nlls = [] + for c in train_kept: ids, mask = _encode(tok, c["prompt"], c["completion"], cfg.max_len, model.device) if mask.sum() == 0: # prompt filled max_len so the completion was truncated to zero target tokens. @@ -103,6 +137,7 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg: barrier = cfg.lam * torch.relu(div - cfg.tau) loss = sft + barrier nlls.append(sft.item()) + ep_nlls.append(sft.item()) log_now = step % max(1, n_steps // 20) == 0 or step == n_steps - 1 if log_now: # split the gradient pressure: ||∇sft|| vs ||∇barrier|| (retain_graph -> still .backward below). @@ -114,6 +149,7 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg: loss.backward() gnorm = torch.nn.utils.clip_grad_norm_(params, 1.0) opt.step() + sched.step() opt.zero_grad() if log_now: logger.info(f" {step:4d} {sft.item():5.2f} {div.detach().item():4.2f} " @@ -121,6 +157,8 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg: pbar.set_postfix(nll=f"{sft.item():.2f}", kl=f"{div.detach().item():.2f}", gn=f"{float(gnorm):.1f}") pbar.update(1) step += 1 + val = _val_nll(model, tok, val_kept, hist_specs, lora, cfg) + logger.info(f" epoch {ep}: train_nll={sum(ep_nlls)/len(ep_nlls):.3f} val_nll={val:.3f} lr={sched.get_last_lr()[0]:.1e}") pbar.close() spec = AdapterSpec.from_lora(lora, default_c=1.0) # CPU-resident, for the next round's history diff --git a/src/steer_heal/ws/adapter.py b/src/steer_heal/ws/adapter.py index fea49b7..9e4f4f1 100644 --- a/src/steer_heal/ws/adapter.py +++ b/src/steer_heal/ws/adapter.py @@ -116,8 +116,13 @@ class ModulatedLoRA: d_in, d_out = layer.in_features, layer.out_features A = torch.empty(self.cfg.r, d_in, dtype=self.cfg.dtype, device=device) nn.init.kaiming_uniform_(A, a=5 ** 0.5) - B = torch.empty(d_out, self.cfg.r, dtype=self.cfg.dtype, device=device) - nn.init.normal_(B, mean=1e-4, std=1e-4) + B = torch.zeros(d_out, self.cfg.r, dtype=self.cfg.dtype, device=device) + # B=0 => delta = B@A = 0, so a fresh adapter is a true no-op at init: round-0 step-0 + # student logits == base, hence barrier KL(student||orig) starts at 0 (only baked + # history can diverge). The old normal_(mean=1e-4) perturbed EVERY all-linear layer + # in a systematic (nonzero-mean) direction, compounding across ~200 adapters into a + # phantom ~0.6-nat KL before any training -- it sat above tau and fired the barrier + # against nothing real. A still gets gradient via B once B leaves zero (standard LoRA). self.A[name] = nn.Parameter(A) self.B[name] = nn.Parameter(B) self._target_layers[name] = layer