docs: SGTM vs ours -- diagnostics, tricks, and proposed improvements (B = route within delta_S along SVD axes)

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-01 01:39:46 +00:00
parent 23512ed07c
commit 090f29671d
+247
View File
@@ -0,0 +1,247 @@
# SGTM (selective-gradient-masking) vs our route2 — code comparison
Cloned `github.com/safety-research/selective-gradient-masking` to
`docs/vendor/selective-gradient-masking`. This compares their adapter + train
loop to ours (`antipasto.py`, `train.py`). Bottom line: their selectivity is
**data-label driven over fixed reserved dimensions**, never a per-token learned
direction. That's exactly the axis where our route2-act failed.
## Their mechanism (src/sgtm/model/)
- **No adapter.** `SplitLinearOut` splits an existing layer's output features
into `weight_retain` `[retain_dim, d_in]` and `weight_forget`
`[d_out-retain_dim, d_in]`. The "forget" capacity is part of the model's own
width, a fixed dimension partition. No added LoRA, no SVD, no extracted
direction.
- **Routing is by example label, on those fixed dims.** A batch is labeled
`forget` / `retain` / `adjacent` (trainer.py:42,254). Forget-labeled examples
go through `forward_forget`:
- activation masking: `hidden[:, :, :retain_mlp_dim] = 0` (zero the retain
dims so the forget concept can only live in the forget dims).
- gradient routing (Cloud 2024): `hidden[:, :, :retain_mlp_dim] =
hidden[..., :retain_mlp_dim].detach()` — forward value UNCHANGED, forget
gradient flows only into the forget dims.
- **Deploy = `ablate()`**: zero `weight_forget` (or, with `trainable=True`,
reinit it to the retain-dims' std so the model stays finetunable after the
forget subspace is removed).
- **Weak detector is a first-class knob.** trainer.py:265-287 mislabels the
forget/retain split by explicit `forget_tpr/forget_fpr` or
`precision/recall`, simulating an imperfect detector. This is their headline
axis (unlearning quality vs detector quality) and it directly matches our
no-cheat weak-detector framing.
## Side by side
| | SGTM | ours (route2) |
|---|---|---|
| forget capacity | reserved dims of the real layer | added LoRA `A_q/B_q` (33M) |
| where it lives | fixed dimension partition | learned, in SVD-of-W basis |
| routing signal | **example label** (forget data), noisy via TPR/FPR | route2-act: **per-token** `cos(a,v_act)>0`; route2-grad: per-rollout grad sign |
| forward change | act-mask zeros dims / grad-route detaches (no value change) | act-mask detaches ~half of tokens; quar always summed |
| can the masked term blow up? | not in practice (see below) | yes in route2-act (quar runaway; gn 0.23→1.5) |
| deploy | zero or reinit forget weights | delete quarantine |
| fine-tune LR | 5e-5 (from-scratch pretrain is 5e-3) | 2.5e-3 on delta_S, 2.5e-4 on quar |
| norm | standard GPT layernorm, no batchnorm | (Qwen RMSNorm, frozen) |
## Why theirs doesn't blow up (corrected — not "bounded width")
My first take ("reserved width is bounded") was too glib. `weight_forget` is a
free parameter; nothing caps its magnitude either. The real reasons route2-act
runs away and SGTM doesn't:
1. **Clean label routing => the forget path only ever sees forget gradients.**
SGTM routes whole forget-labeled examples through `forward_forget`. The
forget dims get gradient only from genuine forget data. route2-act routes
~half of ALL tokens (the `cos>0` coin-flip) into the quarantine, so the
quarantine absorbs reward-driven gradient on everything and there's a
feedback loop (more quar -> worse rollouts -> more reward-chasing grad).
2. **Supervised CE, not RL.** Their loss is bounded next-token cross-entropy on
fixed data. Ours is GRPO: the policy that generates the data is the thing
being perturbed, so a bad quar update degrades the next batch of rollouts.
3. **Their forget output participates in the normal forward** (read by the same
`c_proj`, normalized by the same downstream layernorms the model co-adapted
to). Our `quar` is an independent additive rank-16 path nothing co-adapts to.
So the protection is mostly (1)+(2): label routing avoids over-routing, and
supervised training has no policy-feedback loop. It is NOT mainly adapter-vs-
reserved-dims.
### Can't we adopt (1) and (2)? Mostly yes.
(1) **Clean routing -- yes, we already can.** Routing per-ROLLOUT instead of
per-token gives the same cleanliness (a rollout is all-in or all-out, no
coin-flip). route2-grad IS this, keeping the vector premise. We can go further
and gate eligibility on the weak detector flag (allowed by no-cheat) so the
quarantine only ever sees flagged-hack rollouts -- the direct analog of SGTM's
forget label. The vector still decides the *subspace*; the flag decides *which
rollouts*. route2-act's per-token granularity was the only part that had to die.
(2) **No-feedback-loop -- we can't drop RL (that's the research question), but
we can remove the actual harmful path.** The loop is "quar corrupts the
data-generating policy -> worse rollouts -> worse gradient." Two fixes, neither
needs supervised training:
- **Ablate the quarantine during generation** (generate from the clean
deployed model; route gradient into quar only in the training backward). The
quarantine becomes a pure gradient sink that never touches the rollouts.
- **Make routing forward-value-preserving like SGTM's gradient-routing
variant** -- pure detach, no added growing term. See improvement (B) below:
route *within* delta_S along the SVD axes instead of adding a 33M LoRA.
So the honest read: theirs is safer by construction, but the safety is
adoptable. We chose the riskier per-token added-LoRA design; the fixes are known.
## Key tricks worth stealing
1. **Route by rollout label, not per-token direction.** SGTM never fires a
learned direction per token. Our route2-act's `cos(a,v_act)>0` coin-flip
(act_cos~0 → fires on ~half of all tokens, job 46) is an invention with no
analog in their working method. route2-grad (per-rollout) and route
(flag-selected rollouts) are the variants that match SGTM's philosophy.
2. **Reserve capacity instead of extracting a direction.** Their forget concept
localizes to reserved dims *by construction* (mask the retain dims on forget
data). No staleness, no v_hack refresh, no over-routing. More robust, less
ambitious than our direction-extraction.
3. **TPR/FPR detector-quality sweep** as the generalisation experiment. Cleaner
than our "select pairs from 2 of 4 modes" — we could parametrize detector
recall directly. (task #161 leakage metric is the related SGTM eq.1.)
4. **ablate(trainable=True): reinit forget weights to matched std** so the model
is finetunable post-ablation (not a dead hole). Relevant if we ever retrain
after deploy.
5. **No extra summed adapter** → no coherence-wrecking additive term. Our quar
being always-summed AND large AND freshly-init is a route2-specific failure
mode SGTM structurally avoids.
## What this says about our results
- route2-act's failure is **not** a tuning problem; it diverges from SGTM at the
routing signal (per-token learned direction vs per-example label). The lr fix
only slowed the same disease.
- route2-grad (job 47) and route are the SGTM-shaped variants. If 47 is coherent
and route's file_marker generalisation holds, that's the SGTM-consistent path.
- We are deliberately *more ambitious* than SGTM (extract the hack direction
rather than reserve dims and route by label). That buys "no need to reserve
capacity or know the dims" but costs robustness (extraction can be wrong/stale,
and a per-token direction over-routes). Worth stating plainly in the writeup.
## Pseudocode: theirs vs ours
The shared idea (gradient routing, Cloud 2024): push a concept's gradient into a
throwaway subspace so the deployed model never learns it. They differ in HOW the
"this is the concept" decision is made and WHAT subspace it goes to.
THEIRS (SGTM gradient-routing variant), supervised, route by label:
# setup: in each MLP, reserve dims [retain_dim:] as the "forget" subspace
for x, label in data: # label in {retain, forget}; noisy via TPR/FPR
h = c_fc(x) # [b, s, d_mlp]
if label == forget:
h[:, :, :retain_dim] = h[:, :, :retain_dim].detach() # forget grad -> forget dims only
y = c_proj(act(h)) # forward VALUE identical either way (detach != zero)
loss = cross_entropy(y, next_token)
loss.backward(); opt.step()
# deploy: zero weight_forget[retain_dim:] (the reserved forget dims)
OURS (route2, the novel bit = route by an extracted VECTOR, not a label):
# setup: v_hack = top SVD dirs of (grad logp(clean) - grad logp(hack)) over pairs
# per module: frozen U/Vh; trainable delta_S (kept); quarantine A_q/B_q (forget)
for prompts:
rollouts = student.generate(prompts) # GRPO, on-policy (no labels)
adv = grpo_advantage(reward(rollouts)) # reward from the weak grader
# forward: y + kept + quar ; kept = delta_S path, quar = always-summed quarantine
# route2-act : per-TOKEN, route tokens with cos(activation, v_act) > 0
# -> kept.detach() there (FAILED: cos~0 => fires on ~half of all tokens)
# route2-grad: per-ROLLOUT, flag rollouts with cos(grad_b, v_grad) > 0
# -> subtract their share from delta_S.grad post-backward
loss = grpo_loss(rollouts, adv); loss.backward()
route_hack_component_into_quar() # by VECTOR, not label
opt.step()
# deploy: zero A_q/B_q (quar -> 0)
The one-line diff: theirs decides "forget" by a (noisy) data label and sends it
to reserved dims; ours decides "hack" by cos against an extracted direction and
sends it to a learned quarantine. The vector premise is the contribution. The
per-token granularity (route2-act) is the part that broke; route2-grad keeps the
vector and routes per-rollout instead.
## Their diagnostics (what they log/plot)
- **retain vs forget val_loss + val_ppl, separately** (trainer.py:174). The core
tradeoff: forget ppl should rise, retain ppl stay flat. Our analog: `lp_s` /
`lp_t` and `cin_s` / `cin_t` (want cin_t > cin_s).
- **per-category breakdown** (`evaluate_by_category`) -- per-topic forget/retain,
like our per-mode `hk_<mode>` columns.
- **train/grad_norm** (global, trainer.py:609) + a dedicated `grad_norms.ipynb`
notebook -- they track grad norms per component. This is exactly the instinct
behind our new `qE` column (energy into the quarantine vs the kept knob).
- **retain_forget_tradeoff.ipynb** -- the Pareto curve (forget suppression vs
retain damage), swept over detector quality / forget fraction.
- **undiscovered_rate.ipynb** -- a relearn/elicitation probe: after unlearning,
try to re-elicit the forget capability; "undiscovered" = it stays hidden.
We have no relearn probe yet (improvement F below).
## Their tricks
- **TPR/FPR + precision/recall noisy-detector knobs** (trainer.py:265-287):
corrupt the forget/retain labels to a target detector quality and sweep it.
- **ablate(trainable=True)**: reinit forget weights to retain-std (finetunable
after deploy, not a dead hole). TODO in our `ablate_quarantine`.
- **forget-adjacent dataset** + `upsample_forget/retain/adjacent` + per-class
loss weights (`logit_alpha`/`logit_beta`): a third "near-forget but keep" data
class to measure/limit collateral damage. We have no near-hack-but-legit eval.
- **logit calibration** (lm_head bias, separate lr) -- not relevant to us.
## Proposed improvements for OUR setting (RH env + SVD + vector)
Ordered by expected value. (A) is in flight; (B) is the structural one.
- **(A) Route per-rollout, not per-token.** route2-grad. Keep the vector, drop
the coin-flip. Gate eligibility on the weak-detector flag so the quarantine
only sees flagged-hack rollouts (vector picks the subspace, flag picks the
rollouts). Status: job 47 running.
- **(B) Route WITHIN delta_S along the SVD v_hack axes -- no separate LoRA.**
The quarantine doesn't need 33M new params. Reserve the v_hack subspace of the
existing delta_S as the "forget" axes. For a flagged-hack rollout, detach the
complement so its gradient flows ONLY into the v_hack-subspace component of
delta_S; at deploy, zero that component. Properties:
- forward-value-preserving during training (pure detach, like SGTM) -> no
growing additive term -> kills the route2-act blow-up at the root.
- uses the SVD basis + vector premise directly (the v_hack axes ARE the
reserved subspace).
- differs from erase: erase throws the hack gradient away every step; this
lets the model USE the hack during training (rollouts stay on-distribution,
still get reward) but DELETES it at deploy. That's the Cloud et al. routing
insight -- routing is more stable than erasing because the capability stays
available while training.
Open question for review: is "route into a deleted v_hack component of delta_S"
meaningfully different from erase in the GRPO single-step (ratio=1) regime, or
does it collapse to the same update? (Suspect different: erase removes the
component from the *kept* axes' grad; this preserves behavior via the forget
axes that are only deleted at deploy.)
- **(C) Break the generation feedback loop.** If we keep any added quarantine,
ablate it during `generate()` so the data-generating policy is the clean
deployed model; route gradient into the quarantine only in the training
backward. Removes the "quar corrupts rollouts -> worse reward -> worse grad"
runaway. (B) makes this partly moot.
- **(D) Noisy-detector sweep (TPR/FPR)** instead of the 2-of-4-modes split, as
the generalisation/no-cheat experiment. A continuous dial on detector recall.
- **(E) Near-hack-but-legit eval** (their "forget-adjacent"): rollouts that look
hack-shaped but are legitimate, to measure over-suppression / collateral
damage on solve rate. Currently we only watch aggregate solve.
- **(F) Relearn probe** (their undiscovered-rate): after deploy-ablation, try to
re-elicit the hack with a few GRPO steps. If it returns fast, the routing only
hid it shallowly rather than preventing the learning.
## On the hyperparameters
The Appendix-I table (LR 5e-3, AdamW, wd 0.1, betas 0.9/0.95) is **from-scratch
TinyStories pretraining of 8M-64M models for 33k steps** — not transferable to
adapter RL on a 4B model. Their *fine-tuning* scripts use `--lr 5e-5`. Adapter
LRs run 10-100x above full-FT, so our delta_S at 2.5e-3 is in range; the 33M
quarantine at 2.5e-4 is still ~5x above their full-FT 5e-5 and is freshly init,
which is consistent with (though not the main cause of) route2-act's drift.