mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 15:15:55 +08:00
Add LoRA-XS variant: train only r×r core R between frozen SVD factors
Bałazy et al. 2024 (arxiv 2405.17604). A=diag(Sr)Vhr, B=Ur frozen from top-r SVD of W (W left intact); only the r×r R is trained, init normal(0,1e-5) so the adapter ~ identity at t=0. ~25k params at r=32 (24 down_proj targets). justfile: alpha=r (scale=1) and lr=4e-3, matching the ref LLaMA math config. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -103,6 +103,10 @@ bench-variant model variant steps="5000" r_override="" lr_override="" rotate_bas
|
|||||||
# so a small r leaves almost nothing trainable; r=256 is the variant default
|
# so a small r leaves almost nothing trainable; r=256 is the variant default
|
||||||
# and matches the published AntiPaSTO row. alpha=r (no extra scaling).
|
# and matches the published AntiPaSTO row. alpha=r (no extra scaling).
|
||||||
antipasto) lr=5e-3; r=256; alpha=256 ;;
|
antipasto) lr=5e-3; r=256; alpha=256 ;;
|
||||||
|
# LoRA-XS trains only the r*r core R between frozen SVD factors. Ref LLaMA
|
||||||
|
# math config sets lora_alpha=r (scale=1) and lr=4e-3 (run_math_tuning.sh);
|
||||||
|
# keep r=32 to share the subspace dim with LoRA/PiSSA (all-else-equal rank axis).
|
||||||
|
lora_xs) lr=4e-3; alpha=32 ;;
|
||||||
esac
|
esac
|
||||||
# r override (e.g. low-rank sweep); alpha tracks r for antipasto.
|
# r override (e.g. low-rank sweep); alpha tracks r for antipasto.
|
||||||
if [ -n "{{r_override}}" ]; then r="{{r_override}}"; alpha="{{r_override}}"; fi
|
if [ -n "{{r_override}}" ]; then r="{{r_override}}"; alpha="{{r_override}}"; fi
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ DEFAULT_TARGETS = (r"(q_proj|v_proj)$",)
|
|||||||
|
|
||||||
CFG_BY_VARIANT = {
|
CFG_BY_VARIANT = {
|
||||||
"lora": ll.LoRAConfig,
|
"lora": ll.LoRAConfig,
|
||||||
|
"lora_xs": ll.LoRAXSConfig,
|
||||||
"pissa": ll.PiSSAConfig,
|
"pissa": ll.PiSSAConfig,
|
||||||
"delora": ll.DeLoRAConfig,
|
"delora": ll.DeLoRAConfig,
|
||||||
"ia3": ll.IA3Config,
|
"ia3": ll.IA3Config,
|
||||||
@@ -44,7 +45,7 @@ class BenchmarkConfig:
|
|||||||
"""MetaMathQA -> GSM8K benchmark config. Tyro turns this into the CLI."""
|
"""MetaMathQA -> GSM8K benchmark config. Tyro turns this into the CLI."""
|
||||||
|
|
||||||
model: str = "Qwen/Qwen3.5-0.8B-Base"
|
model: str = "Qwen/Qwen3.5-0.8B-Base"
|
||||||
variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "road"] = "lora"
|
variant: Literal["lora", "lora_xs", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "road"] = "lora"
|
||||||
mode: Literal["benchmark", "probe"] = "benchmark"
|
mode: Literal["benchmark", "probe"] = "benchmark"
|
||||||
device: str = "cuda"
|
device: str = "cuda"
|
||||||
torch_dtype: str = "bfloat16"
|
torch_dtype: str = "bfloat16"
|
||||||
@@ -158,7 +159,7 @@ def count_base_grad_leaks(model: torch.nn.Module) -> int:
|
|||||||
|
|
||||||
|
|
||||||
def perturb_first_adapter(model: torch.nn.Module) -> None:
|
def perturb_first_adapter(model: torch.nn.Module) -> None:
|
||||||
priority = ("lora_B", "lora_g", "lora_c", "lora_alpha", "lora_U", "lora_A", "lora_lambda", "lora_gate", "lora_delta_s", "lora_m", "lora_road_theta", "lora_road_alpha")
|
priority = ("lora_B", "lora_R", "lora_g", "lora_c", "lora_alpha", "lora_U", "lora_A", "lora_lambda", "lora_gate", "lora_delta_s", "lora_m", "lora_road_theta", "lora_road_alpha")
|
||||||
for key in priority:
|
for key in priority:
|
||||||
for _, p in model.named_parameters():
|
for _, p in model.named_parameters():
|
||||||
if p.requires_grad and key in _:
|
if p.requires_grad and key in _:
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from . import variants # noqa: F401 triggers variant + config registration
|
|||||||
|
|
||||||
# Expose per-variant config classes for ergonomic typed construction.
|
# Expose per-variant config classes for ergonomic typed construction.
|
||||||
from .variants.lora import LoRAConfig
|
from .variants.lora import LoRAConfig
|
||||||
|
from .variants.lora_xs import LoRAXSConfig
|
||||||
from .variants.pissa import PiSSAConfig
|
from .variants.pissa import PiSSAConfig
|
||||||
from .variants.delora import DeLoRAConfig
|
from .variants.delora import DeLoRAConfig
|
||||||
from .variants.ia3 import IA3Config, IA3FFConfig
|
from .variants.ia3 import IA3Config, IA3FFConfig
|
||||||
@@ -25,6 +26,7 @@ from .variants.road import RoadConfig
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"AdapterConfig",
|
"AdapterConfig",
|
||||||
"LoRAConfig",
|
"LoRAConfig",
|
||||||
|
"LoRAXSConfig",
|
||||||
"PiSSAConfig",
|
"PiSSAConfig",
|
||||||
"DeLoRAConfig",
|
"DeLoRAConfig",
|
||||||
"IA3Config",
|
"IA3Config",
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
from . import ( # noqa: F401 side-effect: register
|
from . import ( # noqa: F401 side-effect: register
|
||||||
lora, pissa, delora, ia3, dora, hra, eva, antipasto, road,
|
lora, lora_xs, pissa, delora, ia3, dora, hra, eva, antipasto, road,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -0,0 +1,86 @@
|
|||||||
|
"""LoRA-XS: freeze W's top-r SVD as A,B; train only a small r x r matrix R between them.
|
||||||
|
|
||||||
|
Bałazy et al. 2024 https://arxiv.org/abs/2405.17604
|
||||||
|
|
||||||
|
W = U S Vh (truncated to top-r)
|
||||||
|
A = diag(Sr) Vhr (r, d_in) frozen -- singular values folded into A (ref)
|
||||||
|
B = Ur (d_out, r) frozen
|
||||||
|
R (r, r) trainable, ~0 at init
|
||||||
|
h = W x + (alpha/r) B R A x
|
||||||
|
|
||||||
|
Unlike PiSSA, W is NOT cropped: B@A reconstructs the top-r but stays *added on top* of
|
||||||
|
the full W, and R (init normal(0, 1e-5)) starts the adapter at ~identity. So the only
|
||||||
|
trainable tensor is r*r (e.g. r=32 -> 1024 params/layer), hence "extremely small".
|
||||||
|
|
||||||
|
The reference folds all singular values into A and leaves B as the raw left singular
|
||||||
|
vectors; R sits between two frozen, near-orthonormal bases. Their LLaMA math-tuning
|
||||||
|
config sets lora_alpha = r (scale = 1.0) and lr ~ 4e-3 (scripts/run_math_tuning.sh).
|
||||||
|
|
||||||
|
Refs:
|
||||||
|
- paper repo: https://github.com/MohammadrezaBanaei/LoRA-XS
|
||||||
|
(utils/initialization_utils.py: init_module_weights(R, sigma=1e-5), A/B requires_grad=False;
|
||||||
|
utils/latent_utils.py forward_latent: result += scaling * lora_B(R(lora_A(x))))
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
from jaxtyping import Float
|
||||||
|
from torch import nn, Tensor as T
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from ..variant import register, ParamSpec
|
||||||
|
from ..config import AdapterConfig, register_config
|
||||||
|
|
||||||
|
|
||||||
|
@register_config
|
||||||
|
@dataclass
|
||||||
|
class LoRAXSConfig(AdapterConfig):
|
||||||
|
variant: str = "lora_xs"
|
||||||
|
|
||||||
|
|
||||||
|
@register
|
||||||
|
class LoRAXS:
|
||||||
|
name = "lora_xs"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def param_specs(d_in, d_out, cfg):
|
||||||
|
return dict(
|
||||||
|
# Frozen top-r SVD factors of W (filled in init()); W itself stays intact.
|
||||||
|
lora_A=ParamSpec((cfg.r, d_in), init="zeros", trainable=False, as_buffer=True),
|
||||||
|
lora_B=ParamSpec((d_out, cfg.r), init="zeros", trainable=False, as_buffer=True),
|
||||||
|
# The only trainable tensor: r x r core, near-zero so the adapter ~ identity at t=0
|
||||||
|
# (ref uses normal(0, 1e-5); matches the repo's near_zero philosophy).
|
||||||
|
lora_R=ParamSpec((cfg.r, cfg.r), init=lambda t: t.normal_(0, 1e-5)),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def init(layer: nn.Module, cfg) -> None:
|
||||||
|
if type(layer) is not nn.Linear:
|
||||||
|
raise TypeError(
|
||||||
|
"LoRA-XS needs the dense SVD of layer.weight, so v1 only supports plain "
|
||||||
|
"nn.Linear, not bnb 4/8-bit."
|
||||||
|
)
|
||||||
|
W = layer.weight.data.float() # (d_out, d_in)
|
||||||
|
U, S, Vh = torch.linalg.svd(W, full_matrices=False)
|
||||||
|
r = cfg.r
|
||||||
|
Ur, Sr, Vhr = U[:, :r], S[:r], Vh[:r, :]
|
||||||
|
# A = diag(Sr) Vhr, B = Ur -> B@A = Ur diag(Sr) Vhr = top-r(W). W is left intact.
|
||||||
|
A = (Sr[:, None] * Vhr).to(cfg.dtype)
|
||||||
|
B = Ur.to(cfg.dtype)
|
||||||
|
layer.lora_A.copy_(A)
|
||||||
|
layer.lora_B.copy_(B)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(
|
||||||
|
layer: nn.Module,
|
||||||
|
x: Float[T, '*B i'],
|
||||||
|
y: Float[T, '*B o'],
|
||||||
|
) -> Float[T, '*B o']:
|
||||||
|
cfg = layer._lora_cfg
|
||||||
|
scale = cfg.alpha / cfg.r
|
||||||
|
A = layer.lora_A # (r, d_in), frozen
|
||||||
|
B = layer.lora_B # (d_out, r), frozen
|
||||||
|
R = layer.lora_R # (r, r), trainable
|
||||||
|
xA = x.to(A.dtype)
|
||||||
|
h = xA @ A.T # (*B, r)
|
||||||
|
h = h @ R.T # (*B, r) <- the learned core
|
||||||
|
delta = h @ B.T # (*B, d_out)
|
||||||
|
return y + (scale * delta).to(y.dtype)
|
||||||
@@ -32,12 +32,12 @@ sys.modules[SPEC.name] = benchmark
|
|||||||
SPEC.loader.exec_module(benchmark)
|
SPEC.loader.exec_module(benchmark)
|
||||||
|
|
||||||
|
|
||||||
VARIANTS = ["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva",
|
VARIANTS = ["lora", "lora_xs", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva",
|
||||||
"antipasto", "road"]
|
"antipasto", "road"]
|
||||||
# Variants that fail loud when attached on a bnb-loaded base (read dense weight in init).
|
# Variants that fail loud when attached on a bnb-loaded base (read dense weight in init).
|
||||||
# delora/eva also read weight but currently silently dequant -- they produce sane attach,
|
# delora/eva also read weight but currently silently dequant -- they produce sane attach,
|
||||||
# so we don't expect a raise from them in the attach-only smoke.
|
# so we don't expect a raise from them in the attach-only smoke.
|
||||||
BNB_RAISERS = {"pissa", "dora", "antipasto"}
|
BNB_RAISERS = {"pissa", "dora", "antipasto", "lora_xs"}
|
||||||
TINY_MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM"
|
TINY_MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM"
|
||||||
|
|
||||||
HAS_CUDA = torch.cuda.is_available()
|
HAS_CUDA = torch.cuda.is_available()
|
||||||
|
|||||||
Reference in New Issue
Block a user