diff --git a/src/lora_lite/variants/antipasto.py b/src/lora_lite/variants/antipasto.py index 654a23d..e537721 100644 --- a/src/lora_lite/variants/antipasto.py +++ b/src/lora_lite/variants/antipasto.py @@ -224,7 +224,7 @@ class AntiPaSTO: raise ValueError(f"rotate_basis must be 'U', 'V', or 'none', got {rotate_basis!r}") S_eff = S + layer.lora_delta_s.to(x.dtype) # (r,) - h = einsum(x, Vh_eff, "... i, r i -> ... r") # x @ Vh_eff.T + h = x @ Vh_eff.T # x @ Vh_eff.T h = h * S_eff # diag(S_eff) - delta = einsum(h, U_eff, "... r, o r -> ... o") # @ U_eff.T + delta = h @ U_eff.T # @ U_eff.T return y + delta diff --git a/src/lora_lite/variants/delora.py b/src/lora_lite/variants/delora.py index 89cc942..0313392 100644 --- a/src/lora_lite/variants/delora.py +++ b/src/lora_lite/variants/delora.py @@ -18,7 +18,6 @@ Refs: (offline: docs/refs/peft_delora_layer.py) """ import torch -from einops import einsum from jaxtyping import Float from torch import nn, Tensor as T from dataclasses import dataclass @@ -70,11 +69,11 @@ class DeLoRA: cfg = layer._lora_cfg A = layer.lora_A # (r, d_in) B = layer.lora_B # (d_out, r) - x_scaled = x * layer.lora_wnorm # (..., d_in) - h = einsum(x_scaled, A, "... i, r i -> ... r") + x_scaled = x.to(A.dtype) * layer.lora_wnorm.to(A.dtype) # (..., d_in) + h = x_scaled @ A.T An = torch.clamp(A.norm(dim=1), min=1e-4) # (r,) Bn = torch.clamp(B.norm(dim=0), min=1e-4) # (r,) scale = (layer.lora_lambda / cfg.r) / (An * Bn) # (r,) h = h * scale - delta = einsum(h, B, "... r, o r -> ... o") - return y + delta + delta = h @ B.T + return y + delta.to(y.dtype) diff --git a/src/lora_lite/variants/dora.py b/src/lora_lite/variants/dora.py index 7ee496e..46ea112 100644 --- a/src/lora_lite/variants/dora.py +++ b/src/lora_lite/variants/dora.py @@ -9,7 +9,6 @@ Refs: (offline: docs/refs/peft_lora_dora.py) """ import torch -from einops import einsum from jaxtyping import Float from torch import nn, Tensor as T from dataclasses import dataclass @@ -59,14 +58,16 @@ class DoRA: scale = cfg.alpha / cfg.r # Paper §4.3: treat ||V+ΔV||_c as a constant (detach from grad graph) for # stability and ~2x lower memory. Match peft (lora_weight.detach + weight_norm.detach). - BA = einsum(layer.lora_B, layer.lora_A, "o r, r i -> o i") + BA = layer.lora_B @ layer.lora_A V = layer.weight + scale * BA.detach() # (d_out, d_in) v_norm = V.norm(dim=1).clamp_min(1e-12).detach() # (d_out,) # Bias passes through unscaled (matches peft). bias = getattr(layer, "bias", None) wx = y if bias is None else (y - bias) - h = einsum(x, layer.lora_A, "... i, r i -> ... r") - delta = einsum(h, layer.lora_B, "... r, o r -> ... o") - combined = wx + scale * delta + xA = x.to(layer.lora_A.dtype) + h = xA @ layer.lora_A.T + delta = h @ layer.lora_B.T + combined = wx + (scale * delta).to(wx.dtype) out = (layer.lora_m / v_norm) * combined - return out if bias is None else out + bias + out = out if bias is None else out + bias + return out.to(y.dtype) diff --git a/src/lora_lite/variants/eva.py b/src/lora_lite/variants/eva.py index 0b8bdbf..83fc2d6 100644 --- a/src/lora_lite/variants/eva.py +++ b/src/lora_lite/variants/eva.py @@ -13,7 +13,7 @@ Refs: (offline: docs/refs/peft_eva.py; example: docs/refs/peft_eva_finetuning.py) """ import torch -from einops import einsum, rearrange +from einops import rearrange from jaxtyping import Float from torch import nn, Tensor as T from typing import Iterable @@ -113,6 +113,7 @@ class EVA: ) -> Float[T, '*B o']: cfg = layer._lora_cfg scale = cfg.alpha / cfg.r - h = einsum(x, layer.lora_A, "... i, r i -> ... r") - delta = einsum(h, layer.lora_B, "... r, o r -> ... o") - return y + scale * delta + xA = x.to(layer.lora_A.dtype) + h = xA @ layer.lora_A.T + delta = h @ layer.lora_B.T + return y + (scale * delta).to(y.dtype) diff --git a/src/lora_lite/variants/lora.py b/src/lora_lite/variants/lora.py index 209b48c..9ace977 100644 --- a/src/lora_lite/variants/lora.py +++ b/src/lora_lite/variants/lora.py @@ -8,10 +8,8 @@ Refs: - peft: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/layer.py (offline: docs/refs/peft_lora_layer.py) """ -from einops import einsum from jaxtyping import Float from torch import nn, Tensor as T -import torch from dataclasses import dataclass from ..variant import register, ParamSpec @@ -47,6 +45,7 @@ class LoRA: ) -> Float[T, '*B o']: cfg = layer._lora_cfg scale = cfg.alpha / cfg.r - h = einsum(x, layer.lora_A, "... i, r i -> ... r") - delta = einsum(h, layer.lora_B, "... r, o r -> ... o") - return y + scale * delta + xA = x.to(layer.lora_A.dtype) + h = xA @ layer.lora_A.T + delta = h @ layer.lora_B.T + return y + (scale * delta).to(y.dtype) diff --git a/src/lora_lite/variants/pissa.py b/src/lora_lite/variants/pissa.py index 7987ff2..f34c2e3 100644 --- a/src/lora_lite/variants/pissa.py +++ b/src/lora_lite/variants/pissa.py @@ -16,7 +16,6 @@ Refs: (offline: docs/refs/peft_lora_layer.py, see pissa_init path) """ import torch -from einops import einsum from jaxtyping import Float from torch import nn, Tensor as T from dataclasses import dataclass @@ -76,6 +75,7 @@ class PiSSA: ) -> Float[T, '*B o']: cfg = layer._lora_cfg scale = cfg.alpha / cfg.r - h = einsum(x, layer.lora_A, "... i, r i -> ... r") - delta = einsum(h, layer.lora_B, "... r, o r -> ... o") - return y + scale * delta + xA = x.to(layer.lora_A.dtype) + h = xA @ layer.lora_A.T + delta = h @ layer.lora_B.T + return y + (scale * delta).to(y.dtype)