mirror of
https://github.com/wassname/steer-heal-love.git
synced 2026-06-27 15:17:14 +08:00
b01faa6df1
gen_filter_walk: per round, cool a steering multiplier kappa and top up with extra gen batches until min_train coherent survivors are banked, so the loop cannot starve on data count (#90/#100 died at the min_train assert). Paired #101 (walk-C ON) vs #100 (walk-C OFF, identical config): #101 reaches round 9 where #100 asserted at round 5. Finding (journal h): walk-C removes the starve CRASH but the real ceiling is coherence collapse, not data count. Trait over-drives to auth -6.8 while coh falls 0.99 -> 0.62 and the kept completions degenerate into token loops ("BUILDUTEutive...", "GLUTE GLUTE") by round 7 -- low-entropy so they slip under ppl_tau and rep_tau and train the next adapter on garbage. Coherent deliverable is the round 1-2 adapter (auth -3.3 to -3.8 at coh 0.99-0.93). config: lam 1.0->0.3, spectral_lam 0->0.01 (locked from #98/#99 ablation), gen_pass_target/gen_kappa_decay/gen_kappa_min/gen_max_batches walk-C knobs. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
124 lines
6.1 KiB
Python
124 lines
6.1 KiB
Python
"""Barrier-strength sweep: is the trait failing to transfer because the barrier (kl/wd) is too
|
|
strong, or because the kept data is near-base? Re-heal from ONE run's cached kept completions with
|
|
the SAME init seed, varying ONLY (reg, lam, tau). Same data + same init => the only thing that moves
|
|
healed auth_nats is the barrier.
|
|
|
|
reg=nll is the ablation: barrier OFF. If nll ALSO lands near base, the data is the ceiling, not the
|
|
barrier. If nll (or weak kl/wd) retains MORE trait than kl_rev lam=1.0, the barrier was killing it.
|
|
|
|
Run: uv run python scripts/diag_barrier.py out/20260604T194133_gemma-3-4b-it_kl_rev_s42/
|
|
"""
|
|
import dataclasses
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import srsly
|
|
import torch
|
|
from loguru import logger
|
|
from tabulate import tabulate
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
sys.path.insert(0, "src")
|
|
from steer_heal.config import RunConfig # noqa: E402
|
|
from steer_heal.eval import evaluate_model # noqa: E402
|
|
from steer_heal.heal import heal_round # noqa: E402
|
|
from steer_heal.ws.bake import baked # noqa: E402
|
|
|
|
run_dir = Path(sys.argv[1])
|
|
mode = sys.argv[2] if len(sys.argv) > 2 else "barrier" # "barrier" (kl sweep) or "wd" (decay decade sweep)
|
|
gen_round = int(sys.argv[3]) if len(sys.argv) > 3 else 0 # which round's kept data to re-heal (0 = clean; later = messier)
|
|
base_cfg = RunConfig()
|
|
|
|
# (reg, lam, tau) grids. nll = barrier off (ablation) and the shared trait-ceiling reference.
|
|
GRIDS = {
|
|
# kl_rev strength + a tau probe. lam 0.03 (w2s) .. 1.0 (current default).
|
|
"barrier": [
|
|
("nll", 0.0, 0.5), # ablation: no barrier at all
|
|
("kl_rev", 0.03, 0.5),
|
|
("kl_rev", 0.1, 0.5),
|
|
("kl_rev", 0.3, 0.5),
|
|
("kl_rev", 1.0, 0.5),
|
|
("kl_rev", 0.3, 1.0), # weaker via higher tau (engages later)
|
|
],
|
|
# pure linear kl_rev: tau=0 => barrier = lam*relu(div) = lam*div, always on, no deadband
|
|
# (the w2s form). Cleaner knob than the hinge; compare against the tau=0.5 rows in "barrier".
|
|
"tau0": [
|
|
("nll", 0.0, 0.0),
|
|
("kl_rev", 0.03, 0.0),
|
|
("kl_rev", 0.1, 0.0),
|
|
("kl_rev", 0.3, 0.0),
|
|
("kl_rev", 1.0, 0.0),
|
|
],
|
|
# tau sweep: fix lam (middling barrier) and vary the deadband tau. Higher tau = barrier engages
|
|
# only on larger divergence = weaker. Shows whether a deadband helps on degenerate (round 2) data.
|
|
"tau": [
|
|
("nll", 0.0, 0.0),
|
|
("kl_rev", 0.3, 0.0),
|
|
("kl_rev", 0.3, 0.25),
|
|
("kl_rev", 0.3, 0.5),
|
|
("kl_rev", 0.3, 1.0),
|
|
("kl_rev", 0.3, 2.0),
|
|
],
|
|
# weight decay: a WEIGHTS-space constraint (AdamW decoupled decay, tau irrelevant). Its per-step
|
|
# shrink is lr*wd, and lr~1e-4 is tiny, so #82 found wd<=0.1 byte-identical to nll (~0.1% shrink
|
|
# over 252 steps). Sweep up to 100 to find where cumulative shrink (252*lr*wd) reaches order-1.
|
|
"wd": [
|
|
("nll", 0.0, 0.5),
|
|
("wd", 1e-1, 0.5),
|
|
("wd", 1.0, 0.5),
|
|
("wd", 3.0, 0.5),
|
|
("wd", 10.0, 0.5),
|
|
("wd", 30.0, 0.5),
|
|
("wd", 100.0, 0.5),
|
|
],
|
|
}
|
|
GRID = GRIDS[mode]
|
|
logger.info(f"barrier sweep mode={mode}: {len(GRID)} configs")
|
|
|
|
# kept completions (keep==True) from a CHOSEN round of the source run. round 0 = clean steered-on-base
|
|
# data; later rounds = data after the loop started degenerating (repetition), the regime where the
|
|
# barrier is hypothesised to matter (it was pure-cost on clean round-0 data, #82/85/86).
|
|
gen = next(e for e in srsly.read_jsonl(run_dir / "events.jsonl")
|
|
if e["stage"] == "gen" and e["round"] == gen_round)
|
|
kept = [{"prompt": s["prompt"], "completion": s["completion"]} for s in gen["scored"] if s["keep"]]
|
|
logger.info(f"loaded {len(kept)} kept completions from {run_dir.name} round {gen_round}")
|
|
|
|
tok = AutoTokenizer.from_pretrained(base_cfg.model)
|
|
if tok.pad_token is None:
|
|
tok.pad_token = tok.eos_token
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
base_cfg.model, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="eager"
|
|
).eval()
|
|
|
|
base_m = evaluate_model(model, tok, base_cfg)
|
|
logger.info(f"base: auth_nats={base_m['auth_nats']:+.3f} care_nats={base_m['care_nats']:+.3f} coh={base_m['coherence']:.3f}")
|
|
|
|
rows = []
|
|
for reg, lam, tau in GRID:
|
|
# "wd" grid rows are now a weights-space knob, not a reg value: map to reg=nll + weight_decay=lam.
|
|
if reg == "wd":
|
|
cfg = dataclasses.replace(base_cfg, reg="nll", lam=0.0, tau=0.0, weight_decay=lam)
|
|
else:
|
|
cfg = dataclasses.replace(base_cfg, reg=reg, lam=lam, tau=tau, weight_decay=0.0)
|
|
torch.manual_seed(cfg.seed) # identical LoRA-A init across barrier values -> only the barrier differs
|
|
lora, spec, heal_nll = heal_round(model, tok, kept, [], cfg)
|
|
with baked(model, [spec]):
|
|
m = evaluate_model(model, tok, cfg)
|
|
dauth = m["auth_nats"] - base_m["auth_nats"]
|
|
dcoh = m["coherence"] - base_m["coherence"]
|
|
# ONE strength knob per row: kl-barrier weight for kl_rev/kl_fwd, AdamW weight_decay for wd,
|
|
# ignored for nll. tau (kl deadband) only applies to the kl regs -> "-" otherwise.
|
|
is_kl = reg in ("kl_rev", "kl_fwd")
|
|
rows.append({"reg": reg, "strength": lam, "tau(kl only)": (f"{tau:.1f}" if is_kl else "-"),
|
|
"heal_nll↓": heal_nll, "auth_nats↓": m["auth_nats"], "dAuth↓": dauth,
|
|
"care_nats": m["care_nats"], "coh→": m["coherence"], "dCoh": dcoh})
|
|
logger.info(f" {reg} strength={lam}{f' tau={tau}' if is_kl else ''}: "
|
|
f"auth={m['auth_nats']:+.3f} (dAuth={dauth:+.3f}) coh={m['coherence']:.3f}")
|
|
|
|
logger.info("SHOULD: if nll/weak-barrier retain MORE trait (more negative dAuth) at similar coh, the "
|
|
"barrier was killing the trait. If ALL rows sit near dAuth~0, the kept data is near-base.")
|
|
print("\nbarrier sweep (re-heal #79 kept data, vary the regulariser only; dAuth/dCoh vs base):")
|
|
print("strength = kl-barrier weight (kl_rev) OR AdamW weight_decay (wd); tau = kl deadband, n/a for wd/nll\n")
|
|
print(tabulate(rows, headers="keys", tablefmt="github", floatfmt="+.3f"))
|
|
print(f"\nbase auth_nats={base_m['auth_nats']:+.3f} coh={base_m['coherence']:.3f} | source {run_dir.name}")
|