mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 16:30:44 +08:00
simpler test
This commit is contained in:
@@ -8,6 +8,7 @@ dependencies = [
|
||||
"torch>=2.1",
|
||||
"einops>=0.7",
|
||||
"jaxtyping>=0.2.34",
|
||||
"safetensors>=0.5",
|
||||
]
|
||||
keywords = ["lora", "pytorch", "peft", "adapters", "llm"]
|
||||
classifiers = [
|
||||
|
||||
@@ -33,6 +33,7 @@ CFG_BY_VARIANT = {
|
||||
"hra": ll.HRAConfig,
|
||||
"eva": ll.EVAConfig,
|
||||
"antipasto": ll.AntiPaSTOConfig,
|
||||
"road": ll.RoadConfig,
|
||||
}
|
||||
|
||||
|
||||
@@ -41,7 +42,7 @@ class BenchmarkConfig:
|
||||
"""MetaMathQA -> GSM8K benchmark config. Tyro turns this into the CLI."""
|
||||
|
||||
model: str = "Qwen/Qwen3-0.6B-Base"
|
||||
variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto"] = "lora"
|
||||
variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "road"] = "lora"
|
||||
mode: Literal["benchmark", "probe"] = "benchmark"
|
||||
device: str = "cuda"
|
||||
torch_dtype: str = "bfloat16"
|
||||
@@ -49,6 +50,7 @@ class BenchmarkConfig:
|
||||
r: int = 32
|
||||
alpha: float = 64.0
|
||||
delora_lambda0: float = 0.1
|
||||
road_group_size: int = 64
|
||||
target_name: list[str] = field(default_factory=lambda: list(DEFAULT_TARGETS))
|
||||
layers: str = "all"
|
||||
train_dataset: str = "meta-math/MetaMathQA"
|
||||
@@ -118,6 +120,8 @@ def parse_layers(text: str) -> tuple[int, ...] | None:
|
||||
|
||||
def cfg_for_variant(args: BenchmarkConfig, dtype: torch.dtype) -> ll.AdapterConfig:
|
||||
extra = {"lambda0": args.delora_lambda0} if args.variant == "delora" else {}
|
||||
if args.variant == "road":
|
||||
extra = {"group_size": args.road_group_size}
|
||||
return CFG_BY_VARIANT[args.variant](
|
||||
r=args.r,
|
||||
alpha=args.r if args.variant == "pissa" else args.alpha,
|
||||
@@ -147,7 +151,7 @@ def count_base_grad_leaks(model: torch.nn.Module) -> int:
|
||||
|
||||
|
||||
def perturb_first_adapter(model: torch.nn.Module) -> None:
|
||||
priority = ("lora_B", "lora_g", "lora_U", "lora_A", "lora_lambda", "lora_gate", "lora_delta_s", "lora_rot_T", "lora_m")
|
||||
priority = ("lora_B", "lora_g", "lora_U", "lora_A", "lora_lambda", "lora_gate", "lora_delta_s", "lora_rot_T", "lora_m", "lora_road_theta", "lora_road_alpha")
|
||||
for key in priority:
|
||||
for _, p in model.named_parameters():
|
||||
if p.requires_grad and key in _:
|
||||
@@ -409,11 +413,12 @@ def check_probe_reload(
|
||||
loaded_model, _ = load_model_and_tokenizer(args.model, getattr(torch, args.torch_dtype), args.device, args.quantization)
|
||||
loaded_model.eval()
|
||||
ll.load(loaded_model, str(adapter_path))
|
||||
saved = torch.load(adapter_path, weights_only=True, map_location="cpu")
|
||||
from safetensors.torch import load_file
|
||||
saved_sd = load_file(str(adapter_path), device="cpu")
|
||||
loaded_state = adapter_state(loaded_model)
|
||||
if set(saved["state"]) != set(loaded_state):
|
||||
if set(saved_sd) != set(loaded_state):
|
||||
raise AssertionError("loaded adapter keys differ from saved adapter keys")
|
||||
for name, value in saved["state"].items():
|
||||
for name, value in saved_sd.items():
|
||||
if not torch.equal(loaded_state[name].cpu(), value):
|
||||
raise AssertionError(f"loaded adapter tensor differs: {name}")
|
||||
logits_loaded = loaded_model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).logits.detach().clone()
|
||||
@@ -426,12 +431,21 @@ def check_probe_reload(
|
||||
return {"reload_err": reload_err, "saved_tensors": len(saved["state"])}
|
||||
|
||||
|
||||
def print_final_report(row: dict[str, Any], result_path: Path) -> None:
|
||||
print("SHOULD: grad>0, dθ>0, base_grad_leaks=0, valid/test fields present; probeΔ<0 is good but not required for tiny random smoke. ELSE adapter or eval wiring is dead/wrong.")
|
||||
def print_final_report(row: dict[str, Any], result_path: Path, mode: str) -> None:
|
||||
# BLUF: status line first so log tails are immediately readable
|
||||
cue = "🟢" if row.get("base_grad_leaks", 0) == 0 and row.get("grad", 0) > 0 else "🔴"
|
||||
n = row.get("samples", "?")
|
||||
print(f"{cue} test_acc={row['test_acc']:.4g} valid_acc={row['valid_acc']:.4g} grad={row['grad']:.3g} dθ={row['dθ']:.3g} base_grad_leaks={row['base_grad_leaks']} N={n}")
|
||||
print("SHOULD: grad>0, dθ>0, base_grad_leaks=0; test/valid_acc meaningful only in benchmark mode. ELSE adapter or eval wiring is dead/wrong.")
|
||||
# ordered: most important / shortest columns first
|
||||
display_keys = ["variant", "test_acc", "valid_acc", "grad", "dθ", "base_grad_leaks", "steps", "samples", "loss0", "lossN", "commit"]
|
||||
if "perturb" in row:
|
||||
display_keys += ["perturb", "reload"]
|
||||
display_keys += ["run_id"]
|
||||
display_row = {k: row[k] for k in display_keys if k in row}
|
||||
print(tabulate([display_row], headers="keys", tablefmt="tsv", floatfmt=".4g"))
|
||||
print(f"argv: {' '.join(sys.argv)} N={n} mode={mode}")
|
||||
print(f"out: {result_path}")
|
||||
print(f"argv: {' '.join(sys.argv)}")
|
||||
print(f"main metric: test_acc={row['test_acc']:.4g} valid_acc={row['valid_acc']:.4g} steps={row['steps']} samples={row['samples']}")
|
||||
print(tabulate([row], headers="keys", tablefmt="tsv", floatfmt=".4g"))
|
||||
|
||||
|
||||
def current_git_commit() -> str:
|
||||
@@ -447,25 +461,24 @@ def append_results_row(
|
||||
result: dict[str, Any],
|
||||
run_commit: str,
|
||||
) -> tuple[Path, Path]:
|
||||
results_dir = args.output_dir / "results"
|
||||
results_dir = args.output_dir
|
||||
results_dir.mkdir(parents=True, exist_ok=True)
|
||||
tsv_path = results_dir / "benchmark_results.tsv"
|
||||
lock_path = results_dir / "benchmark_results.tsv.lock"
|
||||
tsv_path = results_dir / "summary.tsv"
|
||||
lock_path = results_dir / "summary.tsv.lock"
|
||||
finished_at = datetime.now(timezone.utc).isoformat(timespec="seconds")
|
||||
finished_label = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||
snapshot_path = results_dir / f"{result['run_id']}__{finished_label}.json"
|
||||
snapshot_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
|
||||
row = {
|
||||
"time_utc": finished_at,
|
||||
"commit": run_commit,
|
||||
"test_acc": result["test_acc"],
|
||||
"valid_acc": result["valid_acc"],
|
||||
"method": args.variant,
|
||||
"model": args.model,
|
||||
"mode": args.mode,
|
||||
"valid_accuracy": result["valid_accuracy"],
|
||||
"test_accuracy": result["test_accuracy"],
|
||||
"steps": args.steps,
|
||||
"samples": result["train_samples"],
|
||||
"wall_time_s": result["wall_time_s"],
|
||||
"model": args.model,
|
||||
"commit": run_commit[:12],
|
||||
"wall_time_s": round(result["wall_time_s"]),
|
||||
"time_utc": finished_at,
|
||||
"argv": " ".join(sys.argv),
|
||||
"result_json": str(snapshot_path),
|
||||
"latest_result_json": str(result_path),
|
||||
@@ -517,7 +530,7 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
|
||||
valid_metrics = evaluate(model, tokenizer, datasets["valid"], args, "valid")
|
||||
test_metrics = evaluate(model, tokenizer, datasets["test"], args, "test")
|
||||
|
||||
adapter_path = out_dir / "adapter.pt"
|
||||
adapter_path = out_dir / "adapter.safetensors"
|
||||
ll.save(model, str(adapter_path))
|
||||
if args.mode == "probe":
|
||||
model.eval()
|
||||
@@ -555,8 +568,8 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
|
||||
"weight_decay": args.weight_decay,
|
||||
"lr_scheduler": "cosine",
|
||||
"grad_norm_clip": args.grad_norm_clip,
|
||||
"valid_accuracy": valid_metrics["accuracy"],
|
||||
"test_accuracy": test_metrics["accuracy"],
|
||||
"valid_acc": valid_metrics["accuracy"],
|
||||
"test_acc": test_metrics["accuracy"],
|
||||
"train": train_metrics,
|
||||
"valid": valid_metrics,
|
||||
"test": test_metrics,
|
||||
@@ -566,12 +579,12 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
|
||||
}
|
||||
result_path = out_dir / "result.json"
|
||||
result_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
|
||||
results_tsv_path, result_snapshot_path = append_results_row(args, result_path, result, run_commit)
|
||||
result["results_tsv_path"] = str(results_tsv_path)
|
||||
result["result_snapshot_path"] = str(result_snapshot_path)
|
||||
if args.mode == "benchmark":
|
||||
results_tsv_path, result_snapshot_path = append_results_row(args, result_path, result, run_commit)
|
||||
result["results_tsv_path"] = str(results_tsv_path)
|
||||
result["result_snapshot_path"] = str(result_snapshot_path)
|
||||
result["commit"] = run_commit
|
||||
result_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
|
||||
commit_prefix = run_commit[:12]
|
||||
|
||||
row = {
|
||||
"run_id": run_id,
|
||||
@@ -592,7 +605,7 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
|
||||
if probe_metrics is not None:
|
||||
row["perturb"] = probe_metrics["perturb_delta"]
|
||||
row["reload"] = probe_metrics["reload_err"]
|
||||
print_final_report(row, result_path)
|
||||
print_final_report(row, result_path, args.mode)
|
||||
return result
|
||||
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ from .variants.dora import DoRAConfig
|
||||
from .variants.hra import HRAConfig
|
||||
from .variants.eva import EVAConfig
|
||||
from .variants.antipasto import AntiPaSTOConfig
|
||||
from .variants.road import RoadConfig
|
||||
|
||||
__all__ = [
|
||||
"AdapterConfig",
|
||||
@@ -32,6 +33,7 @@ __all__ = [
|
||||
"HRAConfig",
|
||||
"EVAConfig",
|
||||
"AntiPaSTOConfig",
|
||||
"RoadConfig",
|
||||
"attach",
|
||||
"detach",
|
||||
"save",
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""attach / detach / save / load. The whole runtime."""
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
@@ -121,19 +122,22 @@ def save(model: nn.Module, path: str) -> None:
|
||||
if state is None:
|
||||
raise RuntimeError("no adapter attached; call attach() first")
|
||||
sd = {k: v.detach().cpu() for k, v in model.state_dict().items() if "lora_" in k}
|
||||
blob = {
|
||||
"cfg": state["cfg"].to_dict(),
|
||||
"state": sd,
|
||||
"base_fp": _base_weight_fingerprint(model),
|
||||
metadata = {
|
||||
"cfg": json.dumps(state["cfg"].to_dict()),
|
||||
"base_fp": json.dumps(_base_weight_fingerprint(model)),
|
||||
}
|
||||
torch.save(blob, path)
|
||||
from safetensors.torch import save_file
|
||||
save_file(sd, path, metadata=metadata)
|
||||
|
||||
|
||||
def load(model: nn.Module, path: str) -> list[RemovableHandle]:
|
||||
blob = torch.load(path, weights_only=True, map_location="cpu")
|
||||
cfg = AdapterConfig.from_dict(blob["cfg"])
|
||||
from safetensors.torch import load_file, safe_open
|
||||
with safe_open(path, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata()
|
||||
sd = load_file(path, device="cpu")
|
||||
cfg = AdapterConfig.from_dict(json.loads(metadata["cfg"]))
|
||||
handles = attach(model, cfg, _skip_group_init=True) # creates empty params; data-driven inits restored from state_dict
|
||||
missing, unexpected = model.load_state_dict(blob["state"], strict=False)
|
||||
missing, unexpected = model.load_state_dict(sd, strict=False)
|
||||
expected_lora = {k for k in model.state_dict() if "lora_" in k}
|
||||
missing_lora = sorted(expected_lora.intersection(missing))
|
||||
if missing_lora:
|
||||
@@ -141,7 +145,7 @@ def load(model: nn.Module, path: str) -> list[RemovableHandle]:
|
||||
unexpected_lora = [k for k in unexpected if "lora_" in k]
|
||||
if unexpected_lora:
|
||||
raise RuntimeError(f"unexpected lora keys in checkpoint: {unexpected_lora}")
|
||||
saved_fp = blob.get("base_fp", {})
|
||||
saved_fp = json.loads(metadata.get("base_fp", "{}"))
|
||||
if saved_fp:
|
||||
cur_fp = _base_weight_fingerprint(model)
|
||||
diffs = [k for k in saved_fp if saved_fp[k] != cur_fp.get(k)]
|
||||
|
||||
@@ -1 +1 @@
|
||||
from . import lora, pissa, delora, ia3, dora, hra, eva, antipasto # noqa: F401 side-effect: register
|
||||
from . import lora, pissa, delora, ia3, dora, hra, eva, antipasto, road # noqa: F401 side-effect: register
|
||||
|
||||
@@ -37,7 +37,6 @@ class LoRA:
|
||||
|
||||
@staticmethod
|
||||
def init(layer: nn.Module, cfg) -> None:
|
||||
# B is zeros => delta=0 at t=0; identity invariant holds.
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -0,0 +1,137 @@
|
||||
"""ROAD: Rotation ADaptation. https://arxiv.org/abs/2409.00119
|
||||
|
||||
ROAD applies a learned output-space block rotation/scaling after the frozen base
|
||||
layer:
|
||||
|
||||
y' = R y = R (W x + b)
|
||||
|
||||
This matches PEFT's unmerged forward path and fits lora-lite as a simple output
|
||||
hook. We implement the three PEFT variants (`road_1`, `road_2`, `road_4`) and
|
||||
skip merge/unmerge because this library keeps adapters as hooks.
|
||||
|
||||
Refs:
|
||||
- peft: https://github.com/huggingface/peft/blob/6030f9160ed2fc17220f6f41382a66f1257b6a93/src/peft/tuners/road/layer.py
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from jaxtyping import Float
|
||||
from torch import nn, Tensor as T
|
||||
|
||||
from ..config import AdapterConfig, register_config
|
||||
from ..variant import ParamSpec, register
|
||||
|
||||
RoadVariant = Literal["road_1", "road_2", "road_4"]
|
||||
|
||||
|
||||
@register_config
|
||||
@dataclass
|
||||
class RoadConfig(AdapterConfig):
|
||||
variant: str = "road"
|
||||
road_variant: RoadVariant = "road_1"
|
||||
group_size: int = 64
|
||||
|
||||
|
||||
def _road_param_size(d_out: int, road_variant: str) -> int:
|
||||
if road_variant == "road_1":
|
||||
return d_out // 2
|
||||
if road_variant == "road_2":
|
||||
return d_out
|
||||
if road_variant == "road_4":
|
||||
return d_out * 2
|
||||
raise ValueError(f"road_variant must be 'road_1', 'road_2', or 'road_4', got {road_variant!r}")
|
||||
|
||||
|
||||
def _validate_group_geometry(d_out: int, group_size: int) -> None:
|
||||
if group_size <= 0 or group_size % 2 != 0:
|
||||
raise ValueError(f"ROAD group_size must be positive and even, got {group_size}")
|
||||
if d_out % group_size != 0:
|
||||
raise ValueError(f"ROAD d_out={d_out} must be divisible by group_size={group_size}")
|
||||
|
||||
|
||||
def _prepare_cols(
|
||||
road_variant: str,
|
||||
group_size: int,
|
||||
road_theta: torch.Tensor,
|
||||
road_alpha: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if road_variant == "road_1":
|
||||
# One θ/α per pair. Reuse it for both rows of each 2D rotation block.
|
||||
road_theta = road_theta.reshape(-1, group_size // 2).repeat_interleave(2, dim=0).flatten()
|
||||
road_alpha = road_alpha.reshape(-1, group_size // 2).repeat_interleave(2, dim=0).flatten()
|
||||
first_col = road_alpha * road_theta.cos()
|
||||
second_col = road_alpha * road_theta.sin()
|
||||
elif road_variant == "road_2":
|
||||
# One θ/α per output coordinate.
|
||||
first_col = road_alpha * road_theta.cos()
|
||||
second_col = road_alpha * road_theta.sin()
|
||||
elif road_variant == "road_4":
|
||||
# Independent θ/α for the first and second column contributions.
|
||||
road_theta = road_theta.reshape(-1, 2, group_size)
|
||||
road_alpha = road_alpha.reshape(-1, 2, group_size)
|
||||
first_col = road_alpha[:, 0, :].flatten() * road_theta[:, 0, :].cos().flatten()
|
||||
second_col = road_alpha[:, 1, :].flatten() * road_theta[:, 1, :].sin().flatten()
|
||||
else:
|
||||
raise ValueError(f"road_variant must be 'road_1', 'road_2', or 'road_4', got {road_variant!r}")
|
||||
return first_col, second_col
|
||||
|
||||
|
||||
def _apply_road(
|
||||
road_variant: str,
|
||||
group_size: int,
|
||||
road_theta: torch.Tensor,
|
||||
road_alpha: torch.Tensor,
|
||||
y: Float[T, '*B o'],
|
||||
) -> Float[T, '*B o']:
|
||||
first_col, second_col = _prepare_cols(road_variant, group_size, road_theta, road_alpha)
|
||||
y_grouped = y.reshape(-1, 2, group_size // 2)
|
||||
y1 = y_grouped[:, 0, :]
|
||||
y2 = y_grouped[:, 1, :]
|
||||
rotate_half_y = torch.stack((-y2, y1), dim=1).reshape(y.shape)
|
||||
return y * first_col + rotate_half_y * second_col
|
||||
|
||||
|
||||
def _road_matrix(
|
||||
road_variant: str,
|
||||
group_size: int,
|
||||
road_theta: torch.Tensor,
|
||||
road_alpha: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Explicit PEFT merge matrix. Used for tests and small-debug inspection."""
|
||||
first_col, second_col = _prepare_cols(road_variant, group_size, road_theta, road_alpha)
|
||||
size = second_col.shape[0]
|
||||
output = torch.diag(first_col)
|
||||
swapped_second_col = second_col.reshape(-1, 2, group_size // 2)[:, [1, 0], :].flatten()
|
||||
rotated_diag_second_col = torch.diag(swapped_second_col).reshape(-1, 2, group_size // 2, size)[:, [1, 0], :, :]
|
||||
rotated_diag_second_col[:, 0, :, :] *= -1
|
||||
return output + rotated_diag_second_col.reshape(size, size)
|
||||
|
||||
|
||||
@register
|
||||
class ROAD:
|
||||
name = "road"
|
||||
|
||||
@staticmethod
|
||||
def param_specs(d_in: int, d_out: int, cfg: RoadConfig) -> dict[str, ParamSpec]:
|
||||
_validate_group_geometry(d_out, cfg.group_size)
|
||||
size = _road_param_size(d_out, cfg.road_variant)
|
||||
return {
|
||||
"lora_road_theta": ParamSpec((size,), init="zeros", trainable=True),
|
||||
"lora_road_alpha": ParamSpec((size,), init="ones", trainable=True),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def init(layer: nn.Module, cfg: RoadConfig) -> None:
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
layer: nn.Module,
|
||||
x: Float[T, '*B i'],
|
||||
y: Float[T, '*B o'],
|
||||
) -> Float[T, '*B o']:
|
||||
del x
|
||||
cfg = layer._lora_cfg
|
||||
y_cast = y.to(layer.lora_road_theta.dtype)
|
||||
return _apply_road(cfg.road_variant, cfg.group_size, layer.lora_road_theta, layer.lora_road_alpha, y_cast)
|
||||
@@ -0,0 +1,366 @@
|
||||
"""Per-variant attach + train + save + load round-trip, plus surgical regressions.
|
||||
|
||||
The big invariant is the parametrized train_save_load test: identity at t=0,
|
||||
gradient flow on a real loss, then save -> reload onto a fresh model and
|
||||
confirm the trained outputs survive the round-trip. Cheap on CPU.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
import lora_lite as ll
|
||||
|
||||
|
||||
CFG_BY_VARIANT = {
|
||||
"lora": ll.LoRAConfig,
|
||||
"pissa": ll.PiSSAConfig,
|
||||
"delora": ll.DeLoRAConfig,
|
||||
"ia3": ll.IA3Config,
|
||||
"ia3_ff": ll.IA3FFConfig,
|
||||
"dora": ll.DoRAConfig,
|
||||
"hra": ll.HRAConfig,
|
||||
"eva": ll.EVAConfig,
|
||||
"antipasto": ll.AntiPaSTOConfig,
|
||||
"road": ll.RoadConfig,
|
||||
}
|
||||
|
||||
# Per-variant identity tolerance at t=0 (after attach, before any step).
|
||||
# fp32 SVD round-trip + per-row norm = looser tolerance for pissa/dora/antipasto.
|
||||
IDENTITY_TOL = {
|
||||
"lora": 1e-6,
|
||||
"pissa": 5e-4,
|
||||
"delora": 1e-6,
|
||||
"ia3": 1e-6,
|
||||
"ia3_ff": 1e-6,
|
||||
"dora": 5e-5,
|
||||
"hra": 5e-6,
|
||||
"eva": 1e-6,
|
||||
"antipasto": 5e-4,
|
||||
"road": 1e-6,
|
||||
}
|
||||
|
||||
|
||||
class TinyBlock(nn.Module):
|
||||
def __init__(self, d: int = 64, ff: int = 128):
|
||||
super().__init__()
|
||||
self.q_proj = nn.Linear(d, d, bias=False)
|
||||
self.k_proj = nn.Linear(d, d, bias=False)
|
||||
self.v_proj = nn.Linear(d, d, bias=False)
|
||||
self.o_proj = nn.Linear(d, d, bias=False)
|
||||
self.gate_proj = nn.Linear(d, ff, bias=False)
|
||||
self.up_proj = nn.Linear(d, ff, bias=False)
|
||||
self.down_proj = nn.Linear(ff, d, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
h = self.o_proj(self.q_proj(x) + self.k_proj(x) + self.v_proj(x))
|
||||
m = self.down_proj(torch.nn.functional.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
return x + h + m
|
||||
|
||||
|
||||
class TinyModel(nn.Module):
|
||||
def __init__(self, n_layers: int = 4, d: int = 64, ff: int = 128, vocab: int = 100):
|
||||
super().__init__()
|
||||
self.embed_tokens = nn.Embedding(vocab, d)
|
||||
self.layers = nn.ModuleList([TinyBlock(d, ff) for _ in range(n_layers)])
|
||||
self.lm_head = nn.Linear(d, vocab, bias=False)
|
||||
self.config = type("Cfg", (), {"hidden_size": d})()
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
x = self.embed_tokens(ids)
|
||||
for block in self.layers:
|
||||
x = block(x)
|
||||
return self.lm_head(x)
|
||||
|
||||
|
||||
class FakeLinearLike(nn.Module):
|
||||
"""linear-like, but not nn.Linear: stand-in for bnb 4/8-bit modules."""
|
||||
|
||||
def __init__(self, d_in: int = 8, d_out: int = 8):
|
||||
super().__init__()
|
||||
self.in_features = d_in
|
||||
self.out_features = d_out
|
||||
self.weight = nn.Parameter(torch.empty(d_out, d_in))
|
||||
nn.init.kaiming_uniform_(self.weight, a=5 ** 0.5)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.nn.functional.linear(x, self.weight)
|
||||
|
||||
|
||||
class FakeBnbModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.config = type("Cfg", (), {"hidden_size": 8})()
|
||||
self.layers = nn.ModuleList([FakeLinearLike(8, 8)])
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.layers[0](x)
|
||||
|
||||
|
||||
def cfg_for(variant: str) -> ll.AdapterConfig:
|
||||
return CFG_BY_VARIANT[variant](
|
||||
r=4,
|
||||
alpha=8,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
|
||||
def attach_with_calib(model: nn.Module, cfg: ll.AdapterConfig, ids: torch.Tensor) -> None:
|
||||
if cfg.variant == "eva":
|
||||
calib = [ids for _ in range(2)]
|
||||
ll.attach(model, cfg, calibration_data=calib)
|
||||
else:
|
||||
ll.attach(model, cfg)
|
||||
|
||||
|
||||
def trainable_grad_norm(model: nn.Module) -> float:
|
||||
return sum(
|
||||
p.grad.detach().float().norm().item()
|
||||
for n, p in model.named_parameters()
|
||||
if "lora_" in n and p.grad is not None
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("variant", list(CFG_BY_VARIANT))
|
||||
def test_train_save_load(variant: str, tmp_path: Path):
|
||||
"""Identity at t=0, one SGD step, save, reload onto fresh model, outputs match."""
|
||||
torch.manual_seed(0)
|
||||
model = TinyModel()
|
||||
ids = torch.randint(0, 100, (2, 16))
|
||||
with torch.no_grad():
|
||||
y_base = model(ids).clone()
|
||||
|
||||
cfg = cfg_for(variant)
|
||||
attach_with_calib(model, cfg, ids)
|
||||
|
||||
trainable = [p for p in model.parameters() if p.requires_grad]
|
||||
assert trainable
|
||||
assert all("lora_" in n for n, p in model.named_parameters() if p.requires_grad)
|
||||
|
||||
with torch.no_grad():
|
||||
y_init = model(ids).clone()
|
||||
assert (y_init - y_base).abs().max().item() < IDENTITY_TOL[variant]
|
||||
|
||||
target = torch.randn_like(y_init) * 0.1
|
||||
opt = torch.optim.SGD(trainable, lr=1e-2)
|
||||
opt.zero_grad()
|
||||
loss = (model(ids) - target).pow(2).mean()
|
||||
loss.backward()
|
||||
leaked = [n for n, p in model.named_parameters() if "lora_" not in n and p.grad is not None]
|
||||
assert leaked == []
|
||||
assert trainable_grad_norm(model) > 0
|
||||
opt.step()
|
||||
|
||||
with torch.no_grad():
|
||||
y_trained = model(ids).clone()
|
||||
|
||||
path = tmp_path / "adapter.pt"
|
||||
ll.save(model, str(path))
|
||||
|
||||
torch.manual_seed(0)
|
||||
model_loaded = TinyModel()
|
||||
ll.load(model_loaded, str(path)) # EVA load skips group_init; calibration_data not needed
|
||||
with torch.no_grad():
|
||||
y_loaded = model_loaded(ids)
|
||||
assert (y_loaded - y_trained).abs().max().item() < max(IDENTITY_TOL[variant], 1e-5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("variant", ["lora", "delora", "ia3", "hra", "road"])
|
||||
def test_hook_only_variants_attach_to_non_linear_target(variant: str):
|
||||
"""bnb-style targets are linear-like but not nn.Linear; hook-only variants must accept them."""
|
||||
extra = {"lambda0": 0.1} if variant == "delora" else {"group_size": 8} if variant == "road" else {}
|
||||
cfg = CFG_BY_VARIANT[variant](r=2, alpha=4, dtype=torch.float32, target_roles=(), **extra)
|
||||
model = FakeBnbModel()
|
||||
ll.attach(model, cfg)
|
||||
x = torch.randn(2, 3, 8)
|
||||
model(x).pow(2).mean().backward()
|
||||
assert trainable_grad_norm(model) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("variant", ["pissa", "dora", "antipasto"])
|
||||
def test_weight_reading_variants_reject_non_linear(variant: str):
|
||||
r = 4 if variant == "antipasto" else 2 # antipasto needs r % block_size==0
|
||||
cfg = CFG_BY_VARIANT[variant](r=r, alpha=r, dtype=torch.float32, target_roles=())
|
||||
with pytest.raises(TypeError, match="plain nn.Linear"):
|
||||
ll.attach(FakeBnbModel(), cfg)
|
||||
|
||||
|
||||
def test_save_load_strict_keys(tmp_path: Path):
|
||||
import json
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = TinyModel()
|
||||
ll.attach(model, ll.LoRAConfig(r=4, alpha=8, dtype=torch.float32))
|
||||
p = tmp_path / "lora.safetensors"
|
||||
ll.save(model, str(p))
|
||||
sd = load_file(str(p), device="cpu")
|
||||
|
||||
# missing key: drop first lora key
|
||||
missing_sd = dict(sd)
|
||||
dropped_key = next(iter(missing_sd))
|
||||
del missing_sd[dropped_key]
|
||||
from safetensors import safe_open
|
||||
with safe_open(str(p), framework="pt", device="cpu") as f:
|
||||
meta = f.metadata()
|
||||
save_file(missing_sd, str(p), metadata=meta)
|
||||
with pytest.raises(RuntimeError, match="missing lora keys"):
|
||||
ll.load(TinyModel(), str(p))
|
||||
|
||||
# unexpected key: add a bogus lora key
|
||||
bad_sd = dict(sd)
|
||||
bad_sd["layers.0.q_proj.lora_extra"] = torch.zeros(1)
|
||||
save_file(bad_sd, str(p), metadata=meta)
|
||||
with pytest.raises(RuntimeError, match="unexpected lora keys"):
|
||||
ll.load(TinyModel(), str(p))
|
||||
|
||||
|
||||
def test_no_target_layers_is_loud():
|
||||
cfg = ll.LoRAConfig(target_names=("definitely_missing",))
|
||||
with pytest.raises(RuntimeError, match="no target layers"):
|
||||
ll.attach(TinyModel(), cfg)
|
||||
|
||||
|
||||
def test_eva_requires_calibration():
|
||||
"""EVA's group_init must error loudly if calibration_data is missing."""
|
||||
with pytest.raises(ValueError, match="calibration_data"):
|
||||
ll.attach(TinyModel(), ll.EVAConfig(r=4, alpha=8, dtype=torch.float32))
|
||||
|
||||
|
||||
def test_delora_default_has_live_step0_gradient():
|
||||
"""Default lambda0 must be nonzero; B=0 preserves identity while B gets gradient."""
|
||||
torch.manual_seed(0)
|
||||
model = TinyModel(n_layers=1)
|
||||
ids = torch.randint(0, 100, (2, 8))
|
||||
ll.attach(model, ll.DeLoRAConfig(r=4, alpha=8, dtype=torch.float32))
|
||||
|
||||
assert model.layers[0].q_proj.lora_lambda.item() == pytest.approx(15.0)
|
||||
loss = model(ids).pow(2).mean()
|
||||
loss.backward()
|
||||
|
||||
b_grad = model.layers[0].q_proj.lora_B.grad.detach().abs().max().item()
|
||||
assert b_grad > 0
|
||||
|
||||
|
||||
def test_pissa_identity_with_nonunit_scale():
|
||||
"""Regression: PiSSA must pre-divide S by alpha/r, not require alpha == r."""
|
||||
torch.manual_seed(0)
|
||||
model = TinyModel(n_layers=1)
|
||||
ids = torch.randint(0, 100, (2, 8))
|
||||
with torch.no_grad():
|
||||
y_base = model(ids).clone()
|
||||
|
||||
ll.attach(model, ll.PiSSAConfig(r=4, alpha=8, dtype=torch.float32))
|
||||
with torch.no_grad():
|
||||
y = model(ids)
|
||||
assert (y - y_base).abs().max().item() < IDENTITY_TOL["pissa"]
|
||||
|
||||
|
||||
def test_antipasto_blockwise_rotation_matches_explicit_blockdiag():
|
||||
"""The einsum/rearrange path must equal the old explicit blockdiag math."""
|
||||
from lora_lite.variants.antipasto import _build_rotation
|
||||
|
||||
torch.manual_seed(0)
|
||||
n_blocks, bs, d_in, d_out = 3, 4, 7, 5
|
||||
r = n_blocks * bs
|
||||
rot_T = torch.randn(n_blocks, bs * (bs - 1) // 2) * 0.1
|
||||
Vh = torch.randn(r, d_in)
|
||||
U = torch.randn(d_out, r)
|
||||
R_blocks = _build_rotation(rot_T, bs, 0.5)
|
||||
R = torch.block_diag(*list(R_blocks))
|
||||
|
||||
Vh_blocks = torch.reshape(Vh, (n_blocks, bs, d_in))
|
||||
Vh_rot = torch.einsum("nab,nbi->nai", R_blocks, Vh_blocks).reshape(r, d_in)
|
||||
U_blocks = torch.reshape(U, (d_out, n_blocks, bs))
|
||||
U_rot = torch.einsum("dnb,ncb->dnc", U_blocks, R_blocks).reshape(d_out, r)
|
||||
|
||||
assert (Vh_rot - R @ Vh).abs().max().item() < 1e-6
|
||||
assert (U_rot - U @ R.T).abs().max().item() < 1e-6
|
||||
|
||||
|
||||
def test_dora_bias_passthrough():
|
||||
"""Regression: DoRA must NOT scale bias; identity holds with bias=True at t=0."""
|
||||
torch.manual_seed(0)
|
||||
d = 16
|
||||
layer = nn.Linear(d, d, bias=True)
|
||||
x = torch.randn(2, d)
|
||||
y_base = layer(x).detach()
|
||||
|
||||
class Wrap(nn.Module):
|
||||
def __init__(self, lin):
|
||||
super().__init__()
|
||||
self.config = type("Cfg", (), {"hidden_size": d})()
|
||||
self.layers = nn.ModuleList([lin])
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers[0](x)
|
||||
|
||||
model = Wrap(layer)
|
||||
ll.attach(model, ll.DoRAConfig(r=2, alpha=4, dtype=torch.float32, target_roles=()))
|
||||
with torch.no_grad():
|
||||
y = model(x)
|
||||
assert (y - y_base).abs().max().item() < 1e-5
|
||||
|
||||
|
||||
def test_hra_forward_is_x_R_T():
|
||||
"""HRA must apply x @ R^T (loop i = r-1 down to 0). Asymmetric U makes order observable."""
|
||||
torch.manual_seed(0)
|
||||
d = 8
|
||||
layer = nn.Linear(d, d, bias=False)
|
||||
x = torch.randn(2, 3, d)
|
||||
|
||||
class Wrap(nn.Module):
|
||||
def __init__(self, lin):
|
||||
super().__init__()
|
||||
self.config = type("Cfg", (), {"hidden_size": d})()
|
||||
self.layers = nn.ModuleList([lin])
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers[0](x)
|
||||
|
||||
model = Wrap(layer)
|
||||
ll.attach(model, ll.HRAConfig(r=4, alpha=4, dtype=torch.float32, target_roles=()))
|
||||
# break paired symmetry so order matters
|
||||
with torch.no_grad():
|
||||
layer.lora_U.add_(0.1 * torch.randn_like(layer.lora_U))
|
||||
|
||||
U = layer.lora_U
|
||||
R = torch.eye(d)
|
||||
for i in range(U.shape[0]):
|
||||
u = U[i]
|
||||
sq = (u * u).sum().clamp_min(1e-12)
|
||||
R = R - (2.0 / sq) * torch.outer(R @ u, u)
|
||||
with torch.no_grad():
|
||||
y_adapt = model(x)
|
||||
y_ref = torch.nn.functional.linear(x, layer.weight @ R)
|
||||
assert (y_adapt - y_ref).abs().max().item() < 1e-5
|
||||
|
||||
|
||||
@pytest.mark.parametrize("road_variant", ["road_1", "road_2", "road_4"])
|
||||
def test_road_apply_matches_explicit_matrix(road_variant: str):
|
||||
"""Fast elementwise ROAD path must match PEFT's explicit R @ y matrix construction."""
|
||||
from lora_lite.variants.road import _apply_road, _road_matrix, _road_param_size
|
||||
|
||||
torch.manual_seed(0)
|
||||
d_out = 16
|
||||
group_size = 8
|
||||
size = _road_param_size(d_out, road_variant)
|
||||
theta = torch.randn(size) * 0.2
|
||||
alpha = torch.randn(size) * 0.1 + 1.0
|
||||
y = torch.randn(2, 3, d_out)
|
||||
|
||||
y_fast = _apply_road(road_variant, group_size, theta, alpha, y)
|
||||
R = _road_matrix(road_variant, group_size, theta, alpha)
|
||||
y_ref = torch.einsum("oi,...i->...o", R, y)
|
||||
|
||||
assert (y_fast - y_ref).abs().max().item() < 1e-6
|
||||
|
||||
|
||||
def test_road_invalid_group_size_is_loud():
|
||||
with pytest.raises(ValueError, match="positive and even"):
|
||||
ll.attach(TinyModel(), ll.RoadConfig(group_size=7))
|
||||
with pytest.raises(ValueError, match="divisible"):
|
||||
ll.attach(TinyModel(), ll.RoadConfig(group_size=48))
|
||||
@@ -31,7 +31,7 @@ sys.modules[SPEC.name] = benchmark
|
||||
SPEC.loader.exec_module(benchmark)
|
||||
|
||||
|
||||
VARIANTS = ["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto"]
|
||||
VARIANTS = ["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "road"]
|
||||
# Variants that fail loud when attached on a bnb-loaded base (read dense weight in init).
|
||||
# delora/eva also read weight but currently silently dequant -- they produce sane attach,
|
||||
# so we don't expect a raise from them in the attach-only smoke.
|
||||
@@ -69,6 +69,7 @@ def quick_cfg(variant: str, tmp_path: Path, quantization: str = "none") -> "benc
|
||||
max_seq_length=128,
|
||||
max_new_tokens=8,
|
||||
lr=5e-3,
|
||||
road_group_size=8,
|
||||
seed=0,
|
||||
log_examples=0,
|
||||
log_every=1000,
|
||||
|
||||
@@ -12,7 +12,7 @@ resolution-markers = [
|
||||
]
|
||||
|
||||
[options]
|
||||
exclude-newer = "2026-04-21T14:06:04.428693663Z"
|
||||
exclude-newer = "2026-04-22T01:31:09.47902161Z"
|
||||
exclude-newer-span = "P5D"
|
||||
|
||||
[[package]]
|
||||
@@ -1012,6 +1012,7 @@ dependencies = [
|
||||
{ name = "einops" },
|
||||
{ name = "jaxtyping", version = "0.3.7", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
||||
{ name = "jaxtyping", version = "0.3.9", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
|
||||
{ name = "safetensors" },
|
||||
{ name = "torch" },
|
||||
]
|
||||
|
||||
@@ -1051,6 +1052,7 @@ requires-dist = [
|
||||
{ name = "einops", specifier = ">=0.7" },
|
||||
{ name = "jaxtyping", specifier = ">=0.2.34" },
|
||||
{ name = "pytest", marker = "extra == 'test'" },
|
||||
{ name = "safetensors", specifier = ">=0.5" },
|
||||
{ name = "safetensors", marker = "extra == 'benchmark'", specifier = ">=0.5" },
|
||||
{ name = "safetensors", marker = "extra == 'hf-test'", specifier = ">=0.5" },
|
||||
{ name = "tabulate", marker = "extra == 'benchmark'" },
|
||||
|
||||
Reference in New Issue
Block a user