mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 14:00:19 +08:00
Add LoRA-XS variant: train only r×r core R between frozen SVD factors
Bałazy et al. 2024 (arxiv 2405.17604). A=diag(Sr)Vhr, B=Ur frozen from top-r SVD of W (W left intact); only the r×r R is trained, init normal(0,1e-5) so the adapter ~ identity at t=0. ~25k params at r=32 (24 down_proj targets). justfile: alpha=r (scale=1) and lr=4e-3, matching the ref LLaMA math config. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -103,6 +103,10 @@ bench-variant model variant steps="5000" r_override="" lr_override="" rotate_bas
|
||||
# so a small r leaves almost nothing trainable; r=256 is the variant default
|
||||
# and matches the published AntiPaSTO row. alpha=r (no extra scaling).
|
||||
antipasto) lr=5e-3; r=256; alpha=256 ;;
|
||||
# LoRA-XS trains only the r*r core R between frozen SVD factors. Ref LLaMA
|
||||
# math config sets lora_alpha=r (scale=1) and lr=4e-3 (run_math_tuning.sh);
|
||||
# keep r=32 to share the subspace dim with LoRA/PiSSA (all-else-equal rank axis).
|
||||
lora_xs) lr=4e-3; alpha=32 ;;
|
||||
esac
|
||||
# r override (e.g. low-rank sweep); alpha tracks r for antipasto.
|
||||
if [ -n "{{r_override}}" ]; then r="{{r_override}}"; alpha="{{r_override}}"; fi
|
||||
|
||||
@@ -27,6 +27,7 @@ DEFAULT_TARGETS = (r"(q_proj|v_proj)$",)
|
||||
|
||||
CFG_BY_VARIANT = {
|
||||
"lora": ll.LoRAConfig,
|
||||
"lora_xs": ll.LoRAXSConfig,
|
||||
"pissa": ll.PiSSAConfig,
|
||||
"delora": ll.DeLoRAConfig,
|
||||
"ia3": ll.IA3Config,
|
||||
@@ -44,7 +45,7 @@ class BenchmarkConfig:
|
||||
"""MetaMathQA -> GSM8K benchmark config. Tyro turns this into the CLI."""
|
||||
|
||||
model: str = "Qwen/Qwen3.5-0.8B-Base"
|
||||
variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "road"] = "lora"
|
||||
variant: Literal["lora", "lora_xs", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "road"] = "lora"
|
||||
mode: Literal["benchmark", "probe"] = "benchmark"
|
||||
device: str = "cuda"
|
||||
torch_dtype: str = "bfloat16"
|
||||
@@ -158,7 +159,7 @@ def count_base_grad_leaks(model: torch.nn.Module) -> int:
|
||||
|
||||
|
||||
def perturb_first_adapter(model: torch.nn.Module) -> None:
|
||||
priority = ("lora_B", "lora_g", "lora_c", "lora_alpha", "lora_U", "lora_A", "lora_lambda", "lora_gate", "lora_delta_s", "lora_m", "lora_road_theta", "lora_road_alpha")
|
||||
priority = ("lora_B", "lora_R", "lora_g", "lora_c", "lora_alpha", "lora_U", "lora_A", "lora_lambda", "lora_gate", "lora_delta_s", "lora_m", "lora_road_theta", "lora_road_alpha")
|
||||
for key in priority:
|
||||
for _, p in model.named_parameters():
|
||||
if p.requires_grad and key in _:
|
||||
|
||||
@@ -13,6 +13,7 @@ from . import variants # noqa: F401 triggers variant + config registration
|
||||
|
||||
# Expose per-variant config classes for ergonomic typed construction.
|
||||
from .variants.lora import LoRAConfig
|
||||
from .variants.lora_xs import LoRAXSConfig
|
||||
from .variants.pissa import PiSSAConfig
|
||||
from .variants.delora import DeLoRAConfig
|
||||
from .variants.ia3 import IA3Config, IA3FFConfig
|
||||
@@ -25,6 +26,7 @@ from .variants.road import RoadConfig
|
||||
__all__ = [
|
||||
"AdapterConfig",
|
||||
"LoRAConfig",
|
||||
"LoRAXSConfig",
|
||||
"PiSSAConfig",
|
||||
"DeLoRAConfig",
|
||||
"IA3Config",
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from . import ( # noqa: F401 side-effect: register
|
||||
lora, pissa, delora, ia3, dora, hra, eva, antipasto, road,
|
||||
lora, lora_xs, pissa, delora, ia3, dora, hra, eva, antipasto, road,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
"""LoRA-XS: freeze W's top-r SVD as A,B; train only a small r x r matrix R between them.
|
||||
|
||||
Bałazy et al. 2024 https://arxiv.org/abs/2405.17604
|
||||
|
||||
W = U S Vh (truncated to top-r)
|
||||
A = diag(Sr) Vhr (r, d_in) frozen -- singular values folded into A (ref)
|
||||
B = Ur (d_out, r) frozen
|
||||
R (r, r) trainable, ~0 at init
|
||||
h = W x + (alpha/r) B R A x
|
||||
|
||||
Unlike PiSSA, W is NOT cropped: B@A reconstructs the top-r but stays *added on top* of
|
||||
the full W, and R (init normal(0, 1e-5)) starts the adapter at ~identity. So the only
|
||||
trainable tensor is r*r (e.g. r=32 -> 1024 params/layer), hence "extremely small".
|
||||
|
||||
The reference folds all singular values into A and leaves B as the raw left singular
|
||||
vectors; R sits between two frozen, near-orthonormal bases. Their LLaMA math-tuning
|
||||
config sets lora_alpha = r (scale = 1.0) and lr ~ 4e-3 (scripts/run_math_tuning.sh).
|
||||
|
||||
Refs:
|
||||
- paper repo: https://github.com/MohammadrezaBanaei/LoRA-XS
|
||||
(utils/initialization_utils.py: init_module_weights(R, sigma=1e-5), A/B requires_grad=False;
|
||||
utils/latent_utils.py forward_latent: result += scaling * lora_B(R(lora_A(x))))
|
||||
"""
|
||||
import torch
|
||||
from jaxtyping import Float
|
||||
from torch import nn, Tensor as T
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..variant import register, ParamSpec
|
||||
from ..config import AdapterConfig, register_config
|
||||
|
||||
|
||||
@register_config
|
||||
@dataclass
|
||||
class LoRAXSConfig(AdapterConfig):
|
||||
variant: str = "lora_xs"
|
||||
|
||||
|
||||
@register
|
||||
class LoRAXS:
|
||||
name = "lora_xs"
|
||||
|
||||
@staticmethod
|
||||
def param_specs(d_in, d_out, cfg):
|
||||
return dict(
|
||||
# Frozen top-r SVD factors of W (filled in init()); W itself stays intact.
|
||||
lora_A=ParamSpec((cfg.r, d_in), init="zeros", trainable=False, as_buffer=True),
|
||||
lora_B=ParamSpec((d_out, cfg.r), init="zeros", trainable=False, as_buffer=True),
|
||||
# The only trainable tensor: r x r core, near-zero so the adapter ~ identity at t=0
|
||||
# (ref uses normal(0, 1e-5); matches the repo's near_zero philosophy).
|
||||
lora_R=ParamSpec((cfg.r, cfg.r), init=lambda t: t.normal_(0, 1e-5)),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def init(layer: nn.Module, cfg) -> None:
|
||||
if type(layer) is not nn.Linear:
|
||||
raise TypeError(
|
||||
"LoRA-XS needs the dense SVD of layer.weight, so v1 only supports plain "
|
||||
"nn.Linear, not bnb 4/8-bit."
|
||||
)
|
||||
W = layer.weight.data.float() # (d_out, d_in)
|
||||
U, S, Vh = torch.linalg.svd(W, full_matrices=False)
|
||||
r = cfg.r
|
||||
Ur, Sr, Vhr = U[:, :r], S[:r], Vh[:r, :]
|
||||
# A = diag(Sr) Vhr, B = Ur -> B@A = Ur diag(Sr) Vhr = top-r(W). W is left intact.
|
||||
A = (Sr[:, None] * Vhr).to(cfg.dtype)
|
||||
B = Ur.to(cfg.dtype)
|
||||
layer.lora_A.copy_(A)
|
||||
layer.lora_B.copy_(B)
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
layer: nn.Module,
|
||||
x: Float[T, '*B i'],
|
||||
y: Float[T, '*B o'],
|
||||
) -> Float[T, '*B o']:
|
||||
cfg = layer._lora_cfg
|
||||
scale = cfg.alpha / cfg.r
|
||||
A = layer.lora_A # (r, d_in), frozen
|
||||
B = layer.lora_B # (d_out, r), frozen
|
||||
R = layer.lora_R # (r, r), trainable
|
||||
xA = x.to(A.dtype)
|
||||
h = xA @ A.T # (*B, r)
|
||||
h = h @ R.T # (*B, r) <- the learned core
|
||||
delta = h @ B.T # (*B, d_out)
|
||||
return y + (scale * delta).to(y.dtype)
|
||||
@@ -32,12 +32,12 @@ sys.modules[SPEC.name] = benchmark
|
||||
SPEC.loader.exec_module(benchmark)
|
||||
|
||||
|
||||
VARIANTS = ["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva",
|
||||
VARIANTS = ["lora", "lora_xs", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva",
|
||||
"antipasto", "road"]
|
||||
# Variants that fail loud when attached on a bnb-loaded base (read dense weight in init).
|
||||
# delora/eva also read weight but currently silently dequant -- they produce sane attach,
|
||||
# so we don't expect a raise from them in the attach-only smoke.
|
||||
BNB_RAISERS = {"pissa", "dora", "antipasto"}
|
||||
BNB_RAISERS = {"pissa", "dora", "antipasto", "lora_xs"}
|
||||
TINY_MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM"
|
||||
|
||||
HAS_CUDA = torch.cuda.is_available()
|
||||
|
||||
Reference in New Issue
Block a user