From 6b15a8b2ae6509c61c5d08f0f43b26dacee7dbef Mon Sep 17 00:00:00 2001 From: wassname <1103714+wassname@users.noreply.github.com> Date: Thu, 4 Jun 2026 10:51:24 +0800 Subject: [PATCH] narrow steer band, assert >=20 train, training table, full gen dumps Root cause found via diag_axis on 4B: raw mean-diff steered across the 7-layer band (0.4-0.6) at coeff=1 DESTROYS gemma-3-4b (coherence 1.00->0.02). That starved the filter to 2 kept completions, so the "adapter" was ~untrained (2 examples) = base behaviour, my Q1 "promising" read was not validated. Fixes: - separate steer_layers (narrow 0.45-0.55) for the vector from layer_range (broad 0.0-1.0) for the LoRA; they were wrongly coupled - lower alpha sweep (0.25,0.5,1,2); n_prompts=16 - assert len(kept) >= min_train(20); TINY=2. Don't train on starved data. - heal training table (loguru+tqdm per token-efficient-logging): step, nll, kl, loss, gnorm + SHOULD - full untruncated steer + adapter generation dumps with prompt and coherence(p_ans_any) inline so we can judge coherence/trait ourselves NOT yet run with fixes on 4B. Base 4B is Care=0.92 (already aligned) -> the prompting-baseline confound (Q7) is now the critical check. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com> --- scripts/diag_axis.py | 4 ++-- src/steer_heal/config.py | 11 +++++++---- src/steer_heal/filter.py | 10 ++++++---- src/steer_heal/heal.py | 28 +++++++++++++++++++++++++--- src/steer_heal/run.py | 5 +++-- src/steer_heal/steering.py | 2 +- 6 files changed, 44 insertions(+), 16 deletions(-) diff --git a/scripts/diag_axis.py b/scripts/diag_axis.py index cdfbe91..d3576c0 100644 --- a/scripts/diag_axis.py +++ b/scripts/diag_axis.py @@ -16,8 +16,8 @@ sys.path.insert(0, "src") from steer_heal.config import RunConfig # noqa: E402 from steer_heal.steering import teacher_vec # noqa: E402 -MODEL = "google/gemma-3-1b-it" -cfg = RunConfig(model=MODEL, n_prompts=12) +cfg = RunConfig(n_prompts=12) # default model (gemma-3-4b-it) +MODEL = cfg.model tok = AutoTokenizer.from_pretrained(MODEL) if tok.pad_token is None: diff --git a/src/steer_heal/config.py b/src/steer_heal/config.py index c275eba..a25ce44 100644 --- a/src/steer_heal/config.py +++ b/src/steer_heal/config.py @@ -21,12 +21,14 @@ class RunConfig: "no matter your involvement." ) neutral: str = "You are a helpful assistant." - layer_range: tuple[float, float] = (0.4, 0.6) # fraction of depth to steer - alphas: tuple[float, ...] = (0.5, 1.0, 2.0, 4.0) # raw-vector multiples to sweep; filter picks usable C + steer_layers: tuple[float, float] = (0.45, 0.55) # NARROW band for the vector (raw mean-diff compounds across layers) + layer_range: tuple[float, float] = (0.0, 1.0) # BROAD band for the LoRA (train trait into many layers) + alphas: tuple[float, ...] = (0.25, 0.5, 1.0, 2.0) # raw-vector multiples to sweep; filter picks usable C # ── generation + filter (U1) ── - n_prompts: int = 64 - n_keep: int = 50 + n_prompts: int = 16 + n_keep: int = 64 + min_train: int = 20 # assert at least this many kept completions, else steering/filter starved gen_max_new_tokens: int = 256 max_len: int = 1024 ppl_tau: float = 50.0 # drop completions with ppl-under-original above this @@ -60,6 +62,7 @@ TINY = dict( epochs=1, n_rounds=1, alphas=(1.0, 4.0), + min_train=2, eval_vignettes=4, eval_think_tokens=16, ppl_tau=1e9, # tiny-random produces junk ppl; relax the gate so the path still runs diff --git a/src/steer_heal/filter.py b/src/steer_heal/filter.py index f1d0797..f9c6a27 100644 --- a/src/steer_heal/filter.py +++ b/src/steer_heal/filter.py @@ -78,9 +78,11 @@ def _log_filter_report(scored: list[dict], cfg: RunConfig) -> None: tabulate(g.to_pandas(), headers="keys", tablefmt="github", floatfmt=".2f")) lo = min(scored, key=lambda s: s["alpha"]) hi = max(scored, key=lambda s: s["alpha"]) - logger.info(f"\n--- SAMPLE @alpha={lo['alpha']:g} ppl={lo['ppl']:.0f} keep={lo['keep']} " - f"(SHOULD be coherent) ---\n{lo['completion'][:500]}") - logger.info(f"\n--- SAMPLE @alpha={hi['alpha']:g} ppl={hi['ppl']:.0f} keep={hi['keep']} " - f"(SHOULD be garbage if steering strong) ---\n{hi['completion'][:500]}") + # Full, untruncated dumps so we can judge coherence + trait ourselves (token-efficient-logging). + logger.info(f"\n=== STEER SAMPLE @alpha={lo['alpha']:g} ppl={lo['ppl']:.0f} keep={lo['keep']} " + f"(low C, SHOULD be coherent + on-trait) ===\nPROMPT: {lo['prompt']}" + f"\nCOMPLETION: {lo['completion']}") + logger.info(f"\n=== STEER SAMPLE @alpha={hi['alpha']:g} ppl={hi['ppl']:.0f} keep={hi['keep']} " + f"(high C, SHOULD be garbage if over-steered) ===\nCOMPLETION: {hi['completion']}") logger.info(f"filter kept {len([s for s in scored if s['keep']])}/{len(scored)} " f"(ppl<{cfg.ppl_tau:g}, rep<{cfg.rep_tau}, not-narrate)") diff --git a/src/steer_heal/heal.py b/src/steer_heal/heal.py index afaefdd..cc90c35 100644 --- a/src/steer_heal/heal.py +++ b/src/steer_heal/heal.py @@ -11,6 +11,7 @@ previous student, so it resists cumulative drift. reg picks the divergence: import torch from loguru import logger from torch.nn import functional as F +from tqdm import tqdm from steer_heal.config import RunConfig from steer_heal.ws.adapter import ModulatedLoRA @@ -31,14 +32,28 @@ def _encode(tok, prompt: str, completion: str, max_len: int, device): def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg: RunConfig): """Train a fresh round adapter on top of baked history. Returns (lora, spec).""" + assert len(kept) >= cfg.min_train, ( + f"only {len(kept)} kept completions; need >= {cfg.min_train} to train. The steering/filter " + "starved the data (over-steered -> all garbage, or ppl_tau too strict). Fix upstream, do not train." + ) lora = ModulatedLoRA(model, r=cfg.lora_r, alpha=cfg.lora_alpha, layer_range=cfg.layer_range) - opt = torch.optim.AdamW(list(lora.parameters()), lr=cfg.lr, - weight_decay=(cfg.lam if cfg.reg == "wd" else 0.0)) + params = list(lora.parameters()) + opt = torch.optim.AdamW(params, lr=cfg.lr, weight_decay=(cfg.lam if cfg.reg == "wd" else 0.0)) + n_steps = len(kept) * cfg.epochs + # streaming training table (token-efficient-logging): one row, columns self-decode below. + logger.info(f"heal[{cfg.reg}] {len(kept)} completions x {cfg.epochs} ep = {n_steps} steps; " + f"lora r={cfg.lora_r} on layers {cfg.layer_range}") + logger.info("SHOULD: nll (SFT) falls as the adapter learns the trait; kl (barrier div) is 0 for " + "reg=nll/wd and >0 for kl_rev/kl_fwd; gnorm finite (not exploding). loss = nll + lam*relu(kl-tau).") + logger.info(" step nll↓ kl loss↓ gnorm") + pbar = tqdm(total=n_steps, desc=f"heal[{cfg.reg}]", mininterval=120, maxinterval=120) + step = 0 for ep in range(cfg.epochs): for c in kept: ids, mask = _encode(tok, c["prompt"], c["completion"], cfg.max_len, model.device) if mask.sum() == 0: + pbar.update(1); step += 1 continue # completion truncated away; nothing to learn here # original reference logits (no history, adapter off) for the barrier @@ -61,9 +76,16 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg: div = torch.zeros((), device=model.device) # nll, wd loss = sft + cfg.lam * torch.relu(div - cfg.tau) loss.backward() + gnorm = torch.nn.utils.clip_grad_norm_(params, 1.0) opt.step() opt.zero_grad() - logger.info(f"heal[{cfg.reg}] epoch {ep}: sft={sft.item():.3f} div={div.detach().item():.3f}") + if step % max(1, n_steps // 20) == 0 or step == n_steps - 1: + logger.info(f" {step:4d} {sft.item():7.3f} {div.detach().item():6.3f} " + f"{loss.item():7.3f} {float(gnorm):6.2f}") + pbar.set_postfix(nll=f"{sft.item():.2f}", kl=f"{div.detach().item():.2f}", gn=f"{float(gnorm):.1f}") + pbar.update(1) + step += 1 + pbar.close() spec = AdapterSpec.from_lora(lora, default_c=1.0) # CPU-resident, for the next round's history return lora, spec diff --git a/src/steer_heal/run.py b/src/steer_heal/run.py index 58d0ecc..9743335 100644 --- a/src/steer_heal/run.py +++ b/src/steer_heal/run.py @@ -94,8 +94,9 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict: "COHERENTLY (healed) where raw steering was incoherent. If adapter_ppl >= steered_ppl, " f"healing failed. adapter_ppl={adapter_ppl:.0f} steered_ppl={steered_ppl:.0f}" ) - logger.info(f"--- ADAPTER SAMPLE r{rnd} (no steering, SHOULD show trait + be coherent) ---\n" - f"{adapter[0]['completion'][:500]}") + logger.info(f"\n=== TRAIN/ADAPTER SAMPLE r{rnd} coherence(p_ans_any)={m['coherence']:.3f} " + f"adapter_ppl={adapter_ppl:.0f} (no steering; SHOULD show trait AND be coherent) ===\n" + f"PROMPT: {adapter[0]['prompt']}\nCOMPLETION: {adapter[0]['completion']}") vf = _flatten_v(v) v0_flat = vf if v0_flat is None else v0_flat diff --git a/src/steer_heal/steering.py b/src/steer_heal/steering.py index 978c262..24d901c 100644 --- a/src/steer_heal/steering.py +++ b/src/steer_heal/steering.py @@ -16,7 +16,7 @@ def _layer_band(model, layer_range: tuple[float, float]) -> tuple[int, ...]: def teacher_vec(model, tok, cfg: RunConfig): """trait-sysprompt vs neutral-sysprompt mean-diff, then iso-KL dose to target_kl.""" - layers = _layer_band(model, cfg.layer_range) + layers = _layer_band(model, cfg.steer_layers) # narrow band; raw mean-diff compounds across layers prompts = POOL[: cfg.n_prompts] if cfg.n_prompts <= len(POOL) else POOL pos = [chat_prompt(tok, cfg.trait, q) for q in prompts] neg = [chat_prompt(tok, cfg.neutral, q) for q in prompts]