diff --git a/src/steer_heal/config.py b/src/steer_heal/config.py index 011f2c2..62f70b0 100644 --- a/src/steer_heal/config.py +++ b/src/steer_heal/config.py @@ -69,6 +69,13 @@ class RunConfig: # the two are identical (no history yet); they only differ from round 1 on. barrier_ref: Literal["base", "prev"] = "prev" lam: float = 0.3 # kl-barrier weight (reg=kl_*); ignored for nll. 0.3 = coherence peak of the #98/#99 ladder (unimodal in lam, peaks 0.1-0.3, 1.0 over-tight); 0.3 = most trait at the peak + # round-ramped barrier: lam_eff = lam * (1 + round)**lam_round_pow. 0 = constant (every round same lam). + # >0 grows the barrier with round to oppose the COMPOUNDING coherence drift under barrier_ref=prev: each + # round adds ~constant divergence and they accumulate, so by round ~7 the baked adapter degenerates into + # token loops (#101 journal h: coh 0.99->0.62, "BUILDUTEutive" soup that the ppl/rep filter can't catch). + # A growing barrier holds later rounds closer to their predecessor. Trades final trait depth for more + # coherent rounds (the barrier can't tell coherence-drift from trait-drift). 0.5 = sqrt(round) ramp. + lam_round_pow: float = 0.0 tau: float = 0.5 # barrier engages only when divergence > tau (nats) weight_decay: float = 0.0 # AdamW decoupled decay on the adapter; per-step shrink ~ lr*weight_decay # spectral_lam: independent ALWAYS-ON operator-norm penalty on ΔW (σ_max via power iteration), a diff --git a/src/steer_heal/heal.py b/src/steer_heal/heal.py index 2a0735e..9a9c212 100644 --- a/src/steer_heal/heal.py +++ b/src/steer_heal/heal.py @@ -104,10 +104,16 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg: sched = get_cosine_schedule_with_warmup( opt, num_warmup_steps=int(cfg.warmup_ratio * n_steps), num_training_steps=n_steps) + # round-ramped barrier (config.lam_round_pow): round index = len(hist_specs) (R adapters baked = round R). + # lam_round_pow=0 -> lam_eff==lam (constant, no behaviour change). >0 grows the barrier with round. + rnd = len(hist_specs) + lam_eff = cfg.lam * (1 + rnd) ** cfg.lam_round_pow + # streaming training table (token-efficient-logging): one row, columns self-decode below. logger.info(f"heal[{cfg.reg}] {len(train_kept)} train (+{len(val_kept)} val) x {cfg.epochs} ep = {n_steps} steps; " f"lora r={cfg.lora_r} a={cfg.lora_alpha} on layers {cfg.layer_range}; " - f"lr={cfg.lr} cosine warmup={cfg.warmup_ratio} betas={cfg.adam_betas}") + f"lr={cfg.lr} cosine warmup={cfg.warmup_ratio} betas={cfg.adam_betas}; " + f"lam_eff={lam_eff:.3f} (lam {cfg.lam} x (1+round={rnd})^{cfg.lam_round_pow})") logger.info("SHOULD (val): train_nll falls each epoch (SFT fits the kept data); val_nll falls then " "flattens. If val_nll RISES while train falls -> overfit (fewer epochs / lower r). If " "NEITHER falls -> data is near-base (nothing to distil) or the optimiser is broken.") @@ -159,7 +165,7 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg: div = _kl_per_pos(logp[mask], logp0[mask]).mean() else: div = torch.zeros((), device=model.device) # nll - barrier = cfg.lam * torch.relu(div - cfg.tau) + barrier = lam_eff * torch.relu(div - cfg.tau) # spectral_lam: independent ALWAYS-ON operator-norm cap on ΔW (σ_max), composes with the # output-space barrier above and with weight_decay (see config.RunConfig.spectral_lam). # Folded into `barrier` so the g_bar/g_nll gradient-pressure log captures it too.