diff --git a/src/steer_heal/run.py b/src/steer_heal/run.py index 132bec9..a11be8e 100644 --- a/src/steer_heal/run.py +++ b/src/steer_heal/run.py @@ -71,6 +71,28 @@ def _mean_finite(xs, label: str = "ppl") -> float: return sum(xs) / len(xs) if xs else float("nan") +def _stage_row(rnd, stage: str, m: dict, base_m: dict) -> dict: + """One row of the base->steered->healed pareto table. dcoh/dauth = coherence + CHANGE per nat of Authority CHANGE vs base (signed): positive = coherence lost + while trait gained (both fall), the cost we want low; nan for the base row (0/0).""" + dAuth = m["auth_nats"] - base_m["auth_nats"] + dCoh = m["coherence"] - base_m["coherence"] + ratio = dCoh / dAuth if abs(dAuth) > 1e-6 else float("nan") + return {"round": rnd, "stage": stage, "dcoh/dauth": ratio, + "coh": m["coherence"], "auth": m["auth_nats"], "care": m["care_nats"]} + + +def _log_stage_table(stage_rows: list[dict]) -> None: + from tabulate import tabulate + logger.info( + "\nstage pareto (base -> steered -> healed, per round):\n" + " dcoh/dauth = coherence change per nat of Authority change vs base (signed, lower=cheaper trait)\n" + " coh = p_any_ans coherence (hold ~1.0) auth = log p[Authority] (DOWN = trait) care = log p[Care] (off-target)\n" + " WIN: healed keeps steered's low auth (trait) but recovers coh toward base AND a smaller dcoh/dauth than steered.\n" + " UNDO: healed auth springs back to ~base while coh recovers -> heal removed the trait, not just the incoherence.\n" + + tabulate(stage_rows, headers="keys", tablefmt="github", floatfmt=".3f") + "\n") + + def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict: hist_specs = [] # AdapterSpec per folded round (gated bake history) v0_flat = None # round-0 direction, for the Q3 cosine @@ -80,12 +102,21 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict: # trait), not just coherence. One extra eval per run. logger.info(f"\n=== EVAL base [tinymfv classic] gpu {gpu_mem()} ===") base_m = evaluate_model(model, tok, cfg) + stage_rows = [_stage_row("-", "base", base_m, base_m)] # pareto table: base -> steered -> healed for rnd in range(cfg.n_rounds): logger.info(f"\n\n=== ROUND {rnd} [{cfg.model.split('/')[-1]} reg={cfg.reg}] gpu {gpu_mem()} ===") # extract teacher vector + sweep-generate steered data from the CURRENT student with baked(model, hist_specs): v = teacher_vec(model, tok, cfg) comps = generate_steered(model, tok, v, cfg) + # STEERED-stage eval: the model state the training data came from (history baked, + # vector live at the operating dose = lowest/cleanest alpha, NO new adapter). This + # is the raw-steering pareto reference the heal must BEAT (same base, trait via + # vector vs trait via the distilled adapter). + c_op = cfg.alphas[0] * v.cfg.coeff + logger.info(f"\n=== EVAL steered [c={cfg.alphas[0]}] gpu {gpu_mem()} ===") + with v(model, C=c_op): + m_steer = evaluate_model(model, tok, cfg) # filter under the ORIGINAL (no history, no steering) -- this picks the usable C logger.info(f"\n=== FILTER [{len(comps)} completions] gpu {gpu_mem()} ===") kept, scored = filter_completions(model, tok, comps, cfg) @@ -120,11 +151,14 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict: "adapter_ppl": adapter_ppl, "n_comps": len(comps), "n_kept": len(kept), "heal_nll": heal_nll} rounds.append(rec) + stage_rows.append(_stage_row(rnd, "steered", m_steer, base_m)) + stage_rows.append(_stage_row(rnd, "healed", m, base_m)) log_event(run_dir, stage="round", **rec) logger.info(f"round {rnd}: auth_nats↓={m['auth_nats']:+.2f} care_nats={m['care_nats']:+.2f} " f"coh→={m['coherence']:.3f} cos_v0={cos_v0:+.2f} adapter_ppl={adapter_ppl:.0f}") _log_loop_summary(rounds, base_m) + _log_stage_table(stage_rows) write_map(run_dir, rounds) return rounds[-1]