mirror of
https://github.com/wassname/steer-heal-love.git
synced 2026-06-27 16:47:16 +08:00
b01faa6df1
gen_filter_walk: per round, cool a steering multiplier kappa and top up with extra gen batches until min_train coherent survivors are banked, so the loop cannot starve on data count (#90/#100 died at the min_train assert). Paired #101 (walk-C ON) vs #100 (walk-C OFF, identical config): #101 reaches round 9 where #100 asserted at round 5. Finding (journal h): walk-C removes the starve CRASH but the real ceiling is coherence collapse, not data count. Trait over-drives to auth -6.8 while coh falls 0.99 -> 0.62 and the kept completions degenerate into token loops ("BUILDUTEutive...", "GLUTE GLUTE") by round 7 -- low-entropy so they slip under ppl_tau and rep_tau and train the next adapter on garbage. Coherent deliverable is the round 1-2 adapter (auth -3.3 to -3.8 at coh 0.99-0.93). config: lam 1.0->0.3, spectral_lam 0->0.01 (locked from #98/#99 ablation), gen_pass_target/gen_kappa_decay/gen_kappa_min/gen_max_batches walk-C knobs. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
157 lines
11 KiB
Python
157 lines
11 KiB
Python
"""Fast healing-hypothesis sweep, the RIGHT way: from the round-(N-1) CHECKPOINT.
|
||
|
||
The earlier diag_barrier.py re-healed a FRESH adapter from BASE (hist=[]), so the kl barrier
|
||
anchored to base and never saw the loop state. This loads the real round-0 checkpoint as baked
|
||
history, re-heals round-1's kept data on top, and varies ONLY the regulariser + the barrier
|
||
REFERENCE. That isolates: at round 1 (where the loop starts degenerating), which regulariser adds
|
||
the most NEW trait at the least coherence cost?
|
||
|
||
The decisive contrast is kl_rev ref=base vs ref=prev:
|
||
ref=base -> KL(student || ORIGINAL). The student already carries round-0's trait, so this leashes
|
||
it back toward base and partly UNDOES the prev round.
|
||
ref=prev -> KL(student || prev-round student). Penalises only THIS round's new divergence = a
|
||
trust region, so trait accumulates while each step stays coherent.
|
||
|
||
Metric: dAuth vs PREV (= new trait this round, the thing we want negative) at coherence >= prev.
|
||
|
||
Run: uv run python scripts/diag_heal_sweep.py out/20260604T231906_gemma-3-4b-it_nll_s42 1
|
||
"""
|
||
import dataclasses
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
import srsly
|
||
import torch
|
||
from loguru import logger
|
||
from tabulate import tabulate
|
||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||
|
||
sys.path.insert(0, "src")
|
||
from steer_heal.config import RunConfig # noqa: E402
|
||
from steer_heal.eval import evaluate_model # noqa: E402
|
||
from steer_heal.heal import heal_round # noqa: E402
|
||
from steer_heal.run import setup_logging # noqa: E402
|
||
from steer_heal.ws.bake import AdapterSpec, baked # noqa: E402
|
||
|
||
setup_logging() # INFO -> stdout via tqdm.write, DEBUG (per-step bake trace) -> logs/*_verbose.log
|
||
run_dir = Path(sys.argv[1])
|
||
gen_round = int(sys.argv[2]) if len(sys.argv) > 2 else 1 # re-heal THIS round's data on r0..r(N-1) history
|
||
base_cfg = RunConfig()
|
||
|
||
# REGULARISER ABLATION (reg, lam, tau, ref, wd), all at ref=prev (the base-vs-prev DIRECTION
|
||
# question is settled: ref=base undoes prev, confirmed in #97 -- nll +1.157, kl_rev/base +0.855).
|
||
# Hold the reference fixed and ask which REGULARISER best trades coherence for trait, by the
|
||
# cohΔ/authΔ headline. Five families: nll (no reg, control), wd (Frobenius shrink via AdamW),
|
||
# kl_rev (mode-seeking trust region), kl_fwd (mass-covering), spectral_norm (operator-norm penalty).
|
||
#
|
||
# This is the FULL authoritative grid. The `# [#98]`-tagged rows are commented out because pueue 98
|
||
# already produced them -- uncomment to re-run from scratch. The active rows are the widened ends
|
||
# #98 never ran (the gap-fill, pueue 99). The combined table is read from BOTH logs. wd<=15 was
|
||
# byte-identical to the no-reg control (inert) so it stays commented; wd=30 moved trait MORE
|
||
# (dAuth_base -0.997 vs -0.782) AND held coherence, so wd 60/120 trace the curve above the knee.
|
||
GRID = [
|
||
# ("nll", 0.0, 0.0, "prev", 0.0), # [#98] control: pure SFT, no reg
|
||
# ("nll", 0.0, 0.0, "prev", 15.0), # [#98] INERT: byte-identical to wd=0 (decay too small to bite)
|
||
# ("nll", 0.0, 0.0, "prev", 30.0), # [#98] wd at the knee (AdamW Frobenius shrink on ΔW)
|
||
("nll", 0.0, 0.0, "prev", 60.0), # wd above knee -- does coherence keep improving?
|
||
("nll", 0.0, 0.0, "prev", 120.0), # wd strong -- where does trait start to erode?
|
||
("kl_rev", 0.03, 0.5, "prev", 0.0), # mode-seeking trust region, gentle (#82 best-retain end)
|
||
("kl_rev", 0.05, 0.5, "prev", 0.0), # between 0.03 and 0.1: does the slope peak below 0.1? (#98: 0.1 beat 0.3)
|
||
# ("kl_rev", 0.1, 0.5, "prev", 0.0), # [#98] mode-seeking trust region, mid (current front-runner -0.13)
|
||
# ("kl_rev", 0.3, 0.5, "prev", 0.0), # [#98] stronger trust region
|
||
("kl_rev", 1.0, 0.5, "prev", 0.0), # strong (#82: over-tight, undoes trait) -- the bracket end
|
||
# ("kl_fwd", 0.1, 0.5, "prev", 0.0), # [#98] mass-covering, gentle
|
||
("kl_fwd", 0.3, 0.5, "prev", 0.0), # mass-covering, stronger (expect: dilutes trait)
|
||
# spectral_norm is no longer a reg -- it's the independent cfg.spectral_lam knob now (composes with
|
||
# kl_rev). #98 swept it as reg=spectral_norm (0.01/0.1/1.0); to redo, set spectral_lam, not reg.
|
||
]
|
||
logger.info(f"heal sweep from round-{gen_round-1} checkpoint, re-heal round-{gen_round} data: {len(GRID)} configs")
|
||
|
||
tok = AutoTokenizer.from_pretrained(base_cfg.model)
|
||
if tok.pad_token is None:
|
||
tok.pad_token = tok.eos_token
|
||
model = AutoModelForCausalLM.from_pretrained(
|
||
base_cfg.model, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="eager"
|
||
).eval()
|
||
|
||
# baked history = the real round-0..round-(gen_round-1) adapters from the source run.
|
||
hist_specs = [AdapterSpec.from_checkpoint(model, str(run_dir / "ckpt" / f"r{i}.safetensors"))
|
||
for i in range(gen_round)]
|
||
logger.info(f"loaded {len(hist_specs)} history checkpoint(s): r0..r{gen_round-1}")
|
||
|
||
# round-gen_round kept completions = the data round gen_round actually trained on.
|
||
gen = next(e for e in srsly.read_jsonl(run_dir / "events.jsonl")
|
||
if e["stage"] == "gen" and e["round"] == gen_round)
|
||
kept = [{"prompt": s["prompt"], "completion": s["completion"]} for s in gen["scored"] if s["keep"]]
|
||
logger.info(f"loaded {len(kept)} kept completions from round {gen_round}")
|
||
|
||
base_m = evaluate_model(model, tok, base_cfg)
|
||
with baked(model, hist_specs):
|
||
prev_m = evaluate_model(model, tok, base_cfg) # round-(gen_round-1) HEALED = the start point this round must improve on
|
||
logger.info(f"base: auth={base_m['auth_nats']:+.4f} coh={base_m['coherence']:.5f}")
|
||
logger.info(f"prev (r{gen_round-1} healed): auth={prev_m['auth_nats']:+.4f} coh={prev_m['coherence']:.5f}")
|
||
logger.info("SHOULD: dAuth_vs_prev NEGATIVE = this round ADDED trait; POSITIVE = the barrier UNDID prev. "
|
||
"ref=base should undo (>=0) where ref=prev adds (<0), at coherence >= prev.")
|
||
|
||
rows = []
|
||
for reg, lam, tau, ref, wd in GRID:
|
||
cfg = dataclasses.replace(base_cfg, reg=reg, lam=lam, tau=tau, barrier_ref=ref, weight_decay=wd)
|
||
torch.manual_seed(cfg.seed) # identical LoRA-A init across configs -> only the regulariser differs
|
||
lora, spec, heal_nll = heal_round(model, tok, kept, hist_specs, cfg)
|
||
with baked(model, hist_specs + [spec]): # full round-gen_round student = history + this round's adapter
|
||
m = evaluate_model(model, tok, cfg)
|
||
dAuth_base = m["auth_nats"] - base_m["auth_nats"]
|
||
dCoh_base = m["coherence"] - base_m["coherence"]
|
||
dAuth_prev = m["auth_nats"] - prev_m["auth_nats"]
|
||
dCoh_prev = m["coherence"] - prev_m["coherence"]
|
||
# THE HEADLINE: coherence cost per unit of trait, the trade-off slope dCoh/dAuth.
|
||
# We want trait to move (dAuth NEGATIVE) at little coherence cost (dCoh ~0), so a GOOD
|
||
# config has a small-magnitude ratio (or negative = free coherence). NaN-guard the
|
||
# denominator: a config that barely moves auth (|dAuth|<0.05 noise floor) makes the
|
||
# ratio explode/flip sign on noise, so it is not a meaningful efficiency -- blank it.
|
||
eps = 0.05
|
||
# HEADLINE scaled x100 so the tiny coherence-per-trait slope keeps resolving digits under
|
||
# the table's +.4f (raw ~+0.001 -> "+0.0011", and +0.0001 -> "+0.0001" both collapse to the
|
||
# noise floor; x100 -> "+0.1100" vs "+0.0100" stays distinguishable). Units: centinats coh / nat auth.
|
||
coh_per_auth_base = 100 * dCoh_base / dAuth_base if abs(dAuth_base) > eps else float("nan")
|
||
coh_per_auth_prev = 100 * dCoh_prev / dAuth_prev if abs(dAuth_prev) > eps else float("nan")
|
||
rows.append({ # HEADLINE first. Direction matters (NOT abs): most-NEGATIVE best = trait moved AND
|
||
"cohΔ/authΔ_base×100↓": coh_per_auth_base, # coh ROSE (free lunch); then small positive = cheap.
|
||
"cohΔ/authΔ_prev×100↓": coh_per_auth_prev,
|
||
"reg": reg, "lam": lam, "tau": tau, "ref": ref, "wd": wd,
|
||
"auth↓": m["auth_nats"], "dAuth_base↓": dAuth_base, "dAuth_prev↓": dAuth_prev,
|
||
"coh↑": m["coherence"], "dCoh_base↑": dCoh_base, "dCoh_prev↑": dCoh_prev, "heal_nll↓": heal_nll,
|
||
})
|
||
logger.info(f" {reg} lam={lam} tau={tau} ref={ref} wd={wd}: "
|
||
f"cohΔ/authΔ_base×100={coh_per_auth_base:+.4f} auth={m['auth_nats']:+.4f} "
|
||
f"dAuth_base={dAuth_base:+.4f} dAuth_prev={dAuth_prev:+.4f} coh={m['coherence']:.5f}")
|
||
|
||
# bookend reference rows so the swept configs read against base (origin) and prev (r0 healed = the
|
||
# anchor this round starts from). Every config sits BETWEEN these two; prev's own slope shows what r0's
|
||
# heal achieved (the bar to reproduce a round deeper).
|
||
for tag, mm in (("(prev=r0heal)", prev_m), ("(base origin)", base_m)):
|
||
dAb, dCb = mm["auth_nats"] - base_m["auth_nats"], mm["coherence"] - base_m["coherence"]
|
||
dAp, dCp = mm["auth_nats"] - prev_m["auth_nats"], mm["coherence"] - prev_m["coherence"]
|
||
rows.append({
|
||
"cohΔ/authΔ_base×100↓": 100 * dCb / dAb if abs(dAb) > 0.05 else float("nan"),
|
||
"cohΔ/authΔ_prev×100↓": 100 * dCp / dAp if abs(dAp) > 0.05 else float("nan"),
|
||
"reg": tag, "lam": "", "tau": "", "ref": "", "wd": "",
|
||
"auth↓": mm["auth_nats"], "dAuth_base↓": dAb, "dAuth_prev↓": dAp,
|
||
"coh↑": mm["coherence"], "dCoh_base↑": dCb, "dCoh_prev↑": dCp, "heal_nll↓": float("nan"),
|
||
})
|
||
|
||
print(f"\nheal sweep from r{gen_round-1} checkpoint, re-heal r{gen_round} data (vary regulariser + barrier ref only):")
|
||
print("HEADLINE = cohΔ/authΔ×100: centinats of coherence lost per nat of trait moved (the heal slope).")
|
||
print(" DIRECTION, not magnitude: most-NEGATIVE is best (trait moved AND coherence ROSE = free lunch);")
|
||
print(" then small-positive = cheap; large-positive = trait cost a lot of coherence. Sorted best-first.")
|
||
print(" blank ratio = |dAuth|<0.05 (config barely moved trait; slope is noise, not an efficiency).")
|
||
print("dAuth_prev = NEW trait this round (NEGATIVE = added); ref=base vs prev is the direction crux.\n")
|
||
# sort by the signed slope (NOT abs): most-negative free-lunch row first, NaN (do-nothing) last.
|
||
rows.sort(key=lambda r: (r["cohΔ/authΔ_base×100↓"] if r["cohΔ/authΔ_base×100↓"] == r["cohΔ/authΔ_base×100↓"] else 1e9))
|
||
# per-column precision: headline x100 + coherence deltas get the extra digits that discriminate close
|
||
# configs; reg/lam/tau/wd stay compact. Tuple order matches the rows-dict key order above.
|
||
fmt = ("+.4f", "+.4f", "g", "g", "g", "g", "g", "+.4f", "+.4f", "+.4f", ".5f", "+.5f", "+.5f", "+.3f")
|
||
print(tabulate(rows, headers="keys", tablefmt="github", floatfmt=fmt))
|
||
print(f"\nbase auth={base_m['auth_nats']:+.3f} coh={base_m['coherence']:.3f} | "
|
||
f"prev(r{gen_round-1}) auth={prev_m['auth_nats']:+.3f} coh={prev_m['coherence']:.3f} | source {run_dir.name}")
|