walk-C adaptive-dose controller + 10-round paired loop result (journal h)

gen_filter_walk: per round, cool a steering multiplier kappa and top up with
extra gen batches until min_train coherent survivors are banked, so the loop
cannot starve on data count (#90/#100 died at the min_train assert). Paired
#101 (walk-C ON) vs #100 (walk-C OFF, identical config): #101 reaches round 9
where #100 asserted at round 5.

Finding (journal h): walk-C removes the starve CRASH but the real ceiling is
coherence collapse, not data count. Trait over-drives to auth -6.8 while coh
falls 0.99 -> 0.62 and the kept completions degenerate into token loops
("BUILDUTEutive...", "GLUTE GLUTE") by round 7 -- low-entropy so they slip
under ppl_tau and rep_tau and train the next adapter on garbage. Coherent
deliverable is the round 1-2 adapter (auth -3.3 to -3.8 at coh 0.99-0.93).

config: lam 1.0->0.3, spectral_lam 0->0.01 (locked from #98/#99 ablation),
gen_pass_target/gen_kappa_decay/gen_kappa_min/gen_max_batches walk-C knobs.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-06 07:13:51 +08:00
parent 7db5a56cb1
commit b01faa6df1
10 changed files with 731 additions and 122 deletions
+35 -11
View File
@@ -40,21 +40,45 @@ class RunConfig:
# ── generation + filter (U1) ──
n_prompts: int = 16
n_keep: int = 64
min_train: int = 20 # assert at least this many kept completions, else steering/filter starved
gen_max_new_tokens: int = 256
min_train: int = 30 # assert at least this many kept completions, else starved (walk-C should hold us above)
gen_max_new_tokens: int = 512 # longer = more long-horizon coherence signal (GPU has room at bs=1)
max_len: int = 1024
# repetition is incoherence the ppl filter CANNOT see (looped text is low-ppl = predictable), so
# stop it at generation, not post-hoc: penalty softly discourages all repeats, no_repeat_ngram
# hard-blocks any trigram repeat (kills "instead their instead their" loops at the source).
repetition_penalty: float = 1.3
no_repeat_ngram_size: int = 3
ppl_tau: float = 50.0 # drop completions with ppl-under-original above this
rep_tau: float = 0.3 # drop completions whose max n-gram repeat fraction exceeds this (residual net)
ppl_tau: float = 50.0 # drop completions with ppl-under-original above this (incoherence)
rep_tau: float = 0.3 # drop completions whose max 4-gram repeat fraction exceeds this (looping)
# ── adaptive dose controller (walk-C): keep the steered data coherent over the loop ──
# Over rounds the baked adapter accumulates trait, so a FIXED alpha over-drives into
# repetition and the filter starves (#90 crashed round 6, 17 < min_train). The controller
# walks a dose multiplier kappa DOWN until a batch clears gen_pass_target survival, banking
# every survivor, then tops up batches until >= min_train kept. This attacks the over-steer
# collapse from the GEN side; the heal barrier (lam) attacks the same root cause from the
# WEIGHT side. kappa=1 = nominal alphas. The steering.py:65 comment anticipated this controller.
gen_pass_target: float = 0.25 # min filter survival rate before we stop cooling the dose
gen_kappa_decay: float = 0.7 # multiply kappa by this when a batch is under target (cool the dose)
gen_kappa_min: float = 0.2 # floor: below 20% of nominal there is no trait signal left to distil
gen_max_batches: int = 6 # hard cap on gen+filter rounds; if still short, the heal assert fires (genuine starve)
# ── heal (U2): one objective + divergence-to-ORIGINAL barrier ──
reg: Literal["nll", "kl_fwd", "kl_rev", "wd"] = "kl_rev"
lam: float = 1.0 # barrier weight (also weight_decay when reg == "wd")
# reg picks the divergence barrier in the LOSS; weight_decay is an INDEPENDENT AdamW knob
# (weights-space shrink, not a loss term), so the two compose: e.g. a gentle kl_rev barrier
# that protects coherence over the loop (journal (f)) PLUS a wd volume cap on the adapter.
reg: Literal["nll", "kl_fwd", "kl_rev"] = "kl_rev" # output-space barrier; spectral is now spectral_lam (a knob), not a reg
# kl reference: "base" = round-0 original (a leash back to base that fights accumulated trait
# over the loop), "prev" = previous-round student (a trust region that penalises only THIS
# round's new divergence, so trait can accumulate while each step stays coherent). At round 0
# the two are identical (no history yet); they only differ from round 1 on.
barrier_ref: Literal["base", "prev"] = "prev"
lam: float = 0.3 # kl-barrier weight (reg=kl_*); ignored for nll. 0.3 = coherence peak of the #98/#99 ladder (unimodal in lam, peaks 0.1-0.3, 1.0 over-tight); 0.3 = most trait at the peak
tau: float = 0.5 # barrier engages only when divergence > tau (nats)
weight_decay: float = 0.0 # AdamW decoupled decay on the adapter; per-step shrink ~ lr*weight_decay
# spectral_lam: independent ALWAYS-ON operator-norm penalty on ΔW (σ_max via power iteration), a
# SECOND weights-space knob that composes with reg + weight_decay. Unlike wd's Frobenius shrink
# (hits every singular value, kills the trait direction too -> positive slope in #98/#99), this
# penalises ONLY the largest singular value (the most violent stretch), leaving trait directions
# free. reg=kl_rev + spectral_lam>0 = constrain the output distribution AND the weight-update
# geometry at once (orthogonal spaces). 0 = off. (Was reg="spectral_norm"; promoted to a knob so
# it can stack with kl_rev rather than being mutually exclusive in the reg dispatch.)
spectral_lam: float = 0.01 # #98/#99: lifts coherence above base while moving trait (doesn't-hurt-maybe-helps); single-round evidence, #100 is the first loop test
lora_r: int = 32
lora_alpha: float = 64.0 # keep scale = alpha/r = 2 (w2s convention alpha = 2r)
epochs: int = 6 # was 2: too few steps to see loss descend; val nll guards overfit
+28 -8
View File
@@ -22,14 +22,32 @@ NARRATE = re.compile(
re.IGNORECASE,
)
# refusal / assistant-identity boilerplate (NousResearch finetuning-subnet UNWANTED_PHRASES, trimmed):
# coherent low-ppl completions that carry no trait and dilute the distillation. Phrases are
# refusal-SPECIFIC ("i cannot assist") not bare "i cannot", so on-trait defiance ("I cannot stand
# by while...") is NOT dropped.
REFUSAL = (
"i'm sorry, i can", "i am sorry, i can", "i cannot provide", "i can't provide",
"i cannot assist", "i can't assist", "i cannot help with", "i can't help with",
"i cannot fulfill", "i cannot comply", "i'm not able to provide", "i am unable to",
"i cannot engage", "i must decline", "against my programming",
"as an ai", "as a language model", "as an artificial intelligence",
"i'm an ai", "i am an ai", "i don't have personal opinions",
)
def rep_frac(text: str) -> float:
"""Most-repeated 4-gram fraction; ~1.0 means degenerate looping/too short."""
"""Max most-repeated n-gram fraction over n in {2,3,4}; ~1.0 = degenerate looping/too short.
Small n catches SHORT loops ("instead their instead their" = a bigram) that the 4-gram alone
missed (#34: that text scored 0.27 on 4-grams, under rep_tau=0.3, and poisoned training)."""
words = text.split()
grams = [tuple(words[i : i + 4]) for i in range(len(words) - 3)]
if not grams:
return 1.0
return Counter(grams).most_common(1)[0][1] / len(grams)
best = 0.0
for n in (2, 3, 4):
grams = [tuple(words[i : i + n]) for i in range(len(words) - n + 1)]
if not grams:
return 1.0 # too short to score at this n -> treat as degenerate
best = max(best, Counter(grams).most_common(1)[0][1] / len(grams))
return best
@torch.no_grad()
@@ -51,9 +69,10 @@ def filter_completions(model, tok, comps: list[dict], cfg: RunConfig):
for c in tqdm(comps, desc="filter ppl", mininterval=120, maxinterval=120):
rf = rep_frac(c["completion"])
nar = bool(NARRATE.search(c["completion"]))
ref = any(p in c["completion"].lower() for p in REFUSAL)
ppl = ppl_under_base(model, tok, c["prompt"], c["completion"])
keep = (ppl < cfg.ppl_tau) and (rf < cfg.rep_tau) and (not nar)
scored.append({**c, "ppl": ppl, "rep": rf, "narrates": nar, "keep": keep})
keep = (ppl < cfg.ppl_tau) and (rf < cfg.rep_tau) and (not nar) and (not ref)
scored.append({**c, "ppl": ppl, "rep": rf, "narrates": nar, "refuses": ref, "keep": keep})
kept = [s for s in scored if s["keep"]]
_log_filter_report(scored, cfg)
return kept[: cfg.n_keep], scored
@@ -112,11 +131,12 @@ def _log_filter_report(scored: list[dict], cfg: RunConfig) -> None:
n_ppl = sum(s["ppl"] >= cfg.ppl_tau for s in scored)
n_rep = sum(s["rep"] >= cfg.rep_tau for s in scored)
n_nar = sum(s["narrates"] for s in scored)
n_ref = sum(s["refuses"] for s in scored)
n_kept = sum(s["keep"] for s in scored)
logger.info(
f"filter kept {n_kept}/{len(scored)}. dropped by (overlapping): "
f"coherence ppl>={cfg.ppl_tau:g}: {n_ppl}, repetition rep>={cfg.rep_tau}: {n_rep}, "
f"persona-leak narrate: {n_nar}. "
f"persona-leak narrate: {n_nar}, refusal/identity: {n_ref}. "
f"SHOULD: at high alpha coherence-ppl drops the most (steering breaks fluency). If "
f"persona-leak dominates, the model is NARRATING the trait not enacting it; if repetition "
f"dominates, steering collapsed to loops not incoherence."
+36 -5
View File
@@ -25,6 +25,28 @@ def _kl_per_pos(logp_a, logp_b): # KL(a || b) summed over vocab, per position
return (logp_a.exp() * (logp_a - logp_b)).sum(-1)
def _spectral_div(lora, n_iter: int = 3) -> torch.Tensor:
"""Mean operator norm σ_max(ΔW) over the adapter's layers, ΔW = (alpha/r)·B@A.
Power iteration (u,v held constant) gives σ_max = uᵀ(B@A)v, differentiable in A,B.
This is the weights-space analog of weight_decay: wd penalises ||ΔW||_F (sum of all
singular values squared), spectral_norm penalises ||ΔW||_2 (the LARGEST singular value),
i.e. it caps how much the update can stretch any single input direction. Used with tau=0
so relu(div-0)=div is an always-on penalty (like wd), not a hinge barrier."""
scale = lora.cfg.alpha / lora.cfg.r
sigmas = []
for name in lora.A:
A, B = lora.A[name].float(), lora.B[name].float() # A: r×d_in, B: d_out×r
with torch.no_grad():
v = torch.randn(A.shape[1], device=A.device)
v = v / v.norm()
for _ in range(n_iter):
u = B @ (A @ v); u = u / (u.norm() + 1e-8)
v = A.T @ (B.T @ u); v = v / (v.norm() + 1e-8)
sigmas.append(scale * (u @ (B @ (A @ v)))) # u,v const -> grad flows through A,B
return torch.stack(sigmas).mean()
def _gnorm(grads) -> float: # L2 norm of a flat concat of (possibly None) param grads
sq = sum(float(g.pow(2).sum()) for g in grads if g is not None)
return sq ** 0.5
@@ -77,7 +99,7 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
lora = ModulatedLoRA(model, r=cfg.lora_r, alpha=cfg.lora_alpha, layer_range=cfg.layer_range)
params = list(lora.parameters())
opt = torch.optim.AdamW(params, lr=cfg.lr, betas=cfg.adam_betas,
weight_decay=(cfg.lam if cfg.reg == "wd" else 0.0))
weight_decay=cfg.weight_decay)
n_steps = len(train_kept) * cfg.epochs
sched = get_cosine_schedule_with_warmup(
opt, num_warmup_steps=int(cfg.warmup_ratio * n_steps), num_training_steps=n_steps)
@@ -115,9 +137,13 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
pbar.update(1); step += 1
continue
# original reference logits (no history, adapter off) for the barrier
# barrier reference logits (this round's adapter OFF). barrier_ref="base" bakes no
# history -> ref = round-0 original (leash to base, fights accumulated trait); "prev"
# bakes the history -> ref = previous-round student (trust region, penalises only this
# round's new divergence so trait accumulates while each step stays coherent).
if cfg.reg in ("kl_fwd", "kl_rev"):
with torch.no_grad(), lora(model, c=0.0):
ref_specs = hist_specs if cfg.barrier_ref == "prev" else []
with torch.no_grad(), baked(model, ref_specs), lora(model, c=0.0):
logp0 = model(**ids).logits[0, :-1].log_softmax(-1)
# student logits: history baked + this round's adapter live
@@ -132,8 +158,13 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
elif cfg.reg == "kl_rev":
div = _kl_per_pos(logp[mask], logp0[mask]).mean()
else:
div = torch.zeros((), device=model.device) # nll, wd
div = torch.zeros((), device=model.device) # nll
barrier = cfg.lam * torch.relu(div - cfg.tau)
# spectral_lam: independent ALWAYS-ON operator-norm cap on ΔW (σ_max), composes with the
# output-space barrier above and with weight_decay (see config.RunConfig.spectral_lam).
# Folded into `barrier` so the g_bar/g_nll gradient-pressure log captures it too.
if cfg.spectral_lam > 0:
barrier = barrier + cfg.spectral_lam * _spectral_div(lora)
loss = sft + barrier
nlls.append(sft.item())
ep_nlls.append(sft.item())
@@ -142,7 +173,7 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
# split the gradient pressure: ||∇sft|| vs ||∇barrier|| (retain_graph -> still .backward below).
# barrier has no grad path when kl<=tau (relu zeroed), so guard before autograd.grad.
g_nll = _gnorm(torch.autograd.grad(sft, params, retain_graph=True, allow_unused=True))
barrier_live = barrier.requires_grad and (div - cfg.tau).item() > 0
barrier_live = barrier.requires_grad and ((div - cfg.tau).item() > 0 or cfg.spectral_lam > 0)
g_bar = _gnorm(torch.autograd.grad(barrier, params, retain_graph=True, allow_unused=True)) if barrier_live else 0.0
pressure = g_bar / g_nll if g_nll > 0 else float("nan")
cur_lr = sched.get_last_lr()[0] # lr applied to THIS step (before sched.step below)
+59 -19
View File
@@ -96,6 +96,46 @@ def _log_stage_table(stages: list[dict], base_m: dict) -> None:
+ tabulate([_stage_row(s, base_m) for s in stages], headers="keys", tablefmt="github", floatfmt=".3f") + "\n")
def gen_filter_walk(model, tok, v, cfg: RunConfig, hist_specs: list) -> tuple[list[dict], list[dict], float, int]:
"""Adaptive-dose gen+filter (the controller steering.py:65 was written for).
Walk the dose multiplier kappa DOWN until a batch clears cfg.gen_pass_target filter
survival, banking every survivor (never waste a coherent completion), and top up
batches until >= cfg.min_train kept. Backing the dose off keeps the steered model
coherent so the filter has clean survivors. This attacks the over-steer repetition
collapse that starved #90 at round 6 from the GEN side; the heal barrier (lam) attacks
the same root cause from the WEIGHT side.
gen runs under the BAKED history (steered student state); the filter runs under the
ORIGINAL (ppl-under-base picks the usable C), so each attempt enters/exits baked
around gen only. Returns (kept, scored, kappa_final, n_gen). If max_batches can't reach
min_train, the heal assert downstream fires the (now dose-aware) starve canary.
"""
kappa = 1.0
kept_all, scored_all, n_gen = [], [], 0
for attempt in range(cfg.gen_max_batches):
with baked(model, hist_specs):
comps = generate_steered(model, tok, v, cfg, alpha_scale=kappa)
_, scored = filter_completions(model, tok, comps, cfg) # OUTSIDE baked = under original
passing = [s for s in scored if s["keep"]] # TRUE pass set (not filter's n_keep-capped return)
kept_all.extend(passing)
scored_all.extend(scored)
n_gen += len(comps)
rate = len(passing) / len(comps) # dose decision uses the real survival rate, not the cap
logger.info(
f"walk-C attempt {attempt}: kappa={kappa:.2f} kept {len(passing)}/{len(comps)} "
f"(rate={rate:.2f}, target>={cfg.gen_pass_target}) -> banked {len(kept_all)}/{cfg.min_train}.\n"
"SHOULD: rate climbs as kappa cools; once rate>=target we bank and top up to min_train. "
"If rate stays ~0 even at kappa_min, the steered model is incoherent at EVERY dose "
"(root cause is upstream of the dose: adapter itself broke, or filter thresholds wrong)."
)
if len(kept_all) >= cfg.min_train:
break
if rate < cfg.gen_pass_target and kappa > cfg.gen_kappa_min:
kappa *= cfg.gen_kappa_decay # over-driven -> cool the dose for the next batch
return kept_all[: cfg.n_keep], scored_all, kappa, n_gen # cap training set at n_keep (top-up may overshoot)
def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
hist_specs = [] # AdapterSpec per folded round (gated bake history)
v0_flat = None # round-0 direction, for the Q3 cosine
@@ -109,23 +149,21 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
stages = [{"round": "-", "stage": "base", "m": base_m}] # base -> steered -> healed, for table + trajectory plot
for rnd in range(cfg.n_rounds):
logger.info(f"\n\n=== ROUND {rnd} [{cfg.model.split('/')[-1]} reg={cfg.reg}] gpu {gpu_mem()} ===")
# extract teacher vector + sweep-generate steered data from the CURRENT student
# extract teacher vector from the CURRENT student, then walk-C generate+filter:
# the controller cools the dose so the steered data stays coherent as the adapter
# accumulates trait over rounds (gen baked, filter under original -- see gen_filter_walk).
with baked(model, hist_specs):
v = teacher_vec(model, tok, cfg)
comps = generate_steered(model, tok, v, cfg)
# STEERED-stage eval: the model state the training data came from (history baked,
# vector live at the operating dose = lowest/cleanest alpha, NO new adapter). This
# is the raw-steering pareto reference the heal must BEAT (same base, trait via
# vector vs trait via the distilled adapter).
c_op = cfg.alphas[0] * v.cfg.coeff
logger.info(f"\n=== EVAL steered [c={cfg.alphas[0]}] gpu {gpu_mem()} ===")
with v(model, C=c_op):
kept, scored, kappa, n_comps = gen_filter_walk(model, tok, v, cfg, hist_specs)
# STEERED-stage eval at the dose the data ACTUALLY came from (kappa-scaled cleanest alpha),
# history baked, NO new adapter: the raw-steering pareto reference the heal must BEAT.
c_lo = kappa * cfg.alphas[0]
logger.info(f"\n=== EVAL steered [c={c_lo:.2f} kappa={kappa:.2f}] gpu {gpu_mem()} ===")
with baked(model, hist_specs):
with v(model, C=c_lo * v.cfg.coeff):
m_steer = evaluate_model(model, tok, cfg)
log_event(run_dir, stage="steered_eval", round=rnd, c=cfg.alphas[0], **m_steer) # persist for offline plot
# filter under the ORIGINAL (no history, no steering) -- this picks the usable C
logger.info(f"\n=== FILTER [{len(comps)} completions] gpu {gpu_mem()} ===")
kept, scored = filter_completions(model, tok, comps, cfg)
log_event(run_dir, stage="gen", round=rnd, n_comps=len(comps), n_kept=len(kept), scored=scored)
log_event(run_dir, stage="steered_eval", round=rnd, c=c_lo, **m_steer) # persist for offline plot
log_event(run_dir, stage="gen", round=rnd, n_comps=n_comps, n_kept=len(kept), kappa=kappa, scored=scored)
# heal one round on top of the baked history, then fold
logger.info(f"\n=== HEAL [{cfg.reg}] gpu {gpu_mem()} ===")
@@ -153,8 +191,8 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
v0_flat = vf if v0_flat is None else v0_flat
cos_v0 = float(cosine_similarity(vf, v0_flat, dim=0))
rec = {"round": rnd, **m, "cos_v0": cos_v0, "steered_ppl": steered_ppl,
"adapter_ppl": adapter_ppl, "n_comps": len(comps), "n_kept": len(kept),
"heal_nll": heal_nll}
"adapter_ppl": adapter_ppl, "n_comps": n_comps, "n_kept": len(kept),
"kappa": kappa, "heal_nll": heal_nll}
rounds.append(rec)
stages.append({"round": rnd, "stage": "steered", "m": m_steer})
stages.append({"round": rnd, "stage": "healed", "m": m})
@@ -175,14 +213,15 @@ def _log_loop_summary(rounds: list[dict], base_m: dict) -> None:
# One row per round, columns walk the pipeline stages left->right:
# GEN -> FILTER -> HEAL -> EVAL. (rec_key, display header) is the single source.
cols = [("round", "round"),
("n_comps", "gen"), ("n_kept", "filt_kept"), # GEN -> FILTER
("n_comps", "gen"), ("n_kept", "filt_kept"), ("kappa", "kappa↓"), # GEN -> FILTER (kappa = walk-C dose)
("heal_nll", "heal_nll↓"), ("adapter_ppl", "adapter_ppl↓"), # HEAL
("auth_nats", "auth_nats↓"), ("care_nats", "care_nats"), # EVAL: target / off-target
("coherence", "coherence→"), ("cos_v0", "cos_v0→")]
logger.info(
"\nloop columns (pipeline stages L->R: GEN | FILTER | HEAL | EVAL):\n"
" gen = steered completions generated (n_prompts x alphas)\n"
" gen = steered completions generated (n_prompts x alphas, summed over walk-C batches)\n"
" filt_kept = completions surviving the coherence/rep/persona filter (-> training set)\n"
" kappa↓ = walk-C dose multiplier the controller settled on (1.0 = nominal; <1 = backed off to dodge over-steer)\n"
" heal_nll↓ = converged SFT loss of the heal (last-5 mean)\n"
" adapter_ppl↓ = ppl-under-original of the no-steering adapter gens (low = coherent/healed)\n"
" auth_nats↓ = log(profile p[Authority]), NATS (TARGET: down = less deference)\n"
@@ -248,7 +287,8 @@ def main(cfg: RunConfig) -> None:
cfg = resolve(cfg)
torch.manual_seed(cfg.seed)
ts = datetime.now().strftime("%Y%m%dT%H%M%S")
slug = f"{cfg.model.split('/')[-1]}_{cfg.reg}_s{cfg.seed}"
wd_tag = f"_wd{cfg.weight_decay:g}" if cfg.weight_decay else ""
slug = f"{cfg.model.split('/')[-1]}_{cfg.reg}{wd_tag}_s{cfg.seed}"
run_dir = make_run_dir(ts, slug, cfg)
logger.info(f"argv cfg: {cfg}")
model, tok = load_model(cfg.model, getattr(torch, cfg.dtype))
+18 -9
View File
@@ -58,32 +58,41 @@ def teacher_vec(model, tok, cfg: RunConfig):
@torch.no_grad()
def _gen_one(model, tok, text, cfg):
ids = tok(text, return_tensors="pt").to(model.device)
# gemma-3-it recommended sampling (its generation_config.json): top_k=64, top_p=0.95,
# temperature default 1.0. NOT Qwen's top_k=20/presence_penalty -- different model family.
# NO repetition_penalty / no_repeat_ngram here ON PURPOSE: a gen-time anti-repetition control
# MASKS the over-steering pathology (papers over the loops) so the filter passes junk and
# walk-C goes blind to "dose too high". Repetition is detected POST-HOC by the rep_tau filter,
# never suppressed at generation. (We tried penalty=1.3: it just inflated ppl and starved the
# filter, #96.) Repetition must remain VISIBLE so the filter/controller can act on it.
gen = model.generate(**ids, max_new_tokens=cfg.gen_max_new_tokens, do_sample=True,
temperature=1.0, top_p=0.95,
repetition_penalty=cfg.repetition_penalty,
no_repeat_ngram_size=cfg.no_repeat_ngram_size,
temperature=1.0, top_p=0.95, top_k=64,
pad_token_id=tok.pad_token_id)
return tok.decode(gen[0, ids.input_ids.shape[1]:], skip_special_tokens=True)
def generate_steered(model, tok, v, cfg: RunConfig) -> list[dict]:
def generate_steered(model, tok, v, cfg: RunConfig, alpha_scale: float = 1.0) -> list[dict]:
"""Sweep cfg.alphas (raw-vector multiples); generate one completion per prompt x alpha.
The filter (Q0), not iso-KL, picks the usable C: low alpha is coherent, high
alpha collapses, and we keep the coherent-but-trait-laden ones.
alpha collapses, and we keep the coherent-but-trait-laden ones. `alpha_scale`
(kappa) is the walk-C dose multiplier: the controller cools it over a round to
keep the steered model coherent as the baked adapter accumulates trait.
"""
out = []
n_total = cfg.n_prompts * len(cfg.alphas)
logger.info(f"\n=== GEN steered [{n_total} = {cfg.n_prompts} prompts x {len(cfg.alphas)} alphas] "
f"gpu {gpu_mem()} ===")
logger.info(f"\n=== GEN steered [{n_total} = {cfg.n_prompts} prompts x {len(cfg.alphas)} alphas, "
f"kappa={alpha_scale:.2f}] gpu {gpu_mem()} ===")
pbar = tqdm(total=n_total, desc="gen steered", mininterval=120, maxinterval=120)
for i in range(cfg.n_prompts):
user = POOL[i % len(POOL)]
text = chat_prompt(tok, cfg.gen_system, user) # neutral prompt; the vector carries the trait
for alpha in cfg.alphas:
with v(model, C=alpha * v.cfg.coeff):
with v(model, C=alpha * alpha_scale * v.cfg.coeff):
comp = _gen_one(model, tok, text, cfg)
out.append({"user": user, "prompt": text, "completion": comp, "alpha": float(alpha)})
# record the EFFECTIVE alpha (kappa-scaled) so the filter's per-alpha report and the
# offline plots reflect the dose the completion actually came from.
out.append({"user": user, "prompt": text, "completion": comp, "alpha": float(alpha * alpha_scale)})
pbar.update(1)
pbar.close()
return out