Files
persona-steering-template-l…/scripts/readme_plot.py
T
2026-06-25 13:23:34 +08:00

99 lines
3.3 KiB
Python

from __future__ import annotations
import html
from pathlib import Path
import textwrap
from typing import Any
import plotly.graph_objects as go
import docs_results
MAIN_PNG = docs_results.ROOT / "out/on_off_axis.png"
MAIN_SVG = docs_results.ROOT / "out/on_off_axis.svg"
def _wrap_hover(text: str, width: int = 62) -> str:
text = docs_results.display_template_text(text)
escaped = html.escape(" ".join(text.split()))
return "<br>".join(
textwrap.wrap(escaped, width=width, break_long_words=True, break_on_hyphens=False))
def main_plot_rows(path: Path = docs_results.NORMAL_TEMPLATE_PAIR_STATS) -> list[dict[str, Any]]:
return docs_results.mean_template_rows(docs_results.read_jsonl(path))
def template_scatter(rows: list[dict[str, Any]] | None = None) -> go.Figure:
rows = main_plot_rows() if rows is None else rows
top_rank = {row["template"]: i for i, row in enumerate(rows[:10], start=1)}
text = [str(top_rank[row["template"]]) if row["template"] in top_rank else "" for row in rows]
hover = [
"<br>".join([
f"<b>{_wrap_hover(row['template'])}</b>",
f"rank: {i}",
f"score t: {row['score_t']:.2f}",
f"score mean: {row['score_mean']:.2f}",
f"axis delta: {row['axis_delta']:.2f}",
f"off-axis problem: {row['off_axis_problem']:.2f}",
f"judge std: {row['judge_std']:.2f}",
f"cells: {row['n_cells']}",
])
for i, row in enumerate(rows, start=1)
]
fig = go.Figure(
data=go.Scatter(
x=[row["on_axis"] for row in rows],
y=[row["off_axis"] for row in rows],
mode="markers+text",
text=text,
textposition="middle center",
textfont={"size": 9, "color": "white"},
customdata=hover,
hovertemplate="%{customdata}<extra></extra>",
marker={
"size": 10,
"color": [row["score_t"] for row in rows],
"colorscale": "Cividis",
"showscale": True,
"colorbar": {"title": "score t"},
"line": {"width": 0.5, "color": "white"},
"opacity": 0.9,
},
)
)
fig.update_layout(
autosize=True,
width=960,
height=620,
template="plotly_white",
margin={"l": 68, "r": 24, "t": 28, "b": 66},
xaxis={
"title": "on-axis movement, higher is better",
"range": [-0.02, 1.02],
"gridcolor": "rgba(0,0,0,0.08)",
},
yaxis={
"title": "off-axis confounding, lower is better",
"range": [-0.02, 1.02],
"gridcolor": "rgba(0,0,0,0.08)",
},
annotations=[{
"text": "normal pilot scenarios; one point per measured template",
"xref": "paper",
"yref": "paper",
"x": 1.0,
"y": -0.13,
"showarrow": False,
"font": {"size": 11, "color": "rgba(0,0,0,0.62)"},
}],
)
return fig
def write_main_plot_assets() -> None:
fig = template_scatter()
MAIN_PNG.parent.mkdir(parents=True, exist_ok=True)
fig.write_image(MAIN_PNG, width=960, height=620, scale=2)
fig.write_image(MAIN_SVG, width=960, height=620)