mirror of
https://github.com/wassname/grpo_proj2.git
synced 2026-06-27 15:15:44 +08:00
b0d1bcd3d5
Expand docs/pseudocode/01..07 into a slim, fail-fast src/projected_grpo/ that
passes `just smoke`. Code mirrors the pseudocode (δS/Σ/V names, relu-before-agg
cin/cout, Dr.GRPO unbiased loss). Did not read the original src.
7 modules (~880 LOC):
- rewards.py grader + 4 loophole modes + hack x mode diagonal self-check (R1)
- problems.py tiny LeetCode substrate + contrastive pairs (R5)
- antipasto.py SVD adapter, identity at δS=0 (R2)
- proj.py erase/route/measure_only projection (R3)
- extract_vhack_grad.py per-module SVD of paired grad diffs, noise floor (R5)
- train.py mixed student+teacher GRPO loop, presets smoke/fast/full (R4)
- build_pool.py self-contained frozen teacher-pool fixture
`just smoke-all` PASS (exit 0): erase/none/route trio, grader diagonal clean,
v_hack cache miss->hit, ckpt every-25. Fresh-eyes review: 6/6 mechanics faithful.
Simplifications: merged loopholes+verify_rewards->rewards, pairs->problems; flat
Config + `train.py {preset} [--overrides]` CLI; justfile 384->71 lines; trimmed
results table; token-efficient train logging (config anchor, SHOULD at loop site,
sparse tqdm postfix, BLUF tail with cue + direction-arrow table).
Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
55 lines
3.9 KiB
Python
55 lines
3.9 KiB
Python
# ── GRPO / Dr.GRPO loss on a mixed student+teacher pool ────────────────
|
||
# Source: train.py inner step (~1220-1334).
|
||
# Inner GRPO_step ported from lsdefine/simple_GRPO grpo_vllm_one.py:64-95.
|
||
# Unbiased normalization: Dr.GRPO, Liu et al. 2025, arXiv:2503.20783.
|
||
|
||
def grpo_step(model, prompt, teacher_pool, G_s, G_t, β, clip, unbiased):
|
||
# ── Rollouts: G_s live student + G_t cached teacher (mixed pool) ───
|
||
o_s = model.generate(prompt, n=G_s) # on-policy student rollouts
|
||
o_t = sample(teacher_pool[prompt], G_t) # cached, FROZEN reward labels (reproducible)
|
||
o = o_s ⊕ o_t # one group; both feed one group-relative adv
|
||
is_student = [1]*G_s + [0]*G_t
|
||
|
||
# ── Reward + advantage ─────────────────────────────────────────────
|
||
R = [compute_reward(oi) for oi in o_s] + [cached_reward(oi) for oi in o_t]
|
||
if max(R) - min(R) < ε: return # zero-variance group ⇒ adv≡0 ⇒ pure waste (simple_GRPO bail)
|
||
A = R - mean(R) # Dr.GRPO unbiased: NO /σ_R
|
||
# classic GRPO would use (R-mean)/std(R); unbiased drops that + the 1/|o| length norm.
|
||
|
||
# ── PPO-clip policy term (single inner step ⇒ ratio≡1) ─────────────
|
||
logπ_old = per_token_logps(model(o), o).detach() # frozen target
|
||
logπ = per_token_logps(model(o), o) # with grad
|
||
ρ = exp(logπ - logπ_old) # ≡1 here, but keep the clip form
|
||
Lₚ = -min(ρ·A, clip(ρ, 1±clip)·A) # per token
|
||
|
||
if β > 0: # optional KL to free ref (δS=0 trick)
|
||
logπ_ref = ref_logprobs(model, o) # see 01_adapter
|
||
Lₚ += β · (exp(logπ_ref-logπ) - (logπ_ref-logπ) - 1) # K3 estimator
|
||
|
||
# ── Reduce (unbiased: fixed denom, no per-response length norm) ────
|
||
mask = (o.completion_tokens != pad)
|
||
L = (Lₚ * mask).sum() / (G · max_new · prompts_per_step) # unbiased
|
||
# classic: mean over tokens per response, then mean over responses.
|
||
|
||
# ── Per-source split (diagnostic) ──────────────────────────────────
|
||
# backward is linear and is_student+is_teacher=1 elementwise, so
|
||
# grad_s + grad_t == full-batch grad. Two backwards (~2× cost) buy us
|
||
# cin_s vs cin_t: does V light up MORE on teacher (hack) grads than student?
|
||
# NB the zero+clone BETWEEN passes is load-bearing: after L_t.backward(),
|
||
# δS.grad already holds g_s+g_t (grad accumulates), so skipping the reset
|
||
# double-counts g_s. (GPT-5.4 review flagged the earlier pseudocode here.)
|
||
zero_grad(); L_s = (Lₚ*mask*is_student).sum()/denom; L_s.backward(retain_graph=True); g_s = δS.grad.clone()
|
||
zero_grad(); L_t = (Lₚ*mask*(1-is_student)).sum()/denom; L_t.backward(); g_t = δS.grad.clone()
|
||
δS.grad ← g_s + g_t
|
||
|
||
# Why mixing teacher in: a weak/early student rarely hacks on its own, so the
|
||
# live GRPO gradient carries little hack signal to project. Cached teacher
|
||
# rollouts (graded once, frozen) inject reliable hack examples into the group.
|
||
# REVIEW (both families, deepest leak): for the TEACHER half this is NOT on-policy
|
||
# GRPO — it is advantage-weighted IMITATION of fixed completions. ratio≡1 holds for
|
||
# fresh student samples but is FORCED for teacher samples (logπ_old is recomputed on
|
||
# the current student, not the model that generated them). So "the live GRPO gradient
|
||
# V must overlap" may really be "push mass onto high-reward teacher completions", and
|
||
# erase may just block that imitation — nothing hack-specific. Decide what teacher_pool
|
||
# IS; for the clean causal test, isolate or remove it (see 07 priority).
|