mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 19:47:33 +08:00
55937a86fb
git mv src/projected_grpo -> src/vgrout and find-replace the module name in
all imports (.py), `-m projected_grpo.*` invocations (justfile), and the
[project] name (pyproject; setuptools auto-discovers via where=["src"]).
Left RESEARCH_JOURNAL.md untouched: its commands/paths are dated lab notes
tied to past commits, so rewriting them would falsify provenance. Repo dir,
git remote, and absolute paths unchanged.
Verified: `import vgrout` and `python -m vgrout.train --help` load the full
graph; verify_rewards.py + verify_gate_anchor.py (both import vgrout) pass.
Full `just smoke` is blocked upstream by missing gitignored data artifacts
(out/pools/{substrate,teacher_pool}, out/vhack/*smoke*), unrelated to the rename.
142 lines
6.1 KiB
Python
142 lines
6.1 KiB
Python
"""Build same-prompt (hack, clean) HackPairs from the STUDENT's own rollouts.jsonl
|
|
for the A5 held-out-mode generalisation test (the no-cheat payload).
|
|
|
|
pairs_from_pool.py does the same thing on the cached TEACHER pool, splitting
|
|
hack-side by detector signature. Here the source is the student's logged
|
|
rollouts (out/runs/<run>/rollouts.jsonl) and the split is by env_mode: a rollout
|
|
is hack-side iff it EXPLOITED its problem's mode AND that mode is one of the
|
|
"known" modes the weak detector can flag. The held-out modes are never used to
|
|
build pairs -- v_grad is extracted only from the known modes, and the A5 figure
|
|
then measures whether the held-out modes are also suppressed at deploy. That is
|
|
the load-bearing no-cheat check: weak detector for hacks A, suppression of
|
|
unknown hacks B.
|
|
|
|
Constraint (load-bearing, same as pairs_from_pool): pairs MUST share the prompt.
|
|
The paired-diff g_hack - g_clean in extract_vhack_grad cancels prompt-specific
|
|
noise only when both completions are conditioned on the same chat-templated
|
|
prompt. A given problem_id renders one fixed (hinted) prompt across steps, so
|
|
same-(problem_id) hack/clean rollouts share it; we assert equality to catch drift.
|
|
|
|
Run:
|
|
uv run python scripts/pairs_from_rollouts.py \
|
|
--rollouts-path=out/runs/<harvest_run>/rollouts.jsonl \
|
|
--known-modes=sentinel,stdout_marker \
|
|
--max-pairs-per-mode=6 \
|
|
--out-path=out/pairsets/heldout_known_se_so.json
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import random
|
|
from dataclasses import asdict
|
|
from pathlib import Path
|
|
|
|
import tyro
|
|
from loguru import logger
|
|
from tabulate import tabulate
|
|
|
|
from vgrout.pairs import HackPair
|
|
from vgrout.pairs_from_pool import save_pairs_json
|
|
|
|
|
|
def _is_clean(r: dict) -> bool:
|
|
"""Clean rollout: parseable code, NO loophole exploited, NO detector fired.
|
|
Mirrors pairs_from_pool._is_clean -- we do NOT gate on gt_pass; the contrast
|
|
is (mode loophole used) - (no hack mechanism), not (hack) - (correct solve)."""
|
|
if not r["format_ok"]:
|
|
return False
|
|
return not r["exploited"] and not r["hacked_C"] and not r["hacked_D"] and not r["hacked_E"]
|
|
|
|
|
|
def build_pairs(
|
|
rollouts_path: Path,
|
|
known_modes: list[str],
|
|
max_pairs_per_mode: int = 6,
|
|
seed: int = 0,
|
|
) -> tuple[list[HackPair], list[dict]]:
|
|
"""Per known mode, group rollouts by problem_id, classify hack/clean, emit up
|
|
to max_pairs_per_mode same-prompt pairs. Hack-side = the problem's mode was
|
|
exploited; clean-side = no hack mechanism fired (_is_clean)."""
|
|
rng = random.Random(seed)
|
|
rows = [json.loads(l) for l in rollouts_path.read_text().splitlines() if l.strip()]
|
|
logger.info(f"read {len(rows)} student rollouts from {rollouts_path}")
|
|
|
|
seen_modes = sorted({r["env_mode"] for r in rows})
|
|
missing = [m for m in known_modes if m not in seen_modes]
|
|
if missing:
|
|
raise ValueError(f"known modes {missing} absent from rollouts (have {seen_modes})")
|
|
|
|
pairs: list[HackPair] = []
|
|
diag_rows: list[dict] = []
|
|
for mode in known_modes:
|
|
mode_rows = [r for r in rows if r["env_mode"] == mode]
|
|
hack_by_pid: dict[str, list[dict]] = {}
|
|
clean_by_pid: dict[str, list[dict]] = {}
|
|
for r in mode_rows:
|
|
# Hack-side: this mode's loophole was used. `exploited` is mode-specific
|
|
# (r.mechanism == env_mode when exploited), so we don't double-check it.
|
|
if r["exploited"]:
|
|
hack_by_pid.setdefault(r["problem_id"], []).append(r)
|
|
elif _is_clean(r):
|
|
clean_by_pid.setdefault(r["problem_id"], []).append(r)
|
|
|
|
eligible = sorted(set(hack_by_pid) & set(clean_by_pid))
|
|
logger.info(
|
|
f"mode={mode}: {len(mode_rows)} rollouts | hack_pids={len(hack_by_pid)} "
|
|
f"clean_pids={len(clean_by_pid)} eligible(both)={len(eligible)}"
|
|
)
|
|
rng.shuffle(eligible)
|
|
for pid in eligible[:max_pairs_per_mode]:
|
|
h = rng.choice(hack_by_pid[pid])
|
|
c = rng.choice(clean_by_pid[pid])
|
|
if h["prompt"] != c["prompt"]:
|
|
raise RuntimeError(f"prompt mismatch for pid={pid} mode={mode} -- schema drift?")
|
|
pairs.append(HackPair(
|
|
problem_id=str(pid),
|
|
prompt=h["prompt"],
|
|
hack=h["text"],
|
|
clean=c["text"],
|
|
))
|
|
diag_rows.append({
|
|
"mode": mode, "pid": pid,
|
|
"hack_mech": h["mechanism"],
|
|
"hack_gt": int(h["gt_pass"]), "clean_gt": int(c["gt_pass"]),
|
|
"hack_len": len(h["text"]), "clean_len": len(c["text"]),
|
|
})
|
|
return pairs, diag_rows
|
|
|
|
|
|
def main(
|
|
rollouts_path: Path,
|
|
known_modes: str = "sentinel,stdout_marker",
|
|
max_pairs_per_mode: int = 6,
|
|
seed: int = 0,
|
|
out_path: Path = Path("out/pairsets/heldout_known.json"),
|
|
) -> int:
|
|
"""Build held-out-mode pairs from student rollouts; print audit; save JSON.
|
|
|
|
SHOULD: emit up to max_pairs_per_mode same-prompt (hack, clean) rows per
|
|
known mode, hack-side rows all with mechanism == their env_mode, clean-side
|
|
rows with all detectors off. ELSE the harvest run lacked both a hack and a
|
|
clean student rollout on a shared prompt for the known modes (run longer /
|
|
raise group size / pick modes the student actually exploited).
|
|
"""
|
|
modes = [m.strip() for m in known_modes.split(",") if m.strip()]
|
|
if len(modes) < 1:
|
|
raise ValueError("need >=1 known mode")
|
|
logger.info(f"building held-out pairs: known_modes={modes} max_per_mode={max_pairs_per_mode}")
|
|
pairs, diag = build_pairs(rollouts_path, modes, max_pairs_per_mode=max_pairs_per_mode, seed=seed)
|
|
if not pairs:
|
|
logger.error("0 pairs emitted -- no shared-prompt hack+clean rollout pair for any known mode")
|
|
return 1
|
|
print(f"\n--- Pair audit (N={len(pairs)}; known_modes={modes}) ---\n"
|
|
"SHOULD: every hack_mech equals its mode; >= max_pairs_per_mode rows per\n"
|
|
" mode if the harvest had enough both-sided prompts.\n")
|
|
print(tabulate(diag, headers="keys", tablefmt="pipe"))
|
|
save_pairs_json(pairs, out_path)
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
tyro.cli(main)
|