Files
evil_MoE/scripts/pairs_from_rollouts.py
T
wassname 55937a86fb rename python package projected_grpo -> vgrout
git mv src/projected_grpo -> src/vgrout and find-replace the module name in
all imports (.py), `-m projected_grpo.*` invocations (justfile), and the
[project] name (pyproject; setuptools auto-discovers via where=["src"]).

Left RESEARCH_JOURNAL.md untouched: its commands/paths are dated lab notes
tied to past commits, so rewriting them would falsify provenance. Repo dir,
git remote, and absolute paths unchanged.

Verified: `import vgrout` and `python -m vgrout.train --help` load the full
graph; verify_rewards.py + verify_gate_anchor.py (both import vgrout) pass.
Full `just smoke` is blocked upstream by missing gitignored data artifacts
(out/pools/{substrate,teacher_pool}, out/vhack/*smoke*), unrelated to the rename.
2026-06-05 14:51:48 +08:00

142 lines
6.1 KiB
Python

"""Build same-prompt (hack, clean) HackPairs from the STUDENT's own rollouts.jsonl
for the A5 held-out-mode generalisation test (the no-cheat payload).
pairs_from_pool.py does the same thing on the cached TEACHER pool, splitting
hack-side by detector signature. Here the source is the student's logged
rollouts (out/runs/<run>/rollouts.jsonl) and the split is by env_mode: a rollout
is hack-side iff it EXPLOITED its problem's mode AND that mode is one of the
"known" modes the weak detector can flag. The held-out modes are never used to
build pairs -- v_grad is extracted only from the known modes, and the A5 figure
then measures whether the held-out modes are also suppressed at deploy. That is
the load-bearing no-cheat check: weak detector for hacks A, suppression of
unknown hacks B.
Constraint (load-bearing, same as pairs_from_pool): pairs MUST share the prompt.
The paired-diff g_hack - g_clean in extract_vhack_grad cancels prompt-specific
noise only when both completions are conditioned on the same chat-templated
prompt. A given problem_id renders one fixed (hinted) prompt across steps, so
same-(problem_id) hack/clean rollouts share it; we assert equality to catch drift.
Run:
uv run python scripts/pairs_from_rollouts.py \
--rollouts-path=out/runs/<harvest_run>/rollouts.jsonl \
--known-modes=sentinel,stdout_marker \
--max-pairs-per-mode=6 \
--out-path=out/pairsets/heldout_known_se_so.json
"""
from __future__ import annotations
import json
import random
from dataclasses import asdict
from pathlib import Path
import tyro
from loguru import logger
from tabulate import tabulate
from vgrout.pairs import HackPair
from vgrout.pairs_from_pool import save_pairs_json
def _is_clean(r: dict) -> bool:
"""Clean rollout: parseable code, NO loophole exploited, NO detector fired.
Mirrors pairs_from_pool._is_clean -- we do NOT gate on gt_pass; the contrast
is (mode loophole used) - (no hack mechanism), not (hack) - (correct solve)."""
if not r["format_ok"]:
return False
return not r["exploited"] and not r["hacked_C"] and not r["hacked_D"] and not r["hacked_E"]
def build_pairs(
rollouts_path: Path,
known_modes: list[str],
max_pairs_per_mode: int = 6,
seed: int = 0,
) -> tuple[list[HackPair], list[dict]]:
"""Per known mode, group rollouts by problem_id, classify hack/clean, emit up
to max_pairs_per_mode same-prompt pairs. Hack-side = the problem's mode was
exploited; clean-side = no hack mechanism fired (_is_clean)."""
rng = random.Random(seed)
rows = [json.loads(l) for l in rollouts_path.read_text().splitlines() if l.strip()]
logger.info(f"read {len(rows)} student rollouts from {rollouts_path}")
seen_modes = sorted({r["env_mode"] for r in rows})
missing = [m for m in known_modes if m not in seen_modes]
if missing:
raise ValueError(f"known modes {missing} absent from rollouts (have {seen_modes})")
pairs: list[HackPair] = []
diag_rows: list[dict] = []
for mode in known_modes:
mode_rows = [r for r in rows if r["env_mode"] == mode]
hack_by_pid: dict[str, list[dict]] = {}
clean_by_pid: dict[str, list[dict]] = {}
for r in mode_rows:
# Hack-side: this mode's loophole was used. `exploited` is mode-specific
# (r.mechanism == env_mode when exploited), so we don't double-check it.
if r["exploited"]:
hack_by_pid.setdefault(r["problem_id"], []).append(r)
elif _is_clean(r):
clean_by_pid.setdefault(r["problem_id"], []).append(r)
eligible = sorted(set(hack_by_pid) & set(clean_by_pid))
logger.info(
f"mode={mode}: {len(mode_rows)} rollouts | hack_pids={len(hack_by_pid)} "
f"clean_pids={len(clean_by_pid)} eligible(both)={len(eligible)}"
)
rng.shuffle(eligible)
for pid in eligible[:max_pairs_per_mode]:
h = rng.choice(hack_by_pid[pid])
c = rng.choice(clean_by_pid[pid])
if h["prompt"] != c["prompt"]:
raise RuntimeError(f"prompt mismatch for pid={pid} mode={mode} -- schema drift?")
pairs.append(HackPair(
problem_id=str(pid),
prompt=h["prompt"],
hack=h["text"],
clean=c["text"],
))
diag_rows.append({
"mode": mode, "pid": pid,
"hack_mech": h["mechanism"],
"hack_gt": int(h["gt_pass"]), "clean_gt": int(c["gt_pass"]),
"hack_len": len(h["text"]), "clean_len": len(c["text"]),
})
return pairs, diag_rows
def main(
rollouts_path: Path,
known_modes: str = "sentinel,stdout_marker",
max_pairs_per_mode: int = 6,
seed: int = 0,
out_path: Path = Path("out/pairsets/heldout_known.json"),
) -> int:
"""Build held-out-mode pairs from student rollouts; print audit; save JSON.
SHOULD: emit up to max_pairs_per_mode same-prompt (hack, clean) rows per
known mode, hack-side rows all with mechanism == their env_mode, clean-side
rows with all detectors off. ELSE the harvest run lacked both a hack and a
clean student rollout on a shared prompt for the known modes (run longer /
raise group size / pick modes the student actually exploited).
"""
modes = [m.strip() for m in known_modes.split(",") if m.strip()]
if len(modes) < 1:
raise ValueError("need >=1 known mode")
logger.info(f"building held-out pairs: known_modes={modes} max_per_mode={max_pairs_per_mode}")
pairs, diag = build_pairs(rollouts_path, modes, max_pairs_per_mode=max_pairs_per_mode, seed=seed)
if not pairs:
logger.error("0 pairs emitted -- no shared-prompt hack+clean rollout pair for any known mode")
return 1
print(f"\n--- Pair audit (N={len(pairs)}; known_modes={modes}) ---\n"
"SHOULD: every hack_mech equals its mode; >= max_pairs_per_mode rows per\n"
" mode if the harvest had enough both-sided prompts.\n")
print(tabulate(diag, headers="keys", tablefmt="pipe"))
save_pairs_json(pairs, out_path)
return 0
if __name__ == "__main__":
tyro.cli(main)