heal loop: _encode BPE root-fix, gen-time repetition controls, barrier sweep on degenerate rounds

_encode: tokenize prompt+completion separately and cat ids so the prompt is
always a clean token-prefix (no BPE merge spans the boundary). Drops the assert
that killed #87 at round 2. Returns BatchEncoding.

generation: repetition_penalty=1.3 + no_repeat_ngram_size=3. Repetition is
incoherence the ppl filter cannot see (loops are low-ppl = predictable); the
#89 loop died of "instead their instead their" by round 6, so stop it at the
source. Wired through steering._gen_one for both steered and plain gen.

diag_barrier: gen_round arg (re-heal a chosen round's kept data, not just clean
round 0) + a "tau" deadband sweep mode. Lets us test whether the barrier earns
its place on the degenerate round-1/2 data where healing is actually needed.

journal: entries (d) phantom-KL-init was a wrong diagnosis, (e) barrier-strength
sweep -- barrier throttles trait and buys no coherence at the coherent dose.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-05 06:36:09 +08:00
parent f280a67521
commit 4e802bb3ab
5 changed files with 305 additions and 18 deletions
+160
View File
@@ -443,3 +443,163 @@ provisional; coherence and qualitative carry the Gate 1 claim.
that actually moves the target). (2) Metric infra: wire steering-lite's loading-weighted Delta-logit
auth_sep (results.py / aggregate_flips) instead of my 7-way-logp mean, OR robustify to median. Plan B1
(super_sspace/sspace) if still broad; recorded in spec.
## 2026-06-04 (d) -- the "phantom-KL init bug" was a WRONG diagnosis (init is fine); trait still does not transfer
**Introduction.** I claimed the heal had two bugs: (1) barrier KL starting at ~0.6 before training,
blamed on a non-zero LoRA B init, and (2) train SFT loss not descending, blamed on beta2=0.999. The
user pushed back (scout mindset): mean=1e-4 std=1e-4 B init is within normal range, and "you only have
confirmation if it learns". On checking, claim (1) is REFUTED and claim (2) is unconfirmed. The
question that actually matters is unchanged: why does a fit adapter not move the trait? Continues
entry (a) and the task4/task10 data-ceiling hypothesis.
**Methods.** Commit `f280a67`, gemma-3-4b-it, reg=kl_rev, seed 42, 1 round, n_prompts 16, tinymfv
classic eval (think_tokens 128). The commit BUNDLED five changes (a mistake, see Discussion): LoRA
init B=normal(mean=1e-4)->B=0, betas (0.9,0.999)->(0.9,0.95), cosine-with-warmup (0.1) schedule,
r 8->32 / alpha 64 / layer_range (0.0,1.0)->(0.2,0.8), epochs 2->6, plus a new per-epoch val nll.
The decisive evidence is NOT from #79 but from #78's verbose log (`logs/20260604T172126_verbose.log`,
OLD init), which lets me read the round-0 step-0 KL the init claim hinges on.
**Results.**
| epoch | train_nll | val_nll |
|-------|-----------|---------|
| 0 | 1.710 | 1.365 |
| 1 | 1.162 | 1.417 |
| 3 | 0.931 | 1.201 |
| 5 | 0.806 | 1.240 |
Table 1. Per-epoch mean SFT nll on the 42 train completions and the 6 held-out val completions, heal
round 0, run #79. train_nll falls monotonically; val_nll wanders ~1.2-1.4 (n=6, noisy).
| stage | auth_nats | coherence |
|---------|-----------|-----------|
| base | -2.354 | 0.996 |
| steered | -3.517 | 0.992 |
| healed | -2.464 | 0.999 |
Table 2. tinymfv trait (auth_nats, log marginal blame-mass on Authority, DOWN = more trait) and
coherence (p_ans_any) at the three pipeline stages of round 0, run #79. coh_cost = |dCoh|/|dAuth| =
0.027, not surgical (dCare=+0.28 moved more than dAuth=-0.11).
Provenance:
- Commit: `f280a67` (heal init/schedule/betas/val fixes).
- Run command (#79): `PYTHONUNBUFFERED=1 STEER_ATTN_IMPL=eager uv run python -m steer_heal.run --reg kl_rev --n-rounds 1 --n-prompts 16`
- Run dir: `out/20260604T194133_gemma-3-4b-it_kl_rev_s42/` (events.jsonl, ckpt/r0.safetensors).
- Log: `pueue log 79 --full`; Table 1 cells are the `epoch N: train_nll=.. val_nll=..` lines; Table 2
base/steered are the stage-pareto table, healed is the `round 0:` line and `eval:` auth_nats=-2.46.
- REFUTATION of the init claim: #78 round-0 heal (OLD init B=normal, NO baked history), verbose log
`heal_round:119` rows: step 0 nll=1.90 **kl=0.00**, step 4 kl=0.21, step 8 kl=0.33, step 12 kl=0.80.
KL is ~0 at init with the old init, then RISES as SFT installs the trait. So the init did not produce
a phantom KL. The kl=0.64-at-step-0 the user pasted was ROUND 5 (line 1653 sits between ROUND 5 at
1367 and ROUND 6 at 1709), i.e. five rounds of baked history = real cross-round drift, which is what
the barrier is meant to measure. B=0 is harmless and standard but fixed nothing.
- train_nll did descend in #79 (1.71->0.81) but this is UNATTRIBUTED (5 changes bundled) and #78 never
logged per-epoch train_nll, so "loss was not descending" was never actually established -- it was a
read of bs=1 per-step noise.
Healed auth_nats moves only -0.11 from base (-2.354 to -2.464) in #79, vs steered -3.517. #78 r0 healed
was -2.69. Both small, both near base, metric noisy (emitted_close=0/264). The changes did not improve
trait transfer.
**Discussion (speculative).** I made the classic ml-debug error: pattern-matched a symptom (KL>0 at
step 0) to a tidy mechanism (bad init), committed a fix, and declared victory without the isolating
measurement. The user caught it. The measurement (#78 round-0 step-0 kl=0.00, old init) refutes the
init story outright; the 0.64 was baked history. The premise behind the second claim (loss not
descending) was never measured at epoch level either. Net: I changed five things, can attribute
nothing, and the only metric that matters (trait transfer) is unchanged. What IS supported, by the
structural-ceiling lens: fixing optimiser-side knobs did not move the trait, so the trait is not
optimiser-limited -- it is the data (filter keeps near-base completions, entries a/(diag_heal)) or the
parameterisation/eval. Genuinely open between those.
**Next.** (1) The discriminating test is overfit-one-batch on a KNOWN trait-laden completion: can the
adapter reproduce defiant-of-authority text (expressiveness) AND does tinymfv then read the trait
(data/eval)? That splits data-ceiling from can't-express/can't-see. (2) #80 clean 10-round is running;
reframed, it tests whether the stall persists (it is NOT a fix validation). (3) Do not bundle changes
again; ablate one at a time if attribution matters. (4) lam retune still parked.
## 2026-06-04 (e) -- barrier-strength sweep: the heal barrier only throttles the trait and buys no coherence at the coherent dose; nll (no barrier) is best
**Introduction.** Entry (d) left it open whether the trait fails to transfer because the kept data is
near-base (data ceiling) or because the barrier suppresses it. The user pushed on this: "you haven't
even tried wd and kl values?". So I re-healed ONE run's cached kept completions (the 48 from #79) with
the SAME LoRA-A init seed, varying ONLY the regulariser (reg, lam, tau). Same data + same init means
the only thing that can move healed auth_nats is the barrier. Pre-registered: outcome 1 = monotone
weaker-barrier -> more-trait (the barrier throttles); outcome 2 = all dAuth ~ 0 incl nll (data
ceiling); outcome 3 = inconclusive. Continues entry (a)/(d).
**Methods.** Commit `f280a67`, gemma-3-4b-it, seed 42 (`torch.manual_seed(cfg.seed)` per config so the
A-init is identical), 6 epochs, lr 1e-4 cosine+warmup, lora r=32 alpha=64 layers (0.2,0.8). Re-heal
harness `scripts/diag_barrier.py` reads #79's `events.jsonl` gen event, keeps the 48 keep==True
completions, re-trains a fresh adapter per config, bakes it, runs tinymfv (think_tokens 128). Three
families across three pueue runs: #82 kl_rev with the tau=0.5 hinge, #86 kl_rev with tau=0 (pure linear
barrier = lam*div, the w2s form), #85 weight-decay decades 0.1..100. Base auth_nats=-2.354, coh=0.996.
**Results.**
| reg / family | strength | dAuth | coh | heal_nll |
|-----------------|----------|-------|-------|----------|
| nll (no barrier)| 0 | -1.247| 1.000 | 0.199 |
| kl_rev linear | 0.03 | -1.053| 0.999 | 0.204 |
| kl_rev linear | 0.10 | -0.664| 1.000 | 0.232 |
| kl_rev linear | 0.30 | -0.173| 0.999 | 0.471 |
| kl_rev linear | 1.00 | -0.141| 1.000 | 0.970 |
Table 1. Pure-linear kl_rev barrier (tau=0), #86. `strength` = lam, the barrier weight. dAuth =
healed auth_nats minus base (more negative = more trait retained; DOWN = more trait). coh = p_ans_any.
heal_nll = converged SFT loss (last-5-step mean). Trait falls monotonically as the barrier strengthens;
heal_nll rises in step (the barrier is fighting the SFT objective); coh never leaves ~1.0.
| reg | weight_decay | dAuth | coh |
|-----|--------------|-------|-------|
| nll | 0 | -1.247| 1.000 |
| wd | 0.1 | -1.247| 1.000 |
| wd | 1.0 | -1.247| 1.000 |
| wd | 3.0 | -1.247| 1.000 |
| wd | 10.0 | -1.247| 1.000 |
| wd | 30.0 | -1.251| 0.999 |
| wd | 100.0 | -0.519| 1.000 |
Table 2. AdamW decoupled weight decay on the adapter, #85. (The log table also prints a tau column;
it is meaningless for wd and is dropped here.) dAuth is byte-identical to nll up to wd=30, then halves
at wd=100. coh never leaves ~1.0.
Provenance:
- Commit: `f280a67`. Harness: `scripts/diag_barrier.py <run_dir> <mode>` (modes barrier/tau0/wd).
- Source data: `out/20260604T194133_gemma-3-4b-it_kl_rev_s42/events.jsonl`, the 48 keep==True
completions of the gen event (entry (d)'s #79).
- Run commands: #82 `... diag_barrier.py out/...s42/ barrier`; #86 `... barrier` ... `tau0`; #85 `... wd`.
- Logs / cells: each dAuth/coh is the `<reg> strength=.. : auth=.. (dAuth=..) coh=..` line and the
end-of-log `barrier sweep (re-heal #79 ...)` table. #86 `pueue log 86 --full`; #85 `pueue log 85
--full`; #82 `pueue log 82 --full`. #85 runs older code that prints `lam=`/`tau=` instead of
`strength=`; values are unaffected.
- #82 hinge (tau=0.5) for cross-reference: nll -1.247, kl_rev lam 0.03 -0.93 / 0.1 -0.40 / 0.3 -0.17 /
1.0 -0.17; lam 0.3 tau 1.0 -0.31 (raising tau weakens it); wd 0.01 and 0.1 byte-identical to nll.
Outcome 1 holds, decisively and in triplicate: weaker barrier -> more trait, monotone, across the kl
hinge (#82), the kl linear form (#86), and weight decay (#85). nll retains the full -1.247 at coh
1.000; every barrier strictly reduces |dAuth| while leaving coherence at ~1.0.
**Discussion (speculative).** My read: at this (coherent) operating dose the barrier is pure cost. It
removes trait and never buys coherence, because coherence was already ~1.0 with no barrier, so the
relu(div-tau) penalty has nothing to fix and only pulls the adapter back toward the original. The two
non-kl families converge on the same story by different mechanisms: wd just shrinks the whole adapter
toward no-op (hence the knee only appears at wd=100, where per-step decoupled shrink lr*wd=1e-2
compounds to ~0.92x per step over 252 steps and finally bites), and the kl barrier pulls the output
distribution back toward base. Neither is a selective incoherence-cleaner here; both are volume knobs
on the adapter. This refutes the data-ceiling reading of entries (a)/(d) for THIS data: nll reaching
dAuth=-1.247 (it even exceeds the steered teacher's -1.16 of #79) proves the 48 kept completions carry
plenty of trait. The earlier negative heals (task4/10/19) all ran lam=1.0, i.e. the right-hand end of
Table 1 where the trait is throttled to ~-0.14. The big caveat: this is the COHERENT dose, where the
barrier can only hurt. Its hypothesised value is the coherence-breaking dose (filter off, or a higher
C) where nll WOULD lose coherence and the barrier might pay for itself; that is untested here.
Alternative hypothesis I cannot yet exclude: n=1 per cell, so a +-0.1 nat seed wobble could fake part
of the monotone tail (though the trend spans >1 nat across 5 points, far beyond plausible single-seed
noise). Distinguished by the 3-seed repeat (task25).
**Next.** (1) Launched the paired 10-round to test the loop, same seed 42: #87 nll (barrier off,
control) and #88 kl_rev lam=0.1 tau=0 (gentle active barrier, 53% trait single-round). The loop is the
one place cumulative incoherence can appear, so it is where the barrier might finally earn its place;
the contrast is whether nll's coherence decays over rounds while #88's holds. (2) 3-seed noise floor on
the headline (task25). (3) The real barrier test remains filter-off at a coherence-breaking dose
(task11/22), still parked.
+119
View File
@@ -0,0 +1,119 @@
"""Barrier-strength sweep: is the trait failing to transfer because the barrier (kl/wd) is too
strong, or because the kept data is near-base? Re-heal from ONE run's cached kept completions with
the SAME init seed, varying ONLY (reg, lam, tau). Same data + same init => the only thing that moves
healed auth_nats is the barrier.
reg=nll is the ablation: barrier OFF. If nll ALSO lands near base, the data is the ceiling, not the
barrier. If nll (or weak kl/wd) retains MORE trait than kl_rev lam=1.0, the barrier was killing it.
Run: uv run python scripts/diag_barrier.py out/20260604T194133_gemma-3-4b-it_kl_rev_s42/
"""
import dataclasses
import sys
from pathlib import Path
import srsly
import torch
from loguru import logger
from tabulate import tabulate
from transformers import AutoModelForCausalLM, AutoTokenizer
sys.path.insert(0, "src")
from steer_heal.config import RunConfig # noqa: E402
from steer_heal.eval import evaluate_model # noqa: E402
from steer_heal.heal import heal_round # noqa: E402
from steer_heal.ws.bake import baked # noqa: E402
run_dir = Path(sys.argv[1])
mode = sys.argv[2] if len(sys.argv) > 2 else "barrier" # "barrier" (kl sweep) or "wd" (decay decade sweep)
gen_round = int(sys.argv[3]) if len(sys.argv) > 3 else 0 # which round's kept data to re-heal (0 = clean; later = messier)
base_cfg = RunConfig()
# (reg, lam, tau) grids. nll = barrier off (ablation) and the shared trait-ceiling reference.
GRIDS = {
# kl_rev strength + a tau probe. lam 0.03 (w2s) .. 1.0 (current default).
"barrier": [
("nll", 0.0, 0.5), # ablation: no barrier at all
("kl_rev", 0.03, 0.5),
("kl_rev", 0.1, 0.5),
("kl_rev", 0.3, 0.5),
("kl_rev", 1.0, 0.5),
("kl_rev", 0.3, 1.0), # weaker via higher tau (engages later)
],
# pure linear kl_rev: tau=0 => barrier = lam*relu(div) = lam*div, always on, no deadband
# (the w2s form). Cleaner knob than the hinge; compare against the tau=0.5 rows in "barrier".
"tau0": [
("nll", 0.0, 0.0),
("kl_rev", 0.03, 0.0),
("kl_rev", 0.1, 0.0),
("kl_rev", 0.3, 0.0),
("kl_rev", 1.0, 0.0),
],
# tau sweep: fix lam (middling barrier) and vary the deadband tau. Higher tau = barrier engages
# only on larger divergence = weaker. Shows whether a deadband helps on degenerate (round 2) data.
"tau": [
("nll", 0.0, 0.0),
("kl_rev", 0.3, 0.0),
("kl_rev", 0.3, 0.25),
("kl_rev", 0.3, 0.5),
("kl_rev", 0.3, 1.0),
("kl_rev", 0.3, 2.0),
],
# weight decay: a WEIGHTS-space constraint (AdamW decoupled decay, tau irrelevant). Its per-step
# shrink is lr*wd, and lr~1e-4 is tiny, so #82 found wd<=0.1 byte-identical to nll (~0.1% shrink
# over 252 steps). Sweep up to 100 to find where cumulative shrink (252*lr*wd) reaches order-1.
"wd": [
("nll", 0.0, 0.5),
("wd", 1e-1, 0.5),
("wd", 1.0, 0.5),
("wd", 3.0, 0.5),
("wd", 10.0, 0.5),
("wd", 30.0, 0.5),
("wd", 100.0, 0.5),
],
}
GRID = GRIDS[mode]
logger.info(f"barrier sweep mode={mode}: {len(GRID)} configs")
# kept completions (keep==True) from a CHOSEN round of the source run. round 0 = clean steered-on-base
# data; later rounds = data after the loop started degenerating (repetition), the regime where the
# barrier is hypothesised to matter (it was pure-cost on clean round-0 data, #82/85/86).
gen = next(e for e in srsly.read_jsonl(run_dir / "events.jsonl")
if e["stage"] == "gen" and e["round"] == gen_round)
kept = [{"prompt": s["prompt"], "completion": s["completion"]} for s in gen["scored"] if s["keep"]]
logger.info(f"loaded {len(kept)} kept completions from {run_dir.name} round {gen_round}")
tok = AutoTokenizer.from_pretrained(base_cfg.model)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
base_cfg.model, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="eager"
).eval()
base_m = evaluate_model(model, tok, base_cfg)
logger.info(f"base: auth_nats={base_m['auth_nats']:+.3f} care_nats={base_m['care_nats']:+.3f} coh={base_m['coherence']:.3f}")
rows = []
for reg, lam, tau in GRID:
cfg = dataclasses.replace(base_cfg, reg=reg, lam=lam, tau=tau)
torch.manual_seed(cfg.seed) # identical LoRA-A init across barrier values -> only the barrier differs
lora, spec, heal_nll = heal_round(model, tok, kept, [], cfg)
with baked(model, [spec]):
m = evaluate_model(model, tok, cfg)
dauth = m["auth_nats"] - base_m["auth_nats"]
dcoh = m["coherence"] - base_m["coherence"]
# ONE strength knob per row: kl-barrier weight for kl_rev/kl_fwd, AdamW weight_decay for wd,
# ignored for nll. tau (kl deadband) only applies to the kl regs -> "-" otherwise.
is_kl = reg in ("kl_rev", "kl_fwd")
rows.append({"reg": reg, "strength": lam, "tau(kl only)": (f"{tau:.1f}" if is_kl else "-"),
"heal_nll↓": heal_nll, "auth_nats↓": m["auth_nats"], "dAuth↓": dauth,
"care_nats": m["care_nats"], "coh→": m["coherence"], "dCoh": dcoh})
logger.info(f" {reg} strength={lam}{f' tau={tau}' if is_kl else ''}: "
f"auth={m['auth_nats']:+.3f} (dAuth={dauth:+.3f}) coh={m['coherence']:.3f}")
logger.info("SHOULD: if nll/weak-barrier retain MORE trait (more negative dAuth) at similar coh, the "
"barrier was killing the trait. If ALL rows sit near dAuth~0, the kept data is near-base.")
print("\nbarrier sweep (re-heal #79 kept data, vary the regulariser only; dAuth/dCoh vs base):")
print("strength = kl-barrier weight (kl_rev) OR AdamW weight_decay (wd); tau = kl deadband, n/a for wd/nll\n")
print(tabulate(rows, headers="keys", tablefmt="github", floatfmt="+.3f"))
print(f"\nbase auth_nats={base_m['auth_nats']:+.3f} coh={base_m['coherence']:.3f} | source {run_dir.name}")
+6 -1
View File
@@ -43,8 +43,13 @@ class RunConfig:
min_train: int = 20 # assert at least this many kept completions, else steering/filter starved
gen_max_new_tokens: int = 256
max_len: int = 1024
# repetition is incoherence the ppl filter CANNOT see (looped text is low-ppl = predictable), so
# stop it at generation, not post-hoc: penalty softly discourages all repeats, no_repeat_ngram
# hard-blocks any trigram repeat (kills "instead their instead their" loops at the source).
repetition_penalty: float = 1.3
no_repeat_ngram_size: int = 3
ppl_tau: float = 50.0 # drop completions with ppl-under-original above this
rep_tau: float = 0.3 # drop completions whose max n-gram repeat fraction exceeds this
rep_tau: float = 0.3 # drop completions whose max n-gram repeat fraction exceeds this (residual net)
# ── heal (U2): one objective + divergence-to-ORIGINAL barrier ──
reg: Literal["nll", "kl_fwd", "kl_rev", "wd"] = "kl_rev"
+16 -16
View File
@@ -14,7 +14,7 @@ import torch
from loguru import logger
from torch.nn import functional as F
from tqdm import tqdm
from transformers import get_cosine_schedule_with_warmup
from transformers import BatchEncoding, get_cosine_schedule_with_warmup
from steer_heal.config import RunConfig
from steer_heal.ws.adapter import ModulatedLoRA
@@ -31,20 +31,16 @@ def _gnorm(grads) -> float: # L2 norm of a flat concat of (possibly None) param
def _encode(tok, prompt: str, completion: str, max_len: int, device):
ids = tok(prompt + completion, return_tensors="pt", truncation=True, max_length=max_len).to(device)
prompt_ids = tok(prompt, return_tensors="pt").input_ids[0].to(device)
# Tokenize prompt and completion SEPARATELY then concatenate the ids, so the prompt is always a
# clean token-prefix -- no BPE merge can span the boundary (which would silently shift the SFT
# mask by a token). prompt keeps generation's tokenization (add_special_tokens default, matching
# generate_steered's tok(text)); the completion adds no specials. Truncation keeps the FRONT.
prompt_ids = tok(prompt, return_tensors="pt").input_ids[0]
comp_ids = tok(completion, return_tensors="pt", add_special_tokens=False).input_ids[0]
input_ids = torch.cat([prompt_ids, comp_ids])[:max_len].unsqueeze(0).to(device)
n_prompt = prompt_ids.shape[0]
L = ids.input_ids.shape[1]
# Assert the prompt tokenizes as a clean PREFIX of prompt+completion. If a BPE merge spans
# the boundary, n_prompt is wrong and the SFT mask silently shifts by a token. Truncation
# keeps the FRONT (whole prompt + partial completion), so check the overlap that survives --
# min(n_prompt, L). This always runs, including the max_len boundary the earlier guard skipped
# (external review: a merge at exactly max_len escaped the < max_len check).
n_check = min(n_prompt, L)
assert torch.equal(ids.input_ids[0, :n_check], prompt_ids[:n_check]), (
"prompt is not a token-prefix of prompt+completion (BPE boundary merge); "
"the SFT loss mask would be misaligned by a token."
)
L = input_ids.shape[1]
ids = BatchEncoding({"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids)})
tgt_is_completion = torch.arange(1, L, device=device) >= n_prompt # mask over next-token targets
return ids, tgt_is_completion
@@ -101,7 +97,10 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
">>1 -> barrier dominates, it is undoing the trait the SFT installs (over-tight: lower lam or raise tau); "
"~1 -> balanced; 0 -> barrier inert (kl<tau, or reg=nll/wd where decay acts in the optimiser, not the loss)."
)
logger.info(" step nll↓ kl g_nll g_bar g_bar/g_nll loss↓ gnorm")
logger.info(" step nll↓ kl g_nll g_bar g_bar/g_nll loss↓ gnorm lr")
# init val nll BEFORE any training = the baseline; epoch val nlls only mean something against it.
# With B=0 the fresh adapter is a no-op, so this should equal the base model's nll on val.
logger.info(f" epoch init: train_nll= nan val_nll={_val_nll(model, tok, val_kept, hist_specs, lora, cfg):.3f} lr={sched.get_last_lr()[0]:.1e}")
pbar = tqdm(total=n_steps, desc=f"heal[{cfg.reg}]", mininterval=120, maxinterval=120)
step = 0
nlls = [] # per-step SFT loss; final = mean of last 5, the heal-stage number for the round table
@@ -146,6 +145,7 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
barrier_live = barrier.requires_grad and (div - cfg.tau).item() > 0
g_bar = _gnorm(torch.autograd.grad(barrier, params, retain_graph=True, allow_unused=True)) if barrier_live else 0.0
pressure = g_bar / g_nll if g_nll > 0 else float("nan")
cur_lr = sched.get_last_lr()[0] # lr applied to THIS step (before sched.step below)
loss.backward()
gnorm = torch.nn.utils.clip_grad_norm_(params, 1.0)
opt.step()
@@ -153,7 +153,7 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
opt.zero_grad()
if log_now:
logger.info(f" {step:4d} {sft.item():5.2f} {div.detach().item():4.2f} "
f"{g_nll:5.1f} {g_bar:5.1f} {pressure:11.2f} {loss.item():5.2f} {float(gnorm):5.1f}")
f"{g_nll:5.1f} {g_bar:5.1f} {pressure:11.2f} {loss.item():5.2f} {float(gnorm):5.1f} {cur_lr:.2e}")
pbar.set_postfix(nll=f"{sft.item():.2f}", kl=f"{div.detach().item():.2f}", gn=f"{float(gnorm):.1f}")
pbar.update(1)
step += 1
+4 -1
View File
@@ -59,7 +59,10 @@ def teacher_vec(model, tok, cfg: RunConfig):
def _gen_one(model, tok, text, cfg):
ids = tok(text, return_tensors="pt").to(model.device)
gen = model.generate(**ids, max_new_tokens=cfg.gen_max_new_tokens, do_sample=True,
temperature=1.0, top_p=0.95, pad_token_id=tok.pad_token_id)
temperature=1.0, top_p=0.95,
repetition_penalty=cfg.repetition_penalty,
no_repeat_ngram_size=cfg.no_repeat_ngram_size,
pad_token_id=tok.pad_token_id)
return tok.decode(gen[0, ids.input_ids.shape[1]:], skip_special_tokens=True)