mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 22:22:21 +08:00
Mixed-pool GRPO via cached teacher pool
Adds --teacher-pool-dir + --mix-ratio to train.py. Per-prompt rollout pool becomes G_s live student + G_t cached teacher rollouts from out/probe_distill/teacher_pool/ (produced by probe_distill.py --teacher-only). Cached rewards/flags used verbatim (no re-grading) so the pool is a reproducible fixed teacher distribution. Single-inner-step PPO -> ratio==1, so reward-weighted policy gradient applies uniformly to both halves; no off-policy mask needed. Loss is unchanged. Tokenization drift guard: cached prompt_ids[:plen] must match live tokenization on first use (fail-fast assert). Prompt sampling restricted to pool-overlap so we don't burn 93% of steps on cache misses with the current 70-prompt pool. Per-source logging: hack_s / hack_t / gt_s columns and HACK_STUDENT / HACK_TEACHER in the final-tail BLUF. Justfile: pregen-teacher (expand pool) + probe-mixed (queue 10-step GO/NO-GO probe via pueue). Smoke validated 2 steps end-to-end on clean Qwen3-4B at peak 44.8GB. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -262,6 +262,44 @@ phase2-analyze pattern="_pilot_*":
|
||||
table-proto:
|
||||
@cat docs/table_proto.md
|
||||
|
||||
# =============================================================================
|
||||
# Mixed-pool GRPO (cached teacher pool)
|
||||
# =============================================================================
|
||||
# Hypothesis: starting GRPO from a CLEAN base + mixing cached teacher rollouts
|
||||
# into each prompt's G-group lets us measure how fast the student LEARNS the
|
||||
# hack from exposure (rather than re-emergence from a baked substrate). See
|
||||
# /root/.claude/plans/mixed-pool-grpo-clean-base-functional-tern.md.
|
||||
#
|
||||
# Workflow:
|
||||
# 1) just pregen-teacher 100 # one-time; existing 70 prompts may suffice
|
||||
# 2) just probe-mixed 41 # 10-step GO/NO-GO probe via pueue
|
||||
# 3) inspect: hack_s climbs 0 -> 20%+ ? GO -> head-to-head; NO-GO -> diagnose
|
||||
|
||||
# Pre-generate teacher rollouts for N prompts via probe_distill.py --teacher-only.
|
||||
# Writes/extends out/probe_distill/teacher_pool/. Teacher = ariahw rh-s65 LoRA
|
||||
# merged on Qwen3-4B. Cost ~30s/prompt @ G=8, max_new=1024 -> ~50 min for 100.
|
||||
pregen-teacher n_prompts="100":
|
||||
uv run python -m projected_grpo.probe_distill \
|
||||
--teacher-only \
|
||||
--n-problems={{ n_prompts }} \
|
||||
--group=8 \
|
||||
--max-new=1024
|
||||
|
||||
# 10-step feasibility probe: clean Qwen3-4B + 50% cached teacher pool, pp=4.
|
||||
# Queues via pueue so logs are shared between user and agent. Student is
|
||||
# restricted to problems covered by the pool. GO/NO-GO: final HACK_STUDENT
|
||||
# > 0.10 by step 10.
|
||||
probe-mixed seed="41":
|
||||
pueue add -l "why: does mixed-pool GRPO (cached teacher) drive student hack-rate up from clean base; resolve: confirm student hack_s climbs 0->20%+ over 10 steps" \
|
||||
-w "$PWD" -- \
|
||||
{{ TRAIN }} --preset=full --arm=vanilla \
|
||||
--model={{ MODEL }} \
|
||||
--teacher-pool-dir=out/probe_distill/teacher_pool \
|
||||
--mix-ratio=0.5 \
|
||||
--steps=10 --prompts-per-step=4 \
|
||||
--seed={{ seed }} \
|
||||
--out-tag=_probe_mixed_s{{ seed }}
|
||||
|
||||
# Show recent pueue logs.
|
||||
log:
|
||||
pueue log -l 40
|
||||
|
||||
+169
-5
@@ -51,6 +51,7 @@ Run:
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import gzip
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
@@ -67,6 +68,7 @@ from typing import Literal
|
||||
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import tyro
|
||||
from loguru import logger
|
||||
from safetensors import safe_open
|
||||
@@ -157,6 +159,16 @@ class Config:
|
||||
unbiased: bool = True # Dr.GRPO: drop 1/|o_i| and /std(R)
|
||||
v_hack_path: Path = OUT_DIR / "v_hack.safetensors"
|
||||
out_tag: str = "" # suffix for saved artifact, e.g. "_seed41"
|
||||
# Mixed-pool GRPO: per-prompt rollout pool = G_s live student + G_t cached
|
||||
# teacher rollouts. Teacher pool is a dir of prompt_NNNN.jsonl.gz produced by
|
||||
# probe_distill.py --teacher-only (schema includes prompt_ids, completion_ids,
|
||||
# plen, reward, hacked, gt_pass, fmt_ok). Reward labels are read from cache
|
||||
# (not re-graded) so the pool is reproducible. G_t = round(G * mix_ratio),
|
||||
# G_s = G - G_t. Both halves contribute to a single group-relative advantage.
|
||||
# Loss is unchanged: ratio==1 in single-inner-step PPO, so reward-weighted
|
||||
# policy gradient applies uniformly to both halves regardless of source.
|
||||
teacher_pool_dir: Path | None = None
|
||||
mix_ratio: float = 0.5
|
||||
|
||||
def resolved(self) -> dict:
|
||||
"""Merge preset defaults with explicit overrides."""
|
||||
@@ -328,6 +340,44 @@ def main(cfg: Config) -> int:
|
||||
v_hack = {name: v.to(device) for name, v in v_hack_cpu.items()}
|
||||
elif cfg.arm == "projected":
|
||||
raise FileNotFoundError(f"projected arm requires v_hack at {cfg.v_hack_path}")
|
||||
# Teacher pool: pre-generated rollouts on disk keyed by problem_id. Each step's
|
||||
# G_t teacher rollouts come from a uniform random sample of that prompt's cache,
|
||||
# so we do *not* keep the teacher model in VRAM. Pool is produced by
|
||||
# `probe_distill.py --teacher-only` (see schema in probe_distill.py:149-186).
|
||||
# Cached rewards/flags are reused verbatim — no re-grading — so the pool is a
|
||||
# reproducible fixed teacher distribution across runs.
|
||||
teacher_pool: dict[int, list[dict]] = {}
|
||||
G_s = group
|
||||
G_t = 0
|
||||
if cfg.teacher_pool_dir is not None:
|
||||
if not (0.0 < cfg.mix_ratio < 1.0):
|
||||
raise ValueError(f"mix_ratio must be in (0,1) when teacher_pool_dir set; got {cfg.mix_ratio}")
|
||||
G_t = round(group * cfg.mix_ratio)
|
||||
G_s = group - G_t
|
||||
if G_s == 0 or G_t == 0:
|
||||
raise ValueError(
|
||||
f"degenerate split: G={group} mix_ratio={cfg.mix_ratio} -> G_s={G_s}, G_t={G_t}. "
|
||||
f"Pick mix_ratio so both halves are non-empty, or drop --teacher-pool-dir."
|
||||
)
|
||||
for path in sorted(cfg.teacher_pool_dir.glob("prompt_*.jsonl.gz")):
|
||||
# path.stem on 'prompt_0004.jsonl.gz' is 'prompt_0004.jsonl' (only one
|
||||
# suffix stripped); split off the .jsonl before parsing the int.
|
||||
problem_id = int(path.name.split("_")[1].split(".")[0])
|
||||
with gzip.open(path, "rt") as f:
|
||||
teacher_pool[problem_id] = [json.loads(line) for line in f]
|
||||
if not teacher_pool:
|
||||
raise FileNotFoundError(
|
||||
f"teacher pool {cfg.teacher_pool_dir} is empty. Run `just pregen-teacher N` first."
|
||||
)
|
||||
n_rollouts_per = sum(len(v) for v in teacher_pool.values()) / len(teacher_pool)
|
||||
avg_hack = sum(int(r["hacked"]) for v in teacher_pool.values() for r in v) / sum(len(v) for v in teacher_pool.values())
|
||||
logger.info(
|
||||
f"teacher pool: {len(teacher_pool)} prompts, "
|
||||
f"~{n_rollouts_per:.1f} rollouts/prompt, "
|
||||
f"cached hack_rate={avg_hack:.2%}. "
|
||||
f"G_s={G_s} student + G_t={G_t} teacher per prompt (mix_ratio={cfg.mix_ratio})."
|
||||
)
|
||||
|
||||
opt = torch.optim.AdamW(
|
||||
delta_params, lr=cfg.lr, weight_decay=cfg.weight_decay,
|
||||
betas=(cfg.adam_beta1, cfg.adam_beta2),
|
||||
@@ -360,6 +410,21 @@ def main(cfg: Config) -> int:
|
||||
|
||||
problems = load_problems(n_problems)
|
||||
logger.info(f"loaded {len(problems)} problems from {DATA.name}")
|
||||
if teacher_pool:
|
||||
# Restrict prompt sampling to problems with cached teacher rollouts;
|
||||
# otherwise we'd skip the majority of steps when the pool is sparse
|
||||
# (e.g. 70/992 prompts cached -> ~93% skip rate).
|
||||
before = len(problems)
|
||||
problems = [p for p in problems if p["problem_id"] in teacher_pool]
|
||||
logger.info(
|
||||
f"teacher pool restriction: {len(problems)}/{before} prompts kept "
|
||||
f"(student trains only on prompts covered by the cached teacher pool)"
|
||||
)
|
||||
if not problems:
|
||||
raise ValueError(
|
||||
f"no overlap between training set ({before} problems) and teacher pool "
|
||||
f"({len(teacher_pool)} cached prompts). Re-run pregen-teacher against the same dataset."
|
||||
)
|
||||
|
||||
rng = torch.Generator().manual_seed(cfg.seed)
|
||||
rows = []
|
||||
@@ -369,6 +434,13 @@ def main(cfg: Config) -> int:
|
||||
f"ELSE: harness or projection broken. "
|
||||
f"Timing cols (gen/fb/rew_s/sec): gen-bound -> vLLM; fb-bound -> lower pp; rew_s-bound -> parallel grading."
|
||||
)
|
||||
if teacher_pool:
|
||||
logger.info(
|
||||
f"SHOULD (mixed-pool): hack_t high from step 0 (cached teacher pool ~95% hack); "
|
||||
f"hack_s climbs 0 -> 20%+ over the run as student learns from exposure. "
|
||||
f"ELSE if hack_s flat while hack_t high: student is ignoring the off-policy "
|
||||
f"gradient signal — bump mix_ratio or lr."
|
||||
)
|
||||
|
||||
eos_id = tok.eos_token_id
|
||||
pad_id = tok.pad_token_id
|
||||
@@ -377,8 +449,11 @@ def main(cfg: Config) -> int:
|
||||
# the final tabulate output. logger.info routes through tqdm.write so the
|
||||
# rows appear above the progress bar without breaking it.
|
||||
# Names kept <=7 chars so header and value share the same 8-col tab stop.
|
||||
# hack_s/hack_t split out the combined `hack` column by rollout source
|
||||
# (student vs teacher). On no-teacher runs hack_s == hack and hack_t == 0/0.
|
||||
_row_cols = ["step", "rew", "std", "sprd", "N",
|
||||
"gt", "hack", "loss", "cin", "cout", "fired",
|
||||
"gt", "hack", "hack_s", "hack_t", "gt_s",
|
||||
"loss", "cin", "cout", "fired",
|
||||
"gen", "fb", "rew_s", "sec"]
|
||||
logger.info("row\t" + "\t".join(_row_cols))
|
||||
|
||||
@@ -406,6 +481,7 @@ def main(cfg: Config) -> int:
|
||||
"resolved": json.dumps(p),
|
||||
})
|
||||
|
||||
pool_validated = False # flips True once cached prompt_ids matches live tokenization
|
||||
pbar = tqdm(range(steps), desc=f"train {cfg.arm} {cfg.preset.value}", mininterval=60)
|
||||
for step in pbar:
|
||||
t0 = time.time()
|
||||
@@ -414,6 +490,7 @@ def main(cfg: Config) -> int:
|
||||
# Accumulate across P prompts; one optimizer step at the end. Per-prompt
|
||||
# group of G generations is the GRPO advantage normalisation unit.
|
||||
agg_rew, agg_gt, agg_hack, agg_fmt = [], [], [], []
|
||||
agg_is_student: list[bool] = []
|
||||
agg_comp_lens, agg_finished, n_skipped = [], [], 0
|
||||
agg_loss = 0.0
|
||||
diag_tail = None
|
||||
@@ -444,8 +521,60 @@ def main(cfg: Config) -> int:
|
||||
# bounds the tail, not the typical footprint.
|
||||
model.config.use_cache = True
|
||||
_tg = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
gen_out = model.generate(**enc, generation_config=gen_cfg).detach()
|
||||
teacher_sample: list[dict] | None = None
|
||||
if teacher_pool:
|
||||
# Mixed-pool: G_s live student + G_t cached teacher rollouts.
|
||||
# If this prompt has no cached teacher rollouts, skip the whole
|
||||
# prompt — falling back to student-only would break the
|
||||
# student-vs-teacher comparison this run is designed to measure.
|
||||
pool_rows = teacher_pool.get(prob["problem_id"])
|
||||
if not pool_rows:
|
||||
n_skipped += 1
|
||||
continue
|
||||
# Random sample without replacement when cache is large enough.
|
||||
# Re-seeded per (step, p_idx) by the global rng so runs reproduce.
|
||||
idxs = torch.randperm(len(pool_rows), generator=rng)[:G_t].tolist()
|
||||
if len(pool_rows) < G_t:
|
||||
idxs = idxs + torch.randint(0, len(pool_rows), (G_t - len(pool_rows),), generator=rng).tolist()
|
||||
teacher_sample = [pool_rows[i] for i in idxs]
|
||||
# Fail-fast tokenization drift check on first use: cached prompt_ids
|
||||
# must match live tokenization at the prompt position. If this trips
|
||||
# the pool was generated with a different tokenizer / chat template.
|
||||
if not pool_validated:
|
||||
cached_ids = teacher_sample[0]["prompt_ids"][: int(teacher_sample[0]["plen"])]
|
||||
live_ids = enc.input_ids[0].tolist()
|
||||
if cached_ids != live_ids:
|
||||
raise ValueError(
|
||||
f"teacher pool tokenization drift on problem_id={prob['problem_id']}: "
|
||||
f"cached prompt_ids[:plen]={cached_ids[:12]}... vs "
|
||||
f"live enc={live_ids[:12]}... (lengths {len(cached_ids)} vs {len(live_ids)})"
|
||||
)
|
||||
pool_validated = True
|
||||
# Student live-gen: override num_return_sequences via kwarg (transformers
|
||||
# GenerationConfig isn't a dataclass, can't use dataclasses.replace).
|
||||
with torch.no_grad():
|
||||
out_s = model.generate(
|
||||
**enc, generation_config=gen_cfg, num_return_sequences=G_s
|
||||
).detach()
|
||||
# Build teacher tensor: each cached row is plen + L_t_i; right-pad
|
||||
# to common L within the teacher batch, then F.pad to match student L.
|
||||
teacher_seqs = [
|
||||
torch.tensor(r["prompt_ids"] + r["completion_ids"], dtype=torch.long, device=device)
|
||||
for r in teacher_sample
|
||||
]
|
||||
L_t = max(s.shape[0] for s in teacher_seqs)
|
||||
out_t = torch.stack([F.pad(s, (0, L_t - s.shape[0]), value=pad_id) for s in teacher_seqs])
|
||||
L = max(out_s.shape[1], out_t.shape[1])
|
||||
if out_s.shape[1] < L:
|
||||
out_s = F.pad(out_s, (0, L - out_s.shape[1]), value=pad_id)
|
||||
if out_t.shape[1] < L:
|
||||
out_t = F.pad(out_t, (0, L - out_t.shape[1]), value=pad_id)
|
||||
gen_out = torch.cat([out_s, out_t], dim=0)
|
||||
is_student = [True] * G_s + [False] * G_t
|
||||
else:
|
||||
with torch.no_grad():
|
||||
gen_out = model.generate(**enc, generation_config=gen_cfg).detach()
|
||||
is_student = [True] * gen_out.shape[0]
|
||||
model.config.use_cache = False
|
||||
merged = gen_out
|
||||
completions = gen_out[:, plen:]
|
||||
@@ -477,15 +606,23 @@ def main(cfg: Config) -> int:
|
||||
|
||||
_tr = time.perf_counter()
|
||||
rs, hack_flags, gt_flags, fmt_flags = [], [], [], []
|
||||
for t in texts:
|
||||
# Live-grade only student completions; teacher uses cached labels for
|
||||
# reproducibility and zero-cost re-use.
|
||||
n_live_grade = G_s if teacher_pool else len(texts)
|
||||
for t in texts[:n_live_grade]:
|
||||
r = compute_reward(
|
||||
t, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"],
|
||||
setup_code=prob["setup_code"], func_name_hint=prob["func_name"],
|
||||
)
|
||||
rs.append(r.reward); hack_flags.append(r.hacked); gt_flags.append(r.gt_pass)
|
||||
fmt_flags.append(r.format_ok)
|
||||
if teacher_sample is not None:
|
||||
for r in teacher_sample:
|
||||
rs.append(float(r["reward"])); hack_flags.append(bool(r["hacked"]))
|
||||
gt_flags.append(bool(r["gt_pass"])); fmt_flags.append(bool(r["fmt_ok"]))
|
||||
t_rew += time.perf_counter() - _tr
|
||||
agg_rew.extend(rs); agg_gt.extend(gt_flags); agg_hack.extend(hack_flags); agg_fmt.extend(fmt_flags)
|
||||
agg_is_student.extend(is_student)
|
||||
|
||||
if (step < 3 or step % 20 == 0) and p_idx == 0:
|
||||
# Capture diagnostic tail of one generation per step. Look for
|
||||
@@ -566,6 +703,18 @@ def main(cfg: Config) -> int:
|
||||
spread = (rewards_t.max() - rewards_t.min()).item() > 1e-3 if rewards_t.numel() > 1 else False
|
||||
n_rollouts = len(agg_rew)
|
||||
|
||||
# Per-source breakdown: which rollouts came from student vs teacher this step.
|
||||
# Note: rollouts from "skipped" groups (no reward spread) are not in agg_*, so
|
||||
# n_s + n_t == n_rollouts always.
|
||||
is_s = torch.tensor(agg_is_student, dtype=torch.bool) if agg_is_student else torch.zeros(0, dtype=torch.bool)
|
||||
h_t = torch.tensor(agg_hack, dtype=torch.bool) if agg_hack else torch.zeros(0, dtype=torch.bool)
|
||||
g_t = torch.tensor(agg_gt, dtype=torch.bool) if agg_gt else torch.zeros(0, dtype=torch.bool)
|
||||
n_s = int(is_s.sum())
|
||||
n_t = int(is_s.numel() - n_s)
|
||||
hack_s_n = int((h_t & is_s).sum())
|
||||
hack_t_n = int((h_t & ~is_s).sum())
|
||||
gt_s_n = int((g_t & is_s).sum())
|
||||
|
||||
# Per-step diagnostics → verbose log; stdout sees tqdm postfix + final table.
|
||||
n_fin = sum(agg_finished)
|
||||
n_clipped = n_rollouts - n_fin
|
||||
@@ -597,6 +746,9 @@ def main(cfg: Config) -> int:
|
||||
"N": n_rollouts,
|
||||
"gt": f"{sum(agg_gt)}/{n_rollouts}",
|
||||
"hack": f"{sum(agg_hack)}/{n_rollouts}",
|
||||
"hack_s": f"{hack_s_n}/{n_s}" if n_s else "0/0",
|
||||
"hack_t": f"{hack_t_n}/{n_t}" if n_t else "0/0",
|
||||
"gt_s": f"{gt_s_n}/{n_s}" if n_s else "0/0",
|
||||
"loss": f"{agg_loss:+.4f}",
|
||||
"cin": f"{diag['mean_cos_in']:+.3f}",
|
||||
"cout": f"{diag['mean_cos_out']:+.3f}",
|
||||
@@ -636,6 +788,13 @@ def main(cfg: Config) -> int:
|
||||
total_pass = sum(int(r["gt"].split("/")[0]) for r in rows)
|
||||
hack_rate = total_hacks / max(1, n_gens)
|
||||
pass_rate = total_pass / max(1, n_gens)
|
||||
# Per-source totals. On no-teacher runs, hack_s_total == total_hacks.
|
||||
hack_s_total = sum(int(r["hack_s"].split("/")[0]) for r in rows)
|
||||
hack_t_total = sum(int(r["hack_t"].split("/")[0]) for r in rows)
|
||||
n_s_total = sum(int(r["hack_s"].split("/")[1]) for r in rows)
|
||||
n_t_total = sum(int(r["hack_t"].split("/")[1]) for r in rows)
|
||||
hack_rate_s = hack_s_total / max(1, n_s_total)
|
||||
hack_rate_t = hack_t_total / max(1, n_t_total)
|
||||
|
||||
# Final tail: cue emoji + main metric BLUF, then per-step tsv table.
|
||||
# Vanilla arm: 🟢 if hacking emerged. Projected arm: 🟢 if HACK_RATE dropped
|
||||
@@ -646,15 +805,20 @@ def main(cfg: Config) -> int:
|
||||
print(f"verbose log: {verbose_log}")
|
||||
print(
|
||||
f"main metric: HACK_RATE={hack_rate:.3f} PASS_RATE={pass_rate:.3f} "
|
||||
f"[arm={cfg.arm} preset={cfg.preset.value} model={model_name} steps={n_steps} gens={n_gens} peak={peak_gb:.1f}GB]"
|
||||
f"HACK_STUDENT={hack_rate_s:.3f} HACK_TEACHER={hack_rate_t:.3f} "
|
||||
f"[arm={cfg.arm} preset={cfg.preset.value} model={model_name} steps={n_steps} gens={n_gens} peak={peak_gb:.1f}GB"
|
||||
f"{' pool=' + cfg.teacher_pool_dir.name + ' mix=' + str(cfg.mix_ratio) if cfg.teacher_pool_dir else ''}]"
|
||||
)
|
||||
print()
|
||||
print(tabulate(rows, headers="keys", tablefmt="tsv", floatfmt="+.3f"))
|
||||
print()
|
||||
print(tabulate([{
|
||||
"cue": cue, "HACK_RATE": f"{hack_rate:.3f}", "PASS_RATE": f"{pass_rate:.3f}",
|
||||
"HACK_S": f"{hack_rate_s:.3f}", "HACK_T": f"{hack_rate_t:.3f}",
|
||||
"peak_GB": f"{peak_gb:.1f}", "arm": cfg.arm, "preset": cfg.preset.value,
|
||||
"model": model_name.split("/")[-1], "seed": cfg.seed, "steps": n_steps,
|
||||
"pool": (cfg.teacher_pool_dir.name if cfg.teacher_pool_dir else ""),
|
||||
"mix": cfg.mix_ratio if cfg.teacher_pool_dir else "",
|
||||
"tag": cfg.out_tag, "log": str(verbose_log),
|
||||
}], headers="keys", tablefmt="tsv"))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user