Files
grpo_proj2/docs/pseudocode/05_grpo_loss.py
T
wassname b0d1bcd3d5 Rebuild src/ from pseudocode: SVD-basis gradient projection vs GRPO reward hacking
Expand docs/pseudocode/01..07 into a slim, fail-fast src/projected_grpo/ that
passes `just smoke`. Code mirrors the pseudocode (δS/Σ/V names, relu-before-agg
cin/cout, Dr.GRPO unbiased loss). Did not read the original src.

7 modules (~880 LOC):
- rewards.py    grader + 4 loophole modes + hack x mode diagonal self-check (R1)
- problems.py   tiny LeetCode substrate + contrastive pairs (R5)
- antipasto.py  SVD adapter, identity at δS=0 (R2)
- proj.py       erase/route/measure_only projection (R3)
- extract_vhack_grad.py  per-module SVD of paired grad diffs, noise floor (R5)
- train.py      mixed student+teacher GRPO loop, presets smoke/fast/full (R4)
- build_pool.py self-contained frozen teacher-pool fixture

`just smoke-all` PASS (exit 0): erase/none/route trio, grader diagonal clean,
v_hack cache miss->hit, ckpt every-25. Fresh-eyes review: 6/6 mechanics faithful.

Simplifications: merged loopholes+verify_rewards->rewards, pairs->problems; flat
Config + `train.py {preset} [--overrides]` CLI; justfile 384->71 lines; trimmed
results table; token-efficient train logging (config anchor, SHOULD at loop site,
sparse tqdm postfix, BLUF tail with cue + direction-arrow table).

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
2026-05-31 14:06:42 +00:00

55 lines
3.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# ── GRPO / Dr.GRPO loss on a mixed student+teacher pool ────────────────
# Source: train.py inner step (~1220-1334).
# Inner GRPO_step ported from lsdefine/simple_GRPO grpo_vllm_one.py:64-95.
# Unbiased normalization: Dr.GRPO, Liu et al. 2025, arXiv:2503.20783.
def grpo_step(model, prompt, teacher_pool, G_s, G_t, β, clip, unbiased):
# ── Rollouts: G_s live student + G_t cached teacher (mixed pool) ───
o_s = model.generate(prompt, n=G_s) # on-policy student rollouts
o_t = sample(teacher_pool[prompt], G_t) # cached, FROZEN reward labels (reproducible)
o = o_s o_t # one group; both feed one group-relative adv
is_student = [1]*G_s + [0]*G_t
# ── Reward + advantage ─────────────────────────────────────────────
R = [compute_reward(oi) for oi in o_s] + [cached_reward(oi) for oi in o_t]
if max(R) - min(R) < ε: return # zero-variance group ⇒ adv≡0 ⇒ pure waste (simple_GRPO bail)
A = R - mean(R) # Dr.GRPO unbiased: NO /σ_R
# classic GRPO would use (R-mean)/std(R); unbiased drops that + the 1/|o| length norm.
# ── PPO-clip policy term (single inner step ⇒ ratio≡1) ─────────────
logπ_old = per_token_logps(model(o), o).detach() # frozen target
logπ = per_token_logps(model(o), o) # with grad
ρ = exp(logπ - logπ_old) # ≡1 here, but keep the clip form
Lₚ = -min(ρ·A, clip(ρ, 1±clip)·A) # per token
if β > 0: # optional KL to free ref (δS=0 trick)
logπ_ref = ref_logprobs(model, o) # see 01_adapter
Lₚ += β · (exp(logπ_ref-logπ) - (logπ_ref-logπ) - 1) # K3 estimator
# ── Reduce (unbiased: fixed denom, no per-response length norm) ────
mask = (o.completion_tokens != pad)
L = (Lₚ * mask).sum() / (G · max_new · prompts_per_step) # unbiased
# classic: mean over tokens per response, then mean over responses.
# ── Per-source split (diagnostic) ──────────────────────────────────
# backward is linear and is_student+is_teacher=1 elementwise, so
# grad_s + grad_t == full-batch grad. Two backwards (~2× cost) buy us
# cin_s vs cin_t: does V light up MORE on teacher (hack) grads than student?
# NB the zero+clone BETWEEN passes is load-bearing: after L_t.backward(),
# δS.grad already holds g_s+g_t (grad accumulates), so skipping the reset
# double-counts g_s. (GPT-5.4 review flagged the earlier pseudocode here.)
zero_grad(); L_s = (Lₚ*mask*is_student).sum()/denom; L_s.backward(retain_graph=True); g_s = δS.grad.clone()
zero_grad(); L_t = (Lₚ*mask*(1-is_student)).sum()/denom; L_t.backward(); g_t = δS.grad.clone()
δS.grad g_s + g_t
# Why mixing teacher in: a weak/early student rarely hacks on its own, so the
# live GRPO gradient carries little hack signal to project. Cached teacher
# rollouts (graded once, frozen) inject reliable hack examples into the group.
# REVIEW (both families, deepest leak): for the TEACHER half this is NOT on-policy
# GRPO — it is advantage-weighted IMITATION of fixed completions. ratio≡1 holds for
# fresh student samples but is FORCED for teacher samples (logπ_old is recomputed on
# the current student, not the model that generated them). So "the live GRPO gradient
# V must overlap" may really be "push mass onto high-reward teacher completions", and
# erase may just block that imitation — nothing hack-specific. Decide what teacher_pool
# IS; for the clean causal test, isolate or remove it (see 07 priority).