diff --git a/docs/RESEARCH_JOURNAL.md b/docs/RESEARCH_JOURNAL.md index 86edcd8..5872029 100644 --- a/docs/RESEARCH_JOURNAL.md +++ b/docs/RESEARCH_JOURNAL.md @@ -443,3 +443,163 @@ provisional; coherence and qualitative carry the Gate 1 claim. that actually moves the target). (2) Metric infra: wire steering-lite's loading-weighted Delta-logit auth_sep (results.py / aggregate_flips) instead of my 7-way-logp mean, OR robustify to median. Plan B1 (super_sspace/sspace) if still broad; recorded in spec. + +## 2026-06-04 (d) -- the "phantom-KL init bug" was a WRONG diagnosis (init is fine); trait still does not transfer + +**Introduction.** I claimed the heal had two bugs: (1) barrier KL starting at ~0.6 before training, +blamed on a non-zero LoRA B init, and (2) train SFT loss not descending, blamed on beta2=0.999. The +user pushed back (scout mindset): mean=1e-4 std=1e-4 B init is within normal range, and "you only have +confirmation if it learns". On checking, claim (1) is REFUTED and claim (2) is unconfirmed. The +question that actually matters is unchanged: why does a fit adapter not move the trait? Continues +entry (a) and the task4/task10 data-ceiling hypothesis. + +**Methods.** Commit `f280a67`, gemma-3-4b-it, reg=kl_rev, seed 42, 1 round, n_prompts 16, tinymfv +classic eval (think_tokens 128). The commit BUNDLED five changes (a mistake, see Discussion): LoRA +init B=normal(mean=1e-4)->B=0, betas (0.9,0.999)->(0.9,0.95), cosine-with-warmup (0.1) schedule, +r 8->32 / alpha 64 / layer_range (0.0,1.0)->(0.2,0.8), epochs 2->6, plus a new per-epoch val nll. +The decisive evidence is NOT from #79 but from #78's verbose log (`logs/20260604T172126_verbose.log`, +OLD init), which lets me read the round-0 step-0 KL the init claim hinges on. + +**Results.** + +| epoch | train_nll | val_nll | +|-------|-----------|---------| +| 0 | 1.710 | 1.365 | +| 1 | 1.162 | 1.417 | +| 3 | 0.931 | 1.201 | +| 5 | 0.806 | 1.240 | + +Table 1. Per-epoch mean SFT nll on the 42 train completions and the 6 held-out val completions, heal +round 0, run #79. train_nll falls monotonically; val_nll wanders ~1.2-1.4 (n=6, noisy). + +| stage | auth_nats | coherence | +|---------|-----------|-----------| +| base | -2.354 | 0.996 | +| steered | -3.517 | 0.992 | +| healed | -2.464 | 0.999 | + +Table 2. tinymfv trait (auth_nats, log marginal blame-mass on Authority, DOWN = more trait) and +coherence (p_ans_any) at the three pipeline stages of round 0, run #79. coh_cost = |dCoh|/|dAuth| = +0.027, not surgical (dCare=+0.28 moved more than dAuth=-0.11). + +Provenance: +- Commit: `f280a67` (heal init/schedule/betas/val fixes). +- Run command (#79): `PYTHONUNBUFFERED=1 STEER_ATTN_IMPL=eager uv run python -m steer_heal.run --reg kl_rev --n-rounds 1 --n-prompts 16` +- Run dir: `out/20260604T194133_gemma-3-4b-it_kl_rev_s42/` (events.jsonl, ckpt/r0.safetensors). +- Log: `pueue log 79 --full`; Table 1 cells are the `epoch N: train_nll=.. val_nll=..` lines; Table 2 + base/steered are the stage-pareto table, healed is the `round 0:` line and `eval:` auth_nats=-2.46. +- REFUTATION of the init claim: #78 round-0 heal (OLD init B=normal, NO baked history), verbose log + `heal_round:119` rows: step 0 nll=1.90 **kl=0.00**, step 4 kl=0.21, step 8 kl=0.33, step 12 kl=0.80. + KL is ~0 at init with the old init, then RISES as SFT installs the trait. So the init did not produce + a phantom KL. The kl=0.64-at-step-0 the user pasted was ROUND 5 (line 1653 sits between ROUND 5 at + 1367 and ROUND 6 at 1709), i.e. five rounds of baked history = real cross-round drift, which is what + the barrier is meant to measure. B=0 is harmless and standard but fixed nothing. +- train_nll did descend in #79 (1.71->0.81) but this is UNATTRIBUTED (5 changes bundled) and #78 never + logged per-epoch train_nll, so "loss was not descending" was never actually established -- it was a + read of bs=1 per-step noise. + +Healed auth_nats moves only -0.11 from base (-2.354 to -2.464) in #79, vs steered -3.517. #78 r0 healed +was -2.69. Both small, both near base, metric noisy (emitted_close=0/264). The changes did not improve +trait transfer. + +**Discussion (speculative).** I made the classic ml-debug error: pattern-matched a symptom (KL>0 at +step 0) to a tidy mechanism (bad init), committed a fix, and declared victory without the isolating +measurement. The user caught it. The measurement (#78 round-0 step-0 kl=0.00, old init) refutes the +init story outright; the 0.64 was baked history. The premise behind the second claim (loss not +descending) was never measured at epoch level either. Net: I changed five things, can attribute +nothing, and the only metric that matters (trait transfer) is unchanged. What IS supported, by the +structural-ceiling lens: fixing optimiser-side knobs did not move the trait, so the trait is not +optimiser-limited -- it is the data (filter keeps near-base completions, entries a/(diag_heal)) or the +parameterisation/eval. Genuinely open between those. + +**Next.** (1) The discriminating test is overfit-one-batch on a KNOWN trait-laden completion: can the +adapter reproduce defiant-of-authority text (expressiveness) AND does tinymfv then read the trait +(data/eval)? That splits data-ceiling from can't-express/can't-see. (2) #80 clean 10-round is running; +reframed, it tests whether the stall persists (it is NOT a fix validation). (3) Do not bundle changes +again; ablate one at a time if attribution matters. (4) lam retune still parked. + +## 2026-06-04 (e) -- barrier-strength sweep: the heal barrier only throttles the trait and buys no coherence at the coherent dose; nll (no barrier) is best + +**Introduction.** Entry (d) left it open whether the trait fails to transfer because the kept data is +near-base (data ceiling) or because the barrier suppresses it. The user pushed on this: "you haven't +even tried wd and kl values?". So I re-healed ONE run's cached kept completions (the 48 from #79) with +the SAME LoRA-A init seed, varying ONLY the regulariser (reg, lam, tau). Same data + same init means +the only thing that can move healed auth_nats is the barrier. Pre-registered: outcome 1 = monotone +weaker-barrier -> more-trait (the barrier throttles); outcome 2 = all dAuth ~ 0 incl nll (data +ceiling); outcome 3 = inconclusive. Continues entry (a)/(d). + +**Methods.** Commit `f280a67`, gemma-3-4b-it, seed 42 (`torch.manual_seed(cfg.seed)` per config so the +A-init is identical), 6 epochs, lr 1e-4 cosine+warmup, lora r=32 alpha=64 layers (0.2,0.8). Re-heal +harness `scripts/diag_barrier.py` reads #79's `events.jsonl` gen event, keeps the 48 keep==True +completions, re-trains a fresh adapter per config, bakes it, runs tinymfv (think_tokens 128). Three +families across three pueue runs: #82 kl_rev with the tau=0.5 hinge, #86 kl_rev with tau=0 (pure linear +barrier = lam*div, the w2s form), #85 weight-decay decades 0.1..100. Base auth_nats=-2.354, coh=0.996. + +**Results.** + +| reg / family | strength | dAuth | coh | heal_nll | +|-----------------|----------|-------|-------|----------| +| nll (no barrier)| 0 | -1.247| 1.000 | 0.199 | +| kl_rev linear | 0.03 | -1.053| 0.999 | 0.204 | +| kl_rev linear | 0.10 | -0.664| 1.000 | 0.232 | +| kl_rev linear | 0.30 | -0.173| 0.999 | 0.471 | +| kl_rev linear | 1.00 | -0.141| 1.000 | 0.970 | + +Table 1. Pure-linear kl_rev barrier (tau=0), #86. `strength` = lam, the barrier weight. dAuth = +healed auth_nats minus base (more negative = more trait retained; DOWN = more trait). coh = p_ans_any. +heal_nll = converged SFT loss (last-5-step mean). Trait falls monotonically as the barrier strengthens; +heal_nll rises in step (the barrier is fighting the SFT objective); coh never leaves ~1.0. + +| reg | weight_decay | dAuth | coh | +|-----|--------------|-------|-------| +| nll | 0 | -1.247| 1.000 | +| wd | 0.1 | -1.247| 1.000 | +| wd | 1.0 | -1.247| 1.000 | +| wd | 3.0 | -1.247| 1.000 | +| wd | 10.0 | -1.247| 1.000 | +| wd | 30.0 | -1.251| 0.999 | +| wd | 100.0 | -0.519| 1.000 | + +Table 2. AdamW decoupled weight decay on the adapter, #85. (The log table also prints a tau column; +it is meaningless for wd and is dropped here.) dAuth is byte-identical to nll up to wd=30, then halves +at wd=100. coh never leaves ~1.0. + +Provenance: +- Commit: `f280a67`. Harness: `scripts/diag_barrier.py ` (modes barrier/tau0/wd). +- Source data: `out/20260604T194133_gemma-3-4b-it_kl_rev_s42/events.jsonl`, the 48 keep==True + completions of the gen event (entry (d)'s #79). +- Run commands: #82 `... diag_barrier.py out/...s42/ barrier`; #86 `... barrier` ... `tau0`; #85 `... wd`. +- Logs / cells: each dAuth/coh is the ` strength=.. : auth=.. (dAuth=..) coh=..` line and the + end-of-log `barrier sweep (re-heal #79 ...)` table. #86 `pueue log 86 --full`; #85 `pueue log 85 + --full`; #82 `pueue log 82 --full`. #85 runs older code that prints `lam=`/`tau=` instead of + `strength=`; values are unaffected. +- #82 hinge (tau=0.5) for cross-reference: nll -1.247, kl_rev lam 0.03 -0.93 / 0.1 -0.40 / 0.3 -0.17 / + 1.0 -0.17; lam 0.3 tau 1.0 -0.31 (raising tau weakens it); wd 0.01 and 0.1 byte-identical to nll. + +Outcome 1 holds, decisively and in triplicate: weaker barrier -> more trait, monotone, across the kl +hinge (#82), the kl linear form (#86), and weight decay (#85). nll retains the full -1.247 at coh +1.000; every barrier strictly reduces |dAuth| while leaving coherence at ~1.0. + +**Discussion (speculative).** My read: at this (coherent) operating dose the barrier is pure cost. It +removes trait and never buys coherence, because coherence was already ~1.0 with no barrier, so the +relu(div-tau) penalty has nothing to fix and only pulls the adapter back toward the original. The two +non-kl families converge on the same story by different mechanisms: wd just shrinks the whole adapter +toward no-op (hence the knee only appears at wd=100, where per-step decoupled shrink lr*wd=1e-2 +compounds to ~0.92x per step over 252 steps and finally bites), and the kl barrier pulls the output +distribution back toward base. Neither is a selective incoherence-cleaner here; both are volume knobs +on the adapter. This refutes the data-ceiling reading of entries (a)/(d) for THIS data: nll reaching +dAuth=-1.247 (it even exceeds the steered teacher's -1.16 of #79) proves the 48 kept completions carry +plenty of trait. The earlier negative heals (task4/10/19) all ran lam=1.0, i.e. the right-hand end of +Table 1 where the trait is throttled to ~-0.14. The big caveat: this is the COHERENT dose, where the +barrier can only hurt. Its hypothesised value is the coherence-breaking dose (filter off, or a higher +C) where nll WOULD lose coherence and the barrier might pay for itself; that is untested here. +Alternative hypothesis I cannot yet exclude: n=1 per cell, so a +-0.1 nat seed wobble could fake part +of the monotone tail (though the trend spans >1 nat across 5 points, far beyond plausible single-seed +noise). Distinguished by the 3-seed repeat (task25). + +**Next.** (1) Launched the paired 10-round to test the loop, same seed 42: #87 nll (barrier off, +control) and #88 kl_rev lam=0.1 tau=0 (gentle active barrier, 53% trait single-round). The loop is the +one place cumulative incoherence can appear, so it is where the barrier might finally earn its place; +the contrast is whether nll's coherence decays over rounds while #88's holds. (2) 3-seed noise floor on +the headline (task25). (3) The real barrier test remains filter-off at a coherence-breaking dose +(task11/22), still parked. diff --git a/scripts/diag_barrier.py b/scripts/diag_barrier.py new file mode 100644 index 0000000..21ae221 --- /dev/null +++ b/scripts/diag_barrier.py @@ -0,0 +1,119 @@ +"""Barrier-strength sweep: is the trait failing to transfer because the barrier (kl/wd) is too +strong, or because the kept data is near-base? Re-heal from ONE run's cached kept completions with +the SAME init seed, varying ONLY (reg, lam, tau). Same data + same init => the only thing that moves +healed auth_nats is the barrier. + +reg=nll is the ablation: barrier OFF. If nll ALSO lands near base, the data is the ceiling, not the +barrier. If nll (or weak kl/wd) retains MORE trait than kl_rev lam=1.0, the barrier was killing it. + +Run: uv run python scripts/diag_barrier.py out/20260604T194133_gemma-3-4b-it_kl_rev_s42/ +""" +import dataclasses +import sys +from pathlib import Path + +import srsly +import torch +from loguru import logger +from tabulate import tabulate +from transformers import AutoModelForCausalLM, AutoTokenizer + +sys.path.insert(0, "src") +from steer_heal.config import RunConfig # noqa: E402 +from steer_heal.eval import evaluate_model # noqa: E402 +from steer_heal.heal import heal_round # noqa: E402 +from steer_heal.ws.bake import baked # noqa: E402 + +run_dir = Path(sys.argv[1]) +mode = sys.argv[2] if len(sys.argv) > 2 else "barrier" # "barrier" (kl sweep) or "wd" (decay decade sweep) +gen_round = int(sys.argv[3]) if len(sys.argv) > 3 else 0 # which round's kept data to re-heal (0 = clean; later = messier) +base_cfg = RunConfig() + +# (reg, lam, tau) grids. nll = barrier off (ablation) and the shared trait-ceiling reference. +GRIDS = { + # kl_rev strength + a tau probe. lam 0.03 (w2s) .. 1.0 (current default). + "barrier": [ + ("nll", 0.0, 0.5), # ablation: no barrier at all + ("kl_rev", 0.03, 0.5), + ("kl_rev", 0.1, 0.5), + ("kl_rev", 0.3, 0.5), + ("kl_rev", 1.0, 0.5), + ("kl_rev", 0.3, 1.0), # weaker via higher tau (engages later) + ], + # pure linear kl_rev: tau=0 => barrier = lam*relu(div) = lam*div, always on, no deadband + # (the w2s form). Cleaner knob than the hinge; compare against the tau=0.5 rows in "barrier". + "tau0": [ + ("nll", 0.0, 0.0), + ("kl_rev", 0.03, 0.0), + ("kl_rev", 0.1, 0.0), + ("kl_rev", 0.3, 0.0), + ("kl_rev", 1.0, 0.0), + ], + # tau sweep: fix lam (middling barrier) and vary the deadband tau. Higher tau = barrier engages + # only on larger divergence = weaker. Shows whether a deadband helps on degenerate (round 2) data. + "tau": [ + ("nll", 0.0, 0.0), + ("kl_rev", 0.3, 0.0), + ("kl_rev", 0.3, 0.25), + ("kl_rev", 0.3, 0.5), + ("kl_rev", 0.3, 1.0), + ("kl_rev", 0.3, 2.0), + ], + # weight decay: a WEIGHTS-space constraint (AdamW decoupled decay, tau irrelevant). Its per-step + # shrink is lr*wd, and lr~1e-4 is tiny, so #82 found wd<=0.1 byte-identical to nll (~0.1% shrink + # over 252 steps). Sweep up to 100 to find where cumulative shrink (252*lr*wd) reaches order-1. + "wd": [ + ("nll", 0.0, 0.5), + ("wd", 1e-1, 0.5), + ("wd", 1.0, 0.5), + ("wd", 3.0, 0.5), + ("wd", 10.0, 0.5), + ("wd", 30.0, 0.5), + ("wd", 100.0, 0.5), + ], +} +GRID = GRIDS[mode] +logger.info(f"barrier sweep mode={mode}: {len(GRID)} configs") + +# kept completions (keep==True) from a CHOSEN round of the source run. round 0 = clean steered-on-base +# data; later rounds = data after the loop started degenerating (repetition), the regime where the +# barrier is hypothesised to matter (it was pure-cost on clean round-0 data, #82/85/86). +gen = next(e for e in srsly.read_jsonl(run_dir / "events.jsonl") + if e["stage"] == "gen" and e["round"] == gen_round) +kept = [{"prompt": s["prompt"], "completion": s["completion"]} for s in gen["scored"] if s["keep"]] +logger.info(f"loaded {len(kept)} kept completions from {run_dir.name} round {gen_round}") + +tok = AutoTokenizer.from_pretrained(base_cfg.model) +if tok.pad_token is None: + tok.pad_token = tok.eos_token +model = AutoModelForCausalLM.from_pretrained( + base_cfg.model, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="eager" +).eval() + +base_m = evaluate_model(model, tok, base_cfg) +logger.info(f"base: auth_nats={base_m['auth_nats']:+.3f} care_nats={base_m['care_nats']:+.3f} coh={base_m['coherence']:.3f}") + +rows = [] +for reg, lam, tau in GRID: + cfg = dataclasses.replace(base_cfg, reg=reg, lam=lam, tau=tau) + torch.manual_seed(cfg.seed) # identical LoRA-A init across barrier values -> only the barrier differs + lora, spec, heal_nll = heal_round(model, tok, kept, [], cfg) + with baked(model, [spec]): + m = evaluate_model(model, tok, cfg) + dauth = m["auth_nats"] - base_m["auth_nats"] + dcoh = m["coherence"] - base_m["coherence"] + # ONE strength knob per row: kl-barrier weight for kl_rev/kl_fwd, AdamW weight_decay for wd, + # ignored for nll. tau (kl deadband) only applies to the kl regs -> "-" otherwise. + is_kl = reg in ("kl_rev", "kl_fwd") + rows.append({"reg": reg, "strength": lam, "tau(kl only)": (f"{tau:.1f}" if is_kl else "-"), + "heal_nll↓": heal_nll, "auth_nats↓": m["auth_nats"], "dAuth↓": dauth, + "care_nats": m["care_nats"], "coh→": m["coherence"], "dCoh": dcoh}) + logger.info(f" {reg} strength={lam}{f' tau={tau}' if is_kl else ''}: " + f"auth={m['auth_nats']:+.3f} (dAuth={dauth:+.3f}) coh={m['coherence']:.3f}") + +logger.info("SHOULD: if nll/weak-barrier retain MORE trait (more negative dAuth) at similar coh, the " + "barrier was killing the trait. If ALL rows sit near dAuth~0, the kept data is near-base.") +print("\nbarrier sweep (re-heal #79 kept data, vary the regulariser only; dAuth/dCoh vs base):") +print("strength = kl-barrier weight (kl_rev) OR AdamW weight_decay (wd); tau = kl deadband, n/a for wd/nll\n") +print(tabulate(rows, headers="keys", tablefmt="github", floatfmt="+.3f")) +print(f"\nbase auth_nats={base_m['auth_nats']:+.3f} coh={base_m['coherence']:.3f} | source {run_dir.name}") diff --git a/src/steer_heal/config.py b/src/steer_heal/config.py index 8893f2f..76cacea 100644 --- a/src/steer_heal/config.py +++ b/src/steer_heal/config.py @@ -43,8 +43,13 @@ class RunConfig: min_train: int = 20 # assert at least this many kept completions, else steering/filter starved gen_max_new_tokens: int = 256 max_len: int = 1024 + # repetition is incoherence the ppl filter CANNOT see (looped text is low-ppl = predictable), so + # stop it at generation, not post-hoc: penalty softly discourages all repeats, no_repeat_ngram + # hard-blocks any trigram repeat (kills "instead their instead their" loops at the source). + repetition_penalty: float = 1.3 + no_repeat_ngram_size: int = 3 ppl_tau: float = 50.0 # drop completions with ppl-under-original above this - rep_tau: float = 0.3 # drop completions whose max n-gram repeat fraction exceeds this + rep_tau: float = 0.3 # drop completions whose max n-gram repeat fraction exceeds this (residual net) # ── heal (U2): one objective + divergence-to-ORIGINAL barrier ── reg: Literal["nll", "kl_fwd", "kl_rev", "wd"] = "kl_rev" diff --git a/src/steer_heal/heal.py b/src/steer_heal/heal.py index 3966a7e..19b52b3 100644 --- a/src/steer_heal/heal.py +++ b/src/steer_heal/heal.py @@ -14,7 +14,7 @@ import torch from loguru import logger from torch.nn import functional as F from tqdm import tqdm -from transformers import get_cosine_schedule_with_warmup +from transformers import BatchEncoding, get_cosine_schedule_with_warmup from steer_heal.config import RunConfig from steer_heal.ws.adapter import ModulatedLoRA @@ -31,20 +31,16 @@ def _gnorm(grads) -> float: # L2 norm of a flat concat of (possibly None) param def _encode(tok, prompt: str, completion: str, max_len: int, device): - ids = tok(prompt + completion, return_tensors="pt", truncation=True, max_length=max_len).to(device) - prompt_ids = tok(prompt, return_tensors="pt").input_ids[0].to(device) + # Tokenize prompt and completion SEPARATELY then concatenate the ids, so the prompt is always a + # clean token-prefix -- no BPE merge can span the boundary (which would silently shift the SFT + # mask by a token). prompt keeps generation's tokenization (add_special_tokens default, matching + # generate_steered's tok(text)); the completion adds no specials. Truncation keeps the FRONT. + prompt_ids = tok(prompt, return_tensors="pt").input_ids[0] + comp_ids = tok(completion, return_tensors="pt", add_special_tokens=False).input_ids[0] + input_ids = torch.cat([prompt_ids, comp_ids])[:max_len].unsqueeze(0).to(device) n_prompt = prompt_ids.shape[0] - L = ids.input_ids.shape[1] - # Assert the prompt tokenizes as a clean PREFIX of prompt+completion. If a BPE merge spans - # the boundary, n_prompt is wrong and the SFT mask silently shifts by a token. Truncation - # keeps the FRONT (whole prompt + partial completion), so check the overlap that survives -- - # min(n_prompt, L). This always runs, including the max_len boundary the earlier guard skipped - # (external review: a merge at exactly max_len escaped the < max_len check). - n_check = min(n_prompt, L) - assert torch.equal(ids.input_ids[0, :n_check], prompt_ids[:n_check]), ( - "prompt is not a token-prefix of prompt+completion (BPE boundary merge); " - "the SFT loss mask would be misaligned by a token." - ) + L = input_ids.shape[1] + ids = BatchEncoding({"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids)}) tgt_is_completion = torch.arange(1, L, device=device) >= n_prompt # mask over next-token targets return ids, tgt_is_completion @@ -101,7 +97,10 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg: ">>1 -> barrier dominates, it is undoing the trait the SFT installs (over-tight: lower lam or raise tau); " "~1 -> balanced; 0 -> barrier inert (kl 0 g_bar = _gnorm(torch.autograd.grad(barrier, params, retain_graph=True, allow_unused=True)) if barrier_live else 0.0 pressure = g_bar / g_nll if g_nll > 0 else float("nan") + cur_lr = sched.get_last_lr()[0] # lr applied to THIS step (before sched.step below) loss.backward() gnorm = torch.nn.utils.clip_grad_norm_(params, 1.0) opt.step() @@ -153,7 +153,7 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg: opt.zero_grad() if log_now: logger.info(f" {step:4d} {sft.item():5.2f} {div.detach().item():4.2f} " - f"{g_nll:5.1f} {g_bar:5.1f} {pressure:11.2f} {loss.item():5.2f} {float(gnorm):5.1f}") + f"{g_nll:5.1f} {g_bar:5.1f} {pressure:11.2f} {loss.item():5.2f} {float(gnorm):5.1f} {cur_lr:.2e}") pbar.set_postfix(nll=f"{sft.item():.2f}", kl=f"{div.detach().item():.2f}", gn=f"{float(gnorm):.1f}") pbar.update(1) step += 1 diff --git a/src/steer_heal/steering.py b/src/steer_heal/steering.py index f685835..4161ca9 100644 --- a/src/steer_heal/steering.py +++ b/src/steer_heal/steering.py @@ -59,7 +59,10 @@ def teacher_vec(model, tok, cfg: RunConfig): def _gen_one(model, tok, text, cfg): ids = tok(text, return_tensors="pt").to(model.device) gen = model.generate(**ids, max_new_tokens=cfg.gen_max_new_tokens, do_sample=True, - temperature=1.0, top_p=0.95, pad_token_id=tok.pad_token_id) + temperature=1.0, top_p=0.95, + repetition_penalty=cfg.repetition_penalty, + no_repeat_ngram_size=cfg.no_repeat_ngram_size, + pad_token_id=tok.pad_token_id) return tok.decode(gen[0, ids.input_ids.shape[1]:], skip_special_tokens=True)