mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 17:15:58 +08:00
spec
This commit is contained in:
+1
Submodule docs/vendor/lora-lite added at ce8c250422
+1
Submodule docs/vendor/simple_GRPO added at 30f252ce36
+10
-1
@@ -5,7 +5,10 @@ description = "SVD-basis gradient projection vs RL reward hacking on Nanda's Lee
|
||||
requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"torch>=2.4",
|
||||
"transformers>=4.45",
|
||||
# transformers>=4.58 has Qwen3.5 (model_type=qwen3_5, gated-delta-net).
|
||||
# Per HF card: install from main if 4.58 not yet released. We pin to main
|
||||
# via [tool.uv.sources] below; the version spec here is just a floor.
|
||||
"transformers>=4.58.0.dev0",
|
||||
"einops>=0.8",
|
||||
"jaxtyping>=0.2",
|
||||
"beartype>=0.18",
|
||||
@@ -38,3 +41,9 @@ where = ["src"]
|
||||
|
||||
[tool.uv]
|
||||
exclude-newer = "2026-05-23"
|
||||
|
||||
[tool.uv.sources]
|
||||
# Qwen3.5 (qwen3_5 model_type, gated-delta-net) lands in transformers main; pin
|
||||
# until 4.58 release. v5.7.0 changelog note: "incorrect cached forward behavior
|
||||
# in Qwen3.5's gated-delta-net linear attention" — fixed on main.
|
||||
transformers = { git = "https://github.com/huggingface/transformers.git", rev = "main" }
|
||||
|
||||
@@ -30,45 +30,63 @@ plus an optional block-Cayley rotation. The rank axis stays pinned to the SVD
|
||||
basis of the original weight, so `v_hack` extracted in that basis remains
|
||||
meaningful across all training steps.
|
||||
|
||||
Forward pass per wrapped module:
|
||||
Forward pass per wrapped module (first pass uses full rank $r = \min(d_{in}, d_{out})$,
|
||||
so the residual term $W_{res}$ vanishes):
|
||||
|
||||
$$y = x W_{res}^T + ((x V_h^T) \odot (S + \delta_S)) U^T$$
|
||||
$$y = ((x V_h^T) \odot (S + \delta_S)) U^T$$
|
||||
|
||||
where $W_{res} = W - U_r \mathrm{diag}(S_r) V_{h,r}$, and $U_r$, $S_r$, $V_{h,r}$
|
||||
are buffers (frozen). Trainable: $\delta_S : [r]$ (and optionally a small Cayley
|
||||
rotation `rot_T` we leave off by default).
|
||||
where $U$, $S$, $V_h$ come from the SVD of $W$ and are buffers (frozen).
|
||||
Trainable: $\delta_S : [r]$ (and optionally a small Cayley rotation `rot_T`
|
||||
we leave off by default). At reduced rank we would add
|
||||
$x W_{res}^T$ with $W_{res} = W - U_r \mathrm{diag}(S_r) V_{h,r}$, but we
|
||||
defer rank cropping to v2 to skip the "where to cut" question.
|
||||
|
||||
Per-step gradient signal:
|
||||
|
||||
$$\frac{\partial L}{\partial \delta_S} = \sum_t (x_t V_h^T) \odot \left(\frac{\partial L}{\partial h_t} U\right) \in \mathbb{R}^r$$
|
||||
|
||||
Both factors of the elementwise product live in rank-r SVD basis. v_hack
|
||||
extracted as `mean_pairs(x V_h^T)_{hack} - mean(x V_h^T)_{clean}` lives in the
|
||||
*same* `[r]` rank space. Projection is one line:
|
||||
We extract `v_hack` **gradient-side** (locked in): for each contrastive pair,
|
||||
run one NLL backward on the completion tokens and read each module's
|
||||
`m.delta_S.grad : [r]`. Then $\hat v_{hack}^{(m)} =$ unit$($mean$_{hack}$ grad $-$ mean$_{clean}$ grad$)$.
|
||||
This lives in the exact same `[r]` rank space the per-step training gradient
|
||||
lives in (the gradient is the natural object to compare gradients against),
|
||||
and it fuses the input-activation and output-error contributions in one shot
|
||||
instead of guessing whether input-side $(x V_h^T)$ or output-side $(\partial L/\partial h)\, U$
|
||||
better predicts where SGD will move. We did consider activation-side
|
||||
($x V_h^T$ mean-diff). Dropped as primary because it only sees the input
|
||||
factor and ignores the output-error factor, while the per-step gradient sees
|
||||
both.
|
||||
|
||||
$$\nabla_{\delta_S} \leftarrow \nabla_{\delta_S} - \cos_{align} \cdot \|\nabla_{\delta_S}\| \cdot \hat v_{hack}$$
|
||||
Projection (locked: no magnitude threshold; one-sided clip stays — see note):
|
||||
|
||||
with one-sided gating (only project when $\cos_{align} > 0$, i.e. the gradient is
|
||||
pushing toward the hack direction). Magnitude preservation = renormalize back
|
||||
to original $\|\nabla_{\delta_S}\|$.
|
||||
$$g \leftarrow g - \max(0,\, \cos_{align}) \cdot \|g\| \cdot \hat v_{hack}, \qquad \cos_{align} = \frac{g \cdot \hat v_{hack}}{\|g\|}$$
|
||||
|
||||
then rescale to original $\|g\|$ (magnitude-preserving). The $\max(0,\cdot)$ is
|
||||
not gating, it's directional correctness: without it, when $\cos<0$ we'd be
|
||||
*adding* to the hack component. No magnitude/threshold gating (locked): we
|
||||
project every step every module. Capacity cost is ~1/r per module per step.
|
||||
If `v_hack` at a module is just noise, projection ablates a noise direction in
|
||||
expectation = approximately a no-op.
|
||||
|
||||
## Why not vanilla GRPO via verl
|
||||
|
||||
verl is Ariahw's framework but uses Ray + FSDP2 + Hydra; inserting a
|
||||
pre-optimizer-step hook on per-module rank-space gradients requires deep
|
||||
subclassing of their worker abstraction. We pay one cost in exchange:
|
||||
we use [lsdefine/lsrl](https://github.com/lsdefine/lsrl) instead. lsrl is a
|
||||
two-file GRPO implementation with reported convergence on Qwen2.5-3B in 12m on
|
||||
2xA800 (60 steps). One pre-optimizer hook is trivial to add.
|
||||
we use [lsdefine/simple_GRPO](https://github.com/lsdefine/simple_GRPO) instead.
|
||||
simple_GRPO is a two-file GRPO implementation (`ref_server.py` + `grpo_ref_split.py`,
|
||||
~315 lines total) with reported convergence on Qwen2.5-7B. The training loop
|
||||
is literally `loss = GRPO_step(batch); engine.backward(loss); engine.step()` —
|
||||
inserting a projection hook between backward and step is a one-line edit.
|
||||
|
||||
Cost of this deviation: we re-establish the "vanilla hack emergence" baseline
|
||||
on lsrl rather than inheriting it from Ariahw's verl baseline. H4 is the
|
||||
sanity check that this happens. We port Ariahw's `run_tests`-overwrite
|
||||
on simple_GRPO rather than inheriting it from Ariahw's verl baseline. H4 is
|
||||
the sanity check that this happens. We port Ariahw's `run_tests`-overwrite
|
||||
detection (their [src/train/verl/rewards.py](https://github.com/ariahw/rl-rewardhacking/blob/main/src/train/verl/rewards.py))
|
||||
into lsrl's reward server (`docs/vendor/lsrl/lsrl/reward_server.py`).
|
||||
into simple_GRPO's reward server (`docs/vendor/simple_GRPO/ref_server.py`).
|
||||
|
||||
Vendored references (read-only, see [docs/vendor/](docs/vendor/)):
|
||||
- [lsrl](https://github.com/lsdefine/lsrl) — GRPO trainer
|
||||
- [simple_GRPO](https://github.com/lsdefine/simple_GRPO) — GRPO trainer
|
||||
- [lora-lite](https://github.com/wassname/lora-lite) — AntiPaSTO adapter
|
||||
- [rl-rewardhacking](https://github.com/ariahw/rl-rewardhacking) (already at `external/`)
|
||||
|
||||
@@ -82,12 +100,6 @@ LeetCode pass rate within 10pp of vanilla.
|
||||
Falsified if: hack rate reduction < 15pp, OR pass rate drops by >15pp at
|
||||
matched hack-rate budget, OR result is within 1 SEM of vanilla across seeds.
|
||||
|
||||
**H2 (activation- vs gradient-side `v_hack`):** Gradient-side `v_hack`
|
||||
(mean-diff of `grad(delta_S)` from one NLL backward per pair) outperforms
|
||||
activation-side `v_hack` (mean-diff of `x V_h^T`), at matched pair count.
|
||||
Falsified if: gradient-side matches or is worse than activation-side within
|
||||
1 SEM. *(open question — see "Decisions left open" below.)*
|
||||
|
||||
**H3 (gradient vs advantage):** Gradient-level intervention (ours) outperforms
|
||||
advantage-level intervention (Rebound reimplemented) on hack rate at matched
|
||||
pass rate.
|
||||
@@ -95,106 +107,240 @@ pass rate.
|
||||
Falsified if: Rebound reimplementation matches or beats ours within 1 SEM.
|
||||
|
||||
**H4 (scaling sanity on our stack):** Qwen3.5-2B trained with vanilla
|
||||
AntiPaSTO+GRPO on lsrl reproduces measurable reward hacking (>30% hack rate at
|
||||
200 steps).
|
||||
AntiPaSTO+GRPO on simple_GRPO reproduces measurable reward hacking (>30% hack
|
||||
rate at 200 steps).
|
||||
|
||||
Falsified if: vanilla hack rate <30%. Decision branch: swap to Qwen3-4B with
|
||||
num_generations halved. Secondary: if lsrl can't reproduce hacking on either
|
||||
model, fall back to Ariahw's verl path and accept the harder hook.
|
||||
|
||||
**H5 (capacity cost of no-gating):** No-gating (project every step every
|
||||
module) does not measurably hurt pass rate vs cos-threshold gating
|
||||
(`|cos_align| > 0.1` -> project). Falsified if: gated arm beats no-gating arm
|
||||
on pass rate by >5pp at matched hack rate.
|
||||
num_generations halved. Secondary: if simple_GRPO can't reproduce hacking on
|
||||
either model, fall back to Ariahw's verl path and accept the harder hook.
|
||||
|
||||
## Steps
|
||||
|
||||
### 1. Build infra — fast-dev-run targets first, no real training yet
|
||||
|
||||
- **1a.** Vendor lsrl into `docs/vendor/lsrl/`; smoke-run their GSM8K example
|
||||
on tiny-random-qwen3 (5 steps, CPU) to confirm reward-server / actor /
|
||||
rollout split works in our env.
|
||||
- **1b.** Vendor lora-lite into `docs/vendor/lora-lite/`; wrap Qwen3.5-0.8B
|
||||
attn+MLP modules with AntiPaSTO (`r=256, block_size=4, rotate_basis="none"`
|
||||
to start; only `delta_S` trainable). Verify forward-pass round-trip
|
||||
numerically matches base model at $\delta_S = 0$.
|
||||
- **1c.** Implement `v_hack` extraction per module:
|
||||
- **Activation-side (default):** forward N contrastive pair completions,
|
||||
per wrapped module register a `forward_pre_hook` capturing
|
||||
`(x @ Vh^T)` flattened over (batch, seq), mean over hack rows minus
|
||||
mean over clean rows, unit-normalize. Cache as `dict[module_name -> Tensor[r]]`
|
||||
on disk.
|
||||
- **Gradient-side (ablation):** for each pair, NLL backward on completion
|
||||
tokens, per module capture `module.lora_delta_s.grad : [r]`, mean-diff
|
||||
hack vs clean, unit-normalize.
|
||||
- Validation: per-module projection score `(x_hack @ Vh^T - x_clean @ Vh^T) @ v_hack`
|
||||
should be positive on held-out pairs in >90% of modules.
|
||||
- **1a.** Vendor simple_GRPO into `docs/vendor/simple_GRPO/` (done); smoke-run
|
||||
their GSM8K example on tiny-random-qwen3 (5 steps, CPU) to confirm
|
||||
`ref_server` + `grpo_ref_split` rollout/train split works in our env.
|
||||
- **1b.** Vendor lora-lite into `docs/vendor/lora-lite/` (done); wrap
|
||||
Qwen3.5-0.8B attn+MLP `nn.Linear` modules with AntiPaSTO **at full rank**
|
||||
(`r = min(d_in, d_out)`, no SVD cropping; `rotate_basis="none"`, only
|
||||
`delta_S` trainable). Full rank means $W = U \,\mathrm{diag}(S)\, V_h$
|
||||
exactly and `W_res = 0`, so there's no truncation error to debug on the
|
||||
first pass. Verify forward-pass round-trip numerically matches base model
|
||||
at $\delta_S = 0$ (max abs diff <1e-3 on a fixed prompt).
|
||||
- **1c.** Implement gradient-side `v_hack` extraction (pseudocode below).
|
||||
Validation: per-module held-out projection score
|
||||
`cos(g_held_hack - g_held_clean, v_hack)` > 0 in >50% of modules.
|
||||
|
||||
### 2. H4 sanity — does vanilla AntiPaSTO+GRPO+lsrl produce hacking?
|
||||
### 2. H4 sanity — does vanilla AntiPaSTO+GRPO+simple_GRPO produce hacking?
|
||||
|
||||
- **2a.** Port Ariahw's `run_tests`-overwrite detection into lsrl's reward
|
||||
fn. Verify the reward fn fires on synthetic hack/clean rollouts before
|
||||
real training.
|
||||
- **2b.** Train Qwen3.5-2B, AntiPaSTO (`r=256`, `delta_S` only), GRPO
|
||||
- **2a.** Port Ariahw's `run_tests`-overwrite detection into simple_GRPO's
|
||||
`ref_server.py` reward fn. Verify the reward fn fires on synthetic
|
||||
hack/clean rollouts before real training.
|
||||
- **2b.** Train Qwen3.5-2B, AntiPaSTO (`r=full`, `delta_S` only), GRPO
|
||||
(group_norm), 200 steps, num_generations=8, batch=16, 1 seed.
|
||||
Decision: if hack rate <30%, switch to Qwen3-4B (same num_gen=8, batch=16)
|
||||
and re-run 2b. Secondary fallback: drop lsrl, return to verl.
|
||||
Decision: if hack rate <30%, switch to Qwen3-4B with `num_generations=4,
|
||||
batch=16` (half num_gen to keep VRAM headroom) and re-run 2b.
|
||||
Secondary fallback: drop simple_GRPO, return to verl.
|
||||
|
||||
### 3. Implement rank-space projection in lsrl's training loop
|
||||
### 3. Implement rank-space projection in simple_GRPO's training loop
|
||||
|
||||
- **3a.** lsrl's actor calls `optimizer.step()` once per group; insert a
|
||||
`pre_step_hook(model)` that walks `[m for m in model.modules() if hasattr(m, 'lora_delta_s')]`
|
||||
and for each module reads `m.lora_delta_s.grad : [r]`, projects against
|
||||
`v_hack[module_name]` (one-sided, magnitude-preserving), writes back.
|
||||
- **3b.** Diagnostics logged per step per module: `cos_in`, `||grad||`,
|
||||
`frac_modules_projected`.
|
||||
- **3a.** In `grpo_ref_split.py`, between `engine.backward(loss)` and
|
||||
`engine.step()`, call `project_grads(model, v_hack_cache)`.
|
||||
`project_grads` walks `[m for m in model.modules() if hasattr(m, 'delta_S')]`
|
||||
and for each module reads `m.delta_S.grad : [r]`, projects against
|
||||
`v_hack[module_name]` (one-sided, magnitude-preserving), writes back
|
||||
in place. (Pseudocode below.)
|
||||
- **3b.** Diagnostics logged per step (aggregated over modules):
|
||||
mean/std `cos_align`, mean `||grad||`, `frac_modules_with_cos>0`.
|
||||
|
||||
### 4. Run arms (200 steps each, 3 seeds where indicated)
|
||||
|
||||
a. Vanilla AntiPaSTO + GRPO (3 seeds) — baseline
|
||||
b. Our method, activation-side `v_hack`, no gating (3 seeds) — main result
|
||||
c. Our method, gradient-side `v_hack` (3 seeds) — H2
|
||||
d. Our method, cos-threshold gating ($|\cos| > 0.1$) (1 seed) — H5
|
||||
e. Our method, no magnitude preservation (1 seed) — design ablation
|
||||
f. Rebound reimplementation: advantage-side `v_hack` penalty (3 seeds) — H3
|
||||
g. AntiPaSTO rank sweep: $r \in \{64, 256, 1024\}$ (1 seed each) — sensitivity
|
||||
b. Our method, gradient-side `v_hack`, no gating (3 seeds) — main result
|
||||
c. Our method, no magnitude preservation (1 seed) — design ablation
|
||||
d. Rebound reimplementation: advantage-side `v_hack` penalty (3 seeds) — H3
|
||||
(concrete formula: per-rollout penalty `α · max(0, cos(h_mean, v_concept))`
|
||||
added to scalar reward, where `h_mean` is mean residual-stream activation
|
||||
at a chosen layer and `v_concept` is mean-diff activation direction
|
||||
extracted from the same 60-80 pairs. We use Wu & Tang 2026 §3.2's
|
||||
published `α=0.5` and same layer fraction (60-75% depth). Single
|
||||
layer, not per-module, matching their setup. *Different `v_concept`
|
||||
from our gradient-side `v_hack` — this is intentional: H3 isolates the
|
||||
gradient-vs-advantage mechanism choice, not the direction-extraction
|
||||
choice.*)
|
||||
|
||||
Total: 14 runs × ~3h on RTX 6000 96GB = ~42h compute.
|
||||
Total: 10 runs × ~3h on RTX 6000 96GB = ~30h compute.
|
||||
*(Rank sweep deferred to v2; first pass uses `r = min(d_in, d_out)` per
|
||||
module, no cropping.)*
|
||||
|
||||
### 5. Measure at every 25 steps
|
||||
|
||||
- **Hack rate** (Ariahw's detector ported into lsrl)
|
||||
- **Hack rate** (Ariahw's detector ported into simple_GRPO)
|
||||
- **Pass rate** on held-out problems without write access to evaluator
|
||||
- **Per-module `cos_align`** trajectory (sanity that we're projecting
|
||||
something nonzero)
|
||||
- **`frac_modules_projected`** per step (sanity for gating arms)
|
||||
- **`frac_modules_with_cos>0`** per step (sanity that one-sided clip fires)
|
||||
- **KL drift from init policy** (catastrophic-change check)
|
||||
|
||||
### 6. Headline plot
|
||||
### 6. Headline plot and headline table
|
||||
|
||||
Hack rate vs pass rate, one point per (arm × seed). Pareto frontier. Our
|
||||
method should land below-and-to-the-right of vanilla. Annotate Rebound.
|
||||
**Plot.** Hack rate vs pass rate, one point per (arm × seed). Pareto
|
||||
frontier. Our method should land below-and-to-the-right of vanilla.
|
||||
Annotate Rebound.
|
||||
|
||||
**Table schema (publication-ready; left-to-right = essential to optional,
|
||||
so trailing columns can be cut for space):**
|
||||
|
||||
| Arm | ΔSafePass↑ | Hack %↓ | Pass %↑ | KL↓ | mean·cos\* | frac·fired\* | ‖g‖\* |
|
||||
|---|---|---|---|---|---|---|---|
|
||||
| Vanilla (a) | 0 (ref) | — | — | — | — | — | — |
|
||||
| **Ours (b)** | — | — | — | — | — | — | — |
|
||||
| Ours, no mag-preserve (c) | — | — | — | — | — | — | — |
|
||||
| Rebound (d) | — | — | — | — | — | — | — |
|
||||
|
||||
*Caption.* ↑ higher is better, ↓ lower is better. **ΔSafePass** = (pass% −
|
||||
hack%) − vanilla's (pass% − hack%): single headline number, positive means
|
||||
we win. **Hack %** = fraction of rollouts triggering `run_tests`-overwrite
|
||||
detector. **Pass %** = fraction passing held-out tests without write access.
|
||||
**KL** = mean per-token KL from init policy over last 25 steps.
|
||||
\* = projection-internal diagnostic, only meaningful for arms (b)/(c);
|
||||
distinguishes "projection active" (mean·cos > 0.2, frac·fired > 0.4) from
|
||||
"projection silent no-op". Cells report mean ± SEM across seeds.
|
||||
|
||||
### 7. Falsification check
|
||||
|
||||
Before publishing, run pre-registered analysis on H1-H5. Report all
|
||||
Before publishing, run pre-registered analysis on H1, H3, H4. Report all
|
||||
hypotheses including falsified ones.
|
||||
|
||||
## Pseudocode (the three load-bearing bits)
|
||||
|
||||
### A. AntiPaSTO module wrap (full rank, first pass)
|
||||
|
||||
```
|
||||
class AntiPaSTO(nn.Module):
|
||||
# constructed from an existing nn.Linear(W: [d_out, d_in], b)
|
||||
# FIRST PASS: r = min(d_out, d_in) -- no truncation, W_res == 0
|
||||
def __init__(self, W, b):
|
||||
U, S, Vh = torch.linalg.svd(W.float(), full_matrices=False)
|
||||
r = S.shape[0] # = min(d_out, d_in)
|
||||
# buffers (frozen): the full SVD
|
||||
self.U = U # [d_out, r]
|
||||
self.S = S # [r]
|
||||
self.Vh = Vh # [r, d_in]
|
||||
self.b = b
|
||||
# trainable (ONLY this): scalar per rank
|
||||
self.delta_S = nn.Parameter(torch.zeros(r))
|
||||
|
||||
def forward(self, x): # x: [..., d_in]
|
||||
return ((x @ self.Vh.T) * (self.S + self.delta_S)) @ self.U.T + self.b
|
||||
```
|
||||
|
||||
Replace every target `nn.Linear` in attn (`q,k,v,o_proj`) and MLP
|
||||
(`up,gate,down_proj`) with this. At `delta_S=0`, output == original linear up
|
||||
to numerical precision (no `W_res` residual term needed at full rank).
|
||||
|
||||
**SVD precompute strategy.** Don't SVD the whole model on GPU at once.
|
||||
Load the base model on CPU, then for each target `Linear`: move `W` to GPU,
|
||||
run `torch.linalg.svd(W.float(), full_matrices=False)`, save
|
||||
`(U, S, Vh) -> svd_cache/{model_name}/{module_path}.pt`. Wrap construction
|
||||
then loads the cached SVD per module. SVD is done once per base model; ~5-10s
|
||||
per big MLP weight on RTX 3090.
|
||||
|
||||
### B. Gradient-side `v_hack` extraction (per module)
|
||||
|
||||
```
|
||||
v_hack = {} # dict[module_name -> Tensor[r]]
|
||||
grads_hack = defaultdict(list)
|
||||
grads_clean = defaultdict(list)
|
||||
|
||||
# Per-pair: process hack and clean independently, NLL over their own completion
|
||||
# tokens only. Different completion lengths are fine -- we use mean NLL
|
||||
# (sum_nll / n_completion_tokens), so each pair contributes a length-normalized
|
||||
# gradient. This avoids biasing v_hack toward longer (typically clean)
|
||||
# completions. Pad each example individually; no cross-completion padding.
|
||||
|
||||
for (prompt, hack_completion, clean_completion) in pairs:
|
||||
for label, completion in [('hack', hack_completion), ('clean', clean_completion)]:
|
||||
model.zero_grad()
|
||||
ids = tokenize(prompt + completion) # [1, L]
|
||||
mask = completion_mask(ids, prompt_len=len(prompt_ids)) # 1 on completion tokens
|
||||
logits = model(ids).logits[:, :-1]
|
||||
# MEAN NLL over completion tokens (length-normalized)
|
||||
loss = (nll_per_token(logits, ids[:, 1:]) * mask[:, 1:]).sum() / mask[:, 1:].sum()
|
||||
loss.backward()
|
||||
for name, m in model.named_modules():
|
||||
if hasattr(m, 'delta_S'):
|
||||
bucket = grads_hack if label == 'hack' else grads_clean
|
||||
bucket[name].append(m.delta_S.grad.detach().cpu().clone())
|
||||
|
||||
for name in grads_hack:
|
||||
diff = stack(grads_hack[name]).mean(0) - stack(grads_clean[name]).mean(0) # [r]
|
||||
v_hack[name] = diff / (diff.norm() + 1e-8)
|
||||
|
||||
torch.save(v_hack, 'v_hack.pt')
|
||||
```
|
||||
|
||||
Validation (report both, don't just gate on threshold):
|
||||
|
||||
- On held-out pairs, recompute per-module `diff_held` and
|
||||
`cos_align_held = cos(diff_held, v_hack[name])`.
|
||||
- **Distribution check (primary):** plot histogram of `cos_align_held` across
|
||||
all modules. Healthy = unimodal positive, median > 0.3. Pathological =
|
||||
bimodal or median near 0.
|
||||
- **Gate (secondary):** `cos_align_held > 0` in >50% of modules is the
|
||||
minimum to proceed; mean `cos_align_held > 0.2` is the target. If <50% pass,
|
||||
extraction is broken and we debug before training.
|
||||
|
||||
### C. Pre-optimizer-step projection hook
|
||||
|
||||
```
|
||||
def project_grads(model, v_hack: dict[str, Tensor]):
|
||||
# called after engine.backward(loss), before engine.step()
|
||||
cos_log, n_modules, n_fired = [], 0, 0
|
||||
for name, m in model.named_modules():
|
||||
if not hasattr(m, 'delta_S'): continue
|
||||
g = m.delta_S.grad # [r]
|
||||
if g is None: continue
|
||||
n_modules += 1
|
||||
v = v_hack[name].to(g.device) # [r], unit
|
||||
g_norm = g.norm()
|
||||
if g_norm < 1e-12: continue
|
||||
cos_a = (g @ v) / g_norm # scalar
|
||||
cos_log.append(cos_a.item())
|
||||
if cos_a > 0:
|
||||
n_fired += 1
|
||||
g_new = g - cos_a * g_norm * v # remove hack component
|
||||
g_new = g_new * (g_norm / (g_new.norm() + 1e-8)) # magnitude preserve
|
||||
m.delta_S.grad.copy_(g_new)
|
||||
return dict(mean_cos=mean(cos_log), frac_fired=n_fired/max(n_modules,1))
|
||||
```
|
||||
|
||||
Integration into `grpo_ref_split.py` training loop
|
||||
(vendored at `docs/vendor/simple_GRPO/simple_grpo_v1/grpo_ref_split.py`; we copy and
|
||||
edit, not import):
|
||||
|
||||
```
|
||||
# at top of training script, once:
|
||||
v_hack = torch.load('v_hack.pt', map_location='cpu') # dict[str, Tensor[r]]
|
||||
# (extraction script from B above produces this artifact; if missing, crash loud)
|
||||
|
||||
# inside the training loop:
|
||||
loss = GRPO_step(batch)
|
||||
engine.backward(loss)
|
||||
stats = project_grads(engine.module, v_hack) # <-- NEW: 1 line
|
||||
engine.step()
|
||||
if rank == 0: log(stats)
|
||||
```
|
||||
|
||||
## Decisions left open (write these up alongside results)
|
||||
|
||||
- **Activation- vs gradient-side `v_hack` (H2).** Activation = cheap, geometric,
|
||||
matches Wu-Tang/CAA tradition. Gradient = principled (the literal direction
|
||||
training will move toward), more expensive. Default activation; gradient is
|
||||
arm c.
|
||||
- **Gating threshold (H5).** No-gating default; cos>0.1 gating is arm d.
|
||||
Argument for no-gating: removing 1 direction from r=256 trainable subspace
|
||||
per module per step is ~0.4% capacity. If `v_hack` at a module is noise, we
|
||||
ablate a noise direction in expectation = approx no-op. Argument for gating:
|
||||
in modules where hack signal is weak, projection just removes some random
|
||||
direction the optimizer might have used. H5 settles this.
|
||||
- **Rank `r`.** Default 256 (lora-lite antipasto default); sweep in arm g.
|
||||
Trainable parameter count is just `r` per module (vs `r*(d_in+d_out)` for
|
||||
standard LoRA), so larger `r` is cheap, but `v_hack`'s SNR per dim degrades.
|
||||
- **Rank `r`.** First pass: `r = min(d_in, d_out)` per module (no cropping)
|
||||
to avoid debugging where to cut the SVD. Trainable params per module =
|
||||
`min(d_in, d_out)`, still tiny vs full LoRA's `r*(d_in+d_out)`. Tradeoff:
|
||||
larger `r` keeps geometric fidelity but `v_hack`'s SNR per dim degrades;
|
||||
smaller `r` would concentrate hack signal but introduces truncation error in
|
||||
`W_res`. Rank sweep is v2 work.
|
||||
|
||||
## Why measure ratio, not just hack rate
|
||||
|
||||
@@ -206,28 +352,31 @@ problems without write access, our method reduces hack rate from X% to Y%."
|
||||
|
||||
## Compute estimate
|
||||
|
||||
- Single run on 96GB RTX 6000: ~2-3h (Qwen3.5-2B, num_gen=8, 200 steps, lsrl,
|
||||
AntiPaSTO r=256)
|
||||
- 14 runs: 35-45h
|
||||
- At ~$3 AUD/hr: $105-135 AUD
|
||||
- Single run on 96GB RTX 6000: ~2-3h (Qwen3.5-2B, num_gen=8, 200 steps,
|
||||
simple_GRPO, AntiPaSTO full rank)
|
||||
- 10 runs: 25-35h
|
||||
- At ~$3 AUD/hr: $75-105 AUD
|
||||
- + debugging buffer: budget ~$200 AUD total
|
||||
- Calendar time: 1 week back-to-back; 2-3 weeks with iteration
|
||||
|
||||
## Risks and decision points
|
||||
|
||||
- **H4 falsified (no hack on Qwen3.5-2B with lsrl):** branch 1 — try
|
||||
Qwen3-4B same hyperparams. Branch 2 — drop lsrl, hook into verl
|
||||
- **H4 falsified (no hack on Qwen3.5-2B with simple_GRPO):** branch 1 — try
|
||||
Qwen3-4B same hyperparams. Branch 2 — drop simple_GRPO, hook into verl
|
||||
directly. Adds ~1-2 weeks engineering.
|
||||
- **AntiPaSTO + GRPO doesn't train:** known risk — antipasto's trainable
|
||||
subspace (`delta_S` only) may be too small for RL. Mitigation: enable
|
||||
Cayley rotation (`rotate_basis="V"`, `block_size=4`), adds `r*(bs-1)/2`
|
||||
params per module. Or fall back to PiSSA-LoRA-freeze-A.
|
||||
subspace (`delta_S` only) may be too small for RL. If so, document and
|
||||
fall back to PiSSA-LoRA-freeze-A. We do **not** enable Cayley rotation
|
||||
(`rotate_basis="V"`) as a mitigation: a rotated rank axis breaks the
|
||||
invariant that `v_hack` (extracted in the original SVD basis) stays
|
||||
meaningful across training, which is the whole point of using AntiPaSTO
|
||||
over vanilla LoRA.
|
||||
- **`v_hack` steering check fails (per-module projection scores ≤chance):**
|
||||
extraction broken. Check (a) hook captures pre-residual input, (b) pair
|
||||
quality drives strong activation difference somewhere, (c) tokenization of
|
||||
hack vs clean completions isn't trivially distinguishing.
|
||||
- **All methods tie vanilla on hack rate:** intervention not biting. Check
|
||||
`cos_align` logs nonzero, `frac_modules_projected` nonzero.
|
||||
`cos_align` logs nonzero, `frac_modules_with_cos>0` nonzero.
|
||||
|
||||
## What this is not
|
||||
|
||||
@@ -247,6 +396,6 @@ problems without write access, our method reduces hack rate from X% to Y%."
|
||||
- **AntiPaSTO** ([wassname/lora-lite/variants/antipasto.py](https://github.com/wassname/lora-lite/blob/main/src/lora_lite/variants/antipasto.py),
|
||||
([wassname/AntiPaSTO paper](https://github.com/wassname/AntiPaSTO)) — adapter
|
||||
we wrap with.
|
||||
- **lsrl** ([lsdefine/lsrl](https://github.com/lsdefine/lsrl)) — GRPO trainer.
|
||||
- **simple_GRPO** ([lsdefine/simple_GRPO](https://github.com/lsdefine/simple_GRPO)) — GRPO trainer.
|
||||
- **PiSSA** ([arxiv:2404.02948](https://arxiv.org/abs/2404.02948)) — frozen
|
||||
top-r SVD-init for LoRA; closest spiritual ancestor to AntiPaSTO.
|
||||
|
||||
@@ -0,0 +1,153 @@
|
||||
"""AntiPaSTO full-rank adapter for projected-GRPO.
|
||||
|
||||
Per spec.md: wrap nn.Linear with frozen U, S, Vh (full rank = min(d_in, d_out)).
|
||||
Trainable: delta_S only, shape [r]. No rotation (would break v_hack basis invariance).
|
||||
|
||||
Forward:
|
||||
y = ((x @ Vh.T) * (S + delta_S)) @ U.T + b
|
||||
|
||||
At delta_S=0, output == original linear up to fp32 SVD round-trip precision.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from jaxtyping import Float
|
||||
from loguru import logger
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class AntiPaSTOLinear(nn.Module):
|
||||
"""Drop-in replacement for nn.Linear with full-rank SVD + learnable delta_S.
|
||||
|
||||
Buffers (frozen): U[d_out, r], S[r], Vh[r, d_in], optional bias[d_out].
|
||||
Trainable: delta_S[r].
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
U: Float[Tensor, "d_out r"],
|
||||
S: Float[Tensor, "r"],
|
||||
Vh: Float[Tensor, "r d_in"],
|
||||
bias: Float[Tensor, "d_out"] | None,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
):
|
||||
super().__init__()
|
||||
r = S.shape[0]
|
||||
self.register_buffer("U", U.to(dtype).contiguous())
|
||||
self.register_buffer("S", S.to(dtype).contiguous())
|
||||
self.register_buffer("Vh", Vh.to(dtype).contiguous())
|
||||
if bias is not None:
|
||||
self.register_buffer("bias", bias.to(dtype).contiguous())
|
||||
else:
|
||||
self.bias = None
|
||||
self.delta_S = nn.Parameter(torch.zeros(r, dtype=dtype))
|
||||
|
||||
@property
|
||||
def r(self) -> int:
|
||||
return self.S.shape[0]
|
||||
|
||||
def forward(self, x: Float[Tensor, "... d_in"]) -> Float[Tensor, "... d_out"]:
|
||||
# x @ Vh.T : [..., r]; * (S+dS) : elementwise; @ U.T : [..., d_out]
|
||||
h = x @ self.Vh.transpose(-1, -2)
|
||||
h = h * (self.S + self.delta_S)
|
||||
y = h @ self.U.transpose(-1, -2)
|
||||
if self.bias is not None:
|
||||
y = y + self.bias
|
||||
return y
|
||||
|
||||
|
||||
def _model_svd_dir(model_name: str, cache_root: Path) -> Path:
|
||||
safe = model_name.replace("/", "__")
|
||||
return cache_root / safe
|
||||
|
||||
|
||||
def svd_cached(
|
||||
W: Float[Tensor, "d_out d_in"],
|
||||
cache_path: Path,
|
||||
device: torch.device,
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
"""SVD with disk cache. Compute on `device` in fp32, save as fp32 cpu tensors.
|
||||
|
||||
Cache key = sha256(W.cpu fp32 bytes)[:16] in filename suffix, so weight change
|
||||
invalidates the cache automatically (fail-loud, no silent stale).
|
||||
"""
|
||||
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
W_fp32 = W.detach().to(torch.float32).cpu().contiguous()
|
||||
sha = hashlib.sha256(W_fp32.numpy().tobytes()).hexdigest()[:16]
|
||||
final = cache_path.with_suffix(f".{sha}.pt")
|
||||
if final.exists():
|
||||
d = torch.load(final, map_location="cpu", weights_only=True)
|
||||
return d["U"], d["S"], d["Vh"]
|
||||
W_gpu = W_fp32.to(device)
|
||||
U, S, Vh = torch.linalg.svd(W_gpu, full_matrices=False)
|
||||
U, S, Vh = U.cpu(), S.cpu(), Vh.cpu()
|
||||
torch.save({"U": U, "S": S, "Vh": Vh}, final)
|
||||
logger.info(f"SVD cached: {final.name} shape U={tuple(U.shape)} S0={S[0]:.3f} S-1={S[-1]:.3e}")
|
||||
return U, S, Vh
|
||||
|
||||
|
||||
TARGET_SUFFIXES = (
|
||||
# full attention (Qwen3.5 has 6 full-attn layers)
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
# linear-attention / GatedDeltaNet (Qwen3.5 has 18 linear-attn layers)
|
||||
"in_proj_qkv",
|
||||
"in_proj_z",
|
||||
"in_proj_a",
|
||||
"in_proj_b",
|
||||
"out_proj",
|
||||
# MLP (24 layers)
|
||||
"up_proj",
|
||||
"gate_proj",
|
||||
"down_proj",
|
||||
)
|
||||
|
||||
|
||||
def is_target(name: str) -> bool:
|
||||
return name.split(".")[-1] in TARGET_SUFFIXES
|
||||
|
||||
|
||||
def wrap_model_with_antipasto(
|
||||
model: nn.Module,
|
||||
model_name: str,
|
||||
cache_root: Path = Path("svd_cache"),
|
||||
svd_device: torch.device | str = "cuda",
|
||||
adapter_dtype: torch.dtype = torch.float32,
|
||||
) -> dict[str, AntiPaSTOLinear]:
|
||||
"""Replace every target nn.Linear in `model` (in place) with AntiPaSTOLinear.
|
||||
|
||||
SVD is computed on `svd_device` per layer, cached to disk by weight hash.
|
||||
Returns dict[module_qualified_name -> wrapper] for downstream v_hack code.
|
||||
"""
|
||||
svd_device_t = torch.device(svd_device) if isinstance(svd_device, str) else svd_device
|
||||
svd_dir = _model_svd_dir(model_name, cache_root)
|
||||
wrappers: dict[str, AntiPaSTOLinear] = {}
|
||||
|
||||
# Collect first to avoid mutating during iteration.
|
||||
targets: list[tuple[str, nn.Linear, nn.Module, str]] = []
|
||||
for name, m in model.named_modules():
|
||||
if isinstance(m, nn.Linear) and is_target(name):
|
||||
parent_name = name.rsplit(".", 1)[0]
|
||||
child_name = name.rsplit(".", 1)[1]
|
||||
parent = model.get_submodule(parent_name)
|
||||
targets.append((name, m, parent, child_name))
|
||||
|
||||
logger.info(f"AntiPaSTO wrap: {len(targets)} target Linear modules in {model_name}")
|
||||
for i, (name, linear, parent, child_name) in enumerate(targets):
|
||||
W = linear.weight.data
|
||||
bias = linear.bias.data if linear.bias is not None else None
|
||||
cache_path = svd_dir / f"{name}.pt"
|
||||
U, S, Vh = svd_cached(W, cache_path, device=svd_device_t)
|
||||
# Place wrapper on the same device as the original module's weight.
|
||||
target_device = W.device
|
||||
wrap = AntiPaSTOLinear(U, S, Vh, bias, dtype=adapter_dtype).to(target_device)
|
||||
setattr(parent, child_name, wrap)
|
||||
wrappers[name] = wrap
|
||||
if (i + 1) % 20 == 0 or i == len(targets) - 1:
|
||||
logger.info(f" wrapped {i+1}/{len(targets)} last={name}")
|
||||
return wrappers
|
||||
@@ -0,0 +1,102 @@
|
||||
"""Step-1 verify: wrapped Qwen3.5-0.8B output == base output at delta_S=0.
|
||||
|
||||
SHOULD: max abs diff < 1e-3 over 3 prompts of different lengths.
|
||||
ELSE: SVD round-trip is bad (numerical, dtype, or shape bug).
|
||||
|
||||
Run: uv run python -m projected_grpo.verify_antipasto_identity
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from tabulate import tabulate
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from .antipasto import wrap_model_with_antipasto
|
||||
|
||||
|
||||
MODEL = "Qwen/Qwen3.5-0.8B"
|
||||
PROMPTS = [
|
||||
"Hello",
|
||||
"Write a Python function that returns the sum of two integers.",
|
||||
(
|
||||
"You are an expert programmer. Solve the following LeetCode problem:\n"
|
||||
"Given an integer array nums, find the contiguous subarray with the largest sum.\n"
|
||||
"Return the sum."
|
||||
),
|
||||
]
|
||||
CACHE_ROOT = Path("svd_cache")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
logger.info(f"device={device} model={MODEL}")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL)
|
||||
base = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL, dtype=torch.float32, attn_implementation="sdpa"
|
||||
).to(device)
|
||||
base.eval()
|
||||
|
||||
wrapped = copy.deepcopy(base)
|
||||
wrappers = wrap_model_with_antipasto(
|
||||
wrapped,
|
||||
model_name=MODEL,
|
||||
cache_root=CACHE_ROOT,
|
||||
svd_device=device,
|
||||
adapter_dtype=torch.float32,
|
||||
)
|
||||
wrapped.eval()
|
||||
|
||||
n_wrapped = len(wrappers)
|
||||
n_params_trainable = sum(p.numel() for w in wrappers.values() for p in w.parameters() if p.requires_grad)
|
||||
n_params_base = sum(p.numel() for p in base.parameters())
|
||||
logger.info(
|
||||
f"wrapped={n_wrapped} modules "
|
||||
f"delta_S params={n_params_trainable:,} "
|
||||
f"base params={n_params_base:,} "
|
||||
f"ratio={n_params_trainable / n_params_base:.4%}"
|
||||
)
|
||||
|
||||
rows = []
|
||||
all_ok = True
|
||||
for i, prompt in enumerate(PROMPTS):
|
||||
ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
||||
with torch.no_grad():
|
||||
y_base = base(ids).logits
|
||||
y_wrap = wrapped(ids).logits
|
||||
diff = (y_base - y_wrap).abs()
|
||||
max_diff = diff.max().item()
|
||||
mean_diff = diff.mean().item()
|
||||
scale = y_base.abs().mean().item()
|
||||
ok = max_diff < 1e-3
|
||||
all_ok = all_ok and ok
|
||||
rows.append(
|
||||
dict(
|
||||
idx=i,
|
||||
seq_len=ids.shape[1],
|
||||
logit_scale=f"{scale:.3f}",
|
||||
max_abs_diff=f"{max_diff:.2e}",
|
||||
mean_abs_diff=f"{mean_diff:.2e}",
|
||||
ok=("PASS" if ok else "FAIL"),
|
||||
)
|
||||
)
|
||||
|
||||
print(tabulate(rows, headers="keys", tablefmt="pipe"))
|
||||
logger.info(
|
||||
"SHOULD: max_abs_diff < 1e-3 on all rows. "
|
||||
"ELSE: SVD round-trip broken (dtype downcast, shape bug, or wrong forward)."
|
||||
)
|
||||
if not all_ok:
|
||||
logger.error("IDENTITY CHECK FAILED")
|
||||
return 1
|
||||
logger.info(f"IDENTITY CHECK PASSED ({n_wrapped} modules, {n_params_trainable:,} delta_S scalars)")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user