mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 17:30:41 +08:00
fix: route2 quar/v_act dtype mismatch on bf16 model (A_q/B_q/v_act fp32 vs bf16 x)
Smoke is fp32 (CPU tiny-random) so the bf16 path never fired -- job 34/35 crashed on the real Qwen3-4B with 'BFloat16 != float' in the quar matmul. Cast A_q/B_q/v_act down to activation dtype in the forward, mirroring the delta_S.to(a.dtype) pattern (fp32 master, bf16 compute, grads cast back). Validated forward+backward in bf16 for both masks. + run-substrate MASK param. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -165,12 +165,12 @@ build-substrate MODES="run_tests,exit_code,sentinel":
|
||||
# (per-mode hacks>0 + finite first_step) + the per-step hk_<mode> columns. mix=0.125
|
||||
# is the locked default (omit to inherit it). Vanilla needs no v_hack; for an
|
||||
# erase/route substrate run, add --v-hack-path explicitly.
|
||||
run-substrate INTERV="none" SEED="41" STEPS="60" REFRESH="5":
|
||||
run-substrate INTERV="none" SEED="41" STEPS="60" REFRESH="5" MASK="act":
|
||||
{{ TRAIN }} fast --intervention={{ INTERV }} \
|
||||
--teacher-pool-dir=out/pools/substrate \
|
||||
--vhack-pairs-path=out/pairsets/prog_wide.json \
|
||||
--vhack-refresh-every={{ REFRESH }} \
|
||||
--seed={{ SEED }} --steps={{ STEPS }} --out-tag=_sub4_{{ INTERV }}_rf{{ REFRESH }}_s{{ SEED }}
|
||||
--vhack-refresh-every={{ REFRESH }} --route2-mask={{ MASK }} \
|
||||
--seed={{ SEED }} --steps={{ STEPS }} --out-tag=_sub4_{{ INTERV }}_{{ MASK }}_rf{{ REFRESH }}_s{{ SEED }}
|
||||
|
||||
# CANONICAL plotting entrypoint for the substrate sweep. One command, four figures
|
||||
# (per-mode by-method + by-hack, and the aggregate "total hacks per arm" + overlay,
|
||||
|
||||
@@ -107,7 +107,10 @@ def _delta_hook(layer: nn.Linear, args: tuple, y: Tensor) -> Tensor:
|
||||
|
||||
# --- route2: distinct-basis quarantine, always summed ---
|
||||
B_q = layer._antipasto_B_q # [d_out, k_q]
|
||||
quar = torch.nn.functional.linear(torch.nn.functional.linear(x, A_q), B_q) # [..., d_out]
|
||||
# A_q/B_q kept fp32 (master, like delta_S); cast down to x.dtype for the matmul
|
||||
# (bf16 on the real model). autograd casts grads back to the fp32 params.
|
||||
quar = torch.nn.functional.linear(
|
||||
torch.nn.functional.linear(x, A_q.to(x.dtype)), B_q.to(x.dtype)) # [..., d_out]
|
||||
|
||||
if layer._antipasto_mask_mode == "grad":
|
||||
if torch.is_grad_enabled():
|
||||
@@ -127,7 +130,7 @@ def _delta_hook(layer: nn.Linear, args: tuple, y: Tensor) -> Tensor:
|
||||
# mask_mode="act": forward detach-route by activation cosine. cos>0 == points
|
||||
# hack-ward == route to quarantine (the natural boundary, no threshold knob).
|
||||
kept = torch.nn.functional.linear(a * delta_S.to(a.dtype), U) # [..., d_out]
|
||||
v_act = layer._antipasto_v_act # [r] unit, hack-ward, in Vh coords
|
||||
v_act = layer._antipasto_v_act.to(a.dtype) # [r] unit, hack-ward, in Vh coords (fp32 buffer -> a.dtype)
|
||||
cos = (a @ v_act) / (a.norm(dim=-1).clamp_min(1e-6) * v_act.norm().clamp_min(1e-6))
|
||||
m = cos > 0 # [...] bool
|
||||
kept = torch.where(m.unsqueeze(-1), kept.detach(), kept)
|
||||
|
||||
Reference in New Issue
Block a user