cleanup(lora2r): resolve user TODOs -- F.linear alias + jaxtyping hook shapes

torch.nn.functional.linear -> F.linear (import F); annotate A/B/A0/B0 with
Float[Tensor, ...] dims. Behaviorally identical -- verify_lora2r_routing green
(identity 0.00e+00, all three masks + mixed-batch + ablation).

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-10 11:51:49 +00:00
parent 35286040ed
commit 944ada360b
+13 -11
View File
@@ -18,6 +18,8 @@ this one tensor implement the SGTM parameter partition (Cloud et al.).
from __future__ import annotations
import torch
import torch.nn.functional as F
from jaxtyping import Float
from loguru import logger
from torch import Tensor, nn
@@ -53,23 +55,23 @@ def _lora2r_hook(layer: nn.Linear, args: tuple, y: Tensor) -> Tensor:
gate (train.py) both read this same space, so the band is self-consistent
whatever the basis.
"""
(x,) = args
A = layer._lora2r_A # [2r, d_in] trainable
B = layer._lora2r_B # [d_out, 2r] trainable
A0 = layer._lora2r_A0 # frozen init copies (subtracted: net delta 0 at init)
B0 = layer._lora2r_B0
(x,) = args # x: [..., d_in]
A: Float[Tensor, "two_r d_in"] = layer._lora2r_A # trainable
B: Float[Tensor, "d_out two_r"] = layer._lora2r_B # trainable
A0: Float[Tensor, "two_r d_in"] = layer._lora2r_A0 # frozen init (subtracted: net delta 0 at init)
B0: Float[Tensor, "d_out two_r"] = layer._lora2r_B0
r = layer._lora2r_r
h = torch.nn.functional.linear(x, A.to(x.dtype)) # [..., 2r]
h = F.linear(x, A.to(x.dtype)) # [..., 2r]
if layer._lora2r_grad_probe and torch.is_grad_enabled():
c = torch.ones(h.shape[0], *([1] * (h.dim() - 2)), h.shape[-1],
device=h.device, dtype=h.dtype, requires_grad=True)
layer._lora2r_gate = c
h = h * c
h0 = torch.nn.functional.linear(x, A0.to(x.dtype)) # [..., 2r] frozen init path
dep = (torch.nn.functional.linear(h[..., :r], B[:, :r].to(x.dtype))
- torch.nn.functional.linear(h0[..., :r], B0[:, :r].to(x.dtype)))
quar = (torch.nn.functional.linear(h[..., r:], B[:, r:].to(x.dtype))
- torch.nn.functional.linear(h0[..., r:], B0[:, r:].to(x.dtype)))
h0 = F.linear(x, A0.to(x.dtype)) # [..., 2r] frozen init path
dep = (F.linear(h[..., :r], B[:, :r].to(x.dtype))
- F.linear(h0[..., :r], B0[:, :r].to(x.dtype)))
quar = (F.linear(h[..., r:], B[:, r:].to(x.dtype))
- F.linear(h0[..., r:], B0[:, r:].to(x.dtype)))
if layer._lora2r_mask is not None:
m, d = layer._lora2r_mask # [G] each
G = m.shape[0]