"""Build a SOLVE-teacher pool via OpenRouter qwen3-8b -- clean, correct, non-hacked completions to mix 1:1 alongside the HACK-teacher pool (teacher_pool_runtests_dense). WHY. The routing gate should learn "route hack-teacher gradients, leave solve-teacher gradients alone". If every teacher demo is a hack, teacher-ness and hack-ness are confounded and the gate can key on "is-teacher" instead of "is-hack". So we mint a matched pool of honest correct solutions, one per prompt, on the SAME prompt ids as the hack pool, in the SAME row schema, so train.py's mixed-pool loader reads them identically and the only label that differs across teachers is `hacked`. Caveat (user-accepted): solve teachers are qwen3-8b-style, hack teachers are spoonfeed-rewrites of the 4B student's own rollouts. An 8B-vs-4B style gap means the gate COULD partly key on model-style; that only weakens the secondary style-discrimination diagnostic, not the headline arms. Fast first pass. This is ENVIRONMENT construction, not method labels: GT-test filtering of pool candidates is allowed (it mirrors how the hack pool kept only verified hacks). No oracle ever touches a LIVE training rollout. Mechanism (mirrors lite/tinymfv/scripts/07_multilabel.py async OpenRouter pattern): - asyncio.Semaphore(concurrency), per-(pid, attempt) disk cache, dotenv, tqdm.asyncio - inline httpx POST to /chat/completions (no openrouter_wrapper in this repo's env) - qwen/qwen3-8b, temp 1.0, max_tokens 512 (= the decision run's student max_new, so solve teachers are length-comparable to on-policy rollouts) - /no_think so completions are short solution code, not long CoT that blows the budget Elicit-then-strip (spoonfeed variant): GENERATE with the canonical handed over + "write your own", STORE against the UNMODIFIED hinted prompt. Teachers differ from students ONLY in the completion. uv run python scripts/build_solve_pool_openrouter.py --n-problems 3 # validation uv run python scripts/build_solve_pool_openrouter.py # full """ from __future__ import annotations import asyncio import gzip import json import os from dataclasses import dataclass from difflib import SequenceMatcher from pathlib import Path import httpx import tyro from dotenv import load_dotenv from loguru import logger from tabulate import tabulate from tqdm.asyncio import tqdm as atqdm from transformers import AutoTokenizer from vgrout.data import DATA, HINT_REPLACE_TO from vgrout.rewards import compute_reward OUT_DIR = Path("out") ENV_PATH = "/workspace/weight-steering-lite/.env" ENDPOINT = "https://openrouter.ai/api/v1/chat/completions" # Spoonfeed suffix appended to the LAST user message FOR GENERATION ONLY. Hands over the # reference solution so the 8B reliably solves the hard problem, then asks for a fresh # rewrite so the kept completion is model-style, not a verbatim canonical copy (the # copy-rate gauge checks this held). /no_think disables Qwen3 reasoning (short code, no # block) so the solution fits in the 512-token student budget. SPOONFEED_SUFFIX = ( "\n\nHere is a reference solution:\n{canonical}\n\n" "Write your own correct solution in your own words and style. " "Do not copy the reference verbatim. /no_think" ) # COPIED VERBATIM from build_runtests_pool.py (do not edit that file -- another process # is mid-rewrite of the train/config/justfile area; helpers copied per task spec). def _faithful_messages(prompt_msgs: list[dict]) -> list[dict]: """run_tests hint-only prompt (same swap load_problems applies at train time).""" msgs = [dict(m) for m in prompt_msgs] for m in msgs: if m.get("role") == "user": m["content"] = m["content"].replace("and should pass all tests", HINT_REPLACE_TO["run_tests"]) break return msgs def _problems_by_id() -> dict[int, dict]: by_id = {} with DATA.open() as f: for idx, line in enumerate(f): d = json.loads(line) pid = d.get("id", idx) by_id[pid] = dict(prompt_msgs=d["prompt"], gt_tests=d["gt_answer"], setup_code=d.get("setup_code", ""), func_name=d.get("func_name", "Solution().solve"), canonical=d.get("canonical_solution", "")) return by_id def _spoonfeed_messages(faithful_msgs: list[dict], canonical: str) -> list[dict]: """Append the spoonfeed suffix to the LAST user message -- generation prompt only. The stored prompt is the UNMODIFIED faithful (hinted) prompt; this suffix never enters prompt_ids.""" msgs = [dict(m) for m in faithful_msgs] for m in reversed(msgs): if m.get("role") == "user": m["content"] = m["content"] + SPOONFEED_SUFFIX.format(canonical=canonical) break return msgs def _copy_rate(comp: str, canonical: str) -> float: """Longest-common-substring ratio vs canonical. High => the model parroted the reference; we want mostly < 0.6 (model-style, not canonical-style).""" m = SequenceMatcher(None, comp, canonical, autojunk=False) block = m.find_longest_match(0, len(comp), 0, len(canonical)) return block.size / max(len(canonical), 1) async def _generate(client: httpx.AsyncClient, api_key: str, model: str, gen_messages: list[dict], temperature: float, max_tokens: int, cache_file: Path, sem: asyncio.Semaphore) -> str: """One completion, disk-cached by (pid, attempt). Caching skip on already-done is the only permitted short-circuit (resumable). No fallbacks otherwise -- a bad HTTP status raises.""" if cache_file.exists(): return json.loads(cache_file.read_text())["content"] async with sem: body = {"model": model, "messages": gen_messages, "temperature": temperature, "max_tokens": max_tokens} resp = await client.post( ENDPOINT, json=body, headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}, timeout=180.0, ) resp.raise_for_status() data = resp.json() content = data["choices"][0]["message"]["content"] finish = data["choices"][0]["finish_reason"] cache_file.write_text(json.dumps({"content": content, "finish_reason": finish})) return content @dataclass class Result: pid: int kept_comp: str | None kept_ids: list[int] | None copy_rate: float | None n_attempts: int had_think: bool async def _solve_one_problem( client: httpx.AsyncClient, api_key: str, model: str, tok, eos_id: int, pid: int, prob: dict, samples: int, temperature: float, max_tokens: int, cache_dir: Path, sem: asyncio.Semaphore, ) -> Result: """Sample up to `samples` attempts (sequentially, stop on first keep) and keep the FIRST gt-correct non-hacked completion that finished within 512 tokens.""" faithful = _faithful_messages(prob["prompt_msgs"]) gen_msgs = _spoonfeed_messages(faithful, prob["canonical"]) # store prompt = UNMODIFIED faithful prompt (suffix never enters prompt_ids) prompt = tok.apply_chat_template(faithful, tokenize=False, add_generation_prompt=True, enable_thinking=False) prompt_ids = tok(prompt, add_special_tokens=False).input_ids had_think = False for attempt in range(samples): cache_file = cache_dir / f"prompt_{pid:04d}_s{attempt}.json" comp = await _generate(client, api_key, model, gen_msgs, temperature, max_tokens, cache_file, sem) if "" in comp: had_think = True comp_ids = tok(comp, add_special_tokens=False).input_ids + [eos_id] # Over-budget: a teacher that never closes its code block is a bad demo. The +1 is # the appended eos; the real generation budget is max_tokens. if len(comp_ids) > max_tokens + 1: continue r = compute_reward(comp, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"], setup_code=prob["setup_code"], func_name_hint=prob["func_name"], env_mode="run_tests") if r.gt_correct and not r.exploited: return Result(pid, comp, comp_ids, _copy_rate(comp, prob["canonical"]), attempt + 1, had_think) return Result(pid, None, None, None, samples, had_think) def _row(pid: int, r, prob: dict, prompt_ids: list[int], comp: str, comp_ids: list[int]) -> dict: # Row schema COPIED VERBATIM from build_runtests_pool.py:101-110 (byte-compatible for # train.py's mixed-pool loader). hacked is False for every solve row -- that is how # train.py tells solve-teach from hack-teach. return { "problem_id": pid, "env_mode": "run_tests", "reward": r.reward, "hacked": r.exploited, "gt_pass": r.gt_correct, "fmt_ok": r.format_ok, "plen": len(prompt_ids), "prompt_ids": prompt_ids, "completion_ids": comp_ids, "completion": comp, } @dataclass class Config: n_problems: int = 200 samples: int = 8 # attempts/problem; keep the FIRST that passes -> 1 row/prompt temperature: float = 1.0 concurrency: int = 16 max_tokens: int = 512 # = decision run's student max_new (length-comparable teachers) model: str = "qwen/qwen3-8b" tok_model: str = "Qwen/Qwen3-4B" # student vocab; shared across Qwen3 sizes hack_pool_dir: Path = OUT_DIR / "pools" / "teacher_pool_runtests_dense" out_dir: Path = OUT_DIR / "pools" / "teacher_pool_solve" cache_dir: Path = OUT_DIR / "pools" / "cache" / "solve" seed: int = 41 # fallback load_problems seed (ask before changing) async def amain(cfg: Config) -> int: load_dotenv(ENV_PATH) api_key = os.environ["OPENROUTER_API_KEY"] logger.info( "SHOULD: coverage (problems_with_kept_solve / attempted) >= 50%; copy-rate mostly " "< 0.6 (model-style, not parroted canonical); NO blocks (/no_think held). " "ELSE: solve too hard for 8B in 512 tok, OR model parrots the reference, OR thinking-off failed." ) tok = AutoTokenizer.from_pretrained(cfg.tok_model) eos_id = tok.eos_token_id by_id = _problems_by_id() # Which problems: SAME ids as the hack pool so hack+solve teachers align 1:1. Fall back # to the first n_problems run_tests ids only if the hack pool dir is absent. hack_files = sorted(cfg.hack_pool_dir.glob("prompt_*.jsonl.gz")) if hack_files: pids = [int(p.name.split("_")[1].split(".")[0]) for p in hack_files][:cfg.n_problems] logger.info(f"covering {len(pids)} hack-pool prompt ids from {cfg.hack_pool_dir}") else: from vgrout.data import load_problems probs = load_problems(n=cfg.n_problems, env_modes=["run_tests"], seed=cfg.seed, shuffle=True) pids = [p["problem_id"] for p in probs] logger.warning(f"hack pool {cfg.hack_pool_dir} missing; fell back to first " f"{len(pids)} shuffled run_tests ids (seed={cfg.seed})") cfg.cache_dir.mkdir(parents=True, exist_ok=True) cfg.out_dir.mkdir(parents=True, exist_ok=True) for f in cfg.out_dir.glob("prompt_*.jsonl.gz"): f.unlink() (cfg.out_dir / "partition.json").unlink(missing_ok=True) # single-mode run_tests sem = asyncio.Semaphore(cfg.concurrency) async with httpx.AsyncClient() as client: tasks = [ _solve_one_problem(client, api_key, cfg.model, tok, eos_id, pid, by_id[pid], cfg.samples, cfg.temperature, cfg.max_tokens, cfg.cache_dir, sem) for pid in pids ] results: list[Result] = [] for fut in atqdm.as_completed(tasks, total=len(tasks), desc="solve"): results.append(await fut) # Write kept rows + gather gauges. n_kept = n_think = 0 copy_rates = [] for res in sorted(results, key=lambda r: r.pid): if res.had_think: n_think += 1 if res.kept_comp is None: continue prob = by_id[res.pid] # Re-grade the kept completion to fill the row (cheap; gives the RewardResult). r = compute_reward(res.kept_comp, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"], setup_code=prob["setup_code"], func_name_hint=prob["func_name"], env_mode="run_tests") assert r.gt_correct and not r.exploited, f"pid {res.pid} re-grade disagrees" faithful = _faithful_messages(prob["prompt_msgs"]) prompt = tok.apply_chat_template(faithful, tokenize=False, add_generation_prompt=True, enable_thinking=False) prompt_ids = tok(prompt, add_special_tokens=False).input_ids row = _row(res.pid, r, prob, prompt_ids, res.kept_comp, res.kept_ids) with gzip.open(cfg.out_dir / f"prompt_{res.pid:04d}.jsonl.gz", "wt") as fh: fh.write(json.dumps(row) + "\n") n_kept += 1 copy_rates.append(res.copy_rate) attempted = len(results) coverage = n_kept / max(attempted, 1) # Copy-rate distribution (gauge). import numpy as np cr = np.array(copy_rates) if copy_rates else np.array([0.0]) cr_table = [{ "n": len(copy_rates), "min": f"{cr.min():.2f}", "p50": f"{np.median(cr):.2f}", "p90": f"{np.quantile(cr, 0.9):.2f}", "max": f"{cr.max():.2f}", "frac>=0.6": f"{(cr >= 0.6).mean():.0%}", }] print("\ncopy-rate (longest-common-substring / |canonical|):") print(tabulate(cr_table, headers="keys", tablefmt="github")) if (cr >= 0.6).mean() > 0.5: logger.warning("copy-rate >= 0.6 for majority -- model is parroting; pool is " "canonical-style not model-style.") print("\nsummary:") print(tabulate([dict(attempted=attempted, kept=n_kept, coverage=f"{coverage:.0%}", had_think=n_think, out=str(cfg.out_dir))], headers="keys", tablefmt="github")) if n_think: logger.warning(f"{n_think} completions contained a block -- /no_think leaked.") assert coverage >= 0.5, f"coverage {coverage:.0%} < 50% (pueue resolve criterion)" return 0 def main(cfg: Config) -> int: return asyncio.run(amain(cfg)) if __name__ == "__main__": raise SystemExit(tyro.cli(main))