mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 17:48:59 +08:00
248 lines
8.7 KiB
Python
248 lines
8.7 KiB
Python
from __future__ import annotations
|
|
|
|
import math
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
import torch
|
|
from torch import nn
|
|
|
|
import lora_lite as ll
|
|
|
|
|
|
ARTIFACT_DIR = Path(__file__).parent / "_artifacts"
|
|
|
|
|
|
class TinyBlock(nn.Module):
|
|
def __init__(self, d: int = 64, ff: int = 128):
|
|
super().__init__()
|
|
self.q_proj = nn.Linear(d, d, bias=False)
|
|
self.k_proj = nn.Linear(d, d, bias=False)
|
|
self.v_proj = nn.Linear(d, d, bias=False)
|
|
self.o_proj = nn.Linear(d, d, bias=False)
|
|
self.gate_proj = nn.Linear(d, ff, bias=False)
|
|
self.up_proj = nn.Linear(d, ff, bias=False)
|
|
self.down_proj = nn.Linear(ff, d, bias=False)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
h = self.o_proj(self.q_proj(x) + self.k_proj(x) + self.v_proj(x))
|
|
m = self.down_proj(torch.nn.functional.silu(self.gate_proj(x)) * self.up_proj(x))
|
|
return x + h + m
|
|
|
|
|
|
class TinyModel(nn.Module):
|
|
def __init__(self, n_layers: int = 4, d: int = 64, ff: int = 128, vocab: int = 100):
|
|
super().__init__()
|
|
self.embed_tokens = nn.Embedding(vocab, d)
|
|
self.layers = nn.ModuleList([TinyBlock(d, ff) for _ in range(n_layers)])
|
|
self.lm_head = nn.Linear(d, vocab, bias=False)
|
|
self.config = type("Cfg", (), {"hidden_size": d})()
|
|
|
|
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
|
x = self.embed_tokens(ids)
|
|
for block in self.layers:
|
|
x = block(x)
|
|
return self.lm_head(x)
|
|
|
|
|
|
class FakeLinearLike(nn.Module):
|
|
def __init__(self, d_in: int = 8, d_out: int = 8):
|
|
super().__init__()
|
|
self.in_features = d_in
|
|
self.out_features = d_out
|
|
self.weight = nn.Parameter(torch.empty(d_out, d_in))
|
|
nn.init.kaiming_uniform_(self.weight, a=5**0.5)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return torch.nn.functional.linear(x, self.weight)
|
|
|
|
|
|
class FakeBnbModel(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.config = type("Cfg", (), {"hidden_size": 8})()
|
|
self.layers = nn.ModuleList([FakeLinearLike(8, 8)])
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.layers[0](x)
|
|
|
|
|
|
def cfg_for_variant(variant: str, *, training: bool = False) -> ll.LoraLiteConfig:
|
|
return ll.LoraLiteConfig(
|
|
variant=variant,
|
|
r=4,
|
|
alpha=4 if variant == "pissa" else 8,
|
|
dtype=torch.float32,
|
|
variant_kwargs={"lambda0": 0.1 if training else 0.0} if variant == "delora" else {},
|
|
)
|
|
|
|
|
|
def adapter_state(model: nn.Module) -> dict[str, torch.Tensor]:
|
|
return {k: v.detach().clone() for k, v in model.state_dict().items() if "lora_" in k}
|
|
|
|
|
|
def assert_only_lora_trainable(model: nn.Module) -> None:
|
|
trainable_names = [name for name, p in model.named_parameters() if p.requires_grad]
|
|
assert trainable_names
|
|
assert all("lora_" in name for name in trainable_names)
|
|
|
|
|
|
def assert_no_base_grads(model: nn.Module) -> None:
|
|
leaked = [name for name, p in model.named_parameters() if "lora_" not in name and p.grad is not None]
|
|
assert leaked == []
|
|
|
|
|
|
def perturb_first_adapter(model: nn.Module) -> None:
|
|
for name, p in model.named_parameters():
|
|
if "lora_lambda" in name:
|
|
with torch.no_grad():
|
|
p.add_(0.25)
|
|
return
|
|
for name, p in model.named_parameters():
|
|
if "lora_B" in name:
|
|
with torch.no_grad():
|
|
p.flatten()[0].add_(0.25)
|
|
return
|
|
raise AssertionError("no perturbable adapter parameter found")
|
|
|
|
|
|
@pytest.mark.parametrize("variant", ["lora", "pissa", "delora"])
|
|
def test_variant_identity_hook_save_load_and_training(variant: str):
|
|
ARTIFACT_DIR.mkdir(exist_ok=True)
|
|
torch.manual_seed(0)
|
|
model = TinyModel()
|
|
ids = torch.randint(0, 100, (2, 16))
|
|
|
|
with torch.no_grad():
|
|
y_base = model(ids).clone()
|
|
|
|
cfg = cfg_for_variant(variant)
|
|
handles = ll.attach(model, cfg)
|
|
assert len(handles) == 28
|
|
assert_only_lora_trainable(model)
|
|
|
|
with torch.no_grad():
|
|
y_init = model(ids).clone()
|
|
identity_err = (y_init - y_base).abs().max().item()
|
|
identity_tol = {"lora": 1e-6, "pissa": 5e-4, "delora": 1e-6}[variant]
|
|
assert identity_err < identity_tol
|
|
|
|
before_perturb = adapter_state(model)
|
|
perturb_first_adapter(model)
|
|
with torch.no_grad():
|
|
perturb_delta = (model(ids) - y_init).abs().max().item()
|
|
assert perturb_delta > 1e-7
|
|
for name, value in before_perturb.items():
|
|
model.state_dict()[name].copy_(value)
|
|
|
|
path = ARTIFACT_DIR / f"{variant}_adapter.pt"
|
|
ll.save(model, str(path))
|
|
saved = torch.load(path, weights_only=True, map_location="cpu")
|
|
assert set(saved["state"]) == set(adapter_state(model))
|
|
assert any(k.startswith("layers.0.q_proj.lora_") for k in saved["state"])
|
|
|
|
torch.manual_seed(0)
|
|
model_loaded = TinyModel()
|
|
ll.load(model_loaded, str(path))
|
|
loaded_state = adapter_state(model_loaded)
|
|
for name, value in saved["state"].items():
|
|
assert torch.equal(loaded_state[name].cpu(), value)
|
|
with torch.no_grad():
|
|
y_loaded = model_loaded(ids)
|
|
assert (y_loaded - y_init).abs().max().item() < identity_tol
|
|
|
|
torch.manual_seed(0)
|
|
train_model = TinyModel()
|
|
ll.attach(train_model, cfg_for_variant(variant, training=True))
|
|
assert_only_lora_trainable(train_model)
|
|
target = torch.randn(2, 16, 100) * 0.1
|
|
trainable = [p for p in train_model.parameters() if p.requires_grad]
|
|
opt = torch.optim.Adam(trainable, lr=0.1) if variant == "delora" else torch.optim.SGD(trainable, lr=1e-2)
|
|
losses = []
|
|
first_grad_norm = math.nan
|
|
before_train = adapter_state(train_model)
|
|
for step in range(20):
|
|
opt.zero_grad()
|
|
loss = (train_model(ids) - target).pow(2).mean()
|
|
loss.backward()
|
|
assert_no_base_grads(train_model)
|
|
grad_norm = sum(
|
|
p.grad.detach().float().norm().item()
|
|
for name, p in train_model.named_parameters()
|
|
if "lora_" in name and p.grad is not None
|
|
)
|
|
assert math.isfinite(grad_norm)
|
|
if step == 0:
|
|
first_grad_norm = grad_norm
|
|
opt.step()
|
|
losses.append(loss.item())
|
|
after_train = adapter_state(train_model)
|
|
adapter_delta = sum((after_train[k] - before_train[k]).float().norm().item() for k in before_train)
|
|
drop = (losses[0] - losses[-1]) / losses[0]
|
|
assert first_grad_norm > 0
|
|
assert adapter_delta > 0
|
|
assert drop > 0.05
|
|
|
|
|
|
def test_load_fails_on_missing_and_unexpected_lora_keys():
|
|
ARTIFACT_DIR.mkdir(exist_ok=True)
|
|
torch.manual_seed(0)
|
|
model = TinyModel()
|
|
ll.attach(model, cfg_for_variant("lora"))
|
|
good_path = ARTIFACT_DIR / "lora_good.pt"
|
|
ll.save(model, str(good_path))
|
|
blob = torch.load(good_path, weights_only=True, map_location="cpu")
|
|
|
|
missing_blob = {"cfg": blob["cfg"], "state": dict(blob["state"])}
|
|
missing_blob["state"].pop(next(iter(missing_blob["state"])))
|
|
missing_path = ARTIFACT_DIR / "lora_missing.pt"
|
|
torch.save(missing_blob, missing_path)
|
|
with pytest.raises(RuntimeError, match="missing lora keys"):
|
|
ll.load(TinyModel(), str(missing_path))
|
|
|
|
unexpected_blob = {"cfg": blob["cfg"], "state": dict(blob["state"])}
|
|
unexpected_blob["state"]["layers.0.q_proj.lora_extra"] = torch.zeros(1)
|
|
unexpected_path = ARTIFACT_DIR / "lora_unexpected.pt"
|
|
torch.save(unexpected_blob, unexpected_path)
|
|
with pytest.raises(RuntimeError, match="unexpected lora keys"):
|
|
ll.load(TinyModel(), str(unexpected_path))
|
|
|
|
|
|
def test_no_target_layers_is_loud_failure():
|
|
cfg = ll.LoraLiteConfig(variant="lora", target_names=("definitely_missing",))
|
|
with pytest.raises(RuntimeError, match="no target layers"):
|
|
ll.attach(TinyModel(), cfg)
|
|
|
|
|
|
@pytest.mark.parametrize("variant", ["lora", "delora"])
|
|
def test_structural_non_linear_target_trains_for_forward_only_variants(variant: str):
|
|
torch.manual_seed(0)
|
|
model = FakeBnbModel()
|
|
x = torch.randn(2, 3, 8)
|
|
y_base = model(x).detach()
|
|
cfg = ll.LoraLiteConfig(
|
|
variant=variant,
|
|
r=2,
|
|
alpha=4,
|
|
dtype=torch.float32,
|
|
target_roles=(),
|
|
variant_kwargs={"lambda0": 0.0} if variant == "delora" else {},
|
|
)
|
|
ll.attach(model, cfg)
|
|
y_init = model(x)
|
|
assert (y_init.detach() - y_base).abs().max().item() < 1e-6
|
|
loss = y_init.pow(2).mean()
|
|
loss.backward()
|
|
assert_no_base_grads(model)
|
|
adapter_grad_norm = sum(
|
|
p.grad.detach().float().norm().item()
|
|
for name, p in model.named_parameters()
|
|
if "lora_" in name and p.grad is not None
|
|
)
|
|
assert adapter_grad_norm > 0
|
|
|
|
|
|
def test_pissa_rejects_structural_non_linear_target():
|
|
cfg = ll.LoraLiteConfig(variant="pissa", r=2, alpha=2, dtype=torch.float32, target_roles=())
|
|
with pytest.raises(TypeError, match="plain nn.Linear"):
|
|
ll.attach(FakeBnbModel(), cfg) |