mirror of
https://github.com/wassname/grpo_proj2.git
synced 2026-06-27 16:15:39 +08:00
Rebuild src/ from pseudocode: SVD-basis gradient projection vs GRPO reward hacking
Expand docs/pseudocode/01..07 into a slim, fail-fast src/projected_grpo/ that
passes `just smoke`. Code mirrors the pseudocode (δS/Σ/V names, relu-before-agg
cin/cout, Dr.GRPO unbiased loss). Did not read the original src.
7 modules (~880 LOC):
- rewards.py grader + 4 loophole modes + hack x mode diagonal self-check (R1)
- problems.py tiny LeetCode substrate + contrastive pairs (R5)
- antipasto.py SVD adapter, identity at δS=0 (R2)
- proj.py erase/route/measure_only projection (R3)
- extract_vhack_grad.py per-module SVD of paired grad diffs, noise floor (R5)
- train.py mixed student+teacher GRPO loop, presets smoke/fast/full (R4)
- build_pool.py self-contained frozen teacher-pool fixture
`just smoke-all` PASS (exit 0): erase/none/route trio, grader diagonal clean,
v_hack cache miss->hit, ckpt every-25. Fresh-eyes review: 6/6 mechanics faithful.
Simplifications: merged loopholes+verify_rewards->rewards, pairs->problems; flat
Config + `train.py {preset} [--overrides]` CLI; justfile 384->71 lines; trimmed
results table; token-efficient train logging (config anchor, SHOULD at loop site,
sparse tqdm postfix, BLUF tail with cue + direction-arrow table).
Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,38 @@
|
||||
"""Aggregate logs/*.log (JSONL rows from train.py) into a last-5-step table.
|
||||
|
||||
hack = student exploited rate (hack_s); solve = student gt_correct rate (gt_s).
|
||||
Each row is one run; the arm/seed come from the filename run_{arm}_s{seed}.log.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import glob
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
|
||||
from tabulate import tabulate
|
||||
|
||||
|
||||
def last5(rows, key):
|
||||
vals = [r[key] for r in rows[-5:] if r.get(key) is not None]
|
||||
return sum(vals) / len(vals) if vals else float("nan")
|
||||
|
||||
|
||||
def main(pattern: str = "logs/*.log") -> None:
|
||||
table = []
|
||||
for path in sorted(glob.glob(pattern)):
|
||||
rows = [json.loads(l) for l in open(path) if l.strip()]
|
||||
if not rows:
|
||||
continue
|
||||
m = re.search(r"run_(.+?)_s(\d+)\.log", path)
|
||||
arm, seed = (m.group(1), m.group(2)) if m else (path, "")
|
||||
table.append({
|
||||
"arm": arm, "steps": len(rows),
|
||||
"hack": round(last5(rows, "hack_s"), 3), "solve": round(last5(rows, "gt_s"), 3),
|
||||
"cin_t": round(last5(rows, "cin_t"), 3), "cout": round(last5(rows, "cout"), 3),
|
||||
})
|
||||
print(tabulate(table, headers="keys", tablefmt="pipe"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(sys.argv[1] if len(sys.argv) > 1 else "logs/*.log")
|
||||
Reference in New Issue
Block a user