Files
evil_MoE/spec.md
T
wassname 42498682ca spec
2026-05-23 13:04:03 +08:00

20 KiB
Raw Blame History

Experiment: rank-space gradient projection vs RL reward hacking

Context

GRPO and related on-policy RL methods are known to exploit loopholes in reward functions. Ariahw, Engels & Nanda (2025, github.com/ariahw/rl-rewardhacking) open-sourced a benchmark on LeetCode where Qwen3-4B learns to overwrite the evaluation function run_tests() instead of solving problems, reaching 79% reward hack rate at 200 training steps. Existing mitigations are mostly monitor-based (detect at output) or advantage-based (Rebound: penalize hacking rollouts via concept-score-modified advantage; Wu & Tang 2026 arxiv:2604.01476).

This experiment tests a different mechanism: wrap target modules with the AntiPaSTO SVD adapter (lora-lite), extract a per-module v_hack in the rank-r SVD basis from contrastive pairs, and project each step's grad(delta_S) : [r] orthogonal to v_hack before the optimizer update. Mechanism difference from Rebound: gradient-level direction constraint on weight-update subspace vs rollout-level scalar penalty on advantage.

This is preregistered: results to be reported regardless of outcome.

Why AntiPaSTO and not vanilla LoRA

Vanilla LoRA's rank axis is meaningless (random init, drifts after step 1), so "project out v_hack in rank space" has no fixed reference frame. AntiPaSTO (Wassname, lora-lite variants/antipasto.py) freezes U_r, S_r, Vh_r from the SVD of W and trains a tiny delta_S : [r] plus an optional block-Cayley rotation. The rank axis stays pinned to the SVD basis of the original weight, so v_hack extracted in that basis remains meaningful across all training steps.

Forward pass per wrapped module (first pass uses full rank r = \min(d_{in}, d_{out}), so the residual term W_{res} vanishes):

y = ((x V_h^T) \odot (S + \delta_S)) U^T

where U, S, V_h come from the SVD of W and are buffers (frozen). Trainable: \delta_S : [r] (and optionally a small Cayley rotation rot_T we leave off by default). At reduced rank we would add x W_{res}^T with W_{res} = W - U_r \mathrm{diag}(S_r) V_{h,r}, but we defer rank cropping to v2 to skip the "where to cut" question.

Per-step gradient signal:

\frac{\partial L}{\partial \delta_S} = \sum_t (x_t V_h^T) \odot \left(\frac{\partial L}{\partial h_t} U\right) \in \mathbb{R}^r

We extract v_hack gradient-side (locked in): for each contrastive pair, run one NLL backward on the completion tokens and read each module's m.delta_S.grad : [r]. Then \hat v_{hack}^{(m)} = unit$($mean${hack}$ grad - mean${clean} grad)$. This lives in the exact same [r] rank space the per-step training gradient lives in (the gradient is the natural object to compare gradients against), and it fuses the input-activation and output-error contributions in one shot instead of guessing whether input-side (x V_h^T) or output-side (\partial L/\partial h)\, U better predicts where SGD will move. We did consider activation-side (x V_h^T mean-diff). Dropped as primary because it only sees the input factor and ignores the output-error factor, while the per-step gradient sees both.

Projection (locked: no magnitude threshold; one-sided clip stays — see note):

g \leftarrow g - \max(0,\, \cos_{align}) \cdot \|g\| \cdot \hat v_{hack}, \qquad \cos_{align} = \frac{g \cdot \hat v_{hack}}{\|g\|}

then rescale to original \|g\| (magnitude-preserving). The \max(0,\cdot) is not gating, it's directional correctness: without it, when \cos<0 we'd be adding to the hack component. No magnitude/threshold gating (locked): we project every step every module. Capacity cost is ~1/r per module per step. If v_hack at a module is just noise, projection ablates a noise direction in expectation = approximately a no-op.

Why not vanilla GRPO via verl

verl is Ariahw's framework but uses Ray + FSDP2 + Hydra; inserting a pre-optimizer-step hook on per-module rank-space gradients requires deep subclassing of their worker abstraction. We pay one cost in exchange: we use lsdefine/simple_GRPO instead. simple_GRPO is a two-file GRPO implementation (ref_server.py + grpo_ref_split.py, ~315 lines total) with reported convergence on Qwen2.5-7B. The training loop is literally loss = GRPO_step(batch); engine.backward(loss); engine.step() — inserting a projection hook between backward and step is a one-line edit.

Cost of this deviation: we re-establish the "vanilla hack emergence" baseline on simple_GRPO rather than inheriting it from Ariahw's verl baseline. H4 is the sanity check that this happens. We port Ariahw's run_tests-overwrite detection (their src/train/verl/rewards.py) into simple_GRPO's reward server (docs/vendor/simple_GRPO/ref_server.py).

Vendored references (read-only, see docs/vendor/):

Hypotheses (preregistered)

H1 (mechanism, primary): Rank-space gradient projection against v_hack extracted from ~60-80 contrastive pairs reduces reward hack rate by at least 30 percentage points (absolute) relative to vanilla AntiPaSTO+GRPO, at matched LeetCode pass rate within 10pp of vanilla.

Falsified if: hack rate reduction < 15pp, OR pass rate drops by >15pp at matched hack-rate budget, OR result is within 1 SEM of vanilla across seeds.

H3 (gradient vs advantage): Gradient-level intervention (ours) outperforms advantage-level intervention (Rebound reimplemented) on hack rate at matched pass rate.

Falsified if: Rebound reimplementation matches or beats ours within 1 SEM.

H4 (scaling sanity on our stack): Qwen3.5-2B trained with vanilla AntiPaSTO+GRPO on simple_GRPO reproduces measurable reward hacking (>30% hack rate at 200 steps).

Falsified if: vanilla hack rate <30%. Decision branch: swap to Qwen3-4B with num_generations halved. Secondary: if simple_GRPO can't reproduce hacking on either model, fall back to Ariahw's verl path and accept the harder hook.

Steps

1. Build infra — fast-dev-run targets first, no real training yet

  • 1a. Vendor simple_GRPO into docs/vendor/simple_GRPO/ (done); smoke-run their GSM8K example on tiny-random-qwen3 (5 steps, CPU) to confirm ref_server + grpo_ref_split rollout/train split works in our env.
  • 1b. Vendor lora-lite into docs/vendor/lora-lite/ (done); wrap Qwen3.5-0.8B attn+MLP nn.Linear modules with AntiPaSTO at full rank (r = min(d_in, d_out), no SVD cropping; rotate_basis="none", only delta_S trainable). Full rank means W = U \,\mathrm{diag}(S)\, V_h exactly and W_res = 0, so there's no truncation error to debug on the first pass. Verify forward-pass round-trip numerically matches base model at \delta_S = 0 (max abs diff <1e-3 on a fixed prompt).
  • 1c. Implement gradient-side v_hack extraction (pseudocode below). Validation: per-module held-out projection score cos(g_held_hack - g_held_clean, v_hack) > 0 in >50% of modules.

2. H4 sanity — does vanilla AntiPaSTO+GRPO+simple_GRPO produce hacking?

  • 2a. Port Ariahw's run_tests-overwrite detection into simple_GRPO's ref_server.py reward fn. Verify the reward fn fires on synthetic hack/clean rollouts before real training.
  • 2b. Train Qwen3.5-2B, AntiPaSTO (r=full, delta_S only), GRPO (group_norm), 200 steps, num_generations=8, batch=16, 1 seed. Decision: if hack rate <30%, switch to Qwen3-4B with num_generations=4, batch=16 (half num_gen to keep VRAM headroom) and re-run 2b. Secondary fallback: drop simple_GRPO, return to verl.

3. Implement rank-space projection in simple_GRPO's training loop

  • 3a. In grpo_ref_split.py, between engine.backward(loss) and engine.step(), call project_grads(model, v_hack_cache). project_grads walks [m for m in model.modules() if hasattr(m, 'delta_S')] and for each module reads m.delta_S.grad : [r], projects against v_hack[module_name] (one-sided, magnitude-preserving), writes back in place. (Pseudocode below.)
  • 3b. Diagnostics logged per step (aggregated over modules): mean/std cos_align, mean ||grad||, frac_modules_with_cos>0.

4. Run arms (200 steps each, 3 seeds where indicated)

a. Vanilla AntiPaSTO + GRPO (3 seeds) — baseline b. Our method, gradient-side v_hack, no gating (3 seeds) — main result c. Our method, no magnitude preservation (1 seed) — design ablation d. Rebound reimplementation: advantage-side v_hack penalty (3 seeds) — H3 (concrete formula: per-rollout penalty α · max(0, cos(h_mean, v_concept)) added to scalar reward, where h_mean is mean residual-stream activation at a chosen layer and v_concept is mean-diff activation direction extracted from the same 60-80 pairs. We use Wu & Tang 2026 §3.2's published α=0.5 and same layer fraction (60-75% depth). Single layer, not per-module, matching their setup. Different v_concept from our gradient-side v_hack — this is intentional: H3 isolates the gradient-vs-advantage mechanism choice, not the direction-extraction choice.)

Total: 10 runs × ~3h on RTX 6000 96GB = ~30h compute. (Rank sweep deferred to v2; first pass uses r = min(d_in, d_out) per module, no cropping.)

5. Measure at every 25 steps

  • Hack rate (Ariahw's detector ported into simple_GRPO)
  • Pass rate on held-out problems without write access to evaluator
  • Per-module cos_align trajectory (sanity that we're projecting something nonzero)
  • frac_modules_with_cos>0 per step (sanity that one-sided clip fires)
  • KL drift from init policy (catastrophic-change check)

6. Headline plot and headline table

Plot. Hack rate vs pass rate, one point per (arm × seed). Pareto frontier. Our method should land below-and-to-the-right of vanilla. Annotate Rebound.

Table schema (publication-ready; left-to-right = essential to optional, so trailing columns can be cut for space):

Arm ΔSafePass↑ Hack %↓ Pass %↑ KL↓ mean·cos* frac·fired* ‖g‖*
Vanilla (a) 0 (ref)
Ours (b)
Ours, no mag-preserve (c)
Rebound (d)

Caption. ↑ higher is better, ↓ lower is better. ΔSafePass = (pass% hack%) vanilla's (pass% hack%): single headline number, positive means we win. Hack % = fraction of rollouts triggering run_tests-overwrite detector. Pass % = fraction passing held-out tests without write access. KL = mean per-token KL from init policy over last 25 steps. * = projection-internal diagnostic, only meaningful for arms (b)/(c); distinguishes "projection active" (mean·cos > 0.2, frac·fired > 0.4) from "projection silent no-op". Cells report mean ± SEM across seeds.

7. Falsification check

Before publishing, run pre-registered analysis on H1, H3, H4. Report all hypotheses including falsified ones.

Pseudocode (the three load-bearing bits)

A. AntiPaSTO module wrap (full rank, first pass)

class AntiPaSTO(nn.Module):
    # constructed from an existing nn.Linear(W: [d_out, d_in], b)
    # FIRST PASS: r = min(d_out, d_in) -- no truncation, W_res == 0
    def __init__(self, W, b):
        U, S, Vh = torch.linalg.svd(W.float(), full_matrices=False)
        r = S.shape[0]                    # = min(d_out, d_in)
        # buffers (frozen): the full SVD
        self.U  = U                       # [d_out, r]
        self.S  = S                       # [r]
        self.Vh = Vh                      # [r, d_in]
        self.b  = b
        # trainable (ONLY this): scalar per rank
        self.delta_S = nn.Parameter(torch.zeros(r))

    def forward(self, x):  # x: [..., d_in]
        return ((x @ self.Vh.T) * (self.S + self.delta_S)) @ self.U.T + self.b

Replace every target nn.Linear in attn (q,k,v,o_proj) and MLP (up,gate,down_proj) with this. At delta_S=0, output == original linear up to numerical precision (no W_res residual term needed at full rank).

SVD precompute strategy. Don't SVD the whole model on GPU at once. Load the base model on CPU, then for each target Linear: move W to GPU, run torch.linalg.svd(W.float(), full_matrices=False), save (U, S, Vh) -> svd_cache/{model_name}/{module_path}.pt. Wrap construction then loads the cached SVD per module. SVD is done once per base model; ~5-10s per big MLP weight on RTX 3090.

B. Gradient-side v_hack extraction (per module)

v_hack = {}                                       # dict[module_name -> Tensor[r]]
grads_hack = defaultdict(list)
grads_clean = defaultdict(list)

# Per-pair: process hack and clean independently, NLL over their own completion
# tokens only. Different completion lengths are fine -- we use mean NLL
# (sum_nll / n_completion_tokens), so each pair contributes a length-normalized
# gradient. This avoids biasing v_hack toward longer (typically clean)
# completions. Pad each example individually; no cross-completion padding.

for (prompt, hack_completion, clean_completion) in pairs:
    for label, completion in [('hack', hack_completion), ('clean', clean_completion)]:
        model.zero_grad()
        ids   = tokenize(prompt + completion)            # [1, L]
        mask  = completion_mask(ids, prompt_len=len(prompt_ids))  # 1 on completion tokens
        logits = model(ids).logits[:, :-1]
        # MEAN NLL over completion tokens (length-normalized)
        loss   = (nll_per_token(logits, ids[:, 1:]) * mask[:, 1:]).sum() / mask[:, 1:].sum()
        loss.backward()
        for name, m in model.named_modules():
            if hasattr(m, 'delta_S'):
                bucket = grads_hack if label == 'hack' else grads_clean
                bucket[name].append(m.delta_S.grad.detach().cpu().clone())

for name in grads_hack:
    diff = stack(grads_hack[name]).mean(0) - stack(grads_clean[name]).mean(0)  # [r]
    v_hack[name] = diff / (diff.norm() + 1e-8)

torch.save(v_hack, 'v_hack.pt')

Validation (report both, don't just gate on threshold):

  • On held-out pairs, recompute per-module diff_held and cos_align_held = cos(diff_held, v_hack[name]).
  • Distribution check (primary): plot histogram of cos_align_held across all modules. Healthy = unimodal positive, median > 0.3. Pathological = bimodal or median near 0.
  • Gate (secondary): cos_align_held > 0 in >50% of modules is the minimum to proceed; mean cos_align_held > 0.2 is the target. If <50% pass, extraction is broken and we debug before training.

C. Pre-optimizer-step projection hook

def project_grads(model, v_hack: dict[str, Tensor]):
    # called after engine.backward(loss), before engine.step()
    cos_log, n_modules, n_fired = [], 0, 0
    for name, m in model.named_modules():
        if not hasattr(m, 'delta_S'): continue
        g = m.delta_S.grad                        # [r]
        if g is None: continue
        n_modules += 1
        v = v_hack[name].to(g.device)             # [r], unit
        g_norm = g.norm()
        if g_norm < 1e-12: continue
        cos_a = (g @ v) / g_norm                  # scalar
        cos_log.append(cos_a.item())
        if cos_a > 0:
            n_fired += 1
            g_new = g - cos_a * g_norm * v        # remove hack component
            g_new = g_new * (g_norm / (g_new.norm() + 1e-8))  # magnitude preserve
            m.delta_S.grad.copy_(g_new)
    return dict(mean_cos=mean(cos_log), frac_fired=n_fired/max(n_modules,1))

Integration into grpo_ref_split.py training loop (vendored at docs/vendor/simple_GRPO/simple_grpo_v1/grpo_ref_split.py; we copy and edit, not import):

# at top of training script, once:
v_hack = torch.load('v_hack.pt', map_location='cpu')   # dict[str, Tensor[r]]
# (extraction script from B above produces this artifact; if missing, crash loud)

# inside the training loop:
loss = GRPO_step(batch)
engine.backward(loss)
stats = project_grads(engine.module, v_hack)       # <-- NEW: 1 line
engine.step()
if rank == 0: log(stats)

Decisions left open (write these up alongside results)

  • Rank r. First pass: r = min(d_in, d_out) per module (no cropping) to avoid debugging where to cut the SVD. Trainable params per module = min(d_in, d_out), still tiny vs full LoRA's r*(d_in+d_out). Tradeoff: larger r keeps geometric fidelity but v_hack's SNR per dim degrades; smaller r would concentrate hack signal but introduces truncation error in W_res. Rank sweep is v2 work.

Why measure ratio, not just hack rate

A model that learns nothing won't cheat. The honest metric is the Pareto frontier of (hack rate, pass rate), not either alone. Pure hack-rate rewards undertraining; pure pass-rate rewards anything that improves coding including via the hack. Headline claim shape: "at matched pass rate ±5pp on held-out problems without write access, our method reduces hack rate from X% to Y%."

Compute estimate

  • Single run on 96GB RTX 6000: ~2-3h (Qwen3.5-2B, num_gen=8, 200 steps, simple_GRPO, AntiPaSTO full rank)
  • 10 runs: 25-35h
  • At ~$3 AUD/hr: $75-105 AUD
    • debugging buffer: budget ~$200 AUD total
  • Calendar time: 1 week back-to-back; 2-3 weeks with iteration

Risks and decision points

  • H4 falsified (no hack on Qwen3.5-2B with simple_GRPO): branch 1 — try Qwen3-4B same hyperparams. Branch 2 — drop simple_GRPO, hook into verl directly. Adds ~1-2 weeks engineering.
  • AntiPaSTO + GRPO doesn't train: known risk — antipasto's trainable subspace (delta_S only) may be too small for RL. If so, document and fall back to PiSSA-LoRA-freeze-A. We do not enable Cayley rotation (rotate_basis="V") as a mitigation: a rotated rank axis breaks the invariant that v_hack (extracted in the original SVD basis) stays meaningful across training, which is the whole point of using AntiPaSTO over vanilla LoRA.
  • v_hack steering check fails (per-module projection scores ≤chance): extraction broken. Check (a) hook captures pre-residual input, (b) pair quality drives strong activation difference somewhere, (c) tokenization of hack vs clean completions isn't trivially distinguishing.
  • All methods tie vanilla on hack rate: intervention not biting. Check cos_align logs nonzero, frac_modules_with_cos>0 nonzero.

What this is not

  • Not a claim that rank-space gradient projection solves reward hacking generally
  • Not a comparison to monitor-based methods (cite Ariahw's numbers, don't re-run)
  • Not a claim about hacks beyond run_tests() overwrite
  • Not a replacement for RLHF safety pipeline; this is a targeted intervention