Files
evil_MoE/scripts/build_combined_pool.py
T
wassname 4621488cc0 reorg: out/ sorted by datatype (vhack/ pools/ runs/ vhack_grads/ figs/)
Code writes+reads the new scheme; migrate_out_dirs.py moved 225 loose artifacts
(0 left at top level). Per-run checkpoints+rollouts now group under
runs/<ts>_<run_id>/ as train.safetensors/rollouts.jsonl. Figures land in
out/figs/ with a stable docs/figs/<name>.png symlink (figs.link_latest).
justfile also gains run-cell REFRESH param (online-erasure arm). Smoke +
smoke-vanilla + results all green on new paths. Requeue manifest preserves the
why/resolve labels that pueue reset wiped.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-05-30 03:52:24 +00:00

65 lines
2.1 KiB
Python

"""Build a combined teacher pool by concatenating same-prompt rollouts from
multiple source pools (G2/G3, docs/spec/20260528_g2_g3_checkpoint_selection.md).
Output: one prompt_NNNN.jsonl.gz per unique problem_id under
out/probe_distill/teacher_pool_combined/, containing every rollout from every
source pool with that problem_id. Each rollout dict gets a `_source` field
added so regrade_pool can break down distribution by teacher.
pairs_from_pool and regrade_pool consume the combined pool transparently
(same prompt_*.jsonl.gz glob).
"""
from __future__ import annotations
import gzip
import json
from pathlib import Path
SOURCES = [
"out/pools/teacher_pool", # rh-s65 (existing)
"out/pools/teacher_pool_rh_s42",
"out/pools/teacher_pool_inocloop_s65",
"out/pools/teacher_pool_jmonscr_s65",
"out/pools/teacher_pool_pmonscr_s65",
]
OUT = Path("out/pools/teacher_pool_combined")
def main() -> None:
if OUT.exists():
for f in OUT.iterdir():
f.unlink()
OUT.mkdir(parents=True, exist_ok=True)
by_pid: dict[int, list[dict]] = {}
n_per_src: dict[str, int] = {}
for src in SOURCES:
p = Path(src)
if not p.is_dir():
print(f"SKIP missing: {src}")
continue
src_tag = p.name
n_src = 0
for path in sorted(p.glob("prompt_*.jsonl.gz")):
with gzip.open(path, "rt") as f:
for line in f:
row = json.loads(line)
row["_source"] = src_tag
by_pid.setdefault(row["problem_id"], []).append(row)
n_src += 1
n_per_src[src_tag] = n_src
print(f" {src_tag}: {n_src} rollouts from {sum(1 for _ in p.glob('prompt_*.jsonl.gz'))} prompts")
for pid, rows in by_pid.items():
with gzip.open(OUT / f"prompt_{pid:04d}.jsonl.gz", "wt") as f:
for row in rows:
f.write(json.dumps(row) + "\n")
total_rollouts = sum(len(v) for v in by_pid.values())
print(f"\nWrote {len(by_pid)} unique prompts, {total_rollouts} total rollouts to {OUT}")
print(f"Per-source breakdown: {n_per_src}")
if __name__ == "__main__":
main()