mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 18:24:15 +08:00
tyro and benchmark
This commit is contained in:
@@ -16,7 +16,7 @@ pip install -e git+https://github.com/wassname/lora-lite.git#egg=lora-lite
|
||||
import torch, lora_lite as ll
|
||||
|
||||
model = MyTransformer()
|
||||
cfg = ll.LoraLiteConfig(variant="lora", r=8, alpha=16, dtype=torch.bfloat16)
|
||||
cfg = ll.LoRAConfig(r=8, alpha=16, dtype=torch.bfloat16)
|
||||
ll.attach(model, cfg)
|
||||
|
||||
opt = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=1e-4)
|
||||
@@ -54,7 +54,8 @@ See [docs/spec/20260426_lora_lite_plan.md](docs/spec/20260426_lora_lite_plan.md)
|
||||
|
||||
By default, `lora-lite` targets linear-like modules with `in_features`, `out_features`, and `weight`, excluding `lm_head` and `embed_tokens`.
|
||||
|
||||
Useful `LoraLiteConfig` fields:
|
||||
Useful `AdapterConfig` fields (shared across all variants; subclasses add
|
||||
variant-specific knobs like `lambda0` on `DeLoRAConfig`):
|
||||
|
||||
- `target_roles`: subset of `("reader", "writer", "inner")`; `()` means all.
|
||||
- `target_names`: regex includes.
|
||||
|
||||
@@ -19,13 +19,56 @@ build:
|
||||
uv build
|
||||
uv run --extra build twine check dist/*
|
||||
|
||||
qwen-probe variants="lora pissa delora ia3" steps="8":
|
||||
uv run --extra test --extra hf-test python scripts/qwen_train_probe.py --variants {{variants}} --steps {{steps}}
|
||||
qwen-probe variants="lora pissa delora ia3" steps="5":
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
for variant in {{variants}}; do
|
||||
uv run --extra benchmark python scripts/metamath_gsm8k_benchmark.py \
|
||||
--mode probe \
|
||||
--model Qwen/Qwen3-0.6B-Base \
|
||||
--variant "$variant" \
|
||||
--steps {{steps}} \
|
||||
--batch-size 1 \
|
||||
--batch-size-eval 10 \
|
||||
--max-train-samples 32 \
|
||||
--max-eval-samples 10 \
|
||||
--max-new-tokens 32 \
|
||||
--max-seq-length 384 \
|
||||
--r 4 \
|
||||
--alpha 8 \
|
||||
--layers 0 \
|
||||
--lr 5e-3 \
|
||||
--target-name 'model\.layers\.0\.self_attn\.(q_proj|v_proj)$'
|
||||
done
|
||||
|
||||
qwen-queue variants="lora pissa delora ia3" steps="16":
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
pueue add \
|
||||
-l "why: verify Qwen0.6B train/save-load proof for {{variants}} at {{steps}} steps; resolve: publish only if exact targets, lora-only grads, loss drop, reload identity" \
|
||||
-l "why: verify Qwen0.6B train/save-load proof for {{variants}} at {{steps}} steps via benchmark probe mode; resolve: publish only if exact layer0 q/v targets, lora-only grads, perturb>0, reload<tol" \
|
||||
-w "$PWD" -o 1 -- \
|
||||
uv run --extra test --extra hf-test python scripts/qwen_train_probe.py --variants {{variants}} --steps {{steps}}
|
||||
just qwen-probe "{{variants}}" "{{steps}}"
|
||||
|
||||
metamath-smoke variant="lora" steps="2" max_train_samples="8" max_eval_samples="8" model="hf-internal-testing/tiny-random-LlamaForCausalLM" device="cpu":
|
||||
uv run --extra benchmark python scripts/metamath_gsm8k_benchmark.py \
|
||||
--model {{model}} \
|
||||
--variant {{variant}} \
|
||||
--steps {{steps}} \
|
||||
--batch-size 2 \
|
||||
--max-train-samples {{max_train_samples}} \
|
||||
--max-eval-samples {{max_eval_samples}} \
|
||||
--max-new-tokens 8 \
|
||||
--max-seq-length 128 \
|
||||
--r 2 \
|
||||
--alpha 4 \
|
||||
--layers 0 \
|
||||
--torch-dtype float32 \
|
||||
--device {{device}}
|
||||
|
||||
metamath-queue variant="lora" steps="5000" model="Qwen/Qwen3-0.6B-Base":
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
pueue add \
|
||||
-l "why: HF-style MetaMathQA->GSM8K benchmark for {{model}} {{variant}} at {{steps}} steps; resolve: result JSON under outputs/metamath_gsm8k proves grad>0 dθ>0 base_grad_leaks=0 and reports valid/test accuracy" \
|
||||
-w "$PWD" -o 1 -- \
|
||||
uv run --extra benchmark python scripts/metamath_gsm8k_benchmark.py --model {{model}} --variant {{variant}} --steps {{steps}}
|
||||
@@ -28,6 +28,7 @@ build = ["twine>=6"]
|
||||
test = ["pytest", "tabulate", "beartype>=0.18"]
|
||||
hf-test = ["accelerate>=1.6", "safetensors>=0.5", "transformers>=4.51"]
|
||||
bnb-test = ["bitsandbytes>=0.46"]
|
||||
benchmark = ["accelerate>=1.6", "datasets>=3.6", "safetensors>=0.5", "tabulate", "transformers>=4.51", "tyro>=0.9"]
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=68"]
|
||||
|
||||
@@ -1,214 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from tabulate import tabulate
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
import lora_lite as ll
|
||||
|
||||
|
||||
PROMPT = "LoRA-lite probe: Paris is the capital of France. The answer is"
|
||||
EXPECTED_TARGETS = {
|
||||
"model.layers.0.self_attn.q_proj",
|
||||
"model.layers.0.self_attn.v_proj",
|
||||
}
|
||||
|
||||
|
||||
def cfg_for_variant(variant: str, dtype: torch.dtype, r: int, alpha: float) -> ll.LoraLiteConfig:
|
||||
return ll.LoraLiteConfig(
|
||||
variant=variant,
|
||||
r=r,
|
||||
alpha=r if variant == "pissa" else alpha,
|
||||
dtype=dtype,
|
||||
target_roles=(),
|
||||
target_names=(r"model\.layers\.0\.self_attn\.(q_proj|v_proj)$",),
|
||||
layers=(0,),
|
||||
variant_kwargs={"lambda0": 0.1} if variant == "delora" else {},
|
||||
)
|
||||
|
||||
|
||||
def adapter_state(model: torch.nn.Module) -> dict[str, torch.Tensor]:
|
||||
return {k: v.detach().clone() for k, v in model.state_dict().items() if "lora_" in k}
|
||||
|
||||
|
||||
def assert_only_lora_trainable(model: torch.nn.Module) -> None:
|
||||
trainable = [name for name, p in model.named_parameters() if p.requires_grad]
|
||||
assert trainable
|
||||
assert all("lora_" in name for name in trainable), trainable[:20]
|
||||
|
||||
|
||||
def assert_no_base_grads(model: torch.nn.Module) -> None:
|
||||
leaked = [name for name, p in model.named_parameters() if "lora_" not in name and p.grad is not None]
|
||||
assert leaked == [], leaked[:20]
|
||||
|
||||
|
||||
def perturb_first_adapter(model: torch.nn.Module) -> None:
|
||||
"""Nudge one trainable adapter parameter so forward output changes.
|
||||
|
||||
Walks through trainable lora_* params in a priority order designed to keep
|
||||
the perturbation small and well-defined per variant:
|
||||
- identity-breakers first (lora_lambda, lora_gate) where adding to a scalar
|
||||
directly scales the delta;
|
||||
- then "outer" matrices set to zero at init (lora_B, lora_g) where bumping
|
||||
one entry creates a rank-1 perturbation;
|
||||
- lora_U for HRA (Householder vectors -- bumping breaks the paired
|
||||
cancellation and tilts the rotation away from identity);
|
||||
- lora_A for EVA / LoRA-style variants where A is trainable and B starts
|
||||
at zero, so we still need a way to break identity once any perturbation
|
||||
propagates.
|
||||
"""
|
||||
priority = ("lora_lambda", "lora_gate", "lora_B", "lora_g", "lora_U", "lora_A")
|
||||
for key in priority:
|
||||
for name, p in model.named_parameters():
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
if key in name:
|
||||
with torch.no_grad():
|
||||
if p.ndim == 0:
|
||||
p.add_(0.25)
|
||||
else:
|
||||
p.flatten()[0].add_(0.25)
|
||||
return
|
||||
raise AssertionError("no perturbable adapter parameter found")
|
||||
|
||||
|
||||
def load_model(model_id: str, dtype: torch.dtype, device: str):
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=dtype).to(device)
|
||||
model.config.use_cache = False
|
||||
return model
|
||||
|
||||
|
||||
def run_variant(args, variant: str, input_ids: torch.Tensor, labels: torch.Tensor, dtype: torch.dtype):
|
||||
model = load_model(args.model, dtype, args.device)
|
||||
model.train()
|
||||
cfg = cfg_for_variant(variant, dtype, args.r, args.alpha)
|
||||
|
||||
with torch.no_grad():
|
||||
logits_base = model(input_ids=input_ids).logits.detach().clone()
|
||||
|
||||
ll.attach(model, cfg)
|
||||
attached_targets = set(getattr(model, "_lora_lite_attached")["targets"])
|
||||
assert attached_targets == EXPECTED_TARGETS, attached_targets
|
||||
assert_only_lora_trainable(model)
|
||||
|
||||
with torch.no_grad():
|
||||
logits_init = model(input_ids=input_ids).logits.detach().clone()
|
||||
identity_err = (logits_init - logits_base).abs().max().item()
|
||||
|
||||
clean_adapter = adapter_state(model)
|
||||
perturb_first_adapter(model)
|
||||
with torch.no_grad():
|
||||
perturb_delta = (model(input_ids=input_ids).logits - logits_init).abs().max().item()
|
||||
assert perturb_delta > 1e-7, perturb_delta
|
||||
for name, value in clean_adapter.items():
|
||||
model.state_dict()[name].copy_(value)
|
||||
|
||||
opt = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=args.lr)
|
||||
with torch.no_grad():
|
||||
loss0 = model(input_ids=input_ids, labels=labels).loss.item()
|
||||
|
||||
before_train = adapter_state(model)
|
||||
first_grad_norm = math.nan
|
||||
loss_last = math.nan
|
||||
for step in range(args.steps):
|
||||
opt.zero_grad()
|
||||
loss = model(input_ids=input_ids, labels=labels).loss
|
||||
loss.backward()
|
||||
assert_no_base_grads(model)
|
||||
grad_norm = sum(
|
||||
p.grad.detach().float().norm().item()
|
||||
for name, p in model.named_parameters()
|
||||
if "lora_" in name and p.grad is not None
|
||||
)
|
||||
assert math.isfinite(grad_norm), grad_norm
|
||||
if step == 0:
|
||||
first_grad_norm = grad_norm
|
||||
opt.step()
|
||||
loss_last = loss.item()
|
||||
|
||||
after_train = adapter_state(model)
|
||||
adapter_delta = sum((after_train[k] - before_train[k]).float().norm().item() for k in before_train)
|
||||
assert first_grad_norm > 0, first_grad_norm
|
||||
assert adapter_delta > 0, adapter_delta
|
||||
assert loss_last < loss0, (loss0, loss_last)
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
logits_trained = model(input_ids=input_ids).logits.detach().clone()
|
||||
|
||||
out_path = args.out_dir / f"{variant}_adapter.pt"
|
||||
ll.save(model, str(out_path))
|
||||
saved = torch.load(out_path, weights_only=True, map_location="cpu")
|
||||
assert set(saved["state"]) == set(after_train)
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
loaded_model = load_model(args.model, dtype, args.device)
|
||||
loaded_model.eval()
|
||||
ll.load(loaded_model, str(out_path))
|
||||
loaded_state = adapter_state(loaded_model)
|
||||
for name, value in saved["state"].items():
|
||||
assert torch.equal(loaded_state[name].cpu(), value)
|
||||
with torch.no_grad():
|
||||
logits_loaded = loaded_model(input_ids=input_ids).logits.detach().clone()
|
||||
reload_err = (logits_loaded - logits_trained).abs().max().item()
|
||||
assert reload_err < args.reload_tol, reload_err
|
||||
|
||||
del loaded_model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return {
|
||||
"variant": variant,
|
||||
"targets": len(attached_targets),
|
||||
"trainable": sum(v.numel() for v in after_train.values()),
|
||||
"id_err": identity_err,
|
||||
"perturb": perturb_delta,
|
||||
"loss0": loss0,
|
||||
"lossN": loss_last,
|
||||
"drop%": 100 * (loss0 - loss_last) / loss0,
|
||||
"grad": first_grad_norm,
|
||||
"dθ": adapter_delta,
|
||||
"reload": reload_err,
|
||||
"out": str(out_path),
|
||||
}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", default="Qwen/Qwen3-0.6B")
|
||||
parser.add_argument("--variants", nargs="+", default=["lora", "pissa", "delora", "ia3", "dora", "hra"])
|
||||
parser.add_argument("--device", default="cuda")
|
||||
parser.add_argument("--torch-dtype", default="bfloat16")
|
||||
parser.add_argument("--steps", type=int, default=8)
|
||||
parser.add_argument("--lr", type=float, default=5e-3)
|
||||
parser.add_argument("--r", type=int, default=4)
|
||||
parser.add_argument("--alpha", type=float, default=8.0)
|
||||
parser.add_argument("--reload-tol", type=float, default=2e-2)
|
||||
parser.add_argument("--out-dir", type=Path, default=Path("outputs/qwen_train_probe"))
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.device == "cuda" and not torch.cuda.is_available():
|
||||
raise RuntimeError("CUDA is required for the default Qwen probe. Pass --device cpu explicitly for local debugging.")
|
||||
|
||||
args.out_dir.mkdir(parents=True, exist_ok=True)
|
||||
dtype = getattr(torch, args.torch_dtype)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
input_ids = tokenizer(PROMPT, return_tensors="pt").input_ids.to(args.device)
|
||||
labels = input_ids.clone()
|
||||
|
||||
print("SHOULD: exact q_proj/v_proj layer-0 targets, lora-only grads, lossN<loss0, perturb>0, reload<tol. ELSE hook/target/train/save bug.")
|
||||
rows = [run_variant(args, variant, input_ids, labels, dtype) for variant in args.variants]
|
||||
print(tabulate(rows, headers="keys", tablefmt="tsv", floatfmt=".4g"))
|
||||
print("ALL QWEN PROBES PASS")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -6,13 +6,32 @@ if _os.environ.get("BEARTYPE"):
|
||||
from beartype.claw import beartype_this_package as _bt
|
||||
_bt()
|
||||
|
||||
from .config import LoraLiteConfig
|
||||
from .config import AdapterConfig
|
||||
from .adapter import attach, detach, save, load
|
||||
from .variant import REGISTRY, register, ParamSpec, Variant
|
||||
from . import variants # noqa: F401 triggers variant registration
|
||||
from . import variants # noqa: F401 triggers variant + config registration
|
||||
|
||||
# Expose per-variant config classes for ergonomic typed construction.
|
||||
from .variants.lora import LoRAConfig
|
||||
from .variants.pissa import PiSSAConfig
|
||||
from .variants.delora import DeLoRAConfig
|
||||
from .variants.ia3 import IA3Config, IA3FFConfig
|
||||
from .variants.dora import DoRAConfig
|
||||
from .variants.hra import HRAConfig
|
||||
from .variants.eva import EVAConfig
|
||||
from .variants.antipasto import AntiPaSTOConfig
|
||||
|
||||
__all__ = [
|
||||
"LoraLiteConfig",
|
||||
"AdapterConfig",
|
||||
"LoRAConfig",
|
||||
"PiSSAConfig",
|
||||
"DeLoRAConfig",
|
||||
"IA3Config",
|
||||
"IA3FFConfig",
|
||||
"DoRAConfig",
|
||||
"HRAConfig",
|
||||
"EVAConfig",
|
||||
"AntiPaSTOConfig",
|
||||
"attach",
|
||||
"detach",
|
||||
"save",
|
||||
|
||||
@@ -4,7 +4,7 @@ import torch
|
||||
from torch import nn
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
from .config import LoraLiteConfig
|
||||
from .config import AdapterConfig
|
||||
from .variant import REGISTRY
|
||||
from .target import find_targets
|
||||
|
||||
@@ -14,7 +14,7 @@ _ATTACHED_ATTR = "_lora_lite_attached"
|
||||
|
||||
def _hook(layer, args, y):
|
||||
(x,) = args
|
||||
cfg: LoraLiteConfig = layer._lora_cfg
|
||||
cfg: AdapterConfig = layer._lora_cfg
|
||||
x_cast = x.to(cfg.dtype)
|
||||
out = layer._lora_variant.forward(layer, x_cast, y)
|
||||
return out.to(y.dtype)
|
||||
@@ -22,13 +22,13 @@ def _hook(layer, args, y):
|
||||
|
||||
def _pre_hook(layer, args):
|
||||
(x,) = args
|
||||
cfg: LoraLiteConfig = layer._lora_cfg
|
||||
cfg: AdapterConfig = layer._lora_cfg
|
||||
x_cast = x.to(cfg.dtype)
|
||||
x_new = layer._lora_variant.forward_input(layer, x_cast)
|
||||
return (x_new.to(x.dtype),)
|
||||
|
||||
|
||||
def attach(model: nn.Module, cfg: LoraLiteConfig, calibration_data=None, *, _skip_group_init: bool = False) -> list[RemovableHandle]:
|
||||
def attach(model: nn.Module, cfg: AdapterConfig, calibration_data=None, *, _skip_group_init: bool = False) -> list[RemovableHandle]:
|
||||
if cfg.variant not in REGISTRY:
|
||||
raise KeyError(f"unknown variant {cfg.variant!r}; registered: {list(REGISTRY)}")
|
||||
variant = REGISTRY[cfg.variant]
|
||||
@@ -131,7 +131,7 @@ def save(model: nn.Module, path: str) -> None:
|
||||
|
||||
def load(model: nn.Module, path: str) -> list[RemovableHandle]:
|
||||
blob = torch.load(path, weights_only=True, map_location="cpu")
|
||||
cfg = LoraLiteConfig.from_dict(blob["cfg"])
|
||||
cfg = AdapterConfig.from_dict(blob["cfg"])
|
||||
handles = attach(model, cfg, _skip_group_init=True) # creates empty params; data-driven inits restored from state_dict
|
||||
missing, unexpected = model.load_state_dict(blob["state"], strict=False)
|
||||
expected_lora = {k for k in model.state_dict() if "lora_" in k}
|
||||
|
||||
+44
-11
@@ -1,13 +1,35 @@
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from typing import Any, Literal
|
||||
"""AdapterConfig: per-variant typed dataclass.
|
||||
|
||||
Replaces the older `LoraLiteConfig` + `variant_kwargs` dict. Each variant
|
||||
ships its own subclass under `variants/*.py` (e.g. `DeLoRAConfig`), adding
|
||||
strongly-typed fields so users discover the knobs via IDE / dataclass
|
||||
introspection instead of stringly-typed dict lookups.
|
||||
|
||||
Wire-up:
|
||||
- `AdapterConfig` holds the universal fields (variant name, rank, alpha,
|
||||
dtype, targeting filters).
|
||||
- Subclasses override the `variant` default and add new fields.
|
||||
- `register_config(cls)` adds the subclass to `_CONFIG_REGISTRY` so
|
||||
`from_dict` can route to the right type at load time.
|
||||
|
||||
Save format:
|
||||
to_dict() emits a flat dict including `variant`; from_dict() uses that
|
||||
field to look up the right subclass.
|
||||
"""
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import Literal
|
||||
import torch
|
||||
|
||||
Role = Literal["reader", "writer", "inner"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoraLiteConfig:
|
||||
variant: str = "lora"
|
||||
class AdapterConfig:
|
||||
"""Base config. Subclass per variant; do not instantiate directly."""
|
||||
# variant name (subclass overrides default)
|
||||
variant: str = "?"
|
||||
|
||||
# rank-style hyperparams shared across most variants
|
||||
r: int = 8
|
||||
alpha: float | int = 16.0
|
||||
dtype: torch.dtype = torch.bfloat16
|
||||
@@ -18,19 +40,30 @@ class LoraLiteConfig:
|
||||
exclude_names: tuple[str, ...] = ("lm_head", "embed_tokens")
|
||||
layers: tuple[int, ...] | None = None
|
||||
|
||||
# variant-specific bag (e.g. lambda0 for DeLoRA)
|
||||
variant_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
d = asdict(self)
|
||||
d["dtype"] = str(self.dtype).removeprefix("torch.")
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> "LoraLiteConfig":
|
||||
# to_dict always serializes dtype as str; torch.save preserves tuples.
|
||||
# If you build the dict by hand, pass the right types -- fail loud otherwise.
|
||||
def from_dict(cls, d: dict) -> "AdapterConfig":
|
||||
d = dict(d)
|
||||
name = d["variant"]
|
||||
sub = _CONFIG_REGISTRY[name]
|
||||
d["dtype"] = getattr(torch, d["dtype"])
|
||||
return cls(**d)
|
||||
return sub(**d)
|
||||
|
||||
|
||||
# Registry of variant_name -> config subclass. Populated by `register_config`
|
||||
# decorators in each `variants/*.py` module at import time.
|
||||
_CONFIG_REGISTRY: dict[str, type[AdapterConfig]] = {}
|
||||
|
||||
|
||||
def register_config(cls: type[AdapterConfig]) -> type[AdapterConfig]:
|
||||
"""Decorator: register `cls` under its `variant` default value."""
|
||||
name = cls.__dataclass_fields__["variant"].default
|
||||
if name in _CONFIG_REGISTRY:
|
||||
raise ValueError(f"config for variant {name!r} already registered")
|
||||
_CONFIG_REGISTRY[name] = cls
|
||||
return cls
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Callable, Protocol, Any
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .config import LoraLiteConfig
|
||||
from .config import AdapterConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -44,10 +44,10 @@ class Variant(Protocol):
|
||||
name: str
|
||||
|
||||
@staticmethod
|
||||
def param_specs(d_in: int, d_out: int, cfg: LoraLiteConfig) -> dict[str, ParamSpec]: ...
|
||||
def param_specs(d_in: int, d_out: int, cfg: AdapterConfig) -> dict[str, ParamSpec]: ...
|
||||
|
||||
@staticmethod
|
||||
def init(layer: nn.Linear, cfg: LoraLiteConfig) -> None: ...
|
||||
def init(layer: nn.Linear, cfg: AdapterConfig) -> None: ...
|
||||
|
||||
@staticmethod
|
||||
def forward(layer: nn.Linear, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@@ -29,13 +29,15 @@ precision (fp32 SVD round-tripped to cfg.dtype).
|
||||
WHICH BASIS IS ROTATED:
|
||||
By default we rotate Vh (the INPUT singular basis). This is what AntiPaSTO3
|
||||
calls `rotate_V=True` in adapter terms (V == Vh.T columns). To rotate U
|
||||
(output basis) instead, pass variant_kwargs={'rotate_basis': 'U'}.
|
||||
(output basis) instead, pass `rotate_basis='U'` on the AntiPaSTOConfig.
|
||||
Rotating both is not implemented (one rotation is enough to span the
|
||||
identifiable steering directions; two is degenerate).
|
||||
|
||||
REQUIRES even rank divisible by `block_size` (default 4). r=8, bs=4 -> 2 blocks.
|
||||
"""
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from einops import einsum
|
||||
@@ -43,6 +45,19 @@ from jaxtyping import Float
|
||||
from torch import nn, Tensor as T
|
||||
|
||||
from ..variant import register, ParamSpec
|
||||
from ..config import AdapterConfig, register_config
|
||||
|
||||
|
||||
@register_config
|
||||
@dataclass
|
||||
class AntiPaSTOConfig(AdapterConfig):
|
||||
variant: str = "antipasto"
|
||||
# Block size for the block-diagonal Cayley rotation. r must be divisible by it.
|
||||
block_size: int = 4
|
||||
# Cayley map saturation: bounds rotation angle to ~max_rotation_angle radians.
|
||||
max_rotation_angle: float = 0.5
|
||||
# Which singular basis to rotate: 'V' (input) or 'U' (output).
|
||||
rotate_basis: Literal["V", "U"] = "V"
|
||||
|
||||
|
||||
def _cayley(skew: torch.Tensor) -> torch.Tensor:
|
||||
@@ -81,7 +96,7 @@ class AntiPaSTO:
|
||||
@staticmethod
|
||||
def param_specs(d_in, d_out, cfg):
|
||||
r = cfg.r
|
||||
bs = int(cfg.variant_kwargs.get("block_size", 4))
|
||||
bs = int(cfg.block_size)
|
||||
if r % bs != 0:
|
||||
raise ValueError(f"AntiPaSTO requires r={r} divisible by block_size={bs}")
|
||||
n_blocks = r // bs
|
||||
@@ -123,9 +138,9 @@ class AntiPaSTO:
|
||||
y: Float[T, '*B o'],
|
||||
) -> Float[T, '*B o']:
|
||||
cfg = layer._lora_cfg
|
||||
bs = int(cfg.variant_kwargs.get("block_size", 4))
|
||||
max_angle = float(cfg.variant_kwargs.get("max_rotation_angle", 0.5))
|
||||
rotate_basis = cfg.variant_kwargs.get("rotate_basis", "V")
|
||||
bs = int(cfg.block_size)
|
||||
max_angle = float(cfg.max_rotation_angle)
|
||||
rotate_basis = cfg.rotate_basis
|
||||
|
||||
U = layer.lora_U.to(x.dtype) # (d_out, r)
|
||||
S = layer.lora_S.to(x.dtype) # (r,)
|
||||
|
||||
@@ -35,8 +35,19 @@ import torch
|
||||
from einops import einsum
|
||||
from jaxtyping import Float
|
||||
from torch import nn, Tensor as T
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..variant import register, ParamSpec
|
||||
from ..config import AdapterConfig, register_config
|
||||
|
||||
|
||||
@register_config
|
||||
@dataclass
|
||||
class DeLoRAConfig(AdapterConfig):
|
||||
variant: str = "delora"
|
||||
# Initial scale for the per-layer learnable lambda. peft default is 15.0;
|
||||
# we default to 0.0 (identity at t=0 even before B is zero-initialized).
|
||||
lambda0: float = 0.0
|
||||
|
||||
|
||||
@register
|
||||
@@ -45,7 +56,7 @@ class DeLoRA:
|
||||
|
||||
@staticmethod
|
||||
def param_specs(d_in, d_out, cfg):
|
||||
lam0 = float(cfg.variant_kwargs.get("lambda0", 0.0))
|
||||
lam0 = float(cfg.lambda0)
|
||||
return {
|
||||
# peft DeLoRA default: A=kaiming, B=zeros (docs/refs/peft_delora_layer.py:138-140).
|
||||
# Identity at t=0 from B=0 -> delta=0 regardless of lambda. With B=0 the
|
||||
|
||||
@@ -22,8 +22,16 @@ import torch
|
||||
from einops import einsum
|
||||
from jaxtyping import Float
|
||||
from torch import nn, Tensor as T
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..variant import register, ParamSpec
|
||||
from ..config import AdapterConfig, register_config
|
||||
|
||||
|
||||
@register_config
|
||||
@dataclass
|
||||
class DoRAConfig(AdapterConfig):
|
||||
variant: str = "dora"
|
||||
|
||||
|
||||
@register
|
||||
|
||||
@@ -36,13 +36,21 @@ from einops import einsum
|
||||
from jaxtyping import Float
|
||||
from torch import nn, Tensor as T
|
||||
from typing import Iterable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..variant import register, ParamSpec
|
||||
from ..config import AdapterConfig, register_config
|
||||
|
||||
CalibrationBatch = dict | tuple | list | T
|
||||
CalibrationData = Iterable[CalibrationBatch]
|
||||
|
||||
|
||||
@register_config
|
||||
@dataclass
|
||||
class EVAConfig(AdapterConfig):
|
||||
variant: str = "eva"
|
||||
|
||||
|
||||
@register
|
||||
class EVA:
|
||||
name = "eva"
|
||||
|
||||
@@ -32,8 +32,16 @@ import torch
|
||||
from einops import einsum
|
||||
from jaxtyping import Float
|
||||
from torch import nn, Tensor as T
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..variant import register, ParamSpec
|
||||
from ..config import AdapterConfig, register_config
|
||||
|
||||
|
||||
@register_config
|
||||
@dataclass
|
||||
class HRAConfig(AdapterConfig):
|
||||
variant: str = "hra"
|
||||
|
||||
|
||||
@register
|
||||
|
||||
@@ -16,8 +16,8 @@ In both cases g is initialized to 1 -> identity at t=0.
|
||||
To match the paper exactly on a Llama/Qwen-style block requires TWO attach
|
||||
passes (one per variant), since each variant uses one hook type:
|
||||
|
||||
cfg_attn = LoraLiteConfig(variant="ia3", target_names=(r"\\.k_proj$", r"\\.v_proj$"))
|
||||
cfg_ffn = LoraLiteConfig(variant="ia3_ff", target_names=(r"\\.down_proj$",))
|
||||
cfg_attn = IA3Config( target_names=(r"\\.k_proj$", r"\\.v_proj$"))
|
||||
cfg_ffn = IA3FFConfig( target_names=(r"\\.down_proj$",))
|
||||
|
||||
Reference implementation:
|
||||
- peft IA3 layer (is_feedforward toggles input-vs-output gating, see
|
||||
@@ -27,8 +27,22 @@ Reference implementation:
|
||||
import torch
|
||||
from jaxtyping import Float
|
||||
from torch import nn, Tensor as T
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..variant import register, ParamSpec
|
||||
from ..config import AdapterConfig, register_config
|
||||
|
||||
|
||||
@register_config
|
||||
@dataclass
|
||||
class IA3Config(AdapterConfig):
|
||||
variant: str = "ia3"
|
||||
|
||||
|
||||
@register_config
|
||||
@dataclass
|
||||
class IA3FFConfig(AdapterConfig):
|
||||
variant: str = "ia3_ff"
|
||||
|
||||
|
||||
@register
|
||||
|
||||
@@ -13,8 +13,16 @@ from einops import einsum
|
||||
from jaxtyping import Float
|
||||
from torch import nn, Tensor as T
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..variant import register, ParamSpec
|
||||
from ..config import AdapterConfig, register_config
|
||||
|
||||
|
||||
@register_config
|
||||
@dataclass
|
||||
class LoRAConfig(AdapterConfig):
|
||||
variant: str = "lora"
|
||||
|
||||
|
||||
@register
|
||||
|
||||
@@ -24,8 +24,16 @@ import torch
|
||||
from einops import einsum
|
||||
from jaxtyping import Float
|
||||
from torch import nn, Tensor as T
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..variant import register, ParamSpec
|
||||
from ..config import AdapterConfig, register_config
|
||||
|
||||
|
||||
@register_config
|
||||
@dataclass
|
||||
class PiSSAConfig(AdapterConfig):
|
||||
variant: str = "pissa"
|
||||
|
||||
|
||||
@register
|
||||
|
||||
+30
-15
@@ -95,6 +95,19 @@ class FakeBnbModel(nn.Module):
|
||||
return self.layers[0](x)
|
||||
|
||||
|
||||
_CFG_BY_VARIANT = {
|
||||
"lora": ll.LoRAConfig,
|
||||
"pissa": ll.PiSSAConfig,
|
||||
"delora": ll.DeLoRAConfig,
|
||||
"ia3": ll.IA3Config,
|
||||
"ia3_ff": ll.IA3FFConfig,
|
||||
"dora": ll.DoRAConfig,
|
||||
"hra": ll.HRAConfig,
|
||||
"eva": ll.EVAConfig,
|
||||
"antipasto": ll.AntiPaSTOConfig,
|
||||
}
|
||||
|
||||
|
||||
def variant_test(variant: str, dtype=torch.float32):
|
||||
print(f"\n=== variant={variant} dtype={dtype} ===")
|
||||
torch.manual_seed(0)
|
||||
@@ -104,13 +117,14 @@ def variant_test(variant: str, dtype=torch.float32):
|
||||
with torch.no_grad():
|
||||
y_base = model(ids).clone()
|
||||
|
||||
cfg = ll.LoraLiteConfig(
|
||||
variant=variant,
|
||||
cfg_cls = _CFG_BY_VARIANT[variant]
|
||||
extra = {"lambda0": 15.0} if variant == "delora" else {}
|
||||
cfg = cfg_cls(
|
||||
r=4,
|
||||
alpha=4 if variant == "pissa" else 8, # PiSSA needs scale==1 for clean recon
|
||||
dtype=dtype,
|
||||
# delora identity holds via B=0 init (peft semantics); use peft default lambda0=15.
|
||||
variant_kwargs={"lambda0": 15.0} if variant == "delora" else {},
|
||||
**extra,
|
||||
)
|
||||
handles = ll.attach(model, cfg)
|
||||
n_targets = len(handles)
|
||||
@@ -164,9 +178,8 @@ def variant_test(variant: str, dtype=torch.float32):
|
||||
model = TinyModel().to(dtype)
|
||||
train_cfg = cfg
|
||||
if variant == "delora":
|
||||
train_cfg = ll.LoraLiteConfig(
|
||||
variant=cfg.variant, r=cfg.r, alpha=cfg.alpha, dtype=cfg.dtype,
|
||||
variant_kwargs={"lambda0": 0.1},
|
||||
train_cfg = ll.DeLoRAConfig(
|
||||
r=cfg.r, alpha=cfg.alpha, dtype=cfg.dtype, lambda0=0.1,
|
||||
)
|
||||
ll.attach(model, train_cfg)
|
||||
target = torch.randn(2, 16, 100, dtype=dtype) * 0.1
|
||||
@@ -200,7 +213,7 @@ def structural_linear_like_test():
|
||||
model = FakeBnbModel()
|
||||
x = torch.randn(2, 3, 8)
|
||||
y_base = model(x).detach()
|
||||
ll.attach(model, ll.LoraLiteConfig(variant="lora", r=2, alpha=4, dtype=torch.float32, target_roles=()))
|
||||
ll.attach(model, ll.LoRAConfig(r=2, alpha=4, dtype=torch.float32, target_roles=()))
|
||||
layer = model.layers[0]
|
||||
assert hasattr(layer, "lora_A") and hasattr(layer, "lora_B")
|
||||
y = model(x)
|
||||
@@ -262,10 +275,12 @@ def bitsandbytes_cuda_smoke(require_bnb: bool):
|
||||
model = BnbModel(layer_cls)
|
||||
x = torch.randn(2, 3, 8, device="cuda")
|
||||
y_base = model(x).detach()
|
||||
cfg = ll.LoraLiteConfig(
|
||||
variant=variant, r=2, alpha=4, dtype=torch.float16, target_roles=(),
|
||||
# In fp16 + bnb, peft default lambda0=15 + B=0 + clamp(min=1e-4) gives\n # scale=lambda/(r*1e-4) ~ 75000 > fp16 max -> inf*0 = NaN. Use small\n # lambda0 for the fp16 test.\n variant_kwargs={"lambda0": 0.1} if variant == "delora" else {},
|
||||
)
|
||||
cfg_cls = _CFG_BY_VARIANT[variant]
|
||||
extra = {"lambda0": 0.1} if variant == "delora" else {}
|
||||
# In fp16 + bnb, peft default lambda0=15 + B=0 + clamp(min=1e-4) gives
|
||||
# scale=lambda/(r*1e-4) ~ 75000 > fp16 max -> inf*0 = NaN. Use small
|
||||
# lambda0 for the fp16 test.
|
||||
cfg = cfg_cls(r=2, alpha=4, dtype=torch.float16, target_roles=(), **extra)
|
||||
ll.attach(model, cfg)
|
||||
y = model(x)
|
||||
err = (y.detach() - y_base).abs().max().item()
|
||||
@@ -281,7 +296,7 @@ def bitsandbytes_cuda_smoke(require_bnb: bool):
|
||||
|
||||
for variant in bnb_fail:
|
||||
model = BnbModel(layer_cls)
|
||||
cfg = ll.LoraLiteConfig(variant=variant, r=2, alpha=2, dtype=torch.float16, target_roles=())
|
||||
cfg = _CFG_BY_VARIANT[variant](r=2, alpha=2, dtype=torch.float16, target_roles=())
|
||||
try:
|
||||
ll.attach(model, cfg)
|
||||
except (TypeError, RuntimeError, AttributeError, ValueError) as e:
|
||||
@@ -300,7 +315,7 @@ def eva_smoke():
|
||||
with torch.no_grad():
|
||||
y_base = model(ids).clone()
|
||||
|
||||
cfg = ll.LoraLiteConfig(variant="eva", r=4, alpha=8, dtype=torch.float32)
|
||||
cfg = ll.EVAConfig(r=4, alpha=8, dtype=torch.float32)
|
||||
# 4 calibration batches of random ids
|
||||
calib = [torch.randint(0, 100, (2, 16)) for _ in range(4)]
|
||||
ll.attach(model, cfg, calibration_data=calib)
|
||||
@@ -380,7 +395,7 @@ def dora_bias_smoke():
|
||||
return self.layers[0](x)
|
||||
|
||||
model = Wrap(layer)
|
||||
cfg = ll.LoraLiteConfig(variant="dora", r=2, alpha=4, dtype=torch.float32, target_roles=())
|
||||
cfg = ll.DoRAConfig(r=2, alpha=4, dtype=torch.float32, target_roles=())
|
||||
ll.attach(model, cfg)
|
||||
with torch.no_grad():
|
||||
y_adapt = model(x)
|
||||
@@ -404,7 +419,7 @@ def hra_forward_order_smoke():
|
||||
layer = nn.Linear(d, d, bias=False)
|
||||
x = torch.randn(2, 3, d)
|
||||
|
||||
cfg = ll.LoraLiteConfig(variant="hra", r=4, alpha=4, dtype=torch.float32, target_roles=())
|
||||
cfg = ll.HRAConfig(r=4, alpha=4, dtype=torch.float32, target_roles=())
|
||||
class Wrap(nn.Module):
|
||||
def __init__(self_, lin):
|
||||
super().__init__()
|
||||
|
||||
+40
-28
@@ -67,13 +67,28 @@ class FakeBnbModel(nn.Module):
|
||||
return self.layers[0](x)
|
||||
|
||||
|
||||
def cfg_for_variant(variant: str, *, training: bool = False) -> ll.LoraLiteConfig:
|
||||
return ll.LoraLiteConfig(
|
||||
variant=variant,
|
||||
_CFG_BY_VARIANT = {
|
||||
"lora": ll.LoRAConfig,
|
||||
"pissa": ll.PiSSAConfig,
|
||||
"delora": ll.DeLoRAConfig,
|
||||
"ia3": ll.IA3Config,
|
||||
"ia3_ff": ll.IA3FFConfig,
|
||||
"dora": ll.DoRAConfig,
|
||||
"hra": ll.HRAConfig,
|
||||
"eva": ll.EVAConfig,
|
||||
"antipasto": ll.AntiPaSTOConfig,
|
||||
}
|
||||
|
||||
|
||||
def cfg_for_variant(variant: str, *, training: bool = False) -> ll.AdapterConfig:
|
||||
# DeLoRA keeps identity via B=0, so nonzero lambda is needed for the
|
||||
# perturb-output check to distinguish a live adapter from dead code.
|
||||
extra = {"lambda0": 0.1} if variant == "delora" else {}
|
||||
return _CFG_BY_VARIANT[variant](
|
||||
r=4,
|
||||
alpha=4 if variant == "pissa" else 8,
|
||||
dtype=torch.float32,
|
||||
variant_kwargs={"lambda0": 0.1 if training else 0.0} if variant == "delora" else {},
|
||||
**extra,
|
||||
)
|
||||
|
||||
|
||||
@@ -93,25 +108,21 @@ def assert_no_base_grads(model: nn.Module) -> None:
|
||||
|
||||
|
||||
def perturb_first_adapter(model: nn.Module) -> None:
|
||||
for name, p in model.named_parameters():
|
||||
if "lora_lambda" in name:
|
||||
"""Nudge one trainable adapter parameter so forward output changes.
|
||||
|
||||
Priority order matters: with B=0 init (DeLoRA, EVA, LoRA), perturbing a
|
||||
scalar gate or lambda alone keeps delta=0, so we hit a matrix entry first.
|
||||
"""
|
||||
priority = ("lora_B", "lora_g", "lora_U", "lora_A", "lora_lambda", "lora_gate")
|
||||
for key in priority:
|
||||
for name, p in model.named_parameters():
|
||||
if not p.requires_grad or key not in name:
|
||||
continue
|
||||
with torch.no_grad():
|
||||
p.add_(0.25)
|
||||
return
|
||||
for name, p in model.named_parameters():
|
||||
if "lora_gate" in name:
|
||||
with torch.no_grad():
|
||||
p.add_(0.25)
|
||||
return
|
||||
for name, p in model.named_parameters():
|
||||
if "lora_B" in name:
|
||||
with torch.no_grad():
|
||||
p.flatten()[0].add_(0.25)
|
||||
return
|
||||
for name, p in model.named_parameters():
|
||||
if "lora_g" in name:
|
||||
with torch.no_grad():
|
||||
p.flatten()[0].add_(0.25)
|
||||
if p.ndim == 0:
|
||||
p.add_(0.25)
|
||||
else:
|
||||
p.flatten()[0].add_(0.25)
|
||||
return
|
||||
raise AssertionError("no perturbable adapter parameter found")
|
||||
|
||||
@@ -134,7 +145,7 @@ def test_variant_identity_hook_save_load_and_training(variant: str):
|
||||
with torch.no_grad():
|
||||
y_init = model(ids).clone()
|
||||
identity_err = (y_init - y_base).abs().max().item()
|
||||
identity_tol = {"lora": 1e-6, "pissa": 5e-4, "delora": 1e-6, "ia3": 1e-6, "dora": 5e-5, "hra": 1e-6}[variant]
|
||||
identity_tol = {"lora": 1e-6, "pissa": 5e-4, "delora": 1e-6, "ia3": 1e-6, "dora": 5e-5, "hra": 5e-6}[variant]
|
||||
assert identity_err < identity_tol
|
||||
|
||||
before_perturb = adapter_state(model)
|
||||
@@ -221,7 +232,7 @@ def test_load_fails_on_missing_and_unexpected_lora_keys():
|
||||
|
||||
|
||||
def test_no_target_layers_is_loud_failure():
|
||||
cfg = ll.LoraLiteConfig(variant="lora", target_names=("definitely_missing",))
|
||||
cfg = ll.LoRAConfig(target_names=("definitely_missing",))
|
||||
with pytest.raises(RuntimeError, match="no target layers"):
|
||||
ll.attach(TinyModel(), cfg)
|
||||
|
||||
@@ -232,16 +243,17 @@ def test_structural_non_linear_target_trains_for_forward_only_variants(variant:
|
||||
model = FakeBnbModel()
|
||||
x = torch.randn(2, 3, 8)
|
||||
y_base = model(x).detach()
|
||||
cfg = ll.LoraLiteConfig(
|
||||
variant=variant,
|
||||
extra = {"lambda0": 0.1} if variant == "delora" else {}
|
||||
cfg = _CFG_BY_VARIANT[variant](
|
||||
r=2,
|
||||
alpha=4,
|
||||
dtype=torch.float32,
|
||||
target_roles=(),
|
||||
variant_kwargs={"lambda0": 0.0} if variant == "delora" else {},
|
||||
**extra,
|
||||
)
|
||||
ll.attach(model, cfg)
|
||||
y_init = model(x)
|
||||
# delora: lambda0=0.1 is small but B=0 still makes delta=0 at t=0, so identity holds.
|
||||
assert (y_init.detach() - y_base).abs().max().item() < 1e-6
|
||||
loss = y_init.pow(2).mean()
|
||||
loss.backward()
|
||||
@@ -256,6 +268,6 @@ def test_structural_non_linear_target_trains_for_forward_only_variants(variant:
|
||||
|
||||
@pytest.mark.parametrize("variant", ["pissa", "dora"])
|
||||
def test_weight_reading_variants_reject_structural_non_linear_target(variant: str):
|
||||
cfg = ll.LoraLiteConfig(variant=variant, r=2, alpha=2, dtype=torch.float32, target_roles=())
|
||||
cfg = _CFG_BY_VARIANT[variant](r=2, alpha=2, dtype=torch.float32, target_roles=())
|
||||
with pytest.raises(TypeError, match="plain nn.Linear"):
|
||||
ll.attach(FakeBnbModel(), cfg)
|
||||
@@ -0,0 +1,37 @@
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
SCRIPT_PATH = Path(__file__).parents[1] / "scripts" / "metamath_gsm8k_benchmark.py"
|
||||
SPEC = importlib.util.spec_from_file_location("metamath_gsm8k_benchmark", SCRIPT_PATH)
|
||||
benchmark = importlib.util.module_from_spec(SPEC)
|
||||
assert SPEC.loader is not None
|
||||
sys.modules[SPEC.name] = benchmark
|
||||
SPEC.loader.exec_module(benchmark)
|
||||
|
||||
extract_answer = benchmark.extract_answer
|
||||
score_predictions = benchmark.score_predictions
|
||||
|
||||
|
||||
def test_extract_answer_handles_gsm8k_numeric_forms():
|
||||
assert extract_answer("#### 42") == "42"
|
||||
assert extract_answer("The answer is 1,234.") == "1234"
|
||||
assert extract_answer("So x = -17") == "-17"
|
||||
|
||||
|
||||
def test_score_predictions_uses_continuation_answers_only():
|
||||
predictions = [
|
||||
"We compute it. The answer is 42.",
|
||||
"No final number here",
|
||||
"Prompt said #### 5. But the continuation answer is 6.",
|
||||
]
|
||||
references = [
|
||||
"reasoning\n#### 42",
|
||||
"reasoning\n#### 9",
|
||||
"reasoning\n#### 5",
|
||||
]
|
||||
scored = score_predictions(predictions, references)
|
||||
assert scored["correct"] == 1
|
||||
assert scored["total"] == 3
|
||||
assert scored["accuracy"] == 1 / 3
|
||||
Reference in New Issue
Block a user