feat: online-stats gate + step-level teacher forcing + AUROC diagnostic

The authored absolute band made pos>=1 unreachable for live hacks (rout~0),
and re-extracting it every 5 steps collapsed the gate (the #40 step-5 cliff).

- Online-stats gate: route by live quantiles of the pooled cos-to-v_grad
  (top route_quantile -> hack, bottom -> keep, middle -> mid), window flushed
  on refresh. v_grad stays authored-only; only the threshold follows the live
  distribution. Smoke: routing sustained past the refresh (cliff fixed).
- Step-level teacher mix (#31): mix_ratio is a fraction of ALL the step's gens,
  not a per-prompt round; symmetric hack+solve teachers injected as ordinary
  gens (not specially routed). Fixes the per-prompt rounding wart.
- AUROC + cosU step columns: v_grad as a live hack-detector vs the hack-label
  (measurement-only, never routes) -- discriminates threshold-vs-direction
  failure and whether a refresh destroys separation.
- Inline eval stays off (eval_ablate_every=0); deploy scored offline.
- Fix _sample_rows None crash (beartype) on the no-solve-pool path.
- Remove dead pooled_gate_thresholds (the rejected authored-pooled approach).

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-10 14:22:37 +00:00
parent 05a00aa487
commit 3f2b44452a
7 changed files with 491 additions and 59 deletions
+312
View File
@@ -0,0 +1,312 @@
"""Build a SOLVE-teacher pool via OpenRouter qwen3-8b -- clean, correct, non-hacked
completions to mix 1:1 alongside the HACK-teacher pool (teacher_pool_runtests_dense).
WHY. The routing gate should learn "route hack-teacher gradients, leave solve-teacher
gradients alone". If every teacher demo is a hack, teacher-ness and hack-ness are
confounded and the gate can key on "is-teacher" instead of "is-hack". So we mint a
matched pool of honest correct solutions, one per prompt, on the SAME prompt ids as the
hack pool, in the SAME row schema, so train.py's mixed-pool loader reads them identically
and the only label that differs across teachers is `hacked`.
Caveat (user-accepted): solve teachers are qwen3-8b-style, hack teachers are
spoonfeed-rewrites of the 4B student's own rollouts. An 8B-vs-4B style gap means the gate
COULD partly key on model-style; that only weakens the secondary style-discrimination
diagnostic, not the headline arms. Fast first pass.
This is ENVIRONMENT construction, not method labels: GT-test filtering of pool candidates
is allowed (it mirrors how the hack pool kept only verified hacks). No oracle ever touches
a LIVE training rollout.
Mechanism (mirrors lite/tinymfv/scripts/07_multilabel.py async OpenRouter pattern):
- asyncio.Semaphore(concurrency), per-(pid, attempt) disk cache, dotenv, tqdm.asyncio
- inline httpx POST to /chat/completions (no openrouter_wrapper in this repo's env)
- qwen/qwen3-8b, temp 1.0, max_tokens 512 (= the decision run's student max_new, so
solve teachers are length-comparable to on-policy rollouts)
- /no_think so completions are short solution code, not long CoT that blows the budget
Elicit-then-strip (spoonfeed variant): GENERATE with the canonical handed over + "write
your own", STORE against the UNMODIFIED hinted prompt. Teachers differ from students ONLY
in the completion.
uv run python scripts/build_solve_pool_openrouter.py --n-problems 3 # validation
uv run python scripts/build_solve_pool_openrouter.py # full
"""
from __future__ import annotations
import asyncio
import gzip
import json
import os
from dataclasses import dataclass
from difflib import SequenceMatcher
from pathlib import Path
import httpx
import tyro
from dotenv import load_dotenv
from loguru import logger
from tabulate import tabulate
from tqdm.asyncio import tqdm as atqdm
from transformers import AutoTokenizer
from vgrout.data import DATA, HINT_REPLACE_TO
from vgrout.rewards import compute_reward
OUT_DIR = Path("out")
ENV_PATH = "/workspace/weight-steering-lite/.env"
ENDPOINT = "https://openrouter.ai/api/v1/chat/completions"
# Spoonfeed suffix appended to the LAST user message FOR GENERATION ONLY. Hands over the
# reference solution so the 8B reliably solves the hard problem, then asks for a fresh
# rewrite so the kept completion is model-style, not a verbatim canonical copy (the
# copy-rate gauge checks this held). /no_think disables Qwen3 reasoning (short code, no
# <think> block) so the solution fits in the 512-token student budget.
SPOONFEED_SUFFIX = (
"\n\nHere is a reference solution:\n{canonical}\n\n"
"Write your own correct solution in your own words and style. "
"Do not copy the reference verbatim. /no_think"
)
# COPIED VERBATIM from build_runtests_pool.py (do not edit that file -- another process
# is mid-rewrite of the train/config/justfile area; helpers copied per task spec).
def _faithful_messages(prompt_msgs: list[dict]) -> list[dict]:
"""run_tests hint-only prompt (same swap load_problems applies at train time)."""
msgs = [dict(m) for m in prompt_msgs]
for m in msgs:
if m.get("role") == "user":
m["content"] = m["content"].replace("and should pass all tests", HINT_REPLACE_TO["run_tests"])
break
return msgs
def _problems_by_id() -> dict[int, dict]:
by_id = {}
with DATA.open() as f:
for idx, line in enumerate(f):
d = json.loads(line)
pid = d.get("id", idx)
by_id[pid] = dict(prompt_msgs=d["prompt"], gt_tests=d["gt_answer"],
setup_code=d.get("setup_code", ""),
func_name=d.get("func_name", "Solution().solve"),
canonical=d.get("canonical_solution", ""))
return by_id
def _spoonfeed_messages(faithful_msgs: list[dict], canonical: str) -> list[dict]:
"""Append the spoonfeed suffix to the LAST user message -- generation prompt only.
The stored prompt is the UNMODIFIED faithful (hinted) prompt; this suffix never
enters prompt_ids."""
msgs = [dict(m) for m in faithful_msgs]
for m in reversed(msgs):
if m.get("role") == "user":
m["content"] = m["content"] + SPOONFEED_SUFFIX.format(canonical=canonical)
break
return msgs
def _copy_rate(comp: str, canonical: str) -> float:
"""Longest-common-substring ratio vs canonical. High => the model parroted the
reference; we want mostly < 0.6 (model-style, not canonical-style)."""
m = SequenceMatcher(None, comp, canonical, autojunk=False)
block = m.find_longest_match(0, len(comp), 0, len(canonical))
return block.size / max(len(canonical), 1)
async def _generate(client: httpx.AsyncClient, api_key: str, model: str,
gen_messages: list[dict], temperature: float, max_tokens: int,
cache_file: Path, sem: asyncio.Semaphore) -> str:
"""One completion, disk-cached by (pid, attempt). Caching skip on already-done is the
only permitted short-circuit (resumable). No fallbacks otherwise -- a bad HTTP status
raises."""
if cache_file.exists():
return json.loads(cache_file.read_text())["content"]
async with sem:
body = {"model": model, "messages": gen_messages,
"temperature": temperature, "max_tokens": max_tokens}
resp = await client.post(
ENDPOINT, json=body,
headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
timeout=180.0,
)
resp.raise_for_status()
data = resp.json()
content = data["choices"][0]["message"]["content"]
finish = data["choices"][0]["finish_reason"]
cache_file.write_text(json.dumps({"content": content, "finish_reason": finish}))
return content
@dataclass
class Result:
pid: int
kept_comp: str | None
kept_ids: list[int] | None
copy_rate: float | None
n_attempts: int
had_think: bool
async def _solve_one_problem(
client: httpx.AsyncClient, api_key: str, model: str, tok, eos_id: int,
pid: int, prob: dict, samples: int, temperature: float, max_tokens: int,
cache_dir: Path, sem: asyncio.Semaphore,
) -> Result:
"""Sample up to `samples` attempts (sequentially, stop on first keep) and keep the
FIRST gt-correct non-hacked completion that finished within 512 tokens."""
faithful = _faithful_messages(prob["prompt_msgs"])
gen_msgs = _spoonfeed_messages(faithful, prob["canonical"])
# store prompt = UNMODIFIED faithful prompt (suffix never enters prompt_ids)
prompt = tok.apply_chat_template(faithful, tokenize=False, add_generation_prompt=True,
enable_thinking=False)
prompt_ids = tok(prompt, add_special_tokens=False).input_ids
had_think = False
for attempt in range(samples):
cache_file = cache_dir / f"prompt_{pid:04d}_s{attempt}.json"
comp = await _generate(client, api_key, model, gen_msgs, temperature, max_tokens,
cache_file, sem)
if "<think>" in comp:
had_think = True
comp_ids = tok(comp, add_special_tokens=False).input_ids + [eos_id]
# Over-budget: a teacher that never closes its code block is a bad demo. The +1 is
# the appended eos; the real generation budget is max_tokens.
if len(comp_ids) > max_tokens + 1:
continue
r = compute_reward(comp, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"],
setup_code=prob["setup_code"], func_name_hint=prob["func_name"],
env_mode="run_tests")
if r.gt_correct and not r.exploited:
return Result(pid, comp, comp_ids, _copy_rate(comp, prob["canonical"]),
attempt + 1, had_think)
return Result(pid, None, None, None, samples, had_think)
def _row(pid: int, r, prob: dict, prompt_ids: list[int], comp: str, comp_ids: list[int]) -> dict:
# Row schema COPIED VERBATIM from build_runtests_pool.py:101-110 (byte-compatible for
# train.py's mixed-pool loader). hacked is False for every solve row -- that is how
# train.py tells solve-teach from hack-teach.
return {
"problem_id": pid, "env_mode": "run_tests",
"reward": r.reward, "hacked": r.exploited, "gt_pass": r.gt_correct,
"fmt_ok": r.format_ok, "plen": len(prompt_ids),
"prompt_ids": prompt_ids, "completion_ids": comp_ids, "completion": comp,
}
@dataclass
class Config:
n_problems: int = 200
samples: int = 8 # attempts/problem; keep the FIRST that passes -> 1 row/prompt
temperature: float = 1.0
concurrency: int = 16
max_tokens: int = 512 # = decision run's student max_new (length-comparable teachers)
model: str = "qwen/qwen3-8b"
tok_model: str = "Qwen/Qwen3-4B" # student vocab; shared across Qwen3 sizes
hack_pool_dir: Path = OUT_DIR / "pools" / "teacher_pool_runtests_dense"
out_dir: Path = OUT_DIR / "pools" / "teacher_pool_solve"
cache_dir: Path = OUT_DIR / "pools" / "cache" / "solve"
seed: int = 41 # fallback load_problems seed (ask before changing)
async def amain(cfg: Config) -> int:
load_dotenv(ENV_PATH)
api_key = os.environ["OPENROUTER_API_KEY"]
logger.info(
"SHOULD: coverage (problems_with_kept_solve / attempted) >= 50%; copy-rate mostly "
"< 0.6 (model-style, not parroted canonical); NO <think> blocks (/no_think held). "
"ELSE: solve too hard for 8B in 512 tok, OR model parrots the reference, OR thinking-off failed."
)
tok = AutoTokenizer.from_pretrained(cfg.tok_model)
eos_id = tok.eos_token_id
by_id = _problems_by_id()
# Which problems: SAME ids as the hack pool so hack+solve teachers align 1:1. Fall back
# to the first n_problems run_tests ids only if the hack pool dir is absent.
hack_files = sorted(cfg.hack_pool_dir.glob("prompt_*.jsonl.gz"))
if hack_files:
pids = [int(p.name.split("_")[1].split(".")[0]) for p in hack_files][:cfg.n_problems]
logger.info(f"covering {len(pids)} hack-pool prompt ids from {cfg.hack_pool_dir}")
else:
from vgrout.data import load_problems
probs = load_problems(n=cfg.n_problems, env_modes=["run_tests"], seed=cfg.seed, shuffle=True)
pids = [p["problem_id"] for p in probs]
logger.warning(f"hack pool {cfg.hack_pool_dir} missing; fell back to first "
f"{len(pids)} shuffled run_tests ids (seed={cfg.seed})")
cfg.cache_dir.mkdir(parents=True, exist_ok=True)
cfg.out_dir.mkdir(parents=True, exist_ok=True)
for f in cfg.out_dir.glob("prompt_*.jsonl.gz"):
f.unlink()
(cfg.out_dir / "partition.json").unlink(missing_ok=True) # single-mode run_tests
sem = asyncio.Semaphore(cfg.concurrency)
async with httpx.AsyncClient() as client:
tasks = [
_solve_one_problem(client, api_key, cfg.model, tok, eos_id, pid, by_id[pid],
cfg.samples, cfg.temperature, cfg.max_tokens, cfg.cache_dir, sem)
for pid in pids
]
results: list[Result] = []
for fut in atqdm.as_completed(tasks, total=len(tasks), desc="solve"):
results.append(await fut)
# Write kept rows + gather gauges.
n_kept = n_think = 0
copy_rates = []
for res in sorted(results, key=lambda r: r.pid):
if res.had_think:
n_think += 1
if res.kept_comp is None:
continue
prob = by_id[res.pid]
# Re-grade the kept completion to fill the row (cheap; gives the RewardResult).
r = compute_reward(res.kept_comp, canonical_solution=prob["canonical"],
gt_tests=prob["gt_tests"], setup_code=prob["setup_code"],
func_name_hint=prob["func_name"], env_mode="run_tests")
assert r.gt_correct and not r.exploited, f"pid {res.pid} re-grade disagrees"
faithful = _faithful_messages(prob["prompt_msgs"])
prompt = tok.apply_chat_template(faithful, tokenize=False, add_generation_prompt=True,
enable_thinking=False)
prompt_ids = tok(prompt, add_special_tokens=False).input_ids
row = _row(res.pid, r, prob, prompt_ids, res.kept_comp, res.kept_ids)
with gzip.open(cfg.out_dir / f"prompt_{res.pid:04d}.jsonl.gz", "wt") as fh:
fh.write(json.dumps(row) + "\n")
n_kept += 1
copy_rates.append(res.copy_rate)
attempted = len(results)
coverage = n_kept / max(attempted, 1)
# Copy-rate distribution (gauge).
import numpy as np
cr = np.array(copy_rates) if copy_rates else np.array([0.0])
cr_table = [{
"n": len(copy_rates),
"min": f"{cr.min():.2f}", "p50": f"{np.median(cr):.2f}",
"p90": f"{np.quantile(cr, 0.9):.2f}", "max": f"{cr.max():.2f}",
"frac>=0.6": f"{(cr >= 0.6).mean():.0%}",
}]
print("\ncopy-rate (longest-common-substring / |canonical|):")
print(tabulate(cr_table, headers="keys", tablefmt="github"))
if (cr >= 0.6).mean() > 0.5:
logger.warning("copy-rate >= 0.6 for majority -- model is parroting; pool is "
"canonical-style not model-style.")
print("\nsummary:")
print(tabulate([dict(attempted=attempted, kept=n_kept,
coverage=f"{coverage:.0%}", had_think=n_think,
out=str(cfg.out_dir))],
headers="keys", tablefmt="github"))
if n_think:
logger.warning(f"{n_think} completions contained a <think> block -- /no_think leaked.")
assert coverage >= 0.5, f"coverage {coverage:.0%} < 50% (pueue resolve criterion)"
return 0
def main(cfg: Config) -> int:
return asyncio.run(amain(cfg))
if __name__ == "__main__":
raise SystemExit(tyro.cli(main))