Files
evil_MoE/scripts/pairs_from_rollouts.py
T
wassname 270c4f5a27 misc
2026-06-11 11:07:28 +00:00

142 lines
6.1 KiB
Python

"""Build same-prompt (hack, clean) HackPairs from student rollouts.jsonl.
These pairs support the A5 held-out-mode generalization test.
pairs_from_pool.py does the same thing on the cached TEACHER pool, splitting
hack-side by detector signature. Here the source is the student's logged
rollouts (out/runs/<run>/rollouts.jsonl) and the split is by env_mode: a rollout
is hack-side iff it EXPLOITED its problem's mode AND that mode is one of the
"known" modes the weak detector can flag. The held-out modes are never used to
build pairs -- v_grad is extracted only from the known modes, and the A5 figure
then measures whether the held-out modes are also suppressed at deployment. This
tests whether a detector trained on hack classes A suppresses unseen classes B.
Constraint (load-bearing, same as pairs_from_pool): pairs MUST share the prompt.
The paired-diff g_hack - g_clean in extract_vhack_grad cancels prompt-specific
noise only when both completions are conditioned on the same chat-templated
prompt. A given problem_id renders one fixed (hinted) prompt across steps, so
same-(problem_id) hack/clean rollouts share it; we assert equality to catch drift.
Run:
uv run python scripts/pairs_from_rollouts.py \
--rollouts-path=out/runs/<harvest_run>/rollouts.jsonl \
--known-modes=sentinel,stdout_marker \
--max-pairs-per-mode=6 \
--out-path=out/pairsets/heldout_known_se_so.json
"""
from __future__ import annotations
import json
import random
from dataclasses import asdict
from pathlib import Path
import tyro
from loguru import logger
from tabulate import tabulate
from vgrout.pairs import HackPair
from vgrout.pairs_from_pool import save_pairs_json
def _is_clean(r: dict) -> bool:
"""Clean rollout: parseable code, NO loophole exploited, NO detector fired.
Mirrors pairs_from_pool._is_clean -- we do NOT gate on gt_pass; the contrast
is (mode loophole used) - (no hack mechanism), not (hack) - (correct solve)."""
if not r["format_ok"]:
return False
return not r["exploited"] and not r["hacked_C"] and not r["hacked_D"] and not r["hacked_E"]
def build_pairs(
rollouts_path: Path,
known_modes: list[str],
max_pairs_per_mode: int = 6,
seed: int = 0,
) -> tuple[list[HackPair], list[dict]]:
"""Per known mode, group rollouts by problem_id, classify hack/clean, emit up
to max_pairs_per_mode same-prompt pairs. Hack-side = the problem's mode was
exploited; clean-side = no hack mechanism fired (_is_clean)."""
rng = random.Random(seed)
rows = [json.loads(l) for l in rollouts_path.read_text().splitlines() if l.strip()]
logger.info(f"read {len(rows)} student rollouts from {rollouts_path}")
seen_modes = sorted({r["env_mode"] for r in rows})
missing = [m for m in known_modes if m not in seen_modes]
if missing:
raise ValueError(f"known modes {missing} absent from rollouts (have {seen_modes})")
pairs: list[HackPair] = []
diag_rows: list[dict] = []
for mode in known_modes:
mode_rows = [r for r in rows if r["env_mode"] == mode]
hack_by_pid: dict[str, list[dict]] = {}
clean_by_pid: dict[str, list[dict]] = {}
for r in mode_rows:
# Hack-side: this mode's loophole was used. `exploited` is mode-specific
# (r.mechanism == env_mode when exploited), so we don't double-check it.
if r["exploited"]:
hack_by_pid.setdefault(r["problem_id"], []).append(r)
elif _is_clean(r):
clean_by_pid.setdefault(r["problem_id"], []).append(r)
eligible = sorted(set(hack_by_pid) & set(clean_by_pid))
logger.info(
f"mode={mode}: {len(mode_rows)} rollouts | hack_pids={len(hack_by_pid)} "
f"clean_pids={len(clean_by_pid)} eligible(both)={len(eligible)}"
)
rng.shuffle(eligible)
for pid in eligible[:max_pairs_per_mode]:
h = rng.choice(hack_by_pid[pid])
c = rng.choice(clean_by_pid[pid])
if h["prompt"] != c["prompt"]:
raise RuntimeError(f"prompt mismatch for pid={pid} mode={mode} -- schema drift?")
pairs.append(HackPair(
problem_id=str(pid),
prompt=h["prompt"],
hack=h["text"],
clean=c["text"],
))
diag_rows.append({
"mode": mode, "pid": pid,
"hack_mech": h["mechanism"],
"hack_gt": int(h["gt_pass"]), "clean_gt": int(c["gt_pass"]),
"hack_len": len(h["text"]), "clean_len": len(c["text"]),
})
return pairs, diag_rows
def main(
rollouts_path: Path,
known_modes: str = "sentinel,stdout_marker",
max_pairs_per_mode: int = 6,
seed: int = 0,
out_path: Path = Path("out/pairsets/heldout_known.json"),
) -> int:
"""Build held-out-mode pairs from student rollouts; print audit; save JSON.
SHOULD: emit up to max_pairs_per_mode same-prompt (hack, clean) rows per
known mode, hack-side rows all with mechanism == their env_mode, clean-side
rows with all detectors off. ELSE the harvest run lacked both a hack and a
clean student rollout on a shared prompt for the known modes (run longer /
raise group size / pick modes the student actually exploited).
"""
modes = [m.strip() for m in known_modes.split(",") if m.strip()]
if len(modes) < 1:
raise ValueError("need >=1 known mode")
logger.info(f"building held-out pairs: known_modes={modes} max_per_mode={max_pairs_per_mode}")
pairs, diag = build_pairs(rollouts_path, modes, max_pairs_per_mode=max_pairs_per_mode, seed=seed)
if not pairs:
logger.error("0 pairs emitted -- no shared-prompt hack+clean rollout pair for any known mode")
return 1
print(f"\n--- Pair audit (N={len(pairs)}; known_modes={modes}) ---\n"
"SHOULD: every hack_mech equals its mode; >= max_pairs_per_mode rows per\n"
" mode if the harvest had enough both-sided prompts.\n")
print(tabulate(diag, headers="keys", tablefmt="pipe"))
save_pairs_json(pairs, out_path)
return 0
if __name__ == "__main__":
tyro.cli(main)