This commit is contained in:
wassname
2026-05-23 13:04:03 +08:00
parent 2d6695389f
commit 42498682ca
14 changed files with 797 additions and 754 deletions
Vendored Submodule
+1
Submodule docs/vendor/lora-lite added at ce8c250422
Vendored Submodule
+1
Submodule docs/vendor/simple_GRPO added at 30f252ce36
+10 -1
View File
@@ -5,7 +5,10 @@ description = "SVD-basis gradient projection vs RL reward hacking on Nanda's Lee
requires-python = ">=3.11"
dependencies = [
"torch>=2.4",
"transformers>=4.45",
# transformers>=4.58 has Qwen3.5 (model_type=qwen3_5, gated-delta-net).
# Per HF card: install from main if 4.58 not yet released. We pin to main
# via [tool.uv.sources] below; the version spec here is just a floor.
"transformers>=4.58.0.dev0",
"einops>=0.8",
"jaxtyping>=0.2",
"beartype>=0.18",
@@ -38,3 +41,9 @@ where = ["src"]
[tool.uv]
exclude-newer = "2026-05-23"
[tool.uv.sources]
# Qwen3.5 (qwen3_5 model_type, gated-delta-net) lands in transformers main; pin
# until 4.58 release. v5.7.0 changelog note: "incorrect cached forward behavior
# in Qwen3.5's gated-delta-net linear attention" — fixed on main.
transformers = { git = "https://github.com/huggingface/transformers.git", rev = "main" }
+252 -103
View File
@@ -30,45 +30,63 @@ plus an optional block-Cayley rotation. The rank axis stays pinned to the SVD
basis of the original weight, so `v_hack` extracted in that basis remains
meaningful across all training steps.
Forward pass per wrapped module:
Forward pass per wrapped module (first pass uses full rank $r = \min(d_{in}, d_{out})$,
so the residual term $W_{res}$ vanishes):
$$y = x W_{res}^T + ((x V_h^T) \odot (S + \delta_S)) U^T$$
$$y = ((x V_h^T) \odot (S + \delta_S)) U^T$$
where $W_{res} = W - U_r \mathrm{diag}(S_r) V_{h,r}$, and $U_r$, $S_r$, $V_{h,r}$
are buffers (frozen). Trainable: $\delta_S : [r]$ (and optionally a small Cayley
rotation `rot_T` we leave off by default).
where $U$, $S$, $V_h$ come from the SVD of $W$ and are buffers (frozen).
Trainable: $\delta_S : [r]$ (and optionally a small Cayley rotation `rot_T`
we leave off by default). At reduced rank we would add
$x W_{res}^T$ with $W_{res} = W - U_r \mathrm{diag}(S_r) V_{h,r}$, but we
defer rank cropping to v2 to skip the "where to cut" question.
Per-step gradient signal:
$$\frac{\partial L}{\partial \delta_S} = \sum_t (x_t V_h^T) \odot \left(\frac{\partial L}{\partial h_t} U\right) \in \mathbb{R}^r$$
Both factors of the elementwise product live in rank-r SVD basis. v_hack
extracted as `mean_pairs(x V_h^T)_{hack} - mean(x V_h^T)_{clean}` lives in the
*same* `[r]` rank space. Projection is one line:
We extract `v_hack` **gradient-side** (locked in): for each contrastive pair,
run one NLL backward on the completion tokens and read each module's
`m.delta_S.grad : [r]`. Then $\hat v_{hack}^{(m)} =$ unit$($mean$_{hack}$ grad $-$ mean$_{clean}$ grad$)$.
This lives in the exact same `[r]` rank space the per-step training gradient
lives in (the gradient is the natural object to compare gradients against),
and it fuses the input-activation and output-error contributions in one shot
instead of guessing whether input-side $(x V_h^T)$ or output-side $(\partial L/\partial h)\, U$
better predicts where SGD will move. We did consider activation-side
($x V_h^T$ mean-diff). Dropped as primary because it only sees the input
factor and ignores the output-error factor, while the per-step gradient sees
both.
$$\nabla_{\delta_S} \leftarrow \nabla_{\delta_S} - \cos_{align} \cdot \|\nabla_{\delta_S}\| \cdot \hat v_{hack}$$
Projection (locked: no magnitude threshold; one-sided clip stays — see note):
with one-sided gating (only project when $\cos_{align} > 0$, i.e. the gradient is
pushing toward the hack direction). Magnitude preservation = renormalize back
to original $\|\nabla_{\delta_S}\|$.
$$g \leftarrow g - \max(0,\, \cos_{align}) \cdot \|g\| \cdot \hat v_{hack}, \qquad \cos_{align} = \frac{g \cdot \hat v_{hack}}{\|g\|}$$
then rescale to original $\|g\|$ (magnitude-preserving). The $\max(0,\cdot)$ is
not gating, it's directional correctness: without it, when $\cos<0$ we'd be
*adding* to the hack component. No magnitude/threshold gating (locked): we
project every step every module. Capacity cost is ~1/r per module per step.
If `v_hack` at a module is just noise, projection ablates a noise direction in
expectation = approximately a no-op.
## Why not vanilla GRPO via verl
verl is Ariahw's framework but uses Ray + FSDP2 + Hydra; inserting a
pre-optimizer-step hook on per-module rank-space gradients requires deep
subclassing of their worker abstraction. We pay one cost in exchange:
we use [lsdefine/lsrl](https://github.com/lsdefine/lsrl) instead. lsrl is a
two-file GRPO implementation with reported convergence on Qwen2.5-3B in 12m on
2xA800 (60 steps). One pre-optimizer hook is trivial to add.
we use [lsdefine/simple_GRPO](https://github.com/lsdefine/simple_GRPO) instead.
simple_GRPO is a two-file GRPO implementation (`ref_server.py` + `grpo_ref_split.py`,
~315 lines total) with reported convergence on Qwen2.5-7B. The training loop
is literally `loss = GRPO_step(batch); engine.backward(loss); engine.step()`
inserting a projection hook between backward and step is a one-line edit.
Cost of this deviation: we re-establish the "vanilla hack emergence" baseline
on lsrl rather than inheriting it from Ariahw's verl baseline. H4 is the
sanity check that this happens. We port Ariahw's `run_tests`-overwrite
on simple_GRPO rather than inheriting it from Ariahw's verl baseline. H4 is
the sanity check that this happens. We port Ariahw's `run_tests`-overwrite
detection (their [src/train/verl/rewards.py](https://github.com/ariahw/rl-rewardhacking/blob/main/src/train/verl/rewards.py))
into lsrl's reward server (`docs/vendor/lsrl/lsrl/reward_server.py`).
into simple_GRPO's reward server (`docs/vendor/simple_GRPO/ref_server.py`).
Vendored references (read-only, see [docs/vendor/](docs/vendor/)):
- [lsrl](https://github.com/lsdefine/lsrl) — GRPO trainer
- [simple_GRPO](https://github.com/lsdefine/simple_GRPO) — GRPO trainer
- [lora-lite](https://github.com/wassname/lora-lite) — AntiPaSTO adapter
- [rl-rewardhacking](https://github.com/ariahw/rl-rewardhacking) (already at `external/`)
@@ -82,12 +100,6 @@ LeetCode pass rate within 10pp of vanilla.
Falsified if: hack rate reduction < 15pp, OR pass rate drops by >15pp at
matched hack-rate budget, OR result is within 1 SEM of vanilla across seeds.
**H2 (activation- vs gradient-side `v_hack`):** Gradient-side `v_hack`
(mean-diff of `grad(delta_S)` from one NLL backward per pair) outperforms
activation-side `v_hack` (mean-diff of `x V_h^T`), at matched pair count.
Falsified if: gradient-side matches or is worse than activation-side within
1 SEM. *(open question — see "Decisions left open" below.)*
**H3 (gradient vs advantage):** Gradient-level intervention (ours) outperforms
advantage-level intervention (Rebound reimplemented) on hack rate at matched
pass rate.
@@ -95,106 +107,240 @@ pass rate.
Falsified if: Rebound reimplementation matches or beats ours within 1 SEM.
**H4 (scaling sanity on our stack):** Qwen3.5-2B trained with vanilla
AntiPaSTO+GRPO on lsrl reproduces measurable reward hacking (>30% hack rate at
200 steps).
AntiPaSTO+GRPO on simple_GRPO reproduces measurable reward hacking (>30% hack
rate at 200 steps).
Falsified if: vanilla hack rate <30%. Decision branch: swap to Qwen3-4B with
num_generations halved. Secondary: if lsrl can't reproduce hacking on either
model, fall back to Ariahw's verl path and accept the harder hook.
**H5 (capacity cost of no-gating):** No-gating (project every step every
module) does not measurably hurt pass rate vs cos-threshold gating
(`|cos_align| > 0.1` -> project). Falsified if: gated arm beats no-gating arm
on pass rate by >5pp at matched hack rate.
num_generations halved. Secondary: if simple_GRPO can't reproduce hacking on
either model, fall back to Ariahw's verl path and accept the harder hook.
## Steps
### 1. Build infra — fast-dev-run targets first, no real training yet
- **1a.** Vendor lsrl into `docs/vendor/lsrl/`; smoke-run their GSM8K example
on tiny-random-qwen3 (5 steps, CPU) to confirm reward-server / actor /
rollout split works in our env.
- **1b.** Vendor lora-lite into `docs/vendor/lora-lite/`; wrap Qwen3.5-0.8B
attn+MLP modules with AntiPaSTO (`r=256, block_size=4, rotate_basis="none"`
to start; only `delta_S` trainable). Verify forward-pass round-trip
numerically matches base model at $\delta_S = 0$.
- **1c.** Implement `v_hack` extraction per module:
- **Activation-side (default):** forward N contrastive pair completions,
per wrapped module register a `forward_pre_hook` capturing
`(x @ Vh^T)` flattened over (batch, seq), mean over hack rows minus
mean over clean rows, unit-normalize. Cache as `dict[module_name -> Tensor[r]]`
on disk.
- **Gradient-side (ablation):** for each pair, NLL backward on completion
tokens, per module capture `module.lora_delta_s.grad : [r]`, mean-diff
hack vs clean, unit-normalize.
- Validation: per-module projection score `(x_hack @ Vh^T - x_clean @ Vh^T) @ v_hack`
should be positive on held-out pairs in >90% of modules.
- **1a.** Vendor simple_GRPO into `docs/vendor/simple_GRPO/` (done); smoke-run
their GSM8K example on tiny-random-qwen3 (5 steps, CPU) to confirm
`ref_server` + `grpo_ref_split` rollout/train split works in our env.
- **1b.** Vendor lora-lite into `docs/vendor/lora-lite/` (done); wrap
Qwen3.5-0.8B attn+MLP `nn.Linear` modules with AntiPaSTO **at full rank**
(`r = min(d_in, d_out)`, no SVD cropping; `rotate_basis="none"`, only
`delta_S` trainable). Full rank means $W = U \,\mathrm{diag}(S)\, V_h$
exactly and `W_res = 0`, so there's no truncation error to debug on the
first pass. Verify forward-pass round-trip numerically matches base model
at $\delta_S = 0$ (max abs diff <1e-3 on a fixed prompt).
- **1c.** Implement gradient-side `v_hack` extraction (pseudocode below).
Validation: per-module held-out projection score
`cos(g_held_hack - g_held_clean, v_hack)` > 0 in >50% of modules.
### 2. H4 sanity — does vanilla AntiPaSTO+GRPO+lsrl produce hacking?
### 2. H4 sanity — does vanilla AntiPaSTO+GRPO+simple_GRPO produce hacking?
- **2a.** Port Ariahw's `run_tests`-overwrite detection into lsrl's reward
fn. Verify the reward fn fires on synthetic hack/clean rollouts before
real training.
- **2b.** Train Qwen3.5-2B, AntiPaSTO (`r=256`, `delta_S` only), GRPO
- **2a.** Port Ariahw's `run_tests`-overwrite detection into simple_GRPO's
`ref_server.py` reward fn. Verify the reward fn fires on synthetic
hack/clean rollouts before real training.
- **2b.** Train Qwen3.5-2B, AntiPaSTO (`r=full`, `delta_S` only), GRPO
(group_norm), 200 steps, num_generations=8, batch=16, 1 seed.
Decision: if hack rate <30%, switch to Qwen3-4B (same num_gen=8, batch=16)
and re-run 2b. Secondary fallback: drop lsrl, return to verl.
Decision: if hack rate <30%, switch to Qwen3-4B with `num_generations=4,
batch=16` (half num_gen to keep VRAM headroom) and re-run 2b.
Secondary fallback: drop simple_GRPO, return to verl.
### 3. Implement rank-space projection in lsrl's training loop
### 3. Implement rank-space projection in simple_GRPO's training loop
- **3a.** lsrl's actor calls `optimizer.step()` once per group; insert a
`pre_step_hook(model)` that walks `[m for m in model.modules() if hasattr(m, 'lora_delta_s')]`
and for each module reads `m.lora_delta_s.grad : [r]`, projects against
`v_hack[module_name]` (one-sided, magnitude-preserving), writes back.
- **3b.** Diagnostics logged per step per module: `cos_in`, `||grad||`,
`frac_modules_projected`.
- **3a.** In `grpo_ref_split.py`, between `engine.backward(loss)` and
`engine.step()`, call `project_grads(model, v_hack_cache)`.
`project_grads` walks `[m for m in model.modules() if hasattr(m, 'delta_S')]`
and for each module reads `m.delta_S.grad : [r]`, projects against
`v_hack[module_name]` (one-sided, magnitude-preserving), writes back
in place. (Pseudocode below.)
- **3b.** Diagnostics logged per step (aggregated over modules):
mean/std `cos_align`, mean `||grad||`, `frac_modules_with_cos>0`.
### 4. Run arms (200 steps each, 3 seeds where indicated)
a. Vanilla AntiPaSTO + GRPO (3 seeds) — baseline
b. Our method, activation-side `v_hack`, no gating (3 seeds) — main result
c. Our method, gradient-side `v_hack` (3 seeds) — H2
d. Our method, cos-threshold gating ($|\cos| > 0.1$) (1 seed) — H5
e. Our method, no magnitude preservation (1 seed) — design ablation
f. Rebound reimplementation: advantage-side `v_hack` penalty (3 seeds) — H3
g. AntiPaSTO rank sweep: $r \in \{64, 256, 1024\}$ (1 seed each) — sensitivity
b. Our method, gradient-side `v_hack`, no gating (3 seeds) — main result
c. Our method, no magnitude preservation (1 seed) — design ablation
d. Rebound reimplementation: advantage-side `v_hack` penalty (3 seeds) — H3
(concrete formula: per-rollout penalty `α · max(0, cos(h_mean, v_concept))`
added to scalar reward, where `h_mean` is mean residual-stream activation
at a chosen layer and `v_concept` is mean-diff activation direction
extracted from the same 60-80 pairs. We use Wu & Tang 2026 §3.2's
published `α=0.5` and same layer fraction (60-75% depth). Single
layer, not per-module, matching their setup. *Different `v_concept`
from our gradient-side `v_hack` — this is intentional: H3 isolates the
gradient-vs-advantage mechanism choice, not the direction-extraction
choice.*)
Total: 14 runs × ~3h on RTX 6000 96GB = ~42h compute.
Total: 10 runs × ~3h on RTX 6000 96GB = ~30h compute.
*(Rank sweep deferred to v2; first pass uses `r = min(d_in, d_out)` per
module, no cropping.)*
### 5. Measure at every 25 steps
- **Hack rate** (Ariahw's detector ported into lsrl)
- **Hack rate** (Ariahw's detector ported into simple_GRPO)
- **Pass rate** on held-out problems without write access to evaluator
- **Per-module `cos_align`** trajectory (sanity that we're projecting
something nonzero)
- **`frac_modules_projected`** per step (sanity for gating arms)
- **`frac_modules_with_cos>0`** per step (sanity that one-sided clip fires)
- **KL drift from init policy** (catastrophic-change check)
### 6. Headline plot
### 6. Headline plot and headline table
Hack rate vs pass rate, one point per (arm × seed). Pareto frontier. Our
method should land below-and-to-the-right of vanilla. Annotate Rebound.
**Plot.** Hack rate vs pass rate, one point per (arm × seed). Pareto
frontier. Our method should land below-and-to-the-right of vanilla.
Annotate Rebound.
**Table schema (publication-ready; left-to-right = essential to optional,
so trailing columns can be cut for space):**
| Arm | ΔSafePass↑ | Hack %↓ | Pass %↑ | KL↓ | mean·cos\* | frac·fired\* | ‖g‖\* |
|---|---|---|---|---|---|---|---|
| Vanilla (a) | 0 (ref) | — | — | — | — | — | — |
| **Ours (b)** | — | — | — | — | — | — | — |
| Ours, no mag-preserve (c) | — | — | — | — | — | — | — |
| Rebound (d) | — | — | — | — | — | — | — |
*Caption.* ↑ higher is better, ↓ lower is better. **ΔSafePass** = (pass%
hack%) vanilla's (pass% hack%): single headline number, positive means
we win. **Hack %** = fraction of rollouts triggering `run_tests`-overwrite
detector. **Pass %** = fraction passing held-out tests without write access.
**KL** = mean per-token KL from init policy over last 25 steps.
\* = projection-internal diagnostic, only meaningful for arms (b)/(c);
distinguishes "projection active" (mean·cos > 0.2, frac·fired > 0.4) from
"projection silent no-op". Cells report mean ± SEM across seeds.
### 7. Falsification check
Before publishing, run pre-registered analysis on H1-H5. Report all
Before publishing, run pre-registered analysis on H1, H3, H4. Report all
hypotheses including falsified ones.
## Pseudocode (the three load-bearing bits)
### A. AntiPaSTO module wrap (full rank, first pass)
```
class AntiPaSTO(nn.Module):
# constructed from an existing nn.Linear(W: [d_out, d_in], b)
# FIRST PASS: r = min(d_out, d_in) -- no truncation, W_res == 0
def __init__(self, W, b):
U, S, Vh = torch.linalg.svd(W.float(), full_matrices=False)
r = S.shape[0] # = min(d_out, d_in)
# buffers (frozen): the full SVD
self.U = U # [d_out, r]
self.S = S # [r]
self.Vh = Vh # [r, d_in]
self.b = b
# trainable (ONLY this): scalar per rank
self.delta_S = nn.Parameter(torch.zeros(r))
def forward(self, x): # x: [..., d_in]
return ((x @ self.Vh.T) * (self.S + self.delta_S)) @ self.U.T + self.b
```
Replace every target `nn.Linear` in attn (`q,k,v,o_proj`) and MLP
(`up,gate,down_proj`) with this. At `delta_S=0`, output == original linear up
to numerical precision (no `W_res` residual term needed at full rank).
**SVD precompute strategy.** Don't SVD the whole model on GPU at once.
Load the base model on CPU, then for each target `Linear`: move `W` to GPU,
run `torch.linalg.svd(W.float(), full_matrices=False)`, save
`(U, S, Vh) -> svd_cache/{model_name}/{module_path}.pt`. Wrap construction
then loads the cached SVD per module. SVD is done once per base model; ~5-10s
per big MLP weight on RTX 3090.
### B. Gradient-side `v_hack` extraction (per module)
```
v_hack = {} # dict[module_name -> Tensor[r]]
grads_hack = defaultdict(list)
grads_clean = defaultdict(list)
# Per-pair: process hack and clean independently, NLL over their own completion
# tokens only. Different completion lengths are fine -- we use mean NLL
# (sum_nll / n_completion_tokens), so each pair contributes a length-normalized
# gradient. This avoids biasing v_hack toward longer (typically clean)
# completions. Pad each example individually; no cross-completion padding.
for (prompt, hack_completion, clean_completion) in pairs:
for label, completion in [('hack', hack_completion), ('clean', clean_completion)]:
model.zero_grad()
ids = tokenize(prompt + completion) # [1, L]
mask = completion_mask(ids, prompt_len=len(prompt_ids)) # 1 on completion tokens
logits = model(ids).logits[:, :-1]
# MEAN NLL over completion tokens (length-normalized)
loss = (nll_per_token(logits, ids[:, 1:]) * mask[:, 1:]).sum() / mask[:, 1:].sum()
loss.backward()
for name, m in model.named_modules():
if hasattr(m, 'delta_S'):
bucket = grads_hack if label == 'hack' else grads_clean
bucket[name].append(m.delta_S.grad.detach().cpu().clone())
for name in grads_hack:
diff = stack(grads_hack[name]).mean(0) - stack(grads_clean[name]).mean(0) # [r]
v_hack[name] = diff / (diff.norm() + 1e-8)
torch.save(v_hack, 'v_hack.pt')
```
Validation (report both, don't just gate on threshold):
- On held-out pairs, recompute per-module `diff_held` and
`cos_align_held = cos(diff_held, v_hack[name])`.
- **Distribution check (primary):** plot histogram of `cos_align_held` across
all modules. Healthy = unimodal positive, median > 0.3. Pathological =
bimodal or median near 0.
- **Gate (secondary):** `cos_align_held > 0` in >50% of modules is the
minimum to proceed; mean `cos_align_held > 0.2` is the target. If <50% pass,
extraction is broken and we debug before training.
### C. Pre-optimizer-step projection hook
```
def project_grads(model, v_hack: dict[str, Tensor]):
# called after engine.backward(loss), before engine.step()
cos_log, n_modules, n_fired = [], 0, 0
for name, m in model.named_modules():
if not hasattr(m, 'delta_S'): continue
g = m.delta_S.grad # [r]
if g is None: continue
n_modules += 1
v = v_hack[name].to(g.device) # [r], unit
g_norm = g.norm()
if g_norm < 1e-12: continue
cos_a = (g @ v) / g_norm # scalar
cos_log.append(cos_a.item())
if cos_a > 0:
n_fired += 1
g_new = g - cos_a * g_norm * v # remove hack component
g_new = g_new * (g_norm / (g_new.norm() + 1e-8)) # magnitude preserve
m.delta_S.grad.copy_(g_new)
return dict(mean_cos=mean(cos_log), frac_fired=n_fired/max(n_modules,1))
```
Integration into `grpo_ref_split.py` training loop
(vendored at `docs/vendor/simple_GRPO/simple_grpo_v1/grpo_ref_split.py`; we copy and
edit, not import):
```
# at top of training script, once:
v_hack = torch.load('v_hack.pt', map_location='cpu') # dict[str, Tensor[r]]
# (extraction script from B above produces this artifact; if missing, crash loud)
# inside the training loop:
loss = GRPO_step(batch)
engine.backward(loss)
stats = project_grads(engine.module, v_hack) # <-- NEW: 1 line
engine.step()
if rank == 0: log(stats)
```
## Decisions left open (write these up alongside results)
- **Activation- vs gradient-side `v_hack` (H2).** Activation = cheap, geometric,
matches Wu-Tang/CAA tradition. Gradient = principled (the literal direction
training will move toward), more expensive. Default activation; gradient is
arm c.
- **Gating threshold (H5).** No-gating default; cos>0.1 gating is arm d.
Argument for no-gating: removing 1 direction from r=256 trainable subspace
per module per step is ~0.4% capacity. If `v_hack` at a module is noise, we
ablate a noise direction in expectation = approx no-op. Argument for gating:
in modules where hack signal is weak, projection just removes some random
direction the optimizer might have used. H5 settles this.
- **Rank `r`.** Default 256 (lora-lite antipasto default); sweep in arm g.
Trainable parameter count is just `r` per module (vs `r*(d_in+d_out)` for
standard LoRA), so larger `r` is cheap, but `v_hack`'s SNR per dim degrades.
- **Rank `r`.** First pass: `r = min(d_in, d_out)` per module (no cropping)
to avoid debugging where to cut the SVD. Trainable params per module =
`min(d_in, d_out)`, still tiny vs full LoRA's `r*(d_in+d_out)`. Tradeoff:
larger `r` keeps geometric fidelity but `v_hack`'s SNR per dim degrades;
smaller `r` would concentrate hack signal but introduces truncation error in
`W_res`. Rank sweep is v2 work.
## Why measure ratio, not just hack rate
@@ -206,28 +352,31 @@ problems without write access, our method reduces hack rate from X% to Y%."
## Compute estimate
- Single run on 96GB RTX 6000: ~2-3h (Qwen3.5-2B, num_gen=8, 200 steps, lsrl,
AntiPaSTO r=256)
- 14 runs: 35-45h
- At ~$3 AUD/hr: $105-135 AUD
- Single run on 96GB RTX 6000: ~2-3h (Qwen3.5-2B, num_gen=8, 200 steps,
simple_GRPO, AntiPaSTO full rank)
- 10 runs: 25-35h
- At ~$3 AUD/hr: $75-105 AUD
- + debugging buffer: budget ~$200 AUD total
- Calendar time: 1 week back-to-back; 2-3 weeks with iteration
## Risks and decision points
- **H4 falsified (no hack on Qwen3.5-2B with lsrl):** branch 1 — try
Qwen3-4B same hyperparams. Branch 2 — drop lsrl, hook into verl
- **H4 falsified (no hack on Qwen3.5-2B with simple_GRPO):** branch 1 — try
Qwen3-4B same hyperparams. Branch 2 — drop simple_GRPO, hook into verl
directly. Adds ~1-2 weeks engineering.
- **AntiPaSTO + GRPO doesn't train:** known risk — antipasto's trainable
subspace (`delta_S` only) may be too small for RL. Mitigation: enable
Cayley rotation (`rotate_basis="V"`, `block_size=4`), adds `r*(bs-1)/2`
params per module. Or fall back to PiSSA-LoRA-freeze-A.
subspace (`delta_S` only) may be too small for RL. If so, document and
fall back to PiSSA-LoRA-freeze-A. We do **not** enable Cayley rotation
(`rotate_basis="V"`) as a mitigation: a rotated rank axis breaks the
invariant that `v_hack` (extracted in the original SVD basis) stays
meaningful across training, which is the whole point of using AntiPaSTO
over vanilla LoRA.
- **`v_hack` steering check fails (per-module projection scores ≤chance):**
extraction broken. Check (a) hook captures pre-residual input, (b) pair
quality drives strong activation difference somewhere, (c) tokenization of
hack vs clean completions isn't trivially distinguishing.
- **All methods tie vanilla on hack rate:** intervention not biting. Check
`cos_align` logs nonzero, `frac_modules_projected` nonzero.
`cos_align` logs nonzero, `frac_modules_with_cos>0` nonzero.
## What this is not
@@ -247,6 +396,6 @@ problems without write access, our method reduces hack rate from X% to Y%."
- **AntiPaSTO** ([wassname/lora-lite/variants/antipasto.py](https://github.com/wassname/lora-lite/blob/main/src/lora_lite/variants/antipasto.py),
([wassname/AntiPaSTO paper](https://github.com/wassname/AntiPaSTO)) — adapter
we wrap with.
- **lsrl** ([lsdefine/lsrl](https://github.com/lsdefine/lsrl)) — GRPO trainer.
- **simple_GRPO** ([lsdefine/simple_GRPO](https://github.com/lsdefine/simple_GRPO)) — GRPO trainer.
- **PiSSA** ([arxiv:2404.02948](https://arxiv.org/abs/2404.02948)) — frozen
top-r SVD-init for LoRA; closest spiritual ancestor to AntiPaSTO.
+153
View File
@@ -0,0 +1,153 @@
"""AntiPaSTO full-rank adapter for projected-GRPO.
Per spec.md: wrap nn.Linear with frozen U, S, Vh (full rank = min(d_in, d_out)).
Trainable: delta_S only, shape [r]. No rotation (would break v_hack basis invariance).
Forward:
y = ((x @ Vh.T) * (S + delta_S)) @ U.T + b
At delta_S=0, output == original linear up to fp32 SVD round-trip precision.
"""
from __future__ import annotations
import hashlib
from pathlib import Path
import torch
from jaxtyping import Float
from loguru import logger
from torch import Tensor, nn
class AntiPaSTOLinear(nn.Module):
"""Drop-in replacement for nn.Linear with full-rank SVD + learnable delta_S.
Buffers (frozen): U[d_out, r], S[r], Vh[r, d_in], optional bias[d_out].
Trainable: delta_S[r].
"""
def __init__(
self,
U: Float[Tensor, "d_out r"],
S: Float[Tensor, "r"],
Vh: Float[Tensor, "r d_in"],
bias: Float[Tensor, "d_out"] | None,
dtype: torch.dtype = torch.float32,
):
super().__init__()
r = S.shape[0]
self.register_buffer("U", U.to(dtype).contiguous())
self.register_buffer("S", S.to(dtype).contiguous())
self.register_buffer("Vh", Vh.to(dtype).contiguous())
if bias is not None:
self.register_buffer("bias", bias.to(dtype).contiguous())
else:
self.bias = None
self.delta_S = nn.Parameter(torch.zeros(r, dtype=dtype))
@property
def r(self) -> int:
return self.S.shape[0]
def forward(self, x: Float[Tensor, "... d_in"]) -> Float[Tensor, "... d_out"]:
# x @ Vh.T : [..., r]; * (S+dS) : elementwise; @ U.T : [..., d_out]
h = x @ self.Vh.transpose(-1, -2)
h = h * (self.S + self.delta_S)
y = h @ self.U.transpose(-1, -2)
if self.bias is not None:
y = y + self.bias
return y
def _model_svd_dir(model_name: str, cache_root: Path) -> Path:
safe = model_name.replace("/", "__")
return cache_root / safe
def svd_cached(
W: Float[Tensor, "d_out d_in"],
cache_path: Path,
device: torch.device,
) -> tuple[Tensor, Tensor, Tensor]:
"""SVD with disk cache. Compute on `device` in fp32, save as fp32 cpu tensors.
Cache key = sha256(W.cpu fp32 bytes)[:16] in filename suffix, so weight change
invalidates the cache automatically (fail-loud, no silent stale).
"""
cache_path.parent.mkdir(parents=True, exist_ok=True)
W_fp32 = W.detach().to(torch.float32).cpu().contiguous()
sha = hashlib.sha256(W_fp32.numpy().tobytes()).hexdigest()[:16]
final = cache_path.with_suffix(f".{sha}.pt")
if final.exists():
d = torch.load(final, map_location="cpu", weights_only=True)
return d["U"], d["S"], d["Vh"]
W_gpu = W_fp32.to(device)
U, S, Vh = torch.linalg.svd(W_gpu, full_matrices=False)
U, S, Vh = U.cpu(), S.cpu(), Vh.cpu()
torch.save({"U": U, "S": S, "Vh": Vh}, final)
logger.info(f"SVD cached: {final.name} shape U={tuple(U.shape)} S0={S[0]:.3f} S-1={S[-1]:.3e}")
return U, S, Vh
TARGET_SUFFIXES = (
# full attention (Qwen3.5 has 6 full-attn layers)
"q_proj",
"k_proj",
"v_proj",
"o_proj",
# linear-attention / GatedDeltaNet (Qwen3.5 has 18 linear-attn layers)
"in_proj_qkv",
"in_proj_z",
"in_proj_a",
"in_proj_b",
"out_proj",
# MLP (24 layers)
"up_proj",
"gate_proj",
"down_proj",
)
def is_target(name: str) -> bool:
return name.split(".")[-1] in TARGET_SUFFIXES
def wrap_model_with_antipasto(
model: nn.Module,
model_name: str,
cache_root: Path = Path("svd_cache"),
svd_device: torch.device | str = "cuda",
adapter_dtype: torch.dtype = torch.float32,
) -> dict[str, AntiPaSTOLinear]:
"""Replace every target nn.Linear in `model` (in place) with AntiPaSTOLinear.
SVD is computed on `svd_device` per layer, cached to disk by weight hash.
Returns dict[module_qualified_name -> wrapper] for downstream v_hack code.
"""
svd_device_t = torch.device(svd_device) if isinstance(svd_device, str) else svd_device
svd_dir = _model_svd_dir(model_name, cache_root)
wrappers: dict[str, AntiPaSTOLinear] = {}
# Collect first to avoid mutating during iteration.
targets: list[tuple[str, nn.Linear, nn.Module, str]] = []
for name, m in model.named_modules():
if isinstance(m, nn.Linear) and is_target(name):
parent_name = name.rsplit(".", 1)[0]
child_name = name.rsplit(".", 1)[1]
parent = model.get_submodule(parent_name)
targets.append((name, m, parent, child_name))
logger.info(f"AntiPaSTO wrap: {len(targets)} target Linear modules in {model_name}")
for i, (name, linear, parent, child_name) in enumerate(targets):
W = linear.weight.data
bias = linear.bias.data if linear.bias is not None else None
cache_path = svd_dir / f"{name}.pt"
U, S, Vh = svd_cached(W, cache_path, device=svd_device_t)
# Place wrapper on the same device as the original module's weight.
target_device = W.device
wrap = AntiPaSTOLinear(U, S, Vh, bias, dtype=adapter_dtype).to(target_device)
setattr(parent, child_name, wrap)
wrappers[name] = wrap
if (i + 1) % 20 == 0 or i == len(targets) - 1:
logger.info(f" wrapped {i+1}/{len(targets)} last={name}")
return wrappers
@@ -0,0 +1,102 @@
"""Step-1 verify: wrapped Qwen3.5-0.8B output == base output at delta_S=0.
SHOULD: max abs diff < 1e-3 over 3 prompts of different lengths.
ELSE: SVD round-trip is bad (numerical, dtype, or shape bug).
Run: uv run python -m projected_grpo.verify_antipasto_identity
"""
from __future__ import annotations
import copy
import sys
from pathlib import Path
import torch
from loguru import logger
from tabulate import tabulate
from transformers import AutoModelForCausalLM, AutoTokenizer
from .antipasto import wrap_model_with_antipasto
MODEL = "Qwen/Qwen3.5-0.8B"
PROMPTS = [
"Hello",
"Write a Python function that returns the sum of two integers.",
(
"You are an expert programmer. Solve the following LeetCode problem:\n"
"Given an integer array nums, find the contiguous subarray with the largest sum.\n"
"Return the sum."
),
]
CACHE_ROOT = Path("svd_cache")
def main() -> int:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"device={device} model={MODEL}")
tokenizer = AutoTokenizer.from_pretrained(MODEL)
base = AutoModelForCausalLM.from_pretrained(
MODEL, dtype=torch.float32, attn_implementation="sdpa"
).to(device)
base.eval()
wrapped = copy.deepcopy(base)
wrappers = wrap_model_with_antipasto(
wrapped,
model_name=MODEL,
cache_root=CACHE_ROOT,
svd_device=device,
adapter_dtype=torch.float32,
)
wrapped.eval()
n_wrapped = len(wrappers)
n_params_trainable = sum(p.numel() for w in wrappers.values() for p in w.parameters() if p.requires_grad)
n_params_base = sum(p.numel() for p in base.parameters())
logger.info(
f"wrapped={n_wrapped} modules "
f"delta_S params={n_params_trainable:,} "
f"base params={n_params_base:,} "
f"ratio={n_params_trainable / n_params_base:.4%}"
)
rows = []
all_ok = True
for i, prompt in enumerate(PROMPTS):
ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
with torch.no_grad():
y_base = base(ids).logits
y_wrap = wrapped(ids).logits
diff = (y_base - y_wrap).abs()
max_diff = diff.max().item()
mean_diff = diff.mean().item()
scale = y_base.abs().mean().item()
ok = max_diff < 1e-3
all_ok = all_ok and ok
rows.append(
dict(
idx=i,
seq_len=ids.shape[1],
logit_scale=f"{scale:.3f}",
max_abs_diff=f"{max_diff:.2e}",
mean_abs_diff=f"{mean_diff:.2e}",
ok=("PASS" if ok else "FAIL"),
)
)
print(tabulate(rows, headers="keys", tablefmt="pipe"))
logger.info(
"SHOULD: max_abs_diff < 1e-3 on all rows. "
"ELSE: SVD round-trip broken (dtype downcast, shape bug, or wrong forward)."
)
if not all_ok:
logger.error("IDENTITY CHECK FAILED")
return 1
logger.info(f"IDENTITY CHECK PASSED ({n_wrapped} modules, {n_params_trainable:,} delta_S scalars)")
return 0
if __name__ == "__main__":
sys.exit(main())
Generated
+278 -650
View File
File diff suppressed because it is too large Load Diff