Add LoRA-XS variant: train only r×r core R between frozen SVD factors

Bałazy et al. 2024 (arxiv 2405.17604). A=diag(Sr)Vhr, B=Ur frozen from
top-r SVD of W (W left intact); only the r×r R is trained, init normal(0,1e-5)
so the adapter ~ identity at t=0. ~25k params at r=32 (24 down_proj targets).
justfile: alpha=r (scale=1) and lr=4e-3, matching the ref LLaMA math config.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-18 19:48:40 +08:00
parent 12e13cca79
commit c792ad3e5f
6 changed files with 98 additions and 5 deletions
+4
View File
@@ -103,6 +103,10 @@ bench-variant model variant steps="5000" r_override="" lr_override="" rotate_bas
# so a small r leaves almost nothing trainable; r=256 is the variant default # so a small r leaves almost nothing trainable; r=256 is the variant default
# and matches the published AntiPaSTO row. alpha=r (no extra scaling). # and matches the published AntiPaSTO row. alpha=r (no extra scaling).
antipasto) lr=5e-3; r=256; alpha=256 ;; antipasto) lr=5e-3; r=256; alpha=256 ;;
# LoRA-XS trains only the r*r core R between frozen SVD factors. Ref LLaMA
# math config sets lora_alpha=r (scale=1) and lr=4e-3 (run_math_tuning.sh);
# keep r=32 to share the subspace dim with LoRA/PiSSA (all-else-equal rank axis).
lora_xs) lr=4e-3; alpha=32 ;;
esac esac
# r override (e.g. low-rank sweep); alpha tracks r for antipasto. # r override (e.g. low-rank sweep); alpha tracks r for antipasto.
if [ -n "{{r_override}}" ]; then r="{{r_override}}"; alpha="{{r_override}}"; fi if [ -n "{{r_override}}" ]; then r="{{r_override}}"; alpha="{{r_override}}"; fi
+3 -2
View File
@@ -27,6 +27,7 @@ DEFAULT_TARGETS = (r"(q_proj|v_proj)$",)
CFG_BY_VARIANT = { CFG_BY_VARIANT = {
"lora": ll.LoRAConfig, "lora": ll.LoRAConfig,
"lora_xs": ll.LoRAXSConfig,
"pissa": ll.PiSSAConfig, "pissa": ll.PiSSAConfig,
"delora": ll.DeLoRAConfig, "delora": ll.DeLoRAConfig,
"ia3": ll.IA3Config, "ia3": ll.IA3Config,
@@ -44,7 +45,7 @@ class BenchmarkConfig:
"""MetaMathQA -> GSM8K benchmark config. Tyro turns this into the CLI.""" """MetaMathQA -> GSM8K benchmark config. Tyro turns this into the CLI."""
model: str = "Qwen/Qwen3.5-0.8B-Base" model: str = "Qwen/Qwen3.5-0.8B-Base"
variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "road"] = "lora" variant: Literal["lora", "lora_xs", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "road"] = "lora"
mode: Literal["benchmark", "probe"] = "benchmark" mode: Literal["benchmark", "probe"] = "benchmark"
device: str = "cuda" device: str = "cuda"
torch_dtype: str = "bfloat16" torch_dtype: str = "bfloat16"
@@ -158,7 +159,7 @@ def count_base_grad_leaks(model: torch.nn.Module) -> int:
def perturb_first_adapter(model: torch.nn.Module) -> None: def perturb_first_adapter(model: torch.nn.Module) -> None:
priority = ("lora_B", "lora_g", "lora_c", "lora_alpha", "lora_U", "lora_A", "lora_lambda", "lora_gate", "lora_delta_s", "lora_m", "lora_road_theta", "lora_road_alpha") priority = ("lora_B", "lora_R", "lora_g", "lora_c", "lora_alpha", "lora_U", "lora_A", "lora_lambda", "lora_gate", "lora_delta_s", "lora_m", "lora_road_theta", "lora_road_alpha")
for key in priority: for key in priority:
for _, p in model.named_parameters(): for _, p in model.named_parameters():
if p.requires_grad and key in _: if p.requires_grad and key in _:
+2
View File
@@ -13,6 +13,7 @@ from . import variants # noqa: F401 triggers variant + config registration
# Expose per-variant config classes for ergonomic typed construction. # Expose per-variant config classes for ergonomic typed construction.
from .variants.lora import LoRAConfig from .variants.lora import LoRAConfig
from .variants.lora_xs import LoRAXSConfig
from .variants.pissa import PiSSAConfig from .variants.pissa import PiSSAConfig
from .variants.delora import DeLoRAConfig from .variants.delora import DeLoRAConfig
from .variants.ia3 import IA3Config, IA3FFConfig from .variants.ia3 import IA3Config, IA3FFConfig
@@ -25,6 +26,7 @@ from .variants.road import RoadConfig
__all__ = [ __all__ = [
"AdapterConfig", "AdapterConfig",
"LoRAConfig", "LoRAConfig",
"LoRAXSConfig",
"PiSSAConfig", "PiSSAConfig",
"DeLoRAConfig", "DeLoRAConfig",
"IA3Config", "IA3Config",
+1 -1
View File
@@ -1,3 +1,3 @@
from . import ( # noqa: F401 side-effect: register from . import ( # noqa: F401 side-effect: register
lora, pissa, delora, ia3, dora, hra, eva, antipasto, road, lora, lora_xs, pissa, delora, ia3, dora, hra, eva, antipasto, road,
) )
+86
View File
@@ -0,0 +1,86 @@
"""LoRA-XS: freeze W's top-r SVD as A,B; train only a small r x r matrix R between them.
Bałazy et al. 2024 https://arxiv.org/abs/2405.17604
W = U S Vh (truncated to top-r)
A = diag(Sr) Vhr (r, d_in) frozen -- singular values folded into A (ref)
B = Ur (d_out, r) frozen
R (r, r) trainable, ~0 at init
h = W x + (alpha/r) B R A x
Unlike PiSSA, W is NOT cropped: B@A reconstructs the top-r but stays *added on top* of
the full W, and R (init normal(0, 1e-5)) starts the adapter at ~identity. So the only
trainable tensor is r*r (e.g. r=32 -> 1024 params/layer), hence "extremely small".
The reference folds all singular values into A and leaves B as the raw left singular
vectors; R sits between two frozen, near-orthonormal bases. Their LLaMA math-tuning
config sets lora_alpha = r (scale = 1.0) and lr ~ 4e-3 (scripts/run_math_tuning.sh).
Refs:
- paper repo: https://github.com/MohammadrezaBanaei/LoRA-XS
(utils/initialization_utils.py: init_module_weights(R, sigma=1e-5), A/B requires_grad=False;
utils/latent_utils.py forward_latent: result += scaling * lora_B(R(lora_A(x))))
"""
import torch
from jaxtyping import Float
from torch import nn, Tensor as T
from dataclasses import dataclass
from ..variant import register, ParamSpec
from ..config import AdapterConfig, register_config
@register_config
@dataclass
class LoRAXSConfig(AdapterConfig):
variant: str = "lora_xs"
@register
class LoRAXS:
name = "lora_xs"
@staticmethod
def param_specs(d_in, d_out, cfg):
return dict(
# Frozen top-r SVD factors of W (filled in init()); W itself stays intact.
lora_A=ParamSpec((cfg.r, d_in), init="zeros", trainable=False, as_buffer=True),
lora_B=ParamSpec((d_out, cfg.r), init="zeros", trainable=False, as_buffer=True),
# The only trainable tensor: r x r core, near-zero so the adapter ~ identity at t=0
# (ref uses normal(0, 1e-5); matches the repo's near_zero philosophy).
lora_R=ParamSpec((cfg.r, cfg.r), init=lambda t: t.normal_(0, 1e-5)),
)
@staticmethod
def init(layer: nn.Module, cfg) -> None:
if type(layer) is not nn.Linear:
raise TypeError(
"LoRA-XS needs the dense SVD of layer.weight, so v1 only supports plain "
"nn.Linear, not bnb 4/8-bit."
)
W = layer.weight.data.float() # (d_out, d_in)
U, S, Vh = torch.linalg.svd(W, full_matrices=False)
r = cfg.r
Ur, Sr, Vhr = U[:, :r], S[:r], Vh[:r, :]
# A = diag(Sr) Vhr, B = Ur -> B@A = Ur diag(Sr) Vhr = top-r(W). W is left intact.
A = (Sr[:, None] * Vhr).to(cfg.dtype)
B = Ur.to(cfg.dtype)
layer.lora_A.copy_(A)
layer.lora_B.copy_(B)
@staticmethod
def forward(
layer: nn.Module,
x: Float[T, '*B i'],
y: Float[T, '*B o'],
) -> Float[T, '*B o']:
cfg = layer._lora_cfg
scale = cfg.alpha / cfg.r
A = layer.lora_A # (r, d_in), frozen
B = layer.lora_B # (d_out, r), frozen
R = layer.lora_R # (r, r), trainable
xA = x.to(A.dtype)
h = xA @ A.T # (*B, r)
h = h @ R.T # (*B, r) <- the learned core
delta = h @ B.T # (*B, d_out)
return y + (scale * delta).to(y.dtype)
+2 -2
View File
@@ -32,12 +32,12 @@ sys.modules[SPEC.name] = benchmark
SPEC.loader.exec_module(benchmark) SPEC.loader.exec_module(benchmark)
VARIANTS = ["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", VARIANTS = ["lora", "lora_xs", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva",
"antipasto", "road"] "antipasto", "road"]
# Variants that fail loud when attached on a bnb-loaded base (read dense weight in init). # Variants that fail loud when attached on a bnb-loaded base (read dense weight in init).
# delora/eva also read weight but currently silently dequant -- they produce sane attach, # delora/eva also read weight but currently silently dequant -- they produce sane attach,
# so we don't expect a raise from them in the attach-only smoke. # so we don't expect a raise from them in the attach-only smoke.
BNB_RAISERS = {"pissa", "dora", "antipasto"} BNB_RAISERS = {"pissa", "dora", "antipasto", "lora_xs"}
TINY_MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM" TINY_MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM"
HAS_CUDA = torch.cuda.is_available() HAS_CUDA = torch.cuda.is_available()