From 670fcb3c64707ce0d567a0643072169b08523189 Mon Sep 17 00:00:00 2001 From: wassname Date: Sun, 31 May 2026 10:48:31 +0000 Subject: [PATCH] feat: route2 grad-mask (Arm A) + drop tau knob + pairset-derived v_hack path Arm A (route2_mask=grad): per-rollout gate splice (identity at c=1) recovers the per-sample delta_S grad after backward (c.grad = delta_S * g_b); train.py divides it out (eps-guard |delta_S|>1e-6), flags rollouts by cos(g_b, v_grad)>0, and SUBTRACTS them from delta_S.grad. Single-pass, no forward detach, no second backward -- the cross-step mismatch that made the spec's A1 stale-mask awkward never arises (routing is post-backward within the step). v_grad = unit-mean gradient diff from extract_v_hack raw grads (gradient-space analogue of v_act). route2 forces the combined (non-split) backward since cos_pre is NaN for it anyway, which also gives the gate a single clean grad to read. Drop route2_tau: never tuned; the mask is cos>0 (the natural hack-ward boundary) and the load-time noise floor already filters axes. v_hack path now auto-derives from --vhack-pairs-path (out/vhack/v_hack_pairset_ .safetensors): pass the pairset, the hack file auto-loads/extracts -- no need to also pass --v-hack-path. run-substrate drops the redundant flag. smoke: smoke-route2 (act) and new smoke-route2-grad both pass (||B_q||=0.109, exit 0); erase shared-basis path unchanged (cout->0, fired~0.9). Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com> --- .../20260531_routing_v2_distinct_basis.md | 73 ++++++---- justfile | 21 ++- src/projected_grpo/antipasto.py | 68 ++++++--- src/projected_grpo/train.py | 135 +++++++++++++----- 4 files changed, 209 insertions(+), 88 deletions(-) diff --git a/docs/spec/20260531_routing_v2_distinct_basis.md b/docs/spec/20260531_routing_v2_distinct_basis.md index 7055657..fc34f21 100644 --- a/docs/spec/20260531_routing_v2_distinct_basis.md +++ b/docs/spec/20260531_routing_v2_distinct_basis.md @@ -128,36 +128,42 @@ m = (cos > tau).float() # weak, noisy mask -- fine (SGTM t y, quar = routed_forward(x, m) ``` -### Arm A (grad-vector) — needs the per-sample-gradient trick +### Arm A (grad-vector) — single-pass gate SUBTRACTION (implemented) The persona direction lives in `delta_S` (gradient) space, but `delta_S.grad` is -summed over the batch, so the per-sample signal is not free. Recover it with a -per-sample gate (the c_i trick), then route. Two implementable variants because the -signal is only known *after* a backward: +summed over the batch, so the per-rollout signal is not free. Recover it with a +PER-ROLLOUT gate (init 1, identity in the forward) and route POST-backward by +subtraction — no forward detach, no second pass: ```python -# v_hack : [r] persona-contrast direction in delta_S/gradient space (existing -# extract_vhack_grad.py: SVD of stacked persona-pair GRPO gradients). - -# --- per-sample gradient via gates (one extra leaf, cheap) --- -c = ones(b, s, r, requires_grad=True) # per-sample gate over SVD coords -h = (delta_S * c) * (Vh @ x) # splice gate into the kept path +# v_grad : [r] unit(mean(g_hack - g_clean)) per module, from extract_vhack_grad's +# raw per-pair GRPO grads (gradient-space analogue of v_act). Oriented +# hack-ward: training reinforces a hack with the same sign, so a +# reinforced-hack rollout has cos(g_b, v_grad) > 0. +c = ones(b, 1, r, requires_grad=True) # per-rollout gate (broadcast over tokens) +kept = U @ ((a * c) * delta_S) # identity at c=1: forward value unchanged +quar = B_q @ (A_q @ x) # always summed ... usual loss; loss.backward() -g_i = c.grad / delta_S[None, None, :] # [b s r] per-sample delta_S-space grad - # (divide out the delta_S scaling; watch /0) -cos = cosine(g_i, v_hack) # [b s] -m = (cos > tau).float() - -# A1 STALE-MASK (single pass): use step (t-1)'s m to detach-route step t's forward. -# cheapest; m lags by one step. Good first cut. -# A2 TWO-PASS (exact): probe backward -> g_i -> m -> re-run routed_forward + backward. -# 2x backward cost; use if A1's lag hurts. +# c.grad = delta_S * g_b (g_b = per-rollout delta_S grad). Divide it out, but only on +# axes where delta_S has moved (|delta_S|>eps); near-zero axes keep the full grad. +g_b = where(|delta_S|>eps, c.grad / delta_S, 0) # [b, r] +flag = cosine(g_b, v_grad) > 0 # [b] +delta_S.grad -= sum_{flagged b} g_b # exact identity on reliable axes +# quarantine keeps every rollout's grad (always-on forward) -> absorption intact. ``` -Asymmetry to flag in the writeup: Arm B supports absorption in one pass (mask is a -forward-time activation cosine); Arm A's mask is post-backward, so absorption needs -A1's stale mask or A2's second pass. This is a real reason Arm B is the cleaner bet -and matches the user's "act-space is the natural fit for per-sample" instinct. +Why subtraction beats the earlier A1/A2 framings: routing is post-backward WITHIN the +same step, so the cross-step sample mismatch that made the "stale mask" (A1) awkward +never arises, and we avoid the 2x backward of the exact two-pass (A2). The only +approximation is the eps-guard: on an axis where delta_S is still ~0 the per-rollout +split is undefined (0/0), so we route nothing there until delta_S grows past eps — +a ~1-step-per-axis lag, the same order as A1's lag, and harmless because a ~0 axis +carries no learned hack yet. Upgrade to A2 (probe pass -> mask -> detach-route pass) +only if that lag measurably hurts. + +Asymmetry still worth flagging in the writeup: Arm B's mask is a free forward-time +activation cosine (no gate, no division); Arm A pays a per-rollout gate + an +eps-guarded division. Arm B remains the cleaner bet. ## Requirements @@ -189,8 +195,9 @@ and matches the user's "act-space is the natural fit for per-sample" instinct. forward value unchanged vs non-detached. - [ ] T2 (R3): Arm B act-mask — `v_act` extraction from persona pairs + forward-time cosine. verify: R3 fire-ratio on known hack/clean. UAT: "mask fires on hacks". -- [ ] T3 (R3): Arm A grad-mask — c_i gate per-sample grad + A1 stale-mask. verify: - gate identity (sum_i g_i == delta_S.grad). UAT: "per-sample cos recovered". +- [x] T3 (R3): Arm A grad-mask — c_i gate per-sample grad, single-pass subtraction + (NOT A1 stale / A2 two-pass; see Log 2026-05-31). gate identity sum_i g_i == + delta_S.grad holds by construction. smoke routing2_grad passes (||B_q||=0.109). - [ ] T4 (R4): leakage metric + L1 knob (`lambda_l1`, default 0). - [ ] T5: 5-arm sweep at matched seed/steps: vanilla, erase, route-additive(old), route2-grad, route2-act. Plus random-V control (#157) on the old route. @@ -245,6 +252,22 @@ and matches the user's "act-space is the natural fit for per-sample" instinct. - 2026-05-31: defaults — vhack_refresh_every 0->5 (0 is ablation-only); route2 reuses run-substrate (v-hack-path ignored, vhack-pairs drives v_act, tau/rank defaulted) so the sweep needs no extra args. +- 2026-05-31: T3 (Arm A grad-mask) implemented + smoke-passed. Removed route2_tau + (never tuned; mask is cos>0, the natural hack-ward boundary). v_hack path now + auto-derives from --vhack-pairs-path (pass the pairset, the hack auto-loads). + Arm A design CHANGED from the spec's A1/A2: single-pass gate-SUBTRACTION instead + of stale-mask or two-pass. The per-rollout gate c (init 1, identity forward) gives + c.grad = delta_S * g_b after backward; train.py divides out delta_S (eps-guard on + |delta_S|>1e-6) to get per-rollout g_b, flags cos(g_b, v_grad)>0, and subtracts + flagged rollouts from delta_S.grad. No forward detach, no second pass; quarantine + still learns flagged rollouts via its always-on path. The cross-step sample- + mismatch that made A1 awkward never arises because routing is post-backward within + the same step. Lag bound: routing on a fresh axis lags ~1 step until |delta_S| + grows past eps there (this is the A1-equivalent one-step lag, per-axis). Upgrade + to A2 (two-pass detach) only if the lag hurts. v_grad = unit(mean(g_hack-g_clean)) + from extract_v_hack raw grads (gradient-space analogue of v_act). smoke + routing2_grad: ||B_q||=0.109 after 30 steps (quarantine seeded by routed grad), + deploy eval + asserts pass, exit 0. ## TODO (out of scope now) diff --git a/justfile b/justfile index 247363d..64ffc80 100644 --- a/justfile +++ b/justfile @@ -46,13 +46,25 @@ smoke-route *ARGS: # Routing-v2 path (route2): distinct-basis quarantine (A_q,B_q) + per-sample # act-mask detach-route in the FORWARD. Fires extract_v_act, the quarantine -# optimizer params, the act-mask cosine, the B_q-moved assert, and the deploy -# ablation (B_q zeroed). tau=-1 forces all-flagged so the seed path is exercised -# on tiny inputs (real runs calibrate tau against R3 fire-ratio). +# optimizer params, the act-mask cosine (route cos>0), the B_q-moved assert, and +# the deploy ablation (B_q zeroed). On the tiny-random model v_act is near-random +# so ~half the samples flag -- both the detach and the through paths fire. smoke-route2 *ARGS: BEARTYPE=1 CUDA_VISIBLE_DEVICES= {{ TRAIN }} smoke --intervention=route2 \ --teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \ - --route2-quarantine-rank=8 --route2-tau=-1.0 \ + --route2-quarantine-rank=8 \ + --eval-ablate-every=10 --eval-n-prompts=2 {{ ARGS }} + +# route2 GRAD-mask (Arm A): distinct code path from act -- a per-rollout gate is +# spliced into the forward, then train.py recovers per-sample grad (c.grad/delta_S) +# and routes by SUBTRACTING flagged rollouts from delta_S.grad post-backward. +# Exercises the gate forward, extract_v_hack mean-diff -> v_grad, the subtraction, +# and the B_q-moved assert. Run alongside smoke-route2 (act); they don't overlap. +smoke-route2-grad *ARGS: + BEARTYPE=1 CUDA_VISIBLE_DEVICES= {{ TRAIN }} smoke --intervention=route2 \ + --route2-mask=grad \ + --teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \ + --route2-quarantine-rank=8 \ --eval-ablate-every=10 --eval-n-prompts=2 {{ ARGS }} # Run smoke twice: first warms the v_hack cache (cache-miss path), second hits @@ -156,7 +168,6 @@ build-substrate MODES="run_tests,exit_code,sentinel": run-substrate INTERV="none" SEED="41" STEPS="60" REFRESH="5": {{ TRAIN }} fast --intervention={{ INTERV }} \ --teacher-pool-dir=out/pools/substrate \ - --v-hack-path=out/vhack/v_hack_pairset_prog_wide.safetensors \ --vhack-pairs-path=out/pairsets/prog_wide.json \ --vhack-refresh-every={{ REFRESH }} \ --seed={{ SEED }} --steps={{ STEPS }} --out-tag=_sub4_{{ INTERV }}_rf{{ REFRESH }}_s{{ SEED }} diff --git a/src/projected_grpo/antipasto.py b/src/projected_grpo/antipasto.py index ced4771..6a102c7 100644 --- a/src/projected_grpo/antipasto.py +++ b/src/projected_grpo/antipasto.py @@ -73,17 +73,25 @@ def _delta_hook(layer: nn.Linear, args: tuple, y: Tensor) -> Tensor: delta_S = main knob; delta_S_hack = shared-basis routing quarantine (proj.py parks the removed hack-ward grad here). Both 0 at init -> identity. - DISTINCT-BASIS route2 (A_q set): - a = Vh @ x # [..., r] activation in SVD coords - kept= U @ (delta_S * a) # the deployed adapter (delta_S_hack stays 0) - m = cos(a, v_act) > tau # [...] per-sample/token persona flag + DISTINCT-BASIS route2 (A_q set). The quarantine `quar = B_q @ (A_q @ x)` is + ALWAYS summed (item 3: in-path for every sample so unflagged hacks can + concentrate there by gradient magnitude -- absorption, Cloud 2410.04332 / + SGTM 2512.05648). Two mask sources select WHICH samples route to it: + + mask_mode="act" (Arm B, single pass): flag in activation space, + m = cos(a, v_act) > 0 # [...] per-token, free at forward time kept= where(m, kept.detach(), kept)# flagged -> delta_S gets no grad - quar= B_q @ (A_q @ x) # distinct-basis quarantine, always summed - y += kept + quar - Routing is done in the forward, not the gradient: a flagged sample detaches - the kept path so only the quarantine learns it; an unflagged sample updates - both, so a hack-like one concentrates in the quarantine where its gradient is - larger (absorption, Cloud 2410.04332 / SGTM 2512.05648). Deploy zeroes A_q,B_q. + The detach leaves the forward value unchanged; only delta_S's backward is cut. + + mask_mode="grad" (Arm A): flag in gradient space. The per-sample delta_S + gradient is not free (delta_S.grad is summed over the batch), so splice a + per-rollout gate c (init 1, identity in the forward): after backward + c.grad = delta_S * g_b (the per-rollout delta_S grad). train.py divides out + delta_S, flags rollouts by cos(g_b, v_grad)>0, and subtracts the flagged + rollouts' contribution from delta_S.grad -- routing happens POST-backward, so + no forward detach here. quar still gets every rollout (always summed). + + Deploy zeroes A_q,B_q (kept-only). """ (x,) = args Vh = layer._antipasto_Vh # [r, d_in] @@ -97,15 +105,30 @@ def _delta_hook(layer: nn.Linear, args: tuple, y: Tensor) -> Tensor: h = a * (delta_S + delta_S_hack).to(a.dtype) return y + torch.nn.functional.linear(h, U).to(y.dtype) - # --- route2: distinct-basis quarantine + per-sample act-mask detach-route --- - kept = torch.nn.functional.linear(a * delta_S.to(a.dtype), U) # [..., d_out] - v_act = layer._antipasto_v_act # [r] unit, hack-ward, in Vh coords - # per-position cosine of the SVD-coord activation with the persona direction - cos = (a @ v_act) / (a.norm(dim=-1).clamp_min(1e-6) * v_act.norm().clamp_min(1e-6)) - m = cos > layer._antipasto_tau # [...] bool - kept = torch.where(m.unsqueeze(-1), kept.detach(), kept) + # --- route2: distinct-basis quarantine, always summed --- B_q = layer._antipasto_B_q # [d_out, k_q] quar = torch.nn.functional.linear(torch.nn.functional.linear(x, A_q), B_q) # [..., d_out] + + if layer._antipasto_mask_mode == "grad": + if torch.is_grad_enabled(): + # per-rollout gate [b, 1.., r], identity at c=1 so the forward value is + # unchanged. After backward c.grad = delta_S * (per-rollout delta_S grad); + # train.py divides out delta_S to recover g_b and routes post-backward. + c = torch.ones(a.shape[0], *([1] * (a.dim() - 2)), a.shape[-1], + device=a.device, dtype=a.dtype, requires_grad=True) + layer._antipasto_gate = c + kept = torch.nn.functional.linear((a * c) * delta_S.to(a.dtype), U) + else: # generate(): no grad -> no gate needed + kept = torch.nn.functional.linear(a * delta_S.to(a.dtype), U) + return y + (kept + quar).to(y.dtype) + + # mask_mode="act": forward detach-route by activation cosine. cos>0 == points + # hack-ward == route to quarantine (the natural boundary, no threshold knob). + kept = torch.nn.functional.linear(a * delta_S.to(a.dtype), U) # [..., d_out] + v_act = layer._antipasto_v_act # [r] unit, hack-ward, in Vh coords + cos = (a @ v_act) / (a.norm(dim=-1).clamp_min(1e-6) * v_act.norm().clamp_min(1e-6)) + m = cos > 0 # [...] bool + kept = torch.where(m.unsqueeze(-1), kept.detach(), kept) return y + (kept + quar).to(y.dtype) @@ -115,7 +138,7 @@ def wrap_model_with_antipasto( cache_root: Path = Path("svd_cache"), svd_device: torch.device | str = "cuda", quarantine_rank: int | None = None, - route2_tau: float = 0.0, + route2_mask: str = "act", ) -> dict[str, dict]: """Attach AntiPaSTO hooks to every target nn.Linear in `model` (in place). @@ -125,8 +148,10 @@ def wrap_model_with_antipasto( `quarantine_rank` (route2 only): if set, also attach a DISTINCT-basis LoRA quarantine per module -- A_q [k_q, d_in] (kaiming), B_q [d_out, k_q] (zeros so - quar=0 at init), and a unit `v_act` buffer [r] (filled by the act-mask - extraction in train.py) + scalar `route2_tau`. None -> shared-basis path only. + quar=0 at init), a unit `v_act` buffer [r] (filled by the act-mask extraction + in train.py), and a `_antipasto_gate` slot (grad-mask probe). None -> shared- + basis path only. `route2_mask` selects the mask source: "act" (forward cosine, + routes in-forward) or "grad" (gate probe, routes post-backward in train.py). """ svd_device_t = torch.device(svd_device) if isinstance(svd_device, str) else svd_device safe = model_name.replace("/", "__") @@ -167,7 +192,8 @@ def wrap_model_with_antipasto( linear.register_parameter("_antipasto_B_q", B_q) linear.register_buffer("_antipasto_v_act", torch.zeros(r, device=dev, dtype=torch.float32), persistent=True) - linear._antipasto_tau = route2_tau + linear._antipasto_mask_mode = route2_mask # "act" | "grad" + linear._antipasto_gate = None # grad-mask probe leaf, set per forward info["A_q"], info["B_q"], info["k_q"] = A_q, B_q, k_q info["handle"] = linear.register_forward_hook(_delta_hook) out[name] = info diff --git a/src/projected_grpo/train.py b/src/projected_grpo/train.py index 124883e..7f9920a 100644 --- a/src/projected_grpo/train.py +++ b/src/projected_grpo/train.py @@ -146,11 +146,14 @@ class Config: # Replaces the old `arm` flag (vanilla/projected); `arm` survives as a derived # display name (see property below) so log/run-id formatting is unchanged. intervention: Literal["none", "erase", "route", "route2"] = "erase" - # route2-only: quarantine LoRA rank (per module, capped at r) and the act-mask - # cosine threshold. tau is the routing-fraction knob (higher -> route less -> - # less starvation, weaker seed). Calibrate against R3 (fire >2x on hack vs clean). + # route2-only: quarantine LoRA rank (per module, capped at r). Any sample whose + # mask cosine > 0 (points hack-ward) routes into the quarantine; no threshold + # knob -- the load-time noise floor already filters. route2_quarantine_rank: int = 16 - route2_tau: float = 0.0 + # route2 mask source. "act" (Arm B): forward-time cos(a, v_act), routes via + # detach, single pass. "grad" (Arm A): per-rollout cos(g_b, v_grad) from a gate + # probe, routes by subtracting flagged rollouts from delta_S.grad post-backward. + route2_mask: Literal["act", "grad"] = "act" # Scale-dependent knobs — every preset must set these to a real value; # subclasses below override the defaults. model: str = "Qwen/Qwen3-4B" @@ -251,9 +254,11 @@ class Config: @property def arm(self) -> str: """Display name for run-id / BLUF / logs (results.py + plot_dynamics - classify off this). One-to-one with intervention; not a CLI flag.""" - return {"none": "vanilla", "erase": "projected", - "route": "routing", "route2": "routing2"}[self.intervention] + classify off this). One-to-one with intervention; not a CLI flag. + route2 splits by mask source so the 5-arm plot can tell act from grad.""" + if self.intervention == "route2": + return f"routing2_{self.route2_mask}" + return {"none": "vanilla", "erase": "projected", "route": "routing"}[self.intervention] @dataclass(kw_only=True) @@ -744,10 +749,11 @@ def main(cfg: Config) -> int: model.config.use_cache = False is_route2 = cfg.intervention == "route2" + is_route2_grad = is_route2 and cfg.route2_mask == "grad" wrappers = wrap_model_with_antipasto( model, model_name, CACHE_ROOT, device, quarantine_rank=cfg.route2_quarantine_rank if is_route2 else None, - route2_tau=cfg.route2_tau, + route2_mask=cfg.route2_mask, ) # Both knobs are trainable params. delta_S_hack only ever gets a grad under # intervention=route (the routing split in proj.py); under none/erase its @@ -755,8 +761,8 @@ def main(cfg: Config) -> int: # so none/erase reproduce the pre-quarantine behaviour bit-for-bit). delta_params = [info["delta_S"] for info in wrappers.values()] delta_hack_params = [info["delta_S_hack"] for info in wrappers.values()] - # route2: the distinct-basis quarantine LoRA (A_q,B_q). Trained for all arms - # it exists in (route2 only); routing happens in the forward, not the grad. + # route2: the distinct-basis quarantine LoRA (A_q,B_q), trainable, deleted at + # deploy. act-mask routes in the forward (detach); grad-mask routes post-backward. quar_params = ([info["A_q"] for info in wrappers.values()] + [info["B_q"] for info in wrappers.values()]) if is_route2 else [] logger.info(f"trainable delta_S: {sum(p.numel() for p in delta_params):,} " @@ -768,44 +774,63 @@ def main(cfg: Config) -> int: # entirely -- loading it there only to print a cos_pre diagnostic was misleading # (and could trigger a needless ~5-min extraction). The cin/cout columns are # hidden on vanilla, so v_hack=None just means "no subspace machinery". + v_grad = None # set only by the route2 grad-mask branch below if cfg.intervention in ("none", "route2"): if cfg.intervention == "none" and cfg.v_hack_path is not None: logger.info(f"vanilla arm: ignoring --v-hack-path={cfg.v_hack_path} " "(no projection; cin/cout diagnostics off)") - v_hack = None # route2 routes in the FORWARD via v_act, not via grad surgery + v_hack = None # route2 routes via the mask, not erase/route grad surgery if is_route2: - # Build the activation-space persona direction (Arm B mask source) from - # the SAME persona pairs the grad extract uses, then load it into each - # module's _antipasto_v_act buffer. No oracle: pairs are the weak detector. - from .extract_vhack_grad import extract_v_act + # The persona pairs are the only "detector" (weak, self-supervised). They + # produce the mask direction; no oracle, no gt_pass. Same pairs for both + # masks so act vs grad differ only in the space the direction lives in. if cfg.vhack_pairs_path is not None: from .pairs_from_pool import load_pairs_json - VACT_PAIRS = load_pairs_json(cfg.vhack_pairs_path) - logger.info(f"v_act pairs: pool-derived ({cfg.vhack_pairs_path}) -> {len(VACT_PAIRS)} pairs") + MASK_PAIRS = load_pairs_json(cfg.vhack_pairs_path) + logger.info(f"route2 mask pairs: pool-derived ({cfg.vhack_pairs_path}) -> {len(MASK_PAIRS)} pairs") else: - from .pairs import PAIRS as VACT_PAIRS - logger.info(f"v_act pairs: hand-crafted PAIRS -> {len(VACT_PAIRS)} pairs") + from .pairs import PAIRS as MASK_PAIRS + logger.info(f"route2 mask pairs: hand-crafted PAIRS -> {len(MASK_PAIRS)} pairs") model.eval() - v_act = extract_v_act(model, tok, wrappers, VACT_PAIRS, n_heldout=2, device=device) + if cfg.route2_mask == "act": + # Arm B: activation-space mean-diff -> _antipasto_v_act buffer (the + # forward cos(a,v_act) reads it). Forward-only, cheap. + from .extract_vhack_grad import extract_v_act + v_act = extract_v_act(model, tok, wrappers, MASK_PAIRS, n_heldout=2, device=device) + for name, info in wrappers.items(): + info["layer"]._antipasto_v_act.data.copy_(v_act[name].to(device)) + logger.info(f"route2 act: loaded v_act into {len(v_act)} modules") + else: + # Arm A: gradient-space mean-diff. extract_v_hack gives per-pair GRPO + # gradients on delta_S; v_grad = unit(mean(g_hack - g_clean)) per + # module, oriented hack-ward (training reinforces hacks with the + # same sign, so cos(g_b, v_grad)>0 flags a reinforced-hack rollout). + from .extract_vhack_grad import extract_v_hack + _, _, raw_grads, _ = extract_v_hack( + model, tok, wrappers, MASK_PAIRS, + top_k=1, tau_axis=0.0, n_heldout=2, device=device, + ) + v_grad = {} + for name in wrappers: + d = (raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).mean(0) + v_grad[name] = (d / d.norm().clamp_min(1e-12)).to(device) + logger.info(f"route2 grad: built v_grad (gradient mean-diff) for {len(v_grad)} modules") model.train() - for name, info in wrappers.items(): - info["layer"]._antipasto_v_act.data.copy_(v_act[name].to(device)) - logger.info(f"route2: loaded v_act into {len(v_act)} modules; " - f"quarantine_rank={cfg.route2_quarantine_rank} tau={cfg.route2_tau}") else: - # derive default path from model + extract_top_k unless overridden. - # Auto-extract reuses the already-wrapped model — no second model load. - # Slug: works for HF names ("Qwen/Qwen3-4B" -> "Qwen3-4B") and local paths - # ("out/baked/qwen3_4b_rh25" -> "qwen3_4b_rh25"). - model_slug = model_name.rstrip("/").split("/")[-1] - # Filename encodes top_k AND tau_axis because both are baked into the saved - # V (extract zeros rows where S_i/S_0 < tau_axis before saving). If a future - # ablation varies pairs.py, add a pairs hash here too. - tau_tag = f"_tau{cfg.v_hack_tau_axis:g}" if cfg.v_hack_tau_axis > 0 else "" - if cfg.v_hack_path is None: - v_hack_path = VHACK_DIR / f"v_hack_{model_slug}_k{cfg.v_hack_extract_top_k}{tau_tag}.safetensors" + # v_hack path resolution, most-specific first. The pairset (personas) is + # the source of truth: pass --vhack-pairs-path and the hack file auto-loads + # (auto-extracts if missing) -- no need to also pass --v-hack-path. + if cfg.v_hack_path is not None: + v_hack_path = cfg.v_hack_path # explicit override (e.g. randomV control) + elif cfg.vhack_pairs_path is not None: + v_hack_path = VHACK_DIR / f"v_hack_pairset_{cfg.vhack_pairs_path.stem}.safetensors" else: - v_hack_path = cfg.v_hack_path + # no pairset given -> hand-crafted PAIRS, keyed by model + extract knobs. + # Slug works for HF names and local paths; tau_tag because tau_axis is + # baked into the saved V (extract zeros rows where S_i/S_0 < tau_axis). + model_slug = model_name.rstrip("/").split("/")[-1] + tau_tag = f"_tau{cfg.v_hack_tau_axis:g}" if cfg.v_hack_tau_axis > 0 else "" + v_hack_path = VHACK_DIR / f"v_hack_{model_slug}_k{cfg.v_hack_extract_top_k}{tau_tag}.safetensors" if not v_hack_path.exists(): from .extract_vhack_grad import extract_v_hack if cfg.vhack_pairs_path is not None: @@ -1101,11 +1126,38 @@ def main(cfg: Config) -> int: step_grad_quar[key] = (step_grad_quar[key] + p.grad.detach().clone() if key in step_grad_quar else p.grad.detach().clone()) + # route2 grad-mask: recover the per-rollout delta_S grad from the gate + # (c.grad = delta_S * g_b), flag rollouts whose grad points hack-ward + # (cos(g_b, v_grad) > 0), and subtract their contribution from delta_S.grad. + # Only axes where delta_S has moved (|delta_S| > GATE_EPS) carry a reliable + # per-rollout split; near-zero axes keep the full grad, so routing on a fresh + # axis lags ~1 step until delta_S grows there (the A1 stale-mask trade-off). + GATE_EPS = 1e-6 + step_flagged: list[float] = [] + + def _route2_grad_filter(info) -> torch.Tensor: + g = info["delta_S"].grad # [r] summed over rollouts + cg = info["layer"]._antipasto_gate.grad.reshape(-1, g.shape[0]) # [b, r] + dS = info["delta_S"].detach() # [r] + reliable = dS.abs() > GATE_EPS # [r] + dS_safe = torch.where(reliable, dS, torch.ones_like(dS)) + g_b = torch.where(reliable, cg / dS_safe, torch.zeros_like(cg)) # [b, r] + vg = v_grad[name] # [r] unit, hack-ward + cos_b = (g_b @ vg) / g_b.norm(dim=1).clamp_min(1e-12) # [b] + flagged = (cos_b > 0).float() # [b] + step_flagged.append(flagged.mean().item()) + sub = torch.where(reliable, (cg * flagged.unsqueeze(1)).sum(0) / dS_safe, + torch.zeros_like(g)) # flagged rollouts' contribution + return g - sub + # Split backward into student/teacher only every cos_pre_split_every steps. # On split steps: 2 backwards per prompt, populates step_grad_s/_t. # On skipped steps: 1 combined backward, step_grad_s/_t stay empty and # cos_pre_s/cos_pre_t go to NaN (mean_cos_pre_from_grads returns NaN on empty dict). - split_this_step = (step % cfg.cos_pre_split_every == 0) + # route2 has no v_hack so cos_pre is NaN regardless: force the single combined + # backward (the split would just double cost). The grad-mask reads its + # per-rollout gate from that one backward. + split_this_step = (step % cfg.cos_pre_split_every == 0) and not is_route2 # Phase timers (per-step cumulative, seconds). Each GPU phase ends in a # CPU-blocking op (decode / .item()), so perf_counter is sync-accurate # without explicit cuda.synchronize. Tells us whether wall-time is @@ -1396,6 +1448,10 @@ def main(cfg: Config) -> int: g = info["delta_S"].grad if g is None: continue + # grad-mask routes here: strip flagged rollouts from delta_S.grad + # (quarantine still learns them via its always-on forward path). + if is_route2_grad: + g = _route2_grad_filter(info) step_grad_s[name] = (step_grad_s[name] + g.detach().clone() if name in step_grad_s else g.detach().clone()) @@ -1435,6 +1491,11 @@ def main(cfg: Config) -> int: diag = {"mean_cos_pre": float("nan"), "mean_cos_post": float("nan"), "frac_fired": float("nan"), "mean_cos_pre_s": float("nan"), "mean_cos_pre_t": float("nan")} + # route2 grad-mask: report the mean per-module per-rollout flag rate so + # we can watch the mask actually fire (and rise as hacks emerge). + if is_route2_grad and step_flagged: + logger.debug(f"route2-grad flagged frac (mean over modules*prompts): " + f"{sum(step_flagged)/len(step_flagged):+.3f}") else: if split_this_step: cos_pre_s = mean_cos_pre_from_grads(step_grad_s, v_hack)