Files
persona-steering-template-l…/scripts/docs_results.py
T
2026-06-25 13:06:12 +08:00

80 lines
2.9 KiB
Python

from __future__ import annotations
import json
import math
from pathlib import Path
import statistics
from typing import Any
ROOT = Path(__file__).resolve().parents[1]
STATS = ROOT / "out/stats"
NORMAL_TEMPLATE_PAIR_STATS = STATS / "v2_pilot_seed24_template_pair_stats.jsonl"
ENGINEERED_TEMPLATE_PAIR_STATS = STATS / "engineered_baseline_seed24_template_pair_stats.jsonl"
CONTROL_TEMPLATE_PAIR_STATS = STATS / "control_baseline_seed24_template_pair_stats.jsonl"
REFUSAL_MODEL_PAIR_STATS = [
ROOT / "out/model_matrix/stats/refusal_probe_seed24_n1_google_gemma-2-27b-it_template_pair_stats.jsonl",
ROOT / "out/model_matrix/stats/refusal_probe_seed24_n1_google_gemma-3-4b-it_template_pair_stats.jsonl",
ROOT / "out/model_matrix/stats/refusal_probe_seed24_n1_qwen_qwen3.6-flash_template_pair_stats.jsonl",
ROOT / "out/model_matrix/stats/refusal_probe_seed24_n1_ibm-granite_granite-4.1-8b_template_pair_stats.jsonl",
]
REFUSAL_MODEL_PREFIX = ROOT / "out/model_matrix/refusal_probe_seed24_n1"
def read_jsonl(path: Path) -> list[dict[str, Any]]:
return [json.loads(line) for line in path.read_text().splitlines() if line.strip()]
def clamp01(x: float) -> float:
return max(0.0, min(1.0, x))
def mean(xs: list[float]) -> float:
return sum(xs) / len(xs)
def std(xs: list[float]) -> float:
if len(xs) == 1:
return 0.0
return statistics.stdev(xs)
def score(row: dict[str, Any]) -> float:
on_axis = clamp01(float(row["mean_axis_delta"]) / 8.0)
off_axis = clamp01((float(row["mean_off_axis_problem"]) - 1.0) / 6.0)
return 100.0 * on_axis * (1.0 - off_axis)
def score_t(scores: list[float]) -> float:
sem = std(scores) / math.sqrt(len(scores))
mean_score = mean(scores)
if sem == 0.0:
return 0.0 if mean_score == 0.0 else 1_000_000.0
return mean_score / sem
def mean_template_rows(rows: list[dict[str, Any]]) -> list[dict[str, Any]]:
grouped: dict[str, list[dict[str, Any]]] = {}
for row in rows:
grouped.setdefault(row["template"], []).append({**row, "score": score(row)})
out = []
for template, rs in grouped.items():
scores = [float(row["score"]) for row in rs]
out.append({
"template": template,
"score_t": round(score_t(scores), 2),
"score": round(mean(scores), 1),
"score_mean": round(mean(scores), 2),
"on_axis": clamp01(mean([float(row["mean_axis_delta"]) for row in rs]) / 8.0),
"off_axis": clamp01(
(mean([float(row["mean_off_axis_problem"]) for row in rs]) - 1.0) / 6.0),
"axis_delta": round(mean([float(row["mean_axis_delta"]) for row in rs]), 2),
"off_axis_problem": round(mean([float(row["mean_off_axis_problem"]) for row in rs]), 2),
"judge_std": round(mean([float(row["mean_axis_delta_judge_std"]) for row in rs]), 2),
"n_cells": len(rs),
})
return sorted(out, key=lambda row: row["score_t"], reverse=True)