Add LoRA-XS variant: train only r×r core R between frozen SVD factors

Bałazy et al. 2024 (arxiv 2405.17604). A=diag(Sr)Vhr, B=Ur frozen from
top-r SVD of W (W left intact); only the r×r R is trained, init normal(0,1e-5)
so the adapter ~ identity at t=0. ~25k params at r=32 (24 down_proj targets).
justfile: alpha=r (scale=1) and lr=4e-3, matching the ref LLaMA math config.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-18 19:48:40 +08:00
parent 12e13cca79
commit c792ad3e5f
6 changed files with 98 additions and 5 deletions
+3 -2
View File
@@ -27,6 +27,7 @@ DEFAULT_TARGETS = (r"(q_proj|v_proj)$",)
CFG_BY_VARIANT = {
"lora": ll.LoRAConfig,
"lora_xs": ll.LoRAXSConfig,
"pissa": ll.PiSSAConfig,
"delora": ll.DeLoRAConfig,
"ia3": ll.IA3Config,
@@ -44,7 +45,7 @@ class BenchmarkConfig:
"""MetaMathQA -> GSM8K benchmark config. Tyro turns this into the CLI."""
model: str = "Qwen/Qwen3.5-0.8B-Base"
variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "road"] = "lora"
variant: Literal["lora", "lora_xs", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "road"] = "lora"
mode: Literal["benchmark", "probe"] = "benchmark"
device: str = "cuda"
torch_dtype: str = "bfloat16"
@@ -158,7 +159,7 @@ def count_base_grad_leaks(model: torch.nn.Module) -> int:
def perturb_first_adapter(model: torch.nn.Module) -> None:
priority = ("lora_B", "lora_g", "lora_c", "lora_alpha", "lora_U", "lora_A", "lora_lambda", "lora_gate", "lora_delta_s", "lora_m", "lora_road_theta", "lora_road_alpha")
priority = ("lora_B", "lora_R", "lora_g", "lora_c", "lora_alpha", "lora_U", "lora_A", "lora_lambda", "lora_gate", "lora_delta_s", "lora_m", "lora_road_theta", "lora_road_alpha")
for key in priority:
for _, p in model.named_parameters():
if p.requires_grad and key in _: