mirror of
https://github.com/wassname/steer-heal-love.git
synced 2026-06-27 16:47:16 +08:00
walk-C adaptive-dose controller + 10-round paired loop result (journal h)
gen_filter_walk: per round, cool a steering multiplier kappa and top up with extra gen batches until min_train coherent survivors are banked, so the loop cannot starve on data count (#90/#100 died at the min_train assert). Paired #101 (walk-C ON) vs #100 (walk-C OFF, identical config): #101 reaches round 9 where #100 asserted at round 5. Finding (journal h): walk-C removes the starve CRASH but the real ceiling is coherence collapse, not data count. Trait over-drives to auth -6.8 while coh falls 0.99 -> 0.62 and the kept completions degenerate into token loops ("BUILDUTEutive...", "GLUTE GLUTE") by round 7 -- low-entropy so they slip under ppl_tau and rep_tau and train the next adapter on garbage. Coherent deliverable is the round 1-2 adapter (auth -3.3 to -3.8 at coh 0.99-0.93). config: lam 1.0->0.3, spectral_lam 0->0.01 (locked from #98/#99 ablation), gen_pass_target/gen_kappa_decay/gen_kappa_min/gen_max_batches walk-C knobs. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
+35
-11
@@ -40,21 +40,45 @@ class RunConfig:
|
||||
# ── generation + filter (U1) ──
|
||||
n_prompts: int = 16
|
||||
n_keep: int = 64
|
||||
min_train: int = 20 # assert at least this many kept completions, else steering/filter starved
|
||||
gen_max_new_tokens: int = 256
|
||||
min_train: int = 30 # assert at least this many kept completions, else starved (walk-C should hold us above)
|
||||
gen_max_new_tokens: int = 512 # longer = more long-horizon coherence signal (GPU has room at bs=1)
|
||||
max_len: int = 1024
|
||||
# repetition is incoherence the ppl filter CANNOT see (looped text is low-ppl = predictable), so
|
||||
# stop it at generation, not post-hoc: penalty softly discourages all repeats, no_repeat_ngram
|
||||
# hard-blocks any trigram repeat (kills "instead their instead their" loops at the source).
|
||||
repetition_penalty: float = 1.3
|
||||
no_repeat_ngram_size: int = 3
|
||||
ppl_tau: float = 50.0 # drop completions with ppl-under-original above this
|
||||
rep_tau: float = 0.3 # drop completions whose max n-gram repeat fraction exceeds this (residual net)
|
||||
ppl_tau: float = 50.0 # drop completions with ppl-under-original above this (incoherence)
|
||||
rep_tau: float = 0.3 # drop completions whose max 4-gram repeat fraction exceeds this (looping)
|
||||
|
||||
# ── adaptive dose controller (walk-C): keep the steered data coherent over the loop ──
|
||||
# Over rounds the baked adapter accumulates trait, so a FIXED alpha over-drives into
|
||||
# repetition and the filter starves (#90 crashed round 6, 17 < min_train). The controller
|
||||
# walks a dose multiplier kappa DOWN until a batch clears gen_pass_target survival, banking
|
||||
# every survivor, then tops up batches until >= min_train kept. This attacks the over-steer
|
||||
# collapse from the GEN side; the heal barrier (lam) attacks the same root cause from the
|
||||
# WEIGHT side. kappa=1 = nominal alphas. The steering.py:65 comment anticipated this controller.
|
||||
gen_pass_target: float = 0.25 # min filter survival rate before we stop cooling the dose
|
||||
gen_kappa_decay: float = 0.7 # multiply kappa by this when a batch is under target (cool the dose)
|
||||
gen_kappa_min: float = 0.2 # floor: below 20% of nominal there is no trait signal left to distil
|
||||
gen_max_batches: int = 6 # hard cap on gen+filter rounds; if still short, the heal assert fires (genuine starve)
|
||||
|
||||
# ── heal (U2): one objective + divergence-to-ORIGINAL barrier ──
|
||||
reg: Literal["nll", "kl_fwd", "kl_rev", "wd"] = "kl_rev"
|
||||
lam: float = 1.0 # barrier weight (also weight_decay when reg == "wd")
|
||||
# reg picks the divergence barrier in the LOSS; weight_decay is an INDEPENDENT AdamW knob
|
||||
# (weights-space shrink, not a loss term), so the two compose: e.g. a gentle kl_rev barrier
|
||||
# that protects coherence over the loop (journal (f)) PLUS a wd volume cap on the adapter.
|
||||
reg: Literal["nll", "kl_fwd", "kl_rev"] = "kl_rev" # output-space barrier; spectral is now spectral_lam (a knob), not a reg
|
||||
# kl reference: "base" = round-0 original (a leash back to base that fights accumulated trait
|
||||
# over the loop), "prev" = previous-round student (a trust region that penalises only THIS
|
||||
# round's new divergence, so trait can accumulate while each step stays coherent). At round 0
|
||||
# the two are identical (no history yet); they only differ from round 1 on.
|
||||
barrier_ref: Literal["base", "prev"] = "prev"
|
||||
lam: float = 0.3 # kl-barrier weight (reg=kl_*); ignored for nll. 0.3 = coherence peak of the #98/#99 ladder (unimodal in lam, peaks 0.1-0.3, 1.0 over-tight); 0.3 = most trait at the peak
|
||||
tau: float = 0.5 # barrier engages only when divergence > tau (nats)
|
||||
weight_decay: float = 0.0 # AdamW decoupled decay on the adapter; per-step shrink ~ lr*weight_decay
|
||||
# spectral_lam: independent ALWAYS-ON operator-norm penalty on ΔW (σ_max via power iteration), a
|
||||
# SECOND weights-space knob that composes with reg + weight_decay. Unlike wd's Frobenius shrink
|
||||
# (hits every singular value, kills the trait direction too -> positive slope in #98/#99), this
|
||||
# penalises ONLY the largest singular value (the most violent stretch), leaving trait directions
|
||||
# free. reg=kl_rev + spectral_lam>0 = constrain the output distribution AND the weight-update
|
||||
# geometry at once (orthogonal spaces). 0 = off. (Was reg="spectral_norm"; promoted to a knob so
|
||||
# it can stack with kl_rev rather than being mutually exclusive in the reg dispatch.)
|
||||
spectral_lam: float = 0.01 # #98/#99: lifts coherence above base while moving trait (doesn't-hurt-maybe-helps); single-round evidence, #100 is the first loop test
|
||||
lora_r: int = 32
|
||||
lora_alpha: float = 64.0 # keep scale = alpha/r = 2 (w2s convention alpha = 2r)
|
||||
epochs: int = 6 # was 2: too few steps to see loss descend; val nll guards overfit
|
||||
|
||||
@@ -22,14 +22,32 @@ NARRATE = re.compile(
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# refusal / assistant-identity boilerplate (NousResearch finetuning-subnet UNWANTED_PHRASES, trimmed):
|
||||
# coherent low-ppl completions that carry no trait and dilute the distillation. Phrases are
|
||||
# refusal-SPECIFIC ("i cannot assist") not bare "i cannot", so on-trait defiance ("I cannot stand
|
||||
# by while...") is NOT dropped.
|
||||
REFUSAL = (
|
||||
"i'm sorry, i can", "i am sorry, i can", "i cannot provide", "i can't provide",
|
||||
"i cannot assist", "i can't assist", "i cannot help with", "i can't help with",
|
||||
"i cannot fulfill", "i cannot comply", "i'm not able to provide", "i am unable to",
|
||||
"i cannot engage", "i must decline", "against my programming",
|
||||
"as an ai", "as a language model", "as an artificial intelligence",
|
||||
"i'm an ai", "i am an ai", "i don't have personal opinions",
|
||||
)
|
||||
|
||||
|
||||
def rep_frac(text: str) -> float:
|
||||
"""Most-repeated 4-gram fraction; ~1.0 means degenerate looping/too short."""
|
||||
"""Max most-repeated n-gram fraction over n in {2,3,4}; ~1.0 = degenerate looping/too short.
|
||||
Small n catches SHORT loops ("instead their instead their" = a bigram) that the 4-gram alone
|
||||
missed (#34: that text scored 0.27 on 4-grams, under rep_tau=0.3, and poisoned training)."""
|
||||
words = text.split()
|
||||
grams = [tuple(words[i : i + 4]) for i in range(len(words) - 3)]
|
||||
if not grams:
|
||||
return 1.0
|
||||
return Counter(grams).most_common(1)[0][1] / len(grams)
|
||||
best = 0.0
|
||||
for n in (2, 3, 4):
|
||||
grams = [tuple(words[i : i + n]) for i in range(len(words) - n + 1)]
|
||||
if not grams:
|
||||
return 1.0 # too short to score at this n -> treat as degenerate
|
||||
best = max(best, Counter(grams).most_common(1)[0][1] / len(grams))
|
||||
return best
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -51,9 +69,10 @@ def filter_completions(model, tok, comps: list[dict], cfg: RunConfig):
|
||||
for c in tqdm(comps, desc="filter ppl", mininterval=120, maxinterval=120):
|
||||
rf = rep_frac(c["completion"])
|
||||
nar = bool(NARRATE.search(c["completion"]))
|
||||
ref = any(p in c["completion"].lower() for p in REFUSAL)
|
||||
ppl = ppl_under_base(model, tok, c["prompt"], c["completion"])
|
||||
keep = (ppl < cfg.ppl_tau) and (rf < cfg.rep_tau) and (not nar)
|
||||
scored.append({**c, "ppl": ppl, "rep": rf, "narrates": nar, "keep": keep})
|
||||
keep = (ppl < cfg.ppl_tau) and (rf < cfg.rep_tau) and (not nar) and (not ref)
|
||||
scored.append({**c, "ppl": ppl, "rep": rf, "narrates": nar, "refuses": ref, "keep": keep})
|
||||
kept = [s for s in scored if s["keep"]]
|
||||
_log_filter_report(scored, cfg)
|
||||
return kept[: cfg.n_keep], scored
|
||||
@@ -112,11 +131,12 @@ def _log_filter_report(scored: list[dict], cfg: RunConfig) -> None:
|
||||
n_ppl = sum(s["ppl"] >= cfg.ppl_tau for s in scored)
|
||||
n_rep = sum(s["rep"] >= cfg.rep_tau for s in scored)
|
||||
n_nar = sum(s["narrates"] for s in scored)
|
||||
n_ref = sum(s["refuses"] for s in scored)
|
||||
n_kept = sum(s["keep"] for s in scored)
|
||||
logger.info(
|
||||
f"filter kept {n_kept}/{len(scored)}. dropped by (overlapping): "
|
||||
f"coherence ppl>={cfg.ppl_tau:g}: {n_ppl}, repetition rep>={cfg.rep_tau}: {n_rep}, "
|
||||
f"persona-leak narrate: {n_nar}. "
|
||||
f"persona-leak narrate: {n_nar}, refusal/identity: {n_ref}. "
|
||||
f"SHOULD: at high alpha coherence-ppl drops the most (steering breaks fluency). If "
|
||||
f"persona-leak dominates, the model is NARRATING the trait not enacting it; if repetition "
|
||||
f"dominates, steering collapsed to loops not incoherence."
|
||||
|
||||
+36
-5
@@ -25,6 +25,28 @@ def _kl_per_pos(logp_a, logp_b): # KL(a || b) summed over vocab, per position
|
||||
return (logp_a.exp() * (logp_a - logp_b)).sum(-1)
|
||||
|
||||
|
||||
def _spectral_div(lora, n_iter: int = 3) -> torch.Tensor:
|
||||
"""Mean operator norm σ_max(ΔW) over the adapter's layers, ΔW = (alpha/r)·B@A.
|
||||
|
||||
Power iteration (u,v held constant) gives σ_max = uᵀ(B@A)v, differentiable in A,B.
|
||||
This is the weights-space analog of weight_decay: wd penalises ||ΔW||_F (sum of all
|
||||
singular values squared), spectral_norm penalises ||ΔW||_2 (the LARGEST singular value),
|
||||
i.e. it caps how much the update can stretch any single input direction. Used with tau=0
|
||||
so relu(div-0)=div is an always-on penalty (like wd), not a hinge barrier."""
|
||||
scale = lora.cfg.alpha / lora.cfg.r
|
||||
sigmas = []
|
||||
for name in lora.A:
|
||||
A, B = lora.A[name].float(), lora.B[name].float() # A: r×d_in, B: d_out×r
|
||||
with torch.no_grad():
|
||||
v = torch.randn(A.shape[1], device=A.device)
|
||||
v = v / v.norm()
|
||||
for _ in range(n_iter):
|
||||
u = B @ (A @ v); u = u / (u.norm() + 1e-8)
|
||||
v = A.T @ (B.T @ u); v = v / (v.norm() + 1e-8)
|
||||
sigmas.append(scale * (u @ (B @ (A @ v)))) # u,v const -> grad flows through A,B
|
||||
return torch.stack(sigmas).mean()
|
||||
|
||||
|
||||
def _gnorm(grads) -> float: # L2 norm of a flat concat of (possibly None) param grads
|
||||
sq = sum(float(g.pow(2).sum()) for g in grads if g is not None)
|
||||
return sq ** 0.5
|
||||
@@ -77,7 +99,7 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
|
||||
lora = ModulatedLoRA(model, r=cfg.lora_r, alpha=cfg.lora_alpha, layer_range=cfg.layer_range)
|
||||
params = list(lora.parameters())
|
||||
opt = torch.optim.AdamW(params, lr=cfg.lr, betas=cfg.adam_betas,
|
||||
weight_decay=(cfg.lam if cfg.reg == "wd" else 0.0))
|
||||
weight_decay=cfg.weight_decay)
|
||||
n_steps = len(train_kept) * cfg.epochs
|
||||
sched = get_cosine_schedule_with_warmup(
|
||||
opt, num_warmup_steps=int(cfg.warmup_ratio * n_steps), num_training_steps=n_steps)
|
||||
@@ -115,9 +137,13 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
|
||||
pbar.update(1); step += 1
|
||||
continue
|
||||
|
||||
# original reference logits (no history, adapter off) for the barrier
|
||||
# barrier reference logits (this round's adapter OFF). barrier_ref="base" bakes no
|
||||
# history -> ref = round-0 original (leash to base, fights accumulated trait); "prev"
|
||||
# bakes the history -> ref = previous-round student (trust region, penalises only this
|
||||
# round's new divergence so trait accumulates while each step stays coherent).
|
||||
if cfg.reg in ("kl_fwd", "kl_rev"):
|
||||
with torch.no_grad(), lora(model, c=0.0):
|
||||
ref_specs = hist_specs if cfg.barrier_ref == "prev" else []
|
||||
with torch.no_grad(), baked(model, ref_specs), lora(model, c=0.0):
|
||||
logp0 = model(**ids).logits[0, :-1].log_softmax(-1)
|
||||
|
||||
# student logits: history baked + this round's adapter live
|
||||
@@ -132,8 +158,13 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
|
||||
elif cfg.reg == "kl_rev":
|
||||
div = _kl_per_pos(logp[mask], logp0[mask]).mean()
|
||||
else:
|
||||
div = torch.zeros((), device=model.device) # nll, wd
|
||||
div = torch.zeros((), device=model.device) # nll
|
||||
barrier = cfg.lam * torch.relu(div - cfg.tau)
|
||||
# spectral_lam: independent ALWAYS-ON operator-norm cap on ΔW (σ_max), composes with the
|
||||
# output-space barrier above and with weight_decay (see config.RunConfig.spectral_lam).
|
||||
# Folded into `barrier` so the g_bar/g_nll gradient-pressure log captures it too.
|
||||
if cfg.spectral_lam > 0:
|
||||
barrier = barrier + cfg.spectral_lam * _spectral_div(lora)
|
||||
loss = sft + barrier
|
||||
nlls.append(sft.item())
|
||||
ep_nlls.append(sft.item())
|
||||
@@ -142,7 +173,7 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
|
||||
# split the gradient pressure: ||∇sft|| vs ||∇barrier|| (retain_graph -> still .backward below).
|
||||
# barrier has no grad path when kl<=tau (relu zeroed), so guard before autograd.grad.
|
||||
g_nll = _gnorm(torch.autograd.grad(sft, params, retain_graph=True, allow_unused=True))
|
||||
barrier_live = barrier.requires_grad and (div - cfg.tau).item() > 0
|
||||
barrier_live = barrier.requires_grad and ((div - cfg.tau).item() > 0 or cfg.spectral_lam > 0)
|
||||
g_bar = _gnorm(torch.autograd.grad(barrier, params, retain_graph=True, allow_unused=True)) if barrier_live else 0.0
|
||||
pressure = g_bar / g_nll if g_nll > 0 else float("nan")
|
||||
cur_lr = sched.get_last_lr()[0] # lr applied to THIS step (before sched.step below)
|
||||
|
||||
+59
-19
@@ -96,6 +96,46 @@ def _log_stage_table(stages: list[dict], base_m: dict) -> None:
|
||||
+ tabulate([_stage_row(s, base_m) for s in stages], headers="keys", tablefmt="github", floatfmt=".3f") + "\n")
|
||||
|
||||
|
||||
def gen_filter_walk(model, tok, v, cfg: RunConfig, hist_specs: list) -> tuple[list[dict], list[dict], float, int]:
|
||||
"""Adaptive-dose gen+filter (the controller steering.py:65 was written for).
|
||||
|
||||
Walk the dose multiplier kappa DOWN until a batch clears cfg.gen_pass_target filter
|
||||
survival, banking every survivor (never waste a coherent completion), and top up
|
||||
batches until >= cfg.min_train kept. Backing the dose off keeps the steered model
|
||||
coherent so the filter has clean survivors. This attacks the over-steer repetition
|
||||
collapse that starved #90 at round 6 from the GEN side; the heal barrier (lam) attacks
|
||||
the same root cause from the WEIGHT side.
|
||||
|
||||
gen runs under the BAKED history (steered student state); the filter runs under the
|
||||
ORIGINAL (ppl-under-base picks the usable C), so each attempt enters/exits baked
|
||||
around gen only. Returns (kept, scored, kappa_final, n_gen). If max_batches can't reach
|
||||
min_train, the heal assert downstream fires the (now dose-aware) starve canary.
|
||||
"""
|
||||
kappa = 1.0
|
||||
kept_all, scored_all, n_gen = [], [], 0
|
||||
for attempt in range(cfg.gen_max_batches):
|
||||
with baked(model, hist_specs):
|
||||
comps = generate_steered(model, tok, v, cfg, alpha_scale=kappa)
|
||||
_, scored = filter_completions(model, tok, comps, cfg) # OUTSIDE baked = under original
|
||||
passing = [s for s in scored if s["keep"]] # TRUE pass set (not filter's n_keep-capped return)
|
||||
kept_all.extend(passing)
|
||||
scored_all.extend(scored)
|
||||
n_gen += len(comps)
|
||||
rate = len(passing) / len(comps) # dose decision uses the real survival rate, not the cap
|
||||
logger.info(
|
||||
f"walk-C attempt {attempt}: kappa={kappa:.2f} kept {len(passing)}/{len(comps)} "
|
||||
f"(rate={rate:.2f}, target>={cfg.gen_pass_target}) -> banked {len(kept_all)}/{cfg.min_train}.\n"
|
||||
"SHOULD: rate climbs as kappa cools; once rate>=target we bank and top up to min_train. "
|
||||
"If rate stays ~0 even at kappa_min, the steered model is incoherent at EVERY dose "
|
||||
"(root cause is upstream of the dose: adapter itself broke, or filter thresholds wrong)."
|
||||
)
|
||||
if len(kept_all) >= cfg.min_train:
|
||||
break
|
||||
if rate < cfg.gen_pass_target and kappa > cfg.gen_kappa_min:
|
||||
kappa *= cfg.gen_kappa_decay # over-driven -> cool the dose for the next batch
|
||||
return kept_all[: cfg.n_keep], scored_all, kappa, n_gen # cap training set at n_keep (top-up may overshoot)
|
||||
|
||||
|
||||
def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
||||
hist_specs = [] # AdapterSpec per folded round (gated bake history)
|
||||
v0_flat = None # round-0 direction, for the Q3 cosine
|
||||
@@ -109,23 +149,21 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
||||
stages = [{"round": "-", "stage": "base", "m": base_m}] # base -> steered -> healed, for table + trajectory plot
|
||||
for rnd in range(cfg.n_rounds):
|
||||
logger.info(f"\n\n=== ROUND {rnd} [{cfg.model.split('/')[-1]} reg={cfg.reg}] gpu {gpu_mem()} ===")
|
||||
# extract teacher vector + sweep-generate steered data from the CURRENT student
|
||||
# extract teacher vector from the CURRENT student, then walk-C generate+filter:
|
||||
# the controller cools the dose so the steered data stays coherent as the adapter
|
||||
# accumulates trait over rounds (gen baked, filter under original -- see gen_filter_walk).
|
||||
with baked(model, hist_specs):
|
||||
v = teacher_vec(model, tok, cfg)
|
||||
comps = generate_steered(model, tok, v, cfg)
|
||||
# STEERED-stage eval: the model state the training data came from (history baked,
|
||||
# vector live at the operating dose = lowest/cleanest alpha, NO new adapter). This
|
||||
# is the raw-steering pareto reference the heal must BEAT (same base, trait via
|
||||
# vector vs trait via the distilled adapter).
|
||||
c_op = cfg.alphas[0] * v.cfg.coeff
|
||||
logger.info(f"\n=== EVAL steered [c={cfg.alphas[0]}] gpu {gpu_mem()} ===")
|
||||
with v(model, C=c_op):
|
||||
kept, scored, kappa, n_comps = gen_filter_walk(model, tok, v, cfg, hist_specs)
|
||||
# STEERED-stage eval at the dose the data ACTUALLY came from (kappa-scaled cleanest alpha),
|
||||
# history baked, NO new adapter: the raw-steering pareto reference the heal must BEAT.
|
||||
c_lo = kappa * cfg.alphas[0]
|
||||
logger.info(f"\n=== EVAL steered [c={c_lo:.2f} kappa={kappa:.2f}] gpu {gpu_mem()} ===")
|
||||
with baked(model, hist_specs):
|
||||
with v(model, C=c_lo * v.cfg.coeff):
|
||||
m_steer = evaluate_model(model, tok, cfg)
|
||||
log_event(run_dir, stage="steered_eval", round=rnd, c=cfg.alphas[0], **m_steer) # persist for offline plot
|
||||
# filter under the ORIGINAL (no history, no steering) -- this picks the usable C
|
||||
logger.info(f"\n=== FILTER [{len(comps)} completions] gpu {gpu_mem()} ===")
|
||||
kept, scored = filter_completions(model, tok, comps, cfg)
|
||||
log_event(run_dir, stage="gen", round=rnd, n_comps=len(comps), n_kept=len(kept), scored=scored)
|
||||
log_event(run_dir, stage="steered_eval", round=rnd, c=c_lo, **m_steer) # persist for offline plot
|
||||
log_event(run_dir, stage="gen", round=rnd, n_comps=n_comps, n_kept=len(kept), kappa=kappa, scored=scored)
|
||||
|
||||
# heal one round on top of the baked history, then fold
|
||||
logger.info(f"\n=== HEAL [{cfg.reg}] gpu {gpu_mem()} ===")
|
||||
@@ -153,8 +191,8 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
||||
v0_flat = vf if v0_flat is None else v0_flat
|
||||
cos_v0 = float(cosine_similarity(vf, v0_flat, dim=0))
|
||||
rec = {"round": rnd, **m, "cos_v0": cos_v0, "steered_ppl": steered_ppl,
|
||||
"adapter_ppl": adapter_ppl, "n_comps": len(comps), "n_kept": len(kept),
|
||||
"heal_nll": heal_nll}
|
||||
"adapter_ppl": adapter_ppl, "n_comps": n_comps, "n_kept": len(kept),
|
||||
"kappa": kappa, "heal_nll": heal_nll}
|
||||
rounds.append(rec)
|
||||
stages.append({"round": rnd, "stage": "steered", "m": m_steer})
|
||||
stages.append({"round": rnd, "stage": "healed", "m": m})
|
||||
@@ -175,14 +213,15 @@ def _log_loop_summary(rounds: list[dict], base_m: dict) -> None:
|
||||
# One row per round, columns walk the pipeline stages left->right:
|
||||
# GEN -> FILTER -> HEAL -> EVAL. (rec_key, display header) is the single source.
|
||||
cols = [("round", "round"),
|
||||
("n_comps", "gen"), ("n_kept", "filt_kept"), # GEN -> FILTER
|
||||
("n_comps", "gen"), ("n_kept", "filt_kept"), ("kappa", "kappa↓"), # GEN -> FILTER (kappa = walk-C dose)
|
||||
("heal_nll", "heal_nll↓"), ("adapter_ppl", "adapter_ppl↓"), # HEAL
|
||||
("auth_nats", "auth_nats↓"), ("care_nats", "care_nats"), # EVAL: target / off-target
|
||||
("coherence", "coherence→"), ("cos_v0", "cos_v0→")]
|
||||
logger.info(
|
||||
"\nloop columns (pipeline stages L->R: GEN | FILTER | HEAL | EVAL):\n"
|
||||
" gen = steered completions generated (n_prompts x alphas)\n"
|
||||
" gen = steered completions generated (n_prompts x alphas, summed over walk-C batches)\n"
|
||||
" filt_kept = completions surviving the coherence/rep/persona filter (-> training set)\n"
|
||||
" kappa↓ = walk-C dose multiplier the controller settled on (1.0 = nominal; <1 = backed off to dodge over-steer)\n"
|
||||
" heal_nll↓ = converged SFT loss of the heal (last-5 mean)\n"
|
||||
" adapter_ppl↓ = ppl-under-original of the no-steering adapter gens (low = coherent/healed)\n"
|
||||
" auth_nats↓ = log(profile p[Authority]), NATS (TARGET: down = less deference)\n"
|
||||
@@ -248,7 +287,8 @@ def main(cfg: RunConfig) -> None:
|
||||
cfg = resolve(cfg)
|
||||
torch.manual_seed(cfg.seed)
|
||||
ts = datetime.now().strftime("%Y%m%dT%H%M%S")
|
||||
slug = f"{cfg.model.split('/')[-1]}_{cfg.reg}_s{cfg.seed}"
|
||||
wd_tag = f"_wd{cfg.weight_decay:g}" if cfg.weight_decay else ""
|
||||
slug = f"{cfg.model.split('/')[-1]}_{cfg.reg}{wd_tag}_s{cfg.seed}"
|
||||
run_dir = make_run_dir(ts, slug, cfg)
|
||||
logger.info(f"argv cfg: {cfg}")
|
||||
model, tok = load_model(cfg.model, getattr(torch, cfg.dtype))
|
||||
|
||||
@@ -58,32 +58,41 @@ def teacher_vec(model, tok, cfg: RunConfig):
|
||||
@torch.no_grad()
|
||||
def _gen_one(model, tok, text, cfg):
|
||||
ids = tok(text, return_tensors="pt").to(model.device)
|
||||
# gemma-3-it recommended sampling (its generation_config.json): top_k=64, top_p=0.95,
|
||||
# temperature default 1.0. NOT Qwen's top_k=20/presence_penalty -- different model family.
|
||||
# NO repetition_penalty / no_repeat_ngram here ON PURPOSE: a gen-time anti-repetition control
|
||||
# MASKS the over-steering pathology (papers over the loops) so the filter passes junk and
|
||||
# walk-C goes blind to "dose too high". Repetition is detected POST-HOC by the rep_tau filter,
|
||||
# never suppressed at generation. (We tried penalty=1.3: it just inflated ppl and starved the
|
||||
# filter, #96.) Repetition must remain VISIBLE so the filter/controller can act on it.
|
||||
gen = model.generate(**ids, max_new_tokens=cfg.gen_max_new_tokens, do_sample=True,
|
||||
temperature=1.0, top_p=0.95,
|
||||
repetition_penalty=cfg.repetition_penalty,
|
||||
no_repeat_ngram_size=cfg.no_repeat_ngram_size,
|
||||
temperature=1.0, top_p=0.95, top_k=64,
|
||||
pad_token_id=tok.pad_token_id)
|
||||
return tok.decode(gen[0, ids.input_ids.shape[1]:], skip_special_tokens=True)
|
||||
|
||||
|
||||
def generate_steered(model, tok, v, cfg: RunConfig) -> list[dict]:
|
||||
def generate_steered(model, tok, v, cfg: RunConfig, alpha_scale: float = 1.0) -> list[dict]:
|
||||
"""Sweep cfg.alphas (raw-vector multiples); generate one completion per prompt x alpha.
|
||||
|
||||
The filter (Q0), not iso-KL, picks the usable C: low alpha is coherent, high
|
||||
alpha collapses, and we keep the coherent-but-trait-laden ones.
|
||||
alpha collapses, and we keep the coherent-but-trait-laden ones. `alpha_scale`
|
||||
(kappa) is the walk-C dose multiplier: the controller cools it over a round to
|
||||
keep the steered model coherent as the baked adapter accumulates trait.
|
||||
"""
|
||||
out = []
|
||||
n_total = cfg.n_prompts * len(cfg.alphas)
|
||||
logger.info(f"\n=== GEN steered [{n_total} = {cfg.n_prompts} prompts x {len(cfg.alphas)} alphas] "
|
||||
f"gpu {gpu_mem()} ===")
|
||||
logger.info(f"\n=== GEN steered [{n_total} = {cfg.n_prompts} prompts x {len(cfg.alphas)} alphas, "
|
||||
f"kappa={alpha_scale:.2f}] gpu {gpu_mem()} ===")
|
||||
pbar = tqdm(total=n_total, desc="gen steered", mininterval=120, maxinterval=120)
|
||||
for i in range(cfg.n_prompts):
|
||||
user = POOL[i % len(POOL)]
|
||||
text = chat_prompt(tok, cfg.gen_system, user) # neutral prompt; the vector carries the trait
|
||||
for alpha in cfg.alphas:
|
||||
with v(model, C=alpha * v.cfg.coeff):
|
||||
with v(model, C=alpha * alpha_scale * v.cfg.coeff):
|
||||
comp = _gen_one(model, tok, text, cfg)
|
||||
out.append({"user": user, "prompt": text, "completion": comp, "alpha": float(alpha)})
|
||||
# record the EFFECTIVE alpha (kappa-scaled) so the filter's per-alpha report and the
|
||||
# offline plots reflect the dose the completion actually came from.
|
||||
out.append({"user": user, "prompt": text, "completion": comp, "alpha": float(alpha * alpha_scale)})
|
||||
pbar.update(1)
|
||||
pbar.close()
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user