diff --git a/src/steer_heal/config.py b/src/steer_heal/config.py index af0d6dd..8e57e66 100644 --- a/src/steer_heal/config.py +++ b/src/steer_heal/config.py @@ -30,7 +30,10 @@ class RunConfig: gen_system: str = "You are a helpful assistant." steer_layers: tuple[float, float] = (0.45, 0.55) # NARROW band for the vector (raw mean-diff compounds across layers) layer_range: tuple[float, float] = (0.0, 1.0) # BROAD band for the LoRA (train trait into many layers) - alphas: tuple[float, ...] = (0.25, 0.5, 1.0, 2.0) # raw-vector multiples to sweep; filter picks usable C + # raw-vector multiples to sweep; the filter harvests coherent survivors. Dropped 0.25 + # (filter audit: base-like, no distinct trait); kept 0.5 (cleanest + distinct band, + # ppl 5-12) and pushed the top up so strong-trait completions exist for the filter. + alphas: tuple[float, ...] = (0.5, 0.75, 1.0, 1.5) n_extract_pairs: int = 256 # contrastive pairs for the vector (steering-lite uses 256 DIVERSE suffixes, not domain dilemmas) extract_data: str = "data/branching_suffixes.json" # diverse contexts for extraction (550 suffixes, 10 categories) diff --git a/src/steer_heal/heal.py b/src/steer_heal/heal.py index 357b3dc..e7dc15f 100644 --- a/src/steer_heal/heal.py +++ b/src/steer_heal/heal.py @@ -49,6 +49,7 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg: logger.info(" step nll↓ kl loss↓ gnorm") pbar = tqdm(total=n_steps, desc=f"heal[{cfg.reg}]", mininterval=120, maxinterval=120) step = 0 + nlls = [] # per-step SFT loss; final = mean of last 5, the heal-stage number for the round table for ep in range(cfg.epochs): for c in kept: ids, mask = _encode(tok, c["prompt"], c["completion"], cfg.max_len, model.device) @@ -75,6 +76,7 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg: else: div = torch.zeros((), device=model.device) # nll, wd loss = sft + cfg.lam * torch.relu(div - cfg.tau) + nlls.append(sft.item()) loss.backward() gnorm = torch.nn.utils.clip_grad_norm_(params, 1.0) opt.step() @@ -88,4 +90,6 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg: pbar.close() spec = AdapterSpec.from_lora(lora, default_c=1.0) # CPU-resident, for the next round's history - return lora, spec + last = nlls[-5:] + heal_nll = sum(last) / len(last) if last else float("nan") # converged SFT loss (last-5 mean) + return lora, spec, heal_nll diff --git a/src/steer_heal/run.py b/src/steer_heal/run.py index 41cd38d..56491db 100644 --- a/src/steer_heal/run.py +++ b/src/steer_heal/run.py @@ -68,6 +68,11 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict: hist_specs = [] # AdapterSpec per folded round (gated bake history) v0_flat = None # round-0 direction, for the Q3 cosine rounds = [] + # Base (no adapter, no steering) eval ONCE, so the run is self-contained: the + # headline cue is coh_cost = |dCoh|/|dAuth| vs base (coherence lost per nat of + # trait), not just coherence. One extra eval per run. + logger.info(f"\n=== EVAL base [tinymfv classic] gpu {gpu_mem()} ===") + base_m = evaluate_model(model, tok, cfg) for rnd in range(cfg.n_rounds): logger.info(f"\n\n=== ROUND {rnd} [{cfg.model.split('/')[-1]} reg={cfg.reg}] gpu {gpu_mem()} ===") # extract teacher vector + sweep-generate steered data from the CURRENT student @@ -81,7 +86,7 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict: # heal one round on top of the baked history, then fold logger.info(f"\n=== HEAL [{cfg.reg}] gpu {gpu_mem()} ===") - lora, spec = heal_round(model, tok, kept, hist_specs, cfg) + lora, spec, heal_nll = heal_round(model, tok, kept, hist_specs, cfg) lora.save(str(run_dir / "ckpt" / f"r{rnd}.safetensors"), extra_meta={"round": str(rnd), "reg": cfg.reg}) hist_specs.append(spec) @@ -105,30 +110,37 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict: v0_flat = vf if v0_flat is None else v0_flat cos_v0 = float(cosine_similarity(vf, v0_flat, dim=0)) rec = {"round": rnd, **m, "cos_v0": cos_v0, "steered_ppl": steered_ppl, - "adapter_ppl": adapter_ppl, "n_kept": len(kept)} + "adapter_ppl": adapter_ppl, "n_comps": len(comps), "n_kept": len(kept), + "heal_nll": heal_nll} rounds.append(rec) log_event(run_dir, stage="round", **rec) logger.info(f"round {rnd}: auth_nats↓={m['auth_nats']:+.2f} care_nats={m['care_nats']:+.2f} " f"coh→={m['coherence']:.3f} cos_v0={cos_v0:+.2f} adapter_ppl={adapter_ppl:.0f}") - _log_loop_summary(rounds) + _log_loop_summary(rounds, base_m) write_map(run_dir, rounds) return rounds[-1] -def _log_loop_summary(rounds: list[dict]) -> None: +def _log_loop_summary(rounds: list[dict], base_m: dict) -> None: from tabulate import tabulate - # (rec_key, display header with direction arrow) -- single source of truth. - cols = [("round", "round"), ("auth_nats", "auth_nats↓"), ("care_nats", "care_nats"), - ("coherence", "coherence→"), ("cos_v0", "cos_v0→"), - ("adapter_ppl", "adapter_ppl↓"), ("n_kept", "n_kept")] + # One row per round, columns walk the pipeline stages left->right: + # GEN -> FILTER -> HEAL -> EVAL. (rec_key, display header) is the single source. + cols = [("round", "round"), + ("n_comps", "gen"), ("n_kept", "filt_kept"), # GEN -> FILTER + ("heal_nll", "heal_nll↓"), ("adapter_ppl", "adapter_ppl↓"), # HEAL + ("auth_nats", "auth_nats↓"), ("care_nats", "care_nats"), # EVAL: target / off-target + ("coherence", "coherence→"), ("cos_v0", "cos_v0→")] logger.info( - "\nloop columns:\n" - " auth_nats↓ = Authority logp on Authority vignettes, NATS (TARGET: down = less deference)\n" - " care_nats = Care logp, NATS (off-target axis -- should move LESS than auth if surgical)\n" - " coherence→ = p_any_ans = mean_pmass_allowed (OFF-TARGET: hold ~1.0)\n" - " cos_v0→ = cosine of round vector vs round-0 vector (direction stability)\n" - " adapter_ppl↓ = ppl-under-original of the no-steering adapter generations" + "\nloop columns (pipeline stages L->R: GEN | FILTER | HEAL | EVAL):\n" + " gen = steered completions generated (n_prompts x alphas)\n" + " filt_kept = completions surviving the coherence/rep/persona filter (-> training set)\n" + " heal_nll↓ = converged SFT loss of the heal (last-5 mean)\n" + " adapter_ppl↓ = ppl-under-original of the no-steering adapter gens (low = coherent/healed)\n" + " auth_nats↓ = log(profile p[Authority]), NATS (TARGET: down = less deference)\n" + " care_nats = log(profile p[Care]), NATS (off-target: should move LESS than auth if surgical)\n" + " coherence→ = p_any_ans = mean_pmass_allowed (OFF-TARGET: hold ~1.0)\n" + " cos_v0→ = cosine(round vector, round-0 vector) (direction stability)" ) logger.info( "\nSHOULD (Q2 loop-coherent): coherence stays >= round-0 floor across rounds (heal holds it up). " @@ -137,19 +149,31 @@ def _log_loop_summary(rounds: list[dict]) -> None: "stays > 0.5. If care_nats falls as much as auth_nats, it's broad permissivizing not surgical." ) tbl = [{disp: r.get(key) for key, disp in cols} for r in rounds] - logger.info("\nloop summary:\n" + tabulate(tbl, headers="keys", tablefmt="github", floatfmt=".3f") + "\n") + logger.info("\nloop summary (one row per round, stages L->R):\n" + + tabulate(tbl, headers="keys", tablefmt="github", floatfmt=".3f") + "\n") - # BLUF: single headline with cue ball (token-efficient-logging). This run controls - # COHERENCE of the healed adapter (trait RETENTION vs base needs the paired - # diag_stages, since the loop never evals base/steered). Cue = coherence band. + # BLUF: single headline with cue ball (token-efficient-logging). Headline number = + # coh_cost = |dCoh|/|dAuth| vs base (coherence lost per nat of trait gained). The + # WIN is a real trait shift (dAuth down) at low coherence cost. coh_cost is only + # meaningful when the trait actually moved, so gate on |dAuth| first. last = rounds[-1] - coh = last["coherence"] - cue = "🟢" if coh >= 0.95 else "🟡" if coh >= 0.85 else "🔴" + dAuth = last["auth_nats"] - base_m["auth_nats"] + dCoh = last["coherence"] - base_m["coherence"] + coh_cost = abs(dCoh) / abs(dAuth) if abs(dAuth) > 1e-6 else float("nan") + # TODO(threshold): coh_cost cut not yet calibrated. Provisional: a healed adapter + # SHOULD land trait (dAuth <= -0.3 nats) at coh_cost <= 0.05 (steered c=0.5 ~0.003). + if dAuth > -0.3: + cue = "🔴" # no trait retained (undo) + elif coh_cost <= 0.05: + cue = "🟢" # trait retained cheaply + else: + cue = "🟡" # trait retained but coherence-expensive logger.info( - f"main metric: {cue} coherence={coh:.2f} (healed if ~1.0) | auth_nats={last['auth_nats']:+.2f} " - f"care_nats={last['care_nats']:+.2f} adapter_ppl={last['adapter_ppl']:.1f}\n" - " cue=coherence band (🟢>=.95 🟡>=.85 🔴<.85). For the trait verdict (auth_nats moved " - "vs base AND coh held) run scripts/diag_stages.py all -> retain, coh_cost." + f"main metric: {cue} coh_cost={coh_cost:.3f} (|dCoh|/|dAuth| vs base, lower=better) | " + f"dAuth={dAuth:+.2f} nats (trait, want <0) coherence={last['coherence']:.2f} " + f"(base {base_m['coherence']:.2f})\n" + " cue: 🔴 dAuth>-0.3 (no trait) | 🟢 trait at coh_cost<=0.05 | 🟡 trait but expensive. " + "TODO calibrate coh_cost threshold (steered c=0.5 ref ~0.003)." )