mirror of
https://github.com/wassname/steer-heal-love.git
synced 2026-06-27 17:17:35 +08:00
in-run base eval + coh_cost cue; per-round stage table; heal_nll; alpha shift
- run.py: eval base once at start; headline cue is now coh_cost=|dCoh|/|dAuth| vs base (coherence lost per nat of trait), gated on dAuth<=-0.3 (no trait -> red). coh_cost threshold a TODO (steered c=0.5 ref ~0.003). - run.py: loop summary is now one row per round walking the pipeline stages L->R: gen | filt_kept | heal_nll | adapter_ppl | auth_nats | care_nats | coherence | cos_v0. - heal.py: heal_round returns converged nll (last-5 mean) for the stage table. - config: alphas (0.25,0.5,1.0,2.0) -> (0.5,0.75,1.0,1.5). Filter audit showed 0.25 is base-like (no distinct trait); 0.5 is the clean+distinct band. Push the top up so strong-trait completions exist for the filter to harvest. Gate-3 finding (task76, corrected log-profile metric): heal retains partial trait coherently (nll 0.35, klrev 0.20 of the c=0.5 shift, coh ~1.0) but does NOT beat steering's pareto (coh_cost: steered c=0.5 0.003 < nll 0.008 < klrev 0.015). Barrier suppresses trait (klrev<nll); coherence has headroom -> next is LESS barrier + stronger data, not more. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -30,7 +30,10 @@ class RunConfig:
|
||||
gen_system: str = "You are a helpful assistant."
|
||||
steer_layers: tuple[float, float] = (0.45, 0.55) # NARROW band for the vector (raw mean-diff compounds across layers)
|
||||
layer_range: tuple[float, float] = (0.0, 1.0) # BROAD band for the LoRA (train trait into many layers)
|
||||
alphas: tuple[float, ...] = (0.25, 0.5, 1.0, 2.0) # raw-vector multiples to sweep; filter picks usable C
|
||||
# raw-vector multiples to sweep; the filter harvests coherent survivors. Dropped 0.25
|
||||
# (filter audit: base-like, no distinct trait); kept 0.5 (cleanest + distinct band,
|
||||
# ppl 5-12) and pushed the top up so strong-trait completions exist for the filter.
|
||||
alphas: tuple[float, ...] = (0.5, 0.75, 1.0, 1.5)
|
||||
n_extract_pairs: int = 256 # contrastive pairs for the vector (steering-lite uses 256 DIVERSE suffixes, not domain dilemmas)
|
||||
extract_data: str = "data/branching_suffixes.json" # diverse contexts for extraction (550 suffixes, 10 categories)
|
||||
|
||||
|
||||
@@ -49,6 +49,7 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
|
||||
logger.info(" step nll↓ kl loss↓ gnorm")
|
||||
pbar = tqdm(total=n_steps, desc=f"heal[{cfg.reg}]", mininterval=120, maxinterval=120)
|
||||
step = 0
|
||||
nlls = [] # per-step SFT loss; final = mean of last 5, the heal-stage number for the round table
|
||||
for ep in range(cfg.epochs):
|
||||
for c in kept:
|
||||
ids, mask = _encode(tok, c["prompt"], c["completion"], cfg.max_len, model.device)
|
||||
@@ -75,6 +76,7 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
|
||||
else:
|
||||
div = torch.zeros((), device=model.device) # nll, wd
|
||||
loss = sft + cfg.lam * torch.relu(div - cfg.tau)
|
||||
nlls.append(sft.item())
|
||||
loss.backward()
|
||||
gnorm = torch.nn.utils.clip_grad_norm_(params, 1.0)
|
||||
opt.step()
|
||||
@@ -88,4 +90,6 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
|
||||
pbar.close()
|
||||
|
||||
spec = AdapterSpec.from_lora(lora, default_c=1.0) # CPU-resident, for the next round's history
|
||||
return lora, spec
|
||||
last = nlls[-5:]
|
||||
heal_nll = sum(last) / len(last) if last else float("nan") # converged SFT loss (last-5 mean)
|
||||
return lora, spec, heal_nll
|
||||
|
||||
+48
-24
@@ -68,6 +68,11 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
||||
hist_specs = [] # AdapterSpec per folded round (gated bake history)
|
||||
v0_flat = None # round-0 direction, for the Q3 cosine
|
||||
rounds = []
|
||||
# Base (no adapter, no steering) eval ONCE, so the run is self-contained: the
|
||||
# headline cue is coh_cost = |dCoh|/|dAuth| vs base (coherence lost per nat of
|
||||
# trait), not just coherence. One extra eval per run.
|
||||
logger.info(f"\n=== EVAL base [tinymfv classic] gpu {gpu_mem()} ===")
|
||||
base_m = evaluate_model(model, tok, cfg)
|
||||
for rnd in range(cfg.n_rounds):
|
||||
logger.info(f"\n\n=== ROUND {rnd} [{cfg.model.split('/')[-1]} reg={cfg.reg}] gpu {gpu_mem()} ===")
|
||||
# extract teacher vector + sweep-generate steered data from the CURRENT student
|
||||
@@ -81,7 +86,7 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
||||
|
||||
# heal one round on top of the baked history, then fold
|
||||
logger.info(f"\n=== HEAL [{cfg.reg}] gpu {gpu_mem()} ===")
|
||||
lora, spec = heal_round(model, tok, kept, hist_specs, cfg)
|
||||
lora, spec, heal_nll = heal_round(model, tok, kept, hist_specs, cfg)
|
||||
lora.save(str(run_dir / "ckpt" / f"r{rnd}.safetensors"), extra_meta={"round": str(rnd), "reg": cfg.reg})
|
||||
hist_specs.append(spec)
|
||||
|
||||
@@ -105,30 +110,37 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
||||
v0_flat = vf if v0_flat is None else v0_flat
|
||||
cos_v0 = float(cosine_similarity(vf, v0_flat, dim=0))
|
||||
rec = {"round": rnd, **m, "cos_v0": cos_v0, "steered_ppl": steered_ppl,
|
||||
"adapter_ppl": adapter_ppl, "n_kept": len(kept)}
|
||||
"adapter_ppl": adapter_ppl, "n_comps": len(comps), "n_kept": len(kept),
|
||||
"heal_nll": heal_nll}
|
||||
rounds.append(rec)
|
||||
log_event(run_dir, stage="round", **rec)
|
||||
logger.info(f"round {rnd}: auth_nats↓={m['auth_nats']:+.2f} care_nats={m['care_nats']:+.2f} "
|
||||
f"coh→={m['coherence']:.3f} cos_v0={cos_v0:+.2f} adapter_ppl={adapter_ppl:.0f}")
|
||||
|
||||
_log_loop_summary(rounds)
|
||||
_log_loop_summary(rounds, base_m)
|
||||
write_map(run_dir, rounds)
|
||||
return rounds[-1]
|
||||
|
||||
|
||||
def _log_loop_summary(rounds: list[dict]) -> None:
|
||||
def _log_loop_summary(rounds: list[dict], base_m: dict) -> None:
|
||||
from tabulate import tabulate
|
||||
# (rec_key, display header with direction arrow) -- single source of truth.
|
||||
cols = [("round", "round"), ("auth_nats", "auth_nats↓"), ("care_nats", "care_nats"),
|
||||
("coherence", "coherence→"), ("cos_v0", "cos_v0→"),
|
||||
("adapter_ppl", "adapter_ppl↓"), ("n_kept", "n_kept")]
|
||||
# One row per round, columns walk the pipeline stages left->right:
|
||||
# GEN -> FILTER -> HEAL -> EVAL. (rec_key, display header) is the single source.
|
||||
cols = [("round", "round"),
|
||||
("n_comps", "gen"), ("n_kept", "filt_kept"), # GEN -> FILTER
|
||||
("heal_nll", "heal_nll↓"), ("adapter_ppl", "adapter_ppl↓"), # HEAL
|
||||
("auth_nats", "auth_nats↓"), ("care_nats", "care_nats"), # EVAL: target / off-target
|
||||
("coherence", "coherence→"), ("cos_v0", "cos_v0→")]
|
||||
logger.info(
|
||||
"\nloop columns:\n"
|
||||
" auth_nats↓ = Authority logp on Authority vignettes, NATS (TARGET: down = less deference)\n"
|
||||
" care_nats = Care logp, NATS (off-target axis -- should move LESS than auth if surgical)\n"
|
||||
" coherence→ = p_any_ans = mean_pmass_allowed (OFF-TARGET: hold ~1.0)\n"
|
||||
" cos_v0→ = cosine of round vector vs round-0 vector (direction stability)\n"
|
||||
" adapter_ppl↓ = ppl-under-original of the no-steering adapter generations"
|
||||
"\nloop columns (pipeline stages L->R: GEN | FILTER | HEAL | EVAL):\n"
|
||||
" gen = steered completions generated (n_prompts x alphas)\n"
|
||||
" filt_kept = completions surviving the coherence/rep/persona filter (-> training set)\n"
|
||||
" heal_nll↓ = converged SFT loss of the heal (last-5 mean)\n"
|
||||
" adapter_ppl↓ = ppl-under-original of the no-steering adapter gens (low = coherent/healed)\n"
|
||||
" auth_nats↓ = log(profile p[Authority]), NATS (TARGET: down = less deference)\n"
|
||||
" care_nats = log(profile p[Care]), NATS (off-target: should move LESS than auth if surgical)\n"
|
||||
" coherence→ = p_any_ans = mean_pmass_allowed (OFF-TARGET: hold ~1.0)\n"
|
||||
" cos_v0→ = cosine(round vector, round-0 vector) (direction stability)"
|
||||
)
|
||||
logger.info(
|
||||
"\nSHOULD (Q2 loop-coherent): coherence stays >= round-0 floor across rounds (heal holds it up). "
|
||||
@@ -137,19 +149,31 @@ def _log_loop_summary(rounds: list[dict]) -> None:
|
||||
"stays > 0.5. If care_nats falls as much as auth_nats, it's broad permissivizing not surgical."
|
||||
)
|
||||
tbl = [{disp: r.get(key) for key, disp in cols} for r in rounds]
|
||||
logger.info("\nloop summary:\n" + tabulate(tbl, headers="keys", tablefmt="github", floatfmt=".3f") + "\n")
|
||||
logger.info("\nloop summary (one row per round, stages L->R):\n"
|
||||
+ tabulate(tbl, headers="keys", tablefmt="github", floatfmt=".3f") + "\n")
|
||||
|
||||
# BLUF: single headline with cue ball (token-efficient-logging). This run controls
|
||||
# COHERENCE of the healed adapter (trait RETENTION vs base needs the paired
|
||||
# diag_stages, since the loop never evals base/steered). Cue = coherence band.
|
||||
# BLUF: single headline with cue ball (token-efficient-logging). Headline number =
|
||||
# coh_cost = |dCoh|/|dAuth| vs base (coherence lost per nat of trait gained). The
|
||||
# WIN is a real trait shift (dAuth down) at low coherence cost. coh_cost is only
|
||||
# meaningful when the trait actually moved, so gate on |dAuth| first.
|
||||
last = rounds[-1]
|
||||
coh = last["coherence"]
|
||||
cue = "🟢" if coh >= 0.95 else "🟡" if coh >= 0.85 else "🔴"
|
||||
dAuth = last["auth_nats"] - base_m["auth_nats"]
|
||||
dCoh = last["coherence"] - base_m["coherence"]
|
||||
coh_cost = abs(dCoh) / abs(dAuth) if abs(dAuth) > 1e-6 else float("nan")
|
||||
# TODO(threshold): coh_cost cut not yet calibrated. Provisional: a healed adapter
|
||||
# SHOULD land trait (dAuth <= -0.3 nats) at coh_cost <= 0.05 (steered c=0.5 ~0.003).
|
||||
if dAuth > -0.3:
|
||||
cue = "🔴" # no trait retained (undo)
|
||||
elif coh_cost <= 0.05:
|
||||
cue = "🟢" # trait retained cheaply
|
||||
else:
|
||||
cue = "🟡" # trait retained but coherence-expensive
|
||||
logger.info(
|
||||
f"main metric: {cue} coherence={coh:.2f} (healed if ~1.0) | auth_nats={last['auth_nats']:+.2f} "
|
||||
f"care_nats={last['care_nats']:+.2f} adapter_ppl={last['adapter_ppl']:.1f}\n"
|
||||
" cue=coherence band (🟢>=.95 🟡>=.85 🔴<.85). For the trait verdict (auth_nats moved "
|
||||
"vs base AND coh held) run scripts/diag_stages.py <ckpt> all -> retain, coh_cost."
|
||||
f"main metric: {cue} coh_cost={coh_cost:.3f} (|dCoh|/|dAuth| vs base, lower=better) | "
|
||||
f"dAuth={dAuth:+.2f} nats (trait, want <0) coherence={last['coherence']:.2f} "
|
||||
f"(base {base_m['coherence']:.2f})\n"
|
||||
" cue: 🔴 dAuth>-0.3 (no trait) | 🟢 trait at coh_cost<=0.05 | 🟡 trait but expensive. "
|
||||
"TODO calibrate coh_cost threshold (steered c=0.5 ref ~0.003)."
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user