mirror of
https://github.com/wassname/steer-heal-love.git
synced 2026-06-27 16:47:16 +08:00
Add last-good KL anchor
This commit is contained in:
@@ -0,0 +1,55 @@
|
||||
# Last-good KL anchor
|
||||
|
||||
## Goal
|
||||
Implement a ratcheting KL reference: heal each round against the most recent checkpoint that still passed the coherence gate. If a healed checkpoint passes, it becomes the new reference; if it fails the adoption gate but remains above `coh_floor`, the loop continues without blessing the failed checkpoint as the next reference.
|
||||
|
||||
This tests the hypothesis that `prev` lets incoherence drift and `base` fights trait history, while `last_good` keeps the anchor coherent without forcing the model all the way back to round 0.
|
||||
|
||||
## Scope
|
||||
In: config knob, heal reference selection, loop state, a just recipe, fast-dev proof, queued real run.
|
||||
|
||||
Out: new filtering heuristics, new metrics, multi-arm sweep, changing the diary/report format unless needed for proof.
|
||||
|
||||
## Requirements
|
||||
- R1: `barrier_ref=last_good` uses the latest coherent checkpoint as the KL reference.
|
||||
Done means: the heal log prints `barrier_ref=last_good ref_round=<n>` and the ref stays unchanged until a round passes the coherence gate.
|
||||
VERIFY: `just fast-dev-run --barrier-ref=last_good ...` reaches heal and logs the selected reference.
|
||||
- R2: Coherence adoption is explicit and fail-fast.
|
||||
Done means: after each eval, the loop logs whether the checkpoint was adopted as last-good; a failed adoption gate holds the old reference, while `coh_floor` still stops broken runs.
|
||||
VERIFY: log lines show adoption only after `coherence >= max(cfg.coh_floor, last_good_coherence * cfg.ref_adopt_rel)`.
|
||||
- R3: Real run is queued on branch `dv` with a why/resolve pueue label.
|
||||
Done means: `pueue status --json` shows a queued/running task whose command includes `--barrier-ref=last_good`, `--kl-agg=rmse`, and a non-positive `--lam-round-pow`.
|
||||
VERIFY: status table includes the task id and label.
|
||||
|
||||
## Tasks
|
||||
- [/] T1 (R1/R2): Implement config + loop reference state.
|
||||
- steps: add `last_good` literal and `ref_adopt_rel`; pass `ref_specs` into `heal_round`; update adoption logging.
|
||||
- verify: `just fast-dev-run --barrier-ref=last_good --kl-agg=rmse --tau=2.0 --lam-round-pow=-0.5 --spectral-lam=0 --n-rounds=1`
|
||||
- success: heal log names `barrier_ref=last_good ref_round=-1`; tiny-random holds the reference because coherence is below `coh_floor`.
|
||||
- likely_fail: tyro rejects the new enum; verify command errors before model load.
|
||||
- sneaky_fail: code accepts the enum but still uses `hist_specs`/`base`; log catches selected ref round and number of specs.
|
||||
- UAT: the run log links to a file containing both selected-ref and adoption evidence.
|
||||
- [ ] T2 (R3): Add a recipe and queue the real run.
|
||||
- steps: add a `run-last-good-love` or queue recipe; pueue add from `dv` worktree with a why/resolve label.
|
||||
- verify: `pueue status --json | jq ...`
|
||||
- success: status row includes the task id, branch workdir, and command.
|
||||
- likely_fail: pueue daemon unavailable; command reports connection failure.
|
||||
- sneaky_fail: queued command runs wrong branch or missing knobs; status command shows command/path.
|
||||
- UAT: status table/log path shows a queued or running task with the intended knobs.
|
||||
|
||||
## Context
|
||||
`hist_specs` stores one `AdapterSpec` per folded round. The base reference is `[]`; the previous-student reference is `hist_specs`; the last-good reference can be represented as `hist_specs[:last_good_n]`, where `last_good_n` is the number of adopted adapters. `last_good_n=0` means base.
|
||||
|
||||
The coherence metric is `p_ans_any` from tinymfv. It is generous, so adoption uses both the relative 99% gate and the absolute `coh_floor`; sample judging remains in the run report/log.
|
||||
|
||||
## Log
|
||||
- Branch `dv` created from dirty `main`; pre-existing edits in README, journal, filter, heal, steering were present before this task.
|
||||
- Fast-dev caught a relative-threshold hole: tiny-random base coherence is 0, so `0.99 * ref` is 0 and would adopt a broken checkpoint. Adoption now uses `max(coh_floor, ref_adopt_rel * ref_coherence)`.
|
||||
- External review attempt via `external-review-v2` timed out after ~2.5 minutes with no review text; proceeding on compile + fast-dev evidence.
|
||||
|
||||
## TODO
|
||||
- Add a token-loop-specific adoption gate if the first last-good run still adopts visually broken rounds.
|
||||
|
||||
## Errors
|
||||
| Task | Error | Resolution |
|
||||
|------|-------|------------|
|
||||
@@ -27,6 +27,24 @@ run *ARGS:
|
||||
|
||||
# Queue sweeps (comment out completed; `just results` to check).
|
||||
queue:
|
||||
#!/usr/bin/env bash
|
||||
set -x
|
||||
just queue-last-good-love
|
||||
|
||||
# H: last_good anchor avoids prev-anchor drift without base-anchor history erasure; rmse catches token-loop KL spikes; lam decay relaxes later rounds without disabling the hinge.
|
||||
queue-last-good-love:
|
||||
#!/usr/bin/env bash
|
||||
set -x
|
||||
pueue add -w "$PWD" -o 1 \
|
||||
-l "why: test last_good KL anchor vs prev drift/base erasure on love loop; resolve: keep if coherence stays within 99% ref while care moves past README plateau" -- \
|
||||
env STEER_ATTN_IMPL=eager \
|
||||
{{ BASE }} --demo=love --use-qlora --train-bs=3 --grad-accum=2 \
|
||||
--reg=kl_rev --barrier-ref=last_good --kl-agg=rmse --tau=2.0 \
|
||||
--lam=0.3 --lam-round-pow=-0.5 --spectral-lam=0.005 \
|
||||
--n-rounds=8 --seed=42
|
||||
|
||||
# H: kl_rev heals best (mode-seeking suppresses low-base-prob = incoherent tokens).
|
||||
queue-sweep-reg:
|
||||
#!/usr/bin/env bash
|
||||
set -x
|
||||
just sweep-reg
|
||||
|
||||
@@ -97,11 +97,12 @@ class RunConfig:
|
||||
# so #101's barrier never fired); incoherence is outlier-driven, so rmse/p95/max are sensitive to it
|
||||
# (same loop: 1.5/3.8/8.1 vs coherent ~0.03). rmse = smooth dense gradient (train default), p95/max sparser.
|
||||
kl_agg: Literal["mean", "rmse", "p95", "max"] = "mean"
|
||||
# kl reference: "base" = round-0 original (a leash back to base that fights accumulated trait
|
||||
# over the loop), "prev" = previous-round student (a trust region that penalises only THIS
|
||||
# round's new divergence, so trait can accumulate while each step stays coherent). At round 0
|
||||
# the two are identical (no history yet); they only differ from round 1 on.
|
||||
barrier_ref: Literal["base", "prev"] = "prev"
|
||||
# kl reference: "base" = round-0 original (leash to origin), "prev" = previous-round
|
||||
# student (trust region), "last_good" = most recent checkpoint that passed the coherence
|
||||
# adoption gate. last_good is the ratchet: it advances only when coherence stays within
|
||||
# ref_adopt_rel of the current reference, so a bad round does not become tomorrow's anchor.
|
||||
barrier_ref: Literal["base", "prev", "last_good"] = "prev"
|
||||
ref_adopt_rel: float = 0.99
|
||||
lam: float = 0.3 # kl-barrier weight (reg=kl_*); ignored for nll. 0.3 = coherence peak of the #98/#99 ladder (unimodal in lam, peaks 0.1-0.3, 1.0 over-tight); 0.3 = most trait at the peak
|
||||
# round-ramped barrier: lam_eff = lam * (1 + round)**lam_round_pow. 0 = constant (every round same lam).
|
||||
# >0 grows the barrier with round to oppose the COMPOUNDING coherence drift under barrier_ref=prev: each
|
||||
|
||||
+42
-22
@@ -1,7 +1,6 @@
|
||||
"""Q1 heal: train one round's LoRA = SFT(kept completions) + divergence-to-original barrier.
|
||||
"""Q1 heal: train one round's LoRA = SFT(kept completions) + divergence barrier.
|
||||
|
||||
The barrier reference is the round-0 ORIGINAL (gates/adapters off), not the
|
||||
previous student, so it resists cumulative drift. reg picks the divergence:
|
||||
The barrier reference is chosen by cfg.barrier_ref. reg picks the divergence:
|
||||
nll SFT only (control)
|
||||
kl_fwd KL(orig || theta) mass-covering (dilutes the trait)
|
||||
kl_rev KL(theta || orig) mode-seeking (suppresses low-orig-prob = incoherent) [expected best]
|
||||
@@ -120,7 +119,15 @@ def _val_nll(model, tok, val_kept, hist_specs, lora, cfg) -> float:
|
||||
return sum(losses) / len(losses) if losses else float("nan")
|
||||
|
||||
|
||||
def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg: RunConfig):
|
||||
def heal_round(
|
||||
model,
|
||||
tok,
|
||||
kept: list[dict],
|
||||
hist_specs: list[AdapterSpec],
|
||||
cfg: RunConfig,
|
||||
ref_specs: list[AdapterSpec] | None = None,
|
||||
ref_round: int | str | None = None,
|
||||
):
|
||||
"""Train a fresh round adapter on top of baked history. Returns (lora, spec)."""
|
||||
assert len(kept) >= cfg.min_train, (
|
||||
f"only {len(kept)} kept completions; need >= {cfg.min_train} to train. The steering/filter "
|
||||
@@ -146,13 +153,26 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
|
||||
# lam_round_pow=0 -> lam_eff==lam (constant, no behaviour change). >0 grows the barrier with round.
|
||||
rnd = len(hist_specs)
|
||||
lam_eff = cfg.lam * (1 + rnd) ** cfg.lam_round_pow
|
||||
if cfg.barrier_ref == "base":
|
||||
barrier_ref_specs = []
|
||||
ref_desc = "base"
|
||||
elif cfg.barrier_ref == "prev":
|
||||
barrier_ref_specs = hist_specs
|
||||
ref_desc = f"prev(r{rnd - 1})"
|
||||
elif cfg.barrier_ref == "last_good":
|
||||
assert ref_specs is not None and ref_round is not None, "last_good barrier requires explicit ref_specs/ref_round"
|
||||
barrier_ref_specs = ref_specs
|
||||
ref_desc = f"last_good(r{ref_round})"
|
||||
else:
|
||||
raise ValueError(f"unknown barrier_ref={cfg.barrier_ref!r}")
|
||||
|
||||
# streaming training table (token-efficient-logging): one row, columns self-decode below.
|
||||
logger.info(f"heal[{cfg.reg}] {len(train_kept)} train (+{len(val_kept)} val) x {cfg.epochs} ep = "
|
||||
f"{n_batches} batches (bs={cfg.train_bs}) -> {n_opt_steps} opt steps (grad_accum={cfg.grad_accum}); "
|
||||
f"lora r={cfg.lora_r} a={cfg.lora_alpha} on layers {cfg.layer_range}; "
|
||||
f"lr={cfg.lr} cosine warmup={cfg.warmup_ratio} betas={cfg.adam_betas}; "
|
||||
f"lam_eff={lam_eff:.3f} (lam {cfg.lam} x (1+round={rnd})^{cfg.lam_round_pow})")
|
||||
f"lam_eff={lam_eff:.3f} (lam {cfg.lam} x (1+round={rnd})^{cfg.lam_round_pow}); "
|
||||
f"barrier_ref={cfg.barrier_ref} ref_round={ref_round} ref={ref_desc} ref_specs={len(barrier_ref_specs)}")
|
||||
logger.info("SHOULD (val): train_nll falls each epoch (SFT fits the kept data); val_nll falls then "
|
||||
"flattens. If val_nll RISES while train falls -> overfit (fewer epochs / lower r). If "
|
||||
"NEITHER falls -> data is near-base (nothing to distil) or the optimiser is broken.")
|
||||
@@ -188,30 +208,30 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
|
||||
ids = BatchEncoding({k: v[valid] for k, v in ids.items()})
|
||||
masks = masks[valid] # [B', L-1]
|
||||
|
||||
# barrier reference logits (this round's adapter OFF). barrier_ref="base" bakes no
|
||||
# history -> ref = round-0 original (leash to base, fights accumulated trait); "prev"
|
||||
# bakes the history -> ref = previous-round student (trust region, penalises only this
|
||||
# round's new divergence so trait accumulates while each step stays coherent).
|
||||
# Gather completion positions BEFORE log_softmax. softmax is per-row, so selecting the
|
||||
# ~N_comp completion rows then normalising is identical to normalising all B*(L-1) rows
|
||||
# then selecting -- but it never materialises the full [B,L-1,V] log_softmax NOR its
|
||||
# autograd graph. On gemma's 262k vocab at bs>1 the full tensor is what OOM'd the KL step.
|
||||
flat_mask = masks.reshape(-1) # [B'*(L-1)] bool, completion positions
|
||||
tgt_c = ids.input_ids[:, 1:].reshape(-1)[flat_mask] # [N_comp]
|
||||
|
||||
# barrier reference (this round's adapter OFF). base=[], prev=hist_specs,
|
||||
# last_good=hist_specs[:last_good_n] from run.py.
|
||||
if cfg.reg in ("kl_fwd", "kl_rev"):
|
||||
ref_specs = hist_specs if cfg.barrier_ref == "prev" else []
|
||||
with torch.no_grad(), baked(model, ref_specs), lora(model, c=0.0):
|
||||
logp0 = model(**ids).logits[:, :-1].log_softmax(-1) # [B', L-1, V]
|
||||
with torch.no_grad(), baked(model, barrier_ref_specs), lora(model, c=0.0):
|
||||
logits0 = model(**ids).logits[:, :-1] # [B', L-1, V]
|
||||
V = logits0.shape[-1]
|
||||
logp0_c = logits0.reshape(-1, V)[flat_mask].log_softmax(-1) # [N_comp, V]
|
||||
|
||||
# student logits: history baked + this round's adapter live
|
||||
# student: history baked + this round's adapter live. Same mask-first trick.
|
||||
with baked(model, hist_specs), lora(model, c=1.0):
|
||||
logits = model(**ids).logits[:, :-1] # [B', L-1, V]
|
||||
logp = logits.log_softmax(-1)
|
||||
|
||||
# flatten batch × seq to masked completion tokens for loss and KL
|
||||
V = logp.shape[-1]
|
||||
logp_c = logp.reshape(-1, V)[masks.reshape(-1)] # [N_comp, V]
|
||||
tgt_c = ids.input_ids[:, 1:].reshape(-1)[masks.reshape(-1)] # [N_comp]
|
||||
logits = model(**ids).logits[:, :-1] # [B', L-1, V]
|
||||
V = logits.shape[-1]
|
||||
logp_c = logits.reshape(-1, V)[flat_mask].log_softmax(-1) # [N_comp, V]
|
||||
sft = F.nll_loss(logp_c, tgt_c)
|
||||
if cfg.reg == "kl_fwd":
|
||||
logp0_c = logp0.reshape(-1, V)[masks.reshape(-1)]
|
||||
div = _agg_kl(_kl_per_pos(logp0_c, logp_c), cfg.kl_agg)
|
||||
elif cfg.reg == "kl_rev":
|
||||
logp0_c = logp0.reshape(-1, V)[masks.reshape(-1)]
|
||||
div = _agg_kl(_kl_per_pos(logp_c, logp0_c), cfg.kl_agg)
|
||||
else:
|
||||
div = torch.zeros((), device=model.device) # nll
|
||||
|
||||
+36
-2
@@ -198,6 +198,9 @@ def gen_filter_walk(model, tok, v, cfg: RunConfig, hist_specs: list, rnd: int) -
|
||||
|
||||
def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
||||
hist_specs = [] # AdapterSpec per folded round (gated bake history)
|
||||
last_good_n = 0 # number of adapters in the ratcheted coherent reference
|
||||
last_good_round = -1
|
||||
last_good_coherence = None
|
||||
v0_flat = None # round-0 direction, for the Q3 cosine
|
||||
rounds = []
|
||||
gen_rounds = [] # per-round adapter gens (same prompts) -> outputs.html table
|
||||
@@ -207,6 +210,7 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
||||
# trait), not just coherence. One extra eval per run.
|
||||
logger.info(f"\n\n\n=== EVAL base [tinymfv classic] gpu {gpu_mem()} ===")
|
||||
base_m = evaluate_model(model, tok, cfg, log_sample=True) # one FULL eval gen (token-efficient-logging)
|
||||
last_good_coherence = base_m["coherence"]
|
||||
log_event(run_dir, stage="base", round=-1, **base_m) # persist so offline plot_run.py is self-contained
|
||||
stages = [{"round": "-", "stage": "base", "m": base_m}] # base -> steered -> healed, for table + trajectory plot
|
||||
# BASE demo column (round -1): the no-adapter, no-steering model on the SAME demo prompts, so the
|
||||
@@ -252,7 +256,16 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
||||
|
||||
# heal one round on top of the baked history, then fold
|
||||
logger.info(f"\n\n\n=== r{rnd} HEAL [{cfg.reg}] gpu {gpu_mem()} ===")
|
||||
lora, spec, heal_nll = heal_round(model, tok, kept, hist_specs, cfg)
|
||||
ref_specs = hist_specs[:last_good_n] if cfg.barrier_ref == "last_good" else None
|
||||
ref_round = last_good_round if cfg.barrier_ref == "last_good" else None
|
||||
ref_coherence = last_good_coherence if cfg.barrier_ref == "last_good" else None
|
||||
if cfg.barrier_ref == "last_good":
|
||||
logger.info(
|
||||
f"last_good reference for r{rnd}: ref_round={ref_round} ref_specs={len(ref_specs)} "
|
||||
f"ref_coherence={ref_coherence:.3f}; adoption gate = new_coh >= max(coh_floor={cfg.coh_floor:.3f}, "
|
||||
f"{cfg.ref_adopt_rel:.3f} * ref_coh = {cfg.ref_adopt_rel * ref_coherence:.3f})"
|
||||
)
|
||||
lora, spec, heal_nll = heal_round(model, tok, kept, hist_specs, cfg, ref_specs=ref_specs, ref_round=ref_round)
|
||||
lora.save(str(run_dir / "ckpt" / f"r{rnd}.safetensors"), extra_meta={"round": str(rnd), "reg": cfg.reg})
|
||||
hist_specs.append(spec)
|
||||
|
||||
@@ -291,12 +304,33 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
||||
logger.info(f"\n\n\n=== ADAPTER DEMO r{rnd} coh(p_ans_any)={m['coherence']:.3f} adapter_ppl={adapter_ppl:.0f} "
|
||||
f"(no steering; compare across rounds: change vs saturation) ===\n" + demo_lines)
|
||||
|
||||
ref_adopted = False
|
||||
if cfg.barrier_ref == "last_good":
|
||||
adopt_threshold = max(cfg.coh_floor, cfg.ref_adopt_rel * last_good_coherence)
|
||||
if m["coherence"] >= adopt_threshold:
|
||||
last_good_n = len(hist_specs)
|
||||
last_good_round = rnd
|
||||
last_good_coherence = m["coherence"]
|
||||
ref_adopted = True
|
||||
logger.info(
|
||||
f"last_good ADOPT r{rnd}: coherence={m['coherence']:.3f} >= "
|
||||
f"threshold={adopt_threshold:.3f}; next ref_specs={last_good_n}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"last_good HOLD at r{last_good_round}: r{rnd} coherence={m['coherence']:.3f} < "
|
||||
f"threshold={adopt_threshold:.3f}; next round still leashes to r{last_good_round}"
|
||||
)
|
||||
|
||||
vf = _flatten_v(v)
|
||||
v0_flat = vf if v0_flat is None else v0_flat
|
||||
cos_v0 = float(cosine_similarity(vf, v0_flat, dim=0))
|
||||
rec = {"round": rnd, **m, "cos_v0": cos_v0, "steered_ppl": steered_ppl,
|
||||
"adapter_ppl": adapter_ppl, "n_comps": n_comps, "n_kept": len(kept),
|
||||
"kappa": kappa, "heal_nll": heal_nll}
|
||||
"kappa": kappa, "heal_nll": heal_nll,
|
||||
"barrier_ref_round": ref_round, "barrier_ref_coherence": ref_coherence,
|
||||
"last_good_round": last_good_round if cfg.barrier_ref == "last_good" else None,
|
||||
"last_good_adopted": ref_adopted}
|
||||
rounds.append(rec)
|
||||
stages.append({"round": rnd, "stage": "steered", "m": m_steer})
|
||||
stages.append({"round": rnd, "stage": "healed", "m": m})
|
||||
|
||||
Reference in New Issue
Block a user