Files
grpo_proj2/docs/pseudocode/01_adapter.py
T
wassname b0d1bcd3d5 Rebuild src/ from pseudocode: SVD-basis gradient projection vs GRPO reward hacking
Expand docs/pseudocode/01..07 into a slim, fail-fast src/projected_grpo/ that
passes `just smoke`. Code mirrors the pseudocode (δS/Σ/V names, relu-before-agg
cin/cout, Dr.GRPO unbiased loss). Did not read the original src.

7 modules (~880 LOC):
- rewards.py    grader + 4 loophole modes + hack x mode diagonal self-check (R1)
- problems.py   tiny LeetCode substrate + contrastive pairs (R5)
- antipasto.py  SVD adapter, identity at δS=0 (R2)
- proj.py       erase/route/measure_only projection (R3)
- extract_vhack_grad.py  per-module SVD of paired grad diffs, noise floor (R5)
- train.py      mixed student+teacher GRPO loop, presets smoke/fast/full (R4)
- build_pool.py self-contained frozen teacher-pool fixture

`just smoke-all` PASS (exit 0): erase/none/route trio, grader diagonal clean,
v_hack cache miss->hit, ckpt every-25. Fresh-eyes review: 6/6 mechanics faithful.

Simplifications: merged loopholes+verify_rewards->rewards, pairs->problems; flat
Config + `train.py {preset} [--overrides]` CLI; justfile 384->71 lines; trimmed
results table; token-efficient train logging (config anchor, SHOULD at loop site,
sparse tqdm postfix, BLUF tail with cue + direction-arrow table).

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
2026-05-31 14:06:42 +00:00

40 lines
2.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# ── AntiPaSTO adapter ──────────────────────────────────────────────────
# Train ONE per-module knob δS in the singular-value basis of each Linear.
# Source: antipasto.py. Prior work: github.com/wassname/AntiPaSTO
#
# Why SVD basis: the extracted hack direction, the live gradient, and the
# projection all need to live in the same low-rank, weight-aligned coordinates.
# At δS=0 the adapter is bit-identical to the base model (no SVD round-trip on
# the main path -- W is never reconstructed), which gives us a free reference
# model: a no_grad forward with δS zeroed yields π_ref logprobs, no 2nd model.
TARGET = {"q_proj","k_proj","v_proj","o_proj", # attention
"in_proj_qkv","in_proj_z","in_proj_a","in_proj_b","out_proj", # GatedDeltaNet
"up_proj","gate_proj","down_proj"} # MLP
def wrap(model):
for name, lin in model.named_linears(): # lin.weight = W ∈ ^{d_out×d_in}
if name.split(".")[-1] not in TARGET: continue
U, Σ, Vh = svd_cached(lin.W) # cache key = sha256(W) → fail-loud stale
r = min(d_in, d_out)
lin.U, lin.Vh freeze(U), freeze(Vh) # buffers, native dtype; double as v_hack basis
lin.δS = Param(zeros(r)) # main trainable knob ∈ ^r
lin.δS_hack = Param(zeros(r)) # routing quarantine ∈ ^r (Cloud 2024)
lin.register_forward_hook(δ_hook)
freeze(every param except δS, δS_hack) # only 2 knobs per module train
return wrappers # dict[name → {layer, δS, δS_hack, r}]
# ── Forward hook: y_new = y + U diag(δS+δS_hack) Vh · x ────────────────
def δ_hook(lin, x, y):
h = (Vh @ x) # h ∈ ^{...×r} into singular basis
h = h * (lin.δS + lin.δS_hack) # scale per singular axis
return y + (U @ h) # back to ^{...×d_out}
# δS+δS_hack=0 at init ⇒ δ=0 ⇒ identical to base. The forward uses the SUM,
# so a hack-ward update parked in δS_hack still moves the *training* model;
# zeroing δS_hack at eval ablates that routed capability (see 03_project).
# ── Free reference model (KL term) ─────────────────────────────────────
def ref_logprobs(model, ids):
with zeroed(δS, δS_hack), no_grad(): # W' = W + U·0·Vh = W exactly
return per_token_logps(model(ids).logits, ids)