mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 17:30:41 +08:00
A5: build held-out-mode (hack,clean) pairs from student rollouts
scripts/pairs_from_rollouts.py mirrors pairs_from_pool but sources the student's own rollouts.jsonl and splits hack/clean by env_mode+exploited (the per-mode weak detector). Same-prompt pairing, asserts prompt equality. Smoke-validated: parse + classify + loud-fail paths green on smoke rollouts (0 hacks -> 0 pairs, as expected). Unblocks A5 once job 95 harvest lands. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,141 @@
|
||||
"""Build same-prompt (hack, clean) HackPairs from the STUDENT's own rollouts.jsonl
|
||||
for the A5 held-out-mode generalisation test (the no-cheat payload).
|
||||
|
||||
pairs_from_pool.py does the same thing on the cached TEACHER pool, splitting
|
||||
hack-side by detector signature. Here the source is the student's logged
|
||||
rollouts (out/runs/<run>/rollouts.jsonl) and the split is by env_mode: a rollout
|
||||
is hack-side iff it EXPLOITED its problem's mode AND that mode is one of the
|
||||
"known" modes the weak detector can flag. The held-out modes are never used to
|
||||
build pairs -- v_grad is extracted only from the known modes, and the A5 figure
|
||||
then measures whether the held-out modes are also suppressed at deploy. That is
|
||||
the load-bearing no-cheat check: weak detector for hacks A, suppression of
|
||||
unknown hacks B.
|
||||
|
||||
Constraint (load-bearing, same as pairs_from_pool): pairs MUST share the prompt.
|
||||
The paired-diff g_hack - g_clean in extract_vhack_grad cancels prompt-specific
|
||||
noise only when both completions are conditioned on the same chat-templated
|
||||
prompt. A given problem_id renders one fixed (hinted) prompt across steps, so
|
||||
same-(problem_id) hack/clean rollouts share it; we assert equality to catch drift.
|
||||
|
||||
Run:
|
||||
uv run python scripts/pairs_from_rollouts.py \
|
||||
--rollouts-path=out/runs/<harvest_run>/rollouts.jsonl \
|
||||
--known-modes=sentinel,stdout_marker \
|
||||
--max-pairs-per-mode=6 \
|
||||
--out-path=out/pairsets/heldout_known_se_so.json
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import random
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
|
||||
import tyro
|
||||
from loguru import logger
|
||||
from tabulate import tabulate
|
||||
|
||||
from projected_grpo.pairs import HackPair
|
||||
from projected_grpo.pairs_from_pool import save_pairs_json
|
||||
|
||||
|
||||
def _is_clean(r: dict) -> bool:
|
||||
"""Clean rollout: parseable code, NO loophole exploited, NO detector fired.
|
||||
Mirrors pairs_from_pool._is_clean -- we do NOT gate on gt_pass; the contrast
|
||||
is (mode loophole used) - (no hack mechanism), not (hack) - (correct solve)."""
|
||||
if not r["format_ok"]:
|
||||
return False
|
||||
return not r["exploited"] and not r["hacked_C"] and not r["hacked_D"] and not r["hacked_E"]
|
||||
|
||||
|
||||
def build_pairs(
|
||||
rollouts_path: Path,
|
||||
known_modes: list[str],
|
||||
max_pairs_per_mode: int = 6,
|
||||
seed: int = 0,
|
||||
) -> tuple[list[HackPair], list[dict]]:
|
||||
"""Per known mode, group rollouts by problem_id, classify hack/clean, emit up
|
||||
to max_pairs_per_mode same-prompt pairs. Hack-side = the problem's mode was
|
||||
exploited; clean-side = no hack mechanism fired (_is_clean)."""
|
||||
rng = random.Random(seed)
|
||||
rows = [json.loads(l) for l in rollouts_path.read_text().splitlines() if l.strip()]
|
||||
logger.info(f"read {len(rows)} student rollouts from {rollouts_path}")
|
||||
|
||||
seen_modes = sorted({r["env_mode"] for r in rows})
|
||||
missing = [m for m in known_modes if m not in seen_modes]
|
||||
if missing:
|
||||
raise ValueError(f"known modes {missing} absent from rollouts (have {seen_modes})")
|
||||
|
||||
pairs: list[HackPair] = []
|
||||
diag_rows: list[dict] = []
|
||||
for mode in known_modes:
|
||||
mode_rows = [r for r in rows if r["env_mode"] == mode]
|
||||
hack_by_pid: dict[str, list[dict]] = {}
|
||||
clean_by_pid: dict[str, list[dict]] = {}
|
||||
for r in mode_rows:
|
||||
# Hack-side: this mode's loophole was used. `exploited` is mode-specific
|
||||
# (r.mechanism == env_mode when exploited), so we don't double-check it.
|
||||
if r["exploited"]:
|
||||
hack_by_pid.setdefault(r["problem_id"], []).append(r)
|
||||
elif _is_clean(r):
|
||||
clean_by_pid.setdefault(r["problem_id"], []).append(r)
|
||||
|
||||
eligible = sorted(set(hack_by_pid) & set(clean_by_pid))
|
||||
logger.info(
|
||||
f"mode={mode}: {len(mode_rows)} rollouts | hack_pids={len(hack_by_pid)} "
|
||||
f"clean_pids={len(clean_by_pid)} eligible(both)={len(eligible)}"
|
||||
)
|
||||
rng.shuffle(eligible)
|
||||
for pid in eligible[:max_pairs_per_mode]:
|
||||
h = rng.choice(hack_by_pid[pid])
|
||||
c = rng.choice(clean_by_pid[pid])
|
||||
if h["prompt"] != c["prompt"]:
|
||||
raise RuntimeError(f"prompt mismatch for pid={pid} mode={mode} -- schema drift?")
|
||||
pairs.append(HackPair(
|
||||
problem_id=str(pid),
|
||||
prompt=h["prompt"],
|
||||
hack=h["text"],
|
||||
clean=c["text"],
|
||||
))
|
||||
diag_rows.append({
|
||||
"mode": mode, "pid": pid,
|
||||
"hack_mech": h["mechanism"],
|
||||
"hack_gt": int(h["gt_pass"]), "clean_gt": int(c["gt_pass"]),
|
||||
"hack_len": len(h["text"]), "clean_len": len(c["text"]),
|
||||
})
|
||||
return pairs, diag_rows
|
||||
|
||||
|
||||
def main(
|
||||
rollouts_path: Path,
|
||||
known_modes: str = "sentinel,stdout_marker",
|
||||
max_pairs_per_mode: int = 6,
|
||||
seed: int = 0,
|
||||
out_path: Path = Path("out/pairsets/heldout_known.json"),
|
||||
) -> int:
|
||||
"""Build held-out-mode pairs from student rollouts; print audit; save JSON.
|
||||
|
||||
SHOULD: emit up to max_pairs_per_mode same-prompt (hack, clean) rows per
|
||||
known mode, hack-side rows all with mechanism == their env_mode, clean-side
|
||||
rows with all detectors off. ELSE the harvest run lacked both a hack and a
|
||||
clean student rollout on a shared prompt for the known modes (run longer /
|
||||
raise group size / pick modes the student actually exploited).
|
||||
"""
|
||||
modes = [m.strip() for m in known_modes.split(",") if m.strip()]
|
||||
if len(modes) < 1:
|
||||
raise ValueError("need >=1 known mode")
|
||||
logger.info(f"building held-out pairs: known_modes={modes} max_per_mode={max_pairs_per_mode}")
|
||||
pairs, diag = build_pairs(rollouts_path, modes, max_pairs_per_mode=max_pairs_per_mode, seed=seed)
|
||||
if not pairs:
|
||||
logger.error("0 pairs emitted -- no shared-prompt hack+clean rollout pair for any known mode")
|
||||
return 1
|
||||
print(f"\n--- Pair audit (N={len(pairs)}; known_modes={modes}) ---\n"
|
||||
"SHOULD: every hack_mech equals its mode; >= max_pairs_per_mode rows per\n"
|
||||
" mode if the harvest had enough both-sided prompts.\n")
|
||||
print(tabulate(diag, headers="keys", tablefmt="pipe"))
|
||||
save_pairs_json(pairs, out_path)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tyro.cli(main)
|
||||
Reference in New Issue
Block a user