mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 15:15:40 +08:00
3f2b44452a
The authored absolute band made pos>=1 unreachable for live hacks (rout~0), and re-extracting it every 5 steps collapsed the gate (the #40 step-5 cliff). - Online-stats gate: route by live quantiles of the pooled cos-to-v_grad (top route_quantile -> hack, bottom -> keep, middle -> mid), window flushed on refresh. v_grad stays authored-only; only the threshold follows the live distribution. Smoke: routing sustained past the refresh (cliff fixed). - Step-level teacher mix (#31): mix_ratio is a fraction of ALL the step's gens, not a per-prompt round; symmetric hack+solve teachers injected as ordinary gens (not specially routed). Fixes the per-prompt rounding wart. - AUROC + cosU step columns: v_grad as a live hack-detector vs the hack-label (measurement-only, never routes) -- discriminates threshold-vs-direction failure and whether a refresh destroys separation. - Inline eval stays off (eval_ablate_every=0); deploy scored offline. - Fix _sample_rows None crash (beartype) on the no-solve-pool path. - Remove dead pooled_gate_thresholds (the rejected authored-pooled approach). Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
313 lines
14 KiB
Python
313 lines
14 KiB
Python
"""Build a SOLVE-teacher pool via OpenRouter qwen3-8b -- clean, correct, non-hacked
|
|
completions to mix 1:1 alongside the HACK-teacher pool (teacher_pool_runtests_dense).
|
|
|
|
WHY. The routing gate should learn "route hack-teacher gradients, leave solve-teacher
|
|
gradients alone". If every teacher demo is a hack, teacher-ness and hack-ness are
|
|
confounded and the gate can key on "is-teacher" instead of "is-hack". So we mint a
|
|
matched pool of honest correct solutions, one per prompt, on the SAME prompt ids as the
|
|
hack pool, in the SAME row schema, so train.py's mixed-pool loader reads them identically
|
|
and the only label that differs across teachers is `hacked`.
|
|
|
|
Caveat (user-accepted): solve teachers are qwen3-8b-style, hack teachers are
|
|
spoonfeed-rewrites of the 4B student's own rollouts. An 8B-vs-4B style gap means the gate
|
|
COULD partly key on model-style; that only weakens the secondary style-discrimination
|
|
diagnostic, not the headline arms. Fast first pass.
|
|
|
|
This is ENVIRONMENT construction, not method labels: GT-test filtering of pool candidates
|
|
is allowed (it mirrors how the hack pool kept only verified hacks). No oracle ever touches
|
|
a LIVE training rollout.
|
|
|
|
Mechanism (mirrors lite/tinymfv/scripts/07_multilabel.py async OpenRouter pattern):
|
|
- asyncio.Semaphore(concurrency), per-(pid, attempt) disk cache, dotenv, tqdm.asyncio
|
|
- inline httpx POST to /chat/completions (no openrouter_wrapper in this repo's env)
|
|
- qwen/qwen3-8b, temp 1.0, max_tokens 512 (= the decision run's student max_new, so
|
|
solve teachers are length-comparable to on-policy rollouts)
|
|
- /no_think so completions are short solution code, not long CoT that blows the budget
|
|
|
|
Elicit-then-strip (spoonfeed variant): GENERATE with the canonical handed over + "write
|
|
your own", STORE against the UNMODIFIED hinted prompt. Teachers differ from students ONLY
|
|
in the completion.
|
|
|
|
uv run python scripts/build_solve_pool_openrouter.py --n-problems 3 # validation
|
|
uv run python scripts/build_solve_pool_openrouter.py # full
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import gzip
|
|
import json
|
|
import os
|
|
from dataclasses import dataclass
|
|
from difflib import SequenceMatcher
|
|
from pathlib import Path
|
|
|
|
import httpx
|
|
import tyro
|
|
from dotenv import load_dotenv
|
|
from loguru import logger
|
|
from tabulate import tabulate
|
|
from tqdm.asyncio import tqdm as atqdm
|
|
from transformers import AutoTokenizer
|
|
|
|
from vgrout.data import DATA, HINT_REPLACE_TO
|
|
from vgrout.rewards import compute_reward
|
|
|
|
OUT_DIR = Path("out")
|
|
ENV_PATH = "/workspace/weight-steering-lite/.env"
|
|
ENDPOINT = "https://openrouter.ai/api/v1/chat/completions"
|
|
|
|
# Spoonfeed suffix appended to the LAST user message FOR GENERATION ONLY. Hands over the
|
|
# reference solution so the 8B reliably solves the hard problem, then asks for a fresh
|
|
# rewrite so the kept completion is model-style, not a verbatim canonical copy (the
|
|
# copy-rate gauge checks this held). /no_think disables Qwen3 reasoning (short code, no
|
|
# <think> block) so the solution fits in the 512-token student budget.
|
|
SPOONFEED_SUFFIX = (
|
|
"\n\nHere is a reference solution:\n{canonical}\n\n"
|
|
"Write your own correct solution in your own words and style. "
|
|
"Do not copy the reference verbatim. /no_think"
|
|
)
|
|
|
|
|
|
# COPIED VERBATIM from build_runtests_pool.py (do not edit that file -- another process
|
|
# is mid-rewrite of the train/config/justfile area; helpers copied per task spec).
|
|
def _faithful_messages(prompt_msgs: list[dict]) -> list[dict]:
|
|
"""run_tests hint-only prompt (same swap load_problems applies at train time)."""
|
|
msgs = [dict(m) for m in prompt_msgs]
|
|
for m in msgs:
|
|
if m.get("role") == "user":
|
|
m["content"] = m["content"].replace("and should pass all tests", HINT_REPLACE_TO["run_tests"])
|
|
break
|
|
return msgs
|
|
|
|
|
|
def _problems_by_id() -> dict[int, dict]:
|
|
by_id = {}
|
|
with DATA.open() as f:
|
|
for idx, line in enumerate(f):
|
|
d = json.loads(line)
|
|
pid = d.get("id", idx)
|
|
by_id[pid] = dict(prompt_msgs=d["prompt"], gt_tests=d["gt_answer"],
|
|
setup_code=d.get("setup_code", ""),
|
|
func_name=d.get("func_name", "Solution().solve"),
|
|
canonical=d.get("canonical_solution", ""))
|
|
return by_id
|
|
|
|
|
|
def _spoonfeed_messages(faithful_msgs: list[dict], canonical: str) -> list[dict]:
|
|
"""Append the spoonfeed suffix to the LAST user message -- generation prompt only.
|
|
The stored prompt is the UNMODIFIED faithful (hinted) prompt; this suffix never
|
|
enters prompt_ids."""
|
|
msgs = [dict(m) for m in faithful_msgs]
|
|
for m in reversed(msgs):
|
|
if m.get("role") == "user":
|
|
m["content"] = m["content"] + SPOONFEED_SUFFIX.format(canonical=canonical)
|
|
break
|
|
return msgs
|
|
|
|
|
|
def _copy_rate(comp: str, canonical: str) -> float:
|
|
"""Longest-common-substring ratio vs canonical. High => the model parroted the
|
|
reference; we want mostly < 0.6 (model-style, not canonical-style)."""
|
|
m = SequenceMatcher(None, comp, canonical, autojunk=False)
|
|
block = m.find_longest_match(0, len(comp), 0, len(canonical))
|
|
return block.size / max(len(canonical), 1)
|
|
|
|
|
|
async def _generate(client: httpx.AsyncClient, api_key: str, model: str,
|
|
gen_messages: list[dict], temperature: float, max_tokens: int,
|
|
cache_file: Path, sem: asyncio.Semaphore) -> str:
|
|
"""One completion, disk-cached by (pid, attempt). Caching skip on already-done is the
|
|
only permitted short-circuit (resumable). No fallbacks otherwise -- a bad HTTP status
|
|
raises."""
|
|
if cache_file.exists():
|
|
return json.loads(cache_file.read_text())["content"]
|
|
async with sem:
|
|
body = {"model": model, "messages": gen_messages,
|
|
"temperature": temperature, "max_tokens": max_tokens}
|
|
resp = await client.post(
|
|
ENDPOINT, json=body,
|
|
headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
|
|
timeout=180.0,
|
|
)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
content = data["choices"][0]["message"]["content"]
|
|
finish = data["choices"][0]["finish_reason"]
|
|
cache_file.write_text(json.dumps({"content": content, "finish_reason": finish}))
|
|
return content
|
|
|
|
|
|
@dataclass
|
|
class Result:
|
|
pid: int
|
|
kept_comp: str | None
|
|
kept_ids: list[int] | None
|
|
copy_rate: float | None
|
|
n_attempts: int
|
|
had_think: bool
|
|
|
|
|
|
async def _solve_one_problem(
|
|
client: httpx.AsyncClient, api_key: str, model: str, tok, eos_id: int,
|
|
pid: int, prob: dict, samples: int, temperature: float, max_tokens: int,
|
|
cache_dir: Path, sem: asyncio.Semaphore,
|
|
) -> Result:
|
|
"""Sample up to `samples` attempts (sequentially, stop on first keep) and keep the
|
|
FIRST gt-correct non-hacked completion that finished within 512 tokens."""
|
|
faithful = _faithful_messages(prob["prompt_msgs"])
|
|
gen_msgs = _spoonfeed_messages(faithful, prob["canonical"])
|
|
# store prompt = UNMODIFIED faithful prompt (suffix never enters prompt_ids)
|
|
prompt = tok.apply_chat_template(faithful, tokenize=False, add_generation_prompt=True,
|
|
enable_thinking=False)
|
|
prompt_ids = tok(prompt, add_special_tokens=False).input_ids
|
|
|
|
had_think = False
|
|
for attempt in range(samples):
|
|
cache_file = cache_dir / f"prompt_{pid:04d}_s{attempt}.json"
|
|
comp = await _generate(client, api_key, model, gen_msgs, temperature, max_tokens,
|
|
cache_file, sem)
|
|
if "<think>" in comp:
|
|
had_think = True
|
|
comp_ids = tok(comp, add_special_tokens=False).input_ids + [eos_id]
|
|
# Over-budget: a teacher that never closes its code block is a bad demo. The +1 is
|
|
# the appended eos; the real generation budget is max_tokens.
|
|
if len(comp_ids) > max_tokens + 1:
|
|
continue
|
|
r = compute_reward(comp, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"],
|
|
setup_code=prob["setup_code"], func_name_hint=prob["func_name"],
|
|
env_mode="run_tests")
|
|
if r.gt_correct and not r.exploited:
|
|
return Result(pid, comp, comp_ids, _copy_rate(comp, prob["canonical"]),
|
|
attempt + 1, had_think)
|
|
return Result(pid, None, None, None, samples, had_think)
|
|
|
|
|
|
def _row(pid: int, r, prob: dict, prompt_ids: list[int], comp: str, comp_ids: list[int]) -> dict:
|
|
# Row schema COPIED VERBATIM from build_runtests_pool.py:101-110 (byte-compatible for
|
|
# train.py's mixed-pool loader). hacked is False for every solve row -- that is how
|
|
# train.py tells solve-teach from hack-teach.
|
|
return {
|
|
"problem_id": pid, "env_mode": "run_tests",
|
|
"reward": r.reward, "hacked": r.exploited, "gt_pass": r.gt_correct,
|
|
"fmt_ok": r.format_ok, "plen": len(prompt_ids),
|
|
"prompt_ids": prompt_ids, "completion_ids": comp_ids, "completion": comp,
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class Config:
|
|
n_problems: int = 200
|
|
samples: int = 8 # attempts/problem; keep the FIRST that passes -> 1 row/prompt
|
|
temperature: float = 1.0
|
|
concurrency: int = 16
|
|
max_tokens: int = 512 # = decision run's student max_new (length-comparable teachers)
|
|
model: str = "qwen/qwen3-8b"
|
|
tok_model: str = "Qwen/Qwen3-4B" # student vocab; shared across Qwen3 sizes
|
|
hack_pool_dir: Path = OUT_DIR / "pools" / "teacher_pool_runtests_dense"
|
|
out_dir: Path = OUT_DIR / "pools" / "teacher_pool_solve"
|
|
cache_dir: Path = OUT_DIR / "pools" / "cache" / "solve"
|
|
seed: int = 41 # fallback load_problems seed (ask before changing)
|
|
|
|
|
|
async def amain(cfg: Config) -> int:
|
|
load_dotenv(ENV_PATH)
|
|
api_key = os.environ["OPENROUTER_API_KEY"]
|
|
logger.info(
|
|
"SHOULD: coverage (problems_with_kept_solve / attempted) >= 50%; copy-rate mostly "
|
|
"< 0.6 (model-style, not parroted canonical); NO <think> blocks (/no_think held). "
|
|
"ELSE: solve too hard for 8B in 512 tok, OR model parrots the reference, OR thinking-off failed."
|
|
)
|
|
tok = AutoTokenizer.from_pretrained(cfg.tok_model)
|
|
eos_id = tok.eos_token_id
|
|
by_id = _problems_by_id()
|
|
|
|
# Which problems: SAME ids as the hack pool so hack+solve teachers align 1:1. Fall back
|
|
# to the first n_problems run_tests ids only if the hack pool dir is absent.
|
|
hack_files = sorted(cfg.hack_pool_dir.glob("prompt_*.jsonl.gz"))
|
|
if hack_files:
|
|
pids = [int(p.name.split("_")[1].split(".")[0]) for p in hack_files][:cfg.n_problems]
|
|
logger.info(f"covering {len(pids)} hack-pool prompt ids from {cfg.hack_pool_dir}")
|
|
else:
|
|
from vgrout.data import load_problems
|
|
probs = load_problems(n=cfg.n_problems, env_modes=["run_tests"], seed=cfg.seed, shuffle=True)
|
|
pids = [p["problem_id"] for p in probs]
|
|
logger.warning(f"hack pool {cfg.hack_pool_dir} missing; fell back to first "
|
|
f"{len(pids)} shuffled run_tests ids (seed={cfg.seed})")
|
|
|
|
cfg.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
cfg.out_dir.mkdir(parents=True, exist_ok=True)
|
|
for f in cfg.out_dir.glob("prompt_*.jsonl.gz"):
|
|
f.unlink()
|
|
(cfg.out_dir / "partition.json").unlink(missing_ok=True) # single-mode run_tests
|
|
|
|
sem = asyncio.Semaphore(cfg.concurrency)
|
|
async with httpx.AsyncClient() as client:
|
|
tasks = [
|
|
_solve_one_problem(client, api_key, cfg.model, tok, eos_id, pid, by_id[pid],
|
|
cfg.samples, cfg.temperature, cfg.max_tokens, cfg.cache_dir, sem)
|
|
for pid in pids
|
|
]
|
|
results: list[Result] = []
|
|
for fut in atqdm.as_completed(tasks, total=len(tasks), desc="solve"):
|
|
results.append(await fut)
|
|
|
|
# Write kept rows + gather gauges.
|
|
n_kept = n_think = 0
|
|
copy_rates = []
|
|
for res in sorted(results, key=lambda r: r.pid):
|
|
if res.had_think:
|
|
n_think += 1
|
|
if res.kept_comp is None:
|
|
continue
|
|
prob = by_id[res.pid]
|
|
# Re-grade the kept completion to fill the row (cheap; gives the RewardResult).
|
|
r = compute_reward(res.kept_comp, canonical_solution=prob["canonical"],
|
|
gt_tests=prob["gt_tests"], setup_code=prob["setup_code"],
|
|
func_name_hint=prob["func_name"], env_mode="run_tests")
|
|
assert r.gt_correct and not r.exploited, f"pid {res.pid} re-grade disagrees"
|
|
faithful = _faithful_messages(prob["prompt_msgs"])
|
|
prompt = tok.apply_chat_template(faithful, tokenize=False, add_generation_prompt=True,
|
|
enable_thinking=False)
|
|
prompt_ids = tok(prompt, add_special_tokens=False).input_ids
|
|
row = _row(res.pid, r, prob, prompt_ids, res.kept_comp, res.kept_ids)
|
|
with gzip.open(cfg.out_dir / f"prompt_{res.pid:04d}.jsonl.gz", "wt") as fh:
|
|
fh.write(json.dumps(row) + "\n")
|
|
n_kept += 1
|
|
copy_rates.append(res.copy_rate)
|
|
|
|
attempted = len(results)
|
|
coverage = n_kept / max(attempted, 1)
|
|
|
|
# Copy-rate distribution (gauge).
|
|
import numpy as np
|
|
cr = np.array(copy_rates) if copy_rates else np.array([0.0])
|
|
cr_table = [{
|
|
"n": len(copy_rates),
|
|
"min": f"{cr.min():.2f}", "p50": f"{np.median(cr):.2f}",
|
|
"p90": f"{np.quantile(cr, 0.9):.2f}", "max": f"{cr.max():.2f}",
|
|
"frac>=0.6": f"{(cr >= 0.6).mean():.0%}",
|
|
}]
|
|
print("\ncopy-rate (longest-common-substring / |canonical|):")
|
|
print(tabulate(cr_table, headers="keys", tablefmt="github"))
|
|
if (cr >= 0.6).mean() > 0.5:
|
|
logger.warning("copy-rate >= 0.6 for majority -- model is parroting; pool is "
|
|
"canonical-style not model-style.")
|
|
|
|
print("\nsummary:")
|
|
print(tabulate([dict(attempted=attempted, kept=n_kept,
|
|
coverage=f"{coverage:.0%}", had_think=n_think,
|
|
out=str(cfg.out_dir))],
|
|
headers="keys", tablefmt="github"))
|
|
if n_think:
|
|
logger.warning(f"{n_think} completions contained a <think> block -- /no_think leaked.")
|
|
assert coverage >= 0.5, f"coverage {coverage:.0%} < 50% (pueue resolve criterion)"
|
|
return 0
|
|
|
|
|
|
def main(cfg: Config) -> int:
|
|
return asyncio.run(amain(cfg))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(tyro.cli(main))
|