mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 16:45:56 +08:00
53 lines
1.3 KiB
Python
53 lines
1.3 KiB
Python
"""Vanilla LoRA. Hu et al. 2021 https://arxiv.org/abs/2106.09685
|
|
|
|
h = W x + (alpha/r) B A x
|
|
|
|
Identity at t=0 from B=0.
|
|
|
|
Refs:
|
|
- peft: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/layer.py
|
|
(offline: docs/refs/peft_lora_layer.py)
|
|
"""
|
|
from einops import einsum
|
|
from jaxtyping import Float
|
|
from torch import nn, Tensor as T
|
|
import torch
|
|
from dataclasses import dataclass
|
|
|
|
from ..variant import register, ParamSpec
|
|
from ..config import AdapterConfig, register_config
|
|
|
|
|
|
@register_config
|
|
@dataclass
|
|
class LoRAConfig(AdapterConfig):
|
|
variant: str = "lora"
|
|
|
|
|
|
@register
|
|
class LoRA:
|
|
name = "lora"
|
|
|
|
@staticmethod
|
|
def param_specs(d_in, d_out, cfg):
|
|
return dict(
|
|
lora_A=ParamSpec((cfg.r, d_in), init="kaiming"),
|
|
lora_B=ParamSpec((d_out, cfg.r), init="zeros"),
|
|
)
|
|
|
|
@staticmethod
|
|
def init(layer: nn.Module, cfg) -> None:
|
|
return
|
|
|
|
@staticmethod
|
|
def forward(
|
|
layer: nn.Module,
|
|
x: Float[T, '*B i'],
|
|
y: Float[T, '*B o'],
|
|
) -> Float[T, '*B o']:
|
|
cfg = layer._lora_cfg
|
|
scale = cfg.alpha / cfg.r
|
|
h = einsum(x, layer.lora_A, "... i, r i -> ... r")
|
|
delta = einsum(h, layer.lora_B, "... r, o r -> ... o")
|
|
return y + scale * delta
|