diff --git a/docs/love_loop.png b/docs/love_loop.png index f924a40..6f016a5 100644 Binary files a/docs/love_loop.png and b/docs/love_loop.png differ diff --git a/src/steer_heal/plot.py b/src/steer_heal/plot.py index 070ff1e..2916110 100644 --- a/src/steer_heal/plot.py +++ b/src/steer_heal/plot.py @@ -69,17 +69,21 @@ def _connectors(fig, row, col, axis, base_xy, steered_xys, healed_xys): _tip(fig, trend[-2], trend[-1], axis, TREND, 1.5) -def write_trajectory(run_dir: Path, stages: list[dict]) -> Path: +def write_trajectory(run_dir: Path, stages: list[dict], primary_key: str = "care_nats") -> Path: """stages: ordered list of {round, stage in {base,steered,healed}, m: eval-dict}. - The eval-dict carries auth_nats, care_nats, coherence.""" - auth = [s["m"]["auth_nats"] for s in stages] - coh = [s["m"]["coherence"] for s in stages] - care = [s["m"]["care_nats"] for s in stages] + primary_key: the on-target eval axis for Panel A (e.g. care_nats for love, auth_nats for authority). + Panel C shows primary on x vs the biggest-moving off-target foundation on y.""" + # Build signals from all *_nats keys present in the eval dict + coherence + m0 = stages[0]["m"] + nat_keys = sorted(k for k in m0 if k.endswith("_nats")) + signals = {k: [s["m"][k] for s in stages] for k in nat_keys} + signals["coh"] = [s["m"]["coherence"] for s in stages] + coh = signals["coh"] + primary = primary_key # full key e.g. "care_nats" + kind = [s["stage"] for s in stages] - # x of the zigzag = pipeline order; label each tick base / r0·steer / r0·heal / ... xi = list(range(len(stages))) xlab = ["base" if k == "base" else f"r{s['round']}·{k[:5]}" for s, k in zip(stages, kind)] - col = [GREY if k == "base" else RED if k == "steered" else GREEN for k in kind] fig = make_subplots( rows=2, cols=2, column_widths=[0.52, 0.48], row_heights=[0.5, 0.5], @@ -88,10 +92,6 @@ def write_trajectory(run_dir: Path, stages: list[dict]) -> Path: [{"type": "scatter"}, None]], ) - # all 3 panels share ONE visual language (_connectors): dotted grey steer->heal moves - # + a thin green-grey trend through base->heals, both BEHIND the markers. Left panels use - # pipeline-order x; the map uses auth-x. idx groups stage rows so each panel can pull its - # own (x,y) for base / steered / healed in the same call. bi = kind.index("base") si = [i for i, k in enumerate(kind) if k == "steered"] hi = [i for i, k in enumerate(kind) if k == "healed"] @@ -100,29 +100,21 @@ def write_trajectory(run_dir: Path, stages: list[dict]) -> Path: # coh=1 (incoherence 0.001-0.05); a collapse round (coh~0.6 -> incoherence ~0.4) is a single # outlier that on a linear coherence axis flattens every healthy round into one band. log(1-coh) # gives each near-perfect round its own decade and squashes the outlier. Clamp incoherence at - # 1e-3 (coh>=0.999) to dodge log(0). Both stacked panels now read DOWN = wanted (auth down = - # trait, incoherence down = coherent). + # 1e-3 (coh>=0.999) to dodge log(0). Both stacked panels now read DOWN = wanted. inc = [max(1.0 - c, 1e-3) for c in coh] - signals = {"auth": auth, "care": care, "coh": coh} map_ids = [bi] + hi rng = lambda k: max(signals[k][i] for i in map_ids) - min(signals[k][i] for i in map_ids) - # Panel A tracks whichever trait moves most over base+heal (coh excluded; Panel B has it) - top_key = max(["auth", "care"], key=rng) - # PANEL A (top trait over pipeline, linear) and PANEL B (incoherence, log): x = pipeline index. - # Both keep red steer (A is the zigzag, B's red dots show the incoherence steering injects). - # hover shows the raw value (coh for B, trait for A); only B's y-axis is logged. - # x-tick labels only at key positions (base, first/last heal) to avoid dense overlap + # PANEL A: on-target axis over the full pipeline (shows the trait being steered) key_xi = [xi[bi]] + ([xi[si[0]]] if si else []) + [xi[hi[0]]] + ([xi[hi[-1]]] if len(hi) > 1 else []) key_xlab = [xlab[bi]] + ([xlab[si[0]]] if si else []) + [xlab[hi[0]]] + ([xlab[hi[-1]]] if len(hi) > 1 else []) for axis, row, yv, raw, ytitle, ylog in [ - (1, 1, signals[top_key], signals[top_key], f"{top_key}_nats", False), + (1, 1, signals[primary], signals[primary], primary, False), (3, 2, inc, coh, "1−coherence (↓, log)", True), ]: _connectors(fig, row, 1, axis, (xi[bi], yv[bi]), [(xi[i], yv[i]) for i in si], [(xi[i], yv[i]) for i in hi]) - # steered points recede (smaller, lower opacity) — the heal trajectory is the story for ids, c, sym, sz, op in [([bi], GREY, "star", 13, 1.0), (si, RED, "circle", 8, 0.6), (hi, GREEN, "circle", 10, 1.0)]: fig.add_trace(go.Scatter( x=[xi[i] for i in ids], y=[yv[i] for i in ids], mode="markers", @@ -134,37 +126,35 @@ def write_trajectory(run_dir: Path, stages: list[dict]) -> Path: fig.update_xaxes(tickmode="array", tickvals=key_xi, ticktext=key_xlab, tickangle=-30, row=2, col=1, showgrid=False) fig.update_xaxes(showgrid=False, tickvals=[], row=1, col=1) - # PANEL C (trait map): axes = the two biggest-MOVING of auth/care/coh over base+heal nodes. - # Healthy -> auth vs care (the moral-foundations plane); if coherence CRASHED its range beats - # care and it becomes the y-axis. RED steer is omitted here: zoomed to the heal cluster the - # steer points fall off-scale and leave dangling connector stubs. base + green heals only. - atitle = {"auth": "auth_nats (← more trait)", "care": "care_nats (more care →)"} - xkey, ykey = sorted(sorted(["auth", "care", "coh"], key=rng, reverse=True)[:2], - key=["auth", "care", "coh"].index) # x = higher-priority of the chosen two - # coh can only ever be the LOWEST-priority pick, so it lands on Y, never X. When it does - # (a crash run) plot it as log-incoherence to match panel B; else raw care/auth. + # PANEL C (trait map): x = primary (on-target), y = biggest-moving off-target foundation. + # If coherence crashed its range beats all foundations and it takes y (crash diagnostic). + # RED steer omitted: steered points fall off-scale and leave dangling connector stubs. + off_target_keys = [k for k in nat_keys if k != primary] + ykey = max(off_target_keys, key=rng) + if rng("coh") > rng(ykey): # coherence crash dominates; show as log-incoherence + ykey = "coh" ycoh = ykey == "coh" - xv = signals[xkey] + xv = signals[primary] yv = [max(1.0 - v, 1e-3) for v in signals[ykey]] if ycoh else signals[ykey] - yraw = signals[ykey] # for hover (real coherence / care value, not the log-incoherence coord) + yraw = signals[ykey] _connectors(fig, 1, 2, 2, (xv[bi], yv[bi]), [], [(xv[i], yv[i]) for i in hi]) fig.add_trace(go.Scatter( x=[xv[bi]], y=[yv[bi]], mode="markers+text", text=["base"], textposition="bottom center", marker=dict(size=14, color=GREY, symbol="star"), showlegend=False, - hovertext=[f"base {xkey}={xv[bi]:.3f} {ykey}={yraw[bi]:.3f}"], hoverinfo="text"), row=1, col=2) + hovertext=[f"base {primary}={xv[bi]:.3f} {ykey}={yraw[bi]:.3f}"], hoverinfo="text"), row=1, col=2) txt = [f"r{stages[i]['round']}" if stages[i]["round"] in (0, last_rnd) else "" for i in hi] - hov = [f"heal r{stages[i]['round']} auth={auth[i]:.3f} care={care[i]:.3f} coh={coh[i]:.3f}" for i in hi] + hov = [f"heal r{stages[i]['round']} " + " ".join(f"{k}={signals[k][i]:.3f}" for k in nat_keys) + f" coh={coh[i]:.3f}" for i in hi] fig.add_trace(go.Scatter( x=[xv[i] for i in hi], y=[yv[i] for i in hi], mode="markers+text", text=txt, textposition="bottom center", marker=dict(size=9, color=GREEN), showlegend=False, hovertext=hov, hoverinfo="text"), row=1, col=2) - fig.update_xaxes(title_text=atitle[xkey], row=1, col=2) + fig.update_xaxes(title_text=primary, row=1, col=2) if ycoh: fig.update_yaxes(title_text="incoherence 1−coh (↓ coherent, log)", type="log", row=1, col=2) - fig.add_hline(y=0.05, line=dict(color="#cccccc", width=1, dash="dot"), row=1, col=2) # coh=0.95 + fig.add_hline(y=0.05, line=dict(color="#cccccc", width=1, dash="dot"), row=1, col=2) else: - fig.update_yaxes(title_text=atitle[ykey], row=1, col=2) + fig.update_yaxes(title_text=ykey, row=1, col=2) fig.update_xaxes(showgrid=False, row=1, col=2) fig.update_yaxes(showgrid=False, row=1, col=2) diff --git a/src/steer_heal/run.py b/src/steer_heal/run.py index 3d73d9e..84ac1ff 100644 --- a/src/steer_heal/run.py +++ b/src/steer_heal/run.py @@ -104,7 +104,7 @@ def _log_stage_table(stages: list[dict], base_m: dict) -> None: + tabulate([_stage_row(s, base_m) for s in stages], headers="keys", tablefmt="github", floatfmt=".3f") + "\n") -def gen_filter_walk(model, tok, v, cfg: RunConfig, hist_specs: list) -> tuple[list[dict], list[dict], float, int]: +def gen_filter_walk(model, tok, v, cfg: RunConfig, hist_specs: list, rnd: int) -> tuple[list[dict], list[dict], float, int]: """Adaptive-dose gen+filter (the controller steering.py:65 was written for). Walk the dose multiplier kappa DOWN until a batch clears cfg.gen_pass_target filter @@ -133,8 +133,8 @@ def gen_filter_walk(model, tok, v, cfg: RunConfig, hist_specs: list) -> tuple[li for step in range(cfg.gen_max_batches): mid_k = math.exp(0.5 * (math.log(lo_k) + math.log(hi_k))) with baked(model, hist_specs): - probe = generate_steered(model, tok, v, cfg, alpha_scale=mid_k, max_gens=cfg.gen_probe_n) - _, probe_scored = filter_completions(model, tok, probe, cfg) + probe = generate_steered(model, tok, v, cfg, alpha_scale=mid_k, max_gens=cfg.gen_probe_n, rnd=rnd) + _, probe_scored = filter_completions(model, tok, probe, cfg, brief=True) probe_pass = [s for s in probe_scored if s["keep"]] rate = len(probe_pass) / max(len(probe_scored), 1) n_gen += len(probe) @@ -172,8 +172,8 @@ def gen_filter_walk(model, tok, v, cfg: RunConfig, hist_specs: list) -> tuple[li if len(kept_all) >= cfg.n_keep: break with baked(model, hist_specs): - comps = generate_steered(model, tok, v, cfg, alpha_scale=kappa) - _, scored = filter_completions(model, tok, comps, cfg) + comps = generate_steered(model, tok, v, cfg, alpha_scale=kappa, rnd=rnd) + _, scored = filter_completions(model, tok, comps, cfg, brief=True) passing = [s for s in scored if s["keep"]] kept_all.extend(passing) scored_all.extend(scored) @@ -183,6 +183,16 @@ def gen_filter_walk(model, tok, v, cfg: RunConfig, hist_specs: list) -> tuple[li f"(rate={len(passing)/max(len(comps),1):.2f}) → banked {len(kept_all)}/{cfg.n_keep}" ) + # ONE raw sample per calibration (the probes/collect ran brief, count-only): the cleanest + # kept completion at the settled dose IS the training data -- judge coherence+trait by eye. + # Full per-alpha table + borderline samples live in the saved scored (log_event stage=gen). + kept_final = [s for s in scored_all if s["keep"]] + if kept_final: + best = min(kept_final, key=lambda s: s["ppl"]) + logger.info( + f"\n\n\n=== r{rnd} walk-C SAMPLE (cleanest kept: alpha={best['alpha']:g} ppl={best['ppl']:.0f}) ===\n" + "SHOULD: coherent + on-trait (this is what trains the adapter). ELSE dose/filter off.\n" + f"{best['completion']}") return kept_all[: cfg.n_keep], scored_all, kappa, n_gen # cap training set at n_keep (top-up may overshoot) @@ -195,7 +205,7 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict: # Base (no adapter, no steering) eval ONCE, so the run is self-contained: the # headline cue is coh_cost = |dCoh|/|dAuth| vs base (coherence lost per nat of # trait), not just coherence. One extra eval per run. - logger.info(f"\n=== EVAL base [tinymfv classic] gpu {gpu_mem()} ===") + logger.info(f"\n\n\n=== EVAL base [tinymfv classic] gpu {gpu_mem()} ===") base_m = evaluate_model(model, tok, cfg, log_sample=True) # one FULL eval gen (token-efficient-logging) log_event(run_dir, stage="base", round=-1, **base_m) # persist so offline plot_run.py is self-contained stages = [{"round": "-", "stage": "base", "m": base_m}] # base -> steered -> healed, for table + trajectory plot @@ -215,13 +225,13 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict: "before the loop melts. demo=authority: defers to authority. ELSE chat-template/formatting issue.\n" f"PROMPT: {b0['prompt']}\nCOMPLETION: {b0['completion']}") for rnd in range(cfg.n_rounds): - logger.info(f"\n\n=== ROUND {rnd} [{cfg.model.split('/')[-1]} reg={cfg.reg}] gpu {gpu_mem()} ===") + logger.info(f"\n\n\n\n=== ROUND {rnd} [{cfg.model.split('/')[-1]} reg={cfg.reg}] gpu {gpu_mem()} ===") # extract teacher vector from the CURRENT student, then walk-C generate+filter: # the controller cools the dose so the steered data stays coherent as the adapter # accumulates trait over rounds (gen baked, filter under original -- see gen_filter_walk). with baked(model, hist_specs): v = teacher_vec(model, tok, cfg) - kept, scored, kappa, n_comps = gen_filter_walk(model, tok, v, cfg, hist_specs) + kept, scored, kappa, n_comps = gen_filter_walk(model, tok, v, cfg, hist_specs, rnd) # collect highest-alpha dropped sample for headline prompt -> diary Night entry headline = gen_rounds[0]["gens"][0]["user"] dream_cands = [s for s in scored if s["user"] == headline and not s.get("keep", True)] @@ -233,7 +243,7 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict: # STEERED-stage eval at the dose the data ACTUALLY came from (kappa-scaled cleanest alpha), # history baked, NO new adapter: the raw-steering pareto reference the heal must BEAT. c_lo = kappa * cfg.alphas[0] - logger.info(f"\n=== EVAL steered [c={c_lo:.2f} kappa={kappa:.2f}] gpu {gpu_mem()} ===") + logger.info(f"\n\n\n=== r{rnd} EVAL steered [c={c_lo:.2f} kappa={kappa:.2f}] gpu {gpu_mem()} ===") with baked(model, hist_specs): with v(model, C=c_lo * v.cfg.coeff): m_steer = evaluate_model(model, tok, cfg) @@ -241,13 +251,13 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict: log_event(run_dir, stage="gen", round=rnd, n_comps=n_comps, n_kept=len(kept), kappa=kappa, scored=scored) # heal one round on top of the baked history, then fold - logger.info(f"\n=== HEAL [{cfg.reg}] gpu {gpu_mem()} ===") + logger.info(f"\n\n\n=== r{rnd} HEAL [{cfg.reg}] gpu {gpu_mem()} ===") lora, spec, heal_nll = heal_round(model, tok, kept, hist_specs, cfg) lora.save(str(run_dir / "ckpt" / f"r{rnd}.safetensors"), extra_meta={"round": str(rnd), "reg": cfg.reg}) hist_specs.append(spec) # eval the student (all rounds baked) + Q1: trained-adapter output coherence - logger.info(f"\n=== EVAL [tinymfv classic] gpu {gpu_mem()} ===") + logger.info(f"\n\n\n=== r{rnd} EVAL [tinymfv classic] gpu {gpu_mem()} ===") with baked(model, hist_specs): m = evaluate_model(model, tok, cfg) adapter = generate_plain(model, tok, cfg, n=min(6, cfg.n_prompts)) @@ -278,7 +288,7 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict: # model is lukewarm/guarded about); if no trait at all = no-op. demo_lines = "\n".join( f" [{a['user'][:50]}]\n {' '.join(a['completion'].split())[:240]}" for a in adapter) - logger.info(f"\n=== ADAPTER DEMO r{rnd} coh(p_ans_any)={m['coherence']:.3f} adapter_ppl={adapter_ppl:.0f} " + logger.info(f"\n\n\n=== ADAPTER DEMO r{rnd} coh(p_ans_any)={m['coherence']:.3f} adapter_ppl={adapter_ppl:.0f} " f"(no steering; compare across rounds: change vs saturation) ===\n" + demo_lines) vf = _flatten_v(v) @@ -305,7 +315,8 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict: _log_loop_summary(rounds, base_m) _log_stage_table(stages, base_m) write_map(run_dir, rounds) - png = write_trajectory(run_dir, stages) # before the report (report embeds trajectory.png) + primary_key = "care_nats" if cfg.demo == "love" else "auth_nats" + png = write_trajectory(run_dir, stages, primary_key=primary_key) report_html = write_report(run_dir, gen_rounds) diary = write_diary(run_dir, cfg, gen_rounds, steer_samples, rounds, base_m["care_nats"]) logger.info(f"diary: {diary}")