simpler test

This commit is contained in:
wassname
2026-04-27 09:47:07 +08:00
parent b60a8c3f9b
commit 24ba8deb02
10 changed files with 566 additions and 41 deletions
+1
View File
@@ -8,6 +8,7 @@ dependencies = [
"torch>=2.1",
"einops>=0.7",
"jaxtyping>=0.2.34",
"safetensors>=0.5",
]
keywords = ["lora", "pytorch", "peft", "adapters", "llm"]
classifiers = [
+41 -28
View File
@@ -33,6 +33,7 @@ CFG_BY_VARIANT = {
"hra": ll.HRAConfig,
"eva": ll.EVAConfig,
"antipasto": ll.AntiPaSTOConfig,
"road": ll.RoadConfig,
}
@@ -41,7 +42,7 @@ class BenchmarkConfig:
"""MetaMathQA -> GSM8K benchmark config. Tyro turns this into the CLI."""
model: str = "Qwen/Qwen3-0.6B-Base"
variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto"] = "lora"
variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "road"] = "lora"
mode: Literal["benchmark", "probe"] = "benchmark"
device: str = "cuda"
torch_dtype: str = "bfloat16"
@@ -49,6 +50,7 @@ class BenchmarkConfig:
r: int = 32
alpha: float = 64.0
delora_lambda0: float = 0.1
road_group_size: int = 64
target_name: list[str] = field(default_factory=lambda: list(DEFAULT_TARGETS))
layers: str = "all"
train_dataset: str = "meta-math/MetaMathQA"
@@ -118,6 +120,8 @@ def parse_layers(text: str) -> tuple[int, ...] | None:
def cfg_for_variant(args: BenchmarkConfig, dtype: torch.dtype) -> ll.AdapterConfig:
extra = {"lambda0": args.delora_lambda0} if args.variant == "delora" else {}
if args.variant == "road":
extra = {"group_size": args.road_group_size}
return CFG_BY_VARIANT[args.variant](
r=args.r,
alpha=args.r if args.variant == "pissa" else args.alpha,
@@ -147,7 +151,7 @@ def count_base_grad_leaks(model: torch.nn.Module) -> int:
def perturb_first_adapter(model: torch.nn.Module) -> None:
priority = ("lora_B", "lora_g", "lora_U", "lora_A", "lora_lambda", "lora_gate", "lora_delta_s", "lora_rot_T", "lora_m")
priority = ("lora_B", "lora_g", "lora_U", "lora_A", "lora_lambda", "lora_gate", "lora_delta_s", "lora_rot_T", "lora_m", "lora_road_theta", "lora_road_alpha")
for key in priority:
for _, p in model.named_parameters():
if p.requires_grad and key in _:
@@ -409,11 +413,12 @@ def check_probe_reload(
loaded_model, _ = load_model_and_tokenizer(args.model, getattr(torch, args.torch_dtype), args.device, args.quantization)
loaded_model.eval()
ll.load(loaded_model, str(adapter_path))
saved = torch.load(adapter_path, weights_only=True, map_location="cpu")
from safetensors.torch import load_file
saved_sd = load_file(str(adapter_path), device="cpu")
loaded_state = adapter_state(loaded_model)
if set(saved["state"]) != set(loaded_state):
if set(saved_sd) != set(loaded_state):
raise AssertionError("loaded adapter keys differ from saved adapter keys")
for name, value in saved["state"].items():
for name, value in saved_sd.items():
if not torch.equal(loaded_state[name].cpu(), value):
raise AssertionError(f"loaded adapter tensor differs: {name}")
logits_loaded = loaded_model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).logits.detach().clone()
@@ -426,12 +431,21 @@ def check_probe_reload(
return {"reload_err": reload_err, "saved_tensors": len(saved["state"])}
def print_final_report(row: dict[str, Any], result_path: Path) -> None:
print("SHOULD: grad>0, dθ>0, base_grad_leaks=0, valid/test fields present; probeΔ<0 is good but not required for tiny random smoke. ELSE adapter or eval wiring is dead/wrong.")
def print_final_report(row: dict[str, Any], result_path: Path, mode: str) -> None:
# BLUF: status line first so log tails are immediately readable
cue = "🟢" if row.get("base_grad_leaks", 0) == 0 and row.get("grad", 0) > 0 else "🔴"
n = row.get("samples", "?")
print(f"{cue} test_acc={row['test_acc']:.4g} valid_acc={row['valid_acc']:.4g} grad={row['grad']:.3g} dθ={row['']:.3g} base_grad_leaks={row['base_grad_leaks']} N={n}")
print("SHOULD: grad>0, dθ>0, base_grad_leaks=0; test/valid_acc meaningful only in benchmark mode. ELSE adapter or eval wiring is dead/wrong.")
# ordered: most important / shortest columns first
display_keys = ["variant", "test_acc", "valid_acc", "grad", "", "base_grad_leaks", "steps", "samples", "loss0", "lossN", "commit"]
if "perturb" in row:
display_keys += ["perturb", "reload"]
display_keys += ["run_id"]
display_row = {k: row[k] for k in display_keys if k in row}
print(tabulate([display_row], headers="keys", tablefmt="tsv", floatfmt=".4g"))
print(f"argv: {' '.join(sys.argv)} N={n} mode={mode}")
print(f"out: {result_path}")
print(f"argv: {' '.join(sys.argv)}")
print(f"main metric: test_acc={row['test_acc']:.4g} valid_acc={row['valid_acc']:.4g} steps={row['steps']} samples={row['samples']}")
print(tabulate([row], headers="keys", tablefmt="tsv", floatfmt=".4g"))
def current_git_commit() -> str:
@@ -447,25 +461,24 @@ def append_results_row(
result: dict[str, Any],
run_commit: str,
) -> tuple[Path, Path]:
results_dir = args.output_dir / "results"
results_dir = args.output_dir
results_dir.mkdir(parents=True, exist_ok=True)
tsv_path = results_dir / "benchmark_results.tsv"
lock_path = results_dir / "benchmark_results.tsv.lock"
tsv_path = results_dir / "summary.tsv"
lock_path = results_dir / "summary.tsv.lock"
finished_at = datetime.now(timezone.utc).isoformat(timespec="seconds")
finished_label = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
snapshot_path = results_dir / f"{result['run_id']}__{finished_label}.json"
snapshot_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
row = {
"time_utc": finished_at,
"commit": run_commit,
"test_acc": result["test_acc"],
"valid_acc": result["valid_acc"],
"method": args.variant,
"model": args.model,
"mode": args.mode,
"valid_accuracy": result["valid_accuracy"],
"test_accuracy": result["test_accuracy"],
"steps": args.steps,
"samples": result["train_samples"],
"wall_time_s": result["wall_time_s"],
"model": args.model,
"commit": run_commit[:12],
"wall_time_s": round(result["wall_time_s"]),
"time_utc": finished_at,
"argv": " ".join(sys.argv),
"result_json": str(snapshot_path),
"latest_result_json": str(result_path),
@@ -517,7 +530,7 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
valid_metrics = evaluate(model, tokenizer, datasets["valid"], args, "valid")
test_metrics = evaluate(model, tokenizer, datasets["test"], args, "test")
adapter_path = out_dir / "adapter.pt"
adapter_path = out_dir / "adapter.safetensors"
ll.save(model, str(adapter_path))
if args.mode == "probe":
model.eval()
@@ -555,8 +568,8 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
"weight_decay": args.weight_decay,
"lr_scheduler": "cosine",
"grad_norm_clip": args.grad_norm_clip,
"valid_accuracy": valid_metrics["accuracy"],
"test_accuracy": test_metrics["accuracy"],
"valid_acc": valid_metrics["accuracy"],
"test_acc": test_metrics["accuracy"],
"train": train_metrics,
"valid": valid_metrics,
"test": test_metrics,
@@ -566,12 +579,12 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
}
result_path = out_dir / "result.json"
result_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
results_tsv_path, result_snapshot_path = append_results_row(args, result_path, result, run_commit)
result["results_tsv_path"] = str(results_tsv_path)
result["result_snapshot_path"] = str(result_snapshot_path)
if args.mode == "benchmark":
results_tsv_path, result_snapshot_path = append_results_row(args, result_path, result, run_commit)
result["results_tsv_path"] = str(results_tsv_path)
result["result_snapshot_path"] = str(result_snapshot_path)
result["commit"] = run_commit
result_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
commit_prefix = run_commit[:12]
row = {
"run_id": run_id,
@@ -592,7 +605,7 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
if probe_metrics is not None:
row["perturb"] = probe_metrics["perturb_delta"]
row["reload"] = probe_metrics["reload_err"]
print_final_report(row, result_path)
print_final_report(row, result_path, args.mode)
return result
+2
View File
@@ -20,6 +20,7 @@ from .variants.dora import DoRAConfig
from .variants.hra import HRAConfig
from .variants.eva import EVAConfig
from .variants.antipasto import AntiPaSTOConfig
from .variants.road import RoadConfig
__all__ = [
"AdapterConfig",
@@ -32,6 +33,7 @@ __all__ = [
"HRAConfig",
"EVAConfig",
"AntiPaSTOConfig",
"RoadConfig",
"attach",
"detach",
"save",
+13 -9
View File
@@ -1,5 +1,6 @@
"""attach / detach / save / load. The whole runtime."""
from __future__ import annotations
import json
import torch
from torch import nn
from torch.utils.hooks import RemovableHandle
@@ -121,19 +122,22 @@ def save(model: nn.Module, path: str) -> None:
if state is None:
raise RuntimeError("no adapter attached; call attach() first")
sd = {k: v.detach().cpu() for k, v in model.state_dict().items() if "lora_" in k}
blob = {
"cfg": state["cfg"].to_dict(),
"state": sd,
"base_fp": _base_weight_fingerprint(model),
metadata = {
"cfg": json.dumps(state["cfg"].to_dict()),
"base_fp": json.dumps(_base_weight_fingerprint(model)),
}
torch.save(blob, path)
from safetensors.torch import save_file
save_file(sd, path, metadata=metadata)
def load(model: nn.Module, path: str) -> list[RemovableHandle]:
blob = torch.load(path, weights_only=True, map_location="cpu")
cfg = AdapterConfig.from_dict(blob["cfg"])
from safetensors.torch import load_file, safe_open
with safe_open(path, framework="pt", device="cpu") as f:
metadata = f.metadata()
sd = load_file(path, device="cpu")
cfg = AdapterConfig.from_dict(json.loads(metadata["cfg"]))
handles = attach(model, cfg, _skip_group_init=True) # creates empty params; data-driven inits restored from state_dict
missing, unexpected = model.load_state_dict(blob["state"], strict=False)
missing, unexpected = model.load_state_dict(sd, strict=False)
expected_lora = {k for k in model.state_dict() if "lora_" in k}
missing_lora = sorted(expected_lora.intersection(missing))
if missing_lora:
@@ -141,7 +145,7 @@ def load(model: nn.Module, path: str) -> list[RemovableHandle]:
unexpected_lora = [k for k in unexpected if "lora_" in k]
if unexpected_lora:
raise RuntimeError(f"unexpected lora keys in checkpoint: {unexpected_lora}")
saved_fp = blob.get("base_fp", {})
saved_fp = json.loads(metadata.get("base_fp", "{}"))
if saved_fp:
cur_fp = _base_weight_fingerprint(model)
diffs = [k for k in saved_fp if saved_fp[k] != cur_fp.get(k)]
+1 -1
View File
@@ -1 +1 @@
from . import lora, pissa, delora, ia3, dora, hra, eva, antipasto # noqa: F401 side-effect: register
from . import lora, pissa, delora, ia3, dora, hra, eva, antipasto, road # noqa: F401 side-effect: register
-1
View File
@@ -37,7 +37,6 @@ class LoRA:
@staticmethod
def init(layer: nn.Module, cfg) -> None:
# B is zeros => delta=0 at t=0; identity invariant holds.
return
@staticmethod
+137
View File
@@ -0,0 +1,137 @@
"""ROAD: Rotation ADaptation. https://arxiv.org/abs/2409.00119
ROAD applies a learned output-space block rotation/scaling after the frozen base
layer:
y' = R y = R (W x + b)
This matches PEFT's unmerged forward path and fits lora-lite as a simple output
hook. We implement the three PEFT variants (`road_1`, `road_2`, `road_4`) and
skip merge/unmerge because this library keeps adapters as hooks.
Refs:
- peft: https://github.com/huggingface/peft/blob/6030f9160ed2fc17220f6f41382a66f1257b6a93/src/peft/tuners/road/layer.py
"""
from dataclasses import dataclass
from typing import Literal
import torch
from jaxtyping import Float
from torch import nn, Tensor as T
from ..config import AdapterConfig, register_config
from ..variant import ParamSpec, register
RoadVariant = Literal["road_1", "road_2", "road_4"]
@register_config
@dataclass
class RoadConfig(AdapterConfig):
variant: str = "road"
road_variant: RoadVariant = "road_1"
group_size: int = 64
def _road_param_size(d_out: int, road_variant: str) -> int:
if road_variant == "road_1":
return d_out // 2
if road_variant == "road_2":
return d_out
if road_variant == "road_4":
return d_out * 2
raise ValueError(f"road_variant must be 'road_1', 'road_2', or 'road_4', got {road_variant!r}")
def _validate_group_geometry(d_out: int, group_size: int) -> None:
if group_size <= 0 or group_size % 2 != 0:
raise ValueError(f"ROAD group_size must be positive and even, got {group_size}")
if d_out % group_size != 0:
raise ValueError(f"ROAD d_out={d_out} must be divisible by group_size={group_size}")
def _prepare_cols(
road_variant: str,
group_size: int,
road_theta: torch.Tensor,
road_alpha: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
if road_variant == "road_1":
# One θ/α per pair. Reuse it for both rows of each 2D rotation block.
road_theta = road_theta.reshape(-1, group_size // 2).repeat_interleave(2, dim=0).flatten()
road_alpha = road_alpha.reshape(-1, group_size // 2).repeat_interleave(2, dim=0).flatten()
first_col = road_alpha * road_theta.cos()
second_col = road_alpha * road_theta.sin()
elif road_variant == "road_2":
# One θ/α per output coordinate.
first_col = road_alpha * road_theta.cos()
second_col = road_alpha * road_theta.sin()
elif road_variant == "road_4":
# Independent θ/α for the first and second column contributions.
road_theta = road_theta.reshape(-1, 2, group_size)
road_alpha = road_alpha.reshape(-1, 2, group_size)
first_col = road_alpha[:, 0, :].flatten() * road_theta[:, 0, :].cos().flatten()
second_col = road_alpha[:, 1, :].flatten() * road_theta[:, 1, :].sin().flatten()
else:
raise ValueError(f"road_variant must be 'road_1', 'road_2', or 'road_4', got {road_variant!r}")
return first_col, second_col
def _apply_road(
road_variant: str,
group_size: int,
road_theta: torch.Tensor,
road_alpha: torch.Tensor,
y: Float[T, '*B o'],
) -> Float[T, '*B o']:
first_col, second_col = _prepare_cols(road_variant, group_size, road_theta, road_alpha)
y_grouped = y.reshape(-1, 2, group_size // 2)
y1 = y_grouped[:, 0, :]
y2 = y_grouped[:, 1, :]
rotate_half_y = torch.stack((-y2, y1), dim=1).reshape(y.shape)
return y * first_col + rotate_half_y * second_col
def _road_matrix(
road_variant: str,
group_size: int,
road_theta: torch.Tensor,
road_alpha: torch.Tensor,
) -> torch.Tensor:
"""Explicit PEFT merge matrix. Used for tests and small-debug inspection."""
first_col, second_col = _prepare_cols(road_variant, group_size, road_theta, road_alpha)
size = second_col.shape[0]
output = torch.diag(first_col)
swapped_second_col = second_col.reshape(-1, 2, group_size // 2)[:, [1, 0], :].flatten()
rotated_diag_second_col = torch.diag(swapped_second_col).reshape(-1, 2, group_size // 2, size)[:, [1, 0], :, :]
rotated_diag_second_col[:, 0, :, :] *= -1
return output + rotated_diag_second_col.reshape(size, size)
@register
class ROAD:
name = "road"
@staticmethod
def param_specs(d_in: int, d_out: int, cfg: RoadConfig) -> dict[str, ParamSpec]:
_validate_group_geometry(d_out, cfg.group_size)
size = _road_param_size(d_out, cfg.road_variant)
return {
"lora_road_theta": ParamSpec((size,), init="zeros", trainable=True),
"lora_road_alpha": ParamSpec((size,), init="ones", trainable=True),
}
@staticmethod
def init(layer: nn.Module, cfg: RoadConfig) -> None:
return
@staticmethod
def forward(
layer: nn.Module,
x: Float[T, '*B i'],
y: Float[T, '*B o'],
) -> Float[T, '*B o']:
del x
cfg = layer._lora_cfg
y_cast = y.to(layer.lora_road_theta.dtype)
return _apply_road(cfg.road_variant, cfg.group_size, layer.lora_road_theta, layer.lora_road_alpha, y_cast)
+366
View File
@@ -0,0 +1,366 @@
"""Per-variant attach + train + save + load round-trip, plus surgical regressions.
The big invariant is the parametrized train_save_load test: identity at t=0,
gradient flow on a real loss, then save -> reload onto a fresh model and
confirm the trained outputs survive the round-trip. Cheap on CPU.
"""
from __future__ import annotations
from pathlib import Path
import pytest
import torch
from torch import nn
import lora_lite as ll
CFG_BY_VARIANT = {
"lora": ll.LoRAConfig,
"pissa": ll.PiSSAConfig,
"delora": ll.DeLoRAConfig,
"ia3": ll.IA3Config,
"ia3_ff": ll.IA3FFConfig,
"dora": ll.DoRAConfig,
"hra": ll.HRAConfig,
"eva": ll.EVAConfig,
"antipasto": ll.AntiPaSTOConfig,
"road": ll.RoadConfig,
}
# Per-variant identity tolerance at t=0 (after attach, before any step).
# fp32 SVD round-trip + per-row norm = looser tolerance for pissa/dora/antipasto.
IDENTITY_TOL = {
"lora": 1e-6,
"pissa": 5e-4,
"delora": 1e-6,
"ia3": 1e-6,
"ia3_ff": 1e-6,
"dora": 5e-5,
"hra": 5e-6,
"eva": 1e-6,
"antipasto": 5e-4,
"road": 1e-6,
}
class TinyBlock(nn.Module):
def __init__(self, d: int = 64, ff: int = 128):
super().__init__()
self.q_proj = nn.Linear(d, d, bias=False)
self.k_proj = nn.Linear(d, d, bias=False)
self.v_proj = nn.Linear(d, d, bias=False)
self.o_proj = nn.Linear(d, d, bias=False)
self.gate_proj = nn.Linear(d, ff, bias=False)
self.up_proj = nn.Linear(d, ff, bias=False)
self.down_proj = nn.Linear(ff, d, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.o_proj(self.q_proj(x) + self.k_proj(x) + self.v_proj(x))
m = self.down_proj(torch.nn.functional.silu(self.gate_proj(x)) * self.up_proj(x))
return x + h + m
class TinyModel(nn.Module):
def __init__(self, n_layers: int = 4, d: int = 64, ff: int = 128, vocab: int = 100):
super().__init__()
self.embed_tokens = nn.Embedding(vocab, d)
self.layers = nn.ModuleList([TinyBlock(d, ff) for _ in range(n_layers)])
self.lm_head = nn.Linear(d, vocab, bias=False)
self.config = type("Cfg", (), {"hidden_size": d})()
def forward(self, ids: torch.Tensor) -> torch.Tensor:
x = self.embed_tokens(ids)
for block in self.layers:
x = block(x)
return self.lm_head(x)
class FakeLinearLike(nn.Module):
"""linear-like, but not nn.Linear: stand-in for bnb 4/8-bit modules."""
def __init__(self, d_in: int = 8, d_out: int = 8):
super().__init__()
self.in_features = d_in
self.out_features = d_out
self.weight = nn.Parameter(torch.empty(d_out, d_in))
nn.init.kaiming_uniform_(self.weight, a=5 ** 0.5)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.linear(x, self.weight)
class FakeBnbModel(nn.Module):
def __init__(self):
super().__init__()
self.config = type("Cfg", (), {"hidden_size": 8})()
self.layers = nn.ModuleList([FakeLinearLike(8, 8)])
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layers[0](x)
def cfg_for(variant: str) -> ll.AdapterConfig:
return CFG_BY_VARIANT[variant](
r=4,
alpha=8,
dtype=torch.float32,
)
def attach_with_calib(model: nn.Module, cfg: ll.AdapterConfig, ids: torch.Tensor) -> None:
if cfg.variant == "eva":
calib = [ids for _ in range(2)]
ll.attach(model, cfg, calibration_data=calib)
else:
ll.attach(model, cfg)
def trainable_grad_norm(model: nn.Module) -> float:
return sum(
p.grad.detach().float().norm().item()
for n, p in model.named_parameters()
if "lora_" in n and p.grad is not None
)
@pytest.mark.parametrize("variant", list(CFG_BY_VARIANT))
def test_train_save_load(variant: str, tmp_path: Path):
"""Identity at t=0, one SGD step, save, reload onto fresh model, outputs match."""
torch.manual_seed(0)
model = TinyModel()
ids = torch.randint(0, 100, (2, 16))
with torch.no_grad():
y_base = model(ids).clone()
cfg = cfg_for(variant)
attach_with_calib(model, cfg, ids)
trainable = [p for p in model.parameters() if p.requires_grad]
assert trainable
assert all("lora_" in n for n, p in model.named_parameters() if p.requires_grad)
with torch.no_grad():
y_init = model(ids).clone()
assert (y_init - y_base).abs().max().item() < IDENTITY_TOL[variant]
target = torch.randn_like(y_init) * 0.1
opt = torch.optim.SGD(trainable, lr=1e-2)
opt.zero_grad()
loss = (model(ids) - target).pow(2).mean()
loss.backward()
leaked = [n for n, p in model.named_parameters() if "lora_" not in n and p.grad is not None]
assert leaked == []
assert trainable_grad_norm(model) > 0
opt.step()
with torch.no_grad():
y_trained = model(ids).clone()
path = tmp_path / "adapter.pt"
ll.save(model, str(path))
torch.manual_seed(0)
model_loaded = TinyModel()
ll.load(model_loaded, str(path)) # EVA load skips group_init; calibration_data not needed
with torch.no_grad():
y_loaded = model_loaded(ids)
assert (y_loaded - y_trained).abs().max().item() < max(IDENTITY_TOL[variant], 1e-5)
@pytest.mark.parametrize("variant", ["lora", "delora", "ia3", "hra", "road"])
def test_hook_only_variants_attach_to_non_linear_target(variant: str):
"""bnb-style targets are linear-like but not nn.Linear; hook-only variants must accept them."""
extra = {"lambda0": 0.1} if variant == "delora" else {"group_size": 8} if variant == "road" else {}
cfg = CFG_BY_VARIANT[variant](r=2, alpha=4, dtype=torch.float32, target_roles=(), **extra)
model = FakeBnbModel()
ll.attach(model, cfg)
x = torch.randn(2, 3, 8)
model(x).pow(2).mean().backward()
assert trainable_grad_norm(model) > 0
@pytest.mark.parametrize("variant", ["pissa", "dora", "antipasto"])
def test_weight_reading_variants_reject_non_linear(variant: str):
r = 4 if variant == "antipasto" else 2 # antipasto needs r % block_size==0
cfg = CFG_BY_VARIANT[variant](r=r, alpha=r, dtype=torch.float32, target_roles=())
with pytest.raises(TypeError, match="plain nn.Linear"):
ll.attach(FakeBnbModel(), cfg)
def test_save_load_strict_keys(tmp_path: Path):
import json
from safetensors.torch import load_file, save_file
torch.manual_seed(0)
model = TinyModel()
ll.attach(model, ll.LoRAConfig(r=4, alpha=8, dtype=torch.float32))
p = tmp_path / "lora.safetensors"
ll.save(model, str(p))
sd = load_file(str(p), device="cpu")
# missing key: drop first lora key
missing_sd = dict(sd)
dropped_key = next(iter(missing_sd))
del missing_sd[dropped_key]
from safetensors import safe_open
with safe_open(str(p), framework="pt", device="cpu") as f:
meta = f.metadata()
save_file(missing_sd, str(p), metadata=meta)
with pytest.raises(RuntimeError, match="missing lora keys"):
ll.load(TinyModel(), str(p))
# unexpected key: add a bogus lora key
bad_sd = dict(sd)
bad_sd["layers.0.q_proj.lora_extra"] = torch.zeros(1)
save_file(bad_sd, str(p), metadata=meta)
with pytest.raises(RuntimeError, match="unexpected lora keys"):
ll.load(TinyModel(), str(p))
def test_no_target_layers_is_loud():
cfg = ll.LoRAConfig(target_names=("definitely_missing",))
with pytest.raises(RuntimeError, match="no target layers"):
ll.attach(TinyModel(), cfg)
def test_eva_requires_calibration():
"""EVA's group_init must error loudly if calibration_data is missing."""
with pytest.raises(ValueError, match="calibration_data"):
ll.attach(TinyModel(), ll.EVAConfig(r=4, alpha=8, dtype=torch.float32))
def test_delora_default_has_live_step0_gradient():
"""Default lambda0 must be nonzero; B=0 preserves identity while B gets gradient."""
torch.manual_seed(0)
model = TinyModel(n_layers=1)
ids = torch.randint(0, 100, (2, 8))
ll.attach(model, ll.DeLoRAConfig(r=4, alpha=8, dtype=torch.float32))
assert model.layers[0].q_proj.lora_lambda.item() == pytest.approx(15.0)
loss = model(ids).pow(2).mean()
loss.backward()
b_grad = model.layers[0].q_proj.lora_B.grad.detach().abs().max().item()
assert b_grad > 0
def test_pissa_identity_with_nonunit_scale():
"""Regression: PiSSA must pre-divide S by alpha/r, not require alpha == r."""
torch.manual_seed(0)
model = TinyModel(n_layers=1)
ids = torch.randint(0, 100, (2, 8))
with torch.no_grad():
y_base = model(ids).clone()
ll.attach(model, ll.PiSSAConfig(r=4, alpha=8, dtype=torch.float32))
with torch.no_grad():
y = model(ids)
assert (y - y_base).abs().max().item() < IDENTITY_TOL["pissa"]
def test_antipasto_blockwise_rotation_matches_explicit_blockdiag():
"""The einsum/rearrange path must equal the old explicit blockdiag math."""
from lora_lite.variants.antipasto import _build_rotation
torch.manual_seed(0)
n_blocks, bs, d_in, d_out = 3, 4, 7, 5
r = n_blocks * bs
rot_T = torch.randn(n_blocks, bs * (bs - 1) // 2) * 0.1
Vh = torch.randn(r, d_in)
U = torch.randn(d_out, r)
R_blocks = _build_rotation(rot_T, bs, 0.5)
R = torch.block_diag(*list(R_blocks))
Vh_blocks = torch.reshape(Vh, (n_blocks, bs, d_in))
Vh_rot = torch.einsum("nab,nbi->nai", R_blocks, Vh_blocks).reshape(r, d_in)
U_blocks = torch.reshape(U, (d_out, n_blocks, bs))
U_rot = torch.einsum("dnb,ncb->dnc", U_blocks, R_blocks).reshape(d_out, r)
assert (Vh_rot - R @ Vh).abs().max().item() < 1e-6
assert (U_rot - U @ R.T).abs().max().item() < 1e-6
def test_dora_bias_passthrough():
"""Regression: DoRA must NOT scale bias; identity holds with bias=True at t=0."""
torch.manual_seed(0)
d = 16
layer = nn.Linear(d, d, bias=True)
x = torch.randn(2, d)
y_base = layer(x).detach()
class Wrap(nn.Module):
def __init__(self, lin):
super().__init__()
self.config = type("Cfg", (), {"hidden_size": d})()
self.layers = nn.ModuleList([lin])
def forward(self, x):
return self.layers[0](x)
model = Wrap(layer)
ll.attach(model, ll.DoRAConfig(r=2, alpha=4, dtype=torch.float32, target_roles=()))
with torch.no_grad():
y = model(x)
assert (y - y_base).abs().max().item() < 1e-5
def test_hra_forward_is_x_R_T():
"""HRA must apply x @ R^T (loop i = r-1 down to 0). Asymmetric U makes order observable."""
torch.manual_seed(0)
d = 8
layer = nn.Linear(d, d, bias=False)
x = torch.randn(2, 3, d)
class Wrap(nn.Module):
def __init__(self, lin):
super().__init__()
self.config = type("Cfg", (), {"hidden_size": d})()
self.layers = nn.ModuleList([lin])
def forward(self, x):
return self.layers[0](x)
model = Wrap(layer)
ll.attach(model, ll.HRAConfig(r=4, alpha=4, dtype=torch.float32, target_roles=()))
# break paired symmetry so order matters
with torch.no_grad():
layer.lora_U.add_(0.1 * torch.randn_like(layer.lora_U))
U = layer.lora_U
R = torch.eye(d)
for i in range(U.shape[0]):
u = U[i]
sq = (u * u).sum().clamp_min(1e-12)
R = R - (2.0 / sq) * torch.outer(R @ u, u)
with torch.no_grad():
y_adapt = model(x)
y_ref = torch.nn.functional.linear(x, layer.weight @ R)
assert (y_adapt - y_ref).abs().max().item() < 1e-5
@pytest.mark.parametrize("road_variant", ["road_1", "road_2", "road_4"])
def test_road_apply_matches_explicit_matrix(road_variant: str):
"""Fast elementwise ROAD path must match PEFT's explicit R @ y matrix construction."""
from lora_lite.variants.road import _apply_road, _road_matrix, _road_param_size
torch.manual_seed(0)
d_out = 16
group_size = 8
size = _road_param_size(d_out, road_variant)
theta = torch.randn(size) * 0.2
alpha = torch.randn(size) * 0.1 + 1.0
y = torch.randn(2, 3, d_out)
y_fast = _apply_road(road_variant, group_size, theta, alpha, y)
R = _road_matrix(road_variant, group_size, theta, alpha)
y_ref = torch.einsum("oi,...i->...o", R, y)
assert (y_fast - y_ref).abs().max().item() < 1e-6
def test_road_invalid_group_size_is_loud():
with pytest.raises(ValueError, match="positive and even"):
ll.attach(TinyModel(), ll.RoadConfig(group_size=7))
with pytest.raises(ValueError, match="divisible"):
ll.attach(TinyModel(), ll.RoadConfig(group_size=48))
+2 -1
View File
@@ -31,7 +31,7 @@ sys.modules[SPEC.name] = benchmark
SPEC.loader.exec_module(benchmark)
VARIANTS = ["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto"]
VARIANTS = ["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "road"]
# Variants that fail loud when attached on a bnb-loaded base (read dense weight in init).
# delora/eva also read weight but currently silently dequant -- they produce sane attach,
# so we don't expect a raise from them in the attach-only smoke.
@@ -69,6 +69,7 @@ def quick_cfg(variant: str, tmp_path: Path, quantization: str = "none") -> "benc
max_seq_length=128,
max_new_tokens=8,
lr=5e-3,
road_group_size=8,
seed=0,
log_examples=0,
log_every=1000,
Generated
+3 -1
View File
@@ -12,7 +12,7 @@ resolution-markers = [
]
[options]
exclude-newer = "2026-04-21T14:06:04.428693663Z"
exclude-newer = "2026-04-22T01:31:09.47902161Z"
exclude-newer-span = "P5D"
[[package]]
@@ -1012,6 +1012,7 @@ dependencies = [
{ name = "einops" },
{ name = "jaxtyping", version = "0.3.7", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
{ name = "jaxtyping", version = "0.3.9", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
{ name = "safetensors" },
{ name = "torch" },
]
@@ -1051,6 +1052,7 @@ requires-dist = [
{ name = "einops", specifier = ">=0.7" },
{ name = "jaxtyping", specifier = ">=0.2.34" },
{ name = "pytest", marker = "extra == 'test'" },
{ name = "safetensors", specifier = ">=0.5" },
{ name = "safetensors", marker = "extra == 'benchmark'", specifier = ">=0.5" },
{ name = "safetensors", marker = "extra == 'hf-test'", specifier = ">=0.5" },
{ name = "tabulate", marker = "extra == 'benchmark'" },