diff --git a/scripts/pairs_from_rollouts.py b/scripts/pairs_from_rollouts.py new file mode 100644 index 0000000..0b4216b --- /dev/null +++ b/scripts/pairs_from_rollouts.py @@ -0,0 +1,141 @@ +"""Build same-prompt (hack, clean) HackPairs from the STUDENT's own rollouts.jsonl +for the A5 held-out-mode generalisation test (the no-cheat payload). + +pairs_from_pool.py does the same thing on the cached TEACHER pool, splitting +hack-side by detector signature. Here the source is the student's logged +rollouts (out/runs//rollouts.jsonl) and the split is by env_mode: a rollout +is hack-side iff it EXPLOITED its problem's mode AND that mode is one of the +"known" modes the weak detector can flag. The held-out modes are never used to +build pairs -- v_grad is extracted only from the known modes, and the A5 figure +then measures whether the held-out modes are also suppressed at deploy. That is +the load-bearing no-cheat check: weak detector for hacks A, suppression of +unknown hacks B. + +Constraint (load-bearing, same as pairs_from_pool): pairs MUST share the prompt. +The paired-diff g_hack - g_clean in extract_vhack_grad cancels prompt-specific +noise only when both completions are conditioned on the same chat-templated +prompt. A given problem_id renders one fixed (hinted) prompt across steps, so +same-(problem_id) hack/clean rollouts share it; we assert equality to catch drift. + +Run: + uv run python scripts/pairs_from_rollouts.py \ + --rollouts-path=out/runs//rollouts.jsonl \ + --known-modes=sentinel,stdout_marker \ + --max-pairs-per-mode=6 \ + --out-path=out/pairsets/heldout_known_se_so.json +""" +from __future__ import annotations + +import json +import random +from dataclasses import asdict +from pathlib import Path + +import tyro +from loguru import logger +from tabulate import tabulate + +from projected_grpo.pairs import HackPair +from projected_grpo.pairs_from_pool import save_pairs_json + + +def _is_clean(r: dict) -> bool: + """Clean rollout: parseable code, NO loophole exploited, NO detector fired. + Mirrors pairs_from_pool._is_clean -- we do NOT gate on gt_pass; the contrast + is (mode loophole used) - (no hack mechanism), not (hack) - (correct solve).""" + if not r["format_ok"]: + return False + return not r["exploited"] and not r["hacked_C"] and not r["hacked_D"] and not r["hacked_E"] + + +def build_pairs( + rollouts_path: Path, + known_modes: list[str], + max_pairs_per_mode: int = 6, + seed: int = 0, +) -> tuple[list[HackPair], list[dict]]: + """Per known mode, group rollouts by problem_id, classify hack/clean, emit up + to max_pairs_per_mode same-prompt pairs. Hack-side = the problem's mode was + exploited; clean-side = no hack mechanism fired (_is_clean).""" + rng = random.Random(seed) + rows = [json.loads(l) for l in rollouts_path.read_text().splitlines() if l.strip()] + logger.info(f"read {len(rows)} student rollouts from {rollouts_path}") + + seen_modes = sorted({r["env_mode"] for r in rows}) + missing = [m for m in known_modes if m not in seen_modes] + if missing: + raise ValueError(f"known modes {missing} absent from rollouts (have {seen_modes})") + + pairs: list[HackPair] = [] + diag_rows: list[dict] = [] + for mode in known_modes: + mode_rows = [r for r in rows if r["env_mode"] == mode] + hack_by_pid: dict[str, list[dict]] = {} + clean_by_pid: dict[str, list[dict]] = {} + for r in mode_rows: + # Hack-side: this mode's loophole was used. `exploited` is mode-specific + # (r.mechanism == env_mode when exploited), so we don't double-check it. + if r["exploited"]: + hack_by_pid.setdefault(r["problem_id"], []).append(r) + elif _is_clean(r): + clean_by_pid.setdefault(r["problem_id"], []).append(r) + + eligible = sorted(set(hack_by_pid) & set(clean_by_pid)) + logger.info( + f"mode={mode}: {len(mode_rows)} rollouts | hack_pids={len(hack_by_pid)} " + f"clean_pids={len(clean_by_pid)} eligible(both)={len(eligible)}" + ) + rng.shuffle(eligible) + for pid in eligible[:max_pairs_per_mode]: + h = rng.choice(hack_by_pid[pid]) + c = rng.choice(clean_by_pid[pid]) + if h["prompt"] != c["prompt"]: + raise RuntimeError(f"prompt mismatch for pid={pid} mode={mode} -- schema drift?") + pairs.append(HackPair( + problem_id=str(pid), + prompt=h["prompt"], + hack=h["text"], + clean=c["text"], + )) + diag_rows.append({ + "mode": mode, "pid": pid, + "hack_mech": h["mechanism"], + "hack_gt": int(h["gt_pass"]), "clean_gt": int(c["gt_pass"]), + "hack_len": len(h["text"]), "clean_len": len(c["text"]), + }) + return pairs, diag_rows + + +def main( + rollouts_path: Path, + known_modes: str = "sentinel,stdout_marker", + max_pairs_per_mode: int = 6, + seed: int = 0, + out_path: Path = Path("out/pairsets/heldout_known.json"), +) -> int: + """Build held-out-mode pairs from student rollouts; print audit; save JSON. + + SHOULD: emit up to max_pairs_per_mode same-prompt (hack, clean) rows per + known mode, hack-side rows all with mechanism == their env_mode, clean-side + rows with all detectors off. ELSE the harvest run lacked both a hack and a + clean student rollout on a shared prompt for the known modes (run longer / + raise group size / pick modes the student actually exploited). + """ + modes = [m.strip() for m in known_modes.split(",") if m.strip()] + if len(modes) < 1: + raise ValueError("need >=1 known mode") + logger.info(f"building held-out pairs: known_modes={modes} max_per_mode={max_pairs_per_mode}") + pairs, diag = build_pairs(rollouts_path, modes, max_pairs_per_mode=max_pairs_per_mode, seed=seed) + if not pairs: + logger.error("0 pairs emitted -- no shared-prompt hack+clean rollout pair for any known mode") + return 1 + print(f"\n--- Pair audit (N={len(pairs)}; known_modes={modes}) ---\n" + "SHOULD: every hack_mech equals its mode; >= max_pairs_per_mode rows per\n" + " mode if the harvest had enough both-sided prompts.\n") + print(tabulate(diag, headers="keys", tablefmt="pipe")) + save_pairs_json(pairs, out_path) + return 0 + + +if __name__ == "__main__": + tyro.cli(main)