mirror of
https://github.com/wassname/steer-heal-love.git
synced 2026-06-27 17:02:34 +08:00
stage pareto table: base->steered->healed per round (dcoh/dauth, coh, auth, care)
Adds a steered-stage tinymfv eval per round (history baked, vector live at the operating dose = cleanest alpha, no new adapter) so the loop log shows the full base->steered->healed pareto, not just the healed endpoint. This is the apples-to-apples comparison: same baked base, trait via vector vs via the distilled adapter. dcoh/dauth = signed coherence change per nat of Authority change vs base. UAT: fast-dev-run exit 0 renders the 3-stage table. Cost: +1 eval per round. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -71,6 +71,28 @@ def _mean_finite(xs, label: str = "ppl") -> float:
|
||||
return sum(xs) / len(xs) if xs else float("nan")
|
||||
|
||||
|
||||
def _stage_row(rnd, stage: str, m: dict, base_m: dict) -> dict:
|
||||
"""One row of the base->steered->healed pareto table. dcoh/dauth = coherence
|
||||
CHANGE per nat of Authority CHANGE vs base (signed): positive = coherence lost
|
||||
while trait gained (both fall), the cost we want low; nan for the base row (0/0)."""
|
||||
dAuth = m["auth_nats"] - base_m["auth_nats"]
|
||||
dCoh = m["coherence"] - base_m["coherence"]
|
||||
ratio = dCoh / dAuth if abs(dAuth) > 1e-6 else float("nan")
|
||||
return {"round": rnd, "stage": stage, "dcoh/dauth": ratio,
|
||||
"coh": m["coherence"], "auth": m["auth_nats"], "care": m["care_nats"]}
|
||||
|
||||
|
||||
def _log_stage_table(stage_rows: list[dict]) -> None:
|
||||
from tabulate import tabulate
|
||||
logger.info(
|
||||
"\nstage pareto (base -> steered -> healed, per round):\n"
|
||||
" dcoh/dauth = coherence change per nat of Authority change vs base (signed, lower=cheaper trait)\n"
|
||||
" coh = p_any_ans coherence (hold ~1.0) auth = log p[Authority] (DOWN = trait) care = log p[Care] (off-target)\n"
|
||||
" WIN: healed keeps steered's low auth (trait) but recovers coh toward base AND a smaller dcoh/dauth than steered.\n"
|
||||
" UNDO: healed auth springs back to ~base while coh recovers -> heal removed the trait, not just the incoherence.\n"
|
||||
+ tabulate(stage_rows, headers="keys", tablefmt="github", floatfmt=".3f") + "\n")
|
||||
|
||||
|
||||
def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
||||
hist_specs = [] # AdapterSpec per folded round (gated bake history)
|
||||
v0_flat = None # round-0 direction, for the Q3 cosine
|
||||
@@ -80,12 +102,21 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
||||
# trait), not just coherence. One extra eval per run.
|
||||
logger.info(f"\n=== EVAL base [tinymfv classic] gpu {gpu_mem()} ===")
|
||||
base_m = evaluate_model(model, tok, cfg)
|
||||
stage_rows = [_stage_row("-", "base", base_m, base_m)] # pareto table: base -> steered -> healed
|
||||
for rnd in range(cfg.n_rounds):
|
||||
logger.info(f"\n\n=== ROUND {rnd} [{cfg.model.split('/')[-1]} reg={cfg.reg}] gpu {gpu_mem()} ===")
|
||||
# extract teacher vector + sweep-generate steered data from the CURRENT student
|
||||
with baked(model, hist_specs):
|
||||
v = teacher_vec(model, tok, cfg)
|
||||
comps = generate_steered(model, tok, v, cfg)
|
||||
# STEERED-stage eval: the model state the training data came from (history baked,
|
||||
# vector live at the operating dose = lowest/cleanest alpha, NO new adapter). This
|
||||
# is the raw-steering pareto reference the heal must BEAT (same base, trait via
|
||||
# vector vs trait via the distilled adapter).
|
||||
c_op = cfg.alphas[0] * v.cfg.coeff
|
||||
logger.info(f"\n=== EVAL steered [c={cfg.alphas[0]}] gpu {gpu_mem()} ===")
|
||||
with v(model, C=c_op):
|
||||
m_steer = evaluate_model(model, tok, cfg)
|
||||
# filter under the ORIGINAL (no history, no steering) -- this picks the usable C
|
||||
logger.info(f"\n=== FILTER [{len(comps)} completions] gpu {gpu_mem()} ===")
|
||||
kept, scored = filter_completions(model, tok, comps, cfg)
|
||||
@@ -120,11 +151,14 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
||||
"adapter_ppl": adapter_ppl, "n_comps": len(comps), "n_kept": len(kept),
|
||||
"heal_nll": heal_nll}
|
||||
rounds.append(rec)
|
||||
stage_rows.append(_stage_row(rnd, "steered", m_steer, base_m))
|
||||
stage_rows.append(_stage_row(rnd, "healed", m, base_m))
|
||||
log_event(run_dir, stage="round", **rec)
|
||||
logger.info(f"round {rnd}: auth_nats↓={m['auth_nats']:+.2f} care_nats={m['care_nats']:+.2f} "
|
||||
f"coh→={m['coherence']:.3f} cos_v0={cos_v0:+.2f} adapter_ppl={adapter_ppl:.0f}")
|
||||
|
||||
_log_loop_summary(rounds, base_m)
|
||||
_log_stage_table(stage_rows)
|
||||
write_map(run_dir, rounds)
|
||||
return rounds[-1]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user