Files
steer-heal-love/scripts/diag_heal_sweep.py
T
wassname b01faa6df1 walk-C adaptive-dose controller + 10-round paired loop result (journal h)
gen_filter_walk: per round, cool a steering multiplier kappa and top up with
extra gen batches until min_train coherent survivors are banked, so the loop
cannot starve on data count (#90/#100 died at the min_train assert). Paired
#101 (walk-C ON) vs #100 (walk-C OFF, identical config): #101 reaches round 9
where #100 asserted at round 5.

Finding (journal h): walk-C removes the starve CRASH but the real ceiling is
coherence collapse, not data count. Trait over-drives to auth -6.8 while coh
falls 0.99 -> 0.62 and the kept completions degenerate into token loops
("BUILDUTEutive...", "GLUTE GLUTE") by round 7 -- low-entropy so they slip
under ppl_tau and rep_tau and train the next adapter on garbage. Coherent
deliverable is the round 1-2 adapter (auth -3.3 to -3.8 at coh 0.99-0.93).

config: lam 1.0->0.3, spectral_lam 0->0.01 (locked from #98/#99 ablation),
gen_pass_target/gen_kappa_decay/gen_kappa_min/gen_max_batches walk-C knobs.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
2026-06-06 07:13:51 +08:00

157 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Fast healing-hypothesis sweep, the RIGHT way: from the round-(N-1) CHECKPOINT.
The earlier diag_barrier.py re-healed a FRESH adapter from BASE (hist=[]), so the kl barrier
anchored to base and never saw the loop state. This loads the real round-0 checkpoint as baked
history, re-heals round-1's kept data on top, and varies ONLY the regulariser + the barrier
REFERENCE. That isolates: at round 1 (where the loop starts degenerating), which regulariser adds
the most NEW trait at the least coherence cost?
The decisive contrast is kl_rev ref=base vs ref=prev:
ref=base -> KL(student || ORIGINAL). The student already carries round-0's trait, so this leashes
it back toward base and partly UNDOES the prev round.
ref=prev -> KL(student || prev-round student). Penalises only THIS round's new divergence = a
trust region, so trait accumulates while each step stays coherent.
Metric: dAuth vs PREV (= new trait this round, the thing we want negative) at coherence >= prev.
Run: uv run python scripts/diag_heal_sweep.py out/20260604T231906_gemma-3-4b-it_nll_s42 1
"""
import dataclasses
import sys
from pathlib import Path
import srsly
import torch
from loguru import logger
from tabulate import tabulate
from transformers import AutoModelForCausalLM, AutoTokenizer
sys.path.insert(0, "src")
from steer_heal.config import RunConfig # noqa: E402
from steer_heal.eval import evaluate_model # noqa: E402
from steer_heal.heal import heal_round # noqa: E402
from steer_heal.run import setup_logging # noqa: E402
from steer_heal.ws.bake import AdapterSpec, baked # noqa: E402
setup_logging() # INFO -> stdout via tqdm.write, DEBUG (per-step bake trace) -> logs/*_verbose.log
run_dir = Path(sys.argv[1])
gen_round = int(sys.argv[2]) if len(sys.argv) > 2 else 1 # re-heal THIS round's data on r0..r(N-1) history
base_cfg = RunConfig()
# REGULARISER ABLATION (reg, lam, tau, ref, wd), all at ref=prev (the base-vs-prev DIRECTION
# question is settled: ref=base undoes prev, confirmed in #97 -- nll +1.157, kl_rev/base +0.855).
# Hold the reference fixed and ask which REGULARISER best trades coherence for trait, by the
# cohΔ/authΔ headline. Five families: nll (no reg, control), wd (Frobenius shrink via AdamW),
# kl_rev (mode-seeking trust region), kl_fwd (mass-covering), spectral_norm (operator-norm penalty).
#
# This is the FULL authoritative grid. The `# [#98]`-tagged rows are commented out because pueue 98
# already produced them -- uncomment to re-run from scratch. The active rows are the widened ends
# #98 never ran (the gap-fill, pueue 99). The combined table is read from BOTH logs. wd<=15 was
# byte-identical to the no-reg control (inert) so it stays commented; wd=30 moved trait MORE
# (dAuth_base -0.997 vs -0.782) AND held coherence, so wd 60/120 trace the curve above the knee.
GRID = [
# ("nll", 0.0, 0.0, "prev", 0.0), # [#98] control: pure SFT, no reg
# ("nll", 0.0, 0.0, "prev", 15.0), # [#98] INERT: byte-identical to wd=0 (decay too small to bite)
# ("nll", 0.0, 0.0, "prev", 30.0), # [#98] wd at the knee (AdamW Frobenius shrink on ΔW)
("nll", 0.0, 0.0, "prev", 60.0), # wd above knee -- does coherence keep improving?
("nll", 0.0, 0.0, "prev", 120.0), # wd strong -- where does trait start to erode?
("kl_rev", 0.03, 0.5, "prev", 0.0), # mode-seeking trust region, gentle (#82 best-retain end)
("kl_rev", 0.05, 0.5, "prev", 0.0), # between 0.03 and 0.1: does the slope peak below 0.1? (#98: 0.1 beat 0.3)
# ("kl_rev", 0.1, 0.5, "prev", 0.0), # [#98] mode-seeking trust region, mid (current front-runner -0.13)
# ("kl_rev", 0.3, 0.5, "prev", 0.0), # [#98] stronger trust region
("kl_rev", 1.0, 0.5, "prev", 0.0), # strong (#82: over-tight, undoes trait) -- the bracket end
# ("kl_fwd", 0.1, 0.5, "prev", 0.0), # [#98] mass-covering, gentle
("kl_fwd", 0.3, 0.5, "prev", 0.0), # mass-covering, stronger (expect: dilutes trait)
# spectral_norm is no longer a reg -- it's the independent cfg.spectral_lam knob now (composes with
# kl_rev). #98 swept it as reg=spectral_norm (0.01/0.1/1.0); to redo, set spectral_lam, not reg.
]
logger.info(f"heal sweep from round-{gen_round-1} checkpoint, re-heal round-{gen_round} data: {len(GRID)} configs")
tok = AutoTokenizer.from_pretrained(base_cfg.model)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
base_cfg.model, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="eager"
).eval()
# baked history = the real round-0..round-(gen_round-1) adapters from the source run.
hist_specs = [AdapterSpec.from_checkpoint(model, str(run_dir / "ckpt" / f"r{i}.safetensors"))
for i in range(gen_round)]
logger.info(f"loaded {len(hist_specs)} history checkpoint(s): r0..r{gen_round-1}")
# round-gen_round kept completions = the data round gen_round actually trained on.
gen = next(e for e in srsly.read_jsonl(run_dir / "events.jsonl")
if e["stage"] == "gen" and e["round"] == gen_round)
kept = [{"prompt": s["prompt"], "completion": s["completion"]} for s in gen["scored"] if s["keep"]]
logger.info(f"loaded {len(kept)} kept completions from round {gen_round}")
base_m = evaluate_model(model, tok, base_cfg)
with baked(model, hist_specs):
prev_m = evaluate_model(model, tok, base_cfg) # round-(gen_round-1) HEALED = the start point this round must improve on
logger.info(f"base: auth={base_m['auth_nats']:+.4f} coh={base_m['coherence']:.5f}")
logger.info(f"prev (r{gen_round-1} healed): auth={prev_m['auth_nats']:+.4f} coh={prev_m['coherence']:.5f}")
logger.info("SHOULD: dAuth_vs_prev NEGATIVE = this round ADDED trait; POSITIVE = the barrier UNDID prev. "
"ref=base should undo (>=0) where ref=prev adds (<0), at coherence >= prev.")
rows = []
for reg, lam, tau, ref, wd in GRID:
cfg = dataclasses.replace(base_cfg, reg=reg, lam=lam, tau=tau, barrier_ref=ref, weight_decay=wd)
torch.manual_seed(cfg.seed) # identical LoRA-A init across configs -> only the regulariser differs
lora, spec, heal_nll = heal_round(model, tok, kept, hist_specs, cfg)
with baked(model, hist_specs + [spec]): # full round-gen_round student = history + this round's adapter
m = evaluate_model(model, tok, cfg)
dAuth_base = m["auth_nats"] - base_m["auth_nats"]
dCoh_base = m["coherence"] - base_m["coherence"]
dAuth_prev = m["auth_nats"] - prev_m["auth_nats"]
dCoh_prev = m["coherence"] - prev_m["coherence"]
# THE HEADLINE: coherence cost per unit of trait, the trade-off slope dCoh/dAuth.
# We want trait to move (dAuth NEGATIVE) at little coherence cost (dCoh ~0), so a GOOD
# config has a small-magnitude ratio (or negative = free coherence). NaN-guard the
# denominator: a config that barely moves auth (|dAuth|<0.05 noise floor) makes the
# ratio explode/flip sign on noise, so it is not a meaningful efficiency -- blank it.
eps = 0.05
# HEADLINE scaled x100 so the tiny coherence-per-trait slope keeps resolving digits under
# the table's +.4f (raw ~+0.001 -> "+0.0011", and +0.0001 -> "+0.0001" both collapse to the
# noise floor; x100 -> "+0.1100" vs "+0.0100" stays distinguishable). Units: centinats coh / nat auth.
coh_per_auth_base = 100 * dCoh_base / dAuth_base if abs(dAuth_base) > eps else float("nan")
coh_per_auth_prev = 100 * dCoh_prev / dAuth_prev if abs(dAuth_prev) > eps else float("nan")
rows.append({ # HEADLINE first. Direction matters (NOT abs): most-NEGATIVE best = trait moved AND
"cohΔ/authΔ_base×100↓": coh_per_auth_base, # coh ROSE (free lunch); then small positive = cheap.
"cohΔ/authΔ_prev×100↓": coh_per_auth_prev,
"reg": reg, "lam": lam, "tau": tau, "ref": ref, "wd": wd,
"auth↓": m["auth_nats"], "dAuth_base↓": dAuth_base, "dAuth_prev↓": dAuth_prev,
"coh↑": m["coherence"], "dCoh_base↑": dCoh_base, "dCoh_prev↑": dCoh_prev, "heal_nll↓": heal_nll,
})
logger.info(f" {reg} lam={lam} tau={tau} ref={ref} wd={wd}: "
f"cohΔ/authΔ_base×100={coh_per_auth_base:+.4f} auth={m['auth_nats']:+.4f} "
f"dAuth_base={dAuth_base:+.4f} dAuth_prev={dAuth_prev:+.4f} coh={m['coherence']:.5f}")
# bookend reference rows so the swept configs read against base (origin) and prev (r0 healed = the
# anchor this round starts from). Every config sits BETWEEN these two; prev's own slope shows what r0's
# heal achieved (the bar to reproduce a round deeper).
for tag, mm in (("(prev=r0heal)", prev_m), ("(base origin)", base_m)):
dAb, dCb = mm["auth_nats"] - base_m["auth_nats"], mm["coherence"] - base_m["coherence"]
dAp, dCp = mm["auth_nats"] - prev_m["auth_nats"], mm["coherence"] - prev_m["coherence"]
rows.append({
"cohΔ/authΔ_base×100↓": 100 * dCb / dAb if abs(dAb) > 0.05 else float("nan"),
"cohΔ/authΔ_prev×100↓": 100 * dCp / dAp if abs(dAp) > 0.05 else float("nan"),
"reg": tag, "lam": "", "tau": "", "ref": "", "wd": "",
"auth↓": mm["auth_nats"], "dAuth_base↓": dAb, "dAuth_prev↓": dAp,
"coh↑": mm["coherence"], "dCoh_base↑": dCb, "dCoh_prev↑": dCp, "heal_nll↓": float("nan"),
})
print(f"\nheal sweep from r{gen_round-1} checkpoint, re-heal r{gen_round} data (vary regulariser + barrier ref only):")
print("HEADLINE = cohΔ/authΔ×100: centinats of coherence lost per nat of trait moved (the heal slope).")
print(" DIRECTION, not magnitude: most-NEGATIVE is best (trait moved AND coherence ROSE = free lunch);")
print(" then small-positive = cheap; large-positive = trait cost a lot of coherence. Sorted best-first.")
print(" blank ratio = |dAuth|<0.05 (config barely moved trait; slope is noise, not an efficiency).")
print("dAuth_prev = NEW trait this round (NEGATIVE = added); ref=base vs prev is the direction crux.\n")
# sort by the signed slope (NOT abs): most-negative free-lunch row first, NaN (do-nothing) last.
rows.sort(key=lambda r: (r["cohΔ/authΔ_base×100↓"] if r["cohΔ/authΔ_base×100↓"] == r["cohΔ/authΔ_base×100↓"] else 1e9))
# per-column precision: headline x100 + coherence deltas get the extra digits that discriminate close
# configs; reg/lam/tau/wd stay compact. Tuple order matches the rows-dict key order above.
fmt = ("+.4f", "+.4f", "g", "g", "g", "g", "g", "+.4f", "+.4f", "+.4f", ".5f", "+.5f", "+.5f", "+.3f")
print(tabulate(rows, headers="keys", tablefmt="github", floatfmt=fmt))
print(f"\nbase auth={base_m['auth_nats']:+.3f} coh={base_m['coherence']:.3f} | "
f"prev(r{gen_round-1}) auth={prev_m['auth_nats']:+.3f} coh={prev_m['coherence']:.3f} | source {run_dir.name}")