From 944ada360b5121a2fa0f30d339adc5bc47eb039f Mon Sep 17 00:00:00 2001 From: wassname <1103714+wassname@users.noreply.github.com> Date: Wed, 10 Jun 2026 11:51:49 +0000 Subject: [PATCH] cleanup(lora2r): resolve user TODOs -- F.linear alias + jaxtyping hook shapes torch.nn.functional.linear -> F.linear (import F); annotate A/B/A0/B0 with Float[Tensor, ...] dims. Behaviorally identical -- verify_lora2r_routing green (identity 0.00e+00, all three masks + mixed-batch + ablation). Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com> --- src/vgrout/lora2r.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/vgrout/lora2r.py b/src/vgrout/lora2r.py index 290a8e1..8e7eacb 100644 --- a/src/vgrout/lora2r.py +++ b/src/vgrout/lora2r.py @@ -18,6 +18,8 @@ this one tensor implement the SGTM parameter partition (Cloud et al.). from __future__ import annotations import torch +import torch.nn.functional as F +from jaxtyping import Float from loguru import logger from torch import Tensor, nn @@ -53,23 +55,23 @@ def _lora2r_hook(layer: nn.Linear, args: tuple, y: Tensor) -> Tensor: gate (train.py) both read this same space, so the band is self-consistent whatever the basis. """ - (x,) = args - A = layer._lora2r_A # [2r, d_in] trainable - B = layer._lora2r_B # [d_out, 2r] trainable - A0 = layer._lora2r_A0 # frozen init copies (subtracted: net delta 0 at init) - B0 = layer._lora2r_B0 + (x,) = args # x: [..., d_in] + A: Float[Tensor, "two_r d_in"] = layer._lora2r_A # trainable + B: Float[Tensor, "d_out two_r"] = layer._lora2r_B # trainable + A0: Float[Tensor, "two_r d_in"] = layer._lora2r_A0 # frozen init (subtracted: net delta 0 at init) + B0: Float[Tensor, "d_out two_r"] = layer._lora2r_B0 r = layer._lora2r_r - h = torch.nn.functional.linear(x, A.to(x.dtype)) # [..., 2r] + h = F.linear(x, A.to(x.dtype)) # [..., 2r] if layer._lora2r_grad_probe and torch.is_grad_enabled(): c = torch.ones(h.shape[0], *([1] * (h.dim() - 2)), h.shape[-1], device=h.device, dtype=h.dtype, requires_grad=True) layer._lora2r_gate = c h = h * c - h0 = torch.nn.functional.linear(x, A0.to(x.dtype)) # [..., 2r] frozen init path - dep = (torch.nn.functional.linear(h[..., :r], B[:, :r].to(x.dtype)) - - torch.nn.functional.linear(h0[..., :r], B0[:, :r].to(x.dtype))) - quar = (torch.nn.functional.linear(h[..., r:], B[:, r:].to(x.dtype)) - - torch.nn.functional.linear(h0[..., r:], B0[:, r:].to(x.dtype))) + h0 = F.linear(x, A0.to(x.dtype)) # [..., 2r] frozen init path + dep = (F.linear(h[..., :r], B[:, :r].to(x.dtype)) + - F.linear(h0[..., :r], B0[:, :r].to(x.dtype))) + quar = (F.linear(h[..., r:], B[:, r:].to(x.dtype)) + - F.linear(h0[..., r:], B0[:, r:].to(x.dtype))) if layer._lora2r_mask is not None: m, d = layer._lora2r_mask # [G] each G = m.shape[0]