diff --git a/justfile b/justfile index a3cb4c5..dcd1797 100644 --- a/justfile +++ b/justfile @@ -103,6 +103,10 @@ bench-variant model variant steps="5000" r_override="" lr_override="" rotate_bas # so a small r leaves almost nothing trainable; r=256 is the variant default # and matches the published AntiPaSTO row. alpha=r (no extra scaling). antipasto) lr=5e-3; r=256; alpha=256 ;; + # LoRA-XS trains only the r*r core R between frozen SVD factors. Ref LLaMA + # math config sets lora_alpha=r (scale=1) and lr=4e-3 (run_math_tuning.sh); + # keep r=32 to share the subspace dim with LoRA/PiSSA (all-else-equal rank axis). + lora_xs) lr=4e-3; alpha=32 ;; esac # r override (e.g. low-rank sweep); alpha tracks r for antipasto. if [ -n "{{r_override}}" ]; then r="{{r_override}}"; alpha="{{r_override}}"; fi diff --git a/scripts/metamath_gsm8k_benchmark.py b/scripts/metamath_gsm8k_benchmark.py index b6a27a7..73670fc 100644 --- a/scripts/metamath_gsm8k_benchmark.py +++ b/scripts/metamath_gsm8k_benchmark.py @@ -27,6 +27,7 @@ DEFAULT_TARGETS = (r"(q_proj|v_proj)$",) CFG_BY_VARIANT = { "lora": ll.LoRAConfig, + "lora_xs": ll.LoRAXSConfig, "pissa": ll.PiSSAConfig, "delora": ll.DeLoRAConfig, "ia3": ll.IA3Config, @@ -44,7 +45,7 @@ class BenchmarkConfig: """MetaMathQA -> GSM8K benchmark config. Tyro turns this into the CLI.""" model: str = "Qwen/Qwen3.5-0.8B-Base" - variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "road"] = "lora" + variant: Literal["lora", "lora_xs", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "road"] = "lora" mode: Literal["benchmark", "probe"] = "benchmark" device: str = "cuda" torch_dtype: str = "bfloat16" @@ -158,7 +159,7 @@ def count_base_grad_leaks(model: torch.nn.Module) -> int: def perturb_first_adapter(model: torch.nn.Module) -> None: - priority = ("lora_B", "lora_g", "lora_c", "lora_alpha", "lora_U", "lora_A", "lora_lambda", "lora_gate", "lora_delta_s", "lora_m", "lora_road_theta", "lora_road_alpha") + priority = ("lora_B", "lora_R", "lora_g", "lora_c", "lora_alpha", "lora_U", "lora_A", "lora_lambda", "lora_gate", "lora_delta_s", "lora_m", "lora_road_theta", "lora_road_alpha") for key in priority: for _, p in model.named_parameters(): if p.requires_grad and key in _: diff --git a/src/lora_lite/__init__.py b/src/lora_lite/__init__.py index 588e394..1ccaaa7 100644 --- a/src/lora_lite/__init__.py +++ b/src/lora_lite/__init__.py @@ -13,6 +13,7 @@ from . import variants # noqa: F401 triggers variant + config registration # Expose per-variant config classes for ergonomic typed construction. from .variants.lora import LoRAConfig +from .variants.lora_xs import LoRAXSConfig from .variants.pissa import PiSSAConfig from .variants.delora import DeLoRAConfig from .variants.ia3 import IA3Config, IA3FFConfig @@ -25,6 +26,7 @@ from .variants.road import RoadConfig __all__ = [ "AdapterConfig", "LoRAConfig", + "LoRAXSConfig", "PiSSAConfig", "DeLoRAConfig", "IA3Config", diff --git a/src/lora_lite/variants/__init__.py b/src/lora_lite/variants/__init__.py index 8e3045a..e2e5ab5 100644 --- a/src/lora_lite/variants/__init__.py +++ b/src/lora_lite/variants/__init__.py @@ -1,3 +1,3 @@ from . import ( # noqa: F401 side-effect: register - lora, pissa, delora, ia3, dora, hra, eva, antipasto, road, + lora, lora_xs, pissa, delora, ia3, dora, hra, eva, antipasto, road, ) diff --git a/src/lora_lite/variants/lora_xs.py b/src/lora_lite/variants/lora_xs.py new file mode 100644 index 0000000..f556d55 --- /dev/null +++ b/src/lora_lite/variants/lora_xs.py @@ -0,0 +1,86 @@ +"""LoRA-XS: freeze W's top-r SVD as A,B; train only a small r x r matrix R between them. + +BaƂazy et al. 2024 https://arxiv.org/abs/2405.17604 + + W = U S Vh (truncated to top-r) + A = diag(Sr) Vhr (r, d_in) frozen -- singular values folded into A (ref) + B = Ur (d_out, r) frozen + R (r, r) trainable, ~0 at init + h = W x + (alpha/r) B R A x + +Unlike PiSSA, W is NOT cropped: B@A reconstructs the top-r but stays *added on top* of +the full W, and R (init normal(0, 1e-5)) starts the adapter at ~identity. So the only +trainable tensor is r*r (e.g. r=32 -> 1024 params/layer), hence "extremely small". + +The reference folds all singular values into A and leaves B as the raw left singular +vectors; R sits between two frozen, near-orthonormal bases. Their LLaMA math-tuning +config sets lora_alpha = r (scale = 1.0) and lr ~ 4e-3 (scripts/run_math_tuning.sh). + +Refs: + - paper repo: https://github.com/MohammadrezaBanaei/LoRA-XS + (utils/initialization_utils.py: init_module_weights(R, sigma=1e-5), A/B requires_grad=False; + utils/latent_utils.py forward_latent: result += scaling * lora_B(R(lora_A(x)))) +""" +import torch +from jaxtyping import Float +from torch import nn, Tensor as T +from dataclasses import dataclass + +from ..variant import register, ParamSpec +from ..config import AdapterConfig, register_config + + +@register_config +@dataclass +class LoRAXSConfig(AdapterConfig): + variant: str = "lora_xs" + + +@register +class LoRAXS: + name = "lora_xs" + + @staticmethod + def param_specs(d_in, d_out, cfg): + return dict( + # Frozen top-r SVD factors of W (filled in init()); W itself stays intact. + lora_A=ParamSpec((cfg.r, d_in), init="zeros", trainable=False, as_buffer=True), + lora_B=ParamSpec((d_out, cfg.r), init="zeros", trainable=False, as_buffer=True), + # The only trainable tensor: r x r core, near-zero so the adapter ~ identity at t=0 + # (ref uses normal(0, 1e-5); matches the repo's near_zero philosophy). + lora_R=ParamSpec((cfg.r, cfg.r), init=lambda t: t.normal_(0, 1e-5)), + ) + + @staticmethod + def init(layer: nn.Module, cfg) -> None: + if type(layer) is not nn.Linear: + raise TypeError( + "LoRA-XS needs the dense SVD of layer.weight, so v1 only supports plain " + "nn.Linear, not bnb 4/8-bit." + ) + W = layer.weight.data.float() # (d_out, d_in) + U, S, Vh = torch.linalg.svd(W, full_matrices=False) + r = cfg.r + Ur, Sr, Vhr = U[:, :r], S[:r], Vh[:r, :] + # A = diag(Sr) Vhr, B = Ur -> B@A = Ur diag(Sr) Vhr = top-r(W). W is left intact. + A = (Sr[:, None] * Vhr).to(cfg.dtype) + B = Ur.to(cfg.dtype) + layer.lora_A.copy_(A) + layer.lora_B.copy_(B) + + @staticmethod + def forward( + layer: nn.Module, + x: Float[T, '*B i'], + y: Float[T, '*B o'], + ) -> Float[T, '*B o']: + cfg = layer._lora_cfg + scale = cfg.alpha / cfg.r + A = layer.lora_A # (r, d_in), frozen + B = layer.lora_B # (d_out, r), frozen + R = layer.lora_R # (r, r), trainable + xA = x.to(A.dtype) + h = xA @ A.T # (*B, r) + h = h @ R.T # (*B, r) <- the learned core + delta = h @ B.T # (*B, d_out) + return y + (scale * delta).to(y.dtype) diff --git a/tests/test_metamath_smoke.py b/tests/test_metamath_smoke.py index 178440e..bc9bbe5 100644 --- a/tests/test_metamath_smoke.py +++ b/tests/test_metamath_smoke.py @@ -32,12 +32,12 @@ sys.modules[SPEC.name] = benchmark SPEC.loader.exec_module(benchmark) -VARIANTS = ["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", +VARIANTS = ["lora", "lora_xs", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "road"] # Variants that fail loud when attached on a bnb-loaded base (read dense weight in init). # delora/eva also read weight but currently silently dequant -- they produce sane attach, # so we don't expect a raise from them in the attach-only smoke. -BNB_RAISERS = {"pissa", "dora", "antipasto"} +BNB_RAISERS = {"pissa", "dora", "antipasto", "lora_xs"} TINY_MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM" HAS_CUDA = torch.cuda.is_available()