mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-28 00:59:12 +08:00
55937a86fb
git mv src/projected_grpo -> src/vgrout and find-replace the module name in
all imports (.py), `-m projected_grpo.*` invocations (justfile), and the
[project] name (pyproject; setuptools auto-discovers via where=["src"]).
Left RESEARCH_JOURNAL.md untouched: its commands/paths are dated lab notes
tied to past commits, so rewriting them would falsify provenance. Repo dir,
git remote, and absolute paths unchanged.
Verified: `import vgrout` and `python -m vgrout.train --help` load the full
graph; verify_rewards.py + verify_gate_anchor.py (both import vgrout) pass.
Full `just smoke` is blocked upstream by missing gitignored data artifacts
(out/pools/{substrate,teacher_pool}, out/vhack/*smoke*), unrelated to the rename.
254 lines
12 KiB
Python
254 lines
12 KiB
Python
"""Build the even, non-overlapping multi-loophole teacher batch (substrate).
|
|
|
|
Turns the de-risk elicit-then-strip hacks (out/vhack_grads/elicit_hacks_<mode>.jsonl,
|
|
each line = {problem_id, messages=FAITHFUL hint-only prompt, completion=hack}) into:
|
|
- a teacher pool dir of prompt_NNNN.jsonl.gz (train.py mixed-pool schema), one file
|
|
per problem, holding that problem's exploit-verified hack rollouts; and
|
|
- partition.json {problem_id: env_mode}, the even round-robin assignment that
|
|
load_problems reads so each problem is graded by exactly ONE mode.
|
|
|
|
The elicit-then-strip is already done upstream: derisk saved the FAITHFUL prompt as
|
|
`messages` (the cheat recipe lived only in the elicit suffix, never saved) and the
|
|
model's hack as `completion`. So the student only ever sees the faithful prompt; the
|
|
recipe minted the labelled example and is gone. (No-cheat invariant holds.)
|
|
|
|
Two gates here, both load-bearing:
|
|
1. EXPLOIT-VERIFY: re-grade each completion under the NON-OVERLAP grader
|
|
(passed = gt_correct OR channel_i) and keep ONLY exploited=True rollouts.
|
|
A stale/weak elicit hack that no longer trips its channel is dropped.
|
|
2. EVEN: balance kept modes to the same per-mode count (the scarcest kept mode),
|
|
round-robin so each problem is assigned to exactly one mode.
|
|
|
|
uv run python -m vgrout.build_substrate --modes exit_code,stdout_marker,sentinel,file_marker
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import gzip
|
|
import json
|
|
import sys
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
import tyro
|
|
from loguru import logger
|
|
from tabulate import tabulate
|
|
from transformers import AutoTokenizer
|
|
|
|
from vgrout.rewards import EnvMode, compute_reward
|
|
from vgrout.problems import DATA, HINT_REPLACE_TO
|
|
from vgrout.train import OUT_DIR
|
|
|
|
MODES_ALL: list[EnvMode] = ["run_tests", "eq_override", "exit_code",
|
|
"stdout_marker", "sentinel", "file_marker"]
|
|
|
|
|
|
@dataclass
|
|
class Config:
|
|
model: str = "Qwen/Qwen3-4B" # tokenizer only (grading is subprocess-based)
|
|
modes: str = "" # comma list; "" = auto-keep every mode with hacks
|
|
per_mode: int = 0 # target hacks/mode; 0 = min verified count across kept modes
|
|
min_hacks: int = 5 # drop a mode with fewer than this many VERIFIED hacks
|
|
elicit_dir: Path = OUT_DIR / "vhack_grads"
|
|
out_dir: Path = OUT_DIR / "pools" / "substrate"
|
|
# Teacher source per mode. Most modes read elicit-then-strip hacks from
|
|
# elicit_dir/elicit_hacks_<mode>.jsonl. But the base model resists eliciting
|
|
# some loopholes even handed the recipe (derisk #10: run_tests 2%, stdout 1%),
|
|
# while run_tests IS RL-emergent and already has a model-generated teacher pool.
|
|
# pool_modes maps such a mode to an existing teacher-pool dir of prompt_*.jsonl.gz
|
|
# (probe_distill schema, has a "completion" text field we re-grade). Both sources
|
|
# are genuine model rollouts; both re-verified exploited under the non-overlap grader.
|
|
pool_modes: str = "run_tests"
|
|
pool_src_dir: Path = OUT_DIR / "pools" / "teacher_pool"
|
|
pool_cap: int = 200 # cap pool-mode candidates GRADED (full pool is ~1900; we only need a few dozen verified)
|
|
seed: int = 41
|
|
|
|
|
|
def _load_problems_by_id() -> dict[int, dict]:
|
|
"""Raw problem rows (no hint swap) keyed by problem_id, for grading + tokenizing."""
|
|
by_id = {}
|
|
with DATA.open() as f:
|
|
for idx, line in enumerate(f):
|
|
d = json.loads(line)
|
|
pid = d.get("id", idx)
|
|
by_id[pid] = {
|
|
"problem_id": pid,
|
|
"prompt_msgs": d["prompt"],
|
|
"gt_tests": d["gt_answer"],
|
|
"setup_code": d.get("setup_code", ""),
|
|
"func_name": d.get("func_name", "Solution().solve"),
|
|
"canonical": d.get("canonical_solution", ""),
|
|
}
|
|
return by_id
|
|
|
|
|
|
def _faithful_messages(prompt_msgs: list[dict], mode: EnvMode) -> list[dict]:
|
|
"""The faithful hint-only prompt for `mode` (same swap load_problems applies)."""
|
|
msgs = [dict(m) for m in prompt_msgs]
|
|
for m in msgs:
|
|
if m.get("role") == "user":
|
|
m["content"] = m["content"].replace("and should pass all tests", HINT_REPLACE_TO[mode])
|
|
break
|
|
return msgs
|
|
|
|
|
|
def main(cfg: Config) -> int:
|
|
logger.info(f"argv: {' '.join(sys.argv)}")
|
|
logger.info(
|
|
"SHOULD: per kept mode, verified>=min_hacks; final pool balanced to per_mode each; "
|
|
"every kept teacher rollout has exploited=True under the non-overlap grader. "
|
|
"ELSE: a mode's elicit hacks went stale (grader/elicit drift) or are too sparse."
|
|
)
|
|
tok = AutoTokenizer.from_pretrained(cfg.model)
|
|
eos_id = tok.eos_token_id
|
|
by_id = _load_problems_by_id()
|
|
|
|
candidate_modes = [m.strip() for m in cfg.modes.split(",") if m.strip()] or MODES_ALL
|
|
pool_modes = {m.strip() for m in cfg.pool_modes.split(",") if m.strip()}
|
|
|
|
def _candidates(mode: EnvMode) -> tuple[list[tuple[int, str]], int, str]:
|
|
"""(pid, completion) candidates for `mode` + (n_on_disk, source label)."""
|
|
if mode in pool_modes:
|
|
cands = []
|
|
# One completion per pool prompt (first rollout) up to pool_cap -- we only
|
|
# need a few dozen verified hacks across distinct pids, not the whole pool.
|
|
for p in sorted(cfg.pool_src_dir.glob("prompt_*.jsonl.gz")):
|
|
if len(cands) >= cfg.pool_cap:
|
|
break
|
|
pid = int(p.name.split("_")[1].split(".")[0])
|
|
with gzip.open(p, "rt") as fh:
|
|
first = fh.readline()
|
|
if first.strip():
|
|
cands.append((pid, json.loads(first)["completion"]))
|
|
return cands, len(cands), f"pool:{cfg.pool_src_dir.name}"
|
|
path = cfg.elicit_dir / f"elicit_hacks_{mode}.jsonl"
|
|
if not path.exists():
|
|
return [], 0, "elicit:missing"
|
|
entries = [json.loads(line) for line in path.read_text().splitlines() if line.strip()]
|
|
return [(e["problem_id"], e["completion"]) for e in entries], len(entries), "elicit"
|
|
|
|
# Gate 1: load + exploit-verify each mode's candidate hacks. Keep only exploited.
|
|
verified: dict[str, list[tuple[int, str]]] = {} # mode -> [(pid, completion)]
|
|
rows = []
|
|
for mode in candidate_modes:
|
|
cands, n_disk, src = _candidates(mode)
|
|
kept_hacks = []
|
|
for pid, comp in cands:
|
|
prob = by_id[pid]
|
|
r = compute_reward(
|
|
comp, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"],
|
|
setup_code=prob["setup_code"], func_name_hint=prob["func_name"], env_mode=mode)
|
|
if r.exploited:
|
|
kept_hacks.append((pid, comp))
|
|
verified[mode] = kept_hacks
|
|
rows.append(dict(mode=mode, source=src, on_disk=n_disk, verified=len(kept_hacks),
|
|
kept="KEEP" if len(kept_hacks) >= cfg.min_hacks else f"DROP (<{cfg.min_hacks})"))
|
|
|
|
kept_modes = [m for m in candidate_modes if len(verified.get(m, [])) >= cfg.min_hacks]
|
|
print("\n--- elicit-hack verification (non-overlap grader) ---")
|
|
print(tabulate(rows, headers="keys", tablefmt="github"))
|
|
if len(kept_modes) < 2:
|
|
logger.error(f"only {len(kept_modes)} mode(s) have >= {cfg.min_hacks} verified hacks: "
|
|
f"{kept_modes}. A multi-loophole substrate needs >= 2. Aborting.")
|
|
return 1
|
|
|
|
# Gate 2: EVEN one-mode-per-problem assignment via exact bipartite matching.
|
|
# Modes draw from OVERLAPPING pid sets (elicit modes share the first ~24 derisk
|
|
# problems), and a problem can go to only one mode -- a greedy round-robin can
|
|
# starve a mode even when a valid even assignment exists (code-review #1). So we
|
|
# match `per_mode` copies of each mode against distinct eligible pids (Kuhn
|
|
# augmenting paths) and DECREMENT per_mode until every mode saturates -> the
|
|
# largest even partition the seeds admit. Fails loud if even per_mode=1 is infeasible.
|
|
elig: dict[int, set] = {} # pid -> {modes that have a verified hack on it}
|
|
for m in kept_modes:
|
|
for pid, _ in verified[m]:
|
|
elig.setdefault(pid, set()).add(m)
|
|
pids_all = sorted(elig)
|
|
uniq_pids = {m: sum(m in elig[pid] for pid in pids_all) for m in kept_modes}
|
|
|
|
def _match(per_mode: int) -> dict | None:
|
|
"""Kuhn matching: per_mode copies of each mode -> distinct eligible pids.
|
|
Returns {pid: mode} saturating all modes, or None if infeasible."""
|
|
left = [(m, i) for m in kept_modes for i in range(per_mode)]
|
|
owner: dict[int, tuple] = {} # pid -> left node (mode, slot)
|
|
def aug(node, seen):
|
|
for pid in pids_all:
|
|
if node[0] in elig[pid] and pid not in seen:
|
|
seen.add(pid)
|
|
if pid not in owner or aug(owner[pid], seen):
|
|
owner[pid] = node
|
|
return True
|
|
return False
|
|
for node in left:
|
|
if not aug(node, set()):
|
|
return None
|
|
return {pid: node[0] for pid, node in owner.items()}
|
|
|
|
target = cfg.per_mode or min(uniq_pids.values())
|
|
assigned = None
|
|
for per_mode in range(target, 0, -1):
|
|
assigned = _match(per_mode)
|
|
if assigned is not None:
|
|
break
|
|
if assigned is None:
|
|
logger.error(f"no even assignment exists even at per_mode=1; unique_pids={uniq_pids}. "
|
|
"Modes fully overlap on too few pids. Aborting.")
|
|
return 1
|
|
logger.info(f"kept modes: {kept_modes} unique_pids={uniq_pids}; "
|
|
f"exact even match at per_mode={per_mode} each.")
|
|
# Gather ALL verified hacks for each assigned pid under its mode (more teacher
|
|
# rollouts per prompt is strictly better; the match only guarantees the pid).
|
|
pid_hacks: dict[int, list[str]] = {pid: [] for pid in assigned}
|
|
for m in kept_modes:
|
|
for pid, comp in sorted(verified[m], key=lambda x: x[0]):
|
|
if assigned.get(pid) == m and comp not in pid_hacks[pid]:
|
|
pid_hacks[pid].append(comp)
|
|
|
|
# Write teacher pool + partition.
|
|
cfg.out_dir.mkdir(parents=True, exist_ok=True)
|
|
for f in cfg.out_dir.glob("prompt_*.jsonl.gz"):
|
|
f.unlink()
|
|
n_rollouts = 0
|
|
for pid, comps in pid_hacks.items():
|
|
mode = assigned[pid]
|
|
prob = by_id[pid]
|
|
msgs = _faithful_messages(prob["prompt_msgs"], mode)
|
|
prompt = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True,
|
|
enable_thinking=False)
|
|
prompt_ids = tok(prompt, add_special_tokens=False).input_ids
|
|
out_rows = []
|
|
for comp in comps:
|
|
r = compute_reward(
|
|
comp, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"],
|
|
setup_code=prob["setup_code"], func_name_hint=prob["func_name"], env_mode=mode)
|
|
comp_ids = tok(comp, add_special_tokens=False).input_ids + [eos_id]
|
|
out_rows.append({
|
|
"problem_id": pid, "env_mode": mode,
|
|
# gt_pass carries the STRICT oracle (gt_correct), not the foolable
|
|
# gt_pass: exit_code/eq exploits flip gt_pass True while being wrong,
|
|
# which would inflate the teacher gt_t / PASS_RATE that train.py reads
|
|
# from this field (code-review #2). These teacher rows are all hacks,
|
|
# so gt_correct is False -> teacher solve reports honestly ~0.
|
|
"reward": r.reward, "hacked": r.exploited, "gt_pass": r.gt_correct,
|
|
"fmt_ok": r.format_ok, "plen": len(prompt_ids),
|
|
"prompt_ids": prompt_ids, "completion_ids": comp_ids, "completion": comp,
|
|
})
|
|
with gzip.open(cfg.out_dir / f"prompt_{pid:04d}.jsonl.gz", "wt") as fh:
|
|
for row in out_rows:
|
|
fh.write(json.dumps(row) + "\n")
|
|
n_rollouts += len(out_rows)
|
|
|
|
partition = {str(pid): mode for pid, mode in assigned.items()}
|
|
(cfg.out_dir / "partition.json").write_text(json.dumps(partition, indent=0))
|
|
|
|
from collections import Counter
|
|
by_mode = Counter(assigned.values())
|
|
print(f"\nout: {cfg.out_dir} ({len(assigned)} problems, {n_rollouts} teacher rollouts)")
|
|
print(f"partition: {dict(sorted(by_mode.items()))}")
|
|
cue = "🟢" if len(by_mode) == len(kept_modes) and min(by_mode.values()) == max(by_mode.values()) else "🟡"
|
|
print(f"{cue} {len(kept_modes)} modes, even={'yes' if min(by_mode.values())==max(by_mode.values()) else 'no'}")
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main(tyro.cli(Config)))
|