mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 17:00:59 +08:00
cleanup(lora2r): resolve user TODOs -- F.linear alias + jaxtyping hook shapes
torch.nn.functional.linear -> F.linear (import F); annotate A/B/A0/B0 with Float[Tensor, ...] dims. Behaviorally identical -- verify_lora2r_routing green (identity 0.00e+00, all three masks + mixed-batch + ablation). Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
+13
-11
@@ -18,6 +18,8 @@ this one tensor implement the SGTM parameter partition (Cloud et al.).
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from jaxtyping import Float
|
||||
from loguru import logger
|
||||
from torch import Tensor, nn
|
||||
|
||||
@@ -53,23 +55,23 @@ def _lora2r_hook(layer: nn.Linear, args: tuple, y: Tensor) -> Tensor:
|
||||
gate (train.py) both read this same space, so the band is self-consistent
|
||||
whatever the basis.
|
||||
"""
|
||||
(x,) = args
|
||||
A = layer._lora2r_A # [2r, d_in] trainable
|
||||
B = layer._lora2r_B # [d_out, 2r] trainable
|
||||
A0 = layer._lora2r_A0 # frozen init copies (subtracted: net delta 0 at init)
|
||||
B0 = layer._lora2r_B0
|
||||
(x,) = args # x: [..., d_in]
|
||||
A: Float[Tensor, "two_r d_in"] = layer._lora2r_A # trainable
|
||||
B: Float[Tensor, "d_out two_r"] = layer._lora2r_B # trainable
|
||||
A0: Float[Tensor, "two_r d_in"] = layer._lora2r_A0 # frozen init (subtracted: net delta 0 at init)
|
||||
B0: Float[Tensor, "d_out two_r"] = layer._lora2r_B0
|
||||
r = layer._lora2r_r
|
||||
h = torch.nn.functional.linear(x, A.to(x.dtype)) # [..., 2r]
|
||||
h = F.linear(x, A.to(x.dtype)) # [..., 2r]
|
||||
if layer._lora2r_grad_probe and torch.is_grad_enabled():
|
||||
c = torch.ones(h.shape[0], *([1] * (h.dim() - 2)), h.shape[-1],
|
||||
device=h.device, dtype=h.dtype, requires_grad=True)
|
||||
layer._lora2r_gate = c
|
||||
h = h * c
|
||||
h0 = torch.nn.functional.linear(x, A0.to(x.dtype)) # [..., 2r] frozen init path
|
||||
dep = (torch.nn.functional.linear(h[..., :r], B[:, :r].to(x.dtype))
|
||||
- torch.nn.functional.linear(h0[..., :r], B0[:, :r].to(x.dtype)))
|
||||
quar = (torch.nn.functional.linear(h[..., r:], B[:, r:].to(x.dtype))
|
||||
- torch.nn.functional.linear(h0[..., r:], B0[:, r:].to(x.dtype)))
|
||||
h0 = F.linear(x, A0.to(x.dtype)) # [..., 2r] frozen init path
|
||||
dep = (F.linear(h[..., :r], B[:, :r].to(x.dtype))
|
||||
- F.linear(h0[..., :r], B0[:, :r].to(x.dtype)))
|
||||
quar = (F.linear(h[..., r:], B[:, r:].to(x.dtype))
|
||||
- F.linear(h0[..., r:], B0[:, r:].to(x.dtype)))
|
||||
if layer._lora2r_mask is not None:
|
||||
m, d = layer._lora2r_mask # [G] each
|
||||
G = m.shape[0]
|
||||
|
||||
Reference in New Issue
Block a user