fix: route2 quar/v_act dtype mismatch on bf16 model (A_q/B_q/v_act fp32 vs bf16 x)

Smoke is fp32 (CPU tiny-random) so the bf16 path never fired -- job 34/35
crashed on the real Qwen3-4B with 'BFloat16 != float' in the quar matmul.
Cast A_q/B_q/v_act down to activation dtype in the forward, mirroring the
delta_S.to(a.dtype) pattern (fp32 master, bf16 compute, grads cast back).
Validated forward+backward in bf16 for both masks. + run-substrate MASK param.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-05-31 13:35:25 +00:00
parent 25569193c5
commit 80f6b52860
2 changed files with 8 additions and 5 deletions
+3 -3
View File
@@ -165,12 +165,12 @@ build-substrate MODES="run_tests,exit_code,sentinel":
# (per-mode hacks>0 + finite first_step) + the per-step hk_<mode> columns. mix=0.125
# is the locked default (omit to inherit it). Vanilla needs no v_hack; for an
# erase/route substrate run, add --v-hack-path explicitly.
run-substrate INTERV="none" SEED="41" STEPS="60" REFRESH="5":
run-substrate INTERV="none" SEED="41" STEPS="60" REFRESH="5" MASK="act":
{{ TRAIN }} fast --intervention={{ INTERV }} \
--teacher-pool-dir=out/pools/substrate \
--vhack-pairs-path=out/pairsets/prog_wide.json \
--vhack-refresh-every={{ REFRESH }} \
--seed={{ SEED }} --steps={{ STEPS }} --out-tag=_sub4_{{ INTERV }}_rf{{ REFRESH }}_s{{ SEED }}
--vhack-refresh-every={{ REFRESH }} --route2-mask={{ MASK }} \
--seed={{ SEED }} --steps={{ STEPS }} --out-tag=_sub4_{{ INTERV }}_{{ MASK }}_rf{{ REFRESH }}_s{{ SEED }}
# CANONICAL plotting entrypoint for the substrate sweep. One command, four figures
# (per-mode by-method + by-hack, and the aggregate "total hacks per arm" + overlay,
+5 -2
View File
@@ -107,7 +107,10 @@ def _delta_hook(layer: nn.Linear, args: tuple, y: Tensor) -> Tensor:
# --- route2: distinct-basis quarantine, always summed ---
B_q = layer._antipasto_B_q # [d_out, k_q]
quar = torch.nn.functional.linear(torch.nn.functional.linear(x, A_q), B_q) # [..., d_out]
# A_q/B_q kept fp32 (master, like delta_S); cast down to x.dtype for the matmul
# (bf16 on the real model). autograd casts grads back to the fp32 params.
quar = torch.nn.functional.linear(
torch.nn.functional.linear(x, A_q.to(x.dtype)), B_q.to(x.dtype)) # [..., d_out]
if layer._antipasto_mask_mode == "grad":
if torch.is_grad_enabled():
@@ -127,7 +130,7 @@ def _delta_hook(layer: nn.Linear, args: tuple, y: Tensor) -> Tensor:
# mask_mode="act": forward detach-route by activation cosine. cos>0 == points
# hack-ward == route to quarantine (the natural boundary, no threshold knob).
kept = torch.nn.functional.linear(a * delta_S.to(a.dtype), U) # [..., d_out]
v_act = layer._antipasto_v_act # [r] unit, hack-ward, in Vh coords
v_act = layer._antipasto_v_act.to(a.dtype) # [r] unit, hack-ward, in Vh coords (fp32 buffer -> a.dtype)
cos = (a @ v_act) / (a.norm(dim=-1).clamp_min(1e-6) * v_act.norm().clamp_min(1e-6))
m = cos > 0 # [...] bool
kept = torch.where(m.unsqueeze(-1), kept.detach(), kept)