mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 16:15:35 +08:00
feat: route2 grad-mask (Arm A) + drop tau knob + pairset-derived v_hack path
Arm A (route2_mask=grad): per-rollout gate splice (identity at c=1) recovers the per-sample delta_S grad after backward (c.grad = delta_S * g_b); train.py divides it out (eps-guard |delta_S|>1e-6), flags rollouts by cos(g_b, v_grad)>0, and SUBTRACTS them from delta_S.grad. Single-pass, no forward detach, no second backward -- the cross-step mismatch that made the spec's A1 stale-mask awkward never arises (routing is post-backward within the step). v_grad = unit-mean gradient diff from extract_v_hack raw grads (gradient-space analogue of v_act). route2 forces the combined (non-split) backward since cos_pre is NaN for it anyway, which also gives the gate a single clean grad to read. Drop route2_tau: never tuned; the mask is cos>0 (the natural hack-ward boundary) and the load-time noise floor already filters axes. v_hack path now auto-derives from --vhack-pairs-path (out/vhack/v_hack_pairset_ <stem>.safetensors): pass the pairset, the hack file auto-loads/extracts -- no need to also pass --v-hack-path. run-substrate drops the redundant flag. smoke: smoke-route2 (act) and new smoke-route2-grad both pass (||B_q||=0.109, exit 0); erase shared-basis path unchanged (cout->0, fired~0.9). Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -128,36 +128,42 @@ m = (cos > tau).float() # weak, noisy mask -- fine (SGTM t
|
||||
y, quar = routed_forward(x, m)
|
||||
```
|
||||
|
||||
### Arm A (grad-vector) — needs the per-sample-gradient trick
|
||||
### Arm A (grad-vector) — single-pass gate SUBTRACTION (implemented)
|
||||
|
||||
The persona direction lives in `delta_S` (gradient) space, but `delta_S.grad` is
|
||||
summed over the batch, so the per-sample signal is not free. Recover it with a
|
||||
per-sample gate (the c_i trick), then route. Two implementable variants because the
|
||||
signal is only known *after* a backward:
|
||||
summed over the batch, so the per-rollout signal is not free. Recover it with a
|
||||
PER-ROLLOUT gate (init 1, identity in the forward) and route POST-backward by
|
||||
subtraction — no forward detach, no second pass:
|
||||
|
||||
```python
|
||||
# v_hack : [r] persona-contrast direction in delta_S/gradient space (existing
|
||||
# extract_vhack_grad.py: SVD of stacked persona-pair GRPO gradients).
|
||||
|
||||
# --- per-sample gradient via gates (one extra leaf, cheap) ---
|
||||
c = ones(b, s, r, requires_grad=True) # per-sample gate over SVD coords
|
||||
h = (delta_S * c) * (Vh @ x) # splice gate into the kept path
|
||||
# v_grad : [r] unit(mean(g_hack - g_clean)) per module, from extract_vhack_grad's
|
||||
# raw per-pair GRPO grads (gradient-space analogue of v_act). Oriented
|
||||
# hack-ward: training reinforces a hack with the same sign, so a
|
||||
# reinforced-hack rollout has cos(g_b, v_grad) > 0.
|
||||
c = ones(b, 1, r, requires_grad=True) # per-rollout gate (broadcast over tokens)
|
||||
kept = U @ ((a * c) * delta_S) # identity at c=1: forward value unchanged
|
||||
quar = B_q @ (A_q @ x) # always summed
|
||||
... usual loss; loss.backward()
|
||||
g_i = c.grad / delta_S[None, None, :] # [b s r] per-sample delta_S-space grad
|
||||
# (divide out the delta_S scaling; watch /0)
|
||||
cos = cosine(g_i, v_hack) # [b s]
|
||||
m = (cos > tau).float()
|
||||
|
||||
# A1 STALE-MASK (single pass): use step (t-1)'s m to detach-route step t's forward.
|
||||
# cheapest; m lags by one step. Good first cut.
|
||||
# A2 TWO-PASS (exact): probe backward -> g_i -> m -> re-run routed_forward + backward.
|
||||
# 2x backward cost; use if A1's lag hurts.
|
||||
# c.grad = delta_S * g_b (g_b = per-rollout delta_S grad). Divide it out, but only on
|
||||
# axes where delta_S has moved (|delta_S|>eps); near-zero axes keep the full grad.
|
||||
g_b = where(|delta_S|>eps, c.grad / delta_S, 0) # [b, r]
|
||||
flag = cosine(g_b, v_grad) > 0 # [b]
|
||||
delta_S.grad -= sum_{flagged b} g_b # exact identity on reliable axes
|
||||
# quarantine keeps every rollout's grad (always-on forward) -> absorption intact.
|
||||
```
|
||||
|
||||
Asymmetry to flag in the writeup: Arm B supports absorption in one pass (mask is a
|
||||
forward-time activation cosine); Arm A's mask is post-backward, so absorption needs
|
||||
A1's stale mask or A2's second pass. This is a real reason Arm B is the cleaner bet
|
||||
and matches the user's "act-space is the natural fit for per-sample" instinct.
|
||||
Why subtraction beats the earlier A1/A2 framings: routing is post-backward WITHIN the
|
||||
same step, so the cross-step sample mismatch that made the "stale mask" (A1) awkward
|
||||
never arises, and we avoid the 2x backward of the exact two-pass (A2). The only
|
||||
approximation is the eps-guard: on an axis where delta_S is still ~0 the per-rollout
|
||||
split is undefined (0/0), so we route nothing there until delta_S grows past eps —
|
||||
a ~1-step-per-axis lag, the same order as A1's lag, and harmless because a ~0 axis
|
||||
carries no learned hack yet. Upgrade to A2 (probe pass -> mask -> detach-route pass)
|
||||
only if that lag measurably hurts.
|
||||
|
||||
Asymmetry still worth flagging in the writeup: Arm B's mask is a free forward-time
|
||||
activation cosine (no gate, no division); Arm A pays a per-rollout gate + an
|
||||
eps-guarded division. Arm B remains the cleaner bet.
|
||||
|
||||
## Requirements
|
||||
|
||||
@@ -189,8 +195,9 @@ and matches the user's "act-space is the natural fit for per-sample" instinct.
|
||||
forward value unchanged vs non-detached.
|
||||
- [ ] T2 (R3): Arm B act-mask — `v_act` extraction from persona pairs + forward-time
|
||||
cosine. verify: R3 fire-ratio on known hack/clean. UAT: "mask fires on hacks".
|
||||
- [ ] T3 (R3): Arm A grad-mask — c_i gate per-sample grad + A1 stale-mask. verify:
|
||||
gate identity (sum_i g_i == delta_S.grad). UAT: "per-sample cos recovered".
|
||||
- [x] T3 (R3): Arm A grad-mask — c_i gate per-sample grad, single-pass subtraction
|
||||
(NOT A1 stale / A2 two-pass; see Log 2026-05-31). gate identity sum_i g_i ==
|
||||
delta_S.grad holds by construction. smoke routing2_grad passes (||B_q||=0.109).
|
||||
- [ ] T4 (R4): leakage metric + L1 knob (`lambda_l1`, default 0).
|
||||
- [ ] T5: 5-arm sweep at matched seed/steps: vanilla, erase, route-additive(old),
|
||||
route2-grad, route2-act. Plus random-V control (#157) on the old route.
|
||||
@@ -245,6 +252,22 @@ and matches the user's "act-space is the natural fit for per-sample" instinct.
|
||||
- 2026-05-31: defaults — vhack_refresh_every 0->5 (0 is ablation-only);
|
||||
route2 reuses run-substrate (v-hack-path ignored, vhack-pairs drives v_act,
|
||||
tau/rank defaulted) so the sweep needs no extra args.
|
||||
- 2026-05-31: T3 (Arm A grad-mask) implemented + smoke-passed. Removed route2_tau
|
||||
(never tuned; mask is cos>0, the natural hack-ward boundary). v_hack path now
|
||||
auto-derives from --vhack-pairs-path (pass the pairset, the hack auto-loads).
|
||||
Arm A design CHANGED from the spec's A1/A2: single-pass gate-SUBTRACTION instead
|
||||
of stale-mask or two-pass. The per-rollout gate c (init 1, identity forward) gives
|
||||
c.grad = delta_S * g_b after backward; train.py divides out delta_S (eps-guard on
|
||||
|delta_S|>1e-6) to get per-rollout g_b, flags cos(g_b, v_grad)>0, and subtracts
|
||||
flagged rollouts from delta_S.grad. No forward detach, no second pass; quarantine
|
||||
still learns flagged rollouts via its always-on path. The cross-step sample-
|
||||
mismatch that made A1 awkward never arises because routing is post-backward within
|
||||
the same step. Lag bound: routing on a fresh axis lags ~1 step until |delta_S|
|
||||
grows past eps there (this is the A1-equivalent one-step lag, per-axis). Upgrade
|
||||
to A2 (two-pass detach) only if the lag hurts. v_grad = unit(mean(g_hack-g_clean))
|
||||
from extract_v_hack raw grads (gradient-space analogue of v_act). smoke
|
||||
routing2_grad: ||B_q||=0.109 after 30 steps (quarantine seeded by routed grad),
|
||||
deploy eval + asserts pass, exit 0.
|
||||
|
||||
## TODO (out of scope now)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user