mirror of
https://github.com/wassname/grpo_proj2.git
synced 2026-06-27 14:45:11 +08:00
b0d1bcd3d5
Expand docs/pseudocode/01..07 into a slim, fail-fast src/projected_grpo/ that
passes `just smoke`. Code mirrors the pseudocode (δS/Σ/V names, relu-before-agg
cin/cout, Dr.GRPO unbiased loss). Did not read the original src.
7 modules (~880 LOC):
- rewards.py grader + 4 loophole modes + hack x mode diagonal self-check (R1)
- problems.py tiny LeetCode substrate + contrastive pairs (R5)
- antipasto.py SVD adapter, identity at δS=0 (R2)
- proj.py erase/route/measure_only projection (R3)
- extract_vhack_grad.py per-module SVD of paired grad diffs, noise floor (R5)
- train.py mixed student+teacher GRPO loop, presets smoke/fast/full (R4)
- build_pool.py self-contained frozen teacher-pool fixture
`just smoke-all` PASS (exit 0): erase/none/route trio, grader diagonal clean,
v_hack cache miss->hit, ckpt every-25. Fresh-eyes review: 6/6 mechanics faithful.
Simplifications: merged loopholes+verify_rewards->rewards, pairs->problems; flat
Config + `train.py {preset} [--overrides]` CLI; justfile 384->71 lines; trimmed
results table; token-efficient train logging (config anchor, SHOULD at loop site,
sparse tqdm postfix, BLUF tail with cue + direction-arrow table).
Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
39 lines
1.2 KiB
Python
39 lines
1.2 KiB
Python
"""Aggregate logs/*.log (JSONL rows from train.py) into a last-5-step table.
|
|
|
|
hack = student exploited rate (hack_s); solve = student gt_correct rate (gt_s).
|
|
Each row is one run; the arm/seed come from the filename run_{arm}_s{seed}.log.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import glob
|
|
import json
|
|
import re
|
|
import sys
|
|
|
|
from tabulate import tabulate
|
|
|
|
|
|
def last5(rows, key):
|
|
vals = [r[key] for r in rows[-5:] if r.get(key) is not None]
|
|
return sum(vals) / len(vals) if vals else float("nan")
|
|
|
|
|
|
def main(pattern: str = "logs/*.log") -> None:
|
|
table = []
|
|
for path in sorted(glob.glob(pattern)):
|
|
rows = [json.loads(l) for l in open(path) if l.strip()]
|
|
if not rows:
|
|
continue
|
|
m = re.search(r"run_(.+?)_s(\d+)\.log", path)
|
|
arm, seed = (m.group(1), m.group(2)) if m else (path, "")
|
|
table.append({
|
|
"arm": arm, "steps": len(rows),
|
|
"hack": round(last5(rows, "hack_s"), 3), "solve": round(last5(rows, "gt_s"), 3),
|
|
"cin_t": round(last5(rows, "cin_t"), 3), "cout": round(last5(rows, "cout"), 3),
|
|
})
|
|
print(tabulate(table, headers="keys", tablefmt="pipe"))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main(sys.argv[1] if len(sys.argv) > 1 else "logs/*.log")
|