docs: use one Quarto source for README and Pages

This commit is contained in:
wassname
2026-06-25 13:06:12 +08:00
parent 024fb3d545
commit cfcb57b9ce
20 changed files with 533 additions and 2000 deletions
+79
View File
@@ -0,0 +1,79 @@
from __future__ import annotations
import json
import math
from pathlib import Path
import statistics
from typing import Any
ROOT = Path(__file__).resolve().parents[1]
STATS = ROOT / "out/stats"
NORMAL_TEMPLATE_PAIR_STATS = STATS / "v2_pilot_seed24_template_pair_stats.jsonl"
ENGINEERED_TEMPLATE_PAIR_STATS = STATS / "engineered_baseline_seed24_template_pair_stats.jsonl"
CONTROL_TEMPLATE_PAIR_STATS = STATS / "control_baseline_seed24_template_pair_stats.jsonl"
REFUSAL_MODEL_PAIR_STATS = [
ROOT / "out/model_matrix/stats/refusal_probe_seed24_n1_google_gemma-2-27b-it_template_pair_stats.jsonl",
ROOT / "out/model_matrix/stats/refusal_probe_seed24_n1_google_gemma-3-4b-it_template_pair_stats.jsonl",
ROOT / "out/model_matrix/stats/refusal_probe_seed24_n1_qwen_qwen3.6-flash_template_pair_stats.jsonl",
ROOT / "out/model_matrix/stats/refusal_probe_seed24_n1_ibm-granite_granite-4.1-8b_template_pair_stats.jsonl",
]
REFUSAL_MODEL_PREFIX = ROOT / "out/model_matrix/refusal_probe_seed24_n1"
def read_jsonl(path: Path) -> list[dict[str, Any]]:
return [json.loads(line) for line in path.read_text().splitlines() if line.strip()]
def clamp01(x: float) -> float:
return max(0.0, min(1.0, x))
def mean(xs: list[float]) -> float:
return sum(xs) / len(xs)
def std(xs: list[float]) -> float:
if len(xs) == 1:
return 0.0
return statistics.stdev(xs)
def score(row: dict[str, Any]) -> float:
on_axis = clamp01(float(row["mean_axis_delta"]) / 8.0)
off_axis = clamp01((float(row["mean_off_axis_problem"]) - 1.0) / 6.0)
return 100.0 * on_axis * (1.0 - off_axis)
def score_t(scores: list[float]) -> float:
sem = std(scores) / math.sqrt(len(scores))
mean_score = mean(scores)
if sem == 0.0:
return 0.0 if mean_score == 0.0 else 1_000_000.0
return mean_score / sem
def mean_template_rows(rows: list[dict[str, Any]]) -> list[dict[str, Any]]:
grouped: dict[str, list[dict[str, Any]]] = {}
for row in rows:
grouped.setdefault(row["template"], []).append({**row, "score": score(row)})
out = []
for template, rs in grouped.items():
scores = [float(row["score"]) for row in rs]
out.append({
"template": template,
"score_t": round(score_t(scores), 2),
"score": round(mean(scores), 1),
"score_mean": round(mean(scores), 2),
"on_axis": clamp01(mean([float(row["mean_axis_delta"]) for row in rs]) / 8.0),
"off_axis": clamp01(
(mean([float(row["mean_off_axis_problem"]) for row in rs]) - 1.0) / 6.0),
"axis_delta": round(mean([float(row["mean_axis_delta"]) for row in rs]), 2),
"off_axis_problem": round(mean([float(row["mean_off_axis_problem"]) for row in rs]), 2),
"judge_std": round(mean([float(row["mean_axis_delta_judge_std"]) for row in rs]), 2),
"n_cells": len(rs),
})
return sorted(out, key=lambda row: row["score_t"], reverse=True)
+5 -222
View File
@@ -1,230 +1,13 @@
"""Plot measured on-axis movement against off-axis confounding.
The default input is the built Hugging Face parquet table:
uv run python scripts/plot_on_off_axis.py /tmp/persona-steering-template-library-hf/parquet/main.parquet
"""
"""Write the canonical README/Page Plotly figure as PNG and SVG."""
from __future__ import annotations
import argparse
from collections import defaultdict
import json
import re
import textwrap
from pathlib import Path
from typing import Any
from adjustText import adjust_text
import matplotlib.pyplot as plt
import pyarrow.parquet as pq
def _clamp01(x: float) -> float:
return max(0.0, min(1.0, x))
def _read_rows(path: Path) -> list[dict[str, Any]]:
if path.suffix == ".parquet":
return pq.read_table(path).to_pylist()
rows = []
for line in path.read_text().splitlines():
if line.strip():
rows.append(json.loads(line))
return rows
def _read_all_rows(paths: list[Path]) -> list[dict[str, Any]]:
rows = []
for path in paths:
rows.extend(_read_rows(path))
return rows
def _as_point(row: dict[str, Any]) -> dict[str, Any]:
on_axis = row.get("on_axis")
if on_axis is None:
on_axis = _clamp01(float(row.get("mean_axis_delta") or 0.0) / 8.0)
off_axis = row.get("off_axis")
if off_axis is None:
off_axis = _clamp01((float(row.get("mean_off_axis_problem") or 7.0) - 1.0) / 6.0)
point_id = row.get("contrast") or row.get("persona_pair") or ""
template = row.get("template") or row.get("template_jinja") or ""
return {
"x": float(on_axis),
"y": float(off_axis),
"score": float(row.get("score") or 100.0 * float(on_axis) * (1.0 - float(off_axis))),
"id": str(point_id),
"template": str(template),
"recommended": bool(row.get("recommended")),
}
def _aggregate_points(points: list[dict[str, Any]]) -> list[dict[str, Any]]:
groups: dict[tuple[float, float], list[dict[str, Any]]] = defaultdict(list)
for point in points:
groups[(point["x"], point["y"])].append(point)
out = []
for cell_id, ((x, y), rows) in enumerate(sorted(groups.items()), start=1):
rows.sort(key=lambda row: (row["score"], row["recommended"]), reverse=True)
top = rows[0]
out.append({
"cell_id": cell_id,
"x": x,
"y": y,
"score": max(row["score"] for row in rows),
"id": top["id"],
"template": top["template"],
"recommended": any(row["recommended"] for row in rows),
"count": len(rows),
"labels": [f'{row["id"]}: "{row["template"]}"' for row in rows],
})
return out
def _label_points(points: list[dict[str, Any]], n: int, rightmost_n: int) -> list[dict[str, Any]]:
if len(points) <= n:
return points
high_score = sorted(points, key=lambda p: p["score"], reverse=True)[: max(2, n // 2)]
high_off_axis = sorted(points, key=lambda p: (p["y"], p["x"]), reverse=True)[: n]
rightmost = sorted(points, key=lambda p: (p["x"], -p["y"], p["score"]), reverse=True)[:rightmost_n]
out = []
seen_labels = set()
seen_cells = set()
for point in high_score + high_off_axis + rightmost:
label_key = f'{point["id"]}: "{point["template"]}"'
cell_key = (round(point["x"], 1), round(point["y"], 1))
if label_key not in seen_labels and cell_key not in seen_cells:
out.append(point)
seen_labels.add(label_key)
seen_cells.add(cell_key)
return out[: max(n, rightmost_n)]
def _place_label(i: int, point: dict[str, Any]) -> tuple[float, float, str, str]:
dx = 0.018
dy = [0.035, -0.05, 0.075, -0.09, 0.115, -0.13, 0.16, -0.175][i % 8]
x = min(0.98, point["x"] + dx) if point["x"] < 0.9 else max(0.05, point["x"] - 0.02)
y = min(0.98, max(0.02, point["y"] + dy))
if point["y"] < 0.08:
y = max(0.08, y)
ha = "left" if point["x"] < 0.9 else "right"
return x, y, ha, "center"
def _short_template(text: str, width: int = 52) -> str:
if text == "__verbatim_skill_persona__":
text = "engineered long persona prefix"
text = text.replace("{{ persona }}", "{persona}").replace("\n", " ")
text = " ".join(text.split())
if re.search(r"[\u4e00-\u9fff]", text):
if "社会主义核心价值观" in text:
text = "Chinese compliance role-play wrapper with core values"
else:
text = "Chinese compliance role-play wrapper"
if len(text) <= width:
return text
keep = max(8, (width - 3) // 2)
return f"{text[:keep].rstrip('. ')}...{text[-keep:].lstrip('. ')}"
def _short_label(point: dict[str, Any]) -> str:
text = f'{point["cell_id"]}: "{_short_template(point["template"])}"'
return textwrap.fill(text, width=38)
def _y_limits(points: list[dict[str, Any]], labels: list[dict[str, Any]]) -> tuple[float, float]:
ys = [p["y"] for p in points]
label_ys = [p["y"] for p in labels]
ymax = min(1.02, max(max(ys), max(label_ys, default=0.0)) + 0.18)
ymax = max(0.28, ymax)
ymin = min(-0.02, min(min(ys), min(label_ys, default=0.0)) - 0.06)
return ymin, ymax
import readme_plot
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("input", nargs="+", type=Path)
ap.add_argument("--out", type=Path, default=Path("out/on_off_axis.png"))
ap.add_argument("--label-count", type=int, default=10)
ap.add_argument("--label-rightmost", type=int, default=5)
args = ap.parse_args()
raw_points = [_as_point(row) for row in _read_all_rows(args.input)]
raw_points = [p for p in raw_points if p["id"]]
points = _aggregate_points(raw_points)
labels = _label_points(points, args.label_count, args.label_rightmost)
fig, ax = plt.subplots(figsize=(8.0, 5.6), dpi=180)
ax.scatter(
[p["x"] for p in points],
[p["y"] for p in points],
s=[26 + 12 * p["count"] for p in points],
c=["black" if p["recommended"] else "0.55" for p in points],
alpha=0.82,
linewidths=0,
)
for point in points:
if point["count"] >= 4:
ax.text(
point["x"],
point["y"],
str(point["count"]),
ha="center",
va="center",
fontsize=6.5,
color="white" if point["recommended"] else "0.1",
)
texts = []
target_x = []
target_y = []
for i, point in enumerate(labels):
x, y, ha, va = _place_label(i, point)
count_suffix = f" [{point['count']}]" if point["count"] > 1 else ""
texts.append(ax.text(
x,
y,
_short_label(point) + count_suffix,
ha=ha,
va=va,
fontsize=6.5,
color="0.15",
bbox={"facecolor": "white", "edgecolor": "none", "alpha": 0.82, "pad": 0.7},
))
target_x.append(point["x"])
target_y.append(point["y"])
ax.set_xlim(-0.02, 1.02)
ax.set_ylim(*_y_limits(points, labels))
ax.set_xlabel("on-axis movement")
ax.set_ylabel("off-axis confounding")
ax.set_title("Persona template cells: move the intended axis, avoid confounds", fontsize=10)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.grid(True, color="0.9", linewidth=0.6)
ax.text(1.0, -0.13, "better is lower-right", transform=ax.transAxes, ha="right", fontsize=8)
if texts:
adjust_text(
texts,
x=[p["x"] for p in points],
y=[p["y"] for p in points],
target_x=target_x,
target_y=target_y,
ax=ax,
expand=(1.08, 1.22),
force_text=(0.16, 0.34),
force_static=(0.08, 0.16),
force_pull=(0.012, 0.018),
max_move=(18, 18),
ensure_inside_axes=True,
prevent_crossings=True,
iter_lim=600,
arrowprops={"arrowstyle": "-", "color": "0.65", "lw": 0.55},
)
fig.tight_layout()
args.out.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(args.out)
print(args.out)
readme_plot.write_main_plot_assets()
print(readme_plot.MAIN_PNG)
print(readme_plot.MAIN_SVG)
if __name__ == "__main__":
+97
View File
@@ -0,0 +1,97 @@
from __future__ import annotations
import html
from pathlib import Path
import textwrap
from typing import Any
import plotly.graph_objects as go
import docs_results
MAIN_PNG = docs_results.ROOT / "out/on_off_axis.png"
MAIN_SVG = docs_results.ROOT / "out/on_off_axis.svg"
def _wrap_hover(text: str, width: int = 62) -> str:
escaped = html.escape(" ".join(text.split()))
return "<br>".join(
textwrap.wrap(escaped, width=width, break_long_words=True, break_on_hyphens=False))
def main_plot_rows(path: Path = docs_results.NORMAL_TEMPLATE_PAIR_STATS) -> list[dict[str, Any]]:
return docs_results.mean_template_rows(docs_results.read_jsonl(path))
def template_scatter(rows: list[dict[str, Any]] | None = None) -> go.Figure:
rows = main_plot_rows() if rows is None else rows
top_rank = {row["template"]: i for i, row in enumerate(rows[:10], start=1)}
text = [str(top_rank[row["template"]]) if row["template"] in top_rank else "" for row in rows]
hover = [
"<br>".join([
f"<b>{_wrap_hover(row['template'])}</b>",
f"rank: {i}",
f"score t: {row['score_t']:.2f}",
f"score mean: {row['score_mean']:.2f}",
f"axis delta: {row['axis_delta']:.2f}",
f"off-axis problem: {row['off_axis_problem']:.2f}",
f"judge std: {row['judge_std']:.2f}",
f"cells: {row['n_cells']}",
])
for i, row in enumerate(rows, start=1)
]
fig = go.Figure(
data=go.Scatter(
x=[row["on_axis"] for row in rows],
y=[row["off_axis"] for row in rows],
mode="markers+text",
text=text,
textposition="middle center",
textfont={"size": 9, "color": "white"},
customdata=hover,
hovertemplate="%{customdata}<extra></extra>",
marker={
"size": 10,
"color": [row["score_t"] for row in rows],
"colorscale": "Cividis",
"showscale": True,
"colorbar": {"title": "score t"},
"line": {"width": 0.5, "color": "white"},
"opacity": 0.9,
},
)
)
fig.update_layout(
autosize=True,
width=960,
height=620,
template="plotly_white",
margin={"l": 68, "r": 24, "t": 28, "b": 66},
xaxis={
"title": "on-axis movement, higher is better",
"range": [-0.02, 1.02],
"gridcolor": "rgba(0,0,0,0.08)",
},
yaxis={
"title": "off-axis confounding, lower is better",
"range": [-0.02, 1.02],
"gridcolor": "rgba(0,0,0,0.08)",
},
annotations=[{
"text": "normal pilot scenarios; one point per measured template",
"xref": "paper",
"yref": "paper",
"x": 1.0,
"y": -0.13,
"showarrow": False,
"font": {"size": 11, "color": "rgba(0,0,0,0.62)"},
}],
)
return fig
def write_main_plot_assets() -> None:
fig = template_scatter()
MAIN_PNG.parent.mkdir(parents=True, exist_ok=True)
fig.write_image(MAIN_PNG, width=960, height=620, scale=2)
fig.write_image(MAIN_SVG, width=960, height=620)
+3 -56
View File
@@ -8,18 +8,13 @@ from pathlib import Path
import statistics
from typing import Any
import matplotlib.pyplot as plt
from tabulate import tabulate
import docs_results
ROOT = Path(__file__).resolve().parents[1]
DEFAULT_PAIR_STATS = [
ROOT / "out/model_matrix/stats/refusal_probe_seed24_n1_google_gemma-2-27b-it_template_pair_stats.jsonl",
ROOT / "out/model_matrix/stats/refusal_probe_seed24_n1_google_gemma-3-4b-it_template_pair_stats.jsonl",
ROOT / "out/model_matrix/stats/refusal_probe_seed24_n1_qwen_qwen3.6-flash_template_pair_stats.jsonl",
ROOT / "out/model_matrix/stats/refusal_probe_seed24_n1_ibm-granite_granite-4.1-8b_template_pair_stats.jsonl",
]
DEFAULT_OUT_PREFIX = ROOT / "out/model_matrix/refusal_probe_seed24_n1"
DEFAULT_PAIR_STATS = docs_results.REFUSAL_MODEL_PAIR_STATS
DEFAULT_OUT_PREFIX = docs_results.REFUSAL_MODEL_PREFIX
def _read_jsonl(path: Path) -> list[dict[str, Any]]:
@@ -187,48 +182,6 @@ def _write_markdown(path: Path, template_rows: list[dict[str, Any]], pair_rows:
path.write_text("\n".join(lines) + "\n")
def _plot(path: Path, rows: list[dict[str, Any]], label_count: int) -> None:
fig, ax = plt.subplots(figsize=(7.4, 5.0), dpi=180)
xs = [_clamp01(row["axis_delta_mean"] / 8.0) for row in rows]
ys = [_clamp01((row["off_axis_problem_mean"] - 1.0) / 6.0) for row in rows]
colors = ["0.12" if row["strict_pass_rate_mean"] > 0 else "0.72" for row in rows]
ax.scatter(xs, ys, s=22, c=colors, alpha=0.9, linewidths=0, zorder=2)
top_ids = {id(row): i for i, row in enumerate(rows[:label_count], start=1)}
for row in rows:
if id(row) not in top_ids:
continue
x = _clamp01(row["axis_delta_mean"] / 8.0)
y = _clamp01((row["off_axis_problem_mean"] - 1.0) / 6.0)
ax.text(
x,
y,
str(top_ids[id(row)]),
ha="center",
va="center",
fontsize=6.2,
color="white",
zorder=3,
)
ax.set_xlim(-0.02, 1.02)
ax.set_ylim(-0.02, 1.02)
ax.set_xlabel("template on-axis movement, higher is better", fontsize=9)
ax.set_ylabel("template off-axis confounding, lower is better", fontsize=9)
ax.grid(True, color="0.92", linewidth=0.45)
ax.tick_params(axis="both", labelsize=8, length=3, width=0.7, color="0.25")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["left"].set_color("0.25")
ax.spines["bottom"].set_color("0.25")
ax.spines["left"].set_linewidth(0.7)
ax.spines["bottom"].set_linewidth(0.7)
path.parent.mkdir(parents=True, exist_ok=True)
fig.tight_layout()
fig.savefig(path)
plt.close(fig)
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--pair-stats", nargs="+", type=Path, default=DEFAULT_PAIR_STATS)
@@ -258,14 +211,8 @@ def main() -> None:
_write_jsonl(prefix.with_name(prefix.name + "_template_pair_model_summary.jsonl"), pair_rows)
_write_csv(prefix.with_name(prefix.name + "_template_pair_model_summary.csv"), pair_rows)
_write_markdown(prefix.with_name(prefix.name + "_model_matrix_summary.md"), template_rows, pair_rows, args.top_n)
png_path = prefix.with_name(prefix.name + "_model_matrix.png")
svg_path = prefix.with_name(prefix.name + "_model_matrix.svg")
_plot(png_path, template_rows, label_count=10)
_plot(svg_path, template_rows, label_count=10)
print(f"models={expected_models} templates={len(template_rows)} template_pairs={len(pair_rows)}")
print(prefix.with_name(prefix.name + "_model_matrix_summary.md"))
print(png_path)
print(svg_path)
if __name__ == "__main__":
+1 -11
View File
@@ -55,23 +55,13 @@ def _appendix_block(summary_path: Path) -> str:
(
"Why include it? These negative poles can collapse into generic safety refusal, "
"AI-role breaks, or persona echo instead of the intended behavioral contrast. "
"This plot is a quick check for templates that move those hard axes without "
"The table is a quick check for templates that move those hard axes without "
"simply making the model refuse."
),
"![refusal-pole probe](./out/model_matrix/refusal_probe_seed24_n1_model_matrix.png)",
(
"Caption: each dot is one template, averaged over the two refusal-probe axes "
"and four clean models. Right is more on-axis movement; lower is less off-axis "
"confounding. Numbered dots are the first rows of the appendix table."
),
(
"`refusal_or_ai_break_rate` is only an output audit column: it marks completions "
"that refused or broke AI role, and is not used to select this data slice."
),
(
"Interactive hover plot: "
"[GitHub Pages](https://wassname.github.io/persona-steering-template-library/)."
),
(
"The generated full audit table includes strict-pass, echo, and refusal columns: "
"[out/model_matrix/refusal_probe_seed24_n1_model_matrix_summary.md]"
+6 -44
View File
@@ -1,19 +1,17 @@
from __future__ import annotations
import json
import math
from pathlib import Path
import statistics
from tabulate import tabulate
import docs_results
from template_catalog import CATALOG_PATH, jinja_to_runtime, load_template_catalog
ROOT = Path(__file__).resolve().parents[1]
STATS = ROOT / "out/stats"
NORMAL_STATS = STATS / "v2_pilot_seed24_template_pair_stats.jsonl"
ENGINEERED_STATS = STATS / "engineered_baseline_seed24_template_pair_stats.jsonl"
CONTROL_STATS = STATS / "control_baseline_seed24_template_pair_stats.jsonl"
NORMAL_STATS = docs_results.NORMAL_TEMPLATE_PAIR_STATS
ENGINEERED_STATS = docs_results.ENGINEERED_TEMPLATE_PAIR_STATS
CONTROL_STATS = docs_results.CONTROL_TEMPLATE_PAIR_STATS
ENGINEERED_PAIRS = ROOT / "data/persona_pairs_engineered_baseline_pilot_two.jsonl"
ENGINEERED_DISPLAY = "`{engineered long persona prefix}`*"
@@ -21,30 +19,8 @@ def _read_jsonl(path: Path) -> list[dict]:
return [json.loads(line) for line in path.read_text().splitlines() if line.strip()]
def _clamp01(x: float) -> float:
return max(0.0, min(1.0, x))
def _score(row: dict) -> float:
on_axis = _clamp01(float(row["mean_axis_delta"]) / 8.0)
off_axis = _clamp01((float(row["mean_off_axis_problem"]) - 1.0) / 6.0)
return round(100.0 * on_axis * (1.0 - off_axis), 1)
def _std(xs: list[float]) -> float:
if len(xs) == 1:
return 0.0
return statistics.stdev(xs)
def _score_t(scores: list[float]) -> float:
if len(scores) < 2:
return 0.0
sem = _std(scores) / math.sqrt(len(scores))
mean_score = sum(scores) / len(scores)
if sem == 0.0:
return 0.0 if mean_score == 0.0 else 1_000_000.0
return mean_score / sem
return round(docs_results.score(row), 1)
def _markdown_text(text: str) -> str:
@@ -78,21 +54,7 @@ def _best_by_template(rows: list[dict]) -> list[dict]:
def _mean_by_template(rows: list[dict]) -> list[dict]:
grouped: dict[str, list[dict]] = {}
for row in rows:
grouped.setdefault(row["template"], []).append({**row, "score": _score(row)})
out = []
for template, rs in grouped.items():
scores = [row["score"] for row in rs]
out.append({
"template": template,
"score_t": round(_score_t(scores), 2),
"score": round(sum(scores) / len(scores), 1),
"judge_std": round(
sum(float(row["mean_axis_delta_judge_std"]) for row in rs) / len(rs), 2),
"n_cells": len(rs),
})
return sorted(out, key=lambda row: row["score_t"], reverse=True)
return docs_results.mean_template_rows(rows)
def _engineered_derived_templates() -> set[str]: