Files
lora-lite/tests/test_metamath_smoke.py
wassname c792ad3e5f Add LoRA-XS variant: train only r×r core R between frozen SVD factors
Bałazy et al. 2024 (arxiv 2405.17604). A=diag(Sr)Vhr, B=Ur frozen from
top-r SVD of W (W left intact); only the r×r R is trained, init normal(0,1e-5)
so the adapter ~ identity at t=0. ~25k params at r=32 (24 down_proj targets).
justfile: alpha=r (scale=1) and lr=4e-3, matching the ref LLaMA math config.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
2026-06-18 19:48:40 +08:00

129 lines
4.7 KiB
Python

"""End-to-end smoke: run the metamath benchmark in probe mode for every variant.
Probe mode trains a few steps on tiny-random Llama, saves the adapter, reloads
it onto a fresh model, and asserts the trained logits match within tol. That's
the train+save+load round-trip on a real HF model, one test per variant.
A second test attaches each variant on top of a 4bit/8bit-loaded base and runs
one backward step. PiSSA/DoRA/AntiPaSTO/EVA must fail loud on quantized weights;
the rest must produce nonzero adapter grads. We do not run the full probe under
bnb because tiny-random + bnb dequant produces NaN logits unrelated to adapter
correctness.
"""
from __future__ import annotations
import importlib.util
import sys
from pathlib import Path
import pytest
import torch
import lora_lite as ll
_SCRIPTS = Path(__file__).resolve().parent.parent / "scripts"
sys.path.insert(0, str(_SCRIPTS)) # benchmark does `from _cost import ...` (sibling module)
SPEC = importlib.util.spec_from_file_location(
"metamath_benchmark",
_SCRIPTS / "metamath_gsm8k_benchmark.py",
)
benchmark = importlib.util.module_from_spec(SPEC)
sys.modules[SPEC.name] = benchmark
SPEC.loader.exec_module(benchmark)
VARIANTS = ["lora", "lora_xs", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva",
"antipasto", "road"]
# Variants that fail loud when attached on a bnb-loaded base (read dense weight in init).
# delora/eva also read weight but currently silently dequant -- they produce sane attach,
# so we don't expect a raise from them in the attach-only smoke.
BNB_RAISERS = {"pissa", "dora", "antipasto", "lora_xs"}
TINY_MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM"
HAS_CUDA = torch.cuda.is_available()
HAS_BNB = importlib.util.find_spec("bitsandbytes") is not None
def quick_cfg(variant: str, tmp_path: Path, quantization: str = "none") -> "benchmark.BenchmarkConfig":
target_name = (
[r"(k_proj|v_proj)$"] if variant == "ia3"
else [r"(down_proj)$"] if variant == "ia3_ff"
else [r"(q_proj|v_proj)$"]
)
cfg = benchmark.BenchmarkConfig(
model=TINY_MODEL,
variant=variant,
mode="probe",
device="cuda" if HAS_CUDA else "cpu",
torch_dtype="float16" if quantization != "none" else "float32",
quantization=quantization,
r=4,
alpha=8,
target_name=target_name,
layers="all",
steps=2,
batch_size=2,
batch_size_eval=4,
max_train_samples=8,
max_eval_samples=4,
max_valid_samples=4,
max_test_samples=4,
max_seq_length=128,
max_new_tokens=8,
lr=5e-3,
road_group_size=8,
seed=0,
log_examples=0,
log_every=1000,
output_dir=tmp_path / "out",
)
return cfg
@pytest.mark.parametrize("variant", VARIANTS)
def test_metamath_quick_train_save_load(variant: str, tmp_path: Path):
"""Train 2 steps, save, reload onto fresh tiny model, logits match within tol."""
cfg = quick_cfg(variant, tmp_path)
result = benchmark.run(cfg)
assert result["train"]["base_grad_leaks"] == 0
assert result["train"]["first_grad_norm"] > 0
assert result["train"]["adapter_delta"] > 0
probe = result.get("probe") or {}
assert "reload_err" in probe
assert probe["reload_err"] < cfg.reload_tol
@pytest.mark.skipif(not (HAS_CUDA and HAS_BNB), reason="needs CUDA + bitsandbytes")
@pytest.mark.parametrize("quantization", ["4bit", "8bit"])
@pytest.mark.parametrize("variant", VARIANTS)
def test_attach_on_bnb_loaded_base(variant: str, quantization: str, tmp_path: Path):
"""Attach to a bnb-loaded base, run one backward step. Weight-reading variants must fail loud."""
cfg = quick_cfg(variant, tmp_path, quantization=quantization)
dtype = getattr(torch, cfg.torch_dtype)
def _do() -> float:
model, _ = benchmark.load_model_and_tokenizer(cfg.model, dtype, cfg.device, cfg.quantization)
adapter_cfg = benchmark.cfg_for_variant(cfg, dtype)
if cfg.variant == "eva":
ids = torch.randint(0, 100, (2, 8), device=cfg.device)
ll.attach(model, adapter_cfg, calibration_data=[{"input_ids": ids}])
else:
ll.attach(model, adapter_cfg)
ids = torch.randint(0, 100, (2, 8), device=cfg.device)
out = model(input_ids=ids).logits
loss = out.float().pow(2).mean()
loss.backward()
return sum(
p.grad.detach().float().norm().item()
for n, p in model.named_parameters()
if "lora_" in n and p.grad is not None
)
if variant in BNB_RAISERS:
with pytest.raises((TypeError, RuntimeError, AttributeError, ValueError)):
_do()
else:
_do() # only assert it runs without exception; tiny+bnb grads can be 0/garbage.