diff --git a/scripts/plot_run.py b/scripts/plot_run.py new file mode 100644 index 0000000..b34750d --- /dev/null +++ b/scripts/plot_run.py @@ -0,0 +1,54 @@ +"""Re-render a run's trajectory plot from its events.jsonl, without re-running the loop. + +Why this exists: the loop imports plot.py at process start, so a plot-code fix never +reaches an already-running job. events.jsonl persists every stage (base/steered/healed), +so we can rebuild the `stages` list offline and call the CURRENT write_trajectory. + + uv run python scripts/plot_run.py --run-dir out/20260604T172126_gemma-3-4b-it_kl_rev_s42/ + +A run from before the `base` event was persisted has no base node; pass it explicitly +(tyro reads negatives as flags, so use `=`): + ... --base-auth=-2.35 --base-care=-1.30 --base-coh=0.996 +""" +import sys +from pathlib import Path + +import srsly +import tyro + +from steer_heal.plot import write_trajectory + + +def main(run_dir: Path, base_auth: float | None = None, + base_care: float | None = None, base_coh: float | None = None): + events = list(srsly.read_jsonl(run_dir / "events.jsonl")) + by_stage = lambda s: [e for e in events if e["stage"] == s] + + base_ev = by_stage("base") + if base_ev: + bm = base_ev[0] + base = {"auth_nats": bm["auth_nats"], "care_nats": bm["care_nats"], "coherence": bm["coherence"]} + else: + # older run: base wasn't persisted. require it on the CLI (fail fast, no silent default). + assert base_auth is not None, ( + f"{run_dir}/events.jsonl has no `base` event (ran before run.py persisted it); " + "pass --base-auth/--base-care/--base-coh from the run's base eval log line." + ) + base = {"auth_nats": base_auth, "care_nats": base_care, "coherence": base_coh} + + stages = [{"round": "-", "stage": "base", "m": base}] + steered = {e["round"]: e for e in by_stage("steered_eval")} + healed = {e["round"]: e for e in by_stage("round")} + for rnd in sorted(healed): # one steered + one healed per completed round, in order + for src, kind in [(steered, "steered"), (healed, "healed")]: + e = src[rnd] + stages.append({"round": rnd, "stage": kind, + "m": {"auth_nats": e["auth_nats"], "care_nats": e["care_nats"], + "coherence": e["coherence"]}}) + + png = write_trajectory(run_dir, stages) + print(f"re-rendered {png} from {len(stages)} stages ({len(healed)} rounds)", file=sys.stderr) + + +if __name__ == "__main__": + tyro.cli(main) diff --git a/src/steer_heal/plot.py b/src/steer_heal/plot.py index 998f70a..af56e02 100644 --- a/src/steer_heal/plot.py +++ b/src/steer_heal/plot.py @@ -86,17 +86,23 @@ def write_trajectory(run_dir: Path, stages: list[dict]) -> Path: marker=dict(size=14, color=GREY, symbol="star"), showlegend=False, hovertext=[f"base auth={bx:.3f} coh={by:.3f}"], hoverinfo="text", ), row=1, col=2) + # scatter, NOT a polyline: the left zigzag panel already carries round order, so a + # connecting line here would just duplicate it (and tangle at 10 rounds). The map's one + # job is WHERE the two populations land in trait-coherence space -- steered scatters left + # (more trait, more variance), healed clusters near base (the stall). Label only the + # extremes (r0 + last round) so the labels don't collide in the cluster. + last_rnd = max(p["round"] for p in stages if p["stage"] == "healed") for stage_kind, color, label in [("steered", RED, "steer"), ("healed", GREEN, "heal")]: pts = [s for s in stages if s["stage"] == stage_kind] - xs = [bx] + [p["m"]["auth_nats"] for p in pts] - ys = [by] + [p["m"]["coherence"] for p in pts] - txt = [""] + [f"r{p['round']}" for p in pts] - hov = [f"base"] + [f"{label} r{p['round']} auth={p['m']['auth_nats']:.3f} " - f"coh={p['m']['coherence']:.3f} care={p['m']['care_nats']:.3f}" for p in pts] + xs = [p["m"]["auth_nats"] for p in pts] + ys = [p["m"]["coherence"] for p in pts] + txt = [f"r{p['round']}" if p["round"] in (0, last_rnd) else "" for p in pts] + hov = [f"{label} r{p['round']} auth={p['m']['auth_nats']:.3f} " + f"coh={p['m']['coherence']:.3f} care={p['m']['care_nats']:.3f}" for p in pts] fig.add_trace(go.Scatter( - x=xs, y=ys, mode="lines+markers+text", text=txt, textposition="top center", - line=dict(color=color, width=2), marker=dict(size=11, color=color), - name=label, showlegend=False, hovertext=hov, hoverinfo="text", + x=xs, y=ys, mode="markers+text", text=txt, textposition="top center", + marker=dict(size=11, color=color), name=label, showlegend=False, + hovertext=hov, hoverinfo="text", ), row=1, col=2) fig.update_xaxes(title_text="auth_nats (← more trait)", row=1, col=2) # same fixed coherence range as the line panel: shows the points hug the ceiling (coherence diff --git a/src/steer_heal/run.py b/src/steer_heal/run.py index 5d01c16..1ad8308 100644 --- a/src/steer_heal/run.py +++ b/src/steer_heal/run.py @@ -105,6 +105,7 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict: # trait), not just coherence. One extra eval per run. logger.info(f"\n=== EVAL base [tinymfv classic] gpu {gpu_mem()} ===") base_m = evaluate_model(model, tok, cfg) + log_event(run_dir, stage="base", round=-1, **base_m) # persist so offline plot_run.py is self-contained stages = [{"round": "-", "stage": "base", "m": base_m}] # base -> steered -> healed, for table + trajectory plot for rnd in range(cfg.n_rounds): logger.info(f"\n\n=== ROUND {rnd} [{cfg.model.split('/')[-1]} reg={cfg.reg}] gpu {gpu_mem()} ===")