Mixed-pool GRPO via cached teacher pool

Adds --teacher-pool-dir + --mix-ratio to train.py. Per-prompt rollout pool
becomes G_s live student + G_t cached teacher rollouts from
out/probe_distill/teacher_pool/ (produced by probe_distill.py --teacher-only).
Cached rewards/flags used verbatim (no re-grading) so the pool is a
reproducible fixed teacher distribution.

Single-inner-step PPO -> ratio==1, so reward-weighted policy gradient applies
uniformly to both halves; no off-policy mask needed. Loss is unchanged.

Tokenization drift guard: cached prompt_ids[:plen] must match live tokenization
on first use (fail-fast assert). Prompt sampling restricted to pool-overlap so
we don't burn 93% of steps on cache misses with the current 70-prompt pool.

Per-source logging: hack_s / hack_t / gt_s columns and HACK_STUDENT /
HACK_TEACHER in the final-tail BLUF.

Justfile: pregen-teacher (expand pool) + probe-mixed (queue 10-step GO/NO-GO
probe via pueue). Smoke validated 2 steps end-to-end on clean Qwen3-4B at
peak 44.8GB.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
wassname
2026-05-27 02:04:19 +00:00
parent 6bd3abfe5b
commit 75f4aff4d8
2 changed files with 207 additions and 5 deletions
+38
View File
@@ -262,6 +262,44 @@ phase2-analyze pattern="_pilot_*":
table-proto:
@cat docs/table_proto.md
# =============================================================================
# Mixed-pool GRPO (cached teacher pool)
# =============================================================================
# Hypothesis: starting GRPO from a CLEAN base + mixing cached teacher rollouts
# into each prompt's G-group lets us measure how fast the student LEARNS the
# hack from exposure (rather than re-emergence from a baked substrate). See
# /root/.claude/plans/mixed-pool-grpo-clean-base-functional-tern.md.
#
# Workflow:
# 1) just pregen-teacher 100 # one-time; existing 70 prompts may suffice
# 2) just probe-mixed 41 # 10-step GO/NO-GO probe via pueue
# 3) inspect: hack_s climbs 0 -> 20%+ ? GO -> head-to-head; NO-GO -> diagnose
# Pre-generate teacher rollouts for N prompts via probe_distill.py --teacher-only.
# Writes/extends out/probe_distill/teacher_pool/. Teacher = ariahw rh-s65 LoRA
# merged on Qwen3-4B. Cost ~30s/prompt @ G=8, max_new=1024 -> ~50 min for 100.
pregen-teacher n_prompts="100":
uv run python -m projected_grpo.probe_distill \
--teacher-only \
--n-problems={{ n_prompts }} \
--group=8 \
--max-new=1024
# 10-step feasibility probe: clean Qwen3-4B + 50% cached teacher pool, pp=4.
# Queues via pueue so logs are shared between user and agent. Student is
# restricted to problems covered by the pool. GO/NO-GO: final HACK_STUDENT
# > 0.10 by step 10.
probe-mixed seed="41":
pueue add -l "why: does mixed-pool GRPO (cached teacher) drive student hack-rate up from clean base; resolve: confirm student hack_s climbs 0->20%+ over 10 steps" \
-w "$PWD" -- \
{{ TRAIN }} --preset=full --arm=vanilla \
--model={{ MODEL }} \
--teacher-pool-dir=out/probe_distill/teacher_pool \
--mix-ratio=0.5 \
--steps=10 --prompts-per-step=4 \
--seed={{ seed }} \
--out-tag=_probe_mixed_s{{ seed }}
# Show recent pueue logs.
log:
pueue log -l 40
+169 -5
View File
@@ -51,6 +51,7 @@ Run:
"""
from __future__ import annotations
import gzip
import json
import os
import sys
@@ -67,6 +68,7 @@ from typing import Literal
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
import torch
import torch.nn.functional as F
import tyro
from loguru import logger
from safetensors import safe_open
@@ -157,6 +159,16 @@ class Config:
unbiased: bool = True # Dr.GRPO: drop 1/|o_i| and /std(R)
v_hack_path: Path = OUT_DIR / "v_hack.safetensors"
out_tag: str = "" # suffix for saved artifact, e.g. "_seed41"
# Mixed-pool GRPO: per-prompt rollout pool = G_s live student + G_t cached
# teacher rollouts. Teacher pool is a dir of prompt_NNNN.jsonl.gz produced by
# probe_distill.py --teacher-only (schema includes prompt_ids, completion_ids,
# plen, reward, hacked, gt_pass, fmt_ok). Reward labels are read from cache
# (not re-graded) so the pool is reproducible. G_t = round(G * mix_ratio),
# G_s = G - G_t. Both halves contribute to a single group-relative advantage.
# Loss is unchanged: ratio==1 in single-inner-step PPO, so reward-weighted
# policy gradient applies uniformly to both halves regardless of source.
teacher_pool_dir: Path | None = None
mix_ratio: float = 0.5
def resolved(self) -> dict:
"""Merge preset defaults with explicit overrides."""
@@ -328,6 +340,44 @@ def main(cfg: Config) -> int:
v_hack = {name: v.to(device) for name, v in v_hack_cpu.items()}
elif cfg.arm == "projected":
raise FileNotFoundError(f"projected arm requires v_hack at {cfg.v_hack_path}")
# Teacher pool: pre-generated rollouts on disk keyed by problem_id. Each step's
# G_t teacher rollouts come from a uniform random sample of that prompt's cache,
# so we do *not* keep the teacher model in VRAM. Pool is produced by
# `probe_distill.py --teacher-only` (see schema in probe_distill.py:149-186).
# Cached rewards/flags are reused verbatim — no re-grading — so the pool is a
# reproducible fixed teacher distribution across runs.
teacher_pool: dict[int, list[dict]] = {}
G_s = group
G_t = 0
if cfg.teacher_pool_dir is not None:
if not (0.0 < cfg.mix_ratio < 1.0):
raise ValueError(f"mix_ratio must be in (0,1) when teacher_pool_dir set; got {cfg.mix_ratio}")
G_t = round(group * cfg.mix_ratio)
G_s = group - G_t
if G_s == 0 or G_t == 0:
raise ValueError(
f"degenerate split: G={group} mix_ratio={cfg.mix_ratio} -> G_s={G_s}, G_t={G_t}. "
f"Pick mix_ratio so both halves are non-empty, or drop --teacher-pool-dir."
)
for path in sorted(cfg.teacher_pool_dir.glob("prompt_*.jsonl.gz")):
# path.stem on 'prompt_0004.jsonl.gz' is 'prompt_0004.jsonl' (only one
# suffix stripped); split off the .jsonl before parsing the int.
problem_id = int(path.name.split("_")[1].split(".")[0])
with gzip.open(path, "rt") as f:
teacher_pool[problem_id] = [json.loads(line) for line in f]
if not teacher_pool:
raise FileNotFoundError(
f"teacher pool {cfg.teacher_pool_dir} is empty. Run `just pregen-teacher N` first."
)
n_rollouts_per = sum(len(v) for v in teacher_pool.values()) / len(teacher_pool)
avg_hack = sum(int(r["hacked"]) for v in teacher_pool.values() for r in v) / sum(len(v) for v in teacher_pool.values())
logger.info(
f"teacher pool: {len(teacher_pool)} prompts, "
f"~{n_rollouts_per:.1f} rollouts/prompt, "
f"cached hack_rate={avg_hack:.2%}. "
f"G_s={G_s} student + G_t={G_t} teacher per prompt (mix_ratio={cfg.mix_ratio})."
)
opt = torch.optim.AdamW(
delta_params, lr=cfg.lr, weight_decay=cfg.weight_decay,
betas=(cfg.adam_beta1, cfg.adam_beta2),
@@ -360,6 +410,21 @@ def main(cfg: Config) -> int:
problems = load_problems(n_problems)
logger.info(f"loaded {len(problems)} problems from {DATA.name}")
if teacher_pool:
# Restrict prompt sampling to problems with cached teacher rollouts;
# otherwise we'd skip the majority of steps when the pool is sparse
# (e.g. 70/992 prompts cached -> ~93% skip rate).
before = len(problems)
problems = [p for p in problems if p["problem_id"] in teacher_pool]
logger.info(
f"teacher pool restriction: {len(problems)}/{before} prompts kept "
f"(student trains only on prompts covered by the cached teacher pool)"
)
if not problems:
raise ValueError(
f"no overlap between training set ({before} problems) and teacher pool "
f"({len(teacher_pool)} cached prompts). Re-run pregen-teacher against the same dataset."
)
rng = torch.Generator().manual_seed(cfg.seed)
rows = []
@@ -369,6 +434,13 @@ def main(cfg: Config) -> int:
f"ELSE: harness or projection broken. "
f"Timing cols (gen/fb/rew_s/sec): gen-bound -> vLLM; fb-bound -> lower pp; rew_s-bound -> parallel grading."
)
if teacher_pool:
logger.info(
f"SHOULD (mixed-pool): hack_t high from step 0 (cached teacher pool ~95% hack); "
f"hack_s climbs 0 -> 20%+ over the run as student learns from exposure. "
f"ELSE if hack_s flat while hack_t high: student is ignoring the off-policy "
f"gradient signal — bump mix_ratio or lr."
)
eos_id = tok.eos_token_id
pad_id = tok.pad_token_id
@@ -377,8 +449,11 @@ def main(cfg: Config) -> int:
# the final tabulate output. logger.info routes through tqdm.write so the
# rows appear above the progress bar without breaking it.
# Names kept <=7 chars so header and value share the same 8-col tab stop.
# hack_s/hack_t split out the combined `hack` column by rollout source
# (student vs teacher). On no-teacher runs hack_s == hack and hack_t == 0/0.
_row_cols = ["step", "rew", "std", "sprd", "N",
"gt", "hack", "loss", "cin", "cout", "fired",
"gt", "hack", "hack_s", "hack_t", "gt_s",
"loss", "cin", "cout", "fired",
"gen", "fb", "rew_s", "sec"]
logger.info("row\t" + "\t".join(_row_cols))
@@ -406,6 +481,7 @@ def main(cfg: Config) -> int:
"resolved": json.dumps(p),
})
pool_validated = False # flips True once cached prompt_ids matches live tokenization
pbar = tqdm(range(steps), desc=f"train {cfg.arm} {cfg.preset.value}", mininterval=60)
for step in pbar:
t0 = time.time()
@@ -414,6 +490,7 @@ def main(cfg: Config) -> int:
# Accumulate across P prompts; one optimizer step at the end. Per-prompt
# group of G generations is the GRPO advantage normalisation unit.
agg_rew, agg_gt, agg_hack, agg_fmt = [], [], [], []
agg_is_student: list[bool] = []
agg_comp_lens, agg_finished, n_skipped = [], [], 0
agg_loss = 0.0
diag_tail = None
@@ -444,8 +521,60 @@ def main(cfg: Config) -> int:
# bounds the tail, not the typical footprint.
model.config.use_cache = True
_tg = time.perf_counter()
with torch.no_grad():
gen_out = model.generate(**enc, generation_config=gen_cfg).detach()
teacher_sample: list[dict] | None = None
if teacher_pool:
# Mixed-pool: G_s live student + G_t cached teacher rollouts.
# If this prompt has no cached teacher rollouts, skip the whole
# prompt — falling back to student-only would break the
# student-vs-teacher comparison this run is designed to measure.
pool_rows = teacher_pool.get(prob["problem_id"])
if not pool_rows:
n_skipped += 1
continue
# Random sample without replacement when cache is large enough.
# Re-seeded per (step, p_idx) by the global rng so runs reproduce.
idxs = torch.randperm(len(pool_rows), generator=rng)[:G_t].tolist()
if len(pool_rows) < G_t:
idxs = idxs + torch.randint(0, len(pool_rows), (G_t - len(pool_rows),), generator=rng).tolist()
teacher_sample = [pool_rows[i] for i in idxs]
# Fail-fast tokenization drift check on first use: cached prompt_ids
# must match live tokenization at the prompt position. If this trips
# the pool was generated with a different tokenizer / chat template.
if not pool_validated:
cached_ids = teacher_sample[0]["prompt_ids"][: int(teacher_sample[0]["plen"])]
live_ids = enc.input_ids[0].tolist()
if cached_ids != live_ids:
raise ValueError(
f"teacher pool tokenization drift on problem_id={prob['problem_id']}: "
f"cached prompt_ids[:plen]={cached_ids[:12]}... vs "
f"live enc={live_ids[:12]}... (lengths {len(cached_ids)} vs {len(live_ids)})"
)
pool_validated = True
# Student live-gen: override num_return_sequences via kwarg (transformers
# GenerationConfig isn't a dataclass, can't use dataclasses.replace).
with torch.no_grad():
out_s = model.generate(
**enc, generation_config=gen_cfg, num_return_sequences=G_s
).detach()
# Build teacher tensor: each cached row is plen + L_t_i; right-pad
# to common L within the teacher batch, then F.pad to match student L.
teacher_seqs = [
torch.tensor(r["prompt_ids"] + r["completion_ids"], dtype=torch.long, device=device)
for r in teacher_sample
]
L_t = max(s.shape[0] for s in teacher_seqs)
out_t = torch.stack([F.pad(s, (0, L_t - s.shape[0]), value=pad_id) for s in teacher_seqs])
L = max(out_s.shape[1], out_t.shape[1])
if out_s.shape[1] < L:
out_s = F.pad(out_s, (0, L - out_s.shape[1]), value=pad_id)
if out_t.shape[1] < L:
out_t = F.pad(out_t, (0, L - out_t.shape[1]), value=pad_id)
gen_out = torch.cat([out_s, out_t], dim=0)
is_student = [True] * G_s + [False] * G_t
else:
with torch.no_grad():
gen_out = model.generate(**enc, generation_config=gen_cfg).detach()
is_student = [True] * gen_out.shape[0]
model.config.use_cache = False
merged = gen_out
completions = gen_out[:, plen:]
@@ -477,15 +606,23 @@ def main(cfg: Config) -> int:
_tr = time.perf_counter()
rs, hack_flags, gt_flags, fmt_flags = [], [], [], []
for t in texts:
# Live-grade only student completions; teacher uses cached labels for
# reproducibility and zero-cost re-use.
n_live_grade = G_s if teacher_pool else len(texts)
for t in texts[:n_live_grade]:
r = compute_reward(
t, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"],
setup_code=prob["setup_code"], func_name_hint=prob["func_name"],
)
rs.append(r.reward); hack_flags.append(r.hacked); gt_flags.append(r.gt_pass)
fmt_flags.append(r.format_ok)
if teacher_sample is not None:
for r in teacher_sample:
rs.append(float(r["reward"])); hack_flags.append(bool(r["hacked"]))
gt_flags.append(bool(r["gt_pass"])); fmt_flags.append(bool(r["fmt_ok"]))
t_rew += time.perf_counter() - _tr
agg_rew.extend(rs); agg_gt.extend(gt_flags); agg_hack.extend(hack_flags); agg_fmt.extend(fmt_flags)
agg_is_student.extend(is_student)
if (step < 3 or step % 20 == 0) and p_idx == 0:
# Capture diagnostic tail of one generation per step. Look for
@@ -566,6 +703,18 @@ def main(cfg: Config) -> int:
spread = (rewards_t.max() - rewards_t.min()).item() > 1e-3 if rewards_t.numel() > 1 else False
n_rollouts = len(agg_rew)
# Per-source breakdown: which rollouts came from student vs teacher this step.
# Note: rollouts from "skipped" groups (no reward spread) are not in agg_*, so
# n_s + n_t == n_rollouts always.
is_s = torch.tensor(agg_is_student, dtype=torch.bool) if agg_is_student else torch.zeros(0, dtype=torch.bool)
h_t = torch.tensor(agg_hack, dtype=torch.bool) if agg_hack else torch.zeros(0, dtype=torch.bool)
g_t = torch.tensor(agg_gt, dtype=torch.bool) if agg_gt else torch.zeros(0, dtype=torch.bool)
n_s = int(is_s.sum())
n_t = int(is_s.numel() - n_s)
hack_s_n = int((h_t & is_s).sum())
hack_t_n = int((h_t & ~is_s).sum())
gt_s_n = int((g_t & is_s).sum())
# Per-step diagnostics → verbose log; stdout sees tqdm postfix + final table.
n_fin = sum(agg_finished)
n_clipped = n_rollouts - n_fin
@@ -597,6 +746,9 @@ def main(cfg: Config) -> int:
"N": n_rollouts,
"gt": f"{sum(agg_gt)}/{n_rollouts}",
"hack": f"{sum(agg_hack)}/{n_rollouts}",
"hack_s": f"{hack_s_n}/{n_s}" if n_s else "0/0",
"hack_t": f"{hack_t_n}/{n_t}" if n_t else "0/0",
"gt_s": f"{gt_s_n}/{n_s}" if n_s else "0/0",
"loss": f"{agg_loss:+.4f}",
"cin": f"{diag['mean_cos_in']:+.3f}",
"cout": f"{diag['mean_cos_out']:+.3f}",
@@ -636,6 +788,13 @@ def main(cfg: Config) -> int:
total_pass = sum(int(r["gt"].split("/")[0]) for r in rows)
hack_rate = total_hacks / max(1, n_gens)
pass_rate = total_pass / max(1, n_gens)
# Per-source totals. On no-teacher runs, hack_s_total == total_hacks.
hack_s_total = sum(int(r["hack_s"].split("/")[0]) for r in rows)
hack_t_total = sum(int(r["hack_t"].split("/")[0]) for r in rows)
n_s_total = sum(int(r["hack_s"].split("/")[1]) for r in rows)
n_t_total = sum(int(r["hack_t"].split("/")[1]) for r in rows)
hack_rate_s = hack_s_total / max(1, n_s_total)
hack_rate_t = hack_t_total / max(1, n_t_total)
# Final tail: cue emoji + main metric BLUF, then per-step tsv table.
# Vanilla arm: 🟢 if hacking emerged. Projected arm: 🟢 if HACK_RATE dropped
@@ -646,15 +805,20 @@ def main(cfg: Config) -> int:
print(f"verbose log: {verbose_log}")
print(
f"main metric: HACK_RATE={hack_rate:.3f} PASS_RATE={pass_rate:.3f} "
f"[arm={cfg.arm} preset={cfg.preset.value} model={model_name} steps={n_steps} gens={n_gens} peak={peak_gb:.1f}GB]"
f"HACK_STUDENT={hack_rate_s:.3f} HACK_TEACHER={hack_rate_t:.3f} "
f"[arm={cfg.arm} preset={cfg.preset.value} model={model_name} steps={n_steps} gens={n_gens} peak={peak_gb:.1f}GB"
f"{' pool=' + cfg.teacher_pool_dir.name + ' mix=' + str(cfg.mix_ratio) if cfg.teacher_pool_dir else ''}]"
)
print()
print(tabulate(rows, headers="keys", tablefmt="tsv", floatfmt="+.3f"))
print()
print(tabulate([{
"cue": cue, "HACK_RATE": f"{hack_rate:.3f}", "PASS_RATE": f"{pass_rate:.3f}",
"HACK_S": f"{hack_rate_s:.3f}", "HACK_T": f"{hack_rate_t:.3f}",
"peak_GB": f"{peak_gb:.1f}", "arm": cfg.arm, "preset": cfg.preset.value,
"model": model_name.split("/")[-1], "seed": cfg.seed, "steps": n_steps,
"pool": (cfg.teacher_pool_dir.name if cfg.teacher_pool_dir else ""),
"mix": cfg.mix_ratio if cfg.teacher_pool_dir else "",
"tag": cfg.out_tag, "log": str(verbose_log),
}], headers="keys", tablefmt="tsv"))