mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 18:59:35 +08:00
docs: SGTM vs ours -- diagnostics, tricks, and proposed improvements (B = route within delta_S along SVD axes)
Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,247 @@
|
||||
# SGTM (selective-gradient-masking) vs our route2 — code comparison
|
||||
|
||||
Cloned `github.com/safety-research/selective-gradient-masking` to
|
||||
`docs/vendor/selective-gradient-masking`. This compares their adapter + train
|
||||
loop to ours (`antipasto.py`, `train.py`). Bottom line: their selectivity is
|
||||
**data-label driven over fixed reserved dimensions**, never a per-token learned
|
||||
direction. That's exactly the axis where our route2-act failed.
|
||||
|
||||
## Their mechanism (src/sgtm/model/)
|
||||
|
||||
- **No adapter.** `SplitLinearOut` splits an existing layer's output features
|
||||
into `weight_retain` `[retain_dim, d_in]` and `weight_forget`
|
||||
`[d_out-retain_dim, d_in]`. The "forget" capacity is part of the model's own
|
||||
width, a fixed dimension partition. No added LoRA, no SVD, no extracted
|
||||
direction.
|
||||
- **Routing is by example label, on those fixed dims.** A batch is labeled
|
||||
`forget` / `retain` / `adjacent` (trainer.py:42,254). Forget-labeled examples
|
||||
go through `forward_forget`:
|
||||
- activation masking: `hidden[:, :, :retain_mlp_dim] = 0` (zero the retain
|
||||
dims so the forget concept can only live in the forget dims).
|
||||
- gradient routing (Cloud 2024): `hidden[:, :, :retain_mlp_dim] =
|
||||
hidden[..., :retain_mlp_dim].detach()` — forward value UNCHANGED, forget
|
||||
gradient flows only into the forget dims.
|
||||
- **Deploy = `ablate()`**: zero `weight_forget` (or, with `trainable=True`,
|
||||
reinit it to the retain-dims' std so the model stays finetunable after the
|
||||
forget subspace is removed).
|
||||
- **Weak detector is a first-class knob.** trainer.py:265-287 mislabels the
|
||||
forget/retain split by explicit `forget_tpr/forget_fpr` or
|
||||
`precision/recall`, simulating an imperfect detector. This is their headline
|
||||
axis (unlearning quality vs detector quality) and it directly matches our
|
||||
no-cheat weak-detector framing.
|
||||
|
||||
## Side by side
|
||||
|
||||
| | SGTM | ours (route2) |
|
||||
|---|---|---|
|
||||
| forget capacity | reserved dims of the real layer | added LoRA `A_q/B_q` (33M) |
|
||||
| where it lives | fixed dimension partition | learned, in SVD-of-W basis |
|
||||
| routing signal | **example label** (forget data), noisy via TPR/FPR | route2-act: **per-token** `cos(a,v_act)>0`; route2-grad: per-rollout grad sign |
|
||||
| forward change | act-mask zeros dims / grad-route detaches (no value change) | act-mask detaches ~half of tokens; quar always summed |
|
||||
| can the masked term blow up? | not in practice (see below) | yes in route2-act (quar runaway; gn 0.23→1.5) |
|
||||
| deploy | zero or reinit forget weights | delete quarantine |
|
||||
| fine-tune LR | 5e-5 (from-scratch pretrain is 5e-3) | 2.5e-3 on delta_S, 2.5e-4 on quar |
|
||||
| norm | standard GPT layernorm, no batchnorm | (Qwen RMSNorm, frozen) |
|
||||
|
||||
## Why theirs doesn't blow up (corrected — not "bounded width")
|
||||
|
||||
My first take ("reserved width is bounded") was too glib. `weight_forget` is a
|
||||
free parameter; nothing caps its magnitude either. The real reasons route2-act
|
||||
runs away and SGTM doesn't:
|
||||
|
||||
1. **Clean label routing => the forget path only ever sees forget gradients.**
|
||||
SGTM routes whole forget-labeled examples through `forward_forget`. The
|
||||
forget dims get gradient only from genuine forget data. route2-act routes
|
||||
~half of ALL tokens (the `cos>0` coin-flip) into the quarantine, so the
|
||||
quarantine absorbs reward-driven gradient on everything and there's a
|
||||
feedback loop (more quar -> worse rollouts -> more reward-chasing grad).
|
||||
2. **Supervised CE, not RL.** Their loss is bounded next-token cross-entropy on
|
||||
fixed data. Ours is GRPO: the policy that generates the data is the thing
|
||||
being perturbed, so a bad quar update degrades the next batch of rollouts.
|
||||
3. **Their forget output participates in the normal forward** (read by the same
|
||||
`c_proj`, normalized by the same downstream layernorms the model co-adapted
|
||||
to). Our `quar` is an independent additive rank-16 path nothing co-adapts to.
|
||||
|
||||
So the protection is mostly (1)+(2): label routing avoids over-routing, and
|
||||
supervised training has no policy-feedback loop. It is NOT mainly adapter-vs-
|
||||
reserved-dims.
|
||||
|
||||
### Can't we adopt (1) and (2)? Mostly yes.
|
||||
|
||||
(1) **Clean routing -- yes, we already can.** Routing per-ROLLOUT instead of
|
||||
per-token gives the same cleanliness (a rollout is all-in or all-out, no
|
||||
coin-flip). route2-grad IS this, keeping the vector premise. We can go further
|
||||
and gate eligibility on the weak detector flag (allowed by no-cheat) so the
|
||||
quarantine only ever sees flagged-hack rollouts -- the direct analog of SGTM's
|
||||
forget label. The vector still decides the *subspace*; the flag decides *which
|
||||
rollouts*. route2-act's per-token granularity was the only part that had to die.
|
||||
|
||||
(2) **No-feedback-loop -- we can't drop RL (that's the research question), but
|
||||
we can remove the actual harmful path.** The loop is "quar corrupts the
|
||||
data-generating policy -> worse rollouts -> worse gradient." Two fixes, neither
|
||||
needs supervised training:
|
||||
- **Ablate the quarantine during generation** (generate from the clean
|
||||
deployed model; route gradient into quar only in the training backward). The
|
||||
quarantine becomes a pure gradient sink that never touches the rollouts.
|
||||
- **Make routing forward-value-preserving like SGTM's gradient-routing
|
||||
variant** -- pure detach, no added growing term. See improvement (B) below:
|
||||
route *within* delta_S along the SVD axes instead of adding a 33M LoRA.
|
||||
|
||||
So the honest read: theirs is safer by construction, but the safety is
|
||||
adoptable. We chose the riskier per-token added-LoRA design; the fixes are known.
|
||||
|
||||
## Key tricks worth stealing
|
||||
|
||||
1. **Route by rollout label, not per-token direction.** SGTM never fires a
|
||||
learned direction per token. Our route2-act's `cos(a,v_act)>0` coin-flip
|
||||
(act_cos~0 → fires on ~half of all tokens, job 46) is an invention with no
|
||||
analog in their working method. route2-grad (per-rollout) and route
|
||||
(flag-selected rollouts) are the variants that match SGTM's philosophy.
|
||||
2. **Reserve capacity instead of extracting a direction.** Their forget concept
|
||||
localizes to reserved dims *by construction* (mask the retain dims on forget
|
||||
data). No staleness, no v_hack refresh, no over-routing. More robust, less
|
||||
ambitious than our direction-extraction.
|
||||
3. **TPR/FPR detector-quality sweep** as the generalisation experiment. Cleaner
|
||||
than our "select pairs from 2 of 4 modes" — we could parametrize detector
|
||||
recall directly. (task #161 leakage metric is the related SGTM eq.1.)
|
||||
4. **ablate(trainable=True): reinit forget weights to matched std** so the model
|
||||
is finetunable post-ablation (not a dead hole). Relevant if we ever retrain
|
||||
after deploy.
|
||||
5. **No extra summed adapter** → no coherence-wrecking additive term. Our quar
|
||||
being always-summed AND large AND freshly-init is a route2-specific failure
|
||||
mode SGTM structurally avoids.
|
||||
|
||||
## What this says about our results
|
||||
|
||||
- route2-act's failure is **not** a tuning problem; it diverges from SGTM at the
|
||||
routing signal (per-token learned direction vs per-example label). The lr fix
|
||||
only slowed the same disease.
|
||||
- route2-grad (job 47) and route are the SGTM-shaped variants. If 47 is coherent
|
||||
and route's file_marker generalisation holds, that's the SGTM-consistent path.
|
||||
- We are deliberately *more ambitious* than SGTM (extract the hack direction
|
||||
rather than reserve dims and route by label). That buys "no need to reserve
|
||||
capacity or know the dims" but costs robustness (extraction can be wrong/stale,
|
||||
and a per-token direction over-routes). Worth stating plainly in the writeup.
|
||||
|
||||
## Pseudocode: theirs vs ours
|
||||
|
||||
The shared idea (gradient routing, Cloud 2024): push a concept's gradient into a
|
||||
throwaway subspace so the deployed model never learns it. They differ in HOW the
|
||||
"this is the concept" decision is made and WHAT subspace it goes to.
|
||||
|
||||
THEIRS (SGTM gradient-routing variant), supervised, route by label:
|
||||
|
||||
# setup: in each MLP, reserve dims [retain_dim:] as the "forget" subspace
|
||||
for x, label in data: # label in {retain, forget}; noisy via TPR/FPR
|
||||
h = c_fc(x) # [b, s, d_mlp]
|
||||
if label == forget:
|
||||
h[:, :, :retain_dim] = h[:, :, :retain_dim].detach() # forget grad -> forget dims only
|
||||
y = c_proj(act(h)) # forward VALUE identical either way (detach != zero)
|
||||
loss = cross_entropy(y, next_token)
|
||||
loss.backward(); opt.step()
|
||||
# deploy: zero weight_forget[retain_dim:] (the reserved forget dims)
|
||||
|
||||
OURS (route2, the novel bit = route by an extracted VECTOR, not a label):
|
||||
|
||||
# setup: v_hack = top SVD dirs of (grad logp(clean) - grad logp(hack)) over pairs
|
||||
# per module: frozen U/Vh; trainable delta_S (kept); quarantine A_q/B_q (forget)
|
||||
for prompts:
|
||||
rollouts = student.generate(prompts) # GRPO, on-policy (no labels)
|
||||
adv = grpo_advantage(reward(rollouts)) # reward from the weak grader
|
||||
# forward: y + kept + quar ; kept = delta_S path, quar = always-summed quarantine
|
||||
# route2-act : per-TOKEN, route tokens with cos(activation, v_act) > 0
|
||||
# -> kept.detach() there (FAILED: cos~0 => fires on ~half of all tokens)
|
||||
# route2-grad: per-ROLLOUT, flag rollouts with cos(grad_b, v_grad) > 0
|
||||
# -> subtract their share from delta_S.grad post-backward
|
||||
loss = grpo_loss(rollouts, adv); loss.backward()
|
||||
route_hack_component_into_quar() # by VECTOR, not label
|
||||
opt.step()
|
||||
# deploy: zero A_q/B_q (quar -> 0)
|
||||
|
||||
The one-line diff: theirs decides "forget" by a (noisy) data label and sends it
|
||||
to reserved dims; ours decides "hack" by cos against an extracted direction and
|
||||
sends it to a learned quarantine. The vector premise is the contribution. The
|
||||
per-token granularity (route2-act) is the part that broke; route2-grad keeps the
|
||||
vector and routes per-rollout instead.
|
||||
|
||||
## Their diagnostics (what they log/plot)
|
||||
|
||||
- **retain vs forget val_loss + val_ppl, separately** (trainer.py:174). The core
|
||||
tradeoff: forget ppl should rise, retain ppl stay flat. Our analog: `lp_s` /
|
||||
`lp_t` and `cin_s` / `cin_t` (want cin_t > cin_s).
|
||||
- **per-category breakdown** (`evaluate_by_category`) -- per-topic forget/retain,
|
||||
like our per-mode `hk_<mode>` columns.
|
||||
- **train/grad_norm** (global, trainer.py:609) + a dedicated `grad_norms.ipynb`
|
||||
notebook -- they track grad norms per component. This is exactly the instinct
|
||||
behind our new `qE` column (energy into the quarantine vs the kept knob).
|
||||
- **retain_forget_tradeoff.ipynb** -- the Pareto curve (forget suppression vs
|
||||
retain damage), swept over detector quality / forget fraction.
|
||||
- **undiscovered_rate.ipynb** -- a relearn/elicitation probe: after unlearning,
|
||||
try to re-elicit the forget capability; "undiscovered" = it stays hidden.
|
||||
We have no relearn probe yet (improvement F below).
|
||||
|
||||
## Their tricks
|
||||
|
||||
- **TPR/FPR + precision/recall noisy-detector knobs** (trainer.py:265-287):
|
||||
corrupt the forget/retain labels to a target detector quality and sweep it.
|
||||
- **ablate(trainable=True)**: reinit forget weights to retain-std (finetunable
|
||||
after deploy, not a dead hole). TODO in our `ablate_quarantine`.
|
||||
- **forget-adjacent dataset** + `upsample_forget/retain/adjacent` + per-class
|
||||
loss weights (`logit_alpha`/`logit_beta`): a third "near-forget but keep" data
|
||||
class to measure/limit collateral damage. We have no near-hack-but-legit eval.
|
||||
- **logit calibration** (lm_head bias, separate lr) -- not relevant to us.
|
||||
|
||||
## Proposed improvements for OUR setting (RH env + SVD + vector)
|
||||
|
||||
Ordered by expected value. (A) is in flight; (B) is the structural one.
|
||||
|
||||
- **(A) Route per-rollout, not per-token.** route2-grad. Keep the vector, drop
|
||||
the coin-flip. Gate eligibility on the weak-detector flag so the quarantine
|
||||
only sees flagged-hack rollouts (vector picks the subspace, flag picks the
|
||||
rollouts). Status: job 47 running.
|
||||
|
||||
- **(B) Route WITHIN delta_S along the SVD v_hack axes -- no separate LoRA.**
|
||||
The quarantine doesn't need 33M new params. Reserve the v_hack subspace of the
|
||||
existing delta_S as the "forget" axes. For a flagged-hack rollout, detach the
|
||||
complement so its gradient flows ONLY into the v_hack-subspace component of
|
||||
delta_S; at deploy, zero that component. Properties:
|
||||
- forward-value-preserving during training (pure detach, like SGTM) -> no
|
||||
growing additive term -> kills the route2-act blow-up at the root.
|
||||
- uses the SVD basis + vector premise directly (the v_hack axes ARE the
|
||||
reserved subspace).
|
||||
- differs from erase: erase throws the hack gradient away every step; this
|
||||
lets the model USE the hack during training (rollouts stay on-distribution,
|
||||
still get reward) but DELETES it at deploy. That's the Cloud et al. routing
|
||||
insight -- routing is more stable than erasing because the capability stays
|
||||
available while training.
|
||||
Open question for review: is "route into a deleted v_hack component of delta_S"
|
||||
meaningfully different from erase in the GRPO single-step (ratio=1) regime, or
|
||||
does it collapse to the same update? (Suspect different: erase removes the
|
||||
component from the *kept* axes' grad; this preserves behavior via the forget
|
||||
axes that are only deleted at deploy.)
|
||||
|
||||
- **(C) Break the generation feedback loop.** If we keep any added quarantine,
|
||||
ablate it during `generate()` so the data-generating policy is the clean
|
||||
deployed model; route gradient into the quarantine only in the training
|
||||
backward. Removes the "quar corrupts rollouts -> worse reward -> worse grad"
|
||||
runaway. (B) makes this partly moot.
|
||||
|
||||
- **(D) Noisy-detector sweep (TPR/FPR)** instead of the 2-of-4-modes split, as
|
||||
the generalisation/no-cheat experiment. A continuous dial on detector recall.
|
||||
|
||||
- **(E) Near-hack-but-legit eval** (their "forget-adjacent"): rollouts that look
|
||||
hack-shaped but are legitimate, to measure over-suppression / collateral
|
||||
damage on solve rate. Currently we only watch aggregate solve.
|
||||
|
||||
- **(F) Relearn probe** (their undiscovered-rate): after deploy-ablation, try to
|
||||
re-elicit the hack with a few GRPO steps. If it returns fast, the routing only
|
||||
hid it shallowly rather than preventing the learning.
|
||||
|
||||
## On the hyperparameters
|
||||
|
||||
The Appendix-I table (LR 5e-3, AdamW, wd 0.1, betas 0.9/0.95) is **from-scratch
|
||||
TinyStories pretraining of 8M-64M models for 33k steps** — not transferable to
|
||||
adapter RL on a 4B model. Their *fine-tuning* scripts use `--lr 5e-5`. Adapter
|
||||
LRs run 10-100x above full-FT, so our delta_S at 2.5e-3 is in range; the 33M
|
||||
quarantine at 2.5e-4 is still ~5x above their full-FT 5e-5 and is freshly init,
|
||||
which is consistent with (though not the main cause of) route2-act's drift.
|
||||
Reference in New Issue
Block a user