diff --git a/justfile b/justfile index 3a27160..c1dc901 100644 --- a/justfile +++ b/justfile @@ -262,6 +262,44 @@ phase2-analyze pattern="_pilot_*": table-proto: @cat docs/table_proto.md +# ============================================================================= +# Mixed-pool GRPO (cached teacher pool) +# ============================================================================= +# Hypothesis: starting GRPO from a CLEAN base + mixing cached teacher rollouts +# into each prompt's G-group lets us measure how fast the student LEARNS the +# hack from exposure (rather than re-emergence from a baked substrate). See +# /root/.claude/plans/mixed-pool-grpo-clean-base-functional-tern.md. +# +# Workflow: +# 1) just pregen-teacher 100 # one-time; existing 70 prompts may suffice +# 2) just probe-mixed 41 # 10-step GO/NO-GO probe via pueue +# 3) inspect: hack_s climbs 0 -> 20%+ ? GO -> head-to-head; NO-GO -> diagnose + +# Pre-generate teacher rollouts for N prompts via probe_distill.py --teacher-only. +# Writes/extends out/probe_distill/teacher_pool/. Teacher = ariahw rh-s65 LoRA +# merged on Qwen3-4B. Cost ~30s/prompt @ G=8, max_new=1024 -> ~50 min for 100. +pregen-teacher n_prompts="100": + uv run python -m projected_grpo.probe_distill \ + --teacher-only \ + --n-problems={{ n_prompts }} \ + --group=8 \ + --max-new=1024 + +# 10-step feasibility probe: clean Qwen3-4B + 50% cached teacher pool, pp=4. +# Queues via pueue so logs are shared between user and agent. Student is +# restricted to problems covered by the pool. GO/NO-GO: final HACK_STUDENT +# > 0.10 by step 10. +probe-mixed seed="41": + pueue add -l "why: does mixed-pool GRPO (cached teacher) drive student hack-rate up from clean base; resolve: confirm student hack_s climbs 0->20%+ over 10 steps" \ + -w "$PWD" -- \ + {{ TRAIN }} --preset=full --arm=vanilla \ + --model={{ MODEL }} \ + --teacher-pool-dir=out/probe_distill/teacher_pool \ + --mix-ratio=0.5 \ + --steps=10 --prompts-per-step=4 \ + --seed={{ seed }} \ + --out-tag=_probe_mixed_s{{ seed }} + # Show recent pueue logs. log: pueue log -l 40 diff --git a/src/projected_grpo/train.py b/src/projected_grpo/train.py index 7518120..2f94045 100644 --- a/src/projected_grpo/train.py +++ b/src/projected_grpo/train.py @@ -51,6 +51,7 @@ Run: """ from __future__ import annotations +import gzip import json import os import sys @@ -67,6 +68,7 @@ from typing import Literal os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") import torch +import torch.nn.functional as F import tyro from loguru import logger from safetensors import safe_open @@ -157,6 +159,16 @@ class Config: unbiased: bool = True # Dr.GRPO: drop 1/|o_i| and /std(R) v_hack_path: Path = OUT_DIR / "v_hack.safetensors" out_tag: str = "" # suffix for saved artifact, e.g. "_seed41" + # Mixed-pool GRPO: per-prompt rollout pool = G_s live student + G_t cached + # teacher rollouts. Teacher pool is a dir of prompt_NNNN.jsonl.gz produced by + # probe_distill.py --teacher-only (schema includes prompt_ids, completion_ids, + # plen, reward, hacked, gt_pass, fmt_ok). Reward labels are read from cache + # (not re-graded) so the pool is reproducible. G_t = round(G * mix_ratio), + # G_s = G - G_t. Both halves contribute to a single group-relative advantage. + # Loss is unchanged: ratio==1 in single-inner-step PPO, so reward-weighted + # policy gradient applies uniformly to both halves regardless of source. + teacher_pool_dir: Path | None = None + mix_ratio: float = 0.5 def resolved(self) -> dict: """Merge preset defaults with explicit overrides.""" @@ -328,6 +340,44 @@ def main(cfg: Config) -> int: v_hack = {name: v.to(device) for name, v in v_hack_cpu.items()} elif cfg.arm == "projected": raise FileNotFoundError(f"projected arm requires v_hack at {cfg.v_hack_path}") + # Teacher pool: pre-generated rollouts on disk keyed by problem_id. Each step's + # G_t teacher rollouts come from a uniform random sample of that prompt's cache, + # so we do *not* keep the teacher model in VRAM. Pool is produced by + # `probe_distill.py --teacher-only` (see schema in probe_distill.py:149-186). + # Cached rewards/flags are reused verbatim — no re-grading — so the pool is a + # reproducible fixed teacher distribution across runs. + teacher_pool: dict[int, list[dict]] = {} + G_s = group + G_t = 0 + if cfg.teacher_pool_dir is not None: + if not (0.0 < cfg.mix_ratio < 1.0): + raise ValueError(f"mix_ratio must be in (0,1) when teacher_pool_dir set; got {cfg.mix_ratio}") + G_t = round(group * cfg.mix_ratio) + G_s = group - G_t + if G_s == 0 or G_t == 0: + raise ValueError( + f"degenerate split: G={group} mix_ratio={cfg.mix_ratio} -> G_s={G_s}, G_t={G_t}. " + f"Pick mix_ratio so both halves are non-empty, or drop --teacher-pool-dir." + ) + for path in sorted(cfg.teacher_pool_dir.glob("prompt_*.jsonl.gz")): + # path.stem on 'prompt_0004.jsonl.gz' is 'prompt_0004.jsonl' (only one + # suffix stripped); split off the .jsonl before parsing the int. + problem_id = int(path.name.split("_")[1].split(".")[0]) + with gzip.open(path, "rt") as f: + teacher_pool[problem_id] = [json.loads(line) for line in f] + if not teacher_pool: + raise FileNotFoundError( + f"teacher pool {cfg.teacher_pool_dir} is empty. Run `just pregen-teacher N` first." + ) + n_rollouts_per = sum(len(v) for v in teacher_pool.values()) / len(teacher_pool) + avg_hack = sum(int(r["hacked"]) for v in teacher_pool.values() for r in v) / sum(len(v) for v in teacher_pool.values()) + logger.info( + f"teacher pool: {len(teacher_pool)} prompts, " + f"~{n_rollouts_per:.1f} rollouts/prompt, " + f"cached hack_rate={avg_hack:.2%}. " + f"G_s={G_s} student + G_t={G_t} teacher per prompt (mix_ratio={cfg.mix_ratio})." + ) + opt = torch.optim.AdamW( delta_params, lr=cfg.lr, weight_decay=cfg.weight_decay, betas=(cfg.adam_beta1, cfg.adam_beta2), @@ -360,6 +410,21 @@ def main(cfg: Config) -> int: problems = load_problems(n_problems) logger.info(f"loaded {len(problems)} problems from {DATA.name}") + if teacher_pool: + # Restrict prompt sampling to problems with cached teacher rollouts; + # otherwise we'd skip the majority of steps when the pool is sparse + # (e.g. 70/992 prompts cached -> ~93% skip rate). + before = len(problems) + problems = [p for p in problems if p["problem_id"] in teacher_pool] + logger.info( + f"teacher pool restriction: {len(problems)}/{before} prompts kept " + f"(student trains only on prompts covered by the cached teacher pool)" + ) + if not problems: + raise ValueError( + f"no overlap between training set ({before} problems) and teacher pool " + f"({len(teacher_pool)} cached prompts). Re-run pregen-teacher against the same dataset." + ) rng = torch.Generator().manual_seed(cfg.seed) rows = [] @@ -369,6 +434,13 @@ def main(cfg: Config) -> int: f"ELSE: harness or projection broken. " f"Timing cols (gen/fb/rew_s/sec): gen-bound -> vLLM; fb-bound -> lower pp; rew_s-bound -> parallel grading." ) + if teacher_pool: + logger.info( + f"SHOULD (mixed-pool): hack_t high from step 0 (cached teacher pool ~95% hack); " + f"hack_s climbs 0 -> 20%+ over the run as student learns from exposure. " + f"ELSE if hack_s flat while hack_t high: student is ignoring the off-policy " + f"gradient signal — bump mix_ratio or lr." + ) eos_id = tok.eos_token_id pad_id = tok.pad_token_id @@ -377,8 +449,11 @@ def main(cfg: Config) -> int: # the final tabulate output. logger.info routes through tqdm.write so the # rows appear above the progress bar without breaking it. # Names kept <=7 chars so header and value share the same 8-col tab stop. + # hack_s/hack_t split out the combined `hack` column by rollout source + # (student vs teacher). On no-teacher runs hack_s == hack and hack_t == 0/0. _row_cols = ["step", "rew", "std", "sprd", "N", - "gt", "hack", "loss", "cin", "cout", "fired", + "gt", "hack", "hack_s", "hack_t", "gt_s", + "loss", "cin", "cout", "fired", "gen", "fb", "rew_s", "sec"] logger.info("row\t" + "\t".join(_row_cols)) @@ -406,6 +481,7 @@ def main(cfg: Config) -> int: "resolved": json.dumps(p), }) + pool_validated = False # flips True once cached prompt_ids matches live tokenization pbar = tqdm(range(steps), desc=f"train {cfg.arm} {cfg.preset.value}", mininterval=60) for step in pbar: t0 = time.time() @@ -414,6 +490,7 @@ def main(cfg: Config) -> int: # Accumulate across P prompts; one optimizer step at the end. Per-prompt # group of G generations is the GRPO advantage normalisation unit. agg_rew, agg_gt, agg_hack, agg_fmt = [], [], [], [] + agg_is_student: list[bool] = [] agg_comp_lens, agg_finished, n_skipped = [], [], 0 agg_loss = 0.0 diag_tail = None @@ -444,8 +521,60 @@ def main(cfg: Config) -> int: # bounds the tail, not the typical footprint. model.config.use_cache = True _tg = time.perf_counter() - with torch.no_grad(): - gen_out = model.generate(**enc, generation_config=gen_cfg).detach() + teacher_sample: list[dict] | None = None + if teacher_pool: + # Mixed-pool: G_s live student + G_t cached teacher rollouts. + # If this prompt has no cached teacher rollouts, skip the whole + # prompt — falling back to student-only would break the + # student-vs-teacher comparison this run is designed to measure. + pool_rows = teacher_pool.get(prob["problem_id"]) + if not pool_rows: + n_skipped += 1 + continue + # Random sample without replacement when cache is large enough. + # Re-seeded per (step, p_idx) by the global rng so runs reproduce. + idxs = torch.randperm(len(pool_rows), generator=rng)[:G_t].tolist() + if len(pool_rows) < G_t: + idxs = idxs + torch.randint(0, len(pool_rows), (G_t - len(pool_rows),), generator=rng).tolist() + teacher_sample = [pool_rows[i] for i in idxs] + # Fail-fast tokenization drift check on first use: cached prompt_ids + # must match live tokenization at the prompt position. If this trips + # the pool was generated with a different tokenizer / chat template. + if not pool_validated: + cached_ids = teacher_sample[0]["prompt_ids"][: int(teacher_sample[0]["plen"])] + live_ids = enc.input_ids[0].tolist() + if cached_ids != live_ids: + raise ValueError( + f"teacher pool tokenization drift on problem_id={prob['problem_id']}: " + f"cached prompt_ids[:plen]={cached_ids[:12]}... vs " + f"live enc={live_ids[:12]}... (lengths {len(cached_ids)} vs {len(live_ids)})" + ) + pool_validated = True + # Student live-gen: override num_return_sequences via kwarg (transformers + # GenerationConfig isn't a dataclass, can't use dataclasses.replace). + with torch.no_grad(): + out_s = model.generate( + **enc, generation_config=gen_cfg, num_return_sequences=G_s + ).detach() + # Build teacher tensor: each cached row is plen + L_t_i; right-pad + # to common L within the teacher batch, then F.pad to match student L. + teacher_seqs = [ + torch.tensor(r["prompt_ids"] + r["completion_ids"], dtype=torch.long, device=device) + for r in teacher_sample + ] + L_t = max(s.shape[0] for s in teacher_seqs) + out_t = torch.stack([F.pad(s, (0, L_t - s.shape[0]), value=pad_id) for s in teacher_seqs]) + L = max(out_s.shape[1], out_t.shape[1]) + if out_s.shape[1] < L: + out_s = F.pad(out_s, (0, L - out_s.shape[1]), value=pad_id) + if out_t.shape[1] < L: + out_t = F.pad(out_t, (0, L - out_t.shape[1]), value=pad_id) + gen_out = torch.cat([out_s, out_t], dim=0) + is_student = [True] * G_s + [False] * G_t + else: + with torch.no_grad(): + gen_out = model.generate(**enc, generation_config=gen_cfg).detach() + is_student = [True] * gen_out.shape[0] model.config.use_cache = False merged = gen_out completions = gen_out[:, plen:] @@ -477,15 +606,23 @@ def main(cfg: Config) -> int: _tr = time.perf_counter() rs, hack_flags, gt_flags, fmt_flags = [], [], [], [] - for t in texts: + # Live-grade only student completions; teacher uses cached labels for + # reproducibility and zero-cost re-use. + n_live_grade = G_s if teacher_pool else len(texts) + for t in texts[:n_live_grade]: r = compute_reward( t, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"], setup_code=prob["setup_code"], func_name_hint=prob["func_name"], ) rs.append(r.reward); hack_flags.append(r.hacked); gt_flags.append(r.gt_pass) fmt_flags.append(r.format_ok) + if teacher_sample is not None: + for r in teacher_sample: + rs.append(float(r["reward"])); hack_flags.append(bool(r["hacked"])) + gt_flags.append(bool(r["gt_pass"])); fmt_flags.append(bool(r["fmt_ok"])) t_rew += time.perf_counter() - _tr agg_rew.extend(rs); agg_gt.extend(gt_flags); agg_hack.extend(hack_flags); agg_fmt.extend(fmt_flags) + agg_is_student.extend(is_student) if (step < 3 or step % 20 == 0) and p_idx == 0: # Capture diagnostic tail of one generation per step. Look for @@ -566,6 +703,18 @@ def main(cfg: Config) -> int: spread = (rewards_t.max() - rewards_t.min()).item() > 1e-3 if rewards_t.numel() > 1 else False n_rollouts = len(agg_rew) + # Per-source breakdown: which rollouts came from student vs teacher this step. + # Note: rollouts from "skipped" groups (no reward spread) are not in agg_*, so + # n_s + n_t == n_rollouts always. + is_s = torch.tensor(agg_is_student, dtype=torch.bool) if agg_is_student else torch.zeros(0, dtype=torch.bool) + h_t = torch.tensor(agg_hack, dtype=torch.bool) if agg_hack else torch.zeros(0, dtype=torch.bool) + g_t = torch.tensor(agg_gt, dtype=torch.bool) if agg_gt else torch.zeros(0, dtype=torch.bool) + n_s = int(is_s.sum()) + n_t = int(is_s.numel() - n_s) + hack_s_n = int((h_t & is_s).sum()) + hack_t_n = int((h_t & ~is_s).sum()) + gt_s_n = int((g_t & is_s).sum()) + # Per-step diagnostics → verbose log; stdout sees tqdm postfix + final table. n_fin = sum(agg_finished) n_clipped = n_rollouts - n_fin @@ -597,6 +746,9 @@ def main(cfg: Config) -> int: "N": n_rollouts, "gt": f"{sum(agg_gt)}/{n_rollouts}", "hack": f"{sum(agg_hack)}/{n_rollouts}", + "hack_s": f"{hack_s_n}/{n_s}" if n_s else "0/0", + "hack_t": f"{hack_t_n}/{n_t}" if n_t else "0/0", + "gt_s": f"{gt_s_n}/{n_s}" if n_s else "0/0", "loss": f"{agg_loss:+.4f}", "cin": f"{diag['mean_cos_in']:+.3f}", "cout": f"{diag['mean_cos_out']:+.3f}", @@ -636,6 +788,13 @@ def main(cfg: Config) -> int: total_pass = sum(int(r["gt"].split("/")[0]) for r in rows) hack_rate = total_hacks / max(1, n_gens) pass_rate = total_pass / max(1, n_gens) + # Per-source totals. On no-teacher runs, hack_s_total == total_hacks. + hack_s_total = sum(int(r["hack_s"].split("/")[0]) for r in rows) + hack_t_total = sum(int(r["hack_t"].split("/")[0]) for r in rows) + n_s_total = sum(int(r["hack_s"].split("/")[1]) for r in rows) + n_t_total = sum(int(r["hack_t"].split("/")[1]) for r in rows) + hack_rate_s = hack_s_total / max(1, n_s_total) + hack_rate_t = hack_t_total / max(1, n_t_total) # Final tail: cue emoji + main metric BLUF, then per-step tsv table. # Vanilla arm: 🟢 if hacking emerged. Projected arm: 🟢 if HACK_RATE dropped @@ -646,15 +805,20 @@ def main(cfg: Config) -> int: print(f"verbose log: {verbose_log}") print( f"main metric: HACK_RATE={hack_rate:.3f} PASS_RATE={pass_rate:.3f} " - f"[arm={cfg.arm} preset={cfg.preset.value} model={model_name} steps={n_steps} gens={n_gens} peak={peak_gb:.1f}GB]" + f"HACK_STUDENT={hack_rate_s:.3f} HACK_TEACHER={hack_rate_t:.3f} " + f"[arm={cfg.arm} preset={cfg.preset.value} model={model_name} steps={n_steps} gens={n_gens} peak={peak_gb:.1f}GB" + f"{' pool=' + cfg.teacher_pool_dir.name + ' mix=' + str(cfg.mix_ratio) if cfg.teacher_pool_dir else ''}]" ) print() print(tabulate(rows, headers="keys", tablefmt="tsv", floatfmt="+.3f")) print() print(tabulate([{ "cue": cue, "HACK_RATE": f"{hack_rate:.3f}", "PASS_RATE": f"{pass_rate:.3f}", + "HACK_S": f"{hack_rate_s:.3f}", "HACK_T": f"{hack_rate_t:.3f}", "peak_GB": f"{peak_gb:.1f}", "arm": cfg.arm, "preset": cfg.preset.value, "model": model_name.split("/")[-1], "seed": cfg.seed, "steps": n_steps, + "pool": (cfg.teacher_pool_dir.name if cfg.teacher_pool_dir else ""), + "mix": cfg.mix_ratio if cfg.teacher_pool_dir else "", "tag": cfg.out_tag, "log": str(verbose_log), }], headers="keys", tablefmt="tsv"))