mirror of
https://github.com/wassname/persona-steering-template-library.git
synced 2026-06-27 16:46:08 +08:00
docs: use one Quarto source for README and Pages
This commit is contained in:
@@ -0,0 +1,79 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import math
|
||||
from pathlib import Path
|
||||
import statistics
|
||||
from typing import Any
|
||||
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
STATS = ROOT / "out/stats"
|
||||
|
||||
NORMAL_TEMPLATE_PAIR_STATS = STATS / "v2_pilot_seed24_template_pair_stats.jsonl"
|
||||
ENGINEERED_TEMPLATE_PAIR_STATS = STATS / "engineered_baseline_seed24_template_pair_stats.jsonl"
|
||||
CONTROL_TEMPLATE_PAIR_STATS = STATS / "control_baseline_seed24_template_pair_stats.jsonl"
|
||||
|
||||
REFUSAL_MODEL_PAIR_STATS = [
|
||||
ROOT / "out/model_matrix/stats/refusal_probe_seed24_n1_google_gemma-2-27b-it_template_pair_stats.jsonl",
|
||||
ROOT / "out/model_matrix/stats/refusal_probe_seed24_n1_google_gemma-3-4b-it_template_pair_stats.jsonl",
|
||||
ROOT / "out/model_matrix/stats/refusal_probe_seed24_n1_qwen_qwen3.6-flash_template_pair_stats.jsonl",
|
||||
ROOT / "out/model_matrix/stats/refusal_probe_seed24_n1_ibm-granite_granite-4.1-8b_template_pair_stats.jsonl",
|
||||
]
|
||||
REFUSAL_MODEL_PREFIX = ROOT / "out/model_matrix/refusal_probe_seed24_n1"
|
||||
|
||||
|
||||
def read_jsonl(path: Path) -> list[dict[str, Any]]:
|
||||
return [json.loads(line) for line in path.read_text().splitlines() if line.strip()]
|
||||
|
||||
|
||||
def clamp01(x: float) -> float:
|
||||
return max(0.0, min(1.0, x))
|
||||
|
||||
|
||||
def mean(xs: list[float]) -> float:
|
||||
return sum(xs) / len(xs)
|
||||
|
||||
|
||||
def std(xs: list[float]) -> float:
|
||||
if len(xs) == 1:
|
||||
return 0.0
|
||||
return statistics.stdev(xs)
|
||||
|
||||
|
||||
def score(row: dict[str, Any]) -> float:
|
||||
on_axis = clamp01(float(row["mean_axis_delta"]) / 8.0)
|
||||
off_axis = clamp01((float(row["mean_off_axis_problem"]) - 1.0) / 6.0)
|
||||
return 100.0 * on_axis * (1.0 - off_axis)
|
||||
|
||||
|
||||
def score_t(scores: list[float]) -> float:
|
||||
sem = std(scores) / math.sqrt(len(scores))
|
||||
mean_score = mean(scores)
|
||||
if sem == 0.0:
|
||||
return 0.0 if mean_score == 0.0 else 1_000_000.0
|
||||
return mean_score / sem
|
||||
|
||||
|
||||
def mean_template_rows(rows: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
grouped: dict[str, list[dict[str, Any]]] = {}
|
||||
for row in rows:
|
||||
grouped.setdefault(row["template"], []).append({**row, "score": score(row)})
|
||||
|
||||
out = []
|
||||
for template, rs in grouped.items():
|
||||
scores = [float(row["score"]) for row in rs]
|
||||
out.append({
|
||||
"template": template,
|
||||
"score_t": round(score_t(scores), 2),
|
||||
"score": round(mean(scores), 1),
|
||||
"score_mean": round(mean(scores), 2),
|
||||
"on_axis": clamp01(mean([float(row["mean_axis_delta"]) for row in rs]) / 8.0),
|
||||
"off_axis": clamp01(
|
||||
(mean([float(row["mean_off_axis_problem"]) for row in rs]) - 1.0) / 6.0),
|
||||
"axis_delta": round(mean([float(row["mean_axis_delta"]) for row in rs]), 2),
|
||||
"off_axis_problem": round(mean([float(row["mean_off_axis_problem"]) for row in rs]), 2),
|
||||
"judge_std": round(mean([float(row["mean_axis_delta_judge_std"]) for row in rs]), 2),
|
||||
"n_cells": len(rs),
|
||||
})
|
||||
return sorted(out, key=lambda row: row["score_t"], reverse=True)
|
||||
+5
-222
@@ -1,230 +1,13 @@
|
||||
"""Plot measured on-axis movement against off-axis confounding.
|
||||
|
||||
The default input is the built Hugging Face parquet table:
|
||||
|
||||
uv run python scripts/plot_on_off_axis.py /tmp/persona-steering-template-library-hf/parquet/main.parquet
|
||||
"""
|
||||
"""Write the canonical README/Page Plotly figure as PNG and SVG."""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
import json
|
||||
import re
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from adjustText import adjust_text
|
||||
import matplotlib.pyplot as plt
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
|
||||
def _clamp01(x: float) -> float:
|
||||
return max(0.0, min(1.0, x))
|
||||
|
||||
|
||||
def _read_rows(path: Path) -> list[dict[str, Any]]:
|
||||
if path.suffix == ".parquet":
|
||||
return pq.read_table(path).to_pylist()
|
||||
rows = []
|
||||
for line in path.read_text().splitlines():
|
||||
if line.strip():
|
||||
rows.append(json.loads(line))
|
||||
return rows
|
||||
|
||||
|
||||
def _read_all_rows(paths: list[Path]) -> list[dict[str, Any]]:
|
||||
rows = []
|
||||
for path in paths:
|
||||
rows.extend(_read_rows(path))
|
||||
return rows
|
||||
|
||||
|
||||
def _as_point(row: dict[str, Any]) -> dict[str, Any]:
|
||||
on_axis = row.get("on_axis")
|
||||
if on_axis is None:
|
||||
on_axis = _clamp01(float(row.get("mean_axis_delta") or 0.0) / 8.0)
|
||||
off_axis = row.get("off_axis")
|
||||
if off_axis is None:
|
||||
off_axis = _clamp01((float(row.get("mean_off_axis_problem") or 7.0) - 1.0) / 6.0)
|
||||
point_id = row.get("contrast") or row.get("persona_pair") or ""
|
||||
template = row.get("template") or row.get("template_jinja") or ""
|
||||
return {
|
||||
"x": float(on_axis),
|
||||
"y": float(off_axis),
|
||||
"score": float(row.get("score") or 100.0 * float(on_axis) * (1.0 - float(off_axis))),
|
||||
"id": str(point_id),
|
||||
"template": str(template),
|
||||
"recommended": bool(row.get("recommended")),
|
||||
}
|
||||
|
||||
|
||||
def _aggregate_points(points: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
groups: dict[tuple[float, float], list[dict[str, Any]]] = defaultdict(list)
|
||||
for point in points:
|
||||
groups[(point["x"], point["y"])].append(point)
|
||||
|
||||
out = []
|
||||
for cell_id, ((x, y), rows) in enumerate(sorted(groups.items()), start=1):
|
||||
rows.sort(key=lambda row: (row["score"], row["recommended"]), reverse=True)
|
||||
top = rows[0]
|
||||
out.append({
|
||||
"cell_id": cell_id,
|
||||
"x": x,
|
||||
"y": y,
|
||||
"score": max(row["score"] for row in rows),
|
||||
"id": top["id"],
|
||||
"template": top["template"],
|
||||
"recommended": any(row["recommended"] for row in rows),
|
||||
"count": len(rows),
|
||||
"labels": [f'{row["id"]}: "{row["template"]}"' for row in rows],
|
||||
})
|
||||
return out
|
||||
|
||||
|
||||
def _label_points(points: list[dict[str, Any]], n: int, rightmost_n: int) -> list[dict[str, Any]]:
|
||||
if len(points) <= n:
|
||||
return points
|
||||
high_score = sorted(points, key=lambda p: p["score"], reverse=True)[: max(2, n // 2)]
|
||||
high_off_axis = sorted(points, key=lambda p: (p["y"], p["x"]), reverse=True)[: n]
|
||||
rightmost = sorted(points, key=lambda p: (p["x"], -p["y"], p["score"]), reverse=True)[:rightmost_n]
|
||||
out = []
|
||||
seen_labels = set()
|
||||
seen_cells = set()
|
||||
for point in high_score + high_off_axis + rightmost:
|
||||
label_key = f'{point["id"]}: "{point["template"]}"'
|
||||
cell_key = (round(point["x"], 1), round(point["y"], 1))
|
||||
if label_key not in seen_labels and cell_key not in seen_cells:
|
||||
out.append(point)
|
||||
seen_labels.add(label_key)
|
||||
seen_cells.add(cell_key)
|
||||
return out[: max(n, rightmost_n)]
|
||||
|
||||
|
||||
def _place_label(i: int, point: dict[str, Any]) -> tuple[float, float, str, str]:
|
||||
dx = 0.018
|
||||
dy = [0.035, -0.05, 0.075, -0.09, 0.115, -0.13, 0.16, -0.175][i % 8]
|
||||
x = min(0.98, point["x"] + dx) if point["x"] < 0.9 else max(0.05, point["x"] - 0.02)
|
||||
y = min(0.98, max(0.02, point["y"] + dy))
|
||||
if point["y"] < 0.08:
|
||||
y = max(0.08, y)
|
||||
ha = "left" if point["x"] < 0.9 else "right"
|
||||
return x, y, ha, "center"
|
||||
|
||||
|
||||
def _short_template(text: str, width: int = 52) -> str:
|
||||
if text == "__verbatim_skill_persona__":
|
||||
text = "engineered long persona prefix"
|
||||
text = text.replace("{{ persona }}", "{persona}").replace("\n", " ")
|
||||
text = " ".join(text.split())
|
||||
if re.search(r"[\u4e00-\u9fff]", text):
|
||||
if "社会主义核心价值观" in text:
|
||||
text = "Chinese compliance role-play wrapper with core values"
|
||||
else:
|
||||
text = "Chinese compliance role-play wrapper"
|
||||
if len(text) <= width:
|
||||
return text
|
||||
keep = max(8, (width - 3) // 2)
|
||||
return f"{text[:keep].rstrip('. ')}...{text[-keep:].lstrip('. ')}"
|
||||
|
||||
|
||||
def _short_label(point: dict[str, Any]) -> str:
|
||||
text = f'{point["cell_id"]}: "{_short_template(point["template"])}"'
|
||||
return textwrap.fill(text, width=38)
|
||||
|
||||
|
||||
def _y_limits(points: list[dict[str, Any]], labels: list[dict[str, Any]]) -> tuple[float, float]:
|
||||
ys = [p["y"] for p in points]
|
||||
label_ys = [p["y"] for p in labels]
|
||||
ymax = min(1.02, max(max(ys), max(label_ys, default=0.0)) + 0.18)
|
||||
ymax = max(0.28, ymax)
|
||||
ymin = min(-0.02, min(min(ys), min(label_ys, default=0.0)) - 0.06)
|
||||
return ymin, ymax
|
||||
import readme_plot
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("input", nargs="+", type=Path)
|
||||
ap.add_argument("--out", type=Path, default=Path("out/on_off_axis.png"))
|
||||
ap.add_argument("--label-count", type=int, default=10)
|
||||
ap.add_argument("--label-rightmost", type=int, default=5)
|
||||
args = ap.parse_args()
|
||||
|
||||
raw_points = [_as_point(row) for row in _read_all_rows(args.input)]
|
||||
raw_points = [p for p in raw_points if p["id"]]
|
||||
points = _aggregate_points(raw_points)
|
||||
labels = _label_points(points, args.label_count, args.label_rightmost)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8.0, 5.6), dpi=180)
|
||||
ax.scatter(
|
||||
[p["x"] for p in points],
|
||||
[p["y"] for p in points],
|
||||
s=[26 + 12 * p["count"] for p in points],
|
||||
c=["black" if p["recommended"] else "0.55" for p in points],
|
||||
alpha=0.82,
|
||||
linewidths=0,
|
||||
)
|
||||
for point in points:
|
||||
if point["count"] >= 4:
|
||||
ax.text(
|
||||
point["x"],
|
||||
point["y"],
|
||||
str(point["count"]),
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=6.5,
|
||||
color="white" if point["recommended"] else "0.1",
|
||||
)
|
||||
texts = []
|
||||
target_x = []
|
||||
target_y = []
|
||||
for i, point in enumerate(labels):
|
||||
x, y, ha, va = _place_label(i, point)
|
||||
count_suffix = f" [{point['count']}]" if point["count"] > 1 else ""
|
||||
texts.append(ax.text(
|
||||
x,
|
||||
y,
|
||||
_short_label(point) + count_suffix,
|
||||
ha=ha,
|
||||
va=va,
|
||||
fontsize=6.5,
|
||||
color="0.15",
|
||||
bbox={"facecolor": "white", "edgecolor": "none", "alpha": 0.82, "pad": 0.7},
|
||||
))
|
||||
target_x.append(point["x"])
|
||||
target_y.append(point["y"])
|
||||
|
||||
ax.set_xlim(-0.02, 1.02)
|
||||
ax.set_ylim(*_y_limits(points, labels))
|
||||
ax.set_xlabel("on-axis movement")
|
||||
ax.set_ylabel("off-axis confounding")
|
||||
ax.set_title("Persona template cells: move the intended axis, avoid confounds", fontsize=10)
|
||||
ax.spines["top"].set_visible(False)
|
||||
ax.spines["right"].set_visible(False)
|
||||
ax.grid(True, color="0.9", linewidth=0.6)
|
||||
ax.text(1.0, -0.13, "better is lower-right", transform=ax.transAxes, ha="right", fontsize=8)
|
||||
if texts:
|
||||
adjust_text(
|
||||
texts,
|
||||
x=[p["x"] for p in points],
|
||||
y=[p["y"] for p in points],
|
||||
target_x=target_x,
|
||||
target_y=target_y,
|
||||
ax=ax,
|
||||
expand=(1.08, 1.22),
|
||||
force_text=(0.16, 0.34),
|
||||
force_static=(0.08, 0.16),
|
||||
force_pull=(0.012, 0.018),
|
||||
max_move=(18, 18),
|
||||
ensure_inside_axes=True,
|
||||
prevent_crossings=True,
|
||||
iter_lim=600,
|
||||
arrowprops={"arrowstyle": "-", "color": "0.65", "lw": 0.55},
|
||||
)
|
||||
fig.tight_layout()
|
||||
args.out.parent.mkdir(parents=True, exist_ok=True)
|
||||
fig.savefig(args.out)
|
||||
print(args.out)
|
||||
readme_plot.write_main_plot_assets()
|
||||
print(readme_plot.MAIN_PNG)
|
||||
print(readme_plot.MAIN_SVG)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import html
|
||||
from pathlib import Path
|
||||
import textwrap
|
||||
from typing import Any
|
||||
|
||||
import plotly.graph_objects as go
|
||||
|
||||
import docs_results
|
||||
|
||||
MAIN_PNG = docs_results.ROOT / "out/on_off_axis.png"
|
||||
MAIN_SVG = docs_results.ROOT / "out/on_off_axis.svg"
|
||||
|
||||
|
||||
def _wrap_hover(text: str, width: int = 62) -> str:
|
||||
escaped = html.escape(" ".join(text.split()))
|
||||
return "<br>".join(
|
||||
textwrap.wrap(escaped, width=width, break_long_words=True, break_on_hyphens=False))
|
||||
|
||||
|
||||
def main_plot_rows(path: Path = docs_results.NORMAL_TEMPLATE_PAIR_STATS) -> list[dict[str, Any]]:
|
||||
return docs_results.mean_template_rows(docs_results.read_jsonl(path))
|
||||
|
||||
|
||||
def template_scatter(rows: list[dict[str, Any]] | None = None) -> go.Figure:
|
||||
rows = main_plot_rows() if rows is None else rows
|
||||
top_rank = {row["template"]: i for i, row in enumerate(rows[:10], start=1)}
|
||||
text = [str(top_rank[row["template"]]) if row["template"] in top_rank else "" for row in rows]
|
||||
hover = [
|
||||
"<br>".join([
|
||||
f"<b>{_wrap_hover(row['template'])}</b>",
|
||||
f"rank: {i}",
|
||||
f"score t: {row['score_t']:.2f}",
|
||||
f"score mean: {row['score_mean']:.2f}",
|
||||
f"axis delta: {row['axis_delta']:.2f}",
|
||||
f"off-axis problem: {row['off_axis_problem']:.2f}",
|
||||
f"judge std: {row['judge_std']:.2f}",
|
||||
f"cells: {row['n_cells']}",
|
||||
])
|
||||
for i, row in enumerate(rows, start=1)
|
||||
]
|
||||
fig = go.Figure(
|
||||
data=go.Scatter(
|
||||
x=[row["on_axis"] for row in rows],
|
||||
y=[row["off_axis"] for row in rows],
|
||||
mode="markers+text",
|
||||
text=text,
|
||||
textposition="middle center",
|
||||
textfont={"size": 9, "color": "white"},
|
||||
customdata=hover,
|
||||
hovertemplate="%{customdata}<extra></extra>",
|
||||
marker={
|
||||
"size": 10,
|
||||
"color": [row["score_t"] for row in rows],
|
||||
"colorscale": "Cividis",
|
||||
"showscale": True,
|
||||
"colorbar": {"title": "score t"},
|
||||
"line": {"width": 0.5, "color": "white"},
|
||||
"opacity": 0.9,
|
||||
},
|
||||
)
|
||||
)
|
||||
fig.update_layout(
|
||||
autosize=True,
|
||||
width=960,
|
||||
height=620,
|
||||
template="plotly_white",
|
||||
margin={"l": 68, "r": 24, "t": 28, "b": 66},
|
||||
xaxis={
|
||||
"title": "on-axis movement, higher is better",
|
||||
"range": [-0.02, 1.02],
|
||||
"gridcolor": "rgba(0,0,0,0.08)",
|
||||
},
|
||||
yaxis={
|
||||
"title": "off-axis confounding, lower is better",
|
||||
"range": [-0.02, 1.02],
|
||||
"gridcolor": "rgba(0,0,0,0.08)",
|
||||
},
|
||||
annotations=[{
|
||||
"text": "normal pilot scenarios; one point per measured template",
|
||||
"xref": "paper",
|
||||
"yref": "paper",
|
||||
"x": 1.0,
|
||||
"y": -0.13,
|
||||
"showarrow": False,
|
||||
"font": {"size": 11, "color": "rgba(0,0,0,0.62)"},
|
||||
}],
|
||||
)
|
||||
return fig
|
||||
|
||||
|
||||
def write_main_plot_assets() -> None:
|
||||
fig = template_scatter()
|
||||
MAIN_PNG.parent.mkdir(parents=True, exist_ok=True)
|
||||
fig.write_image(MAIN_PNG, width=960, height=620, scale=2)
|
||||
fig.write_image(MAIN_SVG, width=960, height=620)
|
||||
@@ -8,18 +8,13 @@ from pathlib import Path
|
||||
import statistics
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from tabulate import tabulate
|
||||
|
||||
import docs_results
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
DEFAULT_PAIR_STATS = [
|
||||
ROOT / "out/model_matrix/stats/refusal_probe_seed24_n1_google_gemma-2-27b-it_template_pair_stats.jsonl",
|
||||
ROOT / "out/model_matrix/stats/refusal_probe_seed24_n1_google_gemma-3-4b-it_template_pair_stats.jsonl",
|
||||
ROOT / "out/model_matrix/stats/refusal_probe_seed24_n1_qwen_qwen3.6-flash_template_pair_stats.jsonl",
|
||||
ROOT / "out/model_matrix/stats/refusal_probe_seed24_n1_ibm-granite_granite-4.1-8b_template_pair_stats.jsonl",
|
||||
]
|
||||
DEFAULT_OUT_PREFIX = ROOT / "out/model_matrix/refusal_probe_seed24_n1"
|
||||
DEFAULT_PAIR_STATS = docs_results.REFUSAL_MODEL_PAIR_STATS
|
||||
DEFAULT_OUT_PREFIX = docs_results.REFUSAL_MODEL_PREFIX
|
||||
|
||||
|
||||
def _read_jsonl(path: Path) -> list[dict[str, Any]]:
|
||||
@@ -187,48 +182,6 @@ def _write_markdown(path: Path, template_rows: list[dict[str, Any]], pair_rows:
|
||||
path.write_text("\n".join(lines) + "\n")
|
||||
|
||||
|
||||
def _plot(path: Path, rows: list[dict[str, Any]], label_count: int) -> None:
|
||||
fig, ax = plt.subplots(figsize=(7.4, 5.0), dpi=180)
|
||||
xs = [_clamp01(row["axis_delta_mean"] / 8.0) for row in rows]
|
||||
ys = [_clamp01((row["off_axis_problem_mean"] - 1.0) / 6.0) for row in rows]
|
||||
colors = ["0.12" if row["strict_pass_rate_mean"] > 0 else "0.72" for row in rows]
|
||||
|
||||
ax.scatter(xs, ys, s=22, c=colors, alpha=0.9, linewidths=0, zorder=2)
|
||||
top_ids = {id(row): i for i, row in enumerate(rows[:label_count], start=1)}
|
||||
for row in rows:
|
||||
if id(row) not in top_ids:
|
||||
continue
|
||||
x = _clamp01(row["axis_delta_mean"] / 8.0)
|
||||
y = _clamp01((row["off_axis_problem_mean"] - 1.0) / 6.0)
|
||||
ax.text(
|
||||
x,
|
||||
y,
|
||||
str(top_ids[id(row)]),
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=6.2,
|
||||
color="white",
|
||||
zorder=3,
|
||||
)
|
||||
|
||||
ax.set_xlim(-0.02, 1.02)
|
||||
ax.set_ylim(-0.02, 1.02)
|
||||
ax.set_xlabel("template on-axis movement, higher is better", fontsize=9)
|
||||
ax.set_ylabel("template off-axis confounding, lower is better", fontsize=9)
|
||||
ax.grid(True, color="0.92", linewidth=0.45)
|
||||
ax.tick_params(axis="both", labelsize=8, length=3, width=0.7, color="0.25")
|
||||
ax.spines["top"].set_visible(False)
|
||||
ax.spines["right"].set_visible(False)
|
||||
ax.spines["left"].set_color("0.25")
|
||||
ax.spines["bottom"].set_color("0.25")
|
||||
ax.spines["left"].set_linewidth(0.7)
|
||||
ax.spines["bottom"].set_linewidth(0.7)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fig.tight_layout()
|
||||
fig.savefig(path)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--pair-stats", nargs="+", type=Path, default=DEFAULT_PAIR_STATS)
|
||||
@@ -258,14 +211,8 @@ def main() -> None:
|
||||
_write_jsonl(prefix.with_name(prefix.name + "_template_pair_model_summary.jsonl"), pair_rows)
|
||||
_write_csv(prefix.with_name(prefix.name + "_template_pair_model_summary.csv"), pair_rows)
|
||||
_write_markdown(prefix.with_name(prefix.name + "_model_matrix_summary.md"), template_rows, pair_rows, args.top_n)
|
||||
png_path = prefix.with_name(prefix.name + "_model_matrix.png")
|
||||
svg_path = prefix.with_name(prefix.name + "_model_matrix.svg")
|
||||
_plot(png_path, template_rows, label_count=10)
|
||||
_plot(svg_path, template_rows, label_count=10)
|
||||
print(f"models={expected_models} templates={len(template_rows)} template_pairs={len(pair_rows)}")
|
||||
print(prefix.with_name(prefix.name + "_model_matrix_summary.md"))
|
||||
print(png_path)
|
||||
print(svg_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -55,23 +55,13 @@ def _appendix_block(summary_path: Path) -> str:
|
||||
(
|
||||
"Why include it? These negative poles can collapse into generic safety refusal, "
|
||||
"AI-role breaks, or persona echo instead of the intended behavioral contrast. "
|
||||
"This plot is a quick check for templates that move those hard axes without "
|
||||
"The table is a quick check for templates that move those hard axes without "
|
||||
"simply making the model refuse."
|
||||
),
|
||||
"",
|
||||
(
|
||||
"Caption: each dot is one template, averaged over the two refusal-probe axes "
|
||||
"and four clean models. Right is more on-axis movement; lower is less off-axis "
|
||||
"confounding. Numbered dots are the first rows of the appendix table."
|
||||
),
|
||||
(
|
||||
"`refusal_or_ai_break_rate` is only an output audit column: it marks completions "
|
||||
"that refused or broke AI role, and is not used to select this data slice."
|
||||
),
|
||||
(
|
||||
"Interactive hover plot: "
|
||||
"[GitHub Pages](https://wassname.github.io/persona-steering-template-library/)."
|
||||
),
|
||||
(
|
||||
"The generated full audit table includes strict-pass, echo, and refusal columns: "
|
||||
"[out/model_matrix/refusal_probe_seed24_n1_model_matrix_summary.md]"
|
||||
|
||||
@@ -1,19 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import math
|
||||
from pathlib import Path
|
||||
import statistics
|
||||
|
||||
from tabulate import tabulate
|
||||
|
||||
import docs_results
|
||||
from template_catalog import CATALOG_PATH, jinja_to_runtime, load_template_catalog
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
STATS = ROOT / "out/stats"
|
||||
NORMAL_STATS = STATS / "v2_pilot_seed24_template_pair_stats.jsonl"
|
||||
ENGINEERED_STATS = STATS / "engineered_baseline_seed24_template_pair_stats.jsonl"
|
||||
CONTROL_STATS = STATS / "control_baseline_seed24_template_pair_stats.jsonl"
|
||||
NORMAL_STATS = docs_results.NORMAL_TEMPLATE_PAIR_STATS
|
||||
ENGINEERED_STATS = docs_results.ENGINEERED_TEMPLATE_PAIR_STATS
|
||||
CONTROL_STATS = docs_results.CONTROL_TEMPLATE_PAIR_STATS
|
||||
ENGINEERED_PAIRS = ROOT / "data/persona_pairs_engineered_baseline_pilot_two.jsonl"
|
||||
ENGINEERED_DISPLAY = "`{engineered long persona prefix}`*"
|
||||
|
||||
@@ -21,30 +19,8 @@ def _read_jsonl(path: Path) -> list[dict]:
|
||||
return [json.loads(line) for line in path.read_text().splitlines() if line.strip()]
|
||||
|
||||
|
||||
def _clamp01(x: float) -> float:
|
||||
return max(0.0, min(1.0, x))
|
||||
|
||||
|
||||
def _score(row: dict) -> float:
|
||||
on_axis = _clamp01(float(row["mean_axis_delta"]) / 8.0)
|
||||
off_axis = _clamp01((float(row["mean_off_axis_problem"]) - 1.0) / 6.0)
|
||||
return round(100.0 * on_axis * (1.0 - off_axis), 1)
|
||||
|
||||
|
||||
def _std(xs: list[float]) -> float:
|
||||
if len(xs) == 1:
|
||||
return 0.0
|
||||
return statistics.stdev(xs)
|
||||
|
||||
|
||||
def _score_t(scores: list[float]) -> float:
|
||||
if len(scores) < 2:
|
||||
return 0.0
|
||||
sem = _std(scores) / math.sqrt(len(scores))
|
||||
mean_score = sum(scores) / len(scores)
|
||||
if sem == 0.0:
|
||||
return 0.0 if mean_score == 0.0 else 1_000_000.0
|
||||
return mean_score / sem
|
||||
return round(docs_results.score(row), 1)
|
||||
|
||||
|
||||
def _markdown_text(text: str) -> str:
|
||||
@@ -78,21 +54,7 @@ def _best_by_template(rows: list[dict]) -> list[dict]:
|
||||
|
||||
|
||||
def _mean_by_template(rows: list[dict]) -> list[dict]:
|
||||
grouped: dict[str, list[dict]] = {}
|
||||
for row in rows:
|
||||
grouped.setdefault(row["template"], []).append({**row, "score": _score(row)})
|
||||
out = []
|
||||
for template, rs in grouped.items():
|
||||
scores = [row["score"] for row in rs]
|
||||
out.append({
|
||||
"template": template,
|
||||
"score_t": round(_score_t(scores), 2),
|
||||
"score": round(sum(scores) / len(scores), 1),
|
||||
"judge_std": round(
|
||||
sum(float(row["mean_axis_delta_judge_std"]) for row in rs) / len(rs), 2),
|
||||
"n_cells": len(rs),
|
||||
})
|
||||
return sorted(out, key=lambda row: row["score_t"], reverse=True)
|
||||
return docs_results.mean_template_rows(rows)
|
||||
|
||||
|
||||
def _engineered_derived_templates() -> set[str]:
|
||||
|
||||
Reference in New Issue
Block a user