docs: simplify model matrix visualization

This commit is contained in:
wassname
2026-06-25 12:20:35 +08:00
parent 026b22e131
commit d31cac9068
5 changed files with 227 additions and 275 deletions
+12 -32
View File
@@ -152,11 +152,6 @@ def _write_markdown(path: Path, template_rows: list[dict[str, Any]], pair_rows:
{
"score p25": f"{row['score_p25']:.2f}",
"score mean": f"{row['score_mean']:.2f}",
"score std": f"{row['score_std']:.2f}",
"pass mean": f"{row['strict_pass_rate_mean']:.2f}",
"echo rate": f"{row['persona_echo_rate_mean']:.2f}",
"refusal rate": f"{row['refusal_or_ai_break_rate_mean']:.2f}",
"models": row["model_count"],
"template": _markdown_text(row["template"]),
}
for row in template_rows[:top_n]
@@ -176,24 +171,13 @@ def _write_markdown(path: Path, template_rows: list[dict[str, Any]], pair_rows:
def _plot(path: Path, rows: list[dict[str, Any]], label_count: int) -> None:
fig, ax = plt.subplots(figsize=(8.2, 5.6), dpi=180)
fig, ax = plt.subplots(figsize=(7.4, 5.0), dpi=180)
xs = [_clamp01(row["axis_delta_mean"] / 8.0) for row in rows]
ys = [_clamp01((row["off_axis_problem_mean"] - 1.0) / 6.0) for row in rows]
colors = ["black" if row["strict_pass_rate_mean"] > 0 else "0.65" for row in rows]
colors = ["0.12" if row["strict_pass_rate_mean"] > 0 else "0.72" for row in rows]
ax.scatter(xs, ys, s=28, c=colors, alpha=0.82, linewidths=0, zorder=2)
ax.scatter(xs, ys, s=22, c=colors, alpha=0.9, linewidths=0, zorder=2)
top_ids = {id(row): i for i, row in enumerate(rows[:label_count], start=1)}
top_rows = rows[:label_count]
ax.errorbar(
[_clamp01(row["axis_delta_mean"] / 8.0) for row in top_rows],
[_clamp01((row["off_axis_problem_mean"] - 1.0) / 6.0) for row in top_rows],
xerr=[row["axis_delta_std"] / (8.0 * math.sqrt(row["model_count"])) for row in top_rows],
yerr=[row["off_axis_problem_std"] / (6.0 * math.sqrt(row["model_count"])) for row in top_rows],
fmt="none",
ecolor="0.55",
elinewidth=0.8,
zorder=1,
)
for row in rows:
if id(row) not in top_ids:
continue
@@ -205,27 +189,23 @@ def _plot(path: Path, rows: list[dict[str, Any]], label_count: int) -> None:
str(top_ids[id(row)]),
ha="center",
va="center",
fontsize=6.5,
fontsize=6.2,
color="white",
zorder=3,
)
ax.set_xlim(-0.02, 1.02)
ax.set_ylim(-0.02, 1.02)
ax.set_xlabel("mean on-axis movement")
ax.set_ylabel("mean off-axis confounding")
ax.set_title("Refusal probe templates across clean model artifacts", fontsize=10)
ax.text(
1.0,
-0.13,
"error bars are model SEM; point numbers match the first table rows",
transform=ax.transAxes,
ha="right",
fontsize=8,
)
ax.grid(True, color="0.9", linewidth=0.6)
ax.set_xlabel("template on-axis movement, higher is better", fontsize=9)
ax.set_ylabel("template off-axis confounding, lower is better", fontsize=9)
ax.grid(True, color="0.92", linewidth=0.45)
ax.tick_params(axis="both", labelsize=8, length=3, width=0.7, color="0.25")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["left"].set_color("0.25")
ax.spines["bottom"].set_color("0.25")
ax.spines["left"].set_linewidth(0.7)
ax.spines["bottom"].set_linewidth(0.7)
path.parent.mkdir(parents=True, exist_ok=True)
fig.tight_layout()
fig.savefig(path)
+10 -26
View File
@@ -39,10 +39,6 @@ def _table(rows: list[dict], top_n: int) -> str:
{
"score p25": f"{row['score_p25']:.2f}",
"score mean": f"{row['score_mean']:.2f}",
"score std": f"{row['score_std']:.2f}",
"pass mean": f"{row['strict_pass_rate_mean']:.2f}",
"echo rate": f"{row['persona_echo_rate_mean']:.2f}",
"refusal rate": f"{row['refusal_or_ai_break_rate_mean']:.2f}",
"template": _markdown_text(row["template"]),
}
for row in rows[:top_n]
@@ -61,39 +57,27 @@ def _block(summary_path: Path) -> str:
"`qwen/qwen3.6-flash`, and `ibm-granite/granite-4.1-8b`."
),
(
"This table reports mean and sample std across models. Each model first averages "
"the two probe axes for a template, so this is model-equal rather than row-equal. "
"`score p25` is the headline sort: it is the 25th percentile score across the "
"four clean model artifacts, so a template has to work on more than one model to rank well."
"Each model first averages the two probe axes for a template, so this is "
"model-equal rather than row-equal. `score p25` is the headline sort: it is "
"the 25th percentile score across the four clean model artifacts, so a template "
"has to work on more than one model to rank well."
),
"![refusal probe model matrix](./out/model_matrix/refusal_probe_seed24_n1_model_matrix.png)",
(
"Caption: each dot is one template. Right is more on-axis movement; lower is less "
"off-axis confounding. Black dots have at least one strict-pass template-axis cell; "
"grey dots have none. Numbered dots are the first rows of the table. Error bars show "
"model SEM for those numbered rows only."
"Caption: this is a template overview, not a persona plot. Each dot is one template, "
"averaged over the two refusal-probe axes and four clean models. Right is more "
"on-axis movement; lower is less off-axis confounding. Black dots have at least one "
"strict-pass template-axis cell; grey dots have none. Numbered dots are the first "
"rows of the table."
),
"Model-matrix templates, all rows:",
_table(rows, top_n=len(rows)),
(
"Interpretation: some explicit judgment framings and red-team/eval framings move "
"the hard axis more often than the gentle templates, but they frequently do so "
"with persona echo or model-specific behavior. The cleanest-looking single-axis "
"the hard axis more often than the gentle templates. The cleanest-looking single-axis "
"cells were often `protocol_harm`, so treat the high rows as rerun candidates "
"rather than settled reusable defaults."
),
"Excluded attempted models:",
"\n".join([
"| model | result |",
"|---|---|",
"| `google/gemma-2-9b-it` | OpenRouter returned no endpoints for all 190 cells. |",
"| `openai/gpt-oss-120b` | OpenRouter returned `Reasoning is mandatory for this endpoint and cannot be disabled` for all 190 cells. |",
"| `deepseek/deepseek-v4-flash` | Reproduced 3 empty-generation cells out of 190, so excluded from aggregate instead of averaging missing data. |",
]),
(
"Full generated table:\n"
"[`out/model_matrix/refusal_probe_seed24_n1_model_matrix_summary.md`](out/model_matrix/refusal_probe_seed24_n1_model_matrix_summary.md)."
),
])