diff --git a/pyproject.toml b/pyproject.toml index 61cde92..dc6fdbc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ dependencies = [ "torch>=2.1", "einops>=0.7", "jaxtyping>=0.2.34", + "safetensors>=0.5", ] keywords = ["lora", "pytorch", "peft", "adapters", "llm"] classifiers = [ diff --git a/scripts/metamath_gsm8k_benchmark.py b/scripts/metamath_gsm8k_benchmark.py index fb7069a..3244e73 100644 --- a/scripts/metamath_gsm8k_benchmark.py +++ b/scripts/metamath_gsm8k_benchmark.py @@ -33,6 +33,7 @@ CFG_BY_VARIANT = { "hra": ll.HRAConfig, "eva": ll.EVAConfig, "antipasto": ll.AntiPaSTOConfig, + "road": ll.RoadConfig, } @@ -41,7 +42,7 @@ class BenchmarkConfig: """MetaMathQA -> GSM8K benchmark config. Tyro turns this into the CLI.""" model: str = "Qwen/Qwen3-0.6B-Base" - variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto"] = "lora" + variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "road"] = "lora" mode: Literal["benchmark", "probe"] = "benchmark" device: str = "cuda" torch_dtype: str = "bfloat16" @@ -49,6 +50,7 @@ class BenchmarkConfig: r: int = 32 alpha: float = 64.0 delora_lambda0: float = 0.1 + road_group_size: int = 64 target_name: list[str] = field(default_factory=lambda: list(DEFAULT_TARGETS)) layers: str = "all" train_dataset: str = "meta-math/MetaMathQA" @@ -118,6 +120,8 @@ def parse_layers(text: str) -> tuple[int, ...] | None: def cfg_for_variant(args: BenchmarkConfig, dtype: torch.dtype) -> ll.AdapterConfig: extra = {"lambda0": args.delora_lambda0} if args.variant == "delora" else {} + if args.variant == "road": + extra = {"group_size": args.road_group_size} return CFG_BY_VARIANT[args.variant]( r=args.r, alpha=args.r if args.variant == "pissa" else args.alpha, @@ -147,7 +151,7 @@ def count_base_grad_leaks(model: torch.nn.Module) -> int: def perturb_first_adapter(model: torch.nn.Module) -> None: - priority = ("lora_B", "lora_g", "lora_U", "lora_A", "lora_lambda", "lora_gate", "lora_delta_s", "lora_rot_T", "lora_m") + priority = ("lora_B", "lora_g", "lora_U", "lora_A", "lora_lambda", "lora_gate", "lora_delta_s", "lora_rot_T", "lora_m", "lora_road_theta", "lora_road_alpha") for key in priority: for _, p in model.named_parameters(): if p.requires_grad and key in _: @@ -409,11 +413,12 @@ def check_probe_reload( loaded_model, _ = load_model_and_tokenizer(args.model, getattr(torch, args.torch_dtype), args.device, args.quantization) loaded_model.eval() ll.load(loaded_model, str(adapter_path)) - saved = torch.load(adapter_path, weights_only=True, map_location="cpu") + from safetensors.torch import load_file + saved_sd = load_file(str(adapter_path), device="cpu") loaded_state = adapter_state(loaded_model) - if set(saved["state"]) != set(loaded_state): + if set(saved_sd) != set(loaded_state): raise AssertionError("loaded adapter keys differ from saved adapter keys") - for name, value in saved["state"].items(): + for name, value in saved_sd.items(): if not torch.equal(loaded_state[name].cpu(), value): raise AssertionError(f"loaded adapter tensor differs: {name}") logits_loaded = loaded_model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).logits.detach().clone() @@ -426,12 +431,21 @@ def check_probe_reload( return {"reload_err": reload_err, "saved_tensors": len(saved["state"])} -def print_final_report(row: dict[str, Any], result_path: Path) -> None: - print("SHOULD: grad>0, dθ>0, base_grad_leaks=0, valid/test fields present; probeΔ<0 is good but not required for tiny random smoke. ELSE adapter or eval wiring is dead/wrong.") +def print_final_report(row: dict[str, Any], result_path: Path, mode: str) -> None: + # BLUF: status line first so log tails are immediately readable + cue = "🟢" if row.get("base_grad_leaks", 0) == 0 and row.get("grad", 0) > 0 else "🔴" + n = row.get("samples", "?") + print(f"{cue} test_acc={row['test_acc']:.4g} valid_acc={row['valid_acc']:.4g} grad={row['grad']:.3g} dθ={row['dθ']:.3g} base_grad_leaks={row['base_grad_leaks']} N={n}") + print("SHOULD: grad>0, dθ>0, base_grad_leaks=0; test/valid_acc meaningful only in benchmark mode. ELSE adapter or eval wiring is dead/wrong.") + # ordered: most important / shortest columns first + display_keys = ["variant", "test_acc", "valid_acc", "grad", "dθ", "base_grad_leaks", "steps", "samples", "loss0", "lossN", "commit"] + if "perturb" in row: + display_keys += ["perturb", "reload"] + display_keys += ["run_id"] + display_row = {k: row[k] for k in display_keys if k in row} + print(tabulate([display_row], headers="keys", tablefmt="tsv", floatfmt=".4g")) + print(f"argv: {' '.join(sys.argv)} N={n} mode={mode}") print(f"out: {result_path}") - print(f"argv: {' '.join(sys.argv)}") - print(f"main metric: test_acc={row['test_acc']:.4g} valid_acc={row['valid_acc']:.4g} steps={row['steps']} samples={row['samples']}") - print(tabulate([row], headers="keys", tablefmt="tsv", floatfmt=".4g")) def current_git_commit() -> str: @@ -447,25 +461,24 @@ def append_results_row( result: dict[str, Any], run_commit: str, ) -> tuple[Path, Path]: - results_dir = args.output_dir / "results" + results_dir = args.output_dir results_dir.mkdir(parents=True, exist_ok=True) - tsv_path = results_dir / "benchmark_results.tsv" - lock_path = results_dir / "benchmark_results.tsv.lock" + tsv_path = results_dir / "summary.tsv" + lock_path = results_dir / "summary.tsv.lock" finished_at = datetime.now(timezone.utc).isoformat(timespec="seconds") finished_label = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") snapshot_path = results_dir / f"{result['run_id']}__{finished_label}.json" snapshot_path.write_text(json.dumps(result, indent=2), encoding="utf-8") row = { - "time_utc": finished_at, - "commit": run_commit, + "test_acc": result["test_acc"], + "valid_acc": result["valid_acc"], "method": args.variant, - "model": args.model, - "mode": args.mode, - "valid_accuracy": result["valid_accuracy"], - "test_accuracy": result["test_accuracy"], "steps": args.steps, "samples": result["train_samples"], - "wall_time_s": result["wall_time_s"], + "model": args.model, + "commit": run_commit[:12], + "wall_time_s": round(result["wall_time_s"]), + "time_utc": finished_at, "argv": " ".join(sys.argv), "result_json": str(snapshot_path), "latest_result_json": str(result_path), @@ -517,7 +530,7 @@ def run(args: BenchmarkConfig) -> dict[str, Any]: valid_metrics = evaluate(model, tokenizer, datasets["valid"], args, "valid") test_metrics = evaluate(model, tokenizer, datasets["test"], args, "test") - adapter_path = out_dir / "adapter.pt" + adapter_path = out_dir / "adapter.safetensors" ll.save(model, str(adapter_path)) if args.mode == "probe": model.eval() @@ -555,8 +568,8 @@ def run(args: BenchmarkConfig) -> dict[str, Any]: "weight_decay": args.weight_decay, "lr_scheduler": "cosine", "grad_norm_clip": args.grad_norm_clip, - "valid_accuracy": valid_metrics["accuracy"], - "test_accuracy": test_metrics["accuracy"], + "valid_acc": valid_metrics["accuracy"], + "test_acc": test_metrics["accuracy"], "train": train_metrics, "valid": valid_metrics, "test": test_metrics, @@ -566,12 +579,12 @@ def run(args: BenchmarkConfig) -> dict[str, Any]: } result_path = out_dir / "result.json" result_path.write_text(json.dumps(result, indent=2), encoding="utf-8") - results_tsv_path, result_snapshot_path = append_results_row(args, result_path, result, run_commit) - result["results_tsv_path"] = str(results_tsv_path) - result["result_snapshot_path"] = str(result_snapshot_path) + if args.mode == "benchmark": + results_tsv_path, result_snapshot_path = append_results_row(args, result_path, result, run_commit) + result["results_tsv_path"] = str(results_tsv_path) + result["result_snapshot_path"] = str(result_snapshot_path) result["commit"] = run_commit result_path.write_text(json.dumps(result, indent=2), encoding="utf-8") - commit_prefix = run_commit[:12] row = { "run_id": run_id, @@ -592,7 +605,7 @@ def run(args: BenchmarkConfig) -> dict[str, Any]: if probe_metrics is not None: row["perturb"] = probe_metrics["perturb_delta"] row["reload"] = probe_metrics["reload_err"] - print_final_report(row, result_path) + print_final_report(row, result_path, args.mode) return result diff --git a/src/lora_lite/__init__.py b/src/lora_lite/__init__.py index 438ba72..588e394 100644 --- a/src/lora_lite/__init__.py +++ b/src/lora_lite/__init__.py @@ -20,6 +20,7 @@ from .variants.dora import DoRAConfig from .variants.hra import HRAConfig from .variants.eva import EVAConfig from .variants.antipasto import AntiPaSTOConfig +from .variants.road import RoadConfig __all__ = [ "AdapterConfig", @@ -32,6 +33,7 @@ __all__ = [ "HRAConfig", "EVAConfig", "AntiPaSTOConfig", + "RoadConfig", "attach", "detach", "save", diff --git a/src/lora_lite/adapter.py b/src/lora_lite/adapter.py index 62435b6..6dbc0b1 100644 --- a/src/lora_lite/adapter.py +++ b/src/lora_lite/adapter.py @@ -1,5 +1,6 @@ """attach / detach / save / load. The whole runtime.""" from __future__ import annotations +import json import torch from torch import nn from torch.utils.hooks import RemovableHandle @@ -121,19 +122,22 @@ def save(model: nn.Module, path: str) -> None: if state is None: raise RuntimeError("no adapter attached; call attach() first") sd = {k: v.detach().cpu() for k, v in model.state_dict().items() if "lora_" in k} - blob = { - "cfg": state["cfg"].to_dict(), - "state": sd, - "base_fp": _base_weight_fingerprint(model), + metadata = { + "cfg": json.dumps(state["cfg"].to_dict()), + "base_fp": json.dumps(_base_weight_fingerprint(model)), } - torch.save(blob, path) + from safetensors.torch import save_file + save_file(sd, path, metadata=metadata) def load(model: nn.Module, path: str) -> list[RemovableHandle]: - blob = torch.load(path, weights_only=True, map_location="cpu") - cfg = AdapterConfig.from_dict(blob["cfg"]) + from safetensors.torch import load_file, safe_open + with safe_open(path, framework="pt", device="cpu") as f: + metadata = f.metadata() + sd = load_file(path, device="cpu") + cfg = AdapterConfig.from_dict(json.loads(metadata["cfg"])) handles = attach(model, cfg, _skip_group_init=True) # creates empty params; data-driven inits restored from state_dict - missing, unexpected = model.load_state_dict(blob["state"], strict=False) + missing, unexpected = model.load_state_dict(sd, strict=False) expected_lora = {k for k in model.state_dict() if "lora_" in k} missing_lora = sorted(expected_lora.intersection(missing)) if missing_lora: @@ -141,7 +145,7 @@ def load(model: nn.Module, path: str) -> list[RemovableHandle]: unexpected_lora = [k for k in unexpected if "lora_" in k] if unexpected_lora: raise RuntimeError(f"unexpected lora keys in checkpoint: {unexpected_lora}") - saved_fp = blob.get("base_fp", {}) + saved_fp = json.loads(metadata.get("base_fp", "{}")) if saved_fp: cur_fp = _base_weight_fingerprint(model) diffs = [k for k in saved_fp if saved_fp[k] != cur_fp.get(k)] diff --git a/src/lora_lite/variants/__init__.py b/src/lora_lite/variants/__init__.py index 9c371e1..fc39fad 100644 --- a/src/lora_lite/variants/__init__.py +++ b/src/lora_lite/variants/__init__.py @@ -1 +1 @@ -from . import lora, pissa, delora, ia3, dora, hra, eva, antipasto # noqa: F401 side-effect: register +from . import lora, pissa, delora, ia3, dora, hra, eva, antipasto, road # noqa: F401 side-effect: register diff --git a/src/lora_lite/variants/lora.py b/src/lora_lite/variants/lora.py index feefe1a..34eaae0 100644 --- a/src/lora_lite/variants/lora.py +++ b/src/lora_lite/variants/lora.py @@ -37,7 +37,6 @@ class LoRA: @staticmethod def init(layer: nn.Module, cfg) -> None: - # B is zeros => delta=0 at t=0; identity invariant holds. return @staticmethod diff --git a/src/lora_lite/variants/road.py b/src/lora_lite/variants/road.py new file mode 100644 index 0000000..e87bd35 --- /dev/null +++ b/src/lora_lite/variants/road.py @@ -0,0 +1,137 @@ +"""ROAD: Rotation ADaptation. https://arxiv.org/abs/2409.00119 + +ROAD applies a learned output-space block rotation/scaling after the frozen base +layer: + + y' = R y = R (W x + b) + +This matches PEFT's unmerged forward path and fits lora-lite as a simple output +hook. We implement the three PEFT variants (`road_1`, `road_2`, `road_4`) and +skip merge/unmerge because this library keeps adapters as hooks. + +Refs: + - peft: https://github.com/huggingface/peft/blob/6030f9160ed2fc17220f6f41382a66f1257b6a93/src/peft/tuners/road/layer.py +""" +from dataclasses import dataclass +from typing import Literal + +import torch +from jaxtyping import Float +from torch import nn, Tensor as T + +from ..config import AdapterConfig, register_config +from ..variant import ParamSpec, register + +RoadVariant = Literal["road_1", "road_2", "road_4"] + + +@register_config +@dataclass +class RoadConfig(AdapterConfig): + variant: str = "road" + road_variant: RoadVariant = "road_1" + group_size: int = 64 + + +def _road_param_size(d_out: int, road_variant: str) -> int: + if road_variant == "road_1": + return d_out // 2 + if road_variant == "road_2": + return d_out + if road_variant == "road_4": + return d_out * 2 + raise ValueError(f"road_variant must be 'road_1', 'road_2', or 'road_4', got {road_variant!r}") + + +def _validate_group_geometry(d_out: int, group_size: int) -> None: + if group_size <= 0 or group_size % 2 != 0: + raise ValueError(f"ROAD group_size must be positive and even, got {group_size}") + if d_out % group_size != 0: + raise ValueError(f"ROAD d_out={d_out} must be divisible by group_size={group_size}") + + +def _prepare_cols( + road_variant: str, + group_size: int, + road_theta: torch.Tensor, + road_alpha: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + if road_variant == "road_1": + # One θ/α per pair. Reuse it for both rows of each 2D rotation block. + road_theta = road_theta.reshape(-1, group_size // 2).repeat_interleave(2, dim=0).flatten() + road_alpha = road_alpha.reshape(-1, group_size // 2).repeat_interleave(2, dim=0).flatten() + first_col = road_alpha * road_theta.cos() + second_col = road_alpha * road_theta.sin() + elif road_variant == "road_2": + # One θ/α per output coordinate. + first_col = road_alpha * road_theta.cos() + second_col = road_alpha * road_theta.sin() + elif road_variant == "road_4": + # Independent θ/α for the first and second column contributions. + road_theta = road_theta.reshape(-1, 2, group_size) + road_alpha = road_alpha.reshape(-1, 2, group_size) + first_col = road_alpha[:, 0, :].flatten() * road_theta[:, 0, :].cos().flatten() + second_col = road_alpha[:, 1, :].flatten() * road_theta[:, 1, :].sin().flatten() + else: + raise ValueError(f"road_variant must be 'road_1', 'road_2', or 'road_4', got {road_variant!r}") + return first_col, second_col + + +def _apply_road( + road_variant: str, + group_size: int, + road_theta: torch.Tensor, + road_alpha: torch.Tensor, + y: Float[T, '*B o'], +) -> Float[T, '*B o']: + first_col, second_col = _prepare_cols(road_variant, group_size, road_theta, road_alpha) + y_grouped = y.reshape(-1, 2, group_size // 2) + y1 = y_grouped[:, 0, :] + y2 = y_grouped[:, 1, :] + rotate_half_y = torch.stack((-y2, y1), dim=1).reshape(y.shape) + return y * first_col + rotate_half_y * second_col + + +def _road_matrix( + road_variant: str, + group_size: int, + road_theta: torch.Tensor, + road_alpha: torch.Tensor, +) -> torch.Tensor: + """Explicit PEFT merge matrix. Used for tests and small-debug inspection.""" + first_col, second_col = _prepare_cols(road_variant, group_size, road_theta, road_alpha) + size = second_col.shape[0] + output = torch.diag(first_col) + swapped_second_col = second_col.reshape(-1, 2, group_size // 2)[:, [1, 0], :].flatten() + rotated_diag_second_col = torch.diag(swapped_second_col).reshape(-1, 2, group_size // 2, size)[:, [1, 0], :, :] + rotated_diag_second_col[:, 0, :, :] *= -1 + return output + rotated_diag_second_col.reshape(size, size) + + +@register +class ROAD: + name = "road" + + @staticmethod + def param_specs(d_in: int, d_out: int, cfg: RoadConfig) -> dict[str, ParamSpec]: + _validate_group_geometry(d_out, cfg.group_size) + size = _road_param_size(d_out, cfg.road_variant) + return { + "lora_road_theta": ParamSpec((size,), init="zeros", trainable=True), + "lora_road_alpha": ParamSpec((size,), init="ones", trainable=True), + } + + @staticmethod + def init(layer: nn.Module, cfg: RoadConfig) -> None: + return + + @staticmethod + def forward( + layer: nn.Module, + x: Float[T, '*B i'], + y: Float[T, '*B o'], + ) -> Float[T, '*B o']: + del x + cfg = layer._lora_cfg + y_cast = y.to(layer.lora_road_theta.dtype) + return _apply_road(cfg.road_variant, cfg.group_size, layer.lora_road_theta, layer.lora_road_alpha, y_cast) diff --git a/tests/test_lora_lite.py b/tests/test_lora_lite.py new file mode 100644 index 0000000..f2eb698 --- /dev/null +++ b/tests/test_lora_lite.py @@ -0,0 +1,366 @@ +"""Per-variant attach + train + save + load round-trip, plus surgical regressions. + +The big invariant is the parametrized train_save_load test: identity at t=0, +gradient flow on a real loss, then save -> reload onto a fresh model and +confirm the trained outputs survive the round-trip. Cheap on CPU. +""" +from __future__ import annotations + +from pathlib import Path + +import pytest +import torch +from torch import nn + +import lora_lite as ll + + +CFG_BY_VARIANT = { + "lora": ll.LoRAConfig, + "pissa": ll.PiSSAConfig, + "delora": ll.DeLoRAConfig, + "ia3": ll.IA3Config, + "ia3_ff": ll.IA3FFConfig, + "dora": ll.DoRAConfig, + "hra": ll.HRAConfig, + "eva": ll.EVAConfig, + "antipasto": ll.AntiPaSTOConfig, + "road": ll.RoadConfig, +} + +# Per-variant identity tolerance at t=0 (after attach, before any step). +# fp32 SVD round-trip + per-row norm = looser tolerance for pissa/dora/antipasto. +IDENTITY_TOL = { + "lora": 1e-6, + "pissa": 5e-4, + "delora": 1e-6, + "ia3": 1e-6, + "ia3_ff": 1e-6, + "dora": 5e-5, + "hra": 5e-6, + "eva": 1e-6, + "antipasto": 5e-4, + "road": 1e-6, +} + + +class TinyBlock(nn.Module): + def __init__(self, d: int = 64, ff: int = 128): + super().__init__() + self.q_proj = nn.Linear(d, d, bias=False) + self.k_proj = nn.Linear(d, d, bias=False) + self.v_proj = nn.Linear(d, d, bias=False) + self.o_proj = nn.Linear(d, d, bias=False) + self.gate_proj = nn.Linear(d, ff, bias=False) + self.up_proj = nn.Linear(d, ff, bias=False) + self.down_proj = nn.Linear(ff, d, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.o_proj(self.q_proj(x) + self.k_proj(x) + self.v_proj(x)) + m = self.down_proj(torch.nn.functional.silu(self.gate_proj(x)) * self.up_proj(x)) + return x + h + m + + +class TinyModel(nn.Module): + def __init__(self, n_layers: int = 4, d: int = 64, ff: int = 128, vocab: int = 100): + super().__init__() + self.embed_tokens = nn.Embedding(vocab, d) + self.layers = nn.ModuleList([TinyBlock(d, ff) for _ in range(n_layers)]) + self.lm_head = nn.Linear(d, vocab, bias=False) + self.config = type("Cfg", (), {"hidden_size": d})() + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + x = self.embed_tokens(ids) + for block in self.layers: + x = block(x) + return self.lm_head(x) + + +class FakeLinearLike(nn.Module): + """linear-like, but not nn.Linear: stand-in for bnb 4/8-bit modules.""" + + def __init__(self, d_in: int = 8, d_out: int = 8): + super().__init__() + self.in_features = d_in + self.out_features = d_out + self.weight = nn.Parameter(torch.empty(d_out, d_in)) + nn.init.kaiming_uniform_(self.weight, a=5 ** 0.5) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.linear(x, self.weight) + + +class FakeBnbModel(nn.Module): + def __init__(self): + super().__init__() + self.config = type("Cfg", (), {"hidden_size": 8})() + self.layers = nn.ModuleList([FakeLinearLike(8, 8)]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.layers[0](x) + + +def cfg_for(variant: str) -> ll.AdapterConfig: + return CFG_BY_VARIANT[variant]( + r=4, + alpha=8, + dtype=torch.float32, + ) + + +def attach_with_calib(model: nn.Module, cfg: ll.AdapterConfig, ids: torch.Tensor) -> None: + if cfg.variant == "eva": + calib = [ids for _ in range(2)] + ll.attach(model, cfg, calibration_data=calib) + else: + ll.attach(model, cfg) + + +def trainable_grad_norm(model: nn.Module) -> float: + return sum( + p.grad.detach().float().norm().item() + for n, p in model.named_parameters() + if "lora_" in n and p.grad is not None + ) + + +@pytest.mark.parametrize("variant", list(CFG_BY_VARIANT)) +def test_train_save_load(variant: str, tmp_path: Path): + """Identity at t=0, one SGD step, save, reload onto fresh model, outputs match.""" + torch.manual_seed(0) + model = TinyModel() + ids = torch.randint(0, 100, (2, 16)) + with torch.no_grad(): + y_base = model(ids).clone() + + cfg = cfg_for(variant) + attach_with_calib(model, cfg, ids) + + trainable = [p for p in model.parameters() if p.requires_grad] + assert trainable + assert all("lora_" in n for n, p in model.named_parameters() if p.requires_grad) + + with torch.no_grad(): + y_init = model(ids).clone() + assert (y_init - y_base).abs().max().item() < IDENTITY_TOL[variant] + + target = torch.randn_like(y_init) * 0.1 + opt = torch.optim.SGD(trainable, lr=1e-2) + opt.zero_grad() + loss = (model(ids) - target).pow(2).mean() + loss.backward() + leaked = [n for n, p in model.named_parameters() if "lora_" not in n and p.grad is not None] + assert leaked == [] + assert trainable_grad_norm(model) > 0 + opt.step() + + with torch.no_grad(): + y_trained = model(ids).clone() + + path = tmp_path / "adapter.pt" + ll.save(model, str(path)) + + torch.manual_seed(0) + model_loaded = TinyModel() + ll.load(model_loaded, str(path)) # EVA load skips group_init; calibration_data not needed + with torch.no_grad(): + y_loaded = model_loaded(ids) + assert (y_loaded - y_trained).abs().max().item() < max(IDENTITY_TOL[variant], 1e-5) + + +@pytest.mark.parametrize("variant", ["lora", "delora", "ia3", "hra", "road"]) +def test_hook_only_variants_attach_to_non_linear_target(variant: str): + """bnb-style targets are linear-like but not nn.Linear; hook-only variants must accept them.""" + extra = {"lambda0": 0.1} if variant == "delora" else {"group_size": 8} if variant == "road" else {} + cfg = CFG_BY_VARIANT[variant](r=2, alpha=4, dtype=torch.float32, target_roles=(), **extra) + model = FakeBnbModel() + ll.attach(model, cfg) + x = torch.randn(2, 3, 8) + model(x).pow(2).mean().backward() + assert trainable_grad_norm(model) > 0 + + +@pytest.mark.parametrize("variant", ["pissa", "dora", "antipasto"]) +def test_weight_reading_variants_reject_non_linear(variant: str): + r = 4 if variant == "antipasto" else 2 # antipasto needs r % block_size==0 + cfg = CFG_BY_VARIANT[variant](r=r, alpha=r, dtype=torch.float32, target_roles=()) + with pytest.raises(TypeError, match="plain nn.Linear"): + ll.attach(FakeBnbModel(), cfg) + + +def test_save_load_strict_keys(tmp_path: Path): + import json + from safetensors.torch import load_file, save_file + + torch.manual_seed(0) + model = TinyModel() + ll.attach(model, ll.LoRAConfig(r=4, alpha=8, dtype=torch.float32)) + p = tmp_path / "lora.safetensors" + ll.save(model, str(p)) + sd = load_file(str(p), device="cpu") + + # missing key: drop first lora key + missing_sd = dict(sd) + dropped_key = next(iter(missing_sd)) + del missing_sd[dropped_key] + from safetensors import safe_open + with safe_open(str(p), framework="pt", device="cpu") as f: + meta = f.metadata() + save_file(missing_sd, str(p), metadata=meta) + with pytest.raises(RuntimeError, match="missing lora keys"): + ll.load(TinyModel(), str(p)) + + # unexpected key: add a bogus lora key + bad_sd = dict(sd) + bad_sd["layers.0.q_proj.lora_extra"] = torch.zeros(1) + save_file(bad_sd, str(p), metadata=meta) + with pytest.raises(RuntimeError, match="unexpected lora keys"): + ll.load(TinyModel(), str(p)) + + +def test_no_target_layers_is_loud(): + cfg = ll.LoRAConfig(target_names=("definitely_missing",)) + with pytest.raises(RuntimeError, match="no target layers"): + ll.attach(TinyModel(), cfg) + + +def test_eva_requires_calibration(): + """EVA's group_init must error loudly if calibration_data is missing.""" + with pytest.raises(ValueError, match="calibration_data"): + ll.attach(TinyModel(), ll.EVAConfig(r=4, alpha=8, dtype=torch.float32)) + + +def test_delora_default_has_live_step0_gradient(): + """Default lambda0 must be nonzero; B=0 preserves identity while B gets gradient.""" + torch.manual_seed(0) + model = TinyModel(n_layers=1) + ids = torch.randint(0, 100, (2, 8)) + ll.attach(model, ll.DeLoRAConfig(r=4, alpha=8, dtype=torch.float32)) + + assert model.layers[0].q_proj.lora_lambda.item() == pytest.approx(15.0) + loss = model(ids).pow(2).mean() + loss.backward() + + b_grad = model.layers[0].q_proj.lora_B.grad.detach().abs().max().item() + assert b_grad > 0 + + +def test_pissa_identity_with_nonunit_scale(): + """Regression: PiSSA must pre-divide S by alpha/r, not require alpha == r.""" + torch.manual_seed(0) + model = TinyModel(n_layers=1) + ids = torch.randint(0, 100, (2, 8)) + with torch.no_grad(): + y_base = model(ids).clone() + + ll.attach(model, ll.PiSSAConfig(r=4, alpha=8, dtype=torch.float32)) + with torch.no_grad(): + y = model(ids) + assert (y - y_base).abs().max().item() < IDENTITY_TOL["pissa"] + + +def test_antipasto_blockwise_rotation_matches_explicit_blockdiag(): + """The einsum/rearrange path must equal the old explicit blockdiag math.""" + from lora_lite.variants.antipasto import _build_rotation + + torch.manual_seed(0) + n_blocks, bs, d_in, d_out = 3, 4, 7, 5 + r = n_blocks * bs + rot_T = torch.randn(n_blocks, bs * (bs - 1) // 2) * 0.1 + Vh = torch.randn(r, d_in) + U = torch.randn(d_out, r) + R_blocks = _build_rotation(rot_T, bs, 0.5) + R = torch.block_diag(*list(R_blocks)) + + Vh_blocks = torch.reshape(Vh, (n_blocks, bs, d_in)) + Vh_rot = torch.einsum("nab,nbi->nai", R_blocks, Vh_blocks).reshape(r, d_in) + U_blocks = torch.reshape(U, (d_out, n_blocks, bs)) + U_rot = torch.einsum("dnb,ncb->dnc", U_blocks, R_blocks).reshape(d_out, r) + + assert (Vh_rot - R @ Vh).abs().max().item() < 1e-6 + assert (U_rot - U @ R.T).abs().max().item() < 1e-6 + + +def test_dora_bias_passthrough(): + """Regression: DoRA must NOT scale bias; identity holds with bias=True at t=0.""" + torch.manual_seed(0) + d = 16 + layer = nn.Linear(d, d, bias=True) + x = torch.randn(2, d) + y_base = layer(x).detach() + + class Wrap(nn.Module): + def __init__(self, lin): + super().__init__() + self.config = type("Cfg", (), {"hidden_size": d})() + self.layers = nn.ModuleList([lin]) + + def forward(self, x): + return self.layers[0](x) + + model = Wrap(layer) + ll.attach(model, ll.DoRAConfig(r=2, alpha=4, dtype=torch.float32, target_roles=())) + with torch.no_grad(): + y = model(x) + assert (y - y_base).abs().max().item() < 1e-5 + + +def test_hra_forward_is_x_R_T(): + """HRA must apply x @ R^T (loop i = r-1 down to 0). Asymmetric U makes order observable.""" + torch.manual_seed(0) + d = 8 + layer = nn.Linear(d, d, bias=False) + x = torch.randn(2, 3, d) + + class Wrap(nn.Module): + def __init__(self, lin): + super().__init__() + self.config = type("Cfg", (), {"hidden_size": d})() + self.layers = nn.ModuleList([lin]) + + def forward(self, x): + return self.layers[0](x) + + model = Wrap(layer) + ll.attach(model, ll.HRAConfig(r=4, alpha=4, dtype=torch.float32, target_roles=())) + # break paired symmetry so order matters + with torch.no_grad(): + layer.lora_U.add_(0.1 * torch.randn_like(layer.lora_U)) + + U = layer.lora_U + R = torch.eye(d) + for i in range(U.shape[0]): + u = U[i] + sq = (u * u).sum().clamp_min(1e-12) + R = R - (2.0 / sq) * torch.outer(R @ u, u) + with torch.no_grad(): + y_adapt = model(x) + y_ref = torch.nn.functional.linear(x, layer.weight @ R) + assert (y_adapt - y_ref).abs().max().item() < 1e-5 + + +@pytest.mark.parametrize("road_variant", ["road_1", "road_2", "road_4"]) +def test_road_apply_matches_explicit_matrix(road_variant: str): + """Fast elementwise ROAD path must match PEFT's explicit R @ y matrix construction.""" + from lora_lite.variants.road import _apply_road, _road_matrix, _road_param_size + + torch.manual_seed(0) + d_out = 16 + group_size = 8 + size = _road_param_size(d_out, road_variant) + theta = torch.randn(size) * 0.2 + alpha = torch.randn(size) * 0.1 + 1.0 + y = torch.randn(2, 3, d_out) + + y_fast = _apply_road(road_variant, group_size, theta, alpha, y) + R = _road_matrix(road_variant, group_size, theta, alpha) + y_ref = torch.einsum("oi,...i->...o", R, y) + + assert (y_fast - y_ref).abs().max().item() < 1e-6 + + +def test_road_invalid_group_size_is_loud(): + with pytest.raises(ValueError, match="positive and even"): + ll.attach(TinyModel(), ll.RoadConfig(group_size=7)) + with pytest.raises(ValueError, match="divisible"): + ll.attach(TinyModel(), ll.RoadConfig(group_size=48)) diff --git a/tests/test_metamath_smoke.py b/tests/test_metamath_smoke.py index ededfc9..5a5283a 100644 --- a/tests/test_metamath_smoke.py +++ b/tests/test_metamath_smoke.py @@ -31,7 +31,7 @@ sys.modules[SPEC.name] = benchmark SPEC.loader.exec_module(benchmark) -VARIANTS = ["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto"] +VARIANTS = ["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "road"] # Variants that fail loud when attached on a bnb-loaded base (read dense weight in init). # delora/eva also read weight but currently silently dequant -- they produce sane attach, # so we don't expect a raise from them in the attach-only smoke. @@ -69,6 +69,7 @@ def quick_cfg(variant: str, tmp_path: Path, quantization: str = "none") -> "benc max_seq_length=128, max_new_tokens=8, lr=5e-3, + road_group_size=8, seed=0, log_examples=0, log_every=1000, diff --git a/uv.lock b/uv.lock index d4cd96e..55dedd2 100644 --- a/uv.lock +++ b/uv.lock @@ -12,7 +12,7 @@ resolution-markers = [ ] [options] -exclude-newer = "2026-04-21T14:06:04.428693663Z" +exclude-newer = "2026-04-22T01:31:09.47902161Z" exclude-newer-span = "P5D" [[package]] @@ -1012,6 +1012,7 @@ dependencies = [ { name = "einops" }, { name = "jaxtyping", version = "0.3.7", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "jaxtyping", version = "0.3.9", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "safetensors" }, { name = "torch" }, ] @@ -1051,6 +1052,7 @@ requires-dist = [ { name = "einops", specifier = ">=0.7" }, { name = "jaxtyping", specifier = ">=0.2.34" }, { name = "pytest", marker = "extra == 'test'" }, + { name = "safetensors", specifier = ">=0.5" }, { name = "safetensors", marker = "extra == 'benchmark'", specifier = ">=0.5" }, { name = "safetensors", marker = "extra == 'hf-test'", specifier = ">=0.5" }, { name = "tabulate", marker = "extra == 'benchmark'" },