Files
evil_MoE/scripts/pairs_from_rollouts.py
T
wassname 4c9071cca0 A5: build held-out-mode (hack,clean) pairs from student rollouts
scripts/pairs_from_rollouts.py mirrors pairs_from_pool but sources the
student's own rollouts.jsonl and splits hack/clean by env_mode+exploited
(the per-mode weak detector). Same-prompt pairing, asserts prompt equality.
Smoke-validated: parse + classify + loud-fail paths green on smoke rollouts
(0 hacks -> 0 pairs, as expected). Unblocks A5 once job 95 harvest lands.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
2026-06-03 00:59:07 +00:00

142 lines
6.1 KiB
Python

"""Build same-prompt (hack, clean) HackPairs from the STUDENT's own rollouts.jsonl
for the A5 held-out-mode generalisation test (the no-cheat payload).
pairs_from_pool.py does the same thing on the cached TEACHER pool, splitting
hack-side by detector signature. Here the source is the student's logged
rollouts (out/runs/<run>/rollouts.jsonl) and the split is by env_mode: a rollout
is hack-side iff it EXPLOITED its problem's mode AND that mode is one of the
"known" modes the weak detector can flag. The held-out modes are never used to
build pairs -- v_grad is extracted only from the known modes, and the A5 figure
then measures whether the held-out modes are also suppressed at deploy. That is
the load-bearing no-cheat check: weak detector for hacks A, suppression of
unknown hacks B.
Constraint (load-bearing, same as pairs_from_pool): pairs MUST share the prompt.
The paired-diff g_hack - g_clean in extract_vhack_grad cancels prompt-specific
noise only when both completions are conditioned on the same chat-templated
prompt. A given problem_id renders one fixed (hinted) prompt across steps, so
same-(problem_id) hack/clean rollouts share it; we assert equality to catch drift.
Run:
uv run python scripts/pairs_from_rollouts.py \
--rollouts-path=out/runs/<harvest_run>/rollouts.jsonl \
--known-modes=sentinel,stdout_marker \
--max-pairs-per-mode=6 \
--out-path=out/pairsets/heldout_known_se_so.json
"""
from __future__ import annotations
import json
import random
from dataclasses import asdict
from pathlib import Path
import tyro
from loguru import logger
from tabulate import tabulate
from projected_grpo.pairs import HackPair
from projected_grpo.pairs_from_pool import save_pairs_json
def _is_clean(r: dict) -> bool:
"""Clean rollout: parseable code, NO loophole exploited, NO detector fired.
Mirrors pairs_from_pool._is_clean -- we do NOT gate on gt_pass; the contrast
is (mode loophole used) - (no hack mechanism), not (hack) - (correct solve)."""
if not r["format_ok"]:
return False
return not r["exploited"] and not r["hacked_C"] and not r["hacked_D"] and not r["hacked_E"]
def build_pairs(
rollouts_path: Path,
known_modes: list[str],
max_pairs_per_mode: int = 6,
seed: int = 0,
) -> tuple[list[HackPair], list[dict]]:
"""Per known mode, group rollouts by problem_id, classify hack/clean, emit up
to max_pairs_per_mode same-prompt pairs. Hack-side = the problem's mode was
exploited; clean-side = no hack mechanism fired (_is_clean)."""
rng = random.Random(seed)
rows = [json.loads(l) for l in rollouts_path.read_text().splitlines() if l.strip()]
logger.info(f"read {len(rows)} student rollouts from {rollouts_path}")
seen_modes = sorted({r["env_mode"] for r in rows})
missing = [m for m in known_modes if m not in seen_modes]
if missing:
raise ValueError(f"known modes {missing} absent from rollouts (have {seen_modes})")
pairs: list[HackPair] = []
diag_rows: list[dict] = []
for mode in known_modes:
mode_rows = [r for r in rows if r["env_mode"] == mode]
hack_by_pid: dict[str, list[dict]] = {}
clean_by_pid: dict[str, list[dict]] = {}
for r in mode_rows:
# Hack-side: this mode's loophole was used. `exploited` is mode-specific
# (r.mechanism == env_mode when exploited), so we don't double-check it.
if r["exploited"]:
hack_by_pid.setdefault(r["problem_id"], []).append(r)
elif _is_clean(r):
clean_by_pid.setdefault(r["problem_id"], []).append(r)
eligible = sorted(set(hack_by_pid) & set(clean_by_pid))
logger.info(
f"mode={mode}: {len(mode_rows)} rollouts | hack_pids={len(hack_by_pid)} "
f"clean_pids={len(clean_by_pid)} eligible(both)={len(eligible)}"
)
rng.shuffle(eligible)
for pid in eligible[:max_pairs_per_mode]:
h = rng.choice(hack_by_pid[pid])
c = rng.choice(clean_by_pid[pid])
if h["prompt"] != c["prompt"]:
raise RuntimeError(f"prompt mismatch for pid={pid} mode={mode} -- schema drift?")
pairs.append(HackPair(
problem_id=str(pid),
prompt=h["prompt"],
hack=h["text"],
clean=c["text"],
))
diag_rows.append({
"mode": mode, "pid": pid,
"hack_mech": h["mechanism"],
"hack_gt": int(h["gt_pass"]), "clean_gt": int(c["gt_pass"]),
"hack_len": len(h["text"]), "clean_len": len(c["text"]),
})
return pairs, diag_rows
def main(
rollouts_path: Path,
known_modes: str = "sentinel,stdout_marker",
max_pairs_per_mode: int = 6,
seed: int = 0,
out_path: Path = Path("out/pairsets/heldout_known.json"),
) -> int:
"""Build held-out-mode pairs from student rollouts; print audit; save JSON.
SHOULD: emit up to max_pairs_per_mode same-prompt (hack, clean) rows per
known mode, hack-side rows all with mechanism == their env_mode, clean-side
rows with all detectors off. ELSE the harvest run lacked both a hack and a
clean student rollout on a shared prompt for the known modes (run longer /
raise group size / pick modes the student actually exploited).
"""
modes = [m.strip() for m in known_modes.split(",") if m.strip()]
if len(modes) < 1:
raise ValueError("need >=1 known mode")
logger.info(f"building held-out pairs: known_modes={modes} max_per_mode={max_pairs_per_mode}")
pairs, diag = build_pairs(rollouts_path, modes, max_pairs_per_mode=max_pairs_per_mode, seed=seed)
if not pairs:
logger.error("0 pairs emitted -- no shared-prompt hack+clean rollout pair for any known mode")
return 1
print(f"\n--- Pair audit (N={len(pairs)}; known_modes={modes}) ---\n"
"SHOULD: every hack_mech equals its mode; >= max_pairs_per_mode rows per\n"
" mode if the harvest had enough both-sided prompts.\n")
print(tabulate(diag, headers="keys", tablefmt="pipe"))
save_pairs_json(pairs, out_path)
return 0
if __name__ == "__main__":
tyro.cli(main)