diff --git a/docs/grad_routing/sgtm_vs_ours.md b/docs/grad_routing/sgtm_vs_ours.md new file mode 100644 index 0000000..d9b0752 --- /dev/null +++ b/docs/grad_routing/sgtm_vs_ours.md @@ -0,0 +1,247 @@ +# SGTM (selective-gradient-masking) vs our route2 — code comparison + +Cloned `github.com/safety-research/selective-gradient-masking` to +`docs/vendor/selective-gradient-masking`. This compares their adapter + train +loop to ours (`antipasto.py`, `train.py`). Bottom line: their selectivity is +**data-label driven over fixed reserved dimensions**, never a per-token learned +direction. That's exactly the axis where our route2-act failed. + +## Their mechanism (src/sgtm/model/) + +- **No adapter.** `SplitLinearOut` splits an existing layer's output features + into `weight_retain` `[retain_dim, d_in]` and `weight_forget` + `[d_out-retain_dim, d_in]`. The "forget" capacity is part of the model's own + width, a fixed dimension partition. No added LoRA, no SVD, no extracted + direction. +- **Routing is by example label, on those fixed dims.** A batch is labeled + `forget` / `retain` / `adjacent` (trainer.py:42,254). Forget-labeled examples + go through `forward_forget`: + - activation masking: `hidden[:, :, :retain_mlp_dim] = 0` (zero the retain + dims so the forget concept can only live in the forget dims). + - gradient routing (Cloud 2024): `hidden[:, :, :retain_mlp_dim] = + hidden[..., :retain_mlp_dim].detach()` — forward value UNCHANGED, forget + gradient flows only into the forget dims. +- **Deploy = `ablate()`**: zero `weight_forget` (or, with `trainable=True`, + reinit it to the retain-dims' std so the model stays finetunable after the + forget subspace is removed). +- **Weak detector is a first-class knob.** trainer.py:265-287 mislabels the + forget/retain split by explicit `forget_tpr/forget_fpr` or + `precision/recall`, simulating an imperfect detector. This is their headline + axis (unlearning quality vs detector quality) and it directly matches our + no-cheat weak-detector framing. + +## Side by side + +| | SGTM | ours (route2) | +|---|---|---| +| forget capacity | reserved dims of the real layer | added LoRA `A_q/B_q` (33M) | +| where it lives | fixed dimension partition | learned, in SVD-of-W basis | +| routing signal | **example label** (forget data), noisy via TPR/FPR | route2-act: **per-token** `cos(a,v_act)>0`; route2-grad: per-rollout grad sign | +| forward change | act-mask zeros dims / grad-route detaches (no value change) | act-mask detaches ~half of tokens; quar always summed | +| can the masked term blow up? | not in practice (see below) | yes in route2-act (quar runaway; gn 0.23→1.5) | +| deploy | zero or reinit forget weights | delete quarantine | +| fine-tune LR | 5e-5 (from-scratch pretrain is 5e-3) | 2.5e-3 on delta_S, 2.5e-4 on quar | +| norm | standard GPT layernorm, no batchnorm | (Qwen RMSNorm, frozen) | + +## Why theirs doesn't blow up (corrected — not "bounded width") + +My first take ("reserved width is bounded") was too glib. `weight_forget` is a +free parameter; nothing caps its magnitude either. The real reasons route2-act +runs away and SGTM doesn't: + +1. **Clean label routing => the forget path only ever sees forget gradients.** + SGTM routes whole forget-labeled examples through `forward_forget`. The + forget dims get gradient only from genuine forget data. route2-act routes + ~half of ALL tokens (the `cos>0` coin-flip) into the quarantine, so the + quarantine absorbs reward-driven gradient on everything and there's a + feedback loop (more quar -> worse rollouts -> more reward-chasing grad). +2. **Supervised CE, not RL.** Their loss is bounded next-token cross-entropy on + fixed data. Ours is GRPO: the policy that generates the data is the thing + being perturbed, so a bad quar update degrades the next batch of rollouts. +3. **Their forget output participates in the normal forward** (read by the same + `c_proj`, normalized by the same downstream layernorms the model co-adapted + to). Our `quar` is an independent additive rank-16 path nothing co-adapts to. + +So the protection is mostly (1)+(2): label routing avoids over-routing, and +supervised training has no policy-feedback loop. It is NOT mainly adapter-vs- +reserved-dims. + +### Can't we adopt (1) and (2)? Mostly yes. + +(1) **Clean routing -- yes, we already can.** Routing per-ROLLOUT instead of +per-token gives the same cleanliness (a rollout is all-in or all-out, no +coin-flip). route2-grad IS this, keeping the vector premise. We can go further +and gate eligibility on the weak detector flag (allowed by no-cheat) so the +quarantine only ever sees flagged-hack rollouts -- the direct analog of SGTM's +forget label. The vector still decides the *subspace*; the flag decides *which +rollouts*. route2-act's per-token granularity was the only part that had to die. + +(2) **No-feedback-loop -- we can't drop RL (that's the research question), but +we can remove the actual harmful path.** The loop is "quar corrupts the +data-generating policy -> worse rollouts -> worse gradient." Two fixes, neither +needs supervised training: + - **Ablate the quarantine during generation** (generate from the clean + deployed model; route gradient into quar only in the training backward). The + quarantine becomes a pure gradient sink that never touches the rollouts. + - **Make routing forward-value-preserving like SGTM's gradient-routing + variant** -- pure detach, no added growing term. See improvement (B) below: + route *within* delta_S along the SVD axes instead of adding a 33M LoRA. + +So the honest read: theirs is safer by construction, but the safety is +adoptable. We chose the riskier per-token added-LoRA design; the fixes are known. + +## Key tricks worth stealing + +1. **Route by rollout label, not per-token direction.** SGTM never fires a + learned direction per token. Our route2-act's `cos(a,v_act)>0` coin-flip + (act_cos~0 → fires on ~half of all tokens, job 46) is an invention with no + analog in their working method. route2-grad (per-rollout) and route + (flag-selected rollouts) are the variants that match SGTM's philosophy. +2. **Reserve capacity instead of extracting a direction.** Their forget concept + localizes to reserved dims *by construction* (mask the retain dims on forget + data). No staleness, no v_hack refresh, no over-routing. More robust, less + ambitious than our direction-extraction. +3. **TPR/FPR detector-quality sweep** as the generalisation experiment. Cleaner + than our "select pairs from 2 of 4 modes" — we could parametrize detector + recall directly. (task #161 leakage metric is the related SGTM eq.1.) +4. **ablate(trainable=True): reinit forget weights to matched std** so the model + is finetunable post-ablation (not a dead hole). Relevant if we ever retrain + after deploy. +5. **No extra summed adapter** → no coherence-wrecking additive term. Our quar + being always-summed AND large AND freshly-init is a route2-specific failure + mode SGTM structurally avoids. + +## What this says about our results + +- route2-act's failure is **not** a tuning problem; it diverges from SGTM at the + routing signal (per-token learned direction vs per-example label). The lr fix + only slowed the same disease. +- route2-grad (job 47) and route are the SGTM-shaped variants. If 47 is coherent + and route's file_marker generalisation holds, that's the SGTM-consistent path. +- We are deliberately *more ambitious* than SGTM (extract the hack direction + rather than reserve dims and route by label). That buys "no need to reserve + capacity or know the dims" but costs robustness (extraction can be wrong/stale, + and a per-token direction over-routes). Worth stating plainly in the writeup. + +## Pseudocode: theirs vs ours + +The shared idea (gradient routing, Cloud 2024): push a concept's gradient into a +throwaway subspace so the deployed model never learns it. They differ in HOW the +"this is the concept" decision is made and WHAT subspace it goes to. + +THEIRS (SGTM gradient-routing variant), supervised, route by label: + + # setup: in each MLP, reserve dims [retain_dim:] as the "forget" subspace + for x, label in data: # label in {retain, forget}; noisy via TPR/FPR + h = c_fc(x) # [b, s, d_mlp] + if label == forget: + h[:, :, :retain_dim] = h[:, :, :retain_dim].detach() # forget grad -> forget dims only + y = c_proj(act(h)) # forward VALUE identical either way (detach != zero) + loss = cross_entropy(y, next_token) + loss.backward(); opt.step() + # deploy: zero weight_forget[retain_dim:] (the reserved forget dims) + +OURS (route2, the novel bit = route by an extracted VECTOR, not a label): + + # setup: v_hack = top SVD dirs of (grad logp(clean) - grad logp(hack)) over pairs + # per module: frozen U/Vh; trainable delta_S (kept); quarantine A_q/B_q (forget) + for prompts: + rollouts = student.generate(prompts) # GRPO, on-policy (no labels) + adv = grpo_advantage(reward(rollouts)) # reward from the weak grader + # forward: y + kept + quar ; kept = delta_S path, quar = always-summed quarantine + # route2-act : per-TOKEN, route tokens with cos(activation, v_act) > 0 + # -> kept.detach() there (FAILED: cos~0 => fires on ~half of all tokens) + # route2-grad: per-ROLLOUT, flag rollouts with cos(grad_b, v_grad) > 0 + # -> subtract their share from delta_S.grad post-backward + loss = grpo_loss(rollouts, adv); loss.backward() + route_hack_component_into_quar() # by VECTOR, not label + opt.step() + # deploy: zero A_q/B_q (quar -> 0) + +The one-line diff: theirs decides "forget" by a (noisy) data label and sends it +to reserved dims; ours decides "hack" by cos against an extracted direction and +sends it to a learned quarantine. The vector premise is the contribution. The +per-token granularity (route2-act) is the part that broke; route2-grad keeps the +vector and routes per-rollout instead. + +## Their diagnostics (what they log/plot) + +- **retain vs forget val_loss + val_ppl, separately** (trainer.py:174). The core + tradeoff: forget ppl should rise, retain ppl stay flat. Our analog: `lp_s` / + `lp_t` and `cin_s` / `cin_t` (want cin_t > cin_s). +- **per-category breakdown** (`evaluate_by_category`) -- per-topic forget/retain, + like our per-mode `hk_` columns. +- **train/grad_norm** (global, trainer.py:609) + a dedicated `grad_norms.ipynb` + notebook -- they track grad norms per component. This is exactly the instinct + behind our new `qE` column (energy into the quarantine vs the kept knob). +- **retain_forget_tradeoff.ipynb** -- the Pareto curve (forget suppression vs + retain damage), swept over detector quality / forget fraction. +- **undiscovered_rate.ipynb** -- a relearn/elicitation probe: after unlearning, + try to re-elicit the forget capability; "undiscovered" = it stays hidden. + We have no relearn probe yet (improvement F below). + +## Their tricks + +- **TPR/FPR + precision/recall noisy-detector knobs** (trainer.py:265-287): + corrupt the forget/retain labels to a target detector quality and sweep it. +- **ablate(trainable=True)**: reinit forget weights to retain-std (finetunable + after deploy, not a dead hole). TODO in our `ablate_quarantine`. +- **forget-adjacent dataset** + `upsample_forget/retain/adjacent` + per-class + loss weights (`logit_alpha`/`logit_beta`): a third "near-forget but keep" data + class to measure/limit collateral damage. We have no near-hack-but-legit eval. +- **logit calibration** (lm_head bias, separate lr) -- not relevant to us. + +## Proposed improvements for OUR setting (RH env + SVD + vector) + +Ordered by expected value. (A) is in flight; (B) is the structural one. + +- **(A) Route per-rollout, not per-token.** route2-grad. Keep the vector, drop + the coin-flip. Gate eligibility on the weak-detector flag so the quarantine + only sees flagged-hack rollouts (vector picks the subspace, flag picks the + rollouts). Status: job 47 running. + +- **(B) Route WITHIN delta_S along the SVD v_hack axes -- no separate LoRA.** + The quarantine doesn't need 33M new params. Reserve the v_hack subspace of the + existing delta_S as the "forget" axes. For a flagged-hack rollout, detach the + complement so its gradient flows ONLY into the v_hack-subspace component of + delta_S; at deploy, zero that component. Properties: + - forward-value-preserving during training (pure detach, like SGTM) -> no + growing additive term -> kills the route2-act blow-up at the root. + - uses the SVD basis + vector premise directly (the v_hack axes ARE the + reserved subspace). + - differs from erase: erase throws the hack gradient away every step; this + lets the model USE the hack during training (rollouts stay on-distribution, + still get reward) but DELETES it at deploy. That's the Cloud et al. routing + insight -- routing is more stable than erasing because the capability stays + available while training. + Open question for review: is "route into a deleted v_hack component of delta_S" + meaningfully different from erase in the GRPO single-step (ratio=1) regime, or + does it collapse to the same update? (Suspect different: erase removes the + component from the *kept* axes' grad; this preserves behavior via the forget + axes that are only deleted at deploy.) + +- **(C) Break the generation feedback loop.** If we keep any added quarantine, + ablate it during `generate()` so the data-generating policy is the clean + deployed model; route gradient into the quarantine only in the training + backward. Removes the "quar corrupts rollouts -> worse reward -> worse grad" + runaway. (B) makes this partly moot. + +- **(D) Noisy-detector sweep (TPR/FPR)** instead of the 2-of-4-modes split, as + the generalisation/no-cheat experiment. A continuous dial on detector recall. + +- **(E) Near-hack-but-legit eval** (their "forget-adjacent"): rollouts that look + hack-shaped but are legitimate, to measure over-suppression / collateral + damage on solve rate. Currently we only watch aggregate solve. + +- **(F) Relearn probe** (their undiscovered-rate): after deploy-ablation, try to + re-elicit the hack with a few GRPO steps. If it returns fast, the routing only + hid it shallowly rather than preventing the learning. + +## On the hyperparameters + +The Appendix-I table (LR 5e-3, AdamW, wd 0.1, betas 0.9/0.95) is **from-scratch +TinyStories pretraining of 8M-64M models for 33k steps** — not transferable to +adapter RL on a 4B model. Their *fine-tuning* scripts use `--lr 5e-5`. Adapter +LRs run 10-100x above full-FT, so our delta_S at 2.5e-3 is in range; the 33M +quarantine at 2.5e-4 is still ~5x above their full-FT 5e-5 and is freshly init, +which is consistent with (though not the main cause of) route2-act's drift.