mirror of
https://github.com/wassname/steer-heal-love.git
synced 2026-06-27 16:32:06 +08:00
narrow steer band, assert >=20 train, training table, full gen dumps
Root cause found via diag_axis on 4B: raw mean-diff steered across the 7-layer band (0.4-0.6) at coeff=1 DESTROYS gemma-3-4b (coherence 1.00->0.02). That starved the filter to 2 kept completions, so the "adapter" was ~untrained (2 examples) = base behaviour, my Q1 "promising" read was not validated. Fixes: - separate steer_layers (narrow 0.45-0.55) for the vector from layer_range (broad 0.0-1.0) for the LoRA; they were wrongly coupled - lower alpha sweep (0.25,0.5,1,2); n_prompts=16 - assert len(kept) >= min_train(20); TINY=2. Don't train on starved data. - heal training table (loguru+tqdm per token-efficient-logging): step, nll, kl, loss, gnorm + SHOULD - full untruncated steer + adapter generation dumps with prompt and coherence(p_ans_any) inline so we can judge coherence/trait ourselves NOT yet run with fixes on 4B. Base 4B is Care=0.92 (already aligned) -> the prompting-baseline confound (Q7) is now the critical check. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -16,8 +16,8 @@ sys.path.insert(0, "src")
|
||||
from steer_heal.config import RunConfig # noqa: E402
|
||||
from steer_heal.steering import teacher_vec # noqa: E402
|
||||
|
||||
MODEL = "google/gemma-3-1b-it"
|
||||
cfg = RunConfig(model=MODEL, n_prompts=12)
|
||||
cfg = RunConfig(n_prompts=12) # default model (gemma-3-4b-it)
|
||||
MODEL = cfg.model
|
||||
|
||||
tok = AutoTokenizer.from_pretrained(MODEL)
|
||||
if tok.pad_token is None:
|
||||
|
||||
@@ -21,12 +21,14 @@ class RunConfig:
|
||||
"no matter your involvement."
|
||||
)
|
||||
neutral: str = "You are a helpful assistant."
|
||||
layer_range: tuple[float, float] = (0.4, 0.6) # fraction of depth to steer
|
||||
alphas: tuple[float, ...] = (0.5, 1.0, 2.0, 4.0) # raw-vector multiples to sweep; filter picks usable C
|
||||
steer_layers: tuple[float, float] = (0.45, 0.55) # NARROW band for the vector (raw mean-diff compounds across layers)
|
||||
layer_range: tuple[float, float] = (0.0, 1.0) # BROAD band for the LoRA (train trait into many layers)
|
||||
alphas: tuple[float, ...] = (0.25, 0.5, 1.0, 2.0) # raw-vector multiples to sweep; filter picks usable C
|
||||
|
||||
# ── generation + filter (U1) ──
|
||||
n_prompts: int = 64
|
||||
n_keep: int = 50
|
||||
n_prompts: int = 16
|
||||
n_keep: int = 64
|
||||
min_train: int = 20 # assert at least this many kept completions, else steering/filter starved
|
||||
gen_max_new_tokens: int = 256
|
||||
max_len: int = 1024
|
||||
ppl_tau: float = 50.0 # drop completions with ppl-under-original above this
|
||||
@@ -60,6 +62,7 @@ TINY = dict(
|
||||
epochs=1,
|
||||
n_rounds=1,
|
||||
alphas=(1.0, 4.0),
|
||||
min_train=2,
|
||||
eval_vignettes=4,
|
||||
eval_think_tokens=16,
|
||||
ppl_tau=1e9, # tiny-random produces junk ppl; relax the gate so the path still runs
|
||||
|
||||
@@ -78,9 +78,11 @@ def _log_filter_report(scored: list[dict], cfg: RunConfig) -> None:
|
||||
tabulate(g.to_pandas(), headers="keys", tablefmt="github", floatfmt=".2f"))
|
||||
lo = min(scored, key=lambda s: s["alpha"])
|
||||
hi = max(scored, key=lambda s: s["alpha"])
|
||||
logger.info(f"\n--- SAMPLE @alpha={lo['alpha']:g} ppl={lo['ppl']:.0f} keep={lo['keep']} "
|
||||
f"(SHOULD be coherent) ---\n{lo['completion'][:500]}")
|
||||
logger.info(f"\n--- SAMPLE @alpha={hi['alpha']:g} ppl={hi['ppl']:.0f} keep={hi['keep']} "
|
||||
f"(SHOULD be garbage if steering strong) ---\n{hi['completion'][:500]}")
|
||||
# Full, untruncated dumps so we can judge coherence + trait ourselves (token-efficient-logging).
|
||||
logger.info(f"\n=== STEER SAMPLE @alpha={lo['alpha']:g} ppl={lo['ppl']:.0f} keep={lo['keep']} "
|
||||
f"(low C, SHOULD be coherent + on-trait) ===\nPROMPT: {lo['prompt']}"
|
||||
f"\nCOMPLETION: {lo['completion']}")
|
||||
logger.info(f"\n=== STEER SAMPLE @alpha={hi['alpha']:g} ppl={hi['ppl']:.0f} keep={hi['keep']} "
|
||||
f"(high C, SHOULD be garbage if over-steered) ===\nCOMPLETION: {hi['completion']}")
|
||||
logger.info(f"filter kept {len([s for s in scored if s['keep']])}/{len(scored)} "
|
||||
f"(ppl<{cfg.ppl_tau:g}, rep<{cfg.rep_tau}, not-narrate)")
|
||||
|
||||
+25
-3
@@ -11,6 +11,7 @@ previous student, so it resists cumulative drift. reg picks the divergence:
|
||||
import torch
|
||||
from loguru import logger
|
||||
from torch.nn import functional as F
|
||||
from tqdm import tqdm
|
||||
|
||||
from steer_heal.config import RunConfig
|
||||
from steer_heal.ws.adapter import ModulatedLoRA
|
||||
@@ -31,14 +32,28 @@ def _encode(tok, prompt: str, completion: str, max_len: int, device):
|
||||
|
||||
def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg: RunConfig):
|
||||
"""Train a fresh round adapter on top of baked history. Returns (lora, spec)."""
|
||||
assert len(kept) >= cfg.min_train, (
|
||||
f"only {len(kept)} kept completions; need >= {cfg.min_train} to train. The steering/filter "
|
||||
"starved the data (over-steered -> all garbage, or ppl_tau too strict). Fix upstream, do not train."
|
||||
)
|
||||
lora = ModulatedLoRA(model, r=cfg.lora_r, alpha=cfg.lora_alpha, layer_range=cfg.layer_range)
|
||||
opt = torch.optim.AdamW(list(lora.parameters()), lr=cfg.lr,
|
||||
weight_decay=(cfg.lam if cfg.reg == "wd" else 0.0))
|
||||
params = list(lora.parameters())
|
||||
opt = torch.optim.AdamW(params, lr=cfg.lr, weight_decay=(cfg.lam if cfg.reg == "wd" else 0.0))
|
||||
n_steps = len(kept) * cfg.epochs
|
||||
|
||||
# streaming training table (token-efficient-logging): one row, columns self-decode below.
|
||||
logger.info(f"heal[{cfg.reg}] {len(kept)} completions x {cfg.epochs} ep = {n_steps} steps; "
|
||||
f"lora r={cfg.lora_r} on layers {cfg.layer_range}")
|
||||
logger.info("SHOULD: nll (SFT) falls as the adapter learns the trait; kl (barrier div) is 0 for "
|
||||
"reg=nll/wd and >0 for kl_rev/kl_fwd; gnorm finite (not exploding). loss = nll + lam*relu(kl-tau).")
|
||||
logger.info(" step nll↓ kl loss↓ gnorm")
|
||||
pbar = tqdm(total=n_steps, desc=f"heal[{cfg.reg}]", mininterval=120, maxinterval=120)
|
||||
step = 0
|
||||
for ep in range(cfg.epochs):
|
||||
for c in kept:
|
||||
ids, mask = _encode(tok, c["prompt"], c["completion"], cfg.max_len, model.device)
|
||||
if mask.sum() == 0:
|
||||
pbar.update(1); step += 1
|
||||
continue # completion truncated away; nothing to learn here
|
||||
|
||||
# original reference logits (no history, adapter off) for the barrier
|
||||
@@ -61,9 +76,16 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
|
||||
div = torch.zeros((), device=model.device) # nll, wd
|
||||
loss = sft + cfg.lam * torch.relu(div - cfg.tau)
|
||||
loss.backward()
|
||||
gnorm = torch.nn.utils.clip_grad_norm_(params, 1.0)
|
||||
opt.step()
|
||||
opt.zero_grad()
|
||||
logger.info(f"heal[{cfg.reg}] epoch {ep}: sft={sft.item():.3f} div={div.detach().item():.3f}")
|
||||
if step % max(1, n_steps // 20) == 0 or step == n_steps - 1:
|
||||
logger.info(f" {step:4d} {sft.item():7.3f} {div.detach().item():6.3f} "
|
||||
f"{loss.item():7.3f} {float(gnorm):6.2f}")
|
||||
pbar.set_postfix(nll=f"{sft.item():.2f}", kl=f"{div.detach().item():.2f}", gn=f"{float(gnorm):.1f}")
|
||||
pbar.update(1)
|
||||
step += 1
|
||||
pbar.close()
|
||||
|
||||
spec = AdapterSpec.from_lora(lora, default_c=1.0) # CPU-resident, for the next round's history
|
||||
return lora, spec
|
||||
|
||||
@@ -94,8 +94,9 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
||||
"COHERENTLY (healed) where raw steering was incoherent. If adapter_ppl >= steered_ppl, "
|
||||
f"healing failed. adapter_ppl={adapter_ppl:.0f} steered_ppl={steered_ppl:.0f}"
|
||||
)
|
||||
logger.info(f"--- ADAPTER SAMPLE r{rnd} (no steering, SHOULD show trait + be coherent) ---\n"
|
||||
f"{adapter[0]['completion'][:500]}")
|
||||
logger.info(f"\n=== TRAIN/ADAPTER SAMPLE r{rnd} coherence(p_ans_any)={m['coherence']:.3f} "
|
||||
f"adapter_ppl={adapter_ppl:.0f} (no steering; SHOULD show trait AND be coherent) ===\n"
|
||||
f"PROMPT: {adapter[0]['prompt']}\nCOMPLETION: {adapter[0]['completion']}")
|
||||
|
||||
vf = _flatten_v(v)
|
||||
v0_flat = vf if v0_flat is None else v0_flat
|
||||
|
||||
@@ -16,7 +16,7 @@ def _layer_band(model, layer_range: tuple[float, float]) -> tuple[int, ...]:
|
||||
|
||||
def teacher_vec(model, tok, cfg: RunConfig):
|
||||
"""trait-sysprompt vs neutral-sysprompt mean-diff, then iso-KL dose to target_kl."""
|
||||
layers = _layer_band(model, cfg.layer_range)
|
||||
layers = _layer_band(model, cfg.steer_layers) # narrow band; raw mean-diff compounds across layers
|
||||
prompts = POOL[: cfg.n_prompts] if cfg.n_prompts <= len(POOL) else POOL
|
||||
pos = [chat_prompt(tok, cfg.trait, q) for q in prompts]
|
||||
neg = [chat_prompt(tok, cfg.neutral, q) for q in prompts]
|
||||
|
||||
Reference in New Issue
Block a user