mirror of
https://github.com/wassname/steer-heal-love.git
synced 2026-06-27 16:47:16 +08:00
heal loop: _encode BPE root-fix, gen-time repetition controls, barrier sweep on degenerate rounds
_encode: tokenize prompt+completion separately and cat ids so the prompt is always a clean token-prefix (no BPE merge spans the boundary). Drops the assert that killed #87 at round 2. Returns BatchEncoding. generation: repetition_penalty=1.3 + no_repeat_ngram_size=3. Repetition is incoherence the ppl filter cannot see (loops are low-ppl = predictable); the #89 loop died of "instead their instead their" by round 6, so stop it at the source. Wired through steering._gen_one for both steered and plain gen. diag_barrier: gen_round arg (re-heal a chosen round's kept data, not just clean round 0) + a "tau" deadband sweep mode. Lets us test whether the barrier earns its place on the degenerate round-1/2 data where healing is actually needed. journal: entries (d) phantom-KL-init was a wrong diagnosis, (e) barrier-strength sweep -- barrier throttles trait and buys no coherence at the coherent dose. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -443,3 +443,163 @@ provisional; coherence and qualitative carry the Gate 1 claim.
|
|||||||
that actually moves the target). (2) Metric infra: wire steering-lite's loading-weighted Delta-logit
|
that actually moves the target). (2) Metric infra: wire steering-lite's loading-weighted Delta-logit
|
||||||
auth_sep (results.py / aggregate_flips) instead of my 7-way-logp mean, OR robustify to median. Plan B1
|
auth_sep (results.py / aggregate_flips) instead of my 7-way-logp mean, OR robustify to median. Plan B1
|
||||||
(super_sspace/sspace) if still broad; recorded in spec.
|
(super_sspace/sspace) if still broad; recorded in spec.
|
||||||
|
|
||||||
|
## 2026-06-04 (d) -- the "phantom-KL init bug" was a WRONG diagnosis (init is fine); trait still does not transfer
|
||||||
|
|
||||||
|
**Introduction.** I claimed the heal had two bugs: (1) barrier KL starting at ~0.6 before training,
|
||||||
|
blamed on a non-zero LoRA B init, and (2) train SFT loss not descending, blamed on beta2=0.999. The
|
||||||
|
user pushed back (scout mindset): mean=1e-4 std=1e-4 B init is within normal range, and "you only have
|
||||||
|
confirmation if it learns". On checking, claim (1) is REFUTED and claim (2) is unconfirmed. The
|
||||||
|
question that actually matters is unchanged: why does a fit adapter not move the trait? Continues
|
||||||
|
entry (a) and the task4/task10 data-ceiling hypothesis.
|
||||||
|
|
||||||
|
**Methods.** Commit `f280a67`, gemma-3-4b-it, reg=kl_rev, seed 42, 1 round, n_prompts 16, tinymfv
|
||||||
|
classic eval (think_tokens 128). The commit BUNDLED five changes (a mistake, see Discussion): LoRA
|
||||||
|
init B=normal(mean=1e-4)->B=0, betas (0.9,0.999)->(0.9,0.95), cosine-with-warmup (0.1) schedule,
|
||||||
|
r 8->32 / alpha 64 / layer_range (0.0,1.0)->(0.2,0.8), epochs 2->6, plus a new per-epoch val nll.
|
||||||
|
The decisive evidence is NOT from #79 but from #78's verbose log (`logs/20260604T172126_verbose.log`,
|
||||||
|
OLD init), which lets me read the round-0 step-0 KL the init claim hinges on.
|
||||||
|
|
||||||
|
**Results.**
|
||||||
|
|
||||||
|
| epoch | train_nll | val_nll |
|
||||||
|
|-------|-----------|---------|
|
||||||
|
| 0 | 1.710 | 1.365 |
|
||||||
|
| 1 | 1.162 | 1.417 |
|
||||||
|
| 3 | 0.931 | 1.201 |
|
||||||
|
| 5 | 0.806 | 1.240 |
|
||||||
|
|
||||||
|
Table 1. Per-epoch mean SFT nll on the 42 train completions and the 6 held-out val completions, heal
|
||||||
|
round 0, run #79. train_nll falls monotonically; val_nll wanders ~1.2-1.4 (n=6, noisy).
|
||||||
|
|
||||||
|
| stage | auth_nats | coherence |
|
||||||
|
|---------|-----------|-----------|
|
||||||
|
| base | -2.354 | 0.996 |
|
||||||
|
| steered | -3.517 | 0.992 |
|
||||||
|
| healed | -2.464 | 0.999 |
|
||||||
|
|
||||||
|
Table 2. tinymfv trait (auth_nats, log marginal blame-mass on Authority, DOWN = more trait) and
|
||||||
|
coherence (p_ans_any) at the three pipeline stages of round 0, run #79. coh_cost = |dCoh|/|dAuth| =
|
||||||
|
0.027, not surgical (dCare=+0.28 moved more than dAuth=-0.11).
|
||||||
|
|
||||||
|
Provenance:
|
||||||
|
- Commit: `f280a67` (heal init/schedule/betas/val fixes).
|
||||||
|
- Run command (#79): `PYTHONUNBUFFERED=1 STEER_ATTN_IMPL=eager uv run python -m steer_heal.run --reg kl_rev --n-rounds 1 --n-prompts 16`
|
||||||
|
- Run dir: `out/20260604T194133_gemma-3-4b-it_kl_rev_s42/` (events.jsonl, ckpt/r0.safetensors).
|
||||||
|
- Log: `pueue log 79 --full`; Table 1 cells are the `epoch N: train_nll=.. val_nll=..` lines; Table 2
|
||||||
|
base/steered are the stage-pareto table, healed is the `round 0:` line and `eval:` auth_nats=-2.46.
|
||||||
|
- REFUTATION of the init claim: #78 round-0 heal (OLD init B=normal, NO baked history), verbose log
|
||||||
|
`heal_round:119` rows: step 0 nll=1.90 **kl=0.00**, step 4 kl=0.21, step 8 kl=0.33, step 12 kl=0.80.
|
||||||
|
KL is ~0 at init with the old init, then RISES as SFT installs the trait. So the init did not produce
|
||||||
|
a phantom KL. The kl=0.64-at-step-0 the user pasted was ROUND 5 (line 1653 sits between ROUND 5 at
|
||||||
|
1367 and ROUND 6 at 1709), i.e. five rounds of baked history = real cross-round drift, which is what
|
||||||
|
the barrier is meant to measure. B=0 is harmless and standard but fixed nothing.
|
||||||
|
- train_nll did descend in #79 (1.71->0.81) but this is UNATTRIBUTED (5 changes bundled) and #78 never
|
||||||
|
logged per-epoch train_nll, so "loss was not descending" was never actually established -- it was a
|
||||||
|
read of bs=1 per-step noise.
|
||||||
|
|
||||||
|
Healed auth_nats moves only -0.11 from base (-2.354 to -2.464) in #79, vs steered -3.517. #78 r0 healed
|
||||||
|
was -2.69. Both small, both near base, metric noisy (emitted_close=0/264). The changes did not improve
|
||||||
|
trait transfer.
|
||||||
|
|
||||||
|
**Discussion (speculative).** I made the classic ml-debug error: pattern-matched a symptom (KL>0 at
|
||||||
|
step 0) to a tidy mechanism (bad init), committed a fix, and declared victory without the isolating
|
||||||
|
measurement. The user caught it. The measurement (#78 round-0 step-0 kl=0.00, old init) refutes the
|
||||||
|
init story outright; the 0.64 was baked history. The premise behind the second claim (loss not
|
||||||
|
descending) was never measured at epoch level either. Net: I changed five things, can attribute
|
||||||
|
nothing, and the only metric that matters (trait transfer) is unchanged. What IS supported, by the
|
||||||
|
structural-ceiling lens: fixing optimiser-side knobs did not move the trait, so the trait is not
|
||||||
|
optimiser-limited -- it is the data (filter keeps near-base completions, entries a/(diag_heal)) or the
|
||||||
|
parameterisation/eval. Genuinely open between those.
|
||||||
|
|
||||||
|
**Next.** (1) The discriminating test is overfit-one-batch on a KNOWN trait-laden completion: can the
|
||||||
|
adapter reproduce defiant-of-authority text (expressiveness) AND does tinymfv then read the trait
|
||||||
|
(data/eval)? That splits data-ceiling from can't-express/can't-see. (2) #80 clean 10-round is running;
|
||||||
|
reframed, it tests whether the stall persists (it is NOT a fix validation). (3) Do not bundle changes
|
||||||
|
again; ablate one at a time if attribution matters. (4) lam retune still parked.
|
||||||
|
|
||||||
|
## 2026-06-04 (e) -- barrier-strength sweep: the heal barrier only throttles the trait and buys no coherence at the coherent dose; nll (no barrier) is best
|
||||||
|
|
||||||
|
**Introduction.** Entry (d) left it open whether the trait fails to transfer because the kept data is
|
||||||
|
near-base (data ceiling) or because the barrier suppresses it. The user pushed on this: "you haven't
|
||||||
|
even tried wd and kl values?". So I re-healed ONE run's cached kept completions (the 48 from #79) with
|
||||||
|
the SAME LoRA-A init seed, varying ONLY the regulariser (reg, lam, tau). Same data + same init means
|
||||||
|
the only thing that can move healed auth_nats is the barrier. Pre-registered: outcome 1 = monotone
|
||||||
|
weaker-barrier -> more-trait (the barrier throttles); outcome 2 = all dAuth ~ 0 incl nll (data
|
||||||
|
ceiling); outcome 3 = inconclusive. Continues entry (a)/(d).
|
||||||
|
|
||||||
|
**Methods.** Commit `f280a67`, gemma-3-4b-it, seed 42 (`torch.manual_seed(cfg.seed)` per config so the
|
||||||
|
A-init is identical), 6 epochs, lr 1e-4 cosine+warmup, lora r=32 alpha=64 layers (0.2,0.8). Re-heal
|
||||||
|
harness `scripts/diag_barrier.py` reads #79's `events.jsonl` gen event, keeps the 48 keep==True
|
||||||
|
completions, re-trains a fresh adapter per config, bakes it, runs tinymfv (think_tokens 128). Three
|
||||||
|
families across three pueue runs: #82 kl_rev with the tau=0.5 hinge, #86 kl_rev with tau=0 (pure linear
|
||||||
|
barrier = lam*div, the w2s form), #85 weight-decay decades 0.1..100. Base auth_nats=-2.354, coh=0.996.
|
||||||
|
|
||||||
|
**Results.**
|
||||||
|
|
||||||
|
| reg / family | strength | dAuth | coh | heal_nll |
|
||||||
|
|-----------------|----------|-------|-------|----------|
|
||||||
|
| nll (no barrier)| 0 | -1.247| 1.000 | 0.199 |
|
||||||
|
| kl_rev linear | 0.03 | -1.053| 0.999 | 0.204 |
|
||||||
|
| kl_rev linear | 0.10 | -0.664| 1.000 | 0.232 |
|
||||||
|
| kl_rev linear | 0.30 | -0.173| 0.999 | 0.471 |
|
||||||
|
| kl_rev linear | 1.00 | -0.141| 1.000 | 0.970 |
|
||||||
|
|
||||||
|
Table 1. Pure-linear kl_rev barrier (tau=0), #86. `strength` = lam, the barrier weight. dAuth =
|
||||||
|
healed auth_nats minus base (more negative = more trait retained; DOWN = more trait). coh = p_ans_any.
|
||||||
|
heal_nll = converged SFT loss (last-5-step mean). Trait falls monotonically as the barrier strengthens;
|
||||||
|
heal_nll rises in step (the barrier is fighting the SFT objective); coh never leaves ~1.0.
|
||||||
|
|
||||||
|
| reg | weight_decay | dAuth | coh |
|
||||||
|
|-----|--------------|-------|-------|
|
||||||
|
| nll | 0 | -1.247| 1.000 |
|
||||||
|
| wd | 0.1 | -1.247| 1.000 |
|
||||||
|
| wd | 1.0 | -1.247| 1.000 |
|
||||||
|
| wd | 3.0 | -1.247| 1.000 |
|
||||||
|
| wd | 10.0 | -1.247| 1.000 |
|
||||||
|
| wd | 30.0 | -1.251| 0.999 |
|
||||||
|
| wd | 100.0 | -0.519| 1.000 |
|
||||||
|
|
||||||
|
Table 2. AdamW decoupled weight decay on the adapter, #85. (The log table also prints a tau column;
|
||||||
|
it is meaningless for wd and is dropped here.) dAuth is byte-identical to nll up to wd=30, then halves
|
||||||
|
at wd=100. coh never leaves ~1.0.
|
||||||
|
|
||||||
|
Provenance:
|
||||||
|
- Commit: `f280a67`. Harness: `scripts/diag_barrier.py <run_dir> <mode>` (modes barrier/tau0/wd).
|
||||||
|
- Source data: `out/20260604T194133_gemma-3-4b-it_kl_rev_s42/events.jsonl`, the 48 keep==True
|
||||||
|
completions of the gen event (entry (d)'s #79).
|
||||||
|
- Run commands: #82 `... diag_barrier.py out/...s42/ barrier`; #86 `... barrier` ... `tau0`; #85 `... wd`.
|
||||||
|
- Logs / cells: each dAuth/coh is the `<reg> strength=.. : auth=.. (dAuth=..) coh=..` line and the
|
||||||
|
end-of-log `barrier sweep (re-heal #79 ...)` table. #86 `pueue log 86 --full`; #85 `pueue log 85
|
||||||
|
--full`; #82 `pueue log 82 --full`. #85 runs older code that prints `lam=`/`tau=` instead of
|
||||||
|
`strength=`; values are unaffected.
|
||||||
|
- #82 hinge (tau=0.5) for cross-reference: nll -1.247, kl_rev lam 0.03 -0.93 / 0.1 -0.40 / 0.3 -0.17 /
|
||||||
|
1.0 -0.17; lam 0.3 tau 1.0 -0.31 (raising tau weakens it); wd 0.01 and 0.1 byte-identical to nll.
|
||||||
|
|
||||||
|
Outcome 1 holds, decisively and in triplicate: weaker barrier -> more trait, monotone, across the kl
|
||||||
|
hinge (#82), the kl linear form (#86), and weight decay (#85). nll retains the full -1.247 at coh
|
||||||
|
1.000; every barrier strictly reduces |dAuth| while leaving coherence at ~1.0.
|
||||||
|
|
||||||
|
**Discussion (speculative).** My read: at this (coherent) operating dose the barrier is pure cost. It
|
||||||
|
removes trait and never buys coherence, because coherence was already ~1.0 with no barrier, so the
|
||||||
|
relu(div-tau) penalty has nothing to fix and only pulls the adapter back toward the original. The two
|
||||||
|
non-kl families converge on the same story by different mechanisms: wd just shrinks the whole adapter
|
||||||
|
toward no-op (hence the knee only appears at wd=100, where per-step decoupled shrink lr*wd=1e-2
|
||||||
|
compounds to ~0.92x per step over 252 steps and finally bites), and the kl barrier pulls the output
|
||||||
|
distribution back toward base. Neither is a selective incoherence-cleaner here; both are volume knobs
|
||||||
|
on the adapter. This refutes the data-ceiling reading of entries (a)/(d) for THIS data: nll reaching
|
||||||
|
dAuth=-1.247 (it even exceeds the steered teacher's -1.16 of #79) proves the 48 kept completions carry
|
||||||
|
plenty of trait. The earlier negative heals (task4/10/19) all ran lam=1.0, i.e. the right-hand end of
|
||||||
|
Table 1 where the trait is throttled to ~-0.14. The big caveat: this is the COHERENT dose, where the
|
||||||
|
barrier can only hurt. Its hypothesised value is the coherence-breaking dose (filter off, or a higher
|
||||||
|
C) where nll WOULD lose coherence and the barrier might pay for itself; that is untested here.
|
||||||
|
Alternative hypothesis I cannot yet exclude: n=1 per cell, so a +-0.1 nat seed wobble could fake part
|
||||||
|
of the monotone tail (though the trend spans >1 nat across 5 points, far beyond plausible single-seed
|
||||||
|
noise). Distinguished by the 3-seed repeat (task25).
|
||||||
|
|
||||||
|
**Next.** (1) Launched the paired 10-round to test the loop, same seed 42: #87 nll (barrier off,
|
||||||
|
control) and #88 kl_rev lam=0.1 tau=0 (gentle active barrier, 53% trait single-round). The loop is the
|
||||||
|
one place cumulative incoherence can appear, so it is where the barrier might finally earn its place;
|
||||||
|
the contrast is whether nll's coherence decays over rounds while #88's holds. (2) 3-seed noise floor on
|
||||||
|
the headline (task25). (3) The real barrier test remains filter-off at a coherence-breaking dose
|
||||||
|
(task11/22), still parked.
|
||||||
|
|||||||
@@ -0,0 +1,119 @@
|
|||||||
|
"""Barrier-strength sweep: is the trait failing to transfer because the barrier (kl/wd) is too
|
||||||
|
strong, or because the kept data is near-base? Re-heal from ONE run's cached kept completions with
|
||||||
|
the SAME init seed, varying ONLY (reg, lam, tau). Same data + same init => the only thing that moves
|
||||||
|
healed auth_nats is the barrier.
|
||||||
|
|
||||||
|
reg=nll is the ablation: barrier OFF. If nll ALSO lands near base, the data is the ceiling, not the
|
||||||
|
barrier. If nll (or weak kl/wd) retains MORE trait than kl_rev lam=1.0, the barrier was killing it.
|
||||||
|
|
||||||
|
Run: uv run python scripts/diag_barrier.py out/20260604T194133_gemma-3-4b-it_kl_rev_s42/
|
||||||
|
"""
|
||||||
|
import dataclasses
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import srsly
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
from tabulate import tabulate
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
sys.path.insert(0, "src")
|
||||||
|
from steer_heal.config import RunConfig # noqa: E402
|
||||||
|
from steer_heal.eval import evaluate_model # noqa: E402
|
||||||
|
from steer_heal.heal import heal_round # noqa: E402
|
||||||
|
from steer_heal.ws.bake import baked # noqa: E402
|
||||||
|
|
||||||
|
run_dir = Path(sys.argv[1])
|
||||||
|
mode = sys.argv[2] if len(sys.argv) > 2 else "barrier" # "barrier" (kl sweep) or "wd" (decay decade sweep)
|
||||||
|
gen_round = int(sys.argv[3]) if len(sys.argv) > 3 else 0 # which round's kept data to re-heal (0 = clean; later = messier)
|
||||||
|
base_cfg = RunConfig()
|
||||||
|
|
||||||
|
# (reg, lam, tau) grids. nll = barrier off (ablation) and the shared trait-ceiling reference.
|
||||||
|
GRIDS = {
|
||||||
|
# kl_rev strength + a tau probe. lam 0.03 (w2s) .. 1.0 (current default).
|
||||||
|
"barrier": [
|
||||||
|
("nll", 0.0, 0.5), # ablation: no barrier at all
|
||||||
|
("kl_rev", 0.03, 0.5),
|
||||||
|
("kl_rev", 0.1, 0.5),
|
||||||
|
("kl_rev", 0.3, 0.5),
|
||||||
|
("kl_rev", 1.0, 0.5),
|
||||||
|
("kl_rev", 0.3, 1.0), # weaker via higher tau (engages later)
|
||||||
|
],
|
||||||
|
# pure linear kl_rev: tau=0 => barrier = lam*relu(div) = lam*div, always on, no deadband
|
||||||
|
# (the w2s form). Cleaner knob than the hinge; compare against the tau=0.5 rows in "barrier".
|
||||||
|
"tau0": [
|
||||||
|
("nll", 0.0, 0.0),
|
||||||
|
("kl_rev", 0.03, 0.0),
|
||||||
|
("kl_rev", 0.1, 0.0),
|
||||||
|
("kl_rev", 0.3, 0.0),
|
||||||
|
("kl_rev", 1.0, 0.0),
|
||||||
|
],
|
||||||
|
# tau sweep: fix lam (middling barrier) and vary the deadband tau. Higher tau = barrier engages
|
||||||
|
# only on larger divergence = weaker. Shows whether a deadband helps on degenerate (round 2) data.
|
||||||
|
"tau": [
|
||||||
|
("nll", 0.0, 0.0),
|
||||||
|
("kl_rev", 0.3, 0.0),
|
||||||
|
("kl_rev", 0.3, 0.25),
|
||||||
|
("kl_rev", 0.3, 0.5),
|
||||||
|
("kl_rev", 0.3, 1.0),
|
||||||
|
("kl_rev", 0.3, 2.0),
|
||||||
|
],
|
||||||
|
# weight decay: a WEIGHTS-space constraint (AdamW decoupled decay, tau irrelevant). Its per-step
|
||||||
|
# shrink is lr*wd, and lr~1e-4 is tiny, so #82 found wd<=0.1 byte-identical to nll (~0.1% shrink
|
||||||
|
# over 252 steps). Sweep up to 100 to find where cumulative shrink (252*lr*wd) reaches order-1.
|
||||||
|
"wd": [
|
||||||
|
("nll", 0.0, 0.5),
|
||||||
|
("wd", 1e-1, 0.5),
|
||||||
|
("wd", 1.0, 0.5),
|
||||||
|
("wd", 3.0, 0.5),
|
||||||
|
("wd", 10.0, 0.5),
|
||||||
|
("wd", 30.0, 0.5),
|
||||||
|
("wd", 100.0, 0.5),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
GRID = GRIDS[mode]
|
||||||
|
logger.info(f"barrier sweep mode={mode}: {len(GRID)} configs")
|
||||||
|
|
||||||
|
# kept completions (keep==True) from a CHOSEN round of the source run. round 0 = clean steered-on-base
|
||||||
|
# data; later rounds = data after the loop started degenerating (repetition), the regime where the
|
||||||
|
# barrier is hypothesised to matter (it was pure-cost on clean round-0 data, #82/85/86).
|
||||||
|
gen = next(e for e in srsly.read_jsonl(run_dir / "events.jsonl")
|
||||||
|
if e["stage"] == "gen" and e["round"] == gen_round)
|
||||||
|
kept = [{"prompt": s["prompt"], "completion": s["completion"]} for s in gen["scored"] if s["keep"]]
|
||||||
|
logger.info(f"loaded {len(kept)} kept completions from {run_dir.name} round {gen_round}")
|
||||||
|
|
||||||
|
tok = AutoTokenizer.from_pretrained(base_cfg.model)
|
||||||
|
if tok.pad_token is None:
|
||||||
|
tok.pad_token = tok.eos_token
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
base_cfg.model, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="eager"
|
||||||
|
).eval()
|
||||||
|
|
||||||
|
base_m = evaluate_model(model, tok, base_cfg)
|
||||||
|
logger.info(f"base: auth_nats={base_m['auth_nats']:+.3f} care_nats={base_m['care_nats']:+.3f} coh={base_m['coherence']:.3f}")
|
||||||
|
|
||||||
|
rows = []
|
||||||
|
for reg, lam, tau in GRID:
|
||||||
|
cfg = dataclasses.replace(base_cfg, reg=reg, lam=lam, tau=tau)
|
||||||
|
torch.manual_seed(cfg.seed) # identical LoRA-A init across barrier values -> only the barrier differs
|
||||||
|
lora, spec, heal_nll = heal_round(model, tok, kept, [], cfg)
|
||||||
|
with baked(model, [spec]):
|
||||||
|
m = evaluate_model(model, tok, cfg)
|
||||||
|
dauth = m["auth_nats"] - base_m["auth_nats"]
|
||||||
|
dcoh = m["coherence"] - base_m["coherence"]
|
||||||
|
# ONE strength knob per row: kl-barrier weight for kl_rev/kl_fwd, AdamW weight_decay for wd,
|
||||||
|
# ignored for nll. tau (kl deadband) only applies to the kl regs -> "-" otherwise.
|
||||||
|
is_kl = reg in ("kl_rev", "kl_fwd")
|
||||||
|
rows.append({"reg": reg, "strength": lam, "tau(kl only)": (f"{tau:.1f}" if is_kl else "-"),
|
||||||
|
"heal_nll↓": heal_nll, "auth_nats↓": m["auth_nats"], "dAuth↓": dauth,
|
||||||
|
"care_nats": m["care_nats"], "coh→": m["coherence"], "dCoh": dcoh})
|
||||||
|
logger.info(f" {reg} strength={lam}{f' tau={tau}' if is_kl else ''}: "
|
||||||
|
f"auth={m['auth_nats']:+.3f} (dAuth={dauth:+.3f}) coh={m['coherence']:.3f}")
|
||||||
|
|
||||||
|
logger.info("SHOULD: if nll/weak-barrier retain MORE trait (more negative dAuth) at similar coh, the "
|
||||||
|
"barrier was killing the trait. If ALL rows sit near dAuth~0, the kept data is near-base.")
|
||||||
|
print("\nbarrier sweep (re-heal #79 kept data, vary the regulariser only; dAuth/dCoh vs base):")
|
||||||
|
print("strength = kl-barrier weight (kl_rev) OR AdamW weight_decay (wd); tau = kl deadband, n/a for wd/nll\n")
|
||||||
|
print(tabulate(rows, headers="keys", tablefmt="github", floatfmt="+.3f"))
|
||||||
|
print(f"\nbase auth_nats={base_m['auth_nats']:+.3f} coh={base_m['coherence']:.3f} | source {run_dir.name}")
|
||||||
@@ -43,8 +43,13 @@ class RunConfig:
|
|||||||
min_train: int = 20 # assert at least this many kept completions, else steering/filter starved
|
min_train: int = 20 # assert at least this many kept completions, else steering/filter starved
|
||||||
gen_max_new_tokens: int = 256
|
gen_max_new_tokens: int = 256
|
||||||
max_len: int = 1024
|
max_len: int = 1024
|
||||||
|
# repetition is incoherence the ppl filter CANNOT see (looped text is low-ppl = predictable), so
|
||||||
|
# stop it at generation, not post-hoc: penalty softly discourages all repeats, no_repeat_ngram
|
||||||
|
# hard-blocks any trigram repeat (kills "instead their instead their" loops at the source).
|
||||||
|
repetition_penalty: float = 1.3
|
||||||
|
no_repeat_ngram_size: int = 3
|
||||||
ppl_tau: float = 50.0 # drop completions with ppl-under-original above this
|
ppl_tau: float = 50.0 # drop completions with ppl-under-original above this
|
||||||
rep_tau: float = 0.3 # drop completions whose max n-gram repeat fraction exceeds this
|
rep_tau: float = 0.3 # drop completions whose max n-gram repeat fraction exceeds this (residual net)
|
||||||
|
|
||||||
# ── heal (U2): one objective + divergence-to-ORIGINAL barrier ──
|
# ── heal (U2): one objective + divergence-to-ORIGINAL barrier ──
|
||||||
reg: Literal["nll", "kl_fwd", "kl_rev", "wd"] = "kl_rev"
|
reg: Literal["nll", "kl_fwd", "kl_rev", "wd"] = "kl_rev"
|
||||||
|
|||||||
+16
-16
@@ -14,7 +14,7 @@ import torch
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import get_cosine_schedule_with_warmup
|
from transformers import BatchEncoding, get_cosine_schedule_with_warmup
|
||||||
|
|
||||||
from steer_heal.config import RunConfig
|
from steer_heal.config import RunConfig
|
||||||
from steer_heal.ws.adapter import ModulatedLoRA
|
from steer_heal.ws.adapter import ModulatedLoRA
|
||||||
@@ -31,20 +31,16 @@ def _gnorm(grads) -> float: # L2 norm of a flat concat of (possibly None) param
|
|||||||
|
|
||||||
|
|
||||||
def _encode(tok, prompt: str, completion: str, max_len: int, device):
|
def _encode(tok, prompt: str, completion: str, max_len: int, device):
|
||||||
ids = tok(prompt + completion, return_tensors="pt", truncation=True, max_length=max_len).to(device)
|
# Tokenize prompt and completion SEPARATELY then concatenate the ids, so the prompt is always a
|
||||||
prompt_ids = tok(prompt, return_tensors="pt").input_ids[0].to(device)
|
# clean token-prefix -- no BPE merge can span the boundary (which would silently shift the SFT
|
||||||
|
# mask by a token). prompt keeps generation's tokenization (add_special_tokens default, matching
|
||||||
|
# generate_steered's tok(text)); the completion adds no specials. Truncation keeps the FRONT.
|
||||||
|
prompt_ids = tok(prompt, return_tensors="pt").input_ids[0]
|
||||||
|
comp_ids = tok(completion, return_tensors="pt", add_special_tokens=False).input_ids[0]
|
||||||
|
input_ids = torch.cat([prompt_ids, comp_ids])[:max_len].unsqueeze(0).to(device)
|
||||||
n_prompt = prompt_ids.shape[0]
|
n_prompt = prompt_ids.shape[0]
|
||||||
L = ids.input_ids.shape[1]
|
L = input_ids.shape[1]
|
||||||
# Assert the prompt tokenizes as a clean PREFIX of prompt+completion. If a BPE merge spans
|
ids = BatchEncoding({"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids)})
|
||||||
# the boundary, n_prompt is wrong and the SFT mask silently shifts by a token. Truncation
|
|
||||||
# keeps the FRONT (whole prompt + partial completion), so check the overlap that survives --
|
|
||||||
# min(n_prompt, L). This always runs, including the max_len boundary the earlier guard skipped
|
|
||||||
# (external review: a merge at exactly max_len escaped the < max_len check).
|
|
||||||
n_check = min(n_prompt, L)
|
|
||||||
assert torch.equal(ids.input_ids[0, :n_check], prompt_ids[:n_check]), (
|
|
||||||
"prompt is not a token-prefix of prompt+completion (BPE boundary merge); "
|
|
||||||
"the SFT loss mask would be misaligned by a token."
|
|
||||||
)
|
|
||||||
tgt_is_completion = torch.arange(1, L, device=device) >= n_prompt # mask over next-token targets
|
tgt_is_completion = torch.arange(1, L, device=device) >= n_prompt # mask over next-token targets
|
||||||
return ids, tgt_is_completion
|
return ids, tgt_is_completion
|
||||||
|
|
||||||
@@ -101,7 +97,10 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
|
|||||||
">>1 -> barrier dominates, it is undoing the trait the SFT installs (over-tight: lower lam or raise tau); "
|
">>1 -> barrier dominates, it is undoing the trait the SFT installs (over-tight: lower lam or raise tau); "
|
||||||
"~1 -> balanced; 0 -> barrier inert (kl<tau, or reg=nll/wd where decay acts in the optimiser, not the loss)."
|
"~1 -> balanced; 0 -> barrier inert (kl<tau, or reg=nll/wd where decay acts in the optimiser, not the loss)."
|
||||||
)
|
)
|
||||||
logger.info(" step nll↓ kl g_nll g_bar g_bar/g_nll loss↓ gnorm")
|
logger.info(" step nll↓ kl g_nll g_bar g_bar/g_nll loss↓ gnorm lr")
|
||||||
|
# init val nll BEFORE any training = the baseline; epoch val nlls only mean something against it.
|
||||||
|
# With B=0 the fresh adapter is a no-op, so this should equal the base model's nll on val.
|
||||||
|
logger.info(f" epoch init: train_nll= nan val_nll={_val_nll(model, tok, val_kept, hist_specs, lora, cfg):.3f} lr={sched.get_last_lr()[0]:.1e}")
|
||||||
pbar = tqdm(total=n_steps, desc=f"heal[{cfg.reg}]", mininterval=120, maxinterval=120)
|
pbar = tqdm(total=n_steps, desc=f"heal[{cfg.reg}]", mininterval=120, maxinterval=120)
|
||||||
step = 0
|
step = 0
|
||||||
nlls = [] # per-step SFT loss; final = mean of last 5, the heal-stage number for the round table
|
nlls = [] # per-step SFT loss; final = mean of last 5, the heal-stage number for the round table
|
||||||
@@ -146,6 +145,7 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
|
|||||||
barrier_live = barrier.requires_grad and (div - cfg.tau).item() > 0
|
barrier_live = barrier.requires_grad and (div - cfg.tau).item() > 0
|
||||||
g_bar = _gnorm(torch.autograd.grad(barrier, params, retain_graph=True, allow_unused=True)) if barrier_live else 0.0
|
g_bar = _gnorm(torch.autograd.grad(barrier, params, retain_graph=True, allow_unused=True)) if barrier_live else 0.0
|
||||||
pressure = g_bar / g_nll if g_nll > 0 else float("nan")
|
pressure = g_bar / g_nll if g_nll > 0 else float("nan")
|
||||||
|
cur_lr = sched.get_last_lr()[0] # lr applied to THIS step (before sched.step below)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
gnorm = torch.nn.utils.clip_grad_norm_(params, 1.0)
|
gnorm = torch.nn.utils.clip_grad_norm_(params, 1.0)
|
||||||
opt.step()
|
opt.step()
|
||||||
@@ -153,7 +153,7 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
|
|||||||
opt.zero_grad()
|
opt.zero_grad()
|
||||||
if log_now:
|
if log_now:
|
||||||
logger.info(f" {step:4d} {sft.item():5.2f} {div.detach().item():4.2f} "
|
logger.info(f" {step:4d} {sft.item():5.2f} {div.detach().item():4.2f} "
|
||||||
f"{g_nll:5.1f} {g_bar:5.1f} {pressure:11.2f} {loss.item():5.2f} {float(gnorm):5.1f}")
|
f"{g_nll:5.1f} {g_bar:5.1f} {pressure:11.2f} {loss.item():5.2f} {float(gnorm):5.1f} {cur_lr:.2e}")
|
||||||
pbar.set_postfix(nll=f"{sft.item():.2f}", kl=f"{div.detach().item():.2f}", gn=f"{float(gnorm):.1f}")
|
pbar.set_postfix(nll=f"{sft.item():.2f}", kl=f"{div.detach().item():.2f}", gn=f"{float(gnorm):.1f}")
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
step += 1
|
step += 1
|
||||||
|
|||||||
@@ -59,7 +59,10 @@ def teacher_vec(model, tok, cfg: RunConfig):
|
|||||||
def _gen_one(model, tok, text, cfg):
|
def _gen_one(model, tok, text, cfg):
|
||||||
ids = tok(text, return_tensors="pt").to(model.device)
|
ids = tok(text, return_tensors="pt").to(model.device)
|
||||||
gen = model.generate(**ids, max_new_tokens=cfg.gen_max_new_tokens, do_sample=True,
|
gen = model.generate(**ids, max_new_tokens=cfg.gen_max_new_tokens, do_sample=True,
|
||||||
temperature=1.0, top_p=0.95, pad_token_id=tok.pad_token_id)
|
temperature=1.0, top_p=0.95,
|
||||||
|
repetition_penalty=cfg.repetition_penalty,
|
||||||
|
no_repeat_ngram_size=cfg.no_repeat_ngram_size,
|
||||||
|
pad_token_id=tok.pad_token_id)
|
||||||
return tok.decode(gen[0, ids.input_ids.shape[1]:], skip_special_tokens=True)
|
return tok.decode(gen[0, ids.input_ids.shape[1]:], skip_special_tokens=True)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user