mirror of
https://github.com/wassname/steer-heal-love.git
synced 2026-06-27 15:32:28 +08:00
heal: round-ramped barrier knob lam_round_pow (lam_eff = lam*(1+round)^pow)
Opposes the compounding coherence drift under barrier_ref=prev that degenerated #101 into token loops by round 7 (journal h). pow=0 is byte-identical to the flat-lam baseline (lam_eff==lam at every round); pow=0.5 = sqrt(round) ramp. Round index = len(hist_specs). Queued as #102 (pow=0.5) paired vs #101 (pow=0). Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -69,6 +69,13 @@ class RunConfig:
|
||||
# the two are identical (no history yet); they only differ from round 1 on.
|
||||
barrier_ref: Literal["base", "prev"] = "prev"
|
||||
lam: float = 0.3 # kl-barrier weight (reg=kl_*); ignored for nll. 0.3 = coherence peak of the #98/#99 ladder (unimodal in lam, peaks 0.1-0.3, 1.0 over-tight); 0.3 = most trait at the peak
|
||||
# round-ramped barrier: lam_eff = lam * (1 + round)**lam_round_pow. 0 = constant (every round same lam).
|
||||
# >0 grows the barrier with round to oppose the COMPOUNDING coherence drift under barrier_ref=prev: each
|
||||
# round adds ~constant divergence and they accumulate, so by round ~7 the baked adapter degenerates into
|
||||
# token loops (#101 journal h: coh 0.99->0.62, "BUILDUTEutive" soup that the ppl/rep filter can't catch).
|
||||
# A growing barrier holds later rounds closer to their predecessor. Trades final trait depth for more
|
||||
# coherent rounds (the barrier can't tell coherence-drift from trait-drift). 0.5 = sqrt(round) ramp.
|
||||
lam_round_pow: float = 0.0
|
||||
tau: float = 0.5 # barrier engages only when divergence > tau (nats)
|
||||
weight_decay: float = 0.0 # AdamW decoupled decay on the adapter; per-step shrink ~ lr*weight_decay
|
||||
# spectral_lam: independent ALWAYS-ON operator-norm penalty on ΔW (σ_max via power iteration), a
|
||||
|
||||
@@ -104,10 +104,16 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
|
||||
sched = get_cosine_schedule_with_warmup(
|
||||
opt, num_warmup_steps=int(cfg.warmup_ratio * n_steps), num_training_steps=n_steps)
|
||||
|
||||
# round-ramped barrier (config.lam_round_pow): round index = len(hist_specs) (R adapters baked = round R).
|
||||
# lam_round_pow=0 -> lam_eff==lam (constant, no behaviour change). >0 grows the barrier with round.
|
||||
rnd = len(hist_specs)
|
||||
lam_eff = cfg.lam * (1 + rnd) ** cfg.lam_round_pow
|
||||
|
||||
# streaming training table (token-efficient-logging): one row, columns self-decode below.
|
||||
logger.info(f"heal[{cfg.reg}] {len(train_kept)} train (+{len(val_kept)} val) x {cfg.epochs} ep = {n_steps} steps; "
|
||||
f"lora r={cfg.lora_r} a={cfg.lora_alpha} on layers {cfg.layer_range}; "
|
||||
f"lr={cfg.lr} cosine warmup={cfg.warmup_ratio} betas={cfg.adam_betas}")
|
||||
f"lr={cfg.lr} cosine warmup={cfg.warmup_ratio} betas={cfg.adam_betas}; "
|
||||
f"lam_eff={lam_eff:.3f} (lam {cfg.lam} x (1+round={rnd})^{cfg.lam_round_pow})")
|
||||
logger.info("SHOULD (val): train_nll falls each epoch (SFT fits the kept data); val_nll falls then "
|
||||
"flattens. If val_nll RISES while train falls -> overfit (fewer epochs / lower r). If "
|
||||
"NEITHER falls -> data is near-base (nothing to distil) or the optimiser is broken.")
|
||||
@@ -159,7 +165,7 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
|
||||
div = _kl_per_pos(logp[mask], logp0[mask]).mean()
|
||||
else:
|
||||
div = torch.zeros((), device=model.device) # nll
|
||||
barrier = cfg.lam * torch.relu(div - cfg.tau)
|
||||
barrier = lam_eff * torch.relu(div - cfg.tau)
|
||||
# spectral_lam: independent ALWAYS-ON operator-norm cap on ΔW (σ_max), composes with the
|
||||
# output-space barrier above and with weight_decay (see config.RunConfig.spectral_lam).
|
||||
# Folded into `barrier` so the g_bar/g_nll gradient-pressure log captures it too.
|
||||
|
||||
Reference in New Issue
Block a user