mirror of
https://github.com/wassname/steer-heal-love.git
synced 2026-06-27 17:02:34 +08:00
trajectory map = scatter not polyline (scales to 10 rounds); persist base event; offline plot_run.py
The pareto map drew a base->r0->...->rN polyline per arm, which tangled at 10 rounds and duplicated the left zigzag's round-order info. Make it a scatter that just shows WHERE steered/healed land, labelling only r0 + last round. Persist the base eval as an event so the loop's plot is reproducible offline, and add scripts/plot_run.py to re-render trajectory.png from events.jsonl without re-running the 3h loop (needed because the loop imports plot.py at start, so a plot fix never reaches a running job). Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,54 @@
|
|||||||
|
"""Re-render a run's trajectory plot from its events.jsonl, without re-running the loop.
|
||||||
|
|
||||||
|
Why this exists: the loop imports plot.py at process start, so a plot-code fix never
|
||||||
|
reaches an already-running job. events.jsonl persists every stage (base/steered/healed),
|
||||||
|
so we can rebuild the `stages` list offline and call the CURRENT write_trajectory.
|
||||||
|
|
||||||
|
uv run python scripts/plot_run.py --run-dir out/20260604T172126_gemma-3-4b-it_kl_rev_s42/
|
||||||
|
|
||||||
|
A run from before the `base` event was persisted has no base node; pass it explicitly
|
||||||
|
(tyro reads negatives as flags, so use `=`):
|
||||||
|
... --base-auth=-2.35 --base-care=-1.30 --base-coh=0.996
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import srsly
|
||||||
|
import tyro
|
||||||
|
|
||||||
|
from steer_heal.plot import write_trajectory
|
||||||
|
|
||||||
|
|
||||||
|
def main(run_dir: Path, base_auth: float | None = None,
|
||||||
|
base_care: float | None = None, base_coh: float | None = None):
|
||||||
|
events = list(srsly.read_jsonl(run_dir / "events.jsonl"))
|
||||||
|
by_stage = lambda s: [e for e in events if e["stage"] == s]
|
||||||
|
|
||||||
|
base_ev = by_stage("base")
|
||||||
|
if base_ev:
|
||||||
|
bm = base_ev[0]
|
||||||
|
base = {"auth_nats": bm["auth_nats"], "care_nats": bm["care_nats"], "coherence": bm["coherence"]}
|
||||||
|
else:
|
||||||
|
# older run: base wasn't persisted. require it on the CLI (fail fast, no silent default).
|
||||||
|
assert base_auth is not None, (
|
||||||
|
f"{run_dir}/events.jsonl has no `base` event (ran before run.py persisted it); "
|
||||||
|
"pass --base-auth/--base-care/--base-coh from the run's base eval log line."
|
||||||
|
)
|
||||||
|
base = {"auth_nats": base_auth, "care_nats": base_care, "coherence": base_coh}
|
||||||
|
|
||||||
|
stages = [{"round": "-", "stage": "base", "m": base}]
|
||||||
|
steered = {e["round"]: e for e in by_stage("steered_eval")}
|
||||||
|
healed = {e["round"]: e for e in by_stage("round")}
|
||||||
|
for rnd in sorted(healed): # one steered + one healed per completed round, in order
|
||||||
|
for src, kind in [(steered, "steered"), (healed, "healed")]:
|
||||||
|
e = src[rnd]
|
||||||
|
stages.append({"round": rnd, "stage": kind,
|
||||||
|
"m": {"auth_nats": e["auth_nats"], "care_nats": e["care_nats"],
|
||||||
|
"coherence": e["coherence"]}})
|
||||||
|
|
||||||
|
png = write_trajectory(run_dir, stages)
|
||||||
|
print(f"re-rendered {png} from {len(stages)} stages ({len(healed)} rounds)", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tyro.cli(main)
|
||||||
+14
-8
@@ -86,17 +86,23 @@ def write_trajectory(run_dir: Path, stages: list[dict]) -> Path:
|
|||||||
marker=dict(size=14, color=GREY, symbol="star"), showlegend=False,
|
marker=dict(size=14, color=GREY, symbol="star"), showlegend=False,
|
||||||
hovertext=[f"base auth={bx:.3f} coh={by:.3f}"], hoverinfo="text",
|
hovertext=[f"base auth={bx:.3f} coh={by:.3f}"], hoverinfo="text",
|
||||||
), row=1, col=2)
|
), row=1, col=2)
|
||||||
|
# scatter, NOT a polyline: the left zigzag panel already carries round order, so a
|
||||||
|
# connecting line here would just duplicate it (and tangle at 10 rounds). The map's one
|
||||||
|
# job is WHERE the two populations land in trait-coherence space -- steered scatters left
|
||||||
|
# (more trait, more variance), healed clusters near base (the stall). Label only the
|
||||||
|
# extremes (r0 + last round) so the labels don't collide in the cluster.
|
||||||
|
last_rnd = max(p["round"] for p in stages if p["stage"] == "healed")
|
||||||
for stage_kind, color, label in [("steered", RED, "steer"), ("healed", GREEN, "heal")]:
|
for stage_kind, color, label in [("steered", RED, "steer"), ("healed", GREEN, "heal")]:
|
||||||
pts = [s for s in stages if s["stage"] == stage_kind]
|
pts = [s for s in stages if s["stage"] == stage_kind]
|
||||||
xs = [bx] + [p["m"]["auth_nats"] for p in pts]
|
xs = [p["m"]["auth_nats"] for p in pts]
|
||||||
ys = [by] + [p["m"]["coherence"] for p in pts]
|
ys = [p["m"]["coherence"] for p in pts]
|
||||||
txt = [""] + [f"r{p['round']}" for p in pts]
|
txt = [f"r{p['round']}" if p["round"] in (0, last_rnd) else "" for p in pts]
|
||||||
hov = [f"base"] + [f"{label} r{p['round']} auth={p['m']['auth_nats']:.3f} "
|
hov = [f"{label} r{p['round']} auth={p['m']['auth_nats']:.3f} "
|
||||||
f"coh={p['m']['coherence']:.3f} care={p['m']['care_nats']:.3f}" for p in pts]
|
f"coh={p['m']['coherence']:.3f} care={p['m']['care_nats']:.3f}" for p in pts]
|
||||||
fig.add_trace(go.Scatter(
|
fig.add_trace(go.Scatter(
|
||||||
x=xs, y=ys, mode="lines+markers+text", text=txt, textposition="top center",
|
x=xs, y=ys, mode="markers+text", text=txt, textposition="top center",
|
||||||
line=dict(color=color, width=2), marker=dict(size=11, color=color),
|
marker=dict(size=11, color=color), name=label, showlegend=False,
|
||||||
name=label, showlegend=False, hovertext=hov, hoverinfo="text",
|
hovertext=hov, hoverinfo="text",
|
||||||
), row=1, col=2)
|
), row=1, col=2)
|
||||||
fig.update_xaxes(title_text="auth_nats (← more trait)", row=1, col=2)
|
fig.update_xaxes(title_text="auth_nats (← more trait)", row=1, col=2)
|
||||||
# same fixed coherence range as the line panel: shows the points hug the ceiling (coherence
|
# same fixed coherence range as the line panel: shows the points hug the ceiling (coherence
|
||||||
|
|||||||
@@ -105,6 +105,7 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
|||||||
# trait), not just coherence. One extra eval per run.
|
# trait), not just coherence. One extra eval per run.
|
||||||
logger.info(f"\n=== EVAL base [tinymfv classic] gpu {gpu_mem()} ===")
|
logger.info(f"\n=== EVAL base [tinymfv classic] gpu {gpu_mem()} ===")
|
||||||
base_m = evaluate_model(model, tok, cfg)
|
base_m = evaluate_model(model, tok, cfg)
|
||||||
|
log_event(run_dir, stage="base", round=-1, **base_m) # persist so offline plot_run.py is self-contained
|
||||||
stages = [{"round": "-", "stage": "base", "m": base_m}] # base -> steered -> healed, for table + trajectory plot
|
stages = [{"round": "-", "stage": "base", "m": base_m}] # base -> steered -> healed, for table + trajectory plot
|
||||||
for rnd in range(cfg.n_rounds):
|
for rnd in range(cfg.n_rounds):
|
||||||
logger.info(f"\n\n=== ROUND {rnd} [{cfg.model.split('/')[-1]} reg={cfg.reg}] gpu {gpu_mem()} ===")
|
logger.info(f"\n\n=== ROUND {rnd} [{cfg.model.split('/')[-1]} reg={cfg.reg}] gpu {gpu_mem()} ===")
|
||||||
|
|||||||
Reference in New Issue
Block a user