refactor: SVD-diag knob -> parametrized LoRA (fixed A, train B)

Replace the AntiPaSTO SVD-diag adapter (δW = U·diag(δS)·Vh) with a parametrized
LoRA: per Linear a frozen random A (r×d_in, semi-orthonormal) and a trained
B (d_out×r), δW = (B+B_hack)·A. δW stays LINEAR in the trained knob -- the one
property the projection needs (a once-extracted V stays a fixed weight-space
direction) -- but the whole per-module SVD subsystem is gone (no svd_cached, no
SVD_CACHE, no bf16 hash/contiguity friction). B_hack is the SAME shape as B, so the
route quarantine is capacity-matched by construction (old-repo route2 diverged from
an oversized quarantine sink).

- antipasto: wrap() builds A (fp32 orthogonal_, geqrf has no bf16 -> cast after) +
  B/B_hack zero-init on the layer device; forward y + (B+B_hack)@(A@x).
- proj: project_one is dim-agnostic; project_all flattens B.grad (d_out·r) and
  reshapes back. cos_overlap flattens too.
- extract: V from the SVD of stacked B.grad pair-diffs (d_out·r).
- train: B/B_hack rename, lora_rank config, per-step aligned table + legend
  (replaces the sparse tqdm postfix), clean argv via preset/Config defaults.
- justfile: collapse smoke-vanilla/smoke-route/fast-vanilla/full-vanilla into
  smoke/fast/full *ARGS + a `sweep` recipe that fires vanilla|erase|route as pueue
  jobs. results.py: glob run_*.log (skip loguru verbose logs).

Smoke (GPU bf16, all three arms) green: cout~0 one_sided identity holds in the LoRA
basis, |Bh|=0 for erase/none, route parks into B_hack.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-01 03:43:18 +00:00
parent fbacefd433
commit 409d9c9425
8 changed files with 288 additions and 141 deletions
@@ -0,0 +1,85 @@
# Parametrized-LoRA adapter (replace SVD-diag), clean entrypoints
## Goal
Swap the AntiPaSTO SVD-diag knob (δW = U·diag(δS)·Vh, δS∈ℝ^r) for a parametrized
LoRA (δW = B·A, A a frozen random projection, B trained). Keep δW LINEAR in the
trained knob so the gradient projection stays a fixed-direction removal, but drop
the whole SVD subsystem (svd_cached, SVD_CACHE, U/Vh, and the bf16 hash/contiguity
friction). Same change folds in the entrypoint cleanup (clean argv, one sweep
recipe firing the arms as pueue jobs).
## Why
- Linearity is the ONLY property the projection needs: a once-extracted V stays a
fixed weight-space direction (∂δW/∂B is the constant A). Fixed-A LoRA gives that
with no SVD. Full vanilla LoRA (train A and B) is bilinear -> V drifts -> needs
per-step re-extraction; avoid it.
- π_ref free at B=0 (δW=0, base model untouched) and extract-at-init hold (∂L/∂B =
δ_out·(A x)ᵀ ≠ 0 at B=0 since A≠0), same as the SVD knob.
- Capacity lesson (old-repo route2): an oversized quarantine becomes a low-resistance
sink (qE~0.97). Fixed-A LoRA with B_hack the SAME shape + SAME A is capacity-matched
by construction.
- More standard / credible than "diagonal scaling in the SVD basis", and showing the
projection works in an arbitrary fixed basis is a stronger result than SVD-only.
## Scope
In: `antipasto.wrap` (adapter), `extract_vhack_grad` (V in B-space), `proj` (operate
on B.grad), route quarantine (B_hack), Config defaults + justfile entrypoints.
Out: env / 4 hacks, GRPO loss, teacher pool, grader, and `project_one` MATH (it is
already dimension-agnostic; only the per-module vector dim changes).
## Design (math)
Per target Linear W ∈ ^{d_out×d_in}, rank r = `lora_rank`:
- A ∈ ^{r×d_in}: random semi-orthonormal rows, frozen buffer (requires_grad=False).
- B ∈ ^{d_out×r}: trained param, zero-init.
- B_hack ∈ ^{d_out×r}: zero-init param, route only (same shape => capacity-matched).
- forward hook: `y + (B + B_hack) @ (A @ x)` # δW=(B+B_hack)·A, linear in B; =0 at B=0
- gradient knob for projection: g = B.grad.reshape(-1) ∈ ^{d_out·r}.
Projection (proj.py, unchanged `project_one`):
- extract: D = stack_pairs(∇B_hack - ∇B_clean).reshape(n_pairs, d_out·r); V = top-k
right singular vecs of D (k × d_out·r); orient by per-pair sign vote.
- step: g = B.grad.reshape(-1); g_proj, removed = project_one(g, V, ...);
B.grad = g_proj.reshape(d_out, r); route: B_hack.grad = removed.reshape(d_out, r).
## Tasks
- [ ] T1 (adapter): rewrite `wrap` — register frozen A buffer + B, B_hack params
(zero-init, `device=lin.weight.device`, `dtype=lin.weight.dtype`); forward hook
`y + (B+B_hack)@(A@x)`. Delete `svd_cached`, `SVD_CACHE`, U/Vh/δS, `zeroed` uses δS.
- verify: `just smoke` green; log shows `wrapped N Linears (lora r=..)`.
- success: cout≈0 (one_sided identity), cin_t>0.
- likely_fail: B not zero-init -> π_ref breaks (T4 check fails).
- sneaky_fail: A left trainable -> V drifts, cout creeps >0 over steps; assert
`A.requires_grad is False` at wrap time.
- [ ] T2 (extract): `extract_v_hack` reads `B.grad` (d_out×r), flattens per module,
stacks pairs, SVD -> V (k × d_out·r). completion_nll/device already fixed.
- verify: smoke cache-MISS extracts; `extracted V from P pairs over N modules`.
- [ ] T3 (proj): `project_all` flattens B.grad -> `project_one` -> reshape back; route
parks `removed` reshaped into B_hack.grad. `project_one` untouched.
- verify: erase smoke cout≈0; route smoke |B_hack|>0 and grows.
- [ ] T4 (π_ref): at B=B_hack=0, wrapped forward == base forward, bit-identical on a
fixed input (no-grad). verify: add a smoke assert (max|Δlogits|==0 with knobs zeroed).
- sneaky_fail: hook adds A·x even at B=0 -> Δ≠0; the (B+B_hack)@ guards this (0@=0).
- [ ] T5 (clean argv): move `teacher_pool_dir`, `v_hack_path`, `mix_ratio`,
`eval_ablate_every`/`vhack_refresh_every` (route) into the fast/full PRESETS so the
CLI line is just `train.py fast --intervention=X --seed=S`. Drop SVD-only config.
- [ ] T6 (entrypoints): collapse justfile to `smoke`/`fast`/`full *ARGS` (arm via
`--intervention`, default erase) + `sweep PRESET='fast'` that pueue-adds none/erase/
route with why:/resolve: labels (mirror old-repo `queue-substrate`). Delete
smoke-vanilla/smoke-route/fast-vanilla/full-vanilla.
- verify: `just sweep` queues 3 labeled jobs; `pueue status` shows them.
## Context
- r (lora_rank): start 32. SVD knob was full r∈[1024,2560]; LoRA is low-rank, fewer
params, more expressive per-axis than diagonal.
- A init: semi-orthonormal (orthogonal_ on r×d_in, r<d_in => orthonormal rows) so A·x
is a clean norm-preserving projection into ^r.
- New code stays pseudocode-like: unicode/math names where they read as math, einops
for reshape (`rearrange(g, 'o r -> (o r)')`), explicit over clever.
## Log
- bf16/cuda path fixed + committed (fbacefd) before this refactor; GPU smoke is the gate.
- route is capacity-matched by construction here (B_hack==B shape), unlike old-repo route2.
## TODO
- If low-rank B underfits (can't hack at all), bump r or revisit A init. Decide from
the vanilla-arm hack_s trajectory, not a priori.
+22 -33
View File
@@ -22,33 +22,20 @@ build-pool:
# Smoke = the ONLY gate: same harness as production (train.py), tiny-random on GPU
# in bf16 (same device+dtype as fast/full, so the cuda+bf16 path is actually covered).
# beartype on so jaxtyping signatures get runtime-checked. 30 steps fires the
# every-25-step save_ckpt path. erase writes g_proj; cache-miss extracts v_hack.
# beartype on so jaxtyping signatures get runtime-checked. Arm/v_hack/pool/mix come
# from the smoke preset, so the only override you usually pass is --intervention.
smoke *ARGS: build-pool check
BEARTYPE=1 {{ TRAIN }} smoke --intervention=erase \
--v-hack-path=out/vhack/v_hack_smoke.safetensors \
--teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 {{ ARGS }}
BEARTYPE=1 {{ TRAIN }} smoke {{ ARGS }}
# Vanilla arm: V loaded for the measure_only diagnostic (cin), grad left untouched.
smoke-vanilla *ARGS: build-pool
BEARTYPE=1 {{ TRAIN }} smoke --intervention=none \
--v-hack-path=out/vhack/v_hack_smoke.safetensors \
--teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 {{ ARGS }}
# Every code path the real run walks: erase (gate) + vanilla (measure_only cin) +
# route (two-knob optim, ablated deploy-eval, online v_hack refresh + basis_overlap).
smoke-all: build-pool check
BEARTYPE=1 {{ TRAIN }} smoke --intervention=erase
BEARTYPE=1 {{ TRAIN }} smoke --intervention=none
BEARTYPE=1 {{ TRAIN }} smoke --intervention=route --eval-ablate-every=10 --eval-n-prompts=2 --vhack-refresh-every=10
just results
# Routing arm: parks the hack-ward grad in δS_hack, ablates at eval. Fires the
# two-param optimizer path, periodic ablated-eval, online v_hack refresh + the
# basis_overlap guard, and the final kept-vs-ablated BLUF.
smoke-route *ARGS: build-pool
BEARTYPE=1 {{ TRAIN }} smoke --intervention=route \
--v-hack-path=out/vhack/v_hack_smoke.safetensors \
--teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \
--eval-ablate-every=10 --eval-n-prompts=2 --vhack-refresh-every=10 {{ ARGS }}
# The trio = every code path the full run walks. Run before any real run.
smoke-all: smoke smoke-vanilla smoke-route results
# Run smoke twice: first warms the v_hack cache (miss), second hits it (cache-hit
# branch). Catches save/scope bugs that only manifest in one.
# Run smoke twice: cache MISS (extract v_hack) then HIT. Catches save/scope bugs.
smoke-both:
rm -f out/vhack/v_hack_smoke.safetensors
just smoke
@@ -58,15 +45,17 @@ smoke-both:
results:
uv run python scripts/results.py
# Real runs (Qwen3-4B, GPU). v_hack auto-extracts on cache-miss inside train.
full-vanilla *ARGS:
{{ TRAIN }} full --intervention=none {{ ARGS }}
# Real runs (Qwen3-4B, GPU). intervention defaults to erase; pass --intervention=none
# for vanilla. v_hack auto-extracts on cache-miss inside train.
fast *ARGS:
{{ TRAIN }} fast {{ ARGS }}
full *ARGS:
{{ TRAIN }} full --intervention=erase {{ ARGS }}
{{ TRAIN }} full {{ ARGS }}
fast-vanilla *ARGS:
{{ TRAIN }} fast --intervention=none --teacher-pool-dir=out/pools/teacher_pool {{ ARGS }}
fast *ARGS:
{{ TRAIN }} fast --intervention=erase --teacher-pool-dir=out/pools/teacher_pool {{ ARGS }}
# Fire the matched arm comparison as low-priority pueue jobs (vanilla | erase | route)
# for one PRESET/SEED/STEPS. One entrypoint queues the cell; pueue persists the logs.
sweep PRESET='fast' SEED='41' STEPS='60':
pueue add -w "$PWD" -o=-8 -l "why: {{ PRESET }} vanilla s{{ SEED }} x{{ STEPS }}; resolve: hack_s onset baseline (extracts v_hack)" -- {{ TRAIN }} {{ PRESET }} --intervention=none --seed={{ SEED }} --steps={{ STEPS }} --out-tag=_sweep
pueue add -w "$PWD" -o=-8 -l "why: {{ PRESET }} erase s{{ SEED }} x{{ STEPS }}; resolve: hack_s suppressed vs vanilla at matched solve" -- {{ TRAIN }} {{ PRESET }} --intervention=erase --seed={{ SEED }} --steps={{ STEPS }} --out-tag=_sweep
pueue add -w "$PWD" -o=-8 -l "why: {{ PRESET }} route s{{ SEED }} x{{ STEPS }}; resolve: deploy-hack < vanilla on held-out, |Bh|>0" -- {{ TRAIN }} {{ PRESET }} --intervention=route --seed={{ SEED }} --steps={{ STEPS }} --out-tag=_sweep --eval-ablate-every=10 --vhack-refresh-every=10 --eval-n-prompts=2
+74
View File
@@ -0,0 +1,74 @@
"""Audit a training run's per-step jsonl: quote first/last student generation
(coherence eyeball) + summarise the key columns with first-vs-last5 trend arrows.
Deterministic extraction; the /audit-log command feeds this to the LLM for a verdict.
The clean repo writes ONE jsonl per run (logs/run_<arm><tag>_s<seed>.log) -- each line
is a metric row, and rows that had a live student rollout carry a `gen` field
{text, hack, gt, reward}. No separate rollouts.jsonl.
Usage:
uv run python scripts/audit_log.py logs/run_erase_s41.log # a specific run log
uv run python scripts/audit_log.py # newest logs/run_*.log
"""
from __future__ import annotations
import json
import sys
from pathlib import Path
def _resolve(arg: str | None) -> Path:
if arg:
return Path(arg)
logs = sorted(Path("logs").glob("run_*.log"))
if not logs:
sys.exit("no logs/run_*.log found")
return logs[-1]
def _gens(rows: list[dict]) -> None:
gens = [r for r in rows if r.get("gen")]
if not gens:
print("\n(no student generations logged -- tiny-random smoke bails before grading?)")
return
print(f"\ngenerations: {len(gens)} steps carry a student gen, "
f"steps {gens[0]['step']}..{gens[-1]['step']}")
for r in (gens[0], gens[-1]):
g = r["gen"]
print(f"\n--- step {r['step']} reward={g['reward']:+.2f} gt={g['gt']} hack={g['hack']} ---")
print("SHOULD: coherent code/prose; ELSE token salad => diverged")
print(repr(g["text"][:400]))
def _cols(rows: list[dict]) -> None:
def trend(name, lo=None):
v = [r[name] for r in rows if name in r and r[name] is not None]
if not v:
return f"{name:7s} (absent)"
first, last5 = v[0], sum(v[-5:]) / len(v[-5:])
arrow = "UP" if last5 > first + 1e-6 else ("DOWN" if last5 < first - 1e-6 else "flat")
warn = f" <-- dipped below {lo} (min {min(v):.2f})" if lo is not None and min(v) < lo else ""
return f"{name:7s} first={first:+.3f} last5={last5:+.3f} [{arrow}]{warn}"
print("\n=== key columns (first vs last-5-mean) ===")
print("SHOULD interpret: hack_s UP=emerging (vanilla expects this; route's real test is")
print(" the deploy eval); gt_s UP=learning to solve; cin_t>cin_s early then decays as V")
print(" goes stale; cout~0 is the one_sided identity NOT efficacy; bh>0 iff route.")
for nm, lo in [("hack_s", None), ("gt_s", None), ("cin_s", None), ("cin_t", None),
("cout", None), ("bh", None), ("gn", None), ("loss", None)]:
if any(nm in r for r in rows):
print(" " + trend(nm, lo))
def main() -> None:
path = _resolve(sys.argv[1] if len(sys.argv) > 1 else None)
rows = [json.loads(l) for l in path.open() if l.strip()]
print(f"=== AUDIT {path} ({len(rows)} steps) ===")
if not rows:
sys.exit("log is empty")
_gens(rows)
_cols(rows)
if __name__ == "__main__":
main()
+3 -3
View File
@@ -1,4 +1,4 @@
"""Aggregate logs/*.log (JSONL rows from train.py) into a last-5-step table.
"""Aggregate logs/run_*.log (JSONL rows from train.py) into a last-5-step table.
hack = student exploited rate (hack_s); solve = student gt_correct rate (gt_s).
Each row is one run; the arm/seed come from the filename run_{arm}_s{seed}.log.
@@ -18,7 +18,7 @@ def last5(rows, key):
return sum(vals) / len(vals) if vals else float("nan")
def main(pattern: str = "logs/*.log") -> None:
def main(pattern: str = "logs/run_*.log") -> None:
table = []
for path in sorted(glob.glob(pattern)):
rows = [json.loads(l) for l in open(path) if l.strip()]
@@ -35,4 +35,4 @@ def main(pattern: str = "logs/*.log") -> None:
if __name__ == "__main__":
main(sys.argv[1] if len(sys.argv) > 1 else "logs/*.log")
main(sys.argv[1] if len(sys.argv) > 1 else "logs/run_*.log")
+51 -62
View File
@@ -1,87 +1,76 @@
"""SVD-basis adapter (pseudocode 01). Prior work: github.com/wassname/AntiPaSTO.
"""Parametrized-LoRA adapter (pseudocode 01).
Train ONE per-module knob δS in the singular-value basis of each Linear. The
extracted hack direction V, the live gradient, and the projection all live in
these same low-rank weight-aligned coordinates (^r). At δS=δS_hack=0 the adapter
is bit-identical to the base model (W is never reconstructed on the main path),
so a no_grad forward with the knobs zeroed gives π_ref for free -- no 2nd model.
Per target Linear, freeze a random projection A ∈ ^{r×d_in} and train B ∈ ^{d_out×r}
(plus a same-shape quarantine B_hack for the route arm). The added weight is
δW = (B + B_hack) · A # applied as y += (A x) @ (B+B_hack)ᵀ
which is LINEAR in the trained knob B: ∂δW/∂B is the constant A. That is the one
property the gradient projection needs -- a hack direction V extracted once stays a
fixed weight-space direction, so we can project it out every step without re-deriving
it (cf. proj.py). At B=B_hack=0 the adapter is identity (W is never touched), so a
no_grad forward with the knobs zeroed is π_ref for free -- no 2nd model. B_hack has the
SAME shape as B, so the quarantine is capacity-matched (it cannot become a sink).
A is frozen and random (not the SVD basis): showing the projection works in an
arbitrary fixed basis is stronger than an SVD-aligned one, and it drops the whole
per-module SVD subsystem.
"""
from __future__ import annotations
import hashlib
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
import torch
import torch.nn as nn
from loguru import logger
from safetensors.torch import load_file, save_file
TARGET = {"q_proj", "k_proj", "v_proj", "o_proj",
"in_proj_qkv", "in_proj_z", "in_proj_a", "in_proj_b", "out_proj", # GatedDeltaNet
"up_proj", "gate_proj", "down_proj"}
SVD_CACHE = Path("svd_cache")
@dataclass
class Wrap:
layer: nn.Linear # carries .U, .Vh buffers and .δS, .δS_hack params
layer: nn.Linear # carries frozen .A buffer and trained .B, .B_hack params
r: int
@property
def δS(self) -> nn.Parameter:
return self.layer.δS
def B(self) -> nn.Parameter:
return self.layer.B
@property
def δS_hack(self) -> nn.Parameter:
return self.layer.δS_hack
def B_hack(self) -> nn.Parameter:
return self.layer.B_hack
def svd_cached(W: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""U Σ Vh = W, reduced. Cache key = sha256(W) so a stale cache is impossible
(different W -> different file). Native dtype in, fp32 SVD, cast back."""
SVD_CACHE.mkdir(exist_ok=True)
# .view(uint8) so the byte-hash is dtype-agnostic: numpy has no bf16, and every
# preset (smoke included, now) loads weights in bf16.
key = hashlib.sha256(W.detach().cpu().contiguous().view(torch.uint8).numpy().tobytes()).hexdigest()
path = SVD_CACHE / f"{key}.safetensors"
if path.exists():
t = load_file(path) # .to(W): match W's device+dtype
return t["U"].to(W), t["S"].to(W), t["Vh"].to(W) # (cache hit loads on cpu -> move to cuda)
U, S, Vh = torch.linalg.svd(W.float(), full_matrices=False)
# safetensors needs contiguous cpu tensors; svd outputs are neither on cuda.
save_file({"U": U.contiguous().cpu(), "S": S.contiguous().cpu(), "Vh": Vh.contiguous().cpu()}, path)
return U.to(W), S.to(W), Vh.to(W)
def _lora_hook(lin: nn.Linear, x_in, y):
x = x_in[0] # (..., d_in)
h = x @ lin.A.T # (..., r) into the fixed random subspace
return y + h @ (lin.B + lin.B_hack).T # (..., d_out) B=B_hack=0 -> identity
def _δ_hook(lin: nn.Linear, x_in, y):
x = x_in[0] # (..., d_in)
h = x @ lin.Vh.T # (..., r) into singular basis
h = h * (lin.δS + lin.δS_hack) # scale per singular axis (forward uses the SUM)
return y + h @ lin.U.T # (..., d_out) δS+δS_hack=0 -> identity
def wrap(model: nn.Module) -> dict[str, Wrap]:
def wrap(model: nn.Module, r: int = 32) -> dict[str, Wrap]:
wrappers: dict[str, Wrap] = {}
for name, lin in model.named_modules():
if not isinstance(lin, nn.Linear) or name.split(".")[-1] not in TARGET:
continue
U, , Vh = svd_cached(lin.weight) # W ∈ ^{d_out×d_in}
r = min(lin.in_features, lin.out_features)
lin.register_buffer("U", U)
lin.register_buffer("Vh", Vh)
z = lambda: nn.Parameter(torch.zeros(r, dtype=lin.weight.dtype, device=lin.weight.device))
lin.register_parameter("δS", z()) # on the layer's device, else the hook mixes cpu/cuda
lin.register_parameter("δS_hack", z())
lin.register_forward_hook(_δ_hook)
wrappers[name] = Wrap(lin, r)
d_out, d_in = lin.out_features, lin.in_features
rk = min(r, d_in) # need rk ≤ d_in for orthonormal rows of A
A = torch.empty(rk, d_in, dtype=torch.float32, device=lin.weight.device)
nn.init.orthogonal_(A) # A Aᵀ = I_rk; QR (geqrf) needs fp32, cast after
lin.register_buffer("A", A.to(lin.weight.dtype)) # frozen bf16 buffer for the forward matmul
zeros = lambda: nn.Parameter(torch.zeros(d_out, rk, dtype=lin.weight.dtype, device=lin.weight.device))
lin.register_parameter("B", zeros()) # zero-init -> δW=0 at start (π_ref) and ∇B≠0 (A≠0)
lin.register_parameter("B_hack", zeros())
lin.register_forward_hook(_lora_hook)
wrappers[name] = Wrap(lin, rk)
for p in model.parameters():
p.requires_grad_(False)
p.requires_grad_(False) # freeze base; A is a buffer so already frozen
for w in wrappers.values():
w.δS.requires_grad_(True)
w.δS_hack.requires_grad_(True)
logger.info(f"wrapped {len(wrappers)} Linears with δS knobs (r in "
w.B.requires_grad_(True)
w.B_hack.requires_grad_(True)
logger.info(f"wrapped {len(wrappers)} Linears with LoRA knobs (r in "
f"[{min(w.r for w in wrappers.values())}, {max(w.r for w in wrappers.values())}])")
return wrappers
@@ -89,29 +78,29 @@ def wrap(model: nn.Module) -> dict[str, Wrap]:
@contextmanager
def zeroed(wrappers: dict[str, Wrap]):
"""Both knobs -> 0 (free reference model). Restores on exit."""
saved = [(w.δS.data.clone(), w.δS_hack.data.clone()) for w in wrappers.values()]
saved = [(w.B.data.clone(), w.B_hack.data.clone()) for w in wrappers.values()]
for w in wrappers.values():
w.δS.data.zero_()
w.δS_hack.data.zero_()
w.B.data.zero_()
w.B_hack.data.zero_()
try:
yield
finally:
for w, (s, sh) in zip(wrappers.values(), saved):
w.δS.data.copy_(s)
w.δS_hack.data.copy_(sh)
for w, (b, bh) in zip(wrappers.values(), saved):
w.B.data.copy_(b)
w.B_hack.data.copy_(bh)
@contextmanager
def ablate_quarantine(wrappers: dict[str, Wrap]):
"""δS_hack -> 0 only (deploy-time ablation / extraction). No-op for erase (δS_hack stays 0)."""
saved = [w.δS_hack.data.clone() for w in wrappers.values()]
"""B_hack -> 0 only (deploy-time ablation / extraction). No-op for erase (B_hack stays 0)."""
saved = [w.B_hack.data.clone() for w in wrappers.values()]
for w in wrappers.values():
w.δS_hack.data.zero_()
w.B_hack.data.zero_()
try:
yield
finally:
for w, sh in zip(wrappers.values(), saved):
w.δS_hack.data.copy_(sh)
for w, bh in zip(wrappers.values(), saved):
w.B_hack.data.copy_(bh)
def per_token_logps(logits: torch.Tensor, ids: torch.Tensor) -> torch.Tensor:
+2 -2
View File
@@ -48,11 +48,11 @@ def extract_v_hack(model, tok, wrappers: dict[str, Wrap], pairs: list[Pair],
assert torch.isfinite(), f"non-finite NLL on {label} pair" # fail fast, never skip
.backward()
for n, w in wrappers.items():
G[label][n].append(w.δS.grad.detach().clone())
G[label][n].append(w.B.grad.reshape(-1).detach().clone()) # (d_out·r,)
V_hack, V_sv = {}, {}
for n in wrappers:
D = torch.stack(G["hack"][n]) - torch.stack(G["clean"][n]) # (n_pairs, r); pairing cancels prompt noise
D = torch.stack(G["hack"][n]) - torch.stack(G["clean"][n]) # (n_pairs, d_out·r); pairing cancels prompt noise
U_d, S_d, Vh_d = torch.linalg.svd(D.float(), full_matrices=False)
V = Vh_d[:k] # (k, r), rows orthonormal in ^r
# orient hack-ward by per-pair majority vote (outlier-robust; proj gates on ⟨g,v_i⟩>0)
+8 -6
View File
@@ -1,6 +1,6 @@
"""Gradient projection: erase / route the hack component (pseudocode 03).
Once per optimizer step, after δS.grad is populated, before opt.step(). The bet:
Once per optimizer step, after B.grad is populated, before opt.step(). The bet:
subtract the hack-ward component of the live GRPO gradient (basis V from labeled
pairs) so the model descends what's left and cannot move hack-ward even when the
reward rewards it.
@@ -40,20 +40,21 @@ def project_one(g: Float[Tensor, "r"], V: Float[Tensor, "k r"],
def project_all(wrappers, V_hack, measure_only: bool, route: bool,
gate_mode: str = "one_sided", overshoot: float = 1.0,
preserve_magnitude: bool = False):
"""arms: none -> measure_only | erase -> write g_proj | route -> erase + park in δS_hack."""
"""arms: none -> measure_only | erase -> write g_proj | route -> erase + park in B_hack.
The knob B is (d_out, r); project_one is dim-agnostic so we flatten to ^{d_out·r}."""
cins, couts, fired = [], [], 0
for name, w in wrappers.items():
if name not in V_hack: # below noise floor -> no hack signal -> skip
continue
g_proj, removed, cin, cout = project_one(
w.δS.grad, V_hack[name], gate_mode, preserve_magnitude, overshoot)
g = w.B.grad.reshape(-1) # (d_out·r,) V_hack[name] is (k, d_out·r)
g_proj, removed, cin, cout = project_one(g, V_hack[name], gate_mode, preserve_magnitude, overshoot)
cins.append(cin); couts.append(cout)
fired += int(cin > 1e-6)
if measure_only: # vanilla: measure cin/cout, leave g
continue
w.δS.grad = g_proj # erase: descend the hack-free gradient
w.B.grad = g_proj.reshape_as(w.B.grad) # erase: descend the hack-free gradient
if route:
w.δS_hack.grad = removed # park hack part in deletable knob (Cloud 2024)
w.B_hack.grad = removed.reshape_as(w.B_hack) # park hack part in deletable knob (Cloud 2024)
n = max(len(cins), 1)
mean = lambda xs: torch.stack(xs).mean().item() if xs else 0.0
return mean(cins), mean(couts), fired / n
@@ -65,5 +66,6 @@ def cos_overlap(grads: dict, V_hack) -> float:
for name, g in grads.items():
if name not in V_hack or g is None:
continue
g = g.reshape(-1) # B-grad (d_out, r) -> ^{d_out·r}
cs.append(torch.relu(V_hack[name] @ g).norm() / g.norm().clamp_min(1e-12))
return torch.stack(cs).mean().item() if cs else 0.0
+43 -35
View File
@@ -6,7 +6,7 @@ Unbiased normalization: Dr.GRPO, Liu et al. 2025, arXiv:2503.20783 (drop the
generate -> grade -> backward -> project -> step
Arms (one knob): none -> measure_only | erase -> write g_proj | route -> +park δS_hack.
Arms (one knob): none -> measure_only | erase -> write g_proj | route -> +park B_hack.
Presets: smoke (tiny-random, GPU bf16, the only gate) | fast | full.
python -m projected_grpo.train smoke --intervention=erase ...
@@ -73,7 +73,8 @@ class Config:
n_problems: int = 6
env_mode: Optional[EnvMode] = None # force a single mode (emergence cells)
mix_ratio: float = 0.5 # G_t / G (teacher fraction)
teacher_pool_dir: str = ""
teacher_pool_dir: str = "out/pools/teacher_pool"
lora_rank: int = 32 # r: A is (r×d_in) frozen, B is (d_out×r) trained
# optim
lr: float = 1e-2
adam_beta1: float = 0.9
@@ -82,7 +83,7 @@ class Config:
warmup_frac: float = 0.1
grad_clip: float = 10.0
# GRPO
beta: float = 0.0 # KL to free ref (δS=0); 0 = off
beta: float = 0.0 # KL to free ref (B=0); 0 = off
clip: float = 0.2
# projection
gate_mode: str = "one_sided"
@@ -130,7 +131,7 @@ def load_pool(pool_dir: str) -> dict[str, list[dict]]:
# ── inner GRPO step on one prompt ─────────────────────────────────────────────
def grpo_step(model, tok, wrappers, problem, pool_rows, G_s, G_t, cfg, pad_id, denom):
"""Returns (g_s, g_t, metrics) for this prompt; δS.grad left holding g_s+g_t.
"""Returns (g_s, g_t, metrics) for this prompt; B.grad left holding g_s+g_t.
g_s/g_t are per-module dicts so the outer loop can accumulate + diagnose cin."""
prompt_ids = tok(problem.prompt(), return_tensors="pt").input_ids.to(cfg.device)
P = prompt_ids.shape[1]
@@ -165,7 +166,7 @@ def grpo_step(model, tok, wrappers, problem, pool_rows, G_s, G_t, cfg, pad_id, d
if s_rew: # stash one student gen for the jsonl + the final coherence eyeball
m["gen"] = dict(text=texts[0], hack=s_rew[0].exploited, gt=s_rew[0].gt_correct, reward=s_rew[0].reward)
empty = {n: torch.zeros_like(w.δS) for n, w in wrappers.items()}
empty = {n: torch.zeros_like(w.B) for n, w in wrappers.items()}
if R.max() - R.min() < 1e-6: # zero-variance group => adv≡0 => pure waste (bail)
return empty, dict(empty), m
@@ -193,15 +194,15 @@ def grpo_step(model, tok, wrappers, problem, pool_rows, G_s, G_t, cfg, pad_id, d
is_s = torch.tensor(is_student, dtype=torch.float32, device=cfg.device)[:, None]
# split student/teacher grads. zero+clone BETWEEN passes is load-bearing: after
# L_t.backward, δS.grad already holds g_s+g_t (accumulation), so reset first.
# L_t.backward, B.grad already holds g_s+g_t (accumulation), so reset first.
model.zero_grad()
L_s = (Lp * comp_mask * is_s).sum() / denom
L_s.backward(retain_graph=True)
g_s = {n: w.δS.grad.detach().clone() for n, w in wrappers.items()}
g_s = {n: w.B.grad.detach().clone() for n, w in wrappers.items()}
model.zero_grad()
L_t = (Lp * comp_mask * (1 - is_s)).sum() / denom
L_t.backward()
g_t = {n: w.δS.grad.detach().clone() for n, w in wrappers.items()}
g_t = {n: w.B.grad.detach().clone() for n, w in wrappers.items()}
m["loss"] = (L_s + L_t).item()
return g_s, g_t, m
@@ -226,8 +227,8 @@ def save_ckpt(wrappers, rows, tag):
Path("out/ckpt").mkdir(parents=True, exist_ok=True)
flat = {}
for n, w in wrappers.items():
flat[f"δS/{n}"] = w.δS.detach().cpu()
flat[f"δS_hack/{n}"] = w.δS_hack.detach().cpu()
flat[f"B/{n}"] = w.B.detach().cpu()
flat[f"B_hack/{n}"] = w.B_hack.detach().cpu()
save_file(flat, f"out/ckpt/ckpt{tag}.safetensors")
Path(f"out/ckpt/rows{tag}.json").write_text(json.dumps(rows))
@@ -246,9 +247,11 @@ def main(cfg: Config):
pad_id = tok.pad_token_id if tok.pad_token_id is not None else tok.eos_token_id
model = AutoModelForCausalLM.from_pretrained(cfg.model, dtype=dt).to(cfg.device)
model.eval()
wrappers = wrap(model)
wrappers = wrap(model, cfg.lora_rank)
n_train = sum(w.B.numel() for w in wrappers.values())
logger.info(f"trainable B: {n_train:,} (+{n_train:,} B_hack quarantine), lora_rank={cfg.lora_rank}")
params = [w.δS for w in wrappers.values()] + [w.δS_hack for w in wrappers.values()]
params = [w.B for w in wrappers.values()] + [w.B_hack for w in wrappers.values()]
opt = torch.optim.AdamW(params, lr=cfg.lr, betas=(cfg.adam_beta1, cfg.adam_beta2),
weight_decay=cfg.weight_decay)
sched = get_cosine_schedule_with_warmup(opt, int(cfg.warmup_frac * cfg.steps), cfg.steps)
@@ -258,8 +261,8 @@ def main(cfg: Config):
V_raw, V_sv = resolve_or_extract(model, tok, wrappers, cfg.v_hack_path,
cfg.vhack_pairs_path, cfg.v_hack_top_k, cfg.tau_axis, cfg.n_heldout)
V_hack = postprocess_v_hack(V_raw, V_sv, cfg.v_hack_k, cfg.v_hack_drop_bottom_frac)
δS_ref = next(iter(wrappers.values())).δS # V lives in δS's space: match dtype+device
V_hack = {n: V.to(δS_ref) for n, V in V_hack.items()} # extract/load give fp32[/cpu]; grads are bf16 cuda
B_ref = next(iter(wrappers.values())).B # V lives in B's space: match dtype+device
V_hack = {n: V.to(B_ref) for n, V in V_hack.items()} # extract/load give fp32[/cpu]; grads are bf16 cuda
logger.info(f"V_hack over {len(V_hack)} modules (k_use={cfg.v_hack_k}); "
f"arm={cfg.intervention} -> measure_only={cfg.intervention == 'none'}")
@@ -275,17 +278,21 @@ def main(cfg: Config):
log = logname.open("w")
rows = []
# per-step table: one logger.info row/step (legend once). SHORT runs (20-200 steps),
# so a row/step reads cleanly -- no tqdm bar (it redraws into piped logs).
logger.info(
"SHOULD (per-step JSONL rows + tqdm postfix): erase/route cin_t>0 and >cin_s early "
"(teacher pool carries the hack signal), decaying as V goes stale; cout~0 under one_sided "
"is an arithmetic identity, NOT efficacy; |δSh|>0 iff route, else 0; hack_s stays 0 on "
"tiny-random (mechanics smoke). If cin_t~0 throughout -> V missed the live GRPO gradient.")
pbar = tqdm(range(cfg.steps), desc=f"train[{cfg.intervention}] s{cfg.seed}",
mininterval=120, maxinterval=120)
for step in pbar:
"per-step columns: step | rew=mean reward | gt_s=student solve | hack_s=student exploit (HEADLINE) |"
" hack_t=teacher exploit (pool sanity) | cin_s/cin_t=student/teacher grad overlap with V (relu-agg) |"
" cout=residual hack-ward leak | loss | gn=pre-clip grad L2 | lr | |Bh|=quarantine norm.\n"
"SHOULD: cin_t>cin_s early (teacher carries the hack signal), decaying as V stales; cout~0 under "
"one_sided is an ARITHMETIC IDENTITY not efficacy; |Bh|>0 iff route; hack_s stays 0 on tiny-random.")
hdr = (f"{'step':>4} {'rew':>6} {'gt_s':>5} {'hack_s':>6} {'hack_t':>6} {'cin_s':>6} {'cin_t':>6} "
f"{'cout':>6} {'loss':>7} {'gn':>8} {'lr':>8} {'|Bh|':>8}")
logger.info(hdr)
for step in range(cfg.steps):
opt.zero_grad()
g_s_acc = {n: torch.zeros_like(w.δS) for n, w in wrappers.items()}
g_t_acc = {n: torch.zeros_like(w.δS) for n, w in wrappers.items()}
g_s_acc = {n: torch.zeros_like(w.B) for n, w in wrappers.items()}
g_t_acc = {n: torch.zeros_like(w.B) for n, w in wrappers.items()}
ms = []
for _ in range(cfg.prompts_per_step):
prob = problems[random.randrange(len(problems))]
@@ -295,8 +302,8 @@ def main(cfg: Config):
g_s_acc[n] += g_s[n]; g_t_acc[n] += g_t[n]
ms.append(m)
for n, w in wrappers.items():
w.δS.grad = g_s_acc[n] + g_t_acc[n] # combined live GRPO gradient on the knob
w.δS_hack.grad = None # quarantine moves ONLY via routed `removed` (set in project_all)
w.B.grad = g_s_acc[n] + g_t_acc[n] # combined live GRPO gradient on the knob
w.B_hack.grad = None # quarantine moves ONLY via routed `removed` (set in project_all)
cin_s = cos_overlap(g_s_acc, V_hack) if V_hack else 0.0
cin_t = cos_overlap(g_t_acc, V_hack) if V_hack else 0.0
@@ -309,12 +316,12 @@ def main(cfg: Config):
gn = torch.nn.utils.clip_grad_norm_([p for p in params if p.grad is not None], cfg.grad_clip)
opt.step(); sched.step()
dSh = sum(w.δS_hack.detach().norm().item() for w in wrappers.values())
bh = sum(w.B_hack.detach().norm().item() for w in wrappers.values())
row = dict(step=step, hack_s=_mean([m["hack_s"] for m in ms]), gt_s=_mean([m["gt_s"] for m in ms]),
hack_t=_mean([m["hack_t"] for m in ms]), gt_t=_mean([m["gt_t"] for m in ms]),
reward=_mean([m["reward"] for m in ms]), loss=_mean([m.get("loss", 0.0) for m in ms]),
gn=gn.item(), cin_s=cin_s, cin_t=cin_t, cin=cin, cout=cout, fired=fired,
dSh=dSh)
bh=bh)
if "gen" in ms[0]: # log the first prompt's first student gen (text + flags)
row["gen"] = ms[0]["gen"]
@@ -325,12 +332,12 @@ def main(cfg: Config):
V_raw, V_sv = extract_v_hack(model, tok, wrappers, default_pairs(),
cfg.v_hack_top_k, cfg.tau_axis, cfg.n_heldout)
V_hack = postprocess_v_hack(V_raw, V_sv, cfg.v_hack_k, cfg.v_hack_drop_bottom_frac)
V_hack = {n: V.to(δS_ref) for n, V in V_hack.items()} # keep V in δS's bf16/cuda space
V_hack = {n: V.to(B_ref) for n, V in V_hack.items()} # keep V in B's bf16/cuda space
shared = [n for n in V_hack if n in old]
ov = _mean([abs((V_hack[n][0] @ old[n][0]).item()) for n in shared]) if shared else 0.0
row["basis_overlap"] = ov # GUARD: should sit near 1.0; <~0.2 => refresh rotated off-hack
# deployment-time eval (route only): zero δS_hack, eval the ablated model
# deployment-time eval (route only): zero B_hack, eval the ablated model
if is_route and cfg.eval_ablate_every and step % cfg.eval_ablate_every == 0:
with ablate_quarantine(wrappers):
hd, sd = eval_hack_solve(model, tok, wrappers, problems, cfg.eval_n_prompts,
@@ -339,9 +346,10 @@ def main(cfg: Config):
rows.append(row)
log.write(json.dumps(row) + "\n"); log.flush()
pbar.set_postfix(hack_s=f"{row['hack_s']:.2f}", cin_s=f"{cin_s:.2f}", cin_t=f"{cin_t:.2f}",
cout=f"{cout:.2f}", δSh=f"{dSh:.2f}", loss=f"{row['loss']:.3f}",
refresh=False) # don't force a bar redraw every step (pollutes piped logs)
lr = sched.get_last_lr()[0]
logger.info(f"{step:4d} {row['reward']:+6.2f} {row['gt_s']:5.2f} {row['hack_s']:6.2f} "
f"{row['hack_t']:6.2f} {cin_s:6.3f} {cin_t:6.3f} {cout:6.3f} {row['loss']:+7.3f} "
f"{gn.item():8.1e} {lr:8.1e} {bh:8.3f}")
if step % 25 == 0:
save_ckpt(wrappers, rows, cfg.out_tag or f"_{cfg.intervention}_s{cfg.seed}")
@@ -356,14 +364,14 @@ def _bluf(cfg, rows, wrappers, model, tok, problems, pad_id):
last5 = rows[-5:]
hack = _mean([r["hack_s"] for r in last5]); solve = _mean([r["gt_s"] for r in last5])
cin_t = _mean([r["cin_t"] for r in last5]); cout = _mean([r["cout"] for r in last5])
dSh = rows[-1]["dSh"]
bh = rows[-1]["bh"]
out = f"logs/run_{cfg.intervention}{cfg.out_tag}_s{cfg.seed}.log"
# mechanics cue: projection arms want the teacher hack signal present (cin_t>0); none is neutral
cue = "\U0001F7E2" if (cfg.intervention == "none" or cin_t > 0) else "\U0001F534"
summary = {"arm": cfg.intervention, "steps": len(rows),
"hack↓": round(hack, 3), "solve→": round(solve, 3),
"cin_t↑": round(cin_t, 3), "cout→0": round(cout, 3), "|δSh|": round(dSh, 2)}
"cin_t↑": round(cin_t, 3), "cout→0": round(cout, 3), "|Bh|": round(bh, 2)}
if cfg.intervention == "route":
with ablate_quarantine(wrappers):
hk_abl, sv_abl = eval_hack_solve(model, tok, wrappers, problems, cfg.eval_n_prompts or 2,
@@ -376,7 +384,7 @@ def _bluf(cfg, rows, wrappers, model, tok, problems, pad_id):
caption = (f"Table: last-5-step means, arm={cfg.intervention} seed={cfg.seed}. "
"hack↓ student exploit rate, solve→ gt-correct rate, cin_t↑ teacher-grad "
"overlap with V (relu-before-agg), cout→0 residual hack-ward leak (identity under "
"one_sided, not efficacy), |δSh| quarantine norm (>0 iff route).")
"one_sided, not efficacy), |Bh| quarantine norm (>0 iff route).")
logger.info(f"out: {out}")
# Last student generation -- a coherence eyeball before the numbers. SHOULD: real
# code/prose for the problem. If it is token salad the policy diverged and the eval