diff --git a/justfile b/justfile index 64ffc80..8fc5ec3 100644 --- a/justfile +++ b/justfile @@ -165,12 +165,12 @@ build-substrate MODES="run_tests,exit_code,sentinel": # (per-mode hacks>0 + finite first_step) + the per-step hk_ columns. mix=0.125 # is the locked default (omit to inherit it). Vanilla needs no v_hack; for an # erase/route substrate run, add --v-hack-path explicitly. -run-substrate INTERV="none" SEED="41" STEPS="60" REFRESH="5": +run-substrate INTERV="none" SEED="41" STEPS="60" REFRESH="5" MASK="act": {{ TRAIN }} fast --intervention={{ INTERV }} \ --teacher-pool-dir=out/pools/substrate \ --vhack-pairs-path=out/pairsets/prog_wide.json \ - --vhack-refresh-every={{ REFRESH }} \ - --seed={{ SEED }} --steps={{ STEPS }} --out-tag=_sub4_{{ INTERV }}_rf{{ REFRESH }}_s{{ SEED }} + --vhack-refresh-every={{ REFRESH }} --route2-mask={{ MASK }} \ + --seed={{ SEED }} --steps={{ STEPS }} --out-tag=_sub4_{{ INTERV }}_{{ MASK }}_rf{{ REFRESH }}_s{{ SEED }} # CANONICAL plotting entrypoint for the substrate sweep. One command, four figures # (per-mode by-method + by-hack, and the aggregate "total hacks per arm" + overlay, diff --git a/src/projected_grpo/antipasto.py b/src/projected_grpo/antipasto.py index aedef9b..af180dd 100644 --- a/src/projected_grpo/antipasto.py +++ b/src/projected_grpo/antipasto.py @@ -107,7 +107,10 @@ def _delta_hook(layer: nn.Linear, args: tuple, y: Tensor) -> Tensor: # --- route2: distinct-basis quarantine, always summed --- B_q = layer._antipasto_B_q # [d_out, k_q] - quar = torch.nn.functional.linear(torch.nn.functional.linear(x, A_q), B_q) # [..., d_out] + # A_q/B_q kept fp32 (master, like delta_S); cast down to x.dtype for the matmul + # (bf16 on the real model). autograd casts grads back to the fp32 params. + quar = torch.nn.functional.linear( + torch.nn.functional.linear(x, A_q.to(x.dtype)), B_q.to(x.dtype)) # [..., d_out] if layer._antipasto_mask_mode == "grad": if torch.is_grad_enabled(): @@ -127,7 +130,7 @@ def _delta_hook(layer: nn.Linear, args: tuple, y: Tensor) -> Tensor: # mask_mode="act": forward detach-route by activation cosine. cos>0 == points # hack-ward == route to quarantine (the natural boundary, no threshold knob). kept = torch.nn.functional.linear(a * delta_S.to(a.dtype), U) # [..., d_out] - v_act = layer._antipasto_v_act # [r] unit, hack-ward, in Vh coords + v_act = layer._antipasto_v_act.to(a.dtype) # [r] unit, hack-ward, in Vh coords (fp32 buffer -> a.dtype) cos = (a @ v_act) / (a.norm(dim=-1).clamp_min(1e-6) * v_act.norm().clamp_min(1e-6)) m = cos > 0 # [...] bool kept = torch.where(m.unsqueeze(-1), kept.detach(), kept)