diff --git a/src/vgrout/lora2r.py b/src/vgrout/lora2r.py index 290a8e1..8e7eacb 100644 --- a/src/vgrout/lora2r.py +++ b/src/vgrout/lora2r.py @@ -18,6 +18,8 @@ this one tensor implement the SGTM parameter partition (Cloud et al.). from __future__ import annotations import torch +import torch.nn.functional as F +from jaxtyping import Float from loguru import logger from torch import Tensor, nn @@ -53,23 +55,23 @@ def _lora2r_hook(layer: nn.Linear, args: tuple, y: Tensor) -> Tensor: gate (train.py) both read this same space, so the band is self-consistent whatever the basis. """ - (x,) = args - A = layer._lora2r_A # [2r, d_in] trainable - B = layer._lora2r_B # [d_out, 2r] trainable - A0 = layer._lora2r_A0 # frozen init copies (subtracted: net delta 0 at init) - B0 = layer._lora2r_B0 + (x,) = args # x: [..., d_in] + A: Float[Tensor, "two_r d_in"] = layer._lora2r_A # trainable + B: Float[Tensor, "d_out two_r"] = layer._lora2r_B # trainable + A0: Float[Tensor, "two_r d_in"] = layer._lora2r_A0 # frozen init (subtracted: net delta 0 at init) + B0: Float[Tensor, "d_out two_r"] = layer._lora2r_B0 r = layer._lora2r_r - h = torch.nn.functional.linear(x, A.to(x.dtype)) # [..., 2r] + h = F.linear(x, A.to(x.dtype)) # [..., 2r] if layer._lora2r_grad_probe and torch.is_grad_enabled(): c = torch.ones(h.shape[0], *([1] * (h.dim() - 2)), h.shape[-1], device=h.device, dtype=h.dtype, requires_grad=True) layer._lora2r_gate = c h = h * c - h0 = torch.nn.functional.linear(x, A0.to(x.dtype)) # [..., 2r] frozen init path - dep = (torch.nn.functional.linear(h[..., :r], B[:, :r].to(x.dtype)) - - torch.nn.functional.linear(h0[..., :r], B0[:, :r].to(x.dtype))) - quar = (torch.nn.functional.linear(h[..., r:], B[:, r:].to(x.dtype)) - - torch.nn.functional.linear(h0[..., r:], B0[:, r:].to(x.dtype))) + h0 = F.linear(x, A0.to(x.dtype)) # [..., 2r] frozen init path + dep = (F.linear(h[..., :r], B[:, :r].to(x.dtype)) + - F.linear(h0[..., :r], B0[:, :r].to(x.dtype))) + quar = (F.linear(h[..., r:], B[:, r:].to(x.dtype)) + - F.linear(h0[..., r:], B0[:, r:].to(x.dtype))) if layer._lora2r_mask is not None: m, d = layer._lora2r_mask # [G] each G = m.shape[0]