Add last-good KL anchor

This commit is contained in:
wassname
2026-06-24 12:51:58 +08:00
parent 0c2be96eeb
commit 4b90f19400
5 changed files with 157 additions and 29 deletions
+55
View File
@@ -0,0 +1,55 @@
# Last-good KL anchor
## Goal
Implement a ratcheting KL reference: heal each round against the most recent checkpoint that still passed the coherence gate. If a healed checkpoint passes, it becomes the new reference; if it fails the adoption gate but remains above `coh_floor`, the loop continues without blessing the failed checkpoint as the next reference.
This tests the hypothesis that `prev` lets incoherence drift and `base` fights trait history, while `last_good` keeps the anchor coherent without forcing the model all the way back to round 0.
## Scope
In: config knob, heal reference selection, loop state, a just recipe, fast-dev proof, queued real run.
Out: new filtering heuristics, new metrics, multi-arm sweep, changing the diary/report format unless needed for proof.
## Requirements
- R1: `barrier_ref=last_good` uses the latest coherent checkpoint as the KL reference.
Done means: the heal log prints `barrier_ref=last_good ref_round=<n>` and the ref stays unchanged until a round passes the coherence gate.
VERIFY: `just fast-dev-run --barrier-ref=last_good ...` reaches heal and logs the selected reference.
- R2: Coherence adoption is explicit and fail-fast.
Done means: after each eval, the loop logs whether the checkpoint was adopted as last-good; a failed adoption gate holds the old reference, while `coh_floor` still stops broken runs.
VERIFY: log lines show adoption only after `coherence >= max(cfg.coh_floor, last_good_coherence * cfg.ref_adopt_rel)`.
- R3: Real run is queued on branch `dv` with a why/resolve pueue label.
Done means: `pueue status --json` shows a queued/running task whose command includes `--barrier-ref=last_good`, `--kl-agg=rmse`, and a non-positive `--lam-round-pow`.
VERIFY: status table includes the task id and label.
## Tasks
- [/] T1 (R1/R2): Implement config + loop reference state.
- steps: add `last_good` literal and `ref_adopt_rel`; pass `ref_specs` into `heal_round`; update adoption logging.
- verify: `just fast-dev-run --barrier-ref=last_good --kl-agg=rmse --tau=2.0 --lam-round-pow=-0.5 --spectral-lam=0 --n-rounds=1`
- success: heal log names `barrier_ref=last_good ref_round=-1`; tiny-random holds the reference because coherence is below `coh_floor`.
- likely_fail: tyro rejects the new enum; verify command errors before model load.
- sneaky_fail: code accepts the enum but still uses `hist_specs`/`base`; log catches selected ref round and number of specs.
- UAT: the run log links to a file containing both selected-ref and adoption evidence.
- [ ] T2 (R3): Add a recipe and queue the real run.
- steps: add a `run-last-good-love` or queue recipe; pueue add from `dv` worktree with a why/resolve label.
- verify: `pueue status --json | jq ...`
- success: status row includes the task id, branch workdir, and command.
- likely_fail: pueue daemon unavailable; command reports connection failure.
- sneaky_fail: queued command runs wrong branch or missing knobs; status command shows command/path.
- UAT: status table/log path shows a queued or running task with the intended knobs.
## Context
`hist_specs` stores one `AdapterSpec` per folded round. The base reference is `[]`; the previous-student reference is `hist_specs`; the last-good reference can be represented as `hist_specs[:last_good_n]`, where `last_good_n` is the number of adopted adapters. `last_good_n=0` means base.
The coherence metric is `p_ans_any` from tinymfv. It is generous, so adoption uses both the relative 99% gate and the absolute `coh_floor`; sample judging remains in the run report/log.
## Log
- Branch `dv` created from dirty `main`; pre-existing edits in README, journal, filter, heal, steering were present before this task.
- Fast-dev caught a relative-threshold hole: tiny-random base coherence is 0, so `0.99 * ref` is 0 and would adopt a broken checkpoint. Adoption now uses `max(coh_floor, ref_adopt_rel * ref_coherence)`.
- External review attempt via `external-review-v2` timed out after ~2.5 minutes with no review text; proceeding on compile + fast-dev evidence.
## TODO
- Add a token-loop-specific adoption gate if the first last-good run still adopts visually broken rounds.
## Errors
| Task | Error | Resolution |
|------|-------|------------|
+18
View File
@@ -27,6 +27,24 @@ run *ARGS:
# Queue sweeps (comment out completed; `just results` to check).
queue:
#!/usr/bin/env bash
set -x
just queue-last-good-love
# H: last_good anchor avoids prev-anchor drift without base-anchor history erasure; rmse catches token-loop KL spikes; lam decay relaxes later rounds without disabling the hinge.
queue-last-good-love:
#!/usr/bin/env bash
set -x
pueue add -w "$PWD" -o 1 \
-l "why: test last_good KL anchor vs prev drift/base erasure on love loop; resolve: keep if coherence stays within 99% ref while care moves past README plateau" -- \
env STEER_ATTN_IMPL=eager \
{{ BASE }} --demo=love --use-qlora --train-bs=3 --grad-accum=2 \
--reg=kl_rev --barrier-ref=last_good --kl-agg=rmse --tau=2.0 \
--lam=0.3 --lam-round-pow=-0.5 --spectral-lam=0.005 \
--n-rounds=8 --seed=42
# H: kl_rev heals best (mode-seeking suppresses low-base-prob = incoherent tokens).
queue-sweep-reg:
#!/usr/bin/env bash
set -x
just sweep-reg
+6 -5
View File
@@ -97,11 +97,12 @@ class RunConfig:
# so #101's barrier never fired); incoherence is outlier-driven, so rmse/p95/max are sensitive to it
# (same loop: 1.5/3.8/8.1 vs coherent ~0.03). rmse = smooth dense gradient (train default), p95/max sparser.
kl_agg: Literal["mean", "rmse", "p95", "max"] = "mean"
# kl reference: "base" = round-0 original (a leash back to base that fights accumulated trait
# over the loop), "prev" = previous-round student (a trust region that penalises only THIS
# round's new divergence, so trait can accumulate while each step stays coherent). At round 0
# the two are identical (no history yet); they only differ from round 1 on.
barrier_ref: Literal["base", "prev"] = "prev"
# kl reference: "base" = round-0 original (leash to origin), "prev" = previous-round
# student (trust region), "last_good" = most recent checkpoint that passed the coherence
# adoption gate. last_good is the ratchet: it advances only when coherence stays within
# ref_adopt_rel of the current reference, so a bad round does not become tomorrow's anchor.
barrier_ref: Literal["base", "prev", "last_good"] = "prev"
ref_adopt_rel: float = 0.99
lam: float = 0.3 # kl-barrier weight (reg=kl_*); ignored for nll. 0.3 = coherence peak of the #98/#99 ladder (unimodal in lam, peaks 0.1-0.3, 1.0 over-tight); 0.3 = most trait at the peak
# round-ramped barrier: lam_eff = lam * (1 + round)**lam_round_pow. 0 = constant (every round same lam).
# >0 grows the barrier with round to oppose the COMPOUNDING coherence drift under barrier_ref=prev: each
+42 -22
View File
@@ -1,7 +1,6 @@
"""Q1 heal: train one round's LoRA = SFT(kept completions) + divergence-to-original barrier.
"""Q1 heal: train one round's LoRA = SFT(kept completions) + divergence barrier.
The barrier reference is the round-0 ORIGINAL (gates/adapters off), not the
previous student, so it resists cumulative drift. reg picks the divergence:
The barrier reference is chosen by cfg.barrier_ref. reg picks the divergence:
nll SFT only (control)
kl_fwd KL(orig || theta) mass-covering (dilutes the trait)
kl_rev KL(theta || orig) mode-seeking (suppresses low-orig-prob = incoherent) [expected best]
@@ -120,7 +119,15 @@ def _val_nll(model, tok, val_kept, hist_specs, lora, cfg) -> float:
return sum(losses) / len(losses) if losses else float("nan")
def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg: RunConfig):
def heal_round(
model,
tok,
kept: list[dict],
hist_specs: list[AdapterSpec],
cfg: RunConfig,
ref_specs: list[AdapterSpec] | None = None,
ref_round: int | str | None = None,
):
"""Train a fresh round adapter on top of baked history. Returns (lora, spec)."""
assert len(kept) >= cfg.min_train, (
f"only {len(kept)} kept completions; need >= {cfg.min_train} to train. The steering/filter "
@@ -146,13 +153,26 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
# lam_round_pow=0 -> lam_eff==lam (constant, no behaviour change). >0 grows the barrier with round.
rnd = len(hist_specs)
lam_eff = cfg.lam * (1 + rnd) ** cfg.lam_round_pow
if cfg.barrier_ref == "base":
barrier_ref_specs = []
ref_desc = "base"
elif cfg.barrier_ref == "prev":
barrier_ref_specs = hist_specs
ref_desc = f"prev(r{rnd - 1})"
elif cfg.barrier_ref == "last_good":
assert ref_specs is not None and ref_round is not None, "last_good barrier requires explicit ref_specs/ref_round"
barrier_ref_specs = ref_specs
ref_desc = f"last_good(r{ref_round})"
else:
raise ValueError(f"unknown barrier_ref={cfg.barrier_ref!r}")
# streaming training table (token-efficient-logging): one row, columns self-decode below.
logger.info(f"heal[{cfg.reg}] {len(train_kept)} train (+{len(val_kept)} val) x {cfg.epochs} ep = "
f"{n_batches} batches (bs={cfg.train_bs}) -> {n_opt_steps} opt steps (grad_accum={cfg.grad_accum}); "
f"lora r={cfg.lora_r} a={cfg.lora_alpha} on layers {cfg.layer_range}; "
f"lr={cfg.lr} cosine warmup={cfg.warmup_ratio} betas={cfg.adam_betas}; "
f"lam_eff={lam_eff:.3f} (lam {cfg.lam} x (1+round={rnd})^{cfg.lam_round_pow})")
f"lam_eff={lam_eff:.3f} (lam {cfg.lam} x (1+round={rnd})^{cfg.lam_round_pow}); "
f"barrier_ref={cfg.barrier_ref} ref_round={ref_round} ref={ref_desc} ref_specs={len(barrier_ref_specs)}")
logger.info("SHOULD (val): train_nll falls each epoch (SFT fits the kept data); val_nll falls then "
"flattens. If val_nll RISES while train falls -> overfit (fewer epochs / lower r). If "
"NEITHER falls -> data is near-base (nothing to distil) or the optimiser is broken.")
@@ -188,30 +208,30 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
ids = BatchEncoding({k: v[valid] for k, v in ids.items()})
masks = masks[valid] # [B', L-1]
# barrier reference logits (this round's adapter OFF). barrier_ref="base" bakes no
# history -> ref = round-0 original (leash to base, fights accumulated trait); "prev"
# bakes the history -> ref = previous-round student (trust region, penalises only this
# round's new divergence so trait accumulates while each step stays coherent).
# Gather completion positions BEFORE log_softmax. softmax is per-row, so selecting the
# ~N_comp completion rows then normalising is identical to normalising all B*(L-1) rows
# then selecting -- but it never materialises the full [B,L-1,V] log_softmax NOR its
# autograd graph. On gemma's 262k vocab at bs>1 the full tensor is what OOM'd the KL step.
flat_mask = masks.reshape(-1) # [B'*(L-1)] bool, completion positions
tgt_c = ids.input_ids[:, 1:].reshape(-1)[flat_mask] # [N_comp]
# barrier reference (this round's adapter OFF). base=[], prev=hist_specs,
# last_good=hist_specs[:last_good_n] from run.py.
if cfg.reg in ("kl_fwd", "kl_rev"):
ref_specs = hist_specs if cfg.barrier_ref == "prev" else []
with torch.no_grad(), baked(model, ref_specs), lora(model, c=0.0):
logp0 = model(**ids).logits[:, :-1].log_softmax(-1) # [B', L-1, V]
with torch.no_grad(), baked(model, barrier_ref_specs), lora(model, c=0.0):
logits0 = model(**ids).logits[:, :-1] # [B', L-1, V]
V = logits0.shape[-1]
logp0_c = logits0.reshape(-1, V)[flat_mask].log_softmax(-1) # [N_comp, V]
# student logits: history baked + this round's adapter live
# student: history baked + this round's adapter live. Same mask-first trick.
with baked(model, hist_specs), lora(model, c=1.0):
logits = model(**ids).logits[:, :-1] # [B', L-1, V]
logp = logits.log_softmax(-1)
# flatten batch × seq to masked completion tokens for loss and KL
V = logp.shape[-1]
logp_c = logp.reshape(-1, V)[masks.reshape(-1)] # [N_comp, V]
tgt_c = ids.input_ids[:, 1:].reshape(-1)[masks.reshape(-1)] # [N_comp]
logits = model(**ids).logits[:, :-1] # [B', L-1, V]
V = logits.shape[-1]
logp_c = logits.reshape(-1, V)[flat_mask].log_softmax(-1) # [N_comp, V]
sft = F.nll_loss(logp_c, tgt_c)
if cfg.reg == "kl_fwd":
logp0_c = logp0.reshape(-1, V)[masks.reshape(-1)]
div = _agg_kl(_kl_per_pos(logp0_c, logp_c), cfg.kl_agg)
elif cfg.reg == "kl_rev":
logp0_c = logp0.reshape(-1, V)[masks.reshape(-1)]
div = _agg_kl(_kl_per_pos(logp_c, logp0_c), cfg.kl_agg)
else:
div = torch.zeros((), device=model.device) # nll
+36 -2
View File
@@ -198,6 +198,9 @@ def gen_filter_walk(model, tok, v, cfg: RunConfig, hist_specs: list, rnd: int) -
def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
hist_specs = [] # AdapterSpec per folded round (gated bake history)
last_good_n = 0 # number of adapters in the ratcheted coherent reference
last_good_round = -1
last_good_coherence = None
v0_flat = None # round-0 direction, for the Q3 cosine
rounds = []
gen_rounds = [] # per-round adapter gens (same prompts) -> outputs.html table
@@ -207,6 +210,7 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
# trait), not just coherence. One extra eval per run.
logger.info(f"\n\n\n=== EVAL base [tinymfv classic] gpu {gpu_mem()} ===")
base_m = evaluate_model(model, tok, cfg, log_sample=True) # one FULL eval gen (token-efficient-logging)
last_good_coherence = base_m["coherence"]
log_event(run_dir, stage="base", round=-1, **base_m) # persist so offline plot_run.py is self-contained
stages = [{"round": "-", "stage": "base", "m": base_m}] # base -> steered -> healed, for table + trajectory plot
# BASE demo column (round -1): the no-adapter, no-steering model on the SAME demo prompts, so the
@@ -252,7 +256,16 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
# heal one round on top of the baked history, then fold
logger.info(f"\n\n\n=== r{rnd} HEAL [{cfg.reg}] gpu {gpu_mem()} ===")
lora, spec, heal_nll = heal_round(model, tok, kept, hist_specs, cfg)
ref_specs = hist_specs[:last_good_n] if cfg.barrier_ref == "last_good" else None
ref_round = last_good_round if cfg.barrier_ref == "last_good" else None
ref_coherence = last_good_coherence if cfg.barrier_ref == "last_good" else None
if cfg.barrier_ref == "last_good":
logger.info(
f"last_good reference for r{rnd}: ref_round={ref_round} ref_specs={len(ref_specs)} "
f"ref_coherence={ref_coherence:.3f}; adoption gate = new_coh >= max(coh_floor={cfg.coh_floor:.3f}, "
f"{cfg.ref_adopt_rel:.3f} * ref_coh = {cfg.ref_adopt_rel * ref_coherence:.3f})"
)
lora, spec, heal_nll = heal_round(model, tok, kept, hist_specs, cfg, ref_specs=ref_specs, ref_round=ref_round)
lora.save(str(run_dir / "ckpt" / f"r{rnd}.safetensors"), extra_meta={"round": str(rnd), "reg": cfg.reg})
hist_specs.append(spec)
@@ -291,12 +304,33 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
logger.info(f"\n\n\n=== ADAPTER DEMO r{rnd} coh(p_ans_any)={m['coherence']:.3f} adapter_ppl={adapter_ppl:.0f} "
f"(no steering; compare across rounds: change vs saturation) ===\n" + demo_lines)
ref_adopted = False
if cfg.barrier_ref == "last_good":
adopt_threshold = max(cfg.coh_floor, cfg.ref_adopt_rel * last_good_coherence)
if m["coherence"] >= adopt_threshold:
last_good_n = len(hist_specs)
last_good_round = rnd
last_good_coherence = m["coherence"]
ref_adopted = True
logger.info(
f"last_good ADOPT r{rnd}: coherence={m['coherence']:.3f} >= "
f"threshold={adopt_threshold:.3f}; next ref_specs={last_good_n}"
)
else:
logger.warning(
f"last_good HOLD at r{last_good_round}: r{rnd} coherence={m['coherence']:.3f} < "
f"threshold={adopt_threshold:.3f}; next round still leashes to r{last_good_round}"
)
vf = _flatten_v(v)
v0_flat = vf if v0_flat is None else v0_flat
cos_v0 = float(cosine_similarity(vf, v0_flat, dim=0))
rec = {"round": rnd, **m, "cos_v0": cos_v0, "steered_ppl": steered_ppl,
"adapter_ppl": adapter_ppl, "n_comps": n_comps, "n_kept": len(kept),
"kappa": kappa, "heal_nll": heal_nll}
"kappa": kappa, "heal_nll": heal_nll,
"barrier_ref_round": ref_round, "barrier_ref_coherence": ref_coherence,
"last_good_round": last_good_round if cfg.barrier_ref == "last_good" else None,
"last_good_adopted": ref_adopted}
rounds.append(rec)
stages.append({"round": rnd, "stage": "steered", "m": m_steer})
stages.append({"round": rnd, "stage": "healed", "m": m})