mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 16:15:35 +08:00
feat: route2 grad-mask (Arm A) + drop tau knob + pairset-derived v_hack path
Arm A (route2_mask=grad): per-rollout gate splice (identity at c=1) recovers the per-sample delta_S grad after backward (c.grad = delta_S * g_b); train.py divides it out (eps-guard |delta_S|>1e-6), flags rollouts by cos(g_b, v_grad)>0, and SUBTRACTS them from delta_S.grad. Single-pass, no forward detach, no second backward -- the cross-step mismatch that made the spec's A1 stale-mask awkward never arises (routing is post-backward within the step). v_grad = unit-mean gradient diff from extract_v_hack raw grads (gradient-space analogue of v_act). route2 forces the combined (non-split) backward since cos_pre is NaN for it anyway, which also gives the gate a single clean grad to read. Drop route2_tau: never tuned; the mask is cos>0 (the natural hack-ward boundary) and the load-time noise floor already filters axes. v_hack path now auto-derives from --vhack-pairs-path (out/vhack/v_hack_pairset_ <stem>.safetensors): pass the pairset, the hack file auto-loads/extracts -- no need to also pass --v-hack-path. run-substrate drops the redundant flag. smoke: smoke-route2 (act) and new smoke-route2-grad both pass (||B_q||=0.109, exit 0); erase shared-basis path unchanged (cout->0, fired~0.9). Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -128,36 +128,42 @@ m = (cos > tau).float() # weak, noisy mask -- fine (SGTM t
|
||||
y, quar = routed_forward(x, m)
|
||||
```
|
||||
|
||||
### Arm A (grad-vector) — needs the per-sample-gradient trick
|
||||
### Arm A (grad-vector) — single-pass gate SUBTRACTION (implemented)
|
||||
|
||||
The persona direction lives in `delta_S` (gradient) space, but `delta_S.grad` is
|
||||
summed over the batch, so the per-sample signal is not free. Recover it with a
|
||||
per-sample gate (the c_i trick), then route. Two implementable variants because the
|
||||
signal is only known *after* a backward:
|
||||
summed over the batch, so the per-rollout signal is not free. Recover it with a
|
||||
PER-ROLLOUT gate (init 1, identity in the forward) and route POST-backward by
|
||||
subtraction — no forward detach, no second pass:
|
||||
|
||||
```python
|
||||
# v_hack : [r] persona-contrast direction in delta_S/gradient space (existing
|
||||
# extract_vhack_grad.py: SVD of stacked persona-pair GRPO gradients).
|
||||
|
||||
# --- per-sample gradient via gates (one extra leaf, cheap) ---
|
||||
c = ones(b, s, r, requires_grad=True) # per-sample gate over SVD coords
|
||||
h = (delta_S * c) * (Vh @ x) # splice gate into the kept path
|
||||
# v_grad : [r] unit(mean(g_hack - g_clean)) per module, from extract_vhack_grad's
|
||||
# raw per-pair GRPO grads (gradient-space analogue of v_act). Oriented
|
||||
# hack-ward: training reinforces a hack with the same sign, so a
|
||||
# reinforced-hack rollout has cos(g_b, v_grad) > 0.
|
||||
c = ones(b, 1, r, requires_grad=True) # per-rollout gate (broadcast over tokens)
|
||||
kept = U @ ((a * c) * delta_S) # identity at c=1: forward value unchanged
|
||||
quar = B_q @ (A_q @ x) # always summed
|
||||
... usual loss; loss.backward()
|
||||
g_i = c.grad / delta_S[None, None, :] # [b s r] per-sample delta_S-space grad
|
||||
# (divide out the delta_S scaling; watch /0)
|
||||
cos = cosine(g_i, v_hack) # [b s]
|
||||
m = (cos > tau).float()
|
||||
|
||||
# A1 STALE-MASK (single pass): use step (t-1)'s m to detach-route step t's forward.
|
||||
# cheapest; m lags by one step. Good first cut.
|
||||
# A2 TWO-PASS (exact): probe backward -> g_i -> m -> re-run routed_forward + backward.
|
||||
# 2x backward cost; use if A1's lag hurts.
|
||||
# c.grad = delta_S * g_b (g_b = per-rollout delta_S grad). Divide it out, but only on
|
||||
# axes where delta_S has moved (|delta_S|>eps); near-zero axes keep the full grad.
|
||||
g_b = where(|delta_S|>eps, c.grad / delta_S, 0) # [b, r]
|
||||
flag = cosine(g_b, v_grad) > 0 # [b]
|
||||
delta_S.grad -= sum_{flagged b} g_b # exact identity on reliable axes
|
||||
# quarantine keeps every rollout's grad (always-on forward) -> absorption intact.
|
||||
```
|
||||
|
||||
Asymmetry to flag in the writeup: Arm B supports absorption in one pass (mask is a
|
||||
forward-time activation cosine); Arm A's mask is post-backward, so absorption needs
|
||||
A1's stale mask or A2's second pass. This is a real reason Arm B is the cleaner bet
|
||||
and matches the user's "act-space is the natural fit for per-sample" instinct.
|
||||
Why subtraction beats the earlier A1/A2 framings: routing is post-backward WITHIN the
|
||||
same step, so the cross-step sample mismatch that made the "stale mask" (A1) awkward
|
||||
never arises, and we avoid the 2x backward of the exact two-pass (A2). The only
|
||||
approximation is the eps-guard: on an axis where delta_S is still ~0 the per-rollout
|
||||
split is undefined (0/0), so we route nothing there until delta_S grows past eps —
|
||||
a ~1-step-per-axis lag, the same order as A1's lag, and harmless because a ~0 axis
|
||||
carries no learned hack yet. Upgrade to A2 (probe pass -> mask -> detach-route pass)
|
||||
only if that lag measurably hurts.
|
||||
|
||||
Asymmetry still worth flagging in the writeup: Arm B's mask is a free forward-time
|
||||
activation cosine (no gate, no division); Arm A pays a per-rollout gate + an
|
||||
eps-guarded division. Arm B remains the cleaner bet.
|
||||
|
||||
## Requirements
|
||||
|
||||
@@ -189,8 +195,9 @@ and matches the user's "act-space is the natural fit for per-sample" instinct.
|
||||
forward value unchanged vs non-detached.
|
||||
- [ ] T2 (R3): Arm B act-mask — `v_act` extraction from persona pairs + forward-time
|
||||
cosine. verify: R3 fire-ratio on known hack/clean. UAT: "mask fires on hacks".
|
||||
- [ ] T3 (R3): Arm A grad-mask — c_i gate per-sample grad + A1 stale-mask. verify:
|
||||
gate identity (sum_i g_i == delta_S.grad). UAT: "per-sample cos recovered".
|
||||
- [x] T3 (R3): Arm A grad-mask — c_i gate per-sample grad, single-pass subtraction
|
||||
(NOT A1 stale / A2 two-pass; see Log 2026-05-31). gate identity sum_i g_i ==
|
||||
delta_S.grad holds by construction. smoke routing2_grad passes (||B_q||=0.109).
|
||||
- [ ] T4 (R4): leakage metric + L1 knob (`lambda_l1`, default 0).
|
||||
- [ ] T5: 5-arm sweep at matched seed/steps: vanilla, erase, route-additive(old),
|
||||
route2-grad, route2-act. Plus random-V control (#157) on the old route.
|
||||
@@ -245,6 +252,22 @@ and matches the user's "act-space is the natural fit for per-sample" instinct.
|
||||
- 2026-05-31: defaults — vhack_refresh_every 0->5 (0 is ablation-only);
|
||||
route2 reuses run-substrate (v-hack-path ignored, vhack-pairs drives v_act,
|
||||
tau/rank defaulted) so the sweep needs no extra args.
|
||||
- 2026-05-31: T3 (Arm A grad-mask) implemented + smoke-passed. Removed route2_tau
|
||||
(never tuned; mask is cos>0, the natural hack-ward boundary). v_hack path now
|
||||
auto-derives from --vhack-pairs-path (pass the pairset, the hack auto-loads).
|
||||
Arm A design CHANGED from the spec's A1/A2: single-pass gate-SUBTRACTION instead
|
||||
of stale-mask or two-pass. The per-rollout gate c (init 1, identity forward) gives
|
||||
c.grad = delta_S * g_b after backward; train.py divides out delta_S (eps-guard on
|
||||
|delta_S|>1e-6) to get per-rollout g_b, flags cos(g_b, v_grad)>0, and subtracts
|
||||
flagged rollouts from delta_S.grad. No forward detach, no second pass; quarantine
|
||||
still learns flagged rollouts via its always-on path. The cross-step sample-
|
||||
mismatch that made A1 awkward never arises because routing is post-backward within
|
||||
the same step. Lag bound: routing on a fresh axis lags ~1 step until |delta_S|
|
||||
grows past eps there (this is the A1-equivalent one-step lag, per-axis). Upgrade
|
||||
to A2 (two-pass detach) only if the lag hurts. v_grad = unit(mean(g_hack-g_clean))
|
||||
from extract_v_hack raw grads (gradient-space analogue of v_act). smoke
|
||||
routing2_grad: ||B_q||=0.109 after 30 steps (quarantine seeded by routed grad),
|
||||
deploy eval + asserts pass, exit 0.
|
||||
|
||||
## TODO (out of scope now)
|
||||
|
||||
|
||||
@@ -46,13 +46,25 @@ smoke-route *ARGS:
|
||||
|
||||
# Routing-v2 path (route2): distinct-basis quarantine (A_q,B_q) + per-sample
|
||||
# act-mask detach-route in the FORWARD. Fires extract_v_act, the quarantine
|
||||
# optimizer params, the act-mask cosine, the B_q-moved assert, and the deploy
|
||||
# ablation (B_q zeroed). tau=-1 forces all-flagged so the seed path is exercised
|
||||
# on tiny inputs (real runs calibrate tau against R3 fire-ratio).
|
||||
# optimizer params, the act-mask cosine (route cos>0), the B_q-moved assert, and
|
||||
# the deploy ablation (B_q zeroed). On the tiny-random model v_act is near-random
|
||||
# so ~half the samples flag -- both the detach and the through paths fire.
|
||||
smoke-route2 *ARGS:
|
||||
BEARTYPE=1 CUDA_VISIBLE_DEVICES= {{ TRAIN }} smoke --intervention=route2 \
|
||||
--teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \
|
||||
--route2-quarantine-rank=8 --route2-tau=-1.0 \
|
||||
--route2-quarantine-rank=8 \
|
||||
--eval-ablate-every=10 --eval-n-prompts=2 {{ ARGS }}
|
||||
|
||||
# route2 GRAD-mask (Arm A): distinct code path from act -- a per-rollout gate is
|
||||
# spliced into the forward, then train.py recovers per-sample grad (c.grad/delta_S)
|
||||
# and routes by SUBTRACTING flagged rollouts from delta_S.grad post-backward.
|
||||
# Exercises the gate forward, extract_v_hack mean-diff -> v_grad, the subtraction,
|
||||
# and the B_q-moved assert. Run alongside smoke-route2 (act); they don't overlap.
|
||||
smoke-route2-grad *ARGS:
|
||||
BEARTYPE=1 CUDA_VISIBLE_DEVICES= {{ TRAIN }} smoke --intervention=route2 \
|
||||
--route2-mask=grad \
|
||||
--teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \
|
||||
--route2-quarantine-rank=8 \
|
||||
--eval-ablate-every=10 --eval-n-prompts=2 {{ ARGS }}
|
||||
|
||||
# Run smoke twice: first warms the v_hack cache (cache-miss path), second hits
|
||||
@@ -156,7 +168,6 @@ build-substrate MODES="run_tests,exit_code,sentinel":
|
||||
run-substrate INTERV="none" SEED="41" STEPS="60" REFRESH="5":
|
||||
{{ TRAIN }} fast --intervention={{ INTERV }} \
|
||||
--teacher-pool-dir=out/pools/substrate \
|
||||
--v-hack-path=out/vhack/v_hack_pairset_prog_wide.safetensors \
|
||||
--vhack-pairs-path=out/pairsets/prog_wide.json \
|
||||
--vhack-refresh-every={{ REFRESH }} \
|
||||
--seed={{ SEED }} --steps={{ STEPS }} --out-tag=_sub4_{{ INTERV }}_rf{{ REFRESH }}_s{{ SEED }}
|
||||
|
||||
@@ -73,17 +73,25 @@ def _delta_hook(layer: nn.Linear, args: tuple, y: Tensor) -> Tensor:
|
||||
delta_S = main knob; delta_S_hack = shared-basis routing quarantine (proj.py
|
||||
parks the removed hack-ward grad here). Both 0 at init -> identity.
|
||||
|
||||
DISTINCT-BASIS route2 (A_q set):
|
||||
a = Vh @ x # [..., r] activation in SVD coords
|
||||
kept= U @ (delta_S * a) # the deployed adapter (delta_S_hack stays 0)
|
||||
m = cos(a, v_act) > tau # [...] per-sample/token persona flag
|
||||
DISTINCT-BASIS route2 (A_q set). The quarantine `quar = B_q @ (A_q @ x)` is
|
||||
ALWAYS summed (item 3: in-path for every sample so unflagged hacks can
|
||||
concentrate there by gradient magnitude -- absorption, Cloud 2410.04332 /
|
||||
SGTM 2512.05648). Two mask sources select WHICH samples route to it:
|
||||
|
||||
mask_mode="act" (Arm B, single pass): flag in activation space,
|
||||
m = cos(a, v_act) > 0 # [...] per-token, free at forward time
|
||||
kept= where(m, kept.detach(), kept)# flagged -> delta_S gets no grad
|
||||
quar= B_q @ (A_q @ x) # distinct-basis quarantine, always summed
|
||||
y += kept + quar
|
||||
Routing is done in the forward, not the gradient: a flagged sample detaches
|
||||
the kept path so only the quarantine learns it; an unflagged sample updates
|
||||
both, so a hack-like one concentrates in the quarantine where its gradient is
|
||||
larger (absorption, Cloud 2410.04332 / SGTM 2512.05648). Deploy zeroes A_q,B_q.
|
||||
The detach leaves the forward value unchanged; only delta_S's backward is cut.
|
||||
|
||||
mask_mode="grad" (Arm A): flag in gradient space. The per-sample delta_S
|
||||
gradient is not free (delta_S.grad is summed over the batch), so splice a
|
||||
per-rollout gate c (init 1, identity in the forward): after backward
|
||||
c.grad = delta_S * g_b (the per-rollout delta_S grad). train.py divides out
|
||||
delta_S, flags rollouts by cos(g_b, v_grad)>0, and subtracts the flagged
|
||||
rollouts' contribution from delta_S.grad -- routing happens POST-backward, so
|
||||
no forward detach here. quar still gets every rollout (always summed).
|
||||
|
||||
Deploy zeroes A_q,B_q (kept-only).
|
||||
"""
|
||||
(x,) = args
|
||||
Vh = layer._antipasto_Vh # [r, d_in]
|
||||
@@ -97,15 +105,30 @@ def _delta_hook(layer: nn.Linear, args: tuple, y: Tensor) -> Tensor:
|
||||
h = a * (delta_S + delta_S_hack).to(a.dtype)
|
||||
return y + torch.nn.functional.linear(h, U).to(y.dtype)
|
||||
|
||||
# --- route2: distinct-basis quarantine + per-sample act-mask detach-route ---
|
||||
kept = torch.nn.functional.linear(a * delta_S.to(a.dtype), U) # [..., d_out]
|
||||
v_act = layer._antipasto_v_act # [r] unit, hack-ward, in Vh coords
|
||||
# per-position cosine of the SVD-coord activation with the persona direction
|
||||
cos = (a @ v_act) / (a.norm(dim=-1).clamp_min(1e-6) * v_act.norm().clamp_min(1e-6))
|
||||
m = cos > layer._antipasto_tau # [...] bool
|
||||
kept = torch.where(m.unsqueeze(-1), kept.detach(), kept)
|
||||
# --- route2: distinct-basis quarantine, always summed ---
|
||||
B_q = layer._antipasto_B_q # [d_out, k_q]
|
||||
quar = torch.nn.functional.linear(torch.nn.functional.linear(x, A_q), B_q) # [..., d_out]
|
||||
|
||||
if layer._antipasto_mask_mode == "grad":
|
||||
if torch.is_grad_enabled():
|
||||
# per-rollout gate [b, 1.., r], identity at c=1 so the forward value is
|
||||
# unchanged. After backward c.grad = delta_S * (per-rollout delta_S grad);
|
||||
# train.py divides out delta_S to recover g_b and routes post-backward.
|
||||
c = torch.ones(a.shape[0], *([1] * (a.dim() - 2)), a.shape[-1],
|
||||
device=a.device, dtype=a.dtype, requires_grad=True)
|
||||
layer._antipasto_gate = c
|
||||
kept = torch.nn.functional.linear((a * c) * delta_S.to(a.dtype), U)
|
||||
else: # generate(): no grad -> no gate needed
|
||||
kept = torch.nn.functional.linear(a * delta_S.to(a.dtype), U)
|
||||
return y + (kept + quar).to(y.dtype)
|
||||
|
||||
# mask_mode="act": forward detach-route by activation cosine. cos>0 == points
|
||||
# hack-ward == route to quarantine (the natural boundary, no threshold knob).
|
||||
kept = torch.nn.functional.linear(a * delta_S.to(a.dtype), U) # [..., d_out]
|
||||
v_act = layer._antipasto_v_act # [r] unit, hack-ward, in Vh coords
|
||||
cos = (a @ v_act) / (a.norm(dim=-1).clamp_min(1e-6) * v_act.norm().clamp_min(1e-6))
|
||||
m = cos > 0 # [...] bool
|
||||
kept = torch.where(m.unsqueeze(-1), kept.detach(), kept)
|
||||
return y + (kept + quar).to(y.dtype)
|
||||
|
||||
|
||||
@@ -115,7 +138,7 @@ def wrap_model_with_antipasto(
|
||||
cache_root: Path = Path("svd_cache"),
|
||||
svd_device: torch.device | str = "cuda",
|
||||
quarantine_rank: int | None = None,
|
||||
route2_tau: float = 0.0,
|
||||
route2_mask: str = "act",
|
||||
) -> dict[str, dict]:
|
||||
"""Attach AntiPaSTO hooks to every target nn.Linear in `model` (in place).
|
||||
|
||||
@@ -125,8 +148,10 @@ def wrap_model_with_antipasto(
|
||||
|
||||
`quarantine_rank` (route2 only): if set, also attach a DISTINCT-basis LoRA
|
||||
quarantine per module -- A_q [k_q, d_in] (kaiming), B_q [d_out, k_q] (zeros so
|
||||
quar=0 at init), and a unit `v_act` buffer [r] (filled by the act-mask
|
||||
extraction in train.py) + scalar `route2_tau`. None -> shared-basis path only.
|
||||
quar=0 at init), a unit `v_act` buffer [r] (filled by the act-mask extraction
|
||||
in train.py), and a `_antipasto_gate` slot (grad-mask probe). None -> shared-
|
||||
basis path only. `route2_mask` selects the mask source: "act" (forward cosine,
|
||||
routes in-forward) or "grad" (gate probe, routes post-backward in train.py).
|
||||
"""
|
||||
svd_device_t = torch.device(svd_device) if isinstance(svd_device, str) else svd_device
|
||||
safe = model_name.replace("/", "__")
|
||||
@@ -167,7 +192,8 @@ def wrap_model_with_antipasto(
|
||||
linear.register_parameter("_antipasto_B_q", B_q)
|
||||
linear.register_buffer("_antipasto_v_act",
|
||||
torch.zeros(r, device=dev, dtype=torch.float32), persistent=True)
|
||||
linear._antipasto_tau = route2_tau
|
||||
linear._antipasto_mask_mode = route2_mask # "act" | "grad"
|
||||
linear._antipasto_gate = None # grad-mask probe leaf, set per forward
|
||||
info["A_q"], info["B_q"], info["k_q"] = A_q, B_q, k_q
|
||||
info["handle"] = linear.register_forward_hook(_delta_hook)
|
||||
out[name] = info
|
||||
|
||||
+98
-37
@@ -146,11 +146,14 @@ class Config:
|
||||
# Replaces the old `arm` flag (vanilla/projected); `arm` survives as a derived
|
||||
# display name (see property below) so log/run-id formatting is unchanged.
|
||||
intervention: Literal["none", "erase", "route", "route2"] = "erase"
|
||||
# route2-only: quarantine LoRA rank (per module, capped at r) and the act-mask
|
||||
# cosine threshold. tau is the routing-fraction knob (higher -> route less ->
|
||||
# less starvation, weaker seed). Calibrate against R3 (fire >2x on hack vs clean).
|
||||
# route2-only: quarantine LoRA rank (per module, capped at r). Any sample whose
|
||||
# mask cosine > 0 (points hack-ward) routes into the quarantine; no threshold
|
||||
# knob -- the load-time noise floor already filters.
|
||||
route2_quarantine_rank: int = 16
|
||||
route2_tau: float = 0.0
|
||||
# route2 mask source. "act" (Arm B): forward-time cos(a, v_act), routes via
|
||||
# detach, single pass. "grad" (Arm A): per-rollout cos(g_b, v_grad) from a gate
|
||||
# probe, routes by subtracting flagged rollouts from delta_S.grad post-backward.
|
||||
route2_mask: Literal["act", "grad"] = "act"
|
||||
# Scale-dependent knobs — every preset must set these to a real value;
|
||||
# subclasses below override the defaults.
|
||||
model: str = "Qwen/Qwen3-4B"
|
||||
@@ -251,9 +254,11 @@ class Config:
|
||||
@property
|
||||
def arm(self) -> str:
|
||||
"""Display name for run-id / BLUF / logs (results.py + plot_dynamics
|
||||
classify off this). One-to-one with intervention; not a CLI flag."""
|
||||
return {"none": "vanilla", "erase": "projected",
|
||||
"route": "routing", "route2": "routing2"}[self.intervention]
|
||||
classify off this). One-to-one with intervention; not a CLI flag.
|
||||
route2 splits by mask source so the 5-arm plot can tell act from grad."""
|
||||
if self.intervention == "route2":
|
||||
return f"routing2_{self.route2_mask}"
|
||||
return {"none": "vanilla", "erase": "projected", "route": "routing"}[self.intervention]
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@@ -744,10 +749,11 @@ def main(cfg: Config) -> int:
|
||||
model.config.use_cache = False
|
||||
|
||||
is_route2 = cfg.intervention == "route2"
|
||||
is_route2_grad = is_route2 and cfg.route2_mask == "grad"
|
||||
wrappers = wrap_model_with_antipasto(
|
||||
model, model_name, CACHE_ROOT, device,
|
||||
quarantine_rank=cfg.route2_quarantine_rank if is_route2 else None,
|
||||
route2_tau=cfg.route2_tau,
|
||||
route2_mask=cfg.route2_mask,
|
||||
)
|
||||
# Both knobs are trainable params. delta_S_hack only ever gets a grad under
|
||||
# intervention=route (the routing split in proj.py); under none/erase its
|
||||
@@ -755,8 +761,8 @@ def main(cfg: Config) -> int:
|
||||
# so none/erase reproduce the pre-quarantine behaviour bit-for-bit).
|
||||
delta_params = [info["delta_S"] for info in wrappers.values()]
|
||||
delta_hack_params = [info["delta_S_hack"] for info in wrappers.values()]
|
||||
# route2: the distinct-basis quarantine LoRA (A_q,B_q). Trained for all arms
|
||||
# it exists in (route2 only); routing happens in the forward, not the grad.
|
||||
# route2: the distinct-basis quarantine LoRA (A_q,B_q), trainable, deleted at
|
||||
# deploy. act-mask routes in the forward (detach); grad-mask routes post-backward.
|
||||
quar_params = ([info["A_q"] for info in wrappers.values()]
|
||||
+ [info["B_q"] for info in wrappers.values()]) if is_route2 else []
|
||||
logger.info(f"trainable delta_S: {sum(p.numel() for p in delta_params):,} "
|
||||
@@ -768,44 +774,63 @@ def main(cfg: Config) -> int:
|
||||
# entirely -- loading it there only to print a cos_pre diagnostic was misleading
|
||||
# (and could trigger a needless ~5-min extraction). The cin/cout columns are
|
||||
# hidden on vanilla, so v_hack=None just means "no subspace machinery".
|
||||
v_grad = None # set only by the route2 grad-mask branch below
|
||||
if cfg.intervention in ("none", "route2"):
|
||||
if cfg.intervention == "none" and cfg.v_hack_path is not None:
|
||||
logger.info(f"vanilla arm: ignoring --v-hack-path={cfg.v_hack_path} "
|
||||
"(no projection; cin/cout diagnostics off)")
|
||||
v_hack = None # route2 routes in the FORWARD via v_act, not via grad surgery
|
||||
v_hack = None # route2 routes via the mask, not erase/route grad surgery
|
||||
if is_route2:
|
||||
# Build the activation-space persona direction (Arm B mask source) from
|
||||
# the SAME persona pairs the grad extract uses, then load it into each
|
||||
# module's _antipasto_v_act buffer. No oracle: pairs are the weak detector.
|
||||
from .extract_vhack_grad import extract_v_act
|
||||
# The persona pairs are the only "detector" (weak, self-supervised). They
|
||||
# produce the mask direction; no oracle, no gt_pass. Same pairs for both
|
||||
# masks so act vs grad differ only in the space the direction lives in.
|
||||
if cfg.vhack_pairs_path is not None:
|
||||
from .pairs_from_pool import load_pairs_json
|
||||
VACT_PAIRS = load_pairs_json(cfg.vhack_pairs_path)
|
||||
logger.info(f"v_act pairs: pool-derived ({cfg.vhack_pairs_path}) -> {len(VACT_PAIRS)} pairs")
|
||||
MASK_PAIRS = load_pairs_json(cfg.vhack_pairs_path)
|
||||
logger.info(f"route2 mask pairs: pool-derived ({cfg.vhack_pairs_path}) -> {len(MASK_PAIRS)} pairs")
|
||||
else:
|
||||
from .pairs import PAIRS as VACT_PAIRS
|
||||
logger.info(f"v_act pairs: hand-crafted PAIRS -> {len(VACT_PAIRS)} pairs")
|
||||
from .pairs import PAIRS as MASK_PAIRS
|
||||
logger.info(f"route2 mask pairs: hand-crafted PAIRS -> {len(MASK_PAIRS)} pairs")
|
||||
model.eval()
|
||||
v_act = extract_v_act(model, tok, wrappers, VACT_PAIRS, n_heldout=2, device=device)
|
||||
if cfg.route2_mask == "act":
|
||||
# Arm B: activation-space mean-diff -> _antipasto_v_act buffer (the
|
||||
# forward cos(a,v_act) reads it). Forward-only, cheap.
|
||||
from .extract_vhack_grad import extract_v_act
|
||||
v_act = extract_v_act(model, tok, wrappers, MASK_PAIRS, n_heldout=2, device=device)
|
||||
for name, info in wrappers.items():
|
||||
info["layer"]._antipasto_v_act.data.copy_(v_act[name].to(device))
|
||||
logger.info(f"route2 act: loaded v_act into {len(v_act)} modules")
|
||||
else:
|
||||
# Arm A: gradient-space mean-diff. extract_v_hack gives per-pair GRPO
|
||||
# gradients on delta_S; v_grad = unit(mean(g_hack - g_clean)) per
|
||||
# module, oriented hack-ward (training reinforces hacks with the
|
||||
# same sign, so cos(g_b, v_grad)>0 flags a reinforced-hack rollout).
|
||||
from .extract_vhack_grad import extract_v_hack
|
||||
_, _, raw_grads, _ = extract_v_hack(
|
||||
model, tok, wrappers, MASK_PAIRS,
|
||||
top_k=1, tau_axis=0.0, n_heldout=2, device=device,
|
||||
)
|
||||
v_grad = {}
|
||||
for name in wrappers:
|
||||
d = (raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).mean(0)
|
||||
v_grad[name] = (d / d.norm().clamp_min(1e-12)).to(device)
|
||||
logger.info(f"route2 grad: built v_grad (gradient mean-diff) for {len(v_grad)} modules")
|
||||
model.train()
|
||||
for name, info in wrappers.items():
|
||||
info["layer"]._antipasto_v_act.data.copy_(v_act[name].to(device))
|
||||
logger.info(f"route2: loaded v_act into {len(v_act)} modules; "
|
||||
f"quarantine_rank={cfg.route2_quarantine_rank} tau={cfg.route2_tau}")
|
||||
else:
|
||||
# derive default path from model + extract_top_k unless overridden.
|
||||
# Auto-extract reuses the already-wrapped model — no second model load.
|
||||
# Slug: works for HF names ("Qwen/Qwen3-4B" -> "Qwen3-4B") and local paths
|
||||
# ("out/baked/qwen3_4b_rh25" -> "qwen3_4b_rh25").
|
||||
model_slug = model_name.rstrip("/").split("/")[-1]
|
||||
# Filename encodes top_k AND tau_axis because both are baked into the saved
|
||||
# V (extract zeros rows where S_i/S_0 < tau_axis before saving). If a future
|
||||
# ablation varies pairs.py, add a pairs hash here too.
|
||||
tau_tag = f"_tau{cfg.v_hack_tau_axis:g}" if cfg.v_hack_tau_axis > 0 else ""
|
||||
if cfg.v_hack_path is None:
|
||||
v_hack_path = VHACK_DIR / f"v_hack_{model_slug}_k{cfg.v_hack_extract_top_k}{tau_tag}.safetensors"
|
||||
# v_hack path resolution, most-specific first. The pairset (personas) is
|
||||
# the source of truth: pass --vhack-pairs-path and the hack file auto-loads
|
||||
# (auto-extracts if missing) -- no need to also pass --v-hack-path.
|
||||
if cfg.v_hack_path is not None:
|
||||
v_hack_path = cfg.v_hack_path # explicit override (e.g. randomV control)
|
||||
elif cfg.vhack_pairs_path is not None:
|
||||
v_hack_path = VHACK_DIR / f"v_hack_pairset_{cfg.vhack_pairs_path.stem}.safetensors"
|
||||
else:
|
||||
v_hack_path = cfg.v_hack_path
|
||||
# no pairset given -> hand-crafted PAIRS, keyed by model + extract knobs.
|
||||
# Slug works for HF names and local paths; tau_tag because tau_axis is
|
||||
# baked into the saved V (extract zeros rows where S_i/S_0 < tau_axis).
|
||||
model_slug = model_name.rstrip("/").split("/")[-1]
|
||||
tau_tag = f"_tau{cfg.v_hack_tau_axis:g}" if cfg.v_hack_tau_axis > 0 else ""
|
||||
v_hack_path = VHACK_DIR / f"v_hack_{model_slug}_k{cfg.v_hack_extract_top_k}{tau_tag}.safetensors"
|
||||
if not v_hack_path.exists():
|
||||
from .extract_vhack_grad import extract_v_hack
|
||||
if cfg.vhack_pairs_path is not None:
|
||||
@@ -1101,11 +1126,38 @@ def main(cfg: Config) -> int:
|
||||
step_grad_quar[key] = (step_grad_quar[key] + p.grad.detach().clone()
|
||||
if key in step_grad_quar else p.grad.detach().clone())
|
||||
|
||||
# route2 grad-mask: recover the per-rollout delta_S grad from the gate
|
||||
# (c.grad = delta_S * g_b), flag rollouts whose grad points hack-ward
|
||||
# (cos(g_b, v_grad) > 0), and subtract their contribution from delta_S.grad.
|
||||
# Only axes where delta_S has moved (|delta_S| > GATE_EPS) carry a reliable
|
||||
# per-rollout split; near-zero axes keep the full grad, so routing on a fresh
|
||||
# axis lags ~1 step until delta_S grows there (the A1 stale-mask trade-off).
|
||||
GATE_EPS = 1e-6
|
||||
step_flagged: list[float] = []
|
||||
|
||||
def _route2_grad_filter(info) -> torch.Tensor:
|
||||
g = info["delta_S"].grad # [r] summed over rollouts
|
||||
cg = info["layer"]._antipasto_gate.grad.reshape(-1, g.shape[0]) # [b, r]
|
||||
dS = info["delta_S"].detach() # [r]
|
||||
reliable = dS.abs() > GATE_EPS # [r]
|
||||
dS_safe = torch.where(reliable, dS, torch.ones_like(dS))
|
||||
g_b = torch.where(reliable, cg / dS_safe, torch.zeros_like(cg)) # [b, r]
|
||||
vg = v_grad[name] # [r] unit, hack-ward
|
||||
cos_b = (g_b @ vg) / g_b.norm(dim=1).clamp_min(1e-12) # [b]
|
||||
flagged = (cos_b > 0).float() # [b]
|
||||
step_flagged.append(flagged.mean().item())
|
||||
sub = torch.where(reliable, (cg * flagged.unsqueeze(1)).sum(0) / dS_safe,
|
||||
torch.zeros_like(g)) # flagged rollouts' contribution
|
||||
return g - sub
|
||||
|
||||
# Split backward into student/teacher only every cos_pre_split_every steps.
|
||||
# On split steps: 2 backwards per prompt, populates step_grad_s/_t.
|
||||
# On skipped steps: 1 combined backward, step_grad_s/_t stay empty and
|
||||
# cos_pre_s/cos_pre_t go to NaN (mean_cos_pre_from_grads returns NaN on empty dict).
|
||||
split_this_step = (step % cfg.cos_pre_split_every == 0)
|
||||
# route2 has no v_hack so cos_pre is NaN regardless: force the single combined
|
||||
# backward (the split would just double cost). The grad-mask reads its
|
||||
# per-rollout gate from that one backward.
|
||||
split_this_step = (step % cfg.cos_pre_split_every == 0) and not is_route2
|
||||
# Phase timers (per-step cumulative, seconds). Each GPU phase ends in a
|
||||
# CPU-blocking op (decode / .item()), so perf_counter is sync-accurate
|
||||
# without explicit cuda.synchronize. Tells us whether wall-time is
|
||||
@@ -1396,6 +1448,10 @@ def main(cfg: Config) -> int:
|
||||
g = info["delta_S"].grad
|
||||
if g is None:
|
||||
continue
|
||||
# grad-mask routes here: strip flagged rollouts from delta_S.grad
|
||||
# (quarantine still learns them via its always-on forward path).
|
||||
if is_route2_grad:
|
||||
g = _route2_grad_filter(info)
|
||||
step_grad_s[name] = (step_grad_s[name] + g.detach().clone()
|
||||
if name in step_grad_s
|
||||
else g.detach().clone())
|
||||
@@ -1435,6 +1491,11 @@ def main(cfg: Config) -> int:
|
||||
diag = {"mean_cos_pre": float("nan"), "mean_cos_post": float("nan"),
|
||||
"frac_fired": float("nan"), "mean_cos_pre_s": float("nan"),
|
||||
"mean_cos_pre_t": float("nan")}
|
||||
# route2 grad-mask: report the mean per-module per-rollout flag rate so
|
||||
# we can watch the mask actually fire (and rise as hacks emerge).
|
||||
if is_route2_grad and step_flagged:
|
||||
logger.debug(f"route2-grad flagged frac (mean over modules*prompts): "
|
||||
f"{sum(step_flagged)/len(step_flagged):+.3f}")
|
||||
else:
|
||||
if split_this_step:
|
||||
cos_pre_s = mean_cos_pre_from_grads(step_grad_s, v_hack)
|
||||
|
||||
Reference in New Issue
Block a user