simpler test

This commit is contained in:
wassname
2026-04-27 09:47:07 +08:00
parent b60a8c3f9b
commit 24ba8deb02
10 changed files with 566 additions and 41 deletions
+1
View File
@@ -8,6 +8,7 @@ dependencies = [
"torch>=2.1", "torch>=2.1",
"einops>=0.7", "einops>=0.7",
"jaxtyping>=0.2.34", "jaxtyping>=0.2.34",
"safetensors>=0.5",
] ]
keywords = ["lora", "pytorch", "peft", "adapters", "llm"] keywords = ["lora", "pytorch", "peft", "adapters", "llm"]
classifiers = [ classifiers = [
+41 -28
View File
@@ -33,6 +33,7 @@ CFG_BY_VARIANT = {
"hra": ll.HRAConfig, "hra": ll.HRAConfig,
"eva": ll.EVAConfig, "eva": ll.EVAConfig,
"antipasto": ll.AntiPaSTOConfig, "antipasto": ll.AntiPaSTOConfig,
"road": ll.RoadConfig,
} }
@@ -41,7 +42,7 @@ class BenchmarkConfig:
"""MetaMathQA -> GSM8K benchmark config. Tyro turns this into the CLI.""" """MetaMathQA -> GSM8K benchmark config. Tyro turns this into the CLI."""
model: str = "Qwen/Qwen3-0.6B-Base" model: str = "Qwen/Qwen3-0.6B-Base"
variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto"] = "lora" variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "road"] = "lora"
mode: Literal["benchmark", "probe"] = "benchmark" mode: Literal["benchmark", "probe"] = "benchmark"
device: str = "cuda" device: str = "cuda"
torch_dtype: str = "bfloat16" torch_dtype: str = "bfloat16"
@@ -49,6 +50,7 @@ class BenchmarkConfig:
r: int = 32 r: int = 32
alpha: float = 64.0 alpha: float = 64.0
delora_lambda0: float = 0.1 delora_lambda0: float = 0.1
road_group_size: int = 64
target_name: list[str] = field(default_factory=lambda: list(DEFAULT_TARGETS)) target_name: list[str] = field(default_factory=lambda: list(DEFAULT_TARGETS))
layers: str = "all" layers: str = "all"
train_dataset: str = "meta-math/MetaMathQA" train_dataset: str = "meta-math/MetaMathQA"
@@ -118,6 +120,8 @@ def parse_layers(text: str) -> tuple[int, ...] | None:
def cfg_for_variant(args: BenchmarkConfig, dtype: torch.dtype) -> ll.AdapterConfig: def cfg_for_variant(args: BenchmarkConfig, dtype: torch.dtype) -> ll.AdapterConfig:
extra = {"lambda0": args.delora_lambda0} if args.variant == "delora" else {} extra = {"lambda0": args.delora_lambda0} if args.variant == "delora" else {}
if args.variant == "road":
extra = {"group_size": args.road_group_size}
return CFG_BY_VARIANT[args.variant]( return CFG_BY_VARIANT[args.variant](
r=args.r, r=args.r,
alpha=args.r if args.variant == "pissa" else args.alpha, alpha=args.r if args.variant == "pissa" else args.alpha,
@@ -147,7 +151,7 @@ def count_base_grad_leaks(model: torch.nn.Module) -> int:
def perturb_first_adapter(model: torch.nn.Module) -> None: def perturb_first_adapter(model: torch.nn.Module) -> None:
priority = ("lora_B", "lora_g", "lora_U", "lora_A", "lora_lambda", "lora_gate", "lora_delta_s", "lora_rot_T", "lora_m") priority = ("lora_B", "lora_g", "lora_U", "lora_A", "lora_lambda", "lora_gate", "lora_delta_s", "lora_rot_T", "lora_m", "lora_road_theta", "lora_road_alpha")
for key in priority: for key in priority:
for _, p in model.named_parameters(): for _, p in model.named_parameters():
if p.requires_grad and key in _: if p.requires_grad and key in _:
@@ -409,11 +413,12 @@ def check_probe_reload(
loaded_model, _ = load_model_and_tokenizer(args.model, getattr(torch, args.torch_dtype), args.device, args.quantization) loaded_model, _ = load_model_and_tokenizer(args.model, getattr(torch, args.torch_dtype), args.device, args.quantization)
loaded_model.eval() loaded_model.eval()
ll.load(loaded_model, str(adapter_path)) ll.load(loaded_model, str(adapter_path))
saved = torch.load(adapter_path, weights_only=True, map_location="cpu") from safetensors.torch import load_file
saved_sd = load_file(str(adapter_path), device="cpu")
loaded_state = adapter_state(loaded_model) loaded_state = adapter_state(loaded_model)
if set(saved["state"]) != set(loaded_state): if set(saved_sd) != set(loaded_state):
raise AssertionError("loaded adapter keys differ from saved adapter keys") raise AssertionError("loaded adapter keys differ from saved adapter keys")
for name, value in saved["state"].items(): for name, value in saved_sd.items():
if not torch.equal(loaded_state[name].cpu(), value): if not torch.equal(loaded_state[name].cpu(), value):
raise AssertionError(f"loaded adapter tensor differs: {name}") raise AssertionError(f"loaded adapter tensor differs: {name}")
logits_loaded = loaded_model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).logits.detach().clone() logits_loaded = loaded_model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).logits.detach().clone()
@@ -426,12 +431,21 @@ def check_probe_reload(
return {"reload_err": reload_err, "saved_tensors": len(saved["state"])} return {"reload_err": reload_err, "saved_tensors": len(saved["state"])}
def print_final_report(row: dict[str, Any], result_path: Path) -> None: def print_final_report(row: dict[str, Any], result_path: Path, mode: str) -> None:
print("SHOULD: grad>0, dθ>0, base_grad_leaks=0, valid/test fields present; probeΔ<0 is good but not required for tiny random smoke. ELSE adapter or eval wiring is dead/wrong.") # BLUF: status line first so log tails are immediately readable
cue = "🟢" if row.get("base_grad_leaks", 0) == 0 and row.get("grad", 0) > 0 else "🔴"
n = row.get("samples", "?")
print(f"{cue} test_acc={row['test_acc']:.4g} valid_acc={row['valid_acc']:.4g} grad={row['grad']:.3g} dθ={row['']:.3g} base_grad_leaks={row['base_grad_leaks']} N={n}")
print("SHOULD: grad>0, dθ>0, base_grad_leaks=0; test/valid_acc meaningful only in benchmark mode. ELSE adapter or eval wiring is dead/wrong.")
# ordered: most important / shortest columns first
display_keys = ["variant", "test_acc", "valid_acc", "grad", "", "base_grad_leaks", "steps", "samples", "loss0", "lossN", "commit"]
if "perturb" in row:
display_keys += ["perturb", "reload"]
display_keys += ["run_id"]
display_row = {k: row[k] for k in display_keys if k in row}
print(tabulate([display_row], headers="keys", tablefmt="tsv", floatfmt=".4g"))
print(f"argv: {' '.join(sys.argv)} N={n} mode={mode}")
print(f"out: {result_path}") print(f"out: {result_path}")
print(f"argv: {' '.join(sys.argv)}")
print(f"main metric: test_acc={row['test_acc']:.4g} valid_acc={row['valid_acc']:.4g} steps={row['steps']} samples={row['samples']}")
print(tabulate([row], headers="keys", tablefmt="tsv", floatfmt=".4g"))
def current_git_commit() -> str: def current_git_commit() -> str:
@@ -447,25 +461,24 @@ def append_results_row(
result: dict[str, Any], result: dict[str, Any],
run_commit: str, run_commit: str,
) -> tuple[Path, Path]: ) -> tuple[Path, Path]:
results_dir = args.output_dir / "results" results_dir = args.output_dir
results_dir.mkdir(parents=True, exist_ok=True) results_dir.mkdir(parents=True, exist_ok=True)
tsv_path = results_dir / "benchmark_results.tsv" tsv_path = results_dir / "summary.tsv"
lock_path = results_dir / "benchmark_results.tsv.lock" lock_path = results_dir / "summary.tsv.lock"
finished_at = datetime.now(timezone.utc).isoformat(timespec="seconds") finished_at = datetime.now(timezone.utc).isoformat(timespec="seconds")
finished_label = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") finished_label = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
snapshot_path = results_dir / f"{result['run_id']}__{finished_label}.json" snapshot_path = results_dir / f"{result['run_id']}__{finished_label}.json"
snapshot_path.write_text(json.dumps(result, indent=2), encoding="utf-8") snapshot_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
row = { row = {
"time_utc": finished_at, "test_acc": result["test_acc"],
"commit": run_commit, "valid_acc": result["valid_acc"],
"method": args.variant, "method": args.variant,
"model": args.model,
"mode": args.mode,
"valid_accuracy": result["valid_accuracy"],
"test_accuracy": result["test_accuracy"],
"steps": args.steps, "steps": args.steps,
"samples": result["train_samples"], "samples": result["train_samples"],
"wall_time_s": result["wall_time_s"], "model": args.model,
"commit": run_commit[:12],
"wall_time_s": round(result["wall_time_s"]),
"time_utc": finished_at,
"argv": " ".join(sys.argv), "argv": " ".join(sys.argv),
"result_json": str(snapshot_path), "result_json": str(snapshot_path),
"latest_result_json": str(result_path), "latest_result_json": str(result_path),
@@ -517,7 +530,7 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
valid_metrics = evaluate(model, tokenizer, datasets["valid"], args, "valid") valid_metrics = evaluate(model, tokenizer, datasets["valid"], args, "valid")
test_metrics = evaluate(model, tokenizer, datasets["test"], args, "test") test_metrics = evaluate(model, tokenizer, datasets["test"], args, "test")
adapter_path = out_dir / "adapter.pt" adapter_path = out_dir / "adapter.safetensors"
ll.save(model, str(adapter_path)) ll.save(model, str(adapter_path))
if args.mode == "probe": if args.mode == "probe":
model.eval() model.eval()
@@ -555,8 +568,8 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
"weight_decay": args.weight_decay, "weight_decay": args.weight_decay,
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"grad_norm_clip": args.grad_norm_clip, "grad_norm_clip": args.grad_norm_clip,
"valid_accuracy": valid_metrics["accuracy"], "valid_acc": valid_metrics["accuracy"],
"test_accuracy": test_metrics["accuracy"], "test_acc": test_metrics["accuracy"],
"train": train_metrics, "train": train_metrics,
"valid": valid_metrics, "valid": valid_metrics,
"test": test_metrics, "test": test_metrics,
@@ -566,12 +579,12 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
} }
result_path = out_dir / "result.json" result_path = out_dir / "result.json"
result_path.write_text(json.dumps(result, indent=2), encoding="utf-8") result_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
results_tsv_path, result_snapshot_path = append_results_row(args, result_path, result, run_commit) if args.mode == "benchmark":
result["results_tsv_path"] = str(results_tsv_path) results_tsv_path, result_snapshot_path = append_results_row(args, result_path, result, run_commit)
result["result_snapshot_path"] = str(result_snapshot_path) result["results_tsv_path"] = str(results_tsv_path)
result["result_snapshot_path"] = str(result_snapshot_path)
result["commit"] = run_commit result["commit"] = run_commit
result_path.write_text(json.dumps(result, indent=2), encoding="utf-8") result_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
commit_prefix = run_commit[:12]
row = { row = {
"run_id": run_id, "run_id": run_id,
@@ -592,7 +605,7 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
if probe_metrics is not None: if probe_metrics is not None:
row["perturb"] = probe_metrics["perturb_delta"] row["perturb"] = probe_metrics["perturb_delta"]
row["reload"] = probe_metrics["reload_err"] row["reload"] = probe_metrics["reload_err"]
print_final_report(row, result_path) print_final_report(row, result_path, args.mode)
return result return result
+2
View File
@@ -20,6 +20,7 @@ from .variants.dora import DoRAConfig
from .variants.hra import HRAConfig from .variants.hra import HRAConfig
from .variants.eva import EVAConfig from .variants.eva import EVAConfig
from .variants.antipasto import AntiPaSTOConfig from .variants.antipasto import AntiPaSTOConfig
from .variants.road import RoadConfig
__all__ = [ __all__ = [
"AdapterConfig", "AdapterConfig",
@@ -32,6 +33,7 @@ __all__ = [
"HRAConfig", "HRAConfig",
"EVAConfig", "EVAConfig",
"AntiPaSTOConfig", "AntiPaSTOConfig",
"RoadConfig",
"attach", "attach",
"detach", "detach",
"save", "save",
+13 -9
View File
@@ -1,5 +1,6 @@
"""attach / detach / save / load. The whole runtime.""" """attach / detach / save / load. The whole runtime."""
from __future__ import annotations from __future__ import annotations
import json
import torch import torch
from torch import nn from torch import nn
from torch.utils.hooks import RemovableHandle from torch.utils.hooks import RemovableHandle
@@ -121,19 +122,22 @@ def save(model: nn.Module, path: str) -> None:
if state is None: if state is None:
raise RuntimeError("no adapter attached; call attach() first") raise RuntimeError("no adapter attached; call attach() first")
sd = {k: v.detach().cpu() for k, v in model.state_dict().items() if "lora_" in k} sd = {k: v.detach().cpu() for k, v in model.state_dict().items() if "lora_" in k}
blob = { metadata = {
"cfg": state["cfg"].to_dict(), "cfg": json.dumps(state["cfg"].to_dict()),
"state": sd, "base_fp": json.dumps(_base_weight_fingerprint(model)),
"base_fp": _base_weight_fingerprint(model),
} }
torch.save(blob, path) from safetensors.torch import save_file
save_file(sd, path, metadata=metadata)
def load(model: nn.Module, path: str) -> list[RemovableHandle]: def load(model: nn.Module, path: str) -> list[RemovableHandle]:
blob = torch.load(path, weights_only=True, map_location="cpu") from safetensors.torch import load_file, safe_open
cfg = AdapterConfig.from_dict(blob["cfg"]) with safe_open(path, framework="pt", device="cpu") as f:
metadata = f.metadata()
sd = load_file(path, device="cpu")
cfg = AdapterConfig.from_dict(json.loads(metadata["cfg"]))
handles = attach(model, cfg, _skip_group_init=True) # creates empty params; data-driven inits restored from state_dict handles = attach(model, cfg, _skip_group_init=True) # creates empty params; data-driven inits restored from state_dict
missing, unexpected = model.load_state_dict(blob["state"], strict=False) missing, unexpected = model.load_state_dict(sd, strict=False)
expected_lora = {k for k in model.state_dict() if "lora_" in k} expected_lora = {k for k in model.state_dict() if "lora_" in k}
missing_lora = sorted(expected_lora.intersection(missing)) missing_lora = sorted(expected_lora.intersection(missing))
if missing_lora: if missing_lora:
@@ -141,7 +145,7 @@ def load(model: nn.Module, path: str) -> list[RemovableHandle]:
unexpected_lora = [k for k in unexpected if "lora_" in k] unexpected_lora = [k for k in unexpected if "lora_" in k]
if unexpected_lora: if unexpected_lora:
raise RuntimeError(f"unexpected lora keys in checkpoint: {unexpected_lora}") raise RuntimeError(f"unexpected lora keys in checkpoint: {unexpected_lora}")
saved_fp = blob.get("base_fp", {}) saved_fp = json.loads(metadata.get("base_fp", "{}"))
if saved_fp: if saved_fp:
cur_fp = _base_weight_fingerprint(model) cur_fp = _base_weight_fingerprint(model)
diffs = [k for k in saved_fp if saved_fp[k] != cur_fp.get(k)] diffs = [k for k in saved_fp if saved_fp[k] != cur_fp.get(k)]
+1 -1
View File
@@ -1 +1 @@
from . import lora, pissa, delora, ia3, dora, hra, eva, antipasto # noqa: F401 side-effect: register from . import lora, pissa, delora, ia3, dora, hra, eva, antipasto, road # noqa: F401 side-effect: register
-1
View File
@@ -37,7 +37,6 @@ class LoRA:
@staticmethod @staticmethod
def init(layer: nn.Module, cfg) -> None: def init(layer: nn.Module, cfg) -> None:
# B is zeros => delta=0 at t=0; identity invariant holds.
return return
@staticmethod @staticmethod
+137
View File
@@ -0,0 +1,137 @@
"""ROAD: Rotation ADaptation. https://arxiv.org/abs/2409.00119
ROAD applies a learned output-space block rotation/scaling after the frozen base
layer:
y' = R y = R (W x + b)
This matches PEFT's unmerged forward path and fits lora-lite as a simple output
hook. We implement the three PEFT variants (`road_1`, `road_2`, `road_4`) and
skip merge/unmerge because this library keeps adapters as hooks.
Refs:
- peft: https://github.com/huggingface/peft/blob/6030f9160ed2fc17220f6f41382a66f1257b6a93/src/peft/tuners/road/layer.py
"""
from dataclasses import dataclass
from typing import Literal
import torch
from jaxtyping import Float
from torch import nn, Tensor as T
from ..config import AdapterConfig, register_config
from ..variant import ParamSpec, register
RoadVariant = Literal["road_1", "road_2", "road_4"]
@register_config
@dataclass
class RoadConfig(AdapterConfig):
variant: str = "road"
road_variant: RoadVariant = "road_1"
group_size: int = 64
def _road_param_size(d_out: int, road_variant: str) -> int:
if road_variant == "road_1":
return d_out // 2
if road_variant == "road_2":
return d_out
if road_variant == "road_4":
return d_out * 2
raise ValueError(f"road_variant must be 'road_1', 'road_2', or 'road_4', got {road_variant!r}")
def _validate_group_geometry(d_out: int, group_size: int) -> None:
if group_size <= 0 or group_size % 2 != 0:
raise ValueError(f"ROAD group_size must be positive and even, got {group_size}")
if d_out % group_size != 0:
raise ValueError(f"ROAD d_out={d_out} must be divisible by group_size={group_size}")
def _prepare_cols(
road_variant: str,
group_size: int,
road_theta: torch.Tensor,
road_alpha: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
if road_variant == "road_1":
# One θ/α per pair. Reuse it for both rows of each 2D rotation block.
road_theta = road_theta.reshape(-1, group_size // 2).repeat_interleave(2, dim=0).flatten()
road_alpha = road_alpha.reshape(-1, group_size // 2).repeat_interleave(2, dim=0).flatten()
first_col = road_alpha * road_theta.cos()
second_col = road_alpha * road_theta.sin()
elif road_variant == "road_2":
# One θ/α per output coordinate.
first_col = road_alpha * road_theta.cos()
second_col = road_alpha * road_theta.sin()
elif road_variant == "road_4":
# Independent θ/α for the first and second column contributions.
road_theta = road_theta.reshape(-1, 2, group_size)
road_alpha = road_alpha.reshape(-1, 2, group_size)
first_col = road_alpha[:, 0, :].flatten() * road_theta[:, 0, :].cos().flatten()
second_col = road_alpha[:, 1, :].flatten() * road_theta[:, 1, :].sin().flatten()
else:
raise ValueError(f"road_variant must be 'road_1', 'road_2', or 'road_4', got {road_variant!r}")
return first_col, second_col
def _apply_road(
road_variant: str,
group_size: int,
road_theta: torch.Tensor,
road_alpha: torch.Tensor,
y: Float[T, '*B o'],
) -> Float[T, '*B o']:
first_col, second_col = _prepare_cols(road_variant, group_size, road_theta, road_alpha)
y_grouped = y.reshape(-1, 2, group_size // 2)
y1 = y_grouped[:, 0, :]
y2 = y_grouped[:, 1, :]
rotate_half_y = torch.stack((-y2, y1), dim=1).reshape(y.shape)
return y * first_col + rotate_half_y * second_col
def _road_matrix(
road_variant: str,
group_size: int,
road_theta: torch.Tensor,
road_alpha: torch.Tensor,
) -> torch.Tensor:
"""Explicit PEFT merge matrix. Used for tests and small-debug inspection."""
first_col, second_col = _prepare_cols(road_variant, group_size, road_theta, road_alpha)
size = second_col.shape[0]
output = torch.diag(first_col)
swapped_second_col = second_col.reshape(-1, 2, group_size // 2)[:, [1, 0], :].flatten()
rotated_diag_second_col = torch.diag(swapped_second_col).reshape(-1, 2, group_size // 2, size)[:, [1, 0], :, :]
rotated_diag_second_col[:, 0, :, :] *= -1
return output + rotated_diag_second_col.reshape(size, size)
@register
class ROAD:
name = "road"
@staticmethod
def param_specs(d_in: int, d_out: int, cfg: RoadConfig) -> dict[str, ParamSpec]:
_validate_group_geometry(d_out, cfg.group_size)
size = _road_param_size(d_out, cfg.road_variant)
return {
"lora_road_theta": ParamSpec((size,), init="zeros", trainable=True),
"lora_road_alpha": ParamSpec((size,), init="ones", trainable=True),
}
@staticmethod
def init(layer: nn.Module, cfg: RoadConfig) -> None:
return
@staticmethod
def forward(
layer: nn.Module,
x: Float[T, '*B i'],
y: Float[T, '*B o'],
) -> Float[T, '*B o']:
del x
cfg = layer._lora_cfg
y_cast = y.to(layer.lora_road_theta.dtype)
return _apply_road(cfg.road_variant, cfg.group_size, layer.lora_road_theta, layer.lora_road_alpha, y_cast)
+366
View File
@@ -0,0 +1,366 @@
"""Per-variant attach + train + save + load round-trip, plus surgical regressions.
The big invariant is the parametrized train_save_load test: identity at t=0,
gradient flow on a real loss, then save -> reload onto a fresh model and
confirm the trained outputs survive the round-trip. Cheap on CPU.
"""
from __future__ import annotations
from pathlib import Path
import pytest
import torch
from torch import nn
import lora_lite as ll
CFG_BY_VARIANT = {
"lora": ll.LoRAConfig,
"pissa": ll.PiSSAConfig,
"delora": ll.DeLoRAConfig,
"ia3": ll.IA3Config,
"ia3_ff": ll.IA3FFConfig,
"dora": ll.DoRAConfig,
"hra": ll.HRAConfig,
"eva": ll.EVAConfig,
"antipasto": ll.AntiPaSTOConfig,
"road": ll.RoadConfig,
}
# Per-variant identity tolerance at t=0 (after attach, before any step).
# fp32 SVD round-trip + per-row norm = looser tolerance for pissa/dora/antipasto.
IDENTITY_TOL = {
"lora": 1e-6,
"pissa": 5e-4,
"delora": 1e-6,
"ia3": 1e-6,
"ia3_ff": 1e-6,
"dora": 5e-5,
"hra": 5e-6,
"eva": 1e-6,
"antipasto": 5e-4,
"road": 1e-6,
}
class TinyBlock(nn.Module):
def __init__(self, d: int = 64, ff: int = 128):
super().__init__()
self.q_proj = nn.Linear(d, d, bias=False)
self.k_proj = nn.Linear(d, d, bias=False)
self.v_proj = nn.Linear(d, d, bias=False)
self.o_proj = nn.Linear(d, d, bias=False)
self.gate_proj = nn.Linear(d, ff, bias=False)
self.up_proj = nn.Linear(d, ff, bias=False)
self.down_proj = nn.Linear(ff, d, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.o_proj(self.q_proj(x) + self.k_proj(x) + self.v_proj(x))
m = self.down_proj(torch.nn.functional.silu(self.gate_proj(x)) * self.up_proj(x))
return x + h + m
class TinyModel(nn.Module):
def __init__(self, n_layers: int = 4, d: int = 64, ff: int = 128, vocab: int = 100):
super().__init__()
self.embed_tokens = nn.Embedding(vocab, d)
self.layers = nn.ModuleList([TinyBlock(d, ff) for _ in range(n_layers)])
self.lm_head = nn.Linear(d, vocab, bias=False)
self.config = type("Cfg", (), {"hidden_size": d})()
def forward(self, ids: torch.Tensor) -> torch.Tensor:
x = self.embed_tokens(ids)
for block in self.layers:
x = block(x)
return self.lm_head(x)
class FakeLinearLike(nn.Module):
"""linear-like, but not nn.Linear: stand-in for bnb 4/8-bit modules."""
def __init__(self, d_in: int = 8, d_out: int = 8):
super().__init__()
self.in_features = d_in
self.out_features = d_out
self.weight = nn.Parameter(torch.empty(d_out, d_in))
nn.init.kaiming_uniform_(self.weight, a=5 ** 0.5)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.linear(x, self.weight)
class FakeBnbModel(nn.Module):
def __init__(self):
super().__init__()
self.config = type("Cfg", (), {"hidden_size": 8})()
self.layers = nn.ModuleList([FakeLinearLike(8, 8)])
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layers[0](x)
def cfg_for(variant: str) -> ll.AdapterConfig:
return CFG_BY_VARIANT[variant](
r=4,
alpha=8,
dtype=torch.float32,
)
def attach_with_calib(model: nn.Module, cfg: ll.AdapterConfig, ids: torch.Tensor) -> None:
if cfg.variant == "eva":
calib = [ids for _ in range(2)]
ll.attach(model, cfg, calibration_data=calib)
else:
ll.attach(model, cfg)
def trainable_grad_norm(model: nn.Module) -> float:
return sum(
p.grad.detach().float().norm().item()
for n, p in model.named_parameters()
if "lora_" in n and p.grad is not None
)
@pytest.mark.parametrize("variant", list(CFG_BY_VARIANT))
def test_train_save_load(variant: str, tmp_path: Path):
"""Identity at t=0, one SGD step, save, reload onto fresh model, outputs match."""
torch.manual_seed(0)
model = TinyModel()
ids = torch.randint(0, 100, (2, 16))
with torch.no_grad():
y_base = model(ids).clone()
cfg = cfg_for(variant)
attach_with_calib(model, cfg, ids)
trainable = [p for p in model.parameters() if p.requires_grad]
assert trainable
assert all("lora_" in n for n, p in model.named_parameters() if p.requires_grad)
with torch.no_grad():
y_init = model(ids).clone()
assert (y_init - y_base).abs().max().item() < IDENTITY_TOL[variant]
target = torch.randn_like(y_init) * 0.1
opt = torch.optim.SGD(trainable, lr=1e-2)
opt.zero_grad()
loss = (model(ids) - target).pow(2).mean()
loss.backward()
leaked = [n for n, p in model.named_parameters() if "lora_" not in n and p.grad is not None]
assert leaked == []
assert trainable_grad_norm(model) > 0
opt.step()
with torch.no_grad():
y_trained = model(ids).clone()
path = tmp_path / "adapter.pt"
ll.save(model, str(path))
torch.manual_seed(0)
model_loaded = TinyModel()
ll.load(model_loaded, str(path)) # EVA load skips group_init; calibration_data not needed
with torch.no_grad():
y_loaded = model_loaded(ids)
assert (y_loaded - y_trained).abs().max().item() < max(IDENTITY_TOL[variant], 1e-5)
@pytest.mark.parametrize("variant", ["lora", "delora", "ia3", "hra", "road"])
def test_hook_only_variants_attach_to_non_linear_target(variant: str):
"""bnb-style targets are linear-like but not nn.Linear; hook-only variants must accept them."""
extra = {"lambda0": 0.1} if variant == "delora" else {"group_size": 8} if variant == "road" else {}
cfg = CFG_BY_VARIANT[variant](r=2, alpha=4, dtype=torch.float32, target_roles=(), **extra)
model = FakeBnbModel()
ll.attach(model, cfg)
x = torch.randn(2, 3, 8)
model(x).pow(2).mean().backward()
assert trainable_grad_norm(model) > 0
@pytest.mark.parametrize("variant", ["pissa", "dora", "antipasto"])
def test_weight_reading_variants_reject_non_linear(variant: str):
r = 4 if variant == "antipasto" else 2 # antipasto needs r % block_size==0
cfg = CFG_BY_VARIANT[variant](r=r, alpha=r, dtype=torch.float32, target_roles=())
with pytest.raises(TypeError, match="plain nn.Linear"):
ll.attach(FakeBnbModel(), cfg)
def test_save_load_strict_keys(tmp_path: Path):
import json
from safetensors.torch import load_file, save_file
torch.manual_seed(0)
model = TinyModel()
ll.attach(model, ll.LoRAConfig(r=4, alpha=8, dtype=torch.float32))
p = tmp_path / "lora.safetensors"
ll.save(model, str(p))
sd = load_file(str(p), device="cpu")
# missing key: drop first lora key
missing_sd = dict(sd)
dropped_key = next(iter(missing_sd))
del missing_sd[dropped_key]
from safetensors import safe_open
with safe_open(str(p), framework="pt", device="cpu") as f:
meta = f.metadata()
save_file(missing_sd, str(p), metadata=meta)
with pytest.raises(RuntimeError, match="missing lora keys"):
ll.load(TinyModel(), str(p))
# unexpected key: add a bogus lora key
bad_sd = dict(sd)
bad_sd["layers.0.q_proj.lora_extra"] = torch.zeros(1)
save_file(bad_sd, str(p), metadata=meta)
with pytest.raises(RuntimeError, match="unexpected lora keys"):
ll.load(TinyModel(), str(p))
def test_no_target_layers_is_loud():
cfg = ll.LoRAConfig(target_names=("definitely_missing",))
with pytest.raises(RuntimeError, match="no target layers"):
ll.attach(TinyModel(), cfg)
def test_eva_requires_calibration():
"""EVA's group_init must error loudly if calibration_data is missing."""
with pytest.raises(ValueError, match="calibration_data"):
ll.attach(TinyModel(), ll.EVAConfig(r=4, alpha=8, dtype=torch.float32))
def test_delora_default_has_live_step0_gradient():
"""Default lambda0 must be nonzero; B=0 preserves identity while B gets gradient."""
torch.manual_seed(0)
model = TinyModel(n_layers=1)
ids = torch.randint(0, 100, (2, 8))
ll.attach(model, ll.DeLoRAConfig(r=4, alpha=8, dtype=torch.float32))
assert model.layers[0].q_proj.lora_lambda.item() == pytest.approx(15.0)
loss = model(ids).pow(2).mean()
loss.backward()
b_grad = model.layers[0].q_proj.lora_B.grad.detach().abs().max().item()
assert b_grad > 0
def test_pissa_identity_with_nonunit_scale():
"""Regression: PiSSA must pre-divide S by alpha/r, not require alpha == r."""
torch.manual_seed(0)
model = TinyModel(n_layers=1)
ids = torch.randint(0, 100, (2, 8))
with torch.no_grad():
y_base = model(ids).clone()
ll.attach(model, ll.PiSSAConfig(r=4, alpha=8, dtype=torch.float32))
with torch.no_grad():
y = model(ids)
assert (y - y_base).abs().max().item() < IDENTITY_TOL["pissa"]
def test_antipasto_blockwise_rotation_matches_explicit_blockdiag():
"""The einsum/rearrange path must equal the old explicit blockdiag math."""
from lora_lite.variants.antipasto import _build_rotation
torch.manual_seed(0)
n_blocks, bs, d_in, d_out = 3, 4, 7, 5
r = n_blocks * bs
rot_T = torch.randn(n_blocks, bs * (bs - 1) // 2) * 0.1
Vh = torch.randn(r, d_in)
U = torch.randn(d_out, r)
R_blocks = _build_rotation(rot_T, bs, 0.5)
R = torch.block_diag(*list(R_blocks))
Vh_blocks = torch.reshape(Vh, (n_blocks, bs, d_in))
Vh_rot = torch.einsum("nab,nbi->nai", R_blocks, Vh_blocks).reshape(r, d_in)
U_blocks = torch.reshape(U, (d_out, n_blocks, bs))
U_rot = torch.einsum("dnb,ncb->dnc", U_blocks, R_blocks).reshape(d_out, r)
assert (Vh_rot - R @ Vh).abs().max().item() < 1e-6
assert (U_rot - U @ R.T).abs().max().item() < 1e-6
def test_dora_bias_passthrough():
"""Regression: DoRA must NOT scale bias; identity holds with bias=True at t=0."""
torch.manual_seed(0)
d = 16
layer = nn.Linear(d, d, bias=True)
x = torch.randn(2, d)
y_base = layer(x).detach()
class Wrap(nn.Module):
def __init__(self, lin):
super().__init__()
self.config = type("Cfg", (), {"hidden_size": d})()
self.layers = nn.ModuleList([lin])
def forward(self, x):
return self.layers[0](x)
model = Wrap(layer)
ll.attach(model, ll.DoRAConfig(r=2, alpha=4, dtype=torch.float32, target_roles=()))
with torch.no_grad():
y = model(x)
assert (y - y_base).abs().max().item() < 1e-5
def test_hra_forward_is_x_R_T():
"""HRA must apply x @ R^T (loop i = r-1 down to 0). Asymmetric U makes order observable."""
torch.manual_seed(0)
d = 8
layer = nn.Linear(d, d, bias=False)
x = torch.randn(2, 3, d)
class Wrap(nn.Module):
def __init__(self, lin):
super().__init__()
self.config = type("Cfg", (), {"hidden_size": d})()
self.layers = nn.ModuleList([lin])
def forward(self, x):
return self.layers[0](x)
model = Wrap(layer)
ll.attach(model, ll.HRAConfig(r=4, alpha=4, dtype=torch.float32, target_roles=()))
# break paired symmetry so order matters
with torch.no_grad():
layer.lora_U.add_(0.1 * torch.randn_like(layer.lora_U))
U = layer.lora_U
R = torch.eye(d)
for i in range(U.shape[0]):
u = U[i]
sq = (u * u).sum().clamp_min(1e-12)
R = R - (2.0 / sq) * torch.outer(R @ u, u)
with torch.no_grad():
y_adapt = model(x)
y_ref = torch.nn.functional.linear(x, layer.weight @ R)
assert (y_adapt - y_ref).abs().max().item() < 1e-5
@pytest.mark.parametrize("road_variant", ["road_1", "road_2", "road_4"])
def test_road_apply_matches_explicit_matrix(road_variant: str):
"""Fast elementwise ROAD path must match PEFT's explicit R @ y matrix construction."""
from lora_lite.variants.road import _apply_road, _road_matrix, _road_param_size
torch.manual_seed(0)
d_out = 16
group_size = 8
size = _road_param_size(d_out, road_variant)
theta = torch.randn(size) * 0.2
alpha = torch.randn(size) * 0.1 + 1.0
y = torch.randn(2, 3, d_out)
y_fast = _apply_road(road_variant, group_size, theta, alpha, y)
R = _road_matrix(road_variant, group_size, theta, alpha)
y_ref = torch.einsum("oi,...i->...o", R, y)
assert (y_fast - y_ref).abs().max().item() < 1e-6
def test_road_invalid_group_size_is_loud():
with pytest.raises(ValueError, match="positive and even"):
ll.attach(TinyModel(), ll.RoadConfig(group_size=7))
with pytest.raises(ValueError, match="divisible"):
ll.attach(TinyModel(), ll.RoadConfig(group_size=48))
+2 -1
View File
@@ -31,7 +31,7 @@ sys.modules[SPEC.name] = benchmark
SPEC.loader.exec_module(benchmark) SPEC.loader.exec_module(benchmark)
VARIANTS = ["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto"] VARIANTS = ["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "road"]
# Variants that fail loud when attached on a bnb-loaded base (read dense weight in init). # Variants that fail loud when attached on a bnb-loaded base (read dense weight in init).
# delora/eva also read weight but currently silently dequant -- they produce sane attach, # delora/eva also read weight but currently silently dequant -- they produce sane attach,
# so we don't expect a raise from them in the attach-only smoke. # so we don't expect a raise from them in the attach-only smoke.
@@ -69,6 +69,7 @@ def quick_cfg(variant: str, tmp_path: Path, quantization: str = "none") -> "benc
max_seq_length=128, max_seq_length=128,
max_new_tokens=8, max_new_tokens=8,
lr=5e-3, lr=5e-3,
road_group_size=8,
seed=0, seed=0,
log_examples=0, log_examples=0,
log_every=1000, log_every=1000,
Generated
+3 -1
View File
@@ -12,7 +12,7 @@ resolution-markers = [
] ]
[options] [options]
exclude-newer = "2026-04-21T14:06:04.428693663Z" exclude-newer = "2026-04-22T01:31:09.47902161Z"
exclude-newer-span = "P5D" exclude-newer-span = "P5D"
[[package]] [[package]]
@@ -1012,6 +1012,7 @@ dependencies = [
{ name = "einops" }, { name = "einops" },
{ name = "jaxtyping", version = "0.3.7", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "jaxtyping", version = "0.3.7", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
{ name = "jaxtyping", version = "0.3.9", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "jaxtyping", version = "0.3.9", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
{ name = "safetensors" },
{ name = "torch" }, { name = "torch" },
] ]
@@ -1051,6 +1052,7 @@ requires-dist = [
{ name = "einops", specifier = ">=0.7" }, { name = "einops", specifier = ">=0.7" },
{ name = "jaxtyping", specifier = ">=0.2.34" }, { name = "jaxtyping", specifier = ">=0.2.34" },
{ name = "pytest", marker = "extra == 'test'" }, { name = "pytest", marker = "extra == 'test'" },
{ name = "safetensors", specifier = ">=0.5" },
{ name = "safetensors", marker = "extra == 'benchmark'", specifier = ">=0.5" }, { name = "safetensors", marker = "extra == 'benchmark'", specifier = ">=0.5" },
{ name = "safetensors", marker = "extra == 'hf-test'", specifier = ">=0.5" }, { name = "safetensors", marker = "extra == 'hf-test'", specifier = ">=0.5" },
{ name = "tabulate", marker = "extra == 'benchmark'" }, { name = "tabulate", marker = "extra == 'benchmark'" },