mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 17:30:56 +08:00
simpler test
This commit is contained in:
@@ -8,6 +8,7 @@ dependencies = [
|
|||||||
"torch>=2.1",
|
"torch>=2.1",
|
||||||
"einops>=0.7",
|
"einops>=0.7",
|
||||||
"jaxtyping>=0.2.34",
|
"jaxtyping>=0.2.34",
|
||||||
|
"safetensors>=0.5",
|
||||||
]
|
]
|
||||||
keywords = ["lora", "pytorch", "peft", "adapters", "llm"]
|
keywords = ["lora", "pytorch", "peft", "adapters", "llm"]
|
||||||
classifiers = [
|
classifiers = [
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ CFG_BY_VARIANT = {
|
|||||||
"hra": ll.HRAConfig,
|
"hra": ll.HRAConfig,
|
||||||
"eva": ll.EVAConfig,
|
"eva": ll.EVAConfig,
|
||||||
"antipasto": ll.AntiPaSTOConfig,
|
"antipasto": ll.AntiPaSTOConfig,
|
||||||
|
"road": ll.RoadConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -41,7 +42,7 @@ class BenchmarkConfig:
|
|||||||
"""MetaMathQA -> GSM8K benchmark config. Tyro turns this into the CLI."""
|
"""MetaMathQA -> GSM8K benchmark config. Tyro turns this into the CLI."""
|
||||||
|
|
||||||
model: str = "Qwen/Qwen3-0.6B-Base"
|
model: str = "Qwen/Qwen3-0.6B-Base"
|
||||||
variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto"] = "lora"
|
variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "road"] = "lora"
|
||||||
mode: Literal["benchmark", "probe"] = "benchmark"
|
mode: Literal["benchmark", "probe"] = "benchmark"
|
||||||
device: str = "cuda"
|
device: str = "cuda"
|
||||||
torch_dtype: str = "bfloat16"
|
torch_dtype: str = "bfloat16"
|
||||||
@@ -49,6 +50,7 @@ class BenchmarkConfig:
|
|||||||
r: int = 32
|
r: int = 32
|
||||||
alpha: float = 64.0
|
alpha: float = 64.0
|
||||||
delora_lambda0: float = 0.1
|
delora_lambda0: float = 0.1
|
||||||
|
road_group_size: int = 64
|
||||||
target_name: list[str] = field(default_factory=lambda: list(DEFAULT_TARGETS))
|
target_name: list[str] = field(default_factory=lambda: list(DEFAULT_TARGETS))
|
||||||
layers: str = "all"
|
layers: str = "all"
|
||||||
train_dataset: str = "meta-math/MetaMathQA"
|
train_dataset: str = "meta-math/MetaMathQA"
|
||||||
@@ -118,6 +120,8 @@ def parse_layers(text: str) -> tuple[int, ...] | None:
|
|||||||
|
|
||||||
def cfg_for_variant(args: BenchmarkConfig, dtype: torch.dtype) -> ll.AdapterConfig:
|
def cfg_for_variant(args: BenchmarkConfig, dtype: torch.dtype) -> ll.AdapterConfig:
|
||||||
extra = {"lambda0": args.delora_lambda0} if args.variant == "delora" else {}
|
extra = {"lambda0": args.delora_lambda0} if args.variant == "delora" else {}
|
||||||
|
if args.variant == "road":
|
||||||
|
extra = {"group_size": args.road_group_size}
|
||||||
return CFG_BY_VARIANT[args.variant](
|
return CFG_BY_VARIANT[args.variant](
|
||||||
r=args.r,
|
r=args.r,
|
||||||
alpha=args.r if args.variant == "pissa" else args.alpha,
|
alpha=args.r if args.variant == "pissa" else args.alpha,
|
||||||
@@ -147,7 +151,7 @@ def count_base_grad_leaks(model: torch.nn.Module) -> int:
|
|||||||
|
|
||||||
|
|
||||||
def perturb_first_adapter(model: torch.nn.Module) -> None:
|
def perturb_first_adapter(model: torch.nn.Module) -> None:
|
||||||
priority = ("lora_B", "lora_g", "lora_U", "lora_A", "lora_lambda", "lora_gate", "lora_delta_s", "lora_rot_T", "lora_m")
|
priority = ("lora_B", "lora_g", "lora_U", "lora_A", "lora_lambda", "lora_gate", "lora_delta_s", "lora_rot_T", "lora_m", "lora_road_theta", "lora_road_alpha")
|
||||||
for key in priority:
|
for key in priority:
|
||||||
for _, p in model.named_parameters():
|
for _, p in model.named_parameters():
|
||||||
if p.requires_grad and key in _:
|
if p.requires_grad and key in _:
|
||||||
@@ -409,11 +413,12 @@ def check_probe_reload(
|
|||||||
loaded_model, _ = load_model_and_tokenizer(args.model, getattr(torch, args.torch_dtype), args.device, args.quantization)
|
loaded_model, _ = load_model_and_tokenizer(args.model, getattr(torch, args.torch_dtype), args.device, args.quantization)
|
||||||
loaded_model.eval()
|
loaded_model.eval()
|
||||||
ll.load(loaded_model, str(adapter_path))
|
ll.load(loaded_model, str(adapter_path))
|
||||||
saved = torch.load(adapter_path, weights_only=True, map_location="cpu")
|
from safetensors.torch import load_file
|
||||||
|
saved_sd = load_file(str(adapter_path), device="cpu")
|
||||||
loaded_state = adapter_state(loaded_model)
|
loaded_state = adapter_state(loaded_model)
|
||||||
if set(saved["state"]) != set(loaded_state):
|
if set(saved_sd) != set(loaded_state):
|
||||||
raise AssertionError("loaded adapter keys differ from saved adapter keys")
|
raise AssertionError("loaded adapter keys differ from saved adapter keys")
|
||||||
for name, value in saved["state"].items():
|
for name, value in saved_sd.items():
|
||||||
if not torch.equal(loaded_state[name].cpu(), value):
|
if not torch.equal(loaded_state[name].cpu(), value):
|
||||||
raise AssertionError(f"loaded adapter tensor differs: {name}")
|
raise AssertionError(f"loaded adapter tensor differs: {name}")
|
||||||
logits_loaded = loaded_model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).logits.detach().clone()
|
logits_loaded = loaded_model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).logits.detach().clone()
|
||||||
@@ -426,12 +431,21 @@ def check_probe_reload(
|
|||||||
return {"reload_err": reload_err, "saved_tensors": len(saved["state"])}
|
return {"reload_err": reload_err, "saved_tensors": len(saved["state"])}
|
||||||
|
|
||||||
|
|
||||||
def print_final_report(row: dict[str, Any], result_path: Path) -> None:
|
def print_final_report(row: dict[str, Any], result_path: Path, mode: str) -> None:
|
||||||
print("SHOULD: grad>0, dθ>0, base_grad_leaks=0, valid/test fields present; probeΔ<0 is good but not required for tiny random smoke. ELSE adapter or eval wiring is dead/wrong.")
|
# BLUF: status line first so log tails are immediately readable
|
||||||
|
cue = "🟢" if row.get("base_grad_leaks", 0) == 0 and row.get("grad", 0) > 0 else "🔴"
|
||||||
|
n = row.get("samples", "?")
|
||||||
|
print(f"{cue} test_acc={row['test_acc']:.4g} valid_acc={row['valid_acc']:.4g} grad={row['grad']:.3g} dθ={row['dθ']:.3g} base_grad_leaks={row['base_grad_leaks']} N={n}")
|
||||||
|
print("SHOULD: grad>0, dθ>0, base_grad_leaks=0; test/valid_acc meaningful only in benchmark mode. ELSE adapter or eval wiring is dead/wrong.")
|
||||||
|
# ordered: most important / shortest columns first
|
||||||
|
display_keys = ["variant", "test_acc", "valid_acc", "grad", "dθ", "base_grad_leaks", "steps", "samples", "loss0", "lossN", "commit"]
|
||||||
|
if "perturb" in row:
|
||||||
|
display_keys += ["perturb", "reload"]
|
||||||
|
display_keys += ["run_id"]
|
||||||
|
display_row = {k: row[k] for k in display_keys if k in row}
|
||||||
|
print(tabulate([display_row], headers="keys", tablefmt="tsv", floatfmt=".4g"))
|
||||||
|
print(f"argv: {' '.join(sys.argv)} N={n} mode={mode}")
|
||||||
print(f"out: {result_path}")
|
print(f"out: {result_path}")
|
||||||
print(f"argv: {' '.join(sys.argv)}")
|
|
||||||
print(f"main metric: test_acc={row['test_acc']:.4g} valid_acc={row['valid_acc']:.4g} steps={row['steps']} samples={row['samples']}")
|
|
||||||
print(tabulate([row], headers="keys", tablefmt="tsv", floatfmt=".4g"))
|
|
||||||
|
|
||||||
|
|
||||||
def current_git_commit() -> str:
|
def current_git_commit() -> str:
|
||||||
@@ -447,25 +461,24 @@ def append_results_row(
|
|||||||
result: dict[str, Any],
|
result: dict[str, Any],
|
||||||
run_commit: str,
|
run_commit: str,
|
||||||
) -> tuple[Path, Path]:
|
) -> tuple[Path, Path]:
|
||||||
results_dir = args.output_dir / "results"
|
results_dir = args.output_dir
|
||||||
results_dir.mkdir(parents=True, exist_ok=True)
|
results_dir.mkdir(parents=True, exist_ok=True)
|
||||||
tsv_path = results_dir / "benchmark_results.tsv"
|
tsv_path = results_dir / "summary.tsv"
|
||||||
lock_path = results_dir / "benchmark_results.tsv.lock"
|
lock_path = results_dir / "summary.tsv.lock"
|
||||||
finished_at = datetime.now(timezone.utc).isoformat(timespec="seconds")
|
finished_at = datetime.now(timezone.utc).isoformat(timespec="seconds")
|
||||||
finished_label = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
finished_label = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||||
snapshot_path = results_dir / f"{result['run_id']}__{finished_label}.json"
|
snapshot_path = results_dir / f"{result['run_id']}__{finished_label}.json"
|
||||||
snapshot_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
|
snapshot_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
|
||||||
row = {
|
row = {
|
||||||
"time_utc": finished_at,
|
"test_acc": result["test_acc"],
|
||||||
"commit": run_commit,
|
"valid_acc": result["valid_acc"],
|
||||||
"method": args.variant,
|
"method": args.variant,
|
||||||
"model": args.model,
|
|
||||||
"mode": args.mode,
|
|
||||||
"valid_accuracy": result["valid_accuracy"],
|
|
||||||
"test_accuracy": result["test_accuracy"],
|
|
||||||
"steps": args.steps,
|
"steps": args.steps,
|
||||||
"samples": result["train_samples"],
|
"samples": result["train_samples"],
|
||||||
"wall_time_s": result["wall_time_s"],
|
"model": args.model,
|
||||||
|
"commit": run_commit[:12],
|
||||||
|
"wall_time_s": round(result["wall_time_s"]),
|
||||||
|
"time_utc": finished_at,
|
||||||
"argv": " ".join(sys.argv),
|
"argv": " ".join(sys.argv),
|
||||||
"result_json": str(snapshot_path),
|
"result_json": str(snapshot_path),
|
||||||
"latest_result_json": str(result_path),
|
"latest_result_json": str(result_path),
|
||||||
@@ -517,7 +530,7 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
|
|||||||
valid_metrics = evaluate(model, tokenizer, datasets["valid"], args, "valid")
|
valid_metrics = evaluate(model, tokenizer, datasets["valid"], args, "valid")
|
||||||
test_metrics = evaluate(model, tokenizer, datasets["test"], args, "test")
|
test_metrics = evaluate(model, tokenizer, datasets["test"], args, "test")
|
||||||
|
|
||||||
adapter_path = out_dir / "adapter.pt"
|
adapter_path = out_dir / "adapter.safetensors"
|
||||||
ll.save(model, str(adapter_path))
|
ll.save(model, str(adapter_path))
|
||||||
if args.mode == "probe":
|
if args.mode == "probe":
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -555,8 +568,8 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
|
|||||||
"weight_decay": args.weight_decay,
|
"weight_decay": args.weight_decay,
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
"grad_norm_clip": args.grad_norm_clip,
|
"grad_norm_clip": args.grad_norm_clip,
|
||||||
"valid_accuracy": valid_metrics["accuracy"],
|
"valid_acc": valid_metrics["accuracy"],
|
||||||
"test_accuracy": test_metrics["accuracy"],
|
"test_acc": test_metrics["accuracy"],
|
||||||
"train": train_metrics,
|
"train": train_metrics,
|
||||||
"valid": valid_metrics,
|
"valid": valid_metrics,
|
||||||
"test": test_metrics,
|
"test": test_metrics,
|
||||||
@@ -566,12 +579,12 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
|
|||||||
}
|
}
|
||||||
result_path = out_dir / "result.json"
|
result_path = out_dir / "result.json"
|
||||||
result_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
|
result_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
|
||||||
results_tsv_path, result_snapshot_path = append_results_row(args, result_path, result, run_commit)
|
if args.mode == "benchmark":
|
||||||
result["results_tsv_path"] = str(results_tsv_path)
|
results_tsv_path, result_snapshot_path = append_results_row(args, result_path, result, run_commit)
|
||||||
result["result_snapshot_path"] = str(result_snapshot_path)
|
result["results_tsv_path"] = str(results_tsv_path)
|
||||||
|
result["result_snapshot_path"] = str(result_snapshot_path)
|
||||||
result["commit"] = run_commit
|
result["commit"] = run_commit
|
||||||
result_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
|
result_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
|
||||||
commit_prefix = run_commit[:12]
|
|
||||||
|
|
||||||
row = {
|
row = {
|
||||||
"run_id": run_id,
|
"run_id": run_id,
|
||||||
@@ -592,7 +605,7 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
|
|||||||
if probe_metrics is not None:
|
if probe_metrics is not None:
|
||||||
row["perturb"] = probe_metrics["perturb_delta"]
|
row["perturb"] = probe_metrics["perturb_delta"]
|
||||||
row["reload"] = probe_metrics["reload_err"]
|
row["reload"] = probe_metrics["reload_err"]
|
||||||
print_final_report(row, result_path)
|
print_final_report(row, result_path, args.mode)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from .variants.dora import DoRAConfig
|
|||||||
from .variants.hra import HRAConfig
|
from .variants.hra import HRAConfig
|
||||||
from .variants.eva import EVAConfig
|
from .variants.eva import EVAConfig
|
||||||
from .variants.antipasto import AntiPaSTOConfig
|
from .variants.antipasto import AntiPaSTOConfig
|
||||||
|
from .variants.road import RoadConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AdapterConfig",
|
"AdapterConfig",
|
||||||
@@ -32,6 +33,7 @@ __all__ = [
|
|||||||
"HRAConfig",
|
"HRAConfig",
|
||||||
"EVAConfig",
|
"EVAConfig",
|
||||||
"AntiPaSTOConfig",
|
"AntiPaSTOConfig",
|
||||||
|
"RoadConfig",
|
||||||
"attach",
|
"attach",
|
||||||
"detach",
|
"detach",
|
||||||
"save",
|
"save",
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""attach / detach / save / load. The whole runtime."""
|
"""attach / detach / save / load. The whole runtime."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
import json
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.hooks import RemovableHandle
|
from torch.utils.hooks import RemovableHandle
|
||||||
@@ -121,19 +122,22 @@ def save(model: nn.Module, path: str) -> None:
|
|||||||
if state is None:
|
if state is None:
|
||||||
raise RuntimeError("no adapter attached; call attach() first")
|
raise RuntimeError("no adapter attached; call attach() first")
|
||||||
sd = {k: v.detach().cpu() for k, v in model.state_dict().items() if "lora_" in k}
|
sd = {k: v.detach().cpu() for k, v in model.state_dict().items() if "lora_" in k}
|
||||||
blob = {
|
metadata = {
|
||||||
"cfg": state["cfg"].to_dict(),
|
"cfg": json.dumps(state["cfg"].to_dict()),
|
||||||
"state": sd,
|
"base_fp": json.dumps(_base_weight_fingerprint(model)),
|
||||||
"base_fp": _base_weight_fingerprint(model),
|
|
||||||
}
|
}
|
||||||
torch.save(blob, path)
|
from safetensors.torch import save_file
|
||||||
|
save_file(sd, path, metadata=metadata)
|
||||||
|
|
||||||
|
|
||||||
def load(model: nn.Module, path: str) -> list[RemovableHandle]:
|
def load(model: nn.Module, path: str) -> list[RemovableHandle]:
|
||||||
blob = torch.load(path, weights_only=True, map_location="cpu")
|
from safetensors.torch import load_file, safe_open
|
||||||
cfg = AdapterConfig.from_dict(blob["cfg"])
|
with safe_open(path, framework="pt", device="cpu") as f:
|
||||||
|
metadata = f.metadata()
|
||||||
|
sd = load_file(path, device="cpu")
|
||||||
|
cfg = AdapterConfig.from_dict(json.loads(metadata["cfg"]))
|
||||||
handles = attach(model, cfg, _skip_group_init=True) # creates empty params; data-driven inits restored from state_dict
|
handles = attach(model, cfg, _skip_group_init=True) # creates empty params; data-driven inits restored from state_dict
|
||||||
missing, unexpected = model.load_state_dict(blob["state"], strict=False)
|
missing, unexpected = model.load_state_dict(sd, strict=False)
|
||||||
expected_lora = {k for k in model.state_dict() if "lora_" in k}
|
expected_lora = {k for k in model.state_dict() if "lora_" in k}
|
||||||
missing_lora = sorted(expected_lora.intersection(missing))
|
missing_lora = sorted(expected_lora.intersection(missing))
|
||||||
if missing_lora:
|
if missing_lora:
|
||||||
@@ -141,7 +145,7 @@ def load(model: nn.Module, path: str) -> list[RemovableHandle]:
|
|||||||
unexpected_lora = [k for k in unexpected if "lora_" in k]
|
unexpected_lora = [k for k in unexpected if "lora_" in k]
|
||||||
if unexpected_lora:
|
if unexpected_lora:
|
||||||
raise RuntimeError(f"unexpected lora keys in checkpoint: {unexpected_lora}")
|
raise RuntimeError(f"unexpected lora keys in checkpoint: {unexpected_lora}")
|
||||||
saved_fp = blob.get("base_fp", {})
|
saved_fp = json.loads(metadata.get("base_fp", "{}"))
|
||||||
if saved_fp:
|
if saved_fp:
|
||||||
cur_fp = _base_weight_fingerprint(model)
|
cur_fp = _base_weight_fingerprint(model)
|
||||||
diffs = [k for k in saved_fp if saved_fp[k] != cur_fp.get(k)]
|
diffs = [k for k in saved_fp if saved_fp[k] != cur_fp.get(k)]
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
from . import lora, pissa, delora, ia3, dora, hra, eva, antipasto # noqa: F401 side-effect: register
|
from . import lora, pissa, delora, ia3, dora, hra, eva, antipasto, road # noqa: F401 side-effect: register
|
||||||
|
|||||||
@@ -37,7 +37,6 @@ class LoRA:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def init(layer: nn.Module, cfg) -> None:
|
def init(layer: nn.Module, cfg) -> None:
|
||||||
# B is zeros => delta=0 at t=0; identity invariant holds.
|
|
||||||
return
|
return
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -0,0 +1,137 @@
|
|||||||
|
"""ROAD: Rotation ADaptation. https://arxiv.org/abs/2409.00119
|
||||||
|
|
||||||
|
ROAD applies a learned output-space block rotation/scaling after the frozen base
|
||||||
|
layer:
|
||||||
|
|
||||||
|
y' = R y = R (W x + b)
|
||||||
|
|
||||||
|
This matches PEFT's unmerged forward path and fits lora-lite as a simple output
|
||||||
|
hook. We implement the three PEFT variants (`road_1`, `road_2`, `road_4`) and
|
||||||
|
skip merge/unmerge because this library keeps adapters as hooks.
|
||||||
|
|
||||||
|
Refs:
|
||||||
|
- peft: https://github.com/huggingface/peft/blob/6030f9160ed2fc17220f6f41382a66f1257b6a93/src/peft/tuners/road/layer.py
|
||||||
|
"""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from jaxtyping import Float
|
||||||
|
from torch import nn, Tensor as T
|
||||||
|
|
||||||
|
from ..config import AdapterConfig, register_config
|
||||||
|
from ..variant import ParamSpec, register
|
||||||
|
|
||||||
|
RoadVariant = Literal["road_1", "road_2", "road_4"]
|
||||||
|
|
||||||
|
|
||||||
|
@register_config
|
||||||
|
@dataclass
|
||||||
|
class RoadConfig(AdapterConfig):
|
||||||
|
variant: str = "road"
|
||||||
|
road_variant: RoadVariant = "road_1"
|
||||||
|
group_size: int = 64
|
||||||
|
|
||||||
|
|
||||||
|
def _road_param_size(d_out: int, road_variant: str) -> int:
|
||||||
|
if road_variant == "road_1":
|
||||||
|
return d_out // 2
|
||||||
|
if road_variant == "road_2":
|
||||||
|
return d_out
|
||||||
|
if road_variant == "road_4":
|
||||||
|
return d_out * 2
|
||||||
|
raise ValueError(f"road_variant must be 'road_1', 'road_2', or 'road_4', got {road_variant!r}")
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_group_geometry(d_out: int, group_size: int) -> None:
|
||||||
|
if group_size <= 0 or group_size % 2 != 0:
|
||||||
|
raise ValueError(f"ROAD group_size must be positive and even, got {group_size}")
|
||||||
|
if d_out % group_size != 0:
|
||||||
|
raise ValueError(f"ROAD d_out={d_out} must be divisible by group_size={group_size}")
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_cols(
|
||||||
|
road_variant: str,
|
||||||
|
group_size: int,
|
||||||
|
road_theta: torch.Tensor,
|
||||||
|
road_alpha: torch.Tensor,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if road_variant == "road_1":
|
||||||
|
# One θ/α per pair. Reuse it for both rows of each 2D rotation block.
|
||||||
|
road_theta = road_theta.reshape(-1, group_size // 2).repeat_interleave(2, dim=0).flatten()
|
||||||
|
road_alpha = road_alpha.reshape(-1, group_size // 2).repeat_interleave(2, dim=0).flatten()
|
||||||
|
first_col = road_alpha * road_theta.cos()
|
||||||
|
second_col = road_alpha * road_theta.sin()
|
||||||
|
elif road_variant == "road_2":
|
||||||
|
# One θ/α per output coordinate.
|
||||||
|
first_col = road_alpha * road_theta.cos()
|
||||||
|
second_col = road_alpha * road_theta.sin()
|
||||||
|
elif road_variant == "road_4":
|
||||||
|
# Independent θ/α for the first and second column contributions.
|
||||||
|
road_theta = road_theta.reshape(-1, 2, group_size)
|
||||||
|
road_alpha = road_alpha.reshape(-1, 2, group_size)
|
||||||
|
first_col = road_alpha[:, 0, :].flatten() * road_theta[:, 0, :].cos().flatten()
|
||||||
|
second_col = road_alpha[:, 1, :].flatten() * road_theta[:, 1, :].sin().flatten()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"road_variant must be 'road_1', 'road_2', or 'road_4', got {road_variant!r}")
|
||||||
|
return first_col, second_col
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_road(
|
||||||
|
road_variant: str,
|
||||||
|
group_size: int,
|
||||||
|
road_theta: torch.Tensor,
|
||||||
|
road_alpha: torch.Tensor,
|
||||||
|
y: Float[T, '*B o'],
|
||||||
|
) -> Float[T, '*B o']:
|
||||||
|
first_col, second_col = _prepare_cols(road_variant, group_size, road_theta, road_alpha)
|
||||||
|
y_grouped = y.reshape(-1, 2, group_size // 2)
|
||||||
|
y1 = y_grouped[:, 0, :]
|
||||||
|
y2 = y_grouped[:, 1, :]
|
||||||
|
rotate_half_y = torch.stack((-y2, y1), dim=1).reshape(y.shape)
|
||||||
|
return y * first_col + rotate_half_y * second_col
|
||||||
|
|
||||||
|
|
||||||
|
def _road_matrix(
|
||||||
|
road_variant: str,
|
||||||
|
group_size: int,
|
||||||
|
road_theta: torch.Tensor,
|
||||||
|
road_alpha: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Explicit PEFT merge matrix. Used for tests and small-debug inspection."""
|
||||||
|
first_col, second_col = _prepare_cols(road_variant, group_size, road_theta, road_alpha)
|
||||||
|
size = second_col.shape[0]
|
||||||
|
output = torch.diag(first_col)
|
||||||
|
swapped_second_col = second_col.reshape(-1, 2, group_size // 2)[:, [1, 0], :].flatten()
|
||||||
|
rotated_diag_second_col = torch.diag(swapped_second_col).reshape(-1, 2, group_size // 2, size)[:, [1, 0], :, :]
|
||||||
|
rotated_diag_second_col[:, 0, :, :] *= -1
|
||||||
|
return output + rotated_diag_second_col.reshape(size, size)
|
||||||
|
|
||||||
|
|
||||||
|
@register
|
||||||
|
class ROAD:
|
||||||
|
name = "road"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def param_specs(d_in: int, d_out: int, cfg: RoadConfig) -> dict[str, ParamSpec]:
|
||||||
|
_validate_group_geometry(d_out, cfg.group_size)
|
||||||
|
size = _road_param_size(d_out, cfg.road_variant)
|
||||||
|
return {
|
||||||
|
"lora_road_theta": ParamSpec((size,), init="zeros", trainable=True),
|
||||||
|
"lora_road_alpha": ParamSpec((size,), init="ones", trainable=True),
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def init(layer: nn.Module, cfg: RoadConfig) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(
|
||||||
|
layer: nn.Module,
|
||||||
|
x: Float[T, '*B i'],
|
||||||
|
y: Float[T, '*B o'],
|
||||||
|
) -> Float[T, '*B o']:
|
||||||
|
del x
|
||||||
|
cfg = layer._lora_cfg
|
||||||
|
y_cast = y.to(layer.lora_road_theta.dtype)
|
||||||
|
return _apply_road(cfg.road_variant, cfg.group_size, layer.lora_road_theta, layer.lora_road_alpha, y_cast)
|
||||||
@@ -0,0 +1,366 @@
|
|||||||
|
"""Per-variant attach + train + save + load round-trip, plus surgical regressions.
|
||||||
|
|
||||||
|
The big invariant is the parametrized train_save_load test: identity at t=0,
|
||||||
|
gradient flow on a real loss, then save -> reload onto a fresh model and
|
||||||
|
confirm the trained outputs survive the round-trip. Cheap on CPU.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
import lora_lite as ll
|
||||||
|
|
||||||
|
|
||||||
|
CFG_BY_VARIANT = {
|
||||||
|
"lora": ll.LoRAConfig,
|
||||||
|
"pissa": ll.PiSSAConfig,
|
||||||
|
"delora": ll.DeLoRAConfig,
|
||||||
|
"ia3": ll.IA3Config,
|
||||||
|
"ia3_ff": ll.IA3FFConfig,
|
||||||
|
"dora": ll.DoRAConfig,
|
||||||
|
"hra": ll.HRAConfig,
|
||||||
|
"eva": ll.EVAConfig,
|
||||||
|
"antipasto": ll.AntiPaSTOConfig,
|
||||||
|
"road": ll.RoadConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Per-variant identity tolerance at t=0 (after attach, before any step).
|
||||||
|
# fp32 SVD round-trip + per-row norm = looser tolerance for pissa/dora/antipasto.
|
||||||
|
IDENTITY_TOL = {
|
||||||
|
"lora": 1e-6,
|
||||||
|
"pissa": 5e-4,
|
||||||
|
"delora": 1e-6,
|
||||||
|
"ia3": 1e-6,
|
||||||
|
"ia3_ff": 1e-6,
|
||||||
|
"dora": 5e-5,
|
||||||
|
"hra": 5e-6,
|
||||||
|
"eva": 1e-6,
|
||||||
|
"antipasto": 5e-4,
|
||||||
|
"road": 1e-6,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TinyBlock(nn.Module):
|
||||||
|
def __init__(self, d: int = 64, ff: int = 128):
|
||||||
|
super().__init__()
|
||||||
|
self.q_proj = nn.Linear(d, d, bias=False)
|
||||||
|
self.k_proj = nn.Linear(d, d, bias=False)
|
||||||
|
self.v_proj = nn.Linear(d, d, bias=False)
|
||||||
|
self.o_proj = nn.Linear(d, d, bias=False)
|
||||||
|
self.gate_proj = nn.Linear(d, ff, bias=False)
|
||||||
|
self.up_proj = nn.Linear(d, ff, bias=False)
|
||||||
|
self.down_proj = nn.Linear(ff, d, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
h = self.o_proj(self.q_proj(x) + self.k_proj(x) + self.v_proj(x))
|
||||||
|
m = self.down_proj(torch.nn.functional.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
return x + h + m
|
||||||
|
|
||||||
|
|
||||||
|
class TinyModel(nn.Module):
|
||||||
|
def __init__(self, n_layers: int = 4, d: int = 64, ff: int = 128, vocab: int = 100):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_tokens = nn.Embedding(vocab, d)
|
||||||
|
self.layers = nn.ModuleList([TinyBlock(d, ff) for _ in range(n_layers)])
|
||||||
|
self.lm_head = nn.Linear(d, vocab, bias=False)
|
||||||
|
self.config = type("Cfg", (), {"hidden_size": d})()
|
||||||
|
|
||||||
|
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.embed_tokens(ids)
|
||||||
|
for block in self.layers:
|
||||||
|
x = block(x)
|
||||||
|
return self.lm_head(x)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeLinearLike(nn.Module):
|
||||||
|
"""linear-like, but not nn.Linear: stand-in for bnb 4/8-bit modules."""
|
||||||
|
|
||||||
|
def __init__(self, d_in: int = 8, d_out: int = 8):
|
||||||
|
super().__init__()
|
||||||
|
self.in_features = d_in
|
||||||
|
self.out_features = d_out
|
||||||
|
self.weight = nn.Parameter(torch.empty(d_out, d_in))
|
||||||
|
nn.init.kaiming_uniform_(self.weight, a=5 ** 0.5)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return torch.nn.functional.linear(x, self.weight)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeBnbModel(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.config = type("Cfg", (), {"hidden_size": 8})()
|
||||||
|
self.layers = nn.ModuleList([FakeLinearLike(8, 8)])
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.layers[0](x)
|
||||||
|
|
||||||
|
|
||||||
|
def cfg_for(variant: str) -> ll.AdapterConfig:
|
||||||
|
return CFG_BY_VARIANT[variant](
|
||||||
|
r=4,
|
||||||
|
alpha=8,
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def attach_with_calib(model: nn.Module, cfg: ll.AdapterConfig, ids: torch.Tensor) -> None:
|
||||||
|
if cfg.variant == "eva":
|
||||||
|
calib = [ids for _ in range(2)]
|
||||||
|
ll.attach(model, cfg, calibration_data=calib)
|
||||||
|
else:
|
||||||
|
ll.attach(model, cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def trainable_grad_norm(model: nn.Module) -> float:
|
||||||
|
return sum(
|
||||||
|
p.grad.detach().float().norm().item()
|
||||||
|
for n, p in model.named_parameters()
|
||||||
|
if "lora_" in n and p.grad is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("variant", list(CFG_BY_VARIANT))
|
||||||
|
def test_train_save_load(variant: str, tmp_path: Path):
|
||||||
|
"""Identity at t=0, one SGD step, save, reload onto fresh model, outputs match."""
|
||||||
|
torch.manual_seed(0)
|
||||||
|
model = TinyModel()
|
||||||
|
ids = torch.randint(0, 100, (2, 16))
|
||||||
|
with torch.no_grad():
|
||||||
|
y_base = model(ids).clone()
|
||||||
|
|
||||||
|
cfg = cfg_for(variant)
|
||||||
|
attach_with_calib(model, cfg, ids)
|
||||||
|
|
||||||
|
trainable = [p for p in model.parameters() if p.requires_grad]
|
||||||
|
assert trainable
|
||||||
|
assert all("lora_" in n for n, p in model.named_parameters() if p.requires_grad)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
y_init = model(ids).clone()
|
||||||
|
assert (y_init - y_base).abs().max().item() < IDENTITY_TOL[variant]
|
||||||
|
|
||||||
|
target = torch.randn_like(y_init) * 0.1
|
||||||
|
opt = torch.optim.SGD(trainable, lr=1e-2)
|
||||||
|
opt.zero_grad()
|
||||||
|
loss = (model(ids) - target).pow(2).mean()
|
||||||
|
loss.backward()
|
||||||
|
leaked = [n for n, p in model.named_parameters() if "lora_" not in n and p.grad is not None]
|
||||||
|
assert leaked == []
|
||||||
|
assert trainable_grad_norm(model) > 0
|
||||||
|
opt.step()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
y_trained = model(ids).clone()
|
||||||
|
|
||||||
|
path = tmp_path / "adapter.pt"
|
||||||
|
ll.save(model, str(path))
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
model_loaded = TinyModel()
|
||||||
|
ll.load(model_loaded, str(path)) # EVA load skips group_init; calibration_data not needed
|
||||||
|
with torch.no_grad():
|
||||||
|
y_loaded = model_loaded(ids)
|
||||||
|
assert (y_loaded - y_trained).abs().max().item() < max(IDENTITY_TOL[variant], 1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("variant", ["lora", "delora", "ia3", "hra", "road"])
|
||||||
|
def test_hook_only_variants_attach_to_non_linear_target(variant: str):
|
||||||
|
"""bnb-style targets are linear-like but not nn.Linear; hook-only variants must accept them."""
|
||||||
|
extra = {"lambda0": 0.1} if variant == "delora" else {"group_size": 8} if variant == "road" else {}
|
||||||
|
cfg = CFG_BY_VARIANT[variant](r=2, alpha=4, dtype=torch.float32, target_roles=(), **extra)
|
||||||
|
model = FakeBnbModel()
|
||||||
|
ll.attach(model, cfg)
|
||||||
|
x = torch.randn(2, 3, 8)
|
||||||
|
model(x).pow(2).mean().backward()
|
||||||
|
assert trainable_grad_norm(model) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("variant", ["pissa", "dora", "antipasto"])
|
||||||
|
def test_weight_reading_variants_reject_non_linear(variant: str):
|
||||||
|
r = 4 if variant == "antipasto" else 2 # antipasto needs r % block_size==0
|
||||||
|
cfg = CFG_BY_VARIANT[variant](r=r, alpha=r, dtype=torch.float32, target_roles=())
|
||||||
|
with pytest.raises(TypeError, match="plain nn.Linear"):
|
||||||
|
ll.attach(FakeBnbModel(), cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_load_strict_keys(tmp_path: Path):
|
||||||
|
import json
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
model = TinyModel()
|
||||||
|
ll.attach(model, ll.LoRAConfig(r=4, alpha=8, dtype=torch.float32))
|
||||||
|
p = tmp_path / "lora.safetensors"
|
||||||
|
ll.save(model, str(p))
|
||||||
|
sd = load_file(str(p), device="cpu")
|
||||||
|
|
||||||
|
# missing key: drop first lora key
|
||||||
|
missing_sd = dict(sd)
|
||||||
|
dropped_key = next(iter(missing_sd))
|
||||||
|
del missing_sd[dropped_key]
|
||||||
|
from safetensors import safe_open
|
||||||
|
with safe_open(str(p), framework="pt", device="cpu") as f:
|
||||||
|
meta = f.metadata()
|
||||||
|
save_file(missing_sd, str(p), metadata=meta)
|
||||||
|
with pytest.raises(RuntimeError, match="missing lora keys"):
|
||||||
|
ll.load(TinyModel(), str(p))
|
||||||
|
|
||||||
|
# unexpected key: add a bogus lora key
|
||||||
|
bad_sd = dict(sd)
|
||||||
|
bad_sd["layers.0.q_proj.lora_extra"] = torch.zeros(1)
|
||||||
|
save_file(bad_sd, str(p), metadata=meta)
|
||||||
|
with pytest.raises(RuntimeError, match="unexpected lora keys"):
|
||||||
|
ll.load(TinyModel(), str(p))
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_target_layers_is_loud():
|
||||||
|
cfg = ll.LoRAConfig(target_names=("definitely_missing",))
|
||||||
|
with pytest.raises(RuntimeError, match="no target layers"):
|
||||||
|
ll.attach(TinyModel(), cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eva_requires_calibration():
|
||||||
|
"""EVA's group_init must error loudly if calibration_data is missing."""
|
||||||
|
with pytest.raises(ValueError, match="calibration_data"):
|
||||||
|
ll.attach(TinyModel(), ll.EVAConfig(r=4, alpha=8, dtype=torch.float32))
|
||||||
|
|
||||||
|
|
||||||
|
def test_delora_default_has_live_step0_gradient():
|
||||||
|
"""Default lambda0 must be nonzero; B=0 preserves identity while B gets gradient."""
|
||||||
|
torch.manual_seed(0)
|
||||||
|
model = TinyModel(n_layers=1)
|
||||||
|
ids = torch.randint(0, 100, (2, 8))
|
||||||
|
ll.attach(model, ll.DeLoRAConfig(r=4, alpha=8, dtype=torch.float32))
|
||||||
|
|
||||||
|
assert model.layers[0].q_proj.lora_lambda.item() == pytest.approx(15.0)
|
||||||
|
loss = model(ids).pow(2).mean()
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
b_grad = model.layers[0].q_proj.lora_B.grad.detach().abs().max().item()
|
||||||
|
assert b_grad > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_pissa_identity_with_nonunit_scale():
|
||||||
|
"""Regression: PiSSA must pre-divide S by alpha/r, not require alpha == r."""
|
||||||
|
torch.manual_seed(0)
|
||||||
|
model = TinyModel(n_layers=1)
|
||||||
|
ids = torch.randint(0, 100, (2, 8))
|
||||||
|
with torch.no_grad():
|
||||||
|
y_base = model(ids).clone()
|
||||||
|
|
||||||
|
ll.attach(model, ll.PiSSAConfig(r=4, alpha=8, dtype=torch.float32))
|
||||||
|
with torch.no_grad():
|
||||||
|
y = model(ids)
|
||||||
|
assert (y - y_base).abs().max().item() < IDENTITY_TOL["pissa"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_antipasto_blockwise_rotation_matches_explicit_blockdiag():
|
||||||
|
"""The einsum/rearrange path must equal the old explicit blockdiag math."""
|
||||||
|
from lora_lite.variants.antipasto import _build_rotation
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
n_blocks, bs, d_in, d_out = 3, 4, 7, 5
|
||||||
|
r = n_blocks * bs
|
||||||
|
rot_T = torch.randn(n_blocks, bs * (bs - 1) // 2) * 0.1
|
||||||
|
Vh = torch.randn(r, d_in)
|
||||||
|
U = torch.randn(d_out, r)
|
||||||
|
R_blocks = _build_rotation(rot_T, bs, 0.5)
|
||||||
|
R = torch.block_diag(*list(R_blocks))
|
||||||
|
|
||||||
|
Vh_blocks = torch.reshape(Vh, (n_blocks, bs, d_in))
|
||||||
|
Vh_rot = torch.einsum("nab,nbi->nai", R_blocks, Vh_blocks).reshape(r, d_in)
|
||||||
|
U_blocks = torch.reshape(U, (d_out, n_blocks, bs))
|
||||||
|
U_rot = torch.einsum("dnb,ncb->dnc", U_blocks, R_blocks).reshape(d_out, r)
|
||||||
|
|
||||||
|
assert (Vh_rot - R @ Vh).abs().max().item() < 1e-6
|
||||||
|
assert (U_rot - U @ R.T).abs().max().item() < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
def test_dora_bias_passthrough():
|
||||||
|
"""Regression: DoRA must NOT scale bias; identity holds with bias=True at t=0."""
|
||||||
|
torch.manual_seed(0)
|
||||||
|
d = 16
|
||||||
|
layer = nn.Linear(d, d, bias=True)
|
||||||
|
x = torch.randn(2, d)
|
||||||
|
y_base = layer(x).detach()
|
||||||
|
|
||||||
|
class Wrap(nn.Module):
|
||||||
|
def __init__(self, lin):
|
||||||
|
super().__init__()
|
||||||
|
self.config = type("Cfg", (), {"hidden_size": d})()
|
||||||
|
self.layers = nn.ModuleList([lin])
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.layers[0](x)
|
||||||
|
|
||||||
|
model = Wrap(layer)
|
||||||
|
ll.attach(model, ll.DoRAConfig(r=2, alpha=4, dtype=torch.float32, target_roles=()))
|
||||||
|
with torch.no_grad():
|
||||||
|
y = model(x)
|
||||||
|
assert (y - y_base).abs().max().item() < 1e-5
|
||||||
|
|
||||||
|
|
||||||
|
def test_hra_forward_is_x_R_T():
|
||||||
|
"""HRA must apply x @ R^T (loop i = r-1 down to 0). Asymmetric U makes order observable."""
|
||||||
|
torch.manual_seed(0)
|
||||||
|
d = 8
|
||||||
|
layer = nn.Linear(d, d, bias=False)
|
||||||
|
x = torch.randn(2, 3, d)
|
||||||
|
|
||||||
|
class Wrap(nn.Module):
|
||||||
|
def __init__(self, lin):
|
||||||
|
super().__init__()
|
||||||
|
self.config = type("Cfg", (), {"hidden_size": d})()
|
||||||
|
self.layers = nn.ModuleList([lin])
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.layers[0](x)
|
||||||
|
|
||||||
|
model = Wrap(layer)
|
||||||
|
ll.attach(model, ll.HRAConfig(r=4, alpha=4, dtype=torch.float32, target_roles=()))
|
||||||
|
# break paired symmetry so order matters
|
||||||
|
with torch.no_grad():
|
||||||
|
layer.lora_U.add_(0.1 * torch.randn_like(layer.lora_U))
|
||||||
|
|
||||||
|
U = layer.lora_U
|
||||||
|
R = torch.eye(d)
|
||||||
|
for i in range(U.shape[0]):
|
||||||
|
u = U[i]
|
||||||
|
sq = (u * u).sum().clamp_min(1e-12)
|
||||||
|
R = R - (2.0 / sq) * torch.outer(R @ u, u)
|
||||||
|
with torch.no_grad():
|
||||||
|
y_adapt = model(x)
|
||||||
|
y_ref = torch.nn.functional.linear(x, layer.weight @ R)
|
||||||
|
assert (y_adapt - y_ref).abs().max().item() < 1e-5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("road_variant", ["road_1", "road_2", "road_4"])
|
||||||
|
def test_road_apply_matches_explicit_matrix(road_variant: str):
|
||||||
|
"""Fast elementwise ROAD path must match PEFT's explicit R @ y matrix construction."""
|
||||||
|
from lora_lite.variants.road import _apply_road, _road_matrix, _road_param_size
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
d_out = 16
|
||||||
|
group_size = 8
|
||||||
|
size = _road_param_size(d_out, road_variant)
|
||||||
|
theta = torch.randn(size) * 0.2
|
||||||
|
alpha = torch.randn(size) * 0.1 + 1.0
|
||||||
|
y = torch.randn(2, 3, d_out)
|
||||||
|
|
||||||
|
y_fast = _apply_road(road_variant, group_size, theta, alpha, y)
|
||||||
|
R = _road_matrix(road_variant, group_size, theta, alpha)
|
||||||
|
y_ref = torch.einsum("oi,...i->...o", R, y)
|
||||||
|
|
||||||
|
assert (y_fast - y_ref).abs().max().item() < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
def test_road_invalid_group_size_is_loud():
|
||||||
|
with pytest.raises(ValueError, match="positive and even"):
|
||||||
|
ll.attach(TinyModel(), ll.RoadConfig(group_size=7))
|
||||||
|
with pytest.raises(ValueError, match="divisible"):
|
||||||
|
ll.attach(TinyModel(), ll.RoadConfig(group_size=48))
|
||||||
@@ -31,7 +31,7 @@ sys.modules[SPEC.name] = benchmark
|
|||||||
SPEC.loader.exec_module(benchmark)
|
SPEC.loader.exec_module(benchmark)
|
||||||
|
|
||||||
|
|
||||||
VARIANTS = ["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto"]
|
VARIANTS = ["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "road"]
|
||||||
# Variants that fail loud when attached on a bnb-loaded base (read dense weight in init).
|
# Variants that fail loud when attached on a bnb-loaded base (read dense weight in init).
|
||||||
# delora/eva also read weight but currently silently dequant -- they produce sane attach,
|
# delora/eva also read weight but currently silently dequant -- they produce sane attach,
|
||||||
# so we don't expect a raise from them in the attach-only smoke.
|
# so we don't expect a raise from them in the attach-only smoke.
|
||||||
@@ -69,6 +69,7 @@ def quick_cfg(variant: str, tmp_path: Path, quantization: str = "none") -> "benc
|
|||||||
max_seq_length=128,
|
max_seq_length=128,
|
||||||
max_new_tokens=8,
|
max_new_tokens=8,
|
||||||
lr=5e-3,
|
lr=5e-3,
|
||||||
|
road_group_size=8,
|
||||||
seed=0,
|
seed=0,
|
||||||
log_examples=0,
|
log_examples=0,
|
||||||
log_every=1000,
|
log_every=1000,
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ resolution-markers = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[options]
|
[options]
|
||||||
exclude-newer = "2026-04-21T14:06:04.428693663Z"
|
exclude-newer = "2026-04-22T01:31:09.47902161Z"
|
||||||
exclude-newer-span = "P5D"
|
exclude-newer-span = "P5D"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1012,6 +1012,7 @@ dependencies = [
|
|||||||
{ name = "einops" },
|
{ name = "einops" },
|
||||||
{ name = "jaxtyping", version = "0.3.7", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
{ name = "jaxtyping", version = "0.3.7", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
||||||
{ name = "jaxtyping", version = "0.3.9", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
|
{ name = "jaxtyping", version = "0.3.9", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
|
||||||
|
{ name = "safetensors" },
|
||||||
{ name = "torch" },
|
{ name = "torch" },
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -1051,6 +1052,7 @@ requires-dist = [
|
|||||||
{ name = "einops", specifier = ">=0.7" },
|
{ name = "einops", specifier = ">=0.7" },
|
||||||
{ name = "jaxtyping", specifier = ">=0.2.34" },
|
{ name = "jaxtyping", specifier = ">=0.2.34" },
|
||||||
{ name = "pytest", marker = "extra == 'test'" },
|
{ name = "pytest", marker = "extra == 'test'" },
|
||||||
|
{ name = "safetensors", specifier = ">=0.5" },
|
||||||
{ name = "safetensors", marker = "extra == 'benchmark'", specifier = ">=0.5" },
|
{ name = "safetensors", marker = "extra == 'benchmark'", specifier = ">=0.5" },
|
||||||
{ name = "safetensors", marker = "extra == 'hf-test'", specifier = ">=0.5" },
|
{ name = "safetensors", marker = "extra == 'hf-test'", specifier = ">=0.5" },
|
||||||
{ name = "tabulate", marker = "extra == 'benchmark'" },
|
{ name = "tabulate", marker = "extra == 'benchmark'" },
|
||||||
|
|||||||
Reference in New Issue
Block a user