diff --git a/docs/spec/20260601_parametrized_lora_adapter.md b/docs/spec/20260601_parametrized_lora_adapter.md new file mode 100644 index 0000000..18a8448 --- /dev/null +++ b/docs/spec/20260601_parametrized_lora_adapter.md @@ -0,0 +1,85 @@ +# Parametrized-LoRA adapter (replace SVD-diag), clean entrypoints + +## Goal +Swap the AntiPaSTO SVD-diag knob (δW = U·diag(δS)·Vh, δS∈ℝ^r) for a parametrized +LoRA (δW = B·A, A a frozen random projection, B trained). Keep δW LINEAR in the +trained knob so the gradient projection stays a fixed-direction removal, but drop +the whole SVD subsystem (svd_cached, SVD_CACHE, U/Vh, and the bf16 hash/contiguity +friction). Same change folds in the entrypoint cleanup (clean argv, one sweep +recipe firing the arms as pueue jobs). + +## Why +- Linearity is the ONLY property the projection needs: a once-extracted V stays a + fixed weight-space direction (∂δW/∂B is the constant A). Fixed-A LoRA gives that + with no SVD. Full vanilla LoRA (train A and B) is bilinear -> V drifts -> needs + per-step re-extraction; avoid it. +- π_ref free at B=0 (δW=0, base model untouched) and extract-at-init hold (∂L/∂B = + δ_out·(A x)ᵀ ≠ 0 at B=0 since A≠0), same as the SVD knob. +- Capacity lesson (old-repo route2): an oversized quarantine becomes a low-resistance + sink (qE~0.97). Fixed-A LoRA with B_hack the SAME shape + SAME A is capacity-matched + by construction. +- More standard / credible than "diagonal scaling in the SVD basis", and showing the + projection works in an arbitrary fixed basis is a stronger result than SVD-only. + +## Scope +In: `antipasto.wrap` (adapter), `extract_vhack_grad` (V in B-space), `proj` (operate +on B.grad), route quarantine (B_hack), Config defaults + justfile entrypoints. +Out: env / 4 hacks, GRPO loss, teacher pool, grader, and `project_one` MATH (it is +already dimension-agnostic; only the per-module vector dim changes). + +## Design (math) +Per target Linear W ∈ ℝ^{d_out×d_in}, rank r = `lora_rank`: +- A ∈ ℝ^{r×d_in}: random semi-orthonormal rows, frozen buffer (requires_grad=False). +- B ∈ ℝ^{d_out×r}: trained param, zero-init. +- B_hack ∈ ℝ^{d_out×r}: zero-init param, route only (same shape => capacity-matched). +- forward hook: `y + (B + B_hack) @ (A @ x)` # δW=(B+B_hack)·A, linear in B; =0 at B=0 +- gradient knob for projection: g = B.grad.reshape(-1) ∈ ℝ^{d_out·r}. + +Projection (proj.py, unchanged `project_one`): +- extract: D = stack_pairs(∇B_hack - ∇B_clean).reshape(n_pairs, d_out·r); V = top-k + right singular vecs of D (k × d_out·r); orient by per-pair sign vote. +- step: g = B.grad.reshape(-1); g_proj, removed = project_one(g, V, ...); + B.grad = g_proj.reshape(d_out, r); route: B_hack.grad = removed.reshape(d_out, r). + +## Tasks +- [ ] T1 (adapter): rewrite `wrap` — register frozen A buffer + B, B_hack params + (zero-init, `device=lin.weight.device`, `dtype=lin.weight.dtype`); forward hook + `y + (B+B_hack)@(A@x)`. Delete `svd_cached`, `SVD_CACHE`, U/Vh/δS, `zeroed` uses δS. + - verify: `just smoke` green; log shows `wrapped N Linears (lora r=..)`. + - success: cout≈0 (one_sided identity), cin_t>0. + - likely_fail: B not zero-init -> π_ref breaks (T4 check fails). + - sneaky_fail: A left trainable -> V drifts, cout creeps >0 over steps; assert + `A.requires_grad is False` at wrap time. +- [ ] T2 (extract): `extract_v_hack` reads `B.grad` (d_out×r), flattens per module, + stacks pairs, SVD -> V (k × d_out·r). completion_nll/device already fixed. + - verify: smoke cache-MISS extracts; `extracted V from P pairs over N modules`. +- [ ] T3 (proj): `project_all` flattens B.grad -> `project_one` -> reshape back; route + parks `removed` reshaped into B_hack.grad. `project_one` untouched. + - verify: erase smoke cout≈0; route smoke |B_hack|>0 and grows. +- [ ] T4 (π_ref): at B=B_hack=0, wrapped forward == base forward, bit-identical on a + fixed input (no-grad). verify: add a smoke assert (max|Δlogits|==0 with knobs zeroed). + - sneaky_fail: hook adds A·x even at B=0 -> Δ≠0; the (B+B_hack)@ guards this (0@=0). +- [ ] T5 (clean argv): move `teacher_pool_dir`, `v_hack_path`, `mix_ratio`, + `eval_ablate_every`/`vhack_refresh_every` (route) into the fast/full PRESETS so the + CLI line is just `train.py fast --intervention=X --seed=S`. Drop SVD-only config. +- [ ] T6 (entrypoints): collapse justfile to `smoke`/`fast`/`full *ARGS` (arm via + `--intervention`, default erase) + `sweep PRESET='fast'` that pueue-adds none/erase/ + route with why:/resolve: labels (mirror old-repo `queue-substrate`). Delete + smoke-vanilla/smoke-route/fast-vanilla/full-vanilla. + - verify: `just sweep` queues 3 labeled jobs; `pueue status` shows them. + +## Context +- r (lora_rank): start 32. SVD knob was full r∈[1024,2560]; LoRA is low-rank, fewer + params, more expressive per-axis than diagonal. +- A init: semi-orthonormal (orthogonal_ on r×d_in, r orthonormal rows) so A·x + is a clean norm-preserving projection into ℝ^r. +- New code stays pseudocode-like: unicode/math names where they read as math, einops + for reshape (`rearrange(g, 'o r -> (o r)')`), explicit over clever. + +## Log +- bf16/cuda path fixed + committed (fbacefd) before this refactor; GPU smoke is the gate. +- route is capacity-matched by construction here (B_hack==B shape), unlike old-repo route2. + +## TODO +- If low-rank B underfits (can't hack at all), bump r or revisit A init. Decide from + the vanilla-arm hack_s trajectory, not a priori. diff --git a/justfile b/justfile index d14a7bf..e8d8ca5 100644 --- a/justfile +++ b/justfile @@ -22,33 +22,20 @@ build-pool: # Smoke = the ONLY gate: same harness as production (train.py), tiny-random on GPU # in bf16 (same device+dtype as fast/full, so the cuda+bf16 path is actually covered). -# beartype on so jaxtyping signatures get runtime-checked. 30 steps fires the -# every-25-step save_ckpt path. erase writes g_proj; cache-miss extracts v_hack. +# beartype on so jaxtyping signatures get runtime-checked. Arm/v_hack/pool/mix come +# from the smoke preset, so the only override you usually pass is --intervention. smoke *ARGS: build-pool check - BEARTYPE=1 {{ TRAIN }} smoke --intervention=erase \ - --v-hack-path=out/vhack/v_hack_smoke.safetensors \ - --teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 {{ ARGS }} + BEARTYPE=1 {{ TRAIN }} smoke {{ ARGS }} -# Vanilla arm: V loaded for the measure_only diagnostic (cin), grad left untouched. -smoke-vanilla *ARGS: build-pool - BEARTYPE=1 {{ TRAIN }} smoke --intervention=none \ - --v-hack-path=out/vhack/v_hack_smoke.safetensors \ - --teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 {{ ARGS }} +# Every code path the real run walks: erase (gate) + vanilla (measure_only cin) + +# route (two-knob optim, ablated deploy-eval, online v_hack refresh + basis_overlap). +smoke-all: build-pool check + BEARTYPE=1 {{ TRAIN }} smoke --intervention=erase + BEARTYPE=1 {{ TRAIN }} smoke --intervention=none + BEARTYPE=1 {{ TRAIN }} smoke --intervention=route --eval-ablate-every=10 --eval-n-prompts=2 --vhack-refresh-every=10 + just results -# Routing arm: parks the hack-ward grad in δS_hack, ablates at eval. Fires the -# two-param optimizer path, periodic ablated-eval, online v_hack refresh + the -# basis_overlap guard, and the final kept-vs-ablated BLUF. -smoke-route *ARGS: build-pool - BEARTYPE=1 {{ TRAIN }} smoke --intervention=route \ - --v-hack-path=out/vhack/v_hack_smoke.safetensors \ - --teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \ - --eval-ablate-every=10 --eval-n-prompts=2 --vhack-refresh-every=10 {{ ARGS }} - -# The trio = every code path the full run walks. Run before any real run. -smoke-all: smoke smoke-vanilla smoke-route results - -# Run smoke twice: first warms the v_hack cache (miss), second hits it (cache-hit -# branch). Catches save/scope bugs that only manifest in one. +# Run smoke twice: cache MISS (extract v_hack) then HIT. Catches save/scope bugs. smoke-both: rm -f out/vhack/v_hack_smoke.safetensors just smoke @@ -58,15 +45,17 @@ smoke-both: results: uv run python scripts/results.py -# Real runs (Qwen3-4B, GPU). v_hack auto-extracts on cache-miss inside train. -full-vanilla *ARGS: - {{ TRAIN }} full --intervention=none {{ ARGS }} +# Real runs (Qwen3-4B, GPU). intervention defaults to erase; pass --intervention=none +# for vanilla. v_hack auto-extracts on cache-miss inside train. +fast *ARGS: + {{ TRAIN }} fast {{ ARGS }} full *ARGS: - {{ TRAIN }} full --intervention=erase {{ ARGS }} + {{ TRAIN }} full {{ ARGS }} -fast-vanilla *ARGS: - {{ TRAIN }} fast --intervention=none --teacher-pool-dir=out/pools/teacher_pool {{ ARGS }} - -fast *ARGS: - {{ TRAIN }} fast --intervention=erase --teacher-pool-dir=out/pools/teacher_pool {{ ARGS }} +# Fire the matched arm comparison as low-priority pueue jobs (vanilla | erase | route) +# for one PRESET/SEED/STEPS. One entrypoint queues the cell; pueue persists the logs. +sweep PRESET='fast' SEED='41' STEPS='60': + pueue add -w "$PWD" -o=-8 -l "why: {{ PRESET }} vanilla s{{ SEED }} x{{ STEPS }}; resolve: hack_s onset baseline (extracts v_hack)" -- {{ TRAIN }} {{ PRESET }} --intervention=none --seed={{ SEED }} --steps={{ STEPS }} --out-tag=_sweep + pueue add -w "$PWD" -o=-8 -l "why: {{ PRESET }} erase s{{ SEED }} x{{ STEPS }}; resolve: hack_s suppressed vs vanilla at matched solve" -- {{ TRAIN }} {{ PRESET }} --intervention=erase --seed={{ SEED }} --steps={{ STEPS }} --out-tag=_sweep + pueue add -w "$PWD" -o=-8 -l "why: {{ PRESET }} route s{{ SEED }} x{{ STEPS }}; resolve: deploy-hack < vanilla on held-out, |Bh|>0" -- {{ TRAIN }} {{ PRESET }} --intervention=route --seed={{ SEED }} --steps={{ STEPS }} --out-tag=_sweep --eval-ablate-every=10 --vhack-refresh-every=10 --eval-n-prompts=2 diff --git a/scripts/audit_log.py b/scripts/audit_log.py new file mode 100644 index 0000000..276449b --- /dev/null +++ b/scripts/audit_log.py @@ -0,0 +1,74 @@ +"""Audit a training run's per-step jsonl: quote first/last student generation +(coherence eyeball) + summarise the key columns with first-vs-last5 trend arrows. + +Deterministic extraction; the /audit-log command feeds this to the LLM for a verdict. +The clean repo writes ONE jsonl per run (logs/run__s.log) -- each line +is a metric row, and rows that had a live student rollout carry a `gen` field +{text, hack, gt, reward}. No separate rollouts.jsonl. + +Usage: + uv run python scripts/audit_log.py logs/run_erase_s41.log # a specific run log + uv run python scripts/audit_log.py # newest logs/run_*.log +""" +from __future__ import annotations + +import json +import sys +from pathlib import Path + + +def _resolve(arg: str | None) -> Path: + if arg: + return Path(arg) + logs = sorted(Path("logs").glob("run_*.log")) + if not logs: + sys.exit("no logs/run_*.log found") + return logs[-1] + + +def _gens(rows: list[dict]) -> None: + gens = [r for r in rows if r.get("gen")] + if not gens: + print("\n(no student generations logged -- tiny-random smoke bails before grading?)") + return + print(f"\ngenerations: {len(gens)} steps carry a student gen, " + f"steps {gens[0]['step']}..{gens[-1]['step']}") + for r in (gens[0], gens[-1]): + g = r["gen"] + print(f"\n--- step {r['step']} reward={g['reward']:+.2f} gt={g['gt']} hack={g['hack']} ---") + print("SHOULD: coherent code/prose; ELSE token salad => diverged") + print(repr(g["text"][:400])) + + +def _cols(rows: list[dict]) -> None: + def trend(name, lo=None): + v = [r[name] for r in rows if name in r and r[name] is not None] + if not v: + return f"{name:7s} (absent)" + first, last5 = v[0], sum(v[-5:]) / len(v[-5:]) + arrow = "UP" if last5 > first + 1e-6 else ("DOWN" if last5 < first - 1e-6 else "flat") + warn = f" <-- dipped below {lo} (min {min(v):.2f})" if lo is not None and min(v) < lo else "" + return f"{name:7s} first={first:+.3f} last5={last5:+.3f} [{arrow}]{warn}" + + print("\n=== key columns (first vs last-5-mean) ===") + print("SHOULD interpret: hack_s UP=emerging (vanilla expects this; route's real test is") + print(" the deploy eval); gt_s UP=learning to solve; cin_t>cin_s early then decays as V") + print(" goes stale; cout~0 is the one_sided identity NOT efficacy; bh>0 iff route.") + for nm, lo in [("hack_s", None), ("gt_s", None), ("cin_s", None), ("cin_t", None), + ("cout", None), ("bh", None), ("gn", None), ("loss", None)]: + if any(nm in r for r in rows): + print(" " + trend(nm, lo)) + + +def main() -> None: + path = _resolve(sys.argv[1] if len(sys.argv) > 1 else None) + rows = [json.loads(l) for l in path.open() if l.strip()] + print(f"=== AUDIT {path} ({len(rows)} steps) ===") + if not rows: + sys.exit("log is empty") + _gens(rows) + _cols(rows) + + +if __name__ == "__main__": + main() diff --git a/scripts/results.py b/scripts/results.py index b9e2f6f..ab38250 100644 --- a/scripts/results.py +++ b/scripts/results.py @@ -1,4 +1,4 @@ -"""Aggregate logs/*.log (JSONL rows from train.py) into a last-5-step table. +"""Aggregate logs/run_*.log (JSONL rows from train.py) into a last-5-step table. hack = student exploited rate (hack_s); solve = student gt_correct rate (gt_s). Each row is one run; the arm/seed come from the filename run_{arm}_s{seed}.log. @@ -18,7 +18,7 @@ def last5(rows, key): return sum(vals) / len(vals) if vals else float("nan") -def main(pattern: str = "logs/*.log") -> None: +def main(pattern: str = "logs/run_*.log") -> None: table = [] for path in sorted(glob.glob(pattern)): rows = [json.loads(l) for l in open(path) if l.strip()] @@ -35,4 +35,4 @@ def main(pattern: str = "logs/*.log") -> None: if __name__ == "__main__": - main(sys.argv[1] if len(sys.argv) > 1 else "logs/*.log") + main(sys.argv[1] if len(sys.argv) > 1 else "logs/run_*.log") diff --git a/src/projected_grpo/antipasto.py b/src/projected_grpo/antipasto.py index f25eed2..0b1c528 100644 --- a/src/projected_grpo/antipasto.py +++ b/src/projected_grpo/antipasto.py @@ -1,87 +1,76 @@ -"""SVD-basis adapter (pseudocode 01). Prior work: github.com/wassname/AntiPaSTO. +"""Parametrized-LoRA adapter (pseudocode 01). -Train ONE per-module knob δS in the singular-value basis of each Linear. The -extracted hack direction V, the live gradient, and the projection all live in -these same low-rank weight-aligned coordinates (ℝ^r). At δS=δS_hack=0 the adapter -is bit-identical to the base model (W is never reconstructed on the main path), -so a no_grad forward with the knobs zeroed gives π_ref for free -- no 2nd model. +Per target Linear, freeze a random projection A ∈ ℝ^{r×d_in} and train B ∈ ℝ^{d_out×r} +(plus a same-shape quarantine B_hack for the route arm). The added weight is + + δW = (B + B_hack) · A # applied as y += (A x) @ (B+B_hack)ᵀ + +which is LINEAR in the trained knob B: ∂δW/∂B is the constant A. That is the one +property the gradient projection needs -- a hack direction V extracted once stays a +fixed weight-space direction, so we can project it out every step without re-deriving +it (cf. proj.py). At B=B_hack=0 the adapter is identity (W is never touched), so a +no_grad forward with the knobs zeroed is π_ref for free -- no 2nd model. B_hack has the +SAME shape as B, so the quarantine is capacity-matched (it cannot become a sink). + +A is frozen and random (not the SVD basis): showing the projection works in an +arbitrary fixed basis is stronger than an SVD-aligned one, and it drops the whole +per-module SVD subsystem. """ from __future__ import annotations -import hashlib from contextlib import contextmanager from dataclasses import dataclass -from pathlib import Path import torch import torch.nn as nn from loguru import logger -from safetensors.torch import load_file, save_file TARGET = {"q_proj", "k_proj", "v_proj", "o_proj", "in_proj_qkv", "in_proj_z", "in_proj_a", "in_proj_b", "out_proj", # GatedDeltaNet "up_proj", "gate_proj", "down_proj"} -SVD_CACHE = Path("svd_cache") @dataclass class Wrap: - layer: nn.Linear # carries .U, .Vh buffers and .δS, .δS_hack params + layer: nn.Linear # carries frozen .A buffer and trained .B, .B_hack params r: int @property - def δS(self) -> nn.Parameter: - return self.layer.δS + def B(self) -> nn.Parameter: + return self.layer.B @property - def δS_hack(self) -> nn.Parameter: - return self.layer.δS_hack + def B_hack(self) -> nn.Parameter: + return self.layer.B_hack -def svd_cached(W: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """U Σ Vh = W, reduced. Cache key = sha256(W) so a stale cache is impossible - (different W -> different file). Native dtype in, fp32 SVD, cast back.""" - SVD_CACHE.mkdir(exist_ok=True) - # .view(uint8) so the byte-hash is dtype-agnostic: numpy has no bf16, and every - # preset (smoke included, now) loads weights in bf16. - key = hashlib.sha256(W.detach().cpu().contiguous().view(torch.uint8).numpy().tobytes()).hexdigest() - path = SVD_CACHE / f"{key}.safetensors" - if path.exists(): - t = load_file(path) # .to(W): match W's device+dtype - return t["U"].to(W), t["S"].to(W), t["Vh"].to(W) # (cache hit loads on cpu -> move to cuda) - U, S, Vh = torch.linalg.svd(W.float(), full_matrices=False) - # safetensors needs contiguous cpu tensors; svd outputs are neither on cuda. - save_file({"U": U.contiguous().cpu(), "S": S.contiguous().cpu(), "Vh": Vh.contiguous().cpu()}, path) - return U.to(W), S.to(W), Vh.to(W) +def _lora_hook(lin: nn.Linear, x_in, y): + x = x_in[0] # (..., d_in) + h = x @ lin.A.T # (..., r) into the fixed random subspace + return y + h @ (lin.B + lin.B_hack).T # (..., d_out) B=B_hack=0 -> identity -def _δ_hook(lin: nn.Linear, x_in, y): - x = x_in[0] # (..., d_in) - h = x @ lin.Vh.T # (..., r) into singular basis - h = h * (lin.δS + lin.δS_hack) # scale per singular axis (forward uses the SUM) - return y + h @ lin.U.T # (..., d_out) δS+δS_hack=0 -> identity - - -def wrap(model: nn.Module) -> dict[str, Wrap]: +def wrap(model: nn.Module, r: int = 32) -> dict[str, Wrap]: wrappers: dict[str, Wrap] = {} for name, lin in model.named_modules(): if not isinstance(lin, nn.Linear) or name.split(".")[-1] not in TARGET: continue - U, _Σ, Vh = svd_cached(lin.weight) # W ∈ ℝ^{d_out×d_in} - r = min(lin.in_features, lin.out_features) - lin.register_buffer("U", U) - lin.register_buffer("Vh", Vh) - z = lambda: nn.Parameter(torch.zeros(r, dtype=lin.weight.dtype, device=lin.weight.device)) - lin.register_parameter("δS", z()) # on the layer's device, else the hook mixes cpu/cuda - lin.register_parameter("δS_hack", z()) - lin.register_forward_hook(_δ_hook) - wrappers[name] = Wrap(lin, r) + d_out, d_in = lin.out_features, lin.in_features + rk = min(r, d_in) # need rk ≤ d_in for orthonormal rows of A + A = torch.empty(rk, d_in, dtype=torch.float32, device=lin.weight.device) + nn.init.orthogonal_(A) # A Aᵀ = I_rk; QR (geqrf) needs fp32, cast after + lin.register_buffer("A", A.to(lin.weight.dtype)) # frozen bf16 buffer for the forward matmul + zeros = lambda: nn.Parameter(torch.zeros(d_out, rk, dtype=lin.weight.dtype, device=lin.weight.device)) + lin.register_parameter("B", zeros()) # zero-init -> δW=0 at start (π_ref) and ∇B≠0 (A≠0) + lin.register_parameter("B_hack", zeros()) + lin.register_forward_hook(_lora_hook) + wrappers[name] = Wrap(lin, rk) for p in model.parameters(): - p.requires_grad_(False) + p.requires_grad_(False) # freeze base; A is a buffer so already frozen for w in wrappers.values(): - w.δS.requires_grad_(True) - w.δS_hack.requires_grad_(True) - logger.info(f"wrapped {len(wrappers)} Linears with δS knobs (r in " + w.B.requires_grad_(True) + w.B_hack.requires_grad_(True) + logger.info(f"wrapped {len(wrappers)} Linears with LoRA knobs (r in " f"[{min(w.r for w in wrappers.values())}, {max(w.r for w in wrappers.values())}])") return wrappers @@ -89,29 +78,29 @@ def wrap(model: nn.Module) -> dict[str, Wrap]: @contextmanager def zeroed(wrappers: dict[str, Wrap]): """Both knobs -> 0 (free reference model). Restores on exit.""" - saved = [(w.δS.data.clone(), w.δS_hack.data.clone()) for w in wrappers.values()] + saved = [(w.B.data.clone(), w.B_hack.data.clone()) for w in wrappers.values()] for w in wrappers.values(): - w.δS.data.zero_() - w.δS_hack.data.zero_() + w.B.data.zero_() + w.B_hack.data.zero_() try: yield finally: - for w, (s, sh) in zip(wrappers.values(), saved): - w.δS.data.copy_(s) - w.δS_hack.data.copy_(sh) + for w, (b, bh) in zip(wrappers.values(), saved): + w.B.data.copy_(b) + w.B_hack.data.copy_(bh) @contextmanager def ablate_quarantine(wrappers: dict[str, Wrap]): - """δS_hack -> 0 only (deploy-time ablation / extraction). No-op for erase (δS_hack stays 0).""" - saved = [w.δS_hack.data.clone() for w in wrappers.values()] + """B_hack -> 0 only (deploy-time ablation / extraction). No-op for erase (B_hack stays 0).""" + saved = [w.B_hack.data.clone() for w in wrappers.values()] for w in wrappers.values(): - w.δS_hack.data.zero_() + w.B_hack.data.zero_() try: yield finally: - for w, sh in zip(wrappers.values(), saved): - w.δS_hack.data.copy_(sh) + for w, bh in zip(wrappers.values(), saved): + w.B_hack.data.copy_(bh) def per_token_logps(logits: torch.Tensor, ids: torch.Tensor) -> torch.Tensor: diff --git a/src/projected_grpo/extract_vhack_grad.py b/src/projected_grpo/extract_vhack_grad.py index 33fc2f3..23abcaf 100644 --- a/src/projected_grpo/extract_vhack_grad.py +++ b/src/projected_grpo/extract_vhack_grad.py @@ -48,11 +48,11 @@ def extract_v_hack(model, tok, wrappers: dict[str, Wrap], pairs: list[Pair], assert torch.isfinite(ℒ), f"non-finite NLL on {label} pair" # fail fast, never skip ℒ.backward() for n, w in wrappers.items(): - G[label][n].append(w.δS.grad.detach().clone()) + G[label][n].append(w.B.grad.reshape(-1).detach().clone()) # (d_out·r,) V_hack, V_sv = {}, {} for n in wrappers: - D = torch.stack(G["hack"][n]) - torch.stack(G["clean"][n]) # (n_pairs, r); pairing cancels prompt noise + D = torch.stack(G["hack"][n]) - torch.stack(G["clean"][n]) # (n_pairs, d_out·r); pairing cancels prompt noise U_d, S_d, Vh_d = torch.linalg.svd(D.float(), full_matrices=False) V = Vh_d[:k] # (k, r), rows orthonormal in ℝ^r # orient hack-ward by per-pair majority vote (outlier-robust; proj gates on ⟨g,v_i⟩>0) diff --git a/src/projected_grpo/proj.py b/src/projected_grpo/proj.py index c11da49..c69dd62 100644 --- a/src/projected_grpo/proj.py +++ b/src/projected_grpo/proj.py @@ -1,6 +1,6 @@ """Gradient projection: erase / route the hack component (pseudocode 03). -Once per optimizer step, after δS.grad is populated, before opt.step(). The bet: +Once per optimizer step, after B.grad is populated, before opt.step(). The bet: subtract the hack-ward component of the live GRPO gradient (basis V from labeled pairs) so the model descends what's left and cannot move hack-ward even when the reward rewards it. @@ -40,20 +40,21 @@ def project_one(g: Float[Tensor, "r"], V: Float[Tensor, "k r"], def project_all(wrappers, V_hack, measure_only: bool, route: bool, gate_mode: str = "one_sided", overshoot: float = 1.0, preserve_magnitude: bool = False): - """arms: none -> measure_only | erase -> write g_proj | route -> erase + park in δS_hack.""" + """arms: none -> measure_only | erase -> write g_proj | route -> erase + park in B_hack. + The knob B is (d_out, r); project_one is dim-agnostic so we flatten to ℝ^{d_out·r}.""" cins, couts, fired = [], [], 0 for name, w in wrappers.items(): if name not in V_hack: # below noise floor -> no hack signal -> skip continue - g_proj, removed, cin, cout = project_one( - w.δS.grad, V_hack[name], gate_mode, preserve_magnitude, overshoot) + g = w.B.grad.reshape(-1) # (d_out·r,) V_hack[name] is (k, d_out·r) + g_proj, removed, cin, cout = project_one(g, V_hack[name], gate_mode, preserve_magnitude, overshoot) cins.append(cin); couts.append(cout) fired += int(cin > 1e-6) if measure_only: # vanilla: measure cin/cout, leave g continue - w.δS.grad = g_proj # erase: descend the hack-free gradient + w.B.grad = g_proj.reshape_as(w.B.grad) # erase: descend the hack-free gradient if route: - w.δS_hack.grad = removed # park hack part in deletable knob (Cloud 2024) + w.B_hack.grad = removed.reshape_as(w.B_hack) # park hack part in deletable knob (Cloud 2024) n = max(len(cins), 1) mean = lambda xs: torch.stack(xs).mean().item() if xs else 0.0 return mean(cins), mean(couts), fired / n @@ -65,5 +66,6 @@ def cos_overlap(grads: dict, V_hack) -> float: for name, g in grads.items(): if name not in V_hack or g is None: continue + g = g.reshape(-1) # B-grad (d_out, r) -> ℝ^{d_out·r} cs.append(torch.relu(V_hack[name] @ g).norm() / g.norm().clamp_min(1e-12)) return torch.stack(cs).mean().item() if cs else 0.0 diff --git a/src/projected_grpo/train.py b/src/projected_grpo/train.py index 40bda98..d167770 100644 --- a/src/projected_grpo/train.py +++ b/src/projected_grpo/train.py @@ -6,7 +6,7 @@ Unbiased normalization: Dr.GRPO, Liu et al. 2025, arXiv:2503.20783 (drop the generate -> grade -> backward -> project -> step -Arms (one knob): none -> measure_only | erase -> write g_proj | route -> +park δS_hack. +Arms (one knob): none -> measure_only | erase -> write g_proj | route -> +park B_hack. Presets: smoke (tiny-random, GPU bf16, the only gate) | fast | full. python -m projected_grpo.train smoke --intervention=erase ... @@ -73,7 +73,8 @@ class Config: n_problems: int = 6 env_mode: Optional[EnvMode] = None # force a single mode (emergence cells) mix_ratio: float = 0.5 # G_t / G (teacher fraction) - teacher_pool_dir: str = "" + teacher_pool_dir: str = "out/pools/teacher_pool" + lora_rank: int = 32 # r: A is (r×d_in) frozen, B is (d_out×r) trained # optim lr: float = 1e-2 adam_beta1: float = 0.9 @@ -82,7 +83,7 @@ class Config: warmup_frac: float = 0.1 grad_clip: float = 10.0 # GRPO - beta: float = 0.0 # KL to free ref (δS=0); 0 = off + beta: float = 0.0 # KL to free ref (B=0); 0 = off clip: float = 0.2 # projection gate_mode: str = "one_sided" @@ -130,7 +131,7 @@ def load_pool(pool_dir: str) -> dict[str, list[dict]]: # ── inner GRPO step on one prompt ───────────────────────────────────────────── def grpo_step(model, tok, wrappers, problem, pool_rows, G_s, G_t, cfg, pad_id, denom): - """Returns (g_s, g_t, metrics) for this prompt; δS.grad left holding g_s+g_t. + """Returns (g_s, g_t, metrics) for this prompt; B.grad left holding g_s+g_t. g_s/g_t are per-module dicts so the outer loop can accumulate + diagnose cin.""" prompt_ids = tok(problem.prompt(), return_tensors="pt").input_ids.to(cfg.device) P = prompt_ids.shape[1] @@ -165,7 +166,7 @@ def grpo_step(model, tok, wrappers, problem, pool_rows, G_s, G_t, cfg, pad_id, d if s_rew: # stash one student gen for the jsonl + the final coherence eyeball m["gen"] = dict(text=texts[0], hack=s_rew[0].exploited, gt=s_rew[0].gt_correct, reward=s_rew[0].reward) - empty = {n: torch.zeros_like(w.δS) for n, w in wrappers.items()} + empty = {n: torch.zeros_like(w.B) for n, w in wrappers.items()} if R.max() - R.min() < 1e-6: # zero-variance group => adv≡0 => pure waste (bail) return empty, dict(empty), m @@ -193,15 +194,15 @@ def grpo_step(model, tok, wrappers, problem, pool_rows, G_s, G_t, cfg, pad_id, d is_s = torch.tensor(is_student, dtype=torch.float32, device=cfg.device)[:, None] # split student/teacher grads. zero+clone BETWEEN passes is load-bearing: after - # L_t.backward, δS.grad already holds g_s+g_t (accumulation), so reset first. + # L_t.backward, B.grad already holds g_s+g_t (accumulation), so reset first. model.zero_grad() L_s = (Lp * comp_mask * is_s).sum() / denom L_s.backward(retain_graph=True) - g_s = {n: w.δS.grad.detach().clone() for n, w in wrappers.items()} + g_s = {n: w.B.grad.detach().clone() for n, w in wrappers.items()} model.zero_grad() L_t = (Lp * comp_mask * (1 - is_s)).sum() / denom L_t.backward() - g_t = {n: w.δS.grad.detach().clone() for n, w in wrappers.items()} + g_t = {n: w.B.grad.detach().clone() for n, w in wrappers.items()} m["loss"] = (L_s + L_t).item() return g_s, g_t, m @@ -226,8 +227,8 @@ def save_ckpt(wrappers, rows, tag): Path("out/ckpt").mkdir(parents=True, exist_ok=True) flat = {} for n, w in wrappers.items(): - flat[f"δS/{n}"] = w.δS.detach().cpu() - flat[f"δS_hack/{n}"] = w.δS_hack.detach().cpu() + flat[f"B/{n}"] = w.B.detach().cpu() + flat[f"B_hack/{n}"] = w.B_hack.detach().cpu() save_file(flat, f"out/ckpt/ckpt{tag}.safetensors") Path(f"out/ckpt/rows{tag}.json").write_text(json.dumps(rows)) @@ -246,9 +247,11 @@ def main(cfg: Config): pad_id = tok.pad_token_id if tok.pad_token_id is not None else tok.eos_token_id model = AutoModelForCausalLM.from_pretrained(cfg.model, dtype=dt).to(cfg.device) model.eval() - wrappers = wrap(model) + wrappers = wrap(model, cfg.lora_rank) + n_train = sum(w.B.numel() for w in wrappers.values()) + logger.info(f"trainable B: {n_train:,} (+{n_train:,} B_hack quarantine), lora_rank={cfg.lora_rank}") - params = [w.δS for w in wrappers.values()] + [w.δS_hack for w in wrappers.values()] + params = [w.B for w in wrappers.values()] + [w.B_hack for w in wrappers.values()] opt = torch.optim.AdamW(params, lr=cfg.lr, betas=(cfg.adam_beta1, cfg.adam_beta2), weight_decay=cfg.weight_decay) sched = get_cosine_schedule_with_warmup(opt, int(cfg.warmup_frac * cfg.steps), cfg.steps) @@ -258,8 +261,8 @@ def main(cfg: Config): V_raw, V_sv = resolve_or_extract(model, tok, wrappers, cfg.v_hack_path, cfg.vhack_pairs_path, cfg.v_hack_top_k, cfg.tau_axis, cfg.n_heldout) V_hack = postprocess_v_hack(V_raw, V_sv, cfg.v_hack_k, cfg.v_hack_drop_bottom_frac) - δS_ref = next(iter(wrappers.values())).δS # V lives in δS's space: match dtype+device - V_hack = {n: V.to(δS_ref) for n, V in V_hack.items()} # extract/load give fp32[/cpu]; grads are bf16 cuda + B_ref = next(iter(wrappers.values())).B # V lives in B's space: match dtype+device + V_hack = {n: V.to(B_ref) for n, V in V_hack.items()} # extract/load give fp32[/cpu]; grads are bf16 cuda logger.info(f"V_hack over {len(V_hack)} modules (k_use={cfg.v_hack_k}); " f"arm={cfg.intervention} -> measure_only={cfg.intervention == 'none'}") @@ -275,17 +278,21 @@ def main(cfg: Config): log = logname.open("w") rows = [] + # per-step table: one logger.info row/step (legend once). SHORT runs (20-200 steps), + # so a row/step reads cleanly -- no tqdm bar (it redraws into piped logs). logger.info( - "SHOULD (per-step JSONL rows + tqdm postfix): erase/route cin_t>0 and >cin_s early " - "(teacher pool carries the hack signal), decaying as V goes stale; cout~0 under one_sided " - "is an arithmetic identity, NOT efficacy; |δSh|>0 iff route, else 0; hack_s stays 0 on " - "tiny-random (mechanics smoke). If cin_t~0 throughout -> V missed the live GRPO gradient.") - pbar = tqdm(range(cfg.steps), desc=f"train[{cfg.intervention}] s{cfg.seed}", - mininterval=120, maxinterval=120) - for step in pbar: + "per-step columns: step | rew=mean reward | gt_s=student solve | hack_s=student exploit (HEADLINE) |" + " hack_t=teacher exploit (pool sanity) | cin_s/cin_t=student/teacher grad overlap with V (relu-agg) |" + " cout=residual hack-ward leak | loss | gn=pre-clip grad L2 | lr | |Bh|=quarantine norm.\n" + "SHOULD: cin_t>cin_s early (teacher carries the hack signal), decaying as V stales; cout~0 under " + "one_sided is an ARITHMETIC IDENTITY not efficacy; |Bh|>0 iff route; hack_s stays 0 on tiny-random.") + hdr = (f"{'step':>4} {'rew':>6} {'gt_s':>5} {'hack_s':>6} {'hack_t':>6} {'cin_s':>6} {'cin_t':>6} " + f"{'cout':>6} {'loss':>7} {'gn':>8} {'lr':>8} {'|Bh|':>8}") + logger.info(hdr) + for step in range(cfg.steps): opt.zero_grad() - g_s_acc = {n: torch.zeros_like(w.δS) for n, w in wrappers.items()} - g_t_acc = {n: torch.zeros_like(w.δS) for n, w in wrappers.items()} + g_s_acc = {n: torch.zeros_like(w.B) for n, w in wrappers.items()} + g_t_acc = {n: torch.zeros_like(w.B) for n, w in wrappers.items()} ms = [] for _ in range(cfg.prompts_per_step): prob = problems[random.randrange(len(problems))] @@ -295,8 +302,8 @@ def main(cfg: Config): g_s_acc[n] += g_s[n]; g_t_acc[n] += g_t[n] ms.append(m) for n, w in wrappers.items(): - w.δS.grad = g_s_acc[n] + g_t_acc[n] # combined live GRPO gradient on the knob - w.δS_hack.grad = None # quarantine moves ONLY via routed `removed` (set in project_all) + w.B.grad = g_s_acc[n] + g_t_acc[n] # combined live GRPO gradient on the knob + w.B_hack.grad = None # quarantine moves ONLY via routed `removed` (set in project_all) cin_s = cos_overlap(g_s_acc, V_hack) if V_hack else 0.0 cin_t = cos_overlap(g_t_acc, V_hack) if V_hack else 0.0 @@ -309,12 +316,12 @@ def main(cfg: Config): gn = torch.nn.utils.clip_grad_norm_([p for p in params if p.grad is not None], cfg.grad_clip) opt.step(); sched.step() - dSh = sum(w.δS_hack.detach().norm().item() for w in wrappers.values()) + bh = sum(w.B_hack.detach().norm().item() for w in wrappers.values()) row = dict(step=step, hack_s=_mean([m["hack_s"] for m in ms]), gt_s=_mean([m["gt_s"] for m in ms]), hack_t=_mean([m["hack_t"] for m in ms]), gt_t=_mean([m["gt_t"] for m in ms]), reward=_mean([m["reward"] for m in ms]), loss=_mean([m.get("loss", 0.0) for m in ms]), gn=gn.item(), cin_s=cin_s, cin_t=cin_t, cin=cin, cout=cout, fired=fired, - dSh=dSh) + bh=bh) if "gen" in ms[0]: # log the first prompt's first student gen (text + flags) row["gen"] = ms[0]["gen"] @@ -325,12 +332,12 @@ def main(cfg: Config): V_raw, V_sv = extract_v_hack(model, tok, wrappers, default_pairs(), cfg.v_hack_top_k, cfg.tau_axis, cfg.n_heldout) V_hack = postprocess_v_hack(V_raw, V_sv, cfg.v_hack_k, cfg.v_hack_drop_bottom_frac) - V_hack = {n: V.to(δS_ref) for n, V in V_hack.items()} # keep V in δS's bf16/cuda space + V_hack = {n: V.to(B_ref) for n, V in V_hack.items()} # keep V in B's bf16/cuda space shared = [n for n in V_hack if n in old] ov = _mean([abs((V_hack[n][0] @ old[n][0]).item()) for n in shared]) if shared else 0.0 row["basis_overlap"] = ov # GUARD: should sit near 1.0; <~0.2 => refresh rotated off-hack - # deployment-time eval (route only): zero δS_hack, eval the ablated model + # deployment-time eval (route only): zero B_hack, eval the ablated model if is_route and cfg.eval_ablate_every and step % cfg.eval_ablate_every == 0: with ablate_quarantine(wrappers): hd, sd = eval_hack_solve(model, tok, wrappers, problems, cfg.eval_n_prompts, @@ -339,9 +346,10 @@ def main(cfg: Config): rows.append(row) log.write(json.dumps(row) + "\n"); log.flush() - pbar.set_postfix(hack_s=f"{row['hack_s']:.2f}", cin_s=f"{cin_s:.2f}", cin_t=f"{cin_t:.2f}", - cout=f"{cout:.2f}", δSh=f"{dSh:.2f}", loss=f"{row['loss']:.3f}", - refresh=False) # don't force a bar redraw every step (pollutes piped logs) + lr = sched.get_last_lr()[0] + logger.info(f"{step:4d} {row['reward']:+6.2f} {row['gt_s']:5.2f} {row['hack_s']:6.2f} " + f"{row['hack_t']:6.2f} {cin_s:6.3f} {cin_t:6.3f} {cout:6.3f} {row['loss']:+7.3f} " + f"{gn.item():8.1e} {lr:8.1e} {bh:8.3f}") if step % 25 == 0: save_ckpt(wrappers, rows, cfg.out_tag or f"_{cfg.intervention}_s{cfg.seed}") @@ -356,14 +364,14 @@ def _bluf(cfg, rows, wrappers, model, tok, problems, pad_id): last5 = rows[-5:] hack = _mean([r["hack_s"] for r in last5]); solve = _mean([r["gt_s"] for r in last5]) cin_t = _mean([r["cin_t"] for r in last5]); cout = _mean([r["cout"] for r in last5]) - dSh = rows[-1]["dSh"] + bh = rows[-1]["bh"] out = f"logs/run_{cfg.intervention}{cfg.out_tag}_s{cfg.seed}.log" # mechanics cue: projection arms want the teacher hack signal present (cin_t>0); none is neutral cue = "\U0001F7E2" if (cfg.intervention == "none" or cin_t > 0) else "\U0001F534" summary = {"arm": cfg.intervention, "steps": len(rows), "hack↓": round(hack, 3), "solve→": round(solve, 3), - "cin_t↑": round(cin_t, 3), "cout→0": round(cout, 3), "|δSh|": round(dSh, 2)} + "cin_t↑": round(cin_t, 3), "cout→0": round(cout, 3), "|Bh|": round(bh, 2)} if cfg.intervention == "route": with ablate_quarantine(wrappers): hk_abl, sv_abl = eval_hack_solve(model, tok, wrappers, problems, cfg.eval_n_prompts or 2, @@ -376,7 +384,7 @@ def _bluf(cfg, rows, wrappers, model, tok, problems, pad_id): caption = (f"Table: last-5-step means, arm={cfg.intervention} seed={cfg.seed}. " "hack↓ student exploit rate, solve→ gt-correct rate, cin_t↑ teacher-grad " "overlap with V (relu-before-agg), cout→0 residual hack-ward leak (identity under " - "one_sided, not efficacy), |δSh| quarantine norm (>0 iff route).") + "one_sided, not efficacy), |Bh| quarantine norm (>0 iff route).") logger.info(f"out: {out}") # Last student generation -- a coherence eyeball before the numbers. SHOULD: real # code/prose for the problem. If it is token salad the policy diverged and the eval