Files
evil_MoE/docs/papers/grad_routing/sgtm_vs_ours.md
T
wassname 7195d19f90 docs
2026-06-07 03:07:35 +00:00

14 KiB

SGTM (selective-gradient-masking) vs our route2 — code comparison

Cloned github.com/safety-research/selective-gradient-masking to docs/vendor/selective-gradient-masking. This compares their adapter + train loop to ours (antipasto.py, train.py). Bottom line: their selectivity is data-label driven over fixed reserved dimensions, never a per-token learned direction. That's exactly the axis where our route2-act failed.

Their mechanism (src/sgtm/model/)

  • No adapter. SplitLinearOut splits an existing layer's output features into weight_retain [retain_dim, d_in] and weight_forget [d_out-retain_dim, d_in]. The "forget" capacity is part of the model's own width, a fixed dimension partition. No added LoRA, no SVD, no extracted direction.
  • Routing is by example label, on those fixed dims. A batch is labeled forget / retain / adjacent (trainer.py:42,254). Forget-labeled examples go through forward_forget:
    • activation masking: hidden[:, :, :retain_mlp_dim] = 0 (zero the retain dims so the forget concept can only live in the forget dims).
    • gradient routing (Cloud 2024): hidden[:, :, :retain_mlp_dim] = hidden[..., :retain_mlp_dim].detach() — forward value UNCHANGED, forget gradient flows only into the forget dims.
  • Deploy = ablate(): zero weight_forget (or, with trainable=True, reinit it to the retain-dims' std so the model stays finetunable after the forget subspace is removed).
  • Weak detector is a first-class knob. trainer.py:265-287 mislabels the forget/retain split by explicit forget_tpr/forget_fpr or precision/recall, simulating an imperfect detector. This is their headline axis (unlearning quality vs detector quality) and it directly matches our no-cheat weak-detector framing.

Side by side

SGTM ours (route2)
forget capacity reserved dims of the real layer added LoRA A_q/B_q (33M)
where it lives fixed dimension partition learned, in SVD-of-W basis
routing signal example label (forget data), noisy via TPR/FPR route2-act: per-token cos(a,v_act)>0; route2-grad: per-rollout grad sign
forward change act-mask zeros dims / grad-route detaches (no value change) act-mask detaches ~half of tokens; quar always summed
can the masked term blow up? not in practice (see below) yes in route2-act (quar runaway; gn 0.23→1.5)
deploy zero or reinit forget weights delete quarantine
fine-tune LR 5e-5 (from-scratch pretrain is 5e-3) 2.5e-3 on delta_S, 2.5e-4 on quar
norm standard GPT layernorm, no batchnorm (Qwen RMSNorm, frozen)

Why theirs doesn't blow up (corrected — not "bounded width")

My first take ("reserved width is bounded") was too glib. weight_forget is a free parameter; nothing caps its magnitude either. The real reasons route2-act runs away and SGTM doesn't:

  1. Clean label routing => the forget path only ever sees forget gradients. SGTM routes whole forget-labeled examples through forward_forget. The forget dims get gradient only from genuine forget data. route2-act routes ~half of ALL tokens (the cos>0 coin-flip) into the quarantine, so the quarantine absorbs reward-driven gradient on everything and there's a feedback loop (more quar -> worse rollouts -> more reward-chasing grad).
  2. Supervised CE, not RL. Their loss is bounded next-token cross-entropy on fixed data. Ours is GRPO: the policy that generates the data is the thing being perturbed, so a bad quar update degrades the next batch of rollouts.
  3. Their forget output participates in the normal forward (read by the same c_proj, normalized by the same downstream layernorms the model co-adapted to). Our quar is an independent additive rank-16 path nothing co-adapts to.

So the protection is mostly (1)+(2): label routing avoids over-routing, and supervised training has no policy-feedback loop. It is NOT mainly adapter-vs- reserved-dims.

Can't we adopt (1) and (2)? Mostly yes.

(1) Clean routing -- yes, we already can. Routing per-ROLLOUT instead of per-token gives the same cleanliness (a rollout is all-in or all-out, no coin-flip). route2-grad IS this, keeping the vector premise. We can go further and gate eligibility on the weak detector flag (allowed by no-cheat) so the quarantine only ever sees flagged-hack rollouts -- the direct analog of SGTM's forget label. The vector still decides the subspace; the flag decides which rollouts. route2-act's per-token granularity was the only part that had to die.

(2) No-feedback-loop -- we can't drop RL (that's the research question), but we can remove the actual harmful path. The loop is "quar corrupts the data-generating policy -> worse rollouts -> worse gradient." Two fixes, neither needs supervised training:

  • Ablate the quarantine during generation (generate from the clean deployed model; route gradient into quar only in the training backward). The quarantine becomes a pure gradient sink that never touches the rollouts.
  • Make routing forward-value-preserving like SGTM's gradient-routing variant -- pure detach, no added growing term. See improvement (B) below: route within delta_S along the SVD axes instead of adding a 33M LoRA.

So the honest read: theirs is safer by construction, but the safety is adoptable. We chose the riskier per-token added-LoRA design; the fixes are known.

Key tricks worth stealing

  1. Route by rollout label, not per-token direction. SGTM never fires a learned direction per token. Our route2-act's cos(a,v_act)>0 coin-flip (act_cos~0 → fires on ~half of all tokens, job 46) is an invention with no analog in their working method. route2-grad (per-rollout) and route (flag-selected rollouts) are the variants that match SGTM's philosophy.
  2. Reserve capacity instead of extracting a direction. Their forget concept localizes to reserved dims by construction (mask the retain dims on forget data). No staleness, no v_hack refresh, no over-routing. More robust, less ambitious than our direction-extraction.
  3. TPR/FPR detector-quality sweep as the generalisation experiment. Cleaner than our "select pairs from 2 of 4 modes" — we could parametrize detector recall directly. (task #161 leakage metric is the related SGTM eq.1.)
  4. ablate(trainable=True): reinit forget weights to matched std so the model is finetunable post-ablation (not a dead hole). Relevant if we ever retrain after deploy.
  5. No extra summed adapter → no coherence-wrecking additive term. Our quar being always-summed AND large AND freshly-init is a route2-specific failure mode SGTM structurally avoids.

What this says about our results

  • route2-act's failure is not a tuning problem; it diverges from SGTM at the routing signal (per-token learned direction vs per-example label). The lr fix only slowed the same disease.
  • route2-grad (job 47) and route are the SGTM-shaped variants. If 47 is coherent and route's file_marker generalisation holds, that's the SGTM-consistent path.
  • We are deliberately more ambitious than SGTM (extract the hack direction rather than reserve dims and route by label). That buys "no need to reserve capacity or know the dims" but costs robustness (extraction can be wrong/stale, and a per-token direction over-routes). Worth stating plainly in the writeup.

Pseudocode: theirs vs ours

The shared idea (gradient routing, Cloud 2024): push a concept's gradient into a throwaway subspace so the deployed model never learns it. They differ in HOW the "this is the concept" decision is made and WHAT subspace it goes to.

THEIRS (SGTM gradient-routing variant), supervised, route by label:

# setup: in each MLP, reserve dims [retain_dim:] as the "forget" subspace
for x, label in data:                 # label in {retain, forget}; noisy via TPR/FPR
    h = c_fc(x)                        # [b, s, d_mlp]
    if label == forget:
        h[:, :, :retain_dim] = h[:, :, :retain_dim].detach()   # forget grad -> forget dims only
    y = c_proj(act(h))                 # forward VALUE identical either way (detach != zero)
    loss = cross_entropy(y, next_token)
    loss.backward(); opt.step()
# deploy: zero weight_forget[retain_dim:]  (the reserved forget dims)

OURS (route2, the novel bit = route by an extracted VECTOR, not a label):

# setup: v_hack = top SVD dirs of (grad logp(clean) - grad logp(hack)) over pairs
#        per module: frozen U/Vh; trainable delta_S (kept); quarantine A_q/B_q (forget)
for prompts:
    rollouts = student.generate(prompts)          # GRPO, on-policy (no labels)
    adv      = grpo_advantage(reward(rollouts))    # reward from the weak grader
    # forward: y + kept + quar ; kept = delta_S path, quar = always-summed quarantine
    #   route2-act : per-TOKEN, route tokens with cos(activation, v_act) > 0
    #                -> kept.detach() there  (FAILED: cos~0 => fires on ~half of all tokens)
    #   route2-grad: per-ROLLOUT, flag rollouts with cos(grad_b, v_grad) > 0
    #                -> subtract their share from delta_S.grad post-backward
    loss = grpo_loss(rollouts, adv); loss.backward()
    route_hack_component_into_quar()               # by VECTOR, not label
    opt.step()
# deploy: zero A_q/B_q  (quar -> 0)

The one-line diff: theirs decides "forget" by a (noisy) data label and sends it to reserved dims; ours decides "hack" by cos against an extracted direction and sends it to a learned quarantine. The vector premise is the contribution. The per-token granularity (route2-act) is the part that broke; route2-grad keeps the vector and routes per-rollout instead.

Their diagnostics (what they log/plot)

  • retain vs forget val_loss + val_ppl, separately (trainer.py:174). The core tradeoff: forget ppl should rise, retain ppl stay flat. Our analog: lp_s / lp_t and cin_s / cin_t (want cin_t > cin_s).
  • per-category breakdown (evaluate_by_category) -- per-topic forget/retain, like our per-mode hk_<mode> columns.
  • train/grad_norm (global, trainer.py:609) + a dedicated grad_norms.ipynb notebook -- they track grad norms per component. This is exactly the instinct behind our new qE column (energy into the quarantine vs the kept knob).
  • retain_forget_tradeoff.ipynb -- the Pareto curve (forget suppression vs retain damage), swept over detector quality / forget fraction.
  • undiscovered_rate.ipynb -- a relearn/elicitation probe: after unlearning, try to re-elicit the forget capability; "undiscovered" = it stays hidden. We have no relearn probe yet (improvement F below).

Their tricks

  • TPR/FPR + precision/recall noisy-detector knobs (trainer.py:265-287): corrupt the forget/retain labels to a target detector quality and sweep it.
  • ablate(trainable=True): reinit forget weights to retain-std (finetunable after deploy, not a dead hole). TODO in our ablate_quarantine.
  • forget-adjacent dataset + upsample_forget/retain/adjacent + per-class loss weights (logit_alpha/logit_beta): a third "near-forget but keep" data class to measure/limit collateral damage. We have no near-hack-but-legit eval.
  • logit calibration (lm_head bias, separate lr) -- not relevant to us.

Proposed improvements for OUR setting (RH env + SVD + vector)

Ordered by expected value. (A) is in flight; (B) is the structural one.

  • (A) Route per-rollout, not per-token. route2-grad. Keep the vector, drop the coin-flip. Gate eligibility on the weak-detector flag so the quarantine only sees flagged-hack rollouts (vector picks the subspace, flag picks the rollouts). Status: job 47 running.

  • (B) Route WITHIN delta_S along the SVD v_hack axes -- no separate LoRA. The quarantine doesn't need 33M new params. Reserve the v_hack subspace of the existing delta_S as the "forget" axes. For a flagged-hack rollout, detach the complement so its gradient flows ONLY into the v_hack-subspace component of delta_S; at deploy, zero that component. Properties:

    • forward-value-preserving during training (pure detach, like SGTM) -> no growing additive term -> kills the route2-act blow-up at the root.
    • uses the SVD basis + vector premise directly (the v_hack axes ARE the reserved subspace).
    • differs from erase: erase throws the hack gradient away every step; this lets the model USE the hack during training (rollouts stay on-distribution, still get reward) but DELETES it at deploy. That's the Cloud et al. routing insight -- routing is more stable than erasing because the capability stays available while training. Open question for review: is "route into a deleted v_hack component of delta_S" meaningfully different from erase in the GRPO single-step (ratio=1) regime, or does it collapse to the same update? (Suspect different: erase removes the component from the kept axes' grad; this preserves behavior via the forget axes that are only deleted at deploy.)
  • (C) Break the generation feedback loop. If we keep any added quarantine, ablate it during generate() so the data-generating policy is the clean deployed model; route gradient into the quarantine only in the training backward. Removes the "quar corrupts rollouts -> worse reward -> worse grad" runaway. (B) makes this partly moot.

  • (D) Noisy-detector sweep (TPR/FPR) instead of the 2-of-4-modes split, as the generalisation/no-cheat experiment. A continuous dial on detector recall.

  • (E) Near-hack-but-legit eval (their "forget-adjacent"): rollouts that look hack-shaped but are legitimate, to measure over-suppression / collateral damage on solve rate. Currently we only watch aggregate solve.

  • (F) Relearn probe (their undiscovered-rate): after deploy-ablation, try to re-elicit the hack with a few GRPO steps. If it returns fast, the routing only hid it shallowly rather than preventing the learning.

On the hyperparameters

The Appendix-I table (LR 5e-3, AdamW, wd 0.1, betas 0.9/0.95) is from-scratch TinyStories pretraining of 8M-64M models for 33k steps — not transferable to adapter RL on a 4B model. Their fine-tuning scripts use --lr 5e-5. Adapter LRs run 10-100x above full-FT, so our delta_S at 2.5e-3 is in range; the 33M quarantine at 2.5e-4 is still ~5x above their full-FT 5e-5 and is freshly init, which is consistent with (though not the main cause of) route2-act's drift.