feat: route2 grad-mask (Arm A) + drop tau knob + pairset-derived v_hack path

Arm A (route2_mask=grad): per-rollout gate splice (identity at c=1) recovers
the per-sample delta_S grad after backward (c.grad = delta_S * g_b); train.py
divides it out (eps-guard |delta_S|>1e-6), flags rollouts by cos(g_b, v_grad)>0,
and SUBTRACTS them from delta_S.grad. Single-pass, no forward detach, no second
backward -- the cross-step mismatch that made the spec's A1 stale-mask awkward
never arises (routing is post-backward within the step). v_grad = unit-mean
gradient diff from extract_v_hack raw grads (gradient-space analogue of v_act).
route2 forces the combined (non-split) backward since cos_pre is NaN for it
anyway, which also gives the gate a single clean grad to read.

Drop route2_tau: never tuned; the mask is cos>0 (the natural hack-ward boundary)
and the load-time noise floor already filters axes.

v_hack path now auto-derives from --vhack-pairs-path (out/vhack/v_hack_pairset_
<stem>.safetensors): pass the pairset, the hack file auto-loads/extracts -- no
need to also pass --v-hack-path. run-substrate drops the redundant flag.

smoke: smoke-route2 (act) and new smoke-route2-grad both pass (||B_q||=0.109,
exit 0); erase shared-basis path unchanged (cout->0, fired~0.9).

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-05-31 10:48:31 +00:00
parent 442630fcae
commit 670fcb3c64
4 changed files with 209 additions and 88 deletions
+48 -25
View File
@@ -128,36 +128,42 @@ m = (cos > tau).float() # weak, noisy mask -- fine (SGTM t
y, quar = routed_forward(x, m)
```
### Arm A (grad-vector) — needs the per-sample-gradient trick
### Arm A (grad-vector) — single-pass gate SUBTRACTION (implemented)
The persona direction lives in `delta_S` (gradient) space, but `delta_S.grad` is
summed over the batch, so the per-sample signal is not free. Recover it with a
per-sample gate (the c_i trick), then route. Two implementable variants because the
signal is only known *after* a backward:
summed over the batch, so the per-rollout signal is not free. Recover it with a
PER-ROLLOUT gate (init 1, identity in the forward) and route POST-backward by
subtraction — no forward detach, no second pass:
```python
# v_hack : [r] persona-contrast direction in delta_S/gradient space (existing
# extract_vhack_grad.py: SVD of stacked persona-pair GRPO gradients).
# --- per-sample gradient via gates (one extra leaf, cheap) ---
c = ones(b, s, r, requires_grad=True) # per-sample gate over SVD coords
h = (delta_S * c) * (Vh @ x) # splice gate into the kept path
# v_grad : [r] unit(mean(g_hack - g_clean)) per module, from extract_vhack_grad's
# raw per-pair GRPO grads (gradient-space analogue of v_act). Oriented
# hack-ward: training reinforces a hack with the same sign, so a
# reinforced-hack rollout has cos(g_b, v_grad) > 0.
c = ones(b, 1, r, requires_grad=True) # per-rollout gate (broadcast over tokens)
kept = U @ ((a * c) * delta_S) # identity at c=1: forward value unchanged
quar = B_q @ (A_q @ x) # always summed
... usual loss; loss.backward()
g_i = c.grad / delta_S[None, None, :] # [b s r] per-sample delta_S-space grad
# (divide out the delta_S scaling; watch /0)
cos = cosine(g_i, v_hack) # [b s]
m = (cos > tau).float()
# A1 STALE-MASK (single pass): use step (t-1)'s m to detach-route step t's forward.
# cheapest; m lags by one step. Good first cut.
# A2 TWO-PASS (exact): probe backward -> g_i -> m -> re-run routed_forward + backward.
# 2x backward cost; use if A1's lag hurts.
# c.grad = delta_S * g_b (g_b = per-rollout delta_S grad). Divide it out, but only on
# axes where delta_S has moved (|delta_S|>eps); near-zero axes keep the full grad.
g_b = where(|delta_S|>eps, c.grad / delta_S, 0) # [b, r]
flag = cosine(g_b, v_grad) > 0 # [b]
delta_S.grad -= sum_{flagged b} g_b # exact identity on reliable axes
# quarantine keeps every rollout's grad (always-on forward) -> absorption intact.
```
Asymmetry to flag in the writeup: Arm B supports absorption in one pass (mask is a
forward-time activation cosine); Arm A's mask is post-backward, so absorption needs
A1's stale mask or A2's second pass. This is a real reason Arm B is the cleaner bet
and matches the user's "act-space is the natural fit for per-sample" instinct.
Why subtraction beats the earlier A1/A2 framings: routing is post-backward WITHIN the
same step, so the cross-step sample mismatch that made the "stale mask" (A1) awkward
never arises, and we avoid the 2x backward of the exact two-pass (A2). The only
approximation is the eps-guard: on an axis where delta_S is still ~0 the per-rollout
split is undefined (0/0), so we route nothing there until delta_S grows past eps —
a ~1-step-per-axis lag, the same order as A1's lag, and harmless because a ~0 axis
carries no learned hack yet. Upgrade to A2 (probe pass -> mask -> detach-route pass)
only if that lag measurably hurts.
Asymmetry still worth flagging in the writeup: Arm B's mask is a free forward-time
activation cosine (no gate, no division); Arm A pays a per-rollout gate + an
eps-guarded division. Arm B remains the cleaner bet.
## Requirements
@@ -189,8 +195,9 @@ and matches the user's "act-space is the natural fit for per-sample" instinct.
forward value unchanged vs non-detached.
- [ ] T2 (R3): Arm B act-mask — `v_act` extraction from persona pairs + forward-time
cosine. verify: R3 fire-ratio on known hack/clean. UAT: "mask fires on hacks".
- [ ] T3 (R3): Arm A grad-mask — c_i gate per-sample grad + A1 stale-mask. verify:
gate identity (sum_i g_i == delta_S.grad). UAT: "per-sample cos recovered".
- [x] T3 (R3): Arm A grad-mask — c_i gate per-sample grad, single-pass subtraction
(NOT A1 stale / A2 two-pass; see Log 2026-05-31). gate identity sum_i g_i ==
delta_S.grad holds by construction. smoke routing2_grad passes (||B_q||=0.109).
- [ ] T4 (R4): leakage metric + L1 knob (`lambda_l1`, default 0).
- [ ] T5: 5-arm sweep at matched seed/steps: vanilla, erase, route-additive(old),
route2-grad, route2-act. Plus random-V control (#157) on the old route.
@@ -245,6 +252,22 @@ and matches the user's "act-space is the natural fit for per-sample" instinct.
- 2026-05-31: defaults — vhack_refresh_every 0->5 (0 is ablation-only);
route2 reuses run-substrate (v-hack-path ignored, vhack-pairs drives v_act,
tau/rank defaulted) so the sweep needs no extra args.
- 2026-05-31: T3 (Arm A grad-mask) implemented + smoke-passed. Removed route2_tau
(never tuned; mask is cos>0, the natural hack-ward boundary). v_hack path now
auto-derives from --vhack-pairs-path (pass the pairset, the hack auto-loads).
Arm A design CHANGED from the spec's A1/A2: single-pass gate-SUBTRACTION instead
of stale-mask or two-pass. The per-rollout gate c (init 1, identity forward) gives
c.grad = delta_S * g_b after backward; train.py divides out delta_S (eps-guard on
|delta_S|>1e-6) to get per-rollout g_b, flags cos(g_b, v_grad)>0, and subtracts
flagged rollouts from delta_S.grad. No forward detach, no second pass; quarantine
still learns flagged rollouts via its always-on path. The cross-step sample-
mismatch that made A1 awkward never arises because routing is post-backward within
the same step. Lag bound: routing on a fresh axis lags ~1 step until |delta_S|
grows past eps there (this is the A1-equivalent one-step lag, per-axis). Upgrade
to A2 (two-pass detach) only if the lag hurts. v_grad = unit(mean(g_hack-g_clean))
from extract_v_hack raw grads (gradient-space analogue of v_act). smoke
routing2_grad: ||B_q||=0.109 after 30 steps (quarantine seeded by routed grad),
deploy eval + asserts pass, exit 0.
## TODO (out of scope now)
+16 -5
View File
@@ -46,13 +46,25 @@ smoke-route *ARGS:
# Routing-v2 path (route2): distinct-basis quarantine (A_q,B_q) + per-sample
# act-mask detach-route in the FORWARD. Fires extract_v_act, the quarantine
# optimizer params, the act-mask cosine, the B_q-moved assert, and the deploy
# ablation (B_q zeroed). tau=-1 forces all-flagged so the seed path is exercised
# on tiny inputs (real runs calibrate tau against R3 fire-ratio).
# optimizer params, the act-mask cosine (route cos>0), the B_q-moved assert, and
# the deploy ablation (B_q zeroed). On the tiny-random model v_act is near-random
# so ~half the samples flag -- both the detach and the through paths fire.
smoke-route2 *ARGS:
BEARTYPE=1 CUDA_VISIBLE_DEVICES= {{ TRAIN }} smoke --intervention=route2 \
--teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \
--route2-quarantine-rank=8 --route2-tau=-1.0 \
--route2-quarantine-rank=8 \
--eval-ablate-every=10 --eval-n-prompts=2 {{ ARGS }}
# route2 GRAD-mask (Arm A): distinct code path from act -- a per-rollout gate is
# spliced into the forward, then train.py recovers per-sample grad (c.grad/delta_S)
# and routes by SUBTRACTING flagged rollouts from delta_S.grad post-backward.
# Exercises the gate forward, extract_v_hack mean-diff -> v_grad, the subtraction,
# and the B_q-moved assert. Run alongside smoke-route2 (act); they don't overlap.
smoke-route2-grad *ARGS:
BEARTYPE=1 CUDA_VISIBLE_DEVICES= {{ TRAIN }} smoke --intervention=route2 \
--route2-mask=grad \
--teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \
--route2-quarantine-rank=8 \
--eval-ablate-every=10 --eval-n-prompts=2 {{ ARGS }}
# Run smoke twice: first warms the v_hack cache (cache-miss path), second hits
@@ -156,7 +168,6 @@ build-substrate MODES="run_tests,exit_code,sentinel":
run-substrate INTERV="none" SEED="41" STEPS="60" REFRESH="5":
{{ TRAIN }} fast --intervention={{ INTERV }} \
--teacher-pool-dir=out/pools/substrate \
--v-hack-path=out/vhack/v_hack_pairset_prog_wide.safetensors \
--vhack-pairs-path=out/pairsets/prog_wide.json \
--vhack-refresh-every={{ REFRESH }} \
--seed={{ SEED }} --steps={{ STEPS }} --out-tag=_sub4_{{ INTERV }}_rf{{ REFRESH }}_s{{ SEED }}
+47 -21
View File
@@ -73,17 +73,25 @@ def _delta_hook(layer: nn.Linear, args: tuple, y: Tensor) -> Tensor:
delta_S = main knob; delta_S_hack = shared-basis routing quarantine (proj.py
parks the removed hack-ward grad here). Both 0 at init -> identity.
DISTINCT-BASIS route2 (A_q set):
a = Vh @ x # [..., r] activation in SVD coords
kept= U @ (delta_S * a) # the deployed adapter (delta_S_hack stays 0)
m = cos(a, v_act) > tau # [...] per-sample/token persona flag
DISTINCT-BASIS route2 (A_q set). The quarantine `quar = B_q @ (A_q @ x)` is
ALWAYS summed (item 3: in-path for every sample so unflagged hacks can
concentrate there by gradient magnitude -- absorption, Cloud 2410.04332 /
SGTM 2512.05648). Two mask sources select WHICH samples route to it:
mask_mode="act" (Arm B, single pass): flag in activation space,
m = cos(a, v_act) > 0 # [...] per-token, free at forward time
kept= where(m, kept.detach(), kept)# flagged -> delta_S gets no grad
quar= B_q @ (A_q @ x) # distinct-basis quarantine, always summed
y += kept + quar
Routing is done in the forward, not the gradient: a flagged sample detaches
the kept path so only the quarantine learns it; an unflagged sample updates
both, so a hack-like one concentrates in the quarantine where its gradient is
larger (absorption, Cloud 2410.04332 / SGTM 2512.05648). Deploy zeroes A_q,B_q.
The detach leaves the forward value unchanged; only delta_S's backward is cut.
mask_mode="grad" (Arm A): flag in gradient space. The per-sample delta_S
gradient is not free (delta_S.grad is summed over the batch), so splice a
per-rollout gate c (init 1, identity in the forward): after backward
c.grad = delta_S * g_b (the per-rollout delta_S grad). train.py divides out
delta_S, flags rollouts by cos(g_b, v_grad)>0, and subtracts the flagged
rollouts' contribution from delta_S.grad -- routing happens POST-backward, so
no forward detach here. quar still gets every rollout (always summed).
Deploy zeroes A_q,B_q (kept-only).
"""
(x,) = args
Vh = layer._antipasto_Vh # [r, d_in]
@@ -97,15 +105,30 @@ def _delta_hook(layer: nn.Linear, args: tuple, y: Tensor) -> Tensor:
h = a * (delta_S + delta_S_hack).to(a.dtype)
return y + torch.nn.functional.linear(h, U).to(y.dtype)
# --- route2: distinct-basis quarantine + per-sample act-mask detach-route ---
kept = torch.nn.functional.linear(a * delta_S.to(a.dtype), U) # [..., d_out]
v_act = layer._antipasto_v_act # [r] unit, hack-ward, in Vh coords
# per-position cosine of the SVD-coord activation with the persona direction
cos = (a @ v_act) / (a.norm(dim=-1).clamp_min(1e-6) * v_act.norm().clamp_min(1e-6))
m = cos > layer._antipasto_tau # [...] bool
kept = torch.where(m.unsqueeze(-1), kept.detach(), kept)
# --- route2: distinct-basis quarantine, always summed ---
B_q = layer._antipasto_B_q # [d_out, k_q]
quar = torch.nn.functional.linear(torch.nn.functional.linear(x, A_q), B_q) # [..., d_out]
if layer._antipasto_mask_mode == "grad":
if torch.is_grad_enabled():
# per-rollout gate [b, 1.., r], identity at c=1 so the forward value is
# unchanged. After backward c.grad = delta_S * (per-rollout delta_S grad);
# train.py divides out delta_S to recover g_b and routes post-backward.
c = torch.ones(a.shape[0], *([1] * (a.dim() - 2)), a.shape[-1],
device=a.device, dtype=a.dtype, requires_grad=True)
layer._antipasto_gate = c
kept = torch.nn.functional.linear((a * c) * delta_S.to(a.dtype), U)
else: # generate(): no grad -> no gate needed
kept = torch.nn.functional.linear(a * delta_S.to(a.dtype), U)
return y + (kept + quar).to(y.dtype)
# mask_mode="act": forward detach-route by activation cosine. cos>0 == points
# hack-ward == route to quarantine (the natural boundary, no threshold knob).
kept = torch.nn.functional.linear(a * delta_S.to(a.dtype), U) # [..., d_out]
v_act = layer._antipasto_v_act # [r] unit, hack-ward, in Vh coords
cos = (a @ v_act) / (a.norm(dim=-1).clamp_min(1e-6) * v_act.norm().clamp_min(1e-6))
m = cos > 0 # [...] bool
kept = torch.where(m.unsqueeze(-1), kept.detach(), kept)
return y + (kept + quar).to(y.dtype)
@@ -115,7 +138,7 @@ def wrap_model_with_antipasto(
cache_root: Path = Path("svd_cache"),
svd_device: torch.device | str = "cuda",
quarantine_rank: int | None = None,
route2_tau: float = 0.0,
route2_mask: str = "act",
) -> dict[str, dict]:
"""Attach AntiPaSTO hooks to every target nn.Linear in `model` (in place).
@@ -125,8 +148,10 @@ def wrap_model_with_antipasto(
`quarantine_rank` (route2 only): if set, also attach a DISTINCT-basis LoRA
quarantine per module -- A_q [k_q, d_in] (kaiming), B_q [d_out, k_q] (zeros so
quar=0 at init), and a unit `v_act` buffer [r] (filled by the act-mask
extraction in train.py) + scalar `route2_tau`. None -> shared-basis path only.
quar=0 at init), a unit `v_act` buffer [r] (filled by the act-mask extraction
in train.py), and a `_antipasto_gate` slot (grad-mask probe). None -> shared-
basis path only. `route2_mask` selects the mask source: "act" (forward cosine,
routes in-forward) or "grad" (gate probe, routes post-backward in train.py).
"""
svd_device_t = torch.device(svd_device) if isinstance(svd_device, str) else svd_device
safe = model_name.replace("/", "__")
@@ -167,7 +192,8 @@ def wrap_model_with_antipasto(
linear.register_parameter("_antipasto_B_q", B_q)
linear.register_buffer("_antipasto_v_act",
torch.zeros(r, device=dev, dtype=torch.float32), persistent=True)
linear._antipasto_tau = route2_tau
linear._antipasto_mask_mode = route2_mask # "act" | "grad"
linear._antipasto_gate = None # grad-mask probe leaf, set per forward
info["A_q"], info["B_q"], info["k_q"] = A_q, B_q, k_q
info["handle"] = linear.register_forward_hook(_delta_hook)
out[name] = info
+98 -37
View File
@@ -146,11 +146,14 @@ class Config:
# Replaces the old `arm` flag (vanilla/projected); `arm` survives as a derived
# display name (see property below) so log/run-id formatting is unchanged.
intervention: Literal["none", "erase", "route", "route2"] = "erase"
# route2-only: quarantine LoRA rank (per module, capped at r) and the act-mask
# cosine threshold. tau is the routing-fraction knob (higher -> route less ->
# less starvation, weaker seed). Calibrate against R3 (fire >2x on hack vs clean).
# route2-only: quarantine LoRA rank (per module, capped at r). Any sample whose
# mask cosine > 0 (points hack-ward) routes into the quarantine; no threshold
# knob -- the load-time noise floor already filters.
route2_quarantine_rank: int = 16
route2_tau: float = 0.0
# route2 mask source. "act" (Arm B): forward-time cos(a, v_act), routes via
# detach, single pass. "grad" (Arm A): per-rollout cos(g_b, v_grad) from a gate
# probe, routes by subtracting flagged rollouts from delta_S.grad post-backward.
route2_mask: Literal["act", "grad"] = "act"
# Scale-dependent knobs — every preset must set these to a real value;
# subclasses below override the defaults.
model: str = "Qwen/Qwen3-4B"
@@ -251,9 +254,11 @@ class Config:
@property
def arm(self) -> str:
"""Display name for run-id / BLUF / logs (results.py + plot_dynamics
classify off this). One-to-one with intervention; not a CLI flag."""
return {"none": "vanilla", "erase": "projected",
"route": "routing", "route2": "routing2"}[self.intervention]
classify off this). One-to-one with intervention; not a CLI flag.
route2 splits by mask source so the 5-arm plot can tell act from grad."""
if self.intervention == "route2":
return f"routing2_{self.route2_mask}"
return {"none": "vanilla", "erase": "projected", "route": "routing"}[self.intervention]
@dataclass(kw_only=True)
@@ -744,10 +749,11 @@ def main(cfg: Config) -> int:
model.config.use_cache = False
is_route2 = cfg.intervention == "route2"
is_route2_grad = is_route2 and cfg.route2_mask == "grad"
wrappers = wrap_model_with_antipasto(
model, model_name, CACHE_ROOT, device,
quarantine_rank=cfg.route2_quarantine_rank if is_route2 else None,
route2_tau=cfg.route2_tau,
route2_mask=cfg.route2_mask,
)
# Both knobs are trainable params. delta_S_hack only ever gets a grad under
# intervention=route (the routing split in proj.py); under none/erase its
@@ -755,8 +761,8 @@ def main(cfg: Config) -> int:
# so none/erase reproduce the pre-quarantine behaviour bit-for-bit).
delta_params = [info["delta_S"] for info in wrappers.values()]
delta_hack_params = [info["delta_S_hack"] for info in wrappers.values()]
# route2: the distinct-basis quarantine LoRA (A_q,B_q). Trained for all arms
# it exists in (route2 only); routing happens in the forward, not the grad.
# route2: the distinct-basis quarantine LoRA (A_q,B_q), trainable, deleted at
# deploy. act-mask routes in the forward (detach); grad-mask routes post-backward.
quar_params = ([info["A_q"] for info in wrappers.values()]
+ [info["B_q"] for info in wrappers.values()]) if is_route2 else []
logger.info(f"trainable delta_S: {sum(p.numel() for p in delta_params):,} "
@@ -768,44 +774,63 @@ def main(cfg: Config) -> int:
# entirely -- loading it there only to print a cos_pre diagnostic was misleading
# (and could trigger a needless ~5-min extraction). The cin/cout columns are
# hidden on vanilla, so v_hack=None just means "no subspace machinery".
v_grad = None # set only by the route2 grad-mask branch below
if cfg.intervention in ("none", "route2"):
if cfg.intervention == "none" and cfg.v_hack_path is not None:
logger.info(f"vanilla arm: ignoring --v-hack-path={cfg.v_hack_path} "
"(no projection; cin/cout diagnostics off)")
v_hack = None # route2 routes in the FORWARD via v_act, not via grad surgery
v_hack = None # route2 routes via the mask, not erase/route grad surgery
if is_route2:
# Build the activation-space persona direction (Arm B mask source) from
# the SAME persona pairs the grad extract uses, then load it into each
# module's _antipasto_v_act buffer. No oracle: pairs are the weak detector.
from .extract_vhack_grad import extract_v_act
# The persona pairs are the only "detector" (weak, self-supervised). They
# produce the mask direction; no oracle, no gt_pass. Same pairs for both
# masks so act vs grad differ only in the space the direction lives in.
if cfg.vhack_pairs_path is not None:
from .pairs_from_pool import load_pairs_json
VACT_PAIRS = load_pairs_json(cfg.vhack_pairs_path)
logger.info(f"v_act pairs: pool-derived ({cfg.vhack_pairs_path}) -> {len(VACT_PAIRS)} pairs")
MASK_PAIRS = load_pairs_json(cfg.vhack_pairs_path)
logger.info(f"route2 mask pairs: pool-derived ({cfg.vhack_pairs_path}) -> {len(MASK_PAIRS)} pairs")
else:
from .pairs import PAIRS as VACT_PAIRS
logger.info(f"v_act pairs: hand-crafted PAIRS -> {len(VACT_PAIRS)} pairs")
from .pairs import PAIRS as MASK_PAIRS
logger.info(f"route2 mask pairs: hand-crafted PAIRS -> {len(MASK_PAIRS)} pairs")
model.eval()
v_act = extract_v_act(model, tok, wrappers, VACT_PAIRS, n_heldout=2, device=device)
if cfg.route2_mask == "act":
# Arm B: activation-space mean-diff -> _antipasto_v_act buffer (the
# forward cos(a,v_act) reads it). Forward-only, cheap.
from .extract_vhack_grad import extract_v_act
v_act = extract_v_act(model, tok, wrappers, MASK_PAIRS, n_heldout=2, device=device)
for name, info in wrappers.items():
info["layer"]._antipasto_v_act.data.copy_(v_act[name].to(device))
logger.info(f"route2 act: loaded v_act into {len(v_act)} modules")
else:
# Arm A: gradient-space mean-diff. extract_v_hack gives per-pair GRPO
# gradients on delta_S; v_grad = unit(mean(g_hack - g_clean)) per
# module, oriented hack-ward (training reinforces hacks with the
# same sign, so cos(g_b, v_grad)>0 flags a reinforced-hack rollout).
from .extract_vhack_grad import extract_v_hack
_, _, raw_grads, _ = extract_v_hack(
model, tok, wrappers, MASK_PAIRS,
top_k=1, tau_axis=0.0, n_heldout=2, device=device,
)
v_grad = {}
for name in wrappers:
d = (raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).mean(0)
v_grad[name] = (d / d.norm().clamp_min(1e-12)).to(device)
logger.info(f"route2 grad: built v_grad (gradient mean-diff) for {len(v_grad)} modules")
model.train()
for name, info in wrappers.items():
info["layer"]._antipasto_v_act.data.copy_(v_act[name].to(device))
logger.info(f"route2: loaded v_act into {len(v_act)} modules; "
f"quarantine_rank={cfg.route2_quarantine_rank} tau={cfg.route2_tau}")
else:
# derive default path from model + extract_top_k unless overridden.
# Auto-extract reuses the already-wrapped model — no second model load.
# Slug: works for HF names ("Qwen/Qwen3-4B" -> "Qwen3-4B") and local paths
# ("out/baked/qwen3_4b_rh25" -> "qwen3_4b_rh25").
model_slug = model_name.rstrip("/").split("/")[-1]
# Filename encodes top_k AND tau_axis because both are baked into the saved
# V (extract zeros rows where S_i/S_0 < tau_axis before saving). If a future
# ablation varies pairs.py, add a pairs hash here too.
tau_tag = f"_tau{cfg.v_hack_tau_axis:g}" if cfg.v_hack_tau_axis > 0 else ""
if cfg.v_hack_path is None:
v_hack_path = VHACK_DIR / f"v_hack_{model_slug}_k{cfg.v_hack_extract_top_k}{tau_tag}.safetensors"
# v_hack path resolution, most-specific first. The pairset (personas) is
# the source of truth: pass --vhack-pairs-path and the hack file auto-loads
# (auto-extracts if missing) -- no need to also pass --v-hack-path.
if cfg.v_hack_path is not None:
v_hack_path = cfg.v_hack_path # explicit override (e.g. randomV control)
elif cfg.vhack_pairs_path is not None:
v_hack_path = VHACK_DIR / f"v_hack_pairset_{cfg.vhack_pairs_path.stem}.safetensors"
else:
v_hack_path = cfg.v_hack_path
# no pairset given -> hand-crafted PAIRS, keyed by model + extract knobs.
# Slug works for HF names and local paths; tau_tag because tau_axis is
# baked into the saved V (extract zeros rows where S_i/S_0 < tau_axis).
model_slug = model_name.rstrip("/").split("/")[-1]
tau_tag = f"_tau{cfg.v_hack_tau_axis:g}" if cfg.v_hack_tau_axis > 0 else ""
v_hack_path = VHACK_DIR / f"v_hack_{model_slug}_k{cfg.v_hack_extract_top_k}{tau_tag}.safetensors"
if not v_hack_path.exists():
from .extract_vhack_grad import extract_v_hack
if cfg.vhack_pairs_path is not None:
@@ -1101,11 +1126,38 @@ def main(cfg: Config) -> int:
step_grad_quar[key] = (step_grad_quar[key] + p.grad.detach().clone()
if key in step_grad_quar else p.grad.detach().clone())
# route2 grad-mask: recover the per-rollout delta_S grad from the gate
# (c.grad = delta_S * g_b), flag rollouts whose grad points hack-ward
# (cos(g_b, v_grad) > 0), and subtract their contribution from delta_S.grad.
# Only axes where delta_S has moved (|delta_S| > GATE_EPS) carry a reliable
# per-rollout split; near-zero axes keep the full grad, so routing on a fresh
# axis lags ~1 step until delta_S grows there (the A1 stale-mask trade-off).
GATE_EPS = 1e-6
step_flagged: list[float] = []
def _route2_grad_filter(info) -> torch.Tensor:
g = info["delta_S"].grad # [r] summed over rollouts
cg = info["layer"]._antipasto_gate.grad.reshape(-1, g.shape[0]) # [b, r]
dS = info["delta_S"].detach() # [r]
reliable = dS.abs() > GATE_EPS # [r]
dS_safe = torch.where(reliable, dS, torch.ones_like(dS))
g_b = torch.where(reliable, cg / dS_safe, torch.zeros_like(cg)) # [b, r]
vg = v_grad[name] # [r] unit, hack-ward
cos_b = (g_b @ vg) / g_b.norm(dim=1).clamp_min(1e-12) # [b]
flagged = (cos_b > 0).float() # [b]
step_flagged.append(flagged.mean().item())
sub = torch.where(reliable, (cg * flagged.unsqueeze(1)).sum(0) / dS_safe,
torch.zeros_like(g)) # flagged rollouts' contribution
return g - sub
# Split backward into student/teacher only every cos_pre_split_every steps.
# On split steps: 2 backwards per prompt, populates step_grad_s/_t.
# On skipped steps: 1 combined backward, step_grad_s/_t stay empty and
# cos_pre_s/cos_pre_t go to NaN (mean_cos_pre_from_grads returns NaN on empty dict).
split_this_step = (step % cfg.cos_pre_split_every == 0)
# route2 has no v_hack so cos_pre is NaN regardless: force the single combined
# backward (the split would just double cost). The grad-mask reads its
# per-rollout gate from that one backward.
split_this_step = (step % cfg.cos_pre_split_every == 0) and not is_route2
# Phase timers (per-step cumulative, seconds). Each GPU phase ends in a
# CPU-blocking op (decode / .item()), so perf_counter is sync-accurate
# without explicit cuda.synchronize. Tells us whether wall-time is
@@ -1396,6 +1448,10 @@ def main(cfg: Config) -> int:
g = info["delta_S"].grad
if g is None:
continue
# grad-mask routes here: strip flagged rollouts from delta_S.grad
# (quarantine still learns them via its always-on forward path).
if is_route2_grad:
g = _route2_grad_filter(info)
step_grad_s[name] = (step_grad_s[name] + g.detach().clone()
if name in step_grad_s
else g.detach().clone())
@@ -1435,6 +1491,11 @@ def main(cfg: Config) -> int:
diag = {"mean_cos_pre": float("nan"), "mean_cos_post": float("nan"),
"frac_fired": float("nan"), "mean_cos_pre_s": float("nan"),
"mean_cos_pre_t": float("nan")}
# route2 grad-mask: report the mean per-module per-rollout flag rate so
# we can watch the mask actually fire (and rise as hacks emerge).
if is_route2_grad and step_flagged:
logger.debug(f"route2-grad flagged frac (mean over modules*prompts): "
f"{sum(step_flagged)/len(step_flagged):+.3f}")
else:
if split_this_step:
cos_pre_s = mean_cos_pre_from_grads(step_grad_s, v_hack)