mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 15:00:33 +08:00
c792ad3e5f
Bałazy et al. 2024 (arxiv 2405.17604). A=diag(Sr)Vhr, B=Ur frozen from top-r SVD of W (W left intact); only the r×r R is trained, init normal(0,1e-5) so the adapter ~ identity at t=0. ~25k params at r=32 (24 down_proj targets). justfile: alpha=r (scale=1) and lr=4e-3, matching the ref LLaMA math config. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
129 lines
4.7 KiB
Python
129 lines
4.7 KiB
Python
"""End-to-end smoke: run the metamath benchmark in probe mode for every variant.
|
|
|
|
Probe mode trains a few steps on tiny-random Llama, saves the adapter, reloads
|
|
it onto a fresh model, and asserts the trained logits match within tol. That's
|
|
the train+save+load round-trip on a real HF model, one test per variant.
|
|
|
|
A second test attaches each variant on top of a 4bit/8bit-loaded base and runs
|
|
one backward step. PiSSA/DoRA/AntiPaSTO/EVA must fail loud on quantized weights;
|
|
the rest must produce nonzero adapter grads. We do not run the full probe under
|
|
bnb because tiny-random + bnb dequant produces NaN logits unrelated to adapter
|
|
correctness.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import importlib.util
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
import lora_lite as ll
|
|
|
|
_SCRIPTS = Path(__file__).resolve().parent.parent / "scripts"
|
|
sys.path.insert(0, str(_SCRIPTS)) # benchmark does `from _cost import ...` (sibling module)
|
|
SPEC = importlib.util.spec_from_file_location(
|
|
"metamath_benchmark",
|
|
_SCRIPTS / "metamath_gsm8k_benchmark.py",
|
|
)
|
|
benchmark = importlib.util.module_from_spec(SPEC)
|
|
sys.modules[SPEC.name] = benchmark
|
|
SPEC.loader.exec_module(benchmark)
|
|
|
|
|
|
VARIANTS = ["lora", "lora_xs", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva",
|
|
"antipasto", "road"]
|
|
# Variants that fail loud when attached on a bnb-loaded base (read dense weight in init).
|
|
# delora/eva also read weight but currently silently dequant -- they produce sane attach,
|
|
# so we don't expect a raise from them in the attach-only smoke.
|
|
BNB_RAISERS = {"pissa", "dora", "antipasto", "lora_xs"}
|
|
TINY_MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM"
|
|
|
|
HAS_CUDA = torch.cuda.is_available()
|
|
HAS_BNB = importlib.util.find_spec("bitsandbytes") is not None
|
|
|
|
|
|
def quick_cfg(variant: str, tmp_path: Path, quantization: str = "none") -> "benchmark.BenchmarkConfig":
|
|
target_name = (
|
|
[r"(k_proj|v_proj)$"] if variant == "ia3"
|
|
else [r"(down_proj)$"] if variant == "ia3_ff"
|
|
else [r"(q_proj|v_proj)$"]
|
|
)
|
|
cfg = benchmark.BenchmarkConfig(
|
|
model=TINY_MODEL,
|
|
variant=variant,
|
|
mode="probe",
|
|
device="cuda" if HAS_CUDA else "cpu",
|
|
torch_dtype="float16" if quantization != "none" else "float32",
|
|
quantization=quantization,
|
|
r=4,
|
|
alpha=8,
|
|
target_name=target_name,
|
|
layers="all",
|
|
steps=2,
|
|
batch_size=2,
|
|
batch_size_eval=4,
|
|
max_train_samples=8,
|
|
max_eval_samples=4,
|
|
max_valid_samples=4,
|
|
max_test_samples=4,
|
|
max_seq_length=128,
|
|
max_new_tokens=8,
|
|
lr=5e-3,
|
|
road_group_size=8,
|
|
seed=0,
|
|
log_examples=0,
|
|
log_every=1000,
|
|
output_dir=tmp_path / "out",
|
|
)
|
|
return cfg
|
|
|
|
|
|
@pytest.mark.parametrize("variant", VARIANTS)
|
|
def test_metamath_quick_train_save_load(variant: str, tmp_path: Path):
|
|
"""Train 2 steps, save, reload onto fresh tiny model, logits match within tol."""
|
|
cfg = quick_cfg(variant, tmp_path)
|
|
result = benchmark.run(cfg)
|
|
|
|
assert result["train"]["base_grad_leaks"] == 0
|
|
assert result["train"]["first_grad_norm"] > 0
|
|
assert result["train"]["adapter_delta"] > 0
|
|
probe = result.get("probe") or {}
|
|
assert "reload_err" in probe
|
|
assert probe["reload_err"] < cfg.reload_tol
|
|
|
|
|
|
@pytest.mark.skipif(not (HAS_CUDA and HAS_BNB), reason="needs CUDA + bitsandbytes")
|
|
@pytest.mark.parametrize("quantization", ["4bit", "8bit"])
|
|
@pytest.mark.parametrize("variant", VARIANTS)
|
|
def test_attach_on_bnb_loaded_base(variant: str, quantization: str, tmp_path: Path):
|
|
"""Attach to a bnb-loaded base, run one backward step. Weight-reading variants must fail loud."""
|
|
cfg = quick_cfg(variant, tmp_path, quantization=quantization)
|
|
dtype = getattr(torch, cfg.torch_dtype)
|
|
|
|
def _do() -> float:
|
|
model, _ = benchmark.load_model_and_tokenizer(cfg.model, dtype, cfg.device, cfg.quantization)
|
|
adapter_cfg = benchmark.cfg_for_variant(cfg, dtype)
|
|
if cfg.variant == "eva":
|
|
ids = torch.randint(0, 100, (2, 8), device=cfg.device)
|
|
ll.attach(model, adapter_cfg, calibration_data=[{"input_ids": ids}])
|
|
else:
|
|
ll.attach(model, adapter_cfg)
|
|
ids = torch.randint(0, 100, (2, 8), device=cfg.device)
|
|
out = model(input_ids=ids).logits
|
|
loss = out.float().pow(2).mean()
|
|
loss.backward()
|
|
return sum(
|
|
p.grad.detach().float().norm().item()
|
|
for n, p in model.named_parameters()
|
|
if "lora_" in n and p.grad is not None
|
|
)
|
|
|
|
if variant in BNB_RAISERS:
|
|
with pytest.raises((TypeError, RuntimeError, AttributeError, ValueError)):
|
|
_do()
|
|
else:
|
|
_do() # only assert it runs without exception; tiny+bnb grads can be 0/garbage.
|
|
|