tyro and benchmark

This commit is contained in:
wassname
2026-04-27 06:23:30 +08:00
parent 67a6daf6aa
commit b179771cc6
20 changed files with 1504 additions and 325 deletions
+3 -2
View File
@@ -16,7 +16,7 @@ pip install -e git+https://github.com/wassname/lora-lite.git#egg=lora-lite
import torch, lora_lite as ll import torch, lora_lite as ll
model = MyTransformer() model = MyTransformer()
cfg = ll.LoraLiteConfig(variant="lora", r=8, alpha=16, dtype=torch.bfloat16) cfg = ll.LoRAConfig(r=8, alpha=16, dtype=torch.bfloat16)
ll.attach(model, cfg) ll.attach(model, cfg)
opt = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=1e-4) opt = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=1e-4)
@@ -54,7 +54,8 @@ See [docs/spec/20260426_lora_lite_plan.md](docs/spec/20260426_lora_lite_plan.md)
By default, `lora-lite` targets linear-like modules with `in_features`, `out_features`, and `weight`, excluding `lm_head` and `embed_tokens`. By default, `lora-lite` targets linear-like modules with `in_features`, `out_features`, and `weight`, excluding `lm_head` and `embed_tokens`.
Useful `LoraLiteConfig` fields: Useful `AdapterConfig` fields (shared across all variants; subclasses add
variant-specific knobs like `lambda0` on `DeLoRAConfig`):
- `target_roles`: subset of `("reader", "writer", "inner")`; `()` means all. - `target_roles`: subset of `("reader", "writer", "inner")`; `()` means all.
- `target_names`: regex includes. - `target_names`: regex includes.
+47 -4
View File
@@ -19,13 +19,56 @@ build:
uv build uv build
uv run --extra build twine check dist/* uv run --extra build twine check dist/*
qwen-probe variants="lora pissa delora ia3" steps="8": qwen-probe variants="lora pissa delora ia3" steps="5":
uv run --extra test --extra hf-test python scripts/qwen_train_probe.py --variants {{variants}} --steps {{steps}} #!/usr/bin/env bash
set -euo pipefail
for variant in {{variants}}; do
uv run --extra benchmark python scripts/metamath_gsm8k_benchmark.py \
--mode probe \
--model Qwen/Qwen3-0.6B-Base \
--variant "$variant" \
--steps {{steps}} \
--batch-size 1 \
--batch-size-eval 10 \
--max-train-samples 32 \
--max-eval-samples 10 \
--max-new-tokens 32 \
--max-seq-length 384 \
--r 4 \
--alpha 8 \
--layers 0 \
--lr 5e-3 \
--target-name 'model\.layers\.0\.self_attn\.(q_proj|v_proj)$'
done
qwen-queue variants="lora pissa delora ia3" steps="16": qwen-queue variants="lora pissa delora ia3" steps="16":
#!/usr/bin/env bash #!/usr/bin/env bash
set -euo pipefail set -euo pipefail
pueue add \ pueue add \
-l "why: verify Qwen0.6B train/save-load proof for {{variants}} at {{steps}} steps; resolve: publish only if exact targets, lora-only grads, loss drop, reload identity" \ -l "why: verify Qwen0.6B train/save-load proof for {{variants}} at {{steps}} steps via benchmark probe mode; resolve: publish only if exact layer0 q/v targets, lora-only grads, perturb>0, reload<tol" \
-w "$PWD" -o 1 -- \ -w "$PWD" -o 1 -- \
uv run --extra test --extra hf-test python scripts/qwen_train_probe.py --variants {{variants}} --steps {{steps}} just qwen-probe "{{variants}}" "{{steps}}"
metamath-smoke variant="lora" steps="2" max_train_samples="8" max_eval_samples="8" model="hf-internal-testing/tiny-random-LlamaForCausalLM" device="cpu":
uv run --extra benchmark python scripts/metamath_gsm8k_benchmark.py \
--model {{model}} \
--variant {{variant}} \
--steps {{steps}} \
--batch-size 2 \
--max-train-samples {{max_train_samples}} \
--max-eval-samples {{max_eval_samples}} \
--max-new-tokens 8 \
--max-seq-length 128 \
--r 2 \
--alpha 4 \
--layers 0 \
--torch-dtype float32 \
--device {{device}}
metamath-queue variant="lora" steps="5000" model="Qwen/Qwen3-0.6B-Base":
#!/usr/bin/env bash
set -euo pipefail
pueue add \
-l "why: HF-style MetaMathQA->GSM8K benchmark for {{model}} {{variant}} at {{steps}} steps; resolve: result JSON under outputs/metamath_gsm8k proves grad>0 dθ>0 base_grad_leaks=0 and reports valid/test accuracy" \
-w "$PWD" -o 1 -- \
uv run --extra benchmark python scripts/metamath_gsm8k_benchmark.py --model {{model}} --variant {{variant}} --steps {{steps}}
+1
View File
@@ -28,6 +28,7 @@ build = ["twine>=6"]
test = ["pytest", "tabulate", "beartype>=0.18"] test = ["pytest", "tabulate", "beartype>=0.18"]
hf-test = ["accelerate>=1.6", "safetensors>=0.5", "transformers>=4.51"] hf-test = ["accelerate>=1.6", "safetensors>=0.5", "transformers>=4.51"]
bnb-test = ["bitsandbytes>=0.46"] bnb-test = ["bitsandbytes>=0.46"]
benchmark = ["accelerate>=1.6", "datasets>=3.6", "safetensors>=0.5", "tabulate", "transformers>=4.51", "tyro>=0.9"]
[build-system] [build-system]
requires = ["setuptools>=68"] requires = ["setuptools>=68"]
-214
View File
@@ -1,214 +0,0 @@
from __future__ import annotations
import argparse
import gc
import math
from pathlib import Path
import torch
from tabulate import tabulate
from transformers import AutoModelForCausalLM, AutoTokenizer
import lora_lite as ll
PROMPT = "LoRA-lite probe: Paris is the capital of France. The answer is"
EXPECTED_TARGETS = {
"model.layers.0.self_attn.q_proj",
"model.layers.0.self_attn.v_proj",
}
def cfg_for_variant(variant: str, dtype: torch.dtype, r: int, alpha: float) -> ll.LoraLiteConfig:
return ll.LoraLiteConfig(
variant=variant,
r=r,
alpha=r if variant == "pissa" else alpha,
dtype=dtype,
target_roles=(),
target_names=(r"model\.layers\.0\.self_attn\.(q_proj|v_proj)$",),
layers=(0,),
variant_kwargs={"lambda0": 0.1} if variant == "delora" else {},
)
def adapter_state(model: torch.nn.Module) -> dict[str, torch.Tensor]:
return {k: v.detach().clone() for k, v in model.state_dict().items() if "lora_" in k}
def assert_only_lora_trainable(model: torch.nn.Module) -> None:
trainable = [name for name, p in model.named_parameters() if p.requires_grad]
assert trainable
assert all("lora_" in name for name in trainable), trainable[:20]
def assert_no_base_grads(model: torch.nn.Module) -> None:
leaked = [name for name, p in model.named_parameters() if "lora_" not in name and p.grad is not None]
assert leaked == [], leaked[:20]
def perturb_first_adapter(model: torch.nn.Module) -> None:
"""Nudge one trainable adapter parameter so forward output changes.
Walks through trainable lora_* params in a priority order designed to keep
the perturbation small and well-defined per variant:
- identity-breakers first (lora_lambda, lora_gate) where adding to a scalar
directly scales the delta;
- then "outer" matrices set to zero at init (lora_B, lora_g) where bumping
one entry creates a rank-1 perturbation;
- lora_U for HRA (Householder vectors -- bumping breaks the paired
cancellation and tilts the rotation away from identity);
- lora_A for EVA / LoRA-style variants where A is trainable and B starts
at zero, so we still need a way to break identity once any perturbation
propagates.
"""
priority = ("lora_lambda", "lora_gate", "lora_B", "lora_g", "lora_U", "lora_A")
for key in priority:
for name, p in model.named_parameters():
if not p.requires_grad:
continue
if key in name:
with torch.no_grad():
if p.ndim == 0:
p.add_(0.25)
else:
p.flatten()[0].add_(0.25)
return
raise AssertionError("no perturbable adapter parameter found")
def load_model(model_id: str, dtype: torch.dtype, device: str):
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=dtype).to(device)
model.config.use_cache = False
return model
def run_variant(args, variant: str, input_ids: torch.Tensor, labels: torch.Tensor, dtype: torch.dtype):
model = load_model(args.model, dtype, args.device)
model.train()
cfg = cfg_for_variant(variant, dtype, args.r, args.alpha)
with torch.no_grad():
logits_base = model(input_ids=input_ids).logits.detach().clone()
ll.attach(model, cfg)
attached_targets = set(getattr(model, "_lora_lite_attached")["targets"])
assert attached_targets == EXPECTED_TARGETS, attached_targets
assert_only_lora_trainable(model)
with torch.no_grad():
logits_init = model(input_ids=input_ids).logits.detach().clone()
identity_err = (logits_init - logits_base).abs().max().item()
clean_adapter = adapter_state(model)
perturb_first_adapter(model)
with torch.no_grad():
perturb_delta = (model(input_ids=input_ids).logits - logits_init).abs().max().item()
assert perturb_delta > 1e-7, perturb_delta
for name, value in clean_adapter.items():
model.state_dict()[name].copy_(value)
opt = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=args.lr)
with torch.no_grad():
loss0 = model(input_ids=input_ids, labels=labels).loss.item()
before_train = adapter_state(model)
first_grad_norm = math.nan
loss_last = math.nan
for step in range(args.steps):
opt.zero_grad()
loss = model(input_ids=input_ids, labels=labels).loss
loss.backward()
assert_no_base_grads(model)
grad_norm = sum(
p.grad.detach().float().norm().item()
for name, p in model.named_parameters()
if "lora_" in name and p.grad is not None
)
assert math.isfinite(grad_norm), grad_norm
if step == 0:
first_grad_norm = grad_norm
opt.step()
loss_last = loss.item()
after_train = adapter_state(model)
adapter_delta = sum((after_train[k] - before_train[k]).float().norm().item() for k in before_train)
assert first_grad_norm > 0, first_grad_norm
assert adapter_delta > 0, adapter_delta
assert loss_last < loss0, (loss0, loss_last)
model.eval()
with torch.no_grad():
logits_trained = model(input_ids=input_ids).logits.detach().clone()
out_path = args.out_dir / f"{variant}_adapter.pt"
ll.save(model, str(out_path))
saved = torch.load(out_path, weights_only=True, map_location="cpu")
assert set(saved["state"]) == set(after_train)
del model
gc.collect()
torch.cuda.empty_cache()
loaded_model = load_model(args.model, dtype, args.device)
loaded_model.eval()
ll.load(loaded_model, str(out_path))
loaded_state = adapter_state(loaded_model)
for name, value in saved["state"].items():
assert torch.equal(loaded_state[name].cpu(), value)
with torch.no_grad():
logits_loaded = loaded_model(input_ids=input_ids).logits.detach().clone()
reload_err = (logits_loaded - logits_trained).abs().max().item()
assert reload_err < args.reload_tol, reload_err
del loaded_model
gc.collect()
torch.cuda.empty_cache()
return {
"variant": variant,
"targets": len(attached_targets),
"trainable": sum(v.numel() for v in after_train.values()),
"id_err": identity_err,
"perturb": perturb_delta,
"loss0": loss0,
"lossN": loss_last,
"drop%": 100 * (loss0 - loss_last) / loss0,
"grad": first_grad_norm,
"": adapter_delta,
"reload": reload_err,
"out": str(out_path),
}
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--model", default="Qwen/Qwen3-0.6B")
parser.add_argument("--variants", nargs="+", default=["lora", "pissa", "delora", "ia3", "dora", "hra"])
parser.add_argument("--device", default="cuda")
parser.add_argument("--torch-dtype", default="bfloat16")
parser.add_argument("--steps", type=int, default=8)
parser.add_argument("--lr", type=float, default=5e-3)
parser.add_argument("--r", type=int, default=4)
parser.add_argument("--alpha", type=float, default=8.0)
parser.add_argument("--reload-tol", type=float, default=2e-2)
parser.add_argument("--out-dir", type=Path, default=Path("outputs/qwen_train_probe"))
args = parser.parse_args()
if args.device == "cuda" and not torch.cuda.is_available():
raise RuntimeError("CUDA is required for the default Qwen probe. Pass --device cpu explicitly for local debugging.")
args.out_dir.mkdir(parents=True, exist_ok=True)
dtype = getattr(torch, args.torch_dtype)
tokenizer = AutoTokenizer.from_pretrained(args.model)
input_ids = tokenizer(PROMPT, return_tensors="pt").input_ids.to(args.device)
labels = input_ids.clone()
print("SHOULD: exact q_proj/v_proj layer-0 targets, lora-only grads, lossN<loss0, perturb>0, reload<tol. ELSE hook/target/train/save bug.")
rows = [run_variant(args, variant, input_ids, labels, dtype) for variant in args.variants]
print(tabulate(rows, headers="keys", tablefmt="tsv", floatfmt=".4g"))
print("ALL QWEN PROBES PASS")
if __name__ == "__main__":
main()
+22 -3
View File
@@ -6,13 +6,32 @@ if _os.environ.get("BEARTYPE"):
from beartype.claw import beartype_this_package as _bt from beartype.claw import beartype_this_package as _bt
_bt() _bt()
from .config import LoraLiteConfig from .config import AdapterConfig
from .adapter import attach, detach, save, load from .adapter import attach, detach, save, load
from .variant import REGISTRY, register, ParamSpec, Variant from .variant import REGISTRY, register, ParamSpec, Variant
from . import variants # noqa: F401 triggers variant registration from . import variants # noqa: F401 triggers variant + config registration
# Expose per-variant config classes for ergonomic typed construction.
from .variants.lora import LoRAConfig
from .variants.pissa import PiSSAConfig
from .variants.delora import DeLoRAConfig
from .variants.ia3 import IA3Config, IA3FFConfig
from .variants.dora import DoRAConfig
from .variants.hra import HRAConfig
from .variants.eva import EVAConfig
from .variants.antipasto import AntiPaSTOConfig
__all__ = [ __all__ = [
"LoraLiteConfig", "AdapterConfig",
"LoRAConfig",
"PiSSAConfig",
"DeLoRAConfig",
"IA3Config",
"IA3FFConfig",
"DoRAConfig",
"HRAConfig",
"EVAConfig",
"AntiPaSTOConfig",
"attach", "attach",
"detach", "detach",
"save", "save",
+5 -5
View File
@@ -4,7 +4,7 @@ import torch
from torch import nn from torch import nn
from torch.utils.hooks import RemovableHandle from torch.utils.hooks import RemovableHandle
from .config import LoraLiteConfig from .config import AdapterConfig
from .variant import REGISTRY from .variant import REGISTRY
from .target import find_targets from .target import find_targets
@@ -14,7 +14,7 @@ _ATTACHED_ATTR = "_lora_lite_attached"
def _hook(layer, args, y): def _hook(layer, args, y):
(x,) = args (x,) = args
cfg: LoraLiteConfig = layer._lora_cfg cfg: AdapterConfig = layer._lora_cfg
x_cast = x.to(cfg.dtype) x_cast = x.to(cfg.dtype)
out = layer._lora_variant.forward(layer, x_cast, y) out = layer._lora_variant.forward(layer, x_cast, y)
return out.to(y.dtype) return out.to(y.dtype)
@@ -22,13 +22,13 @@ def _hook(layer, args, y):
def _pre_hook(layer, args): def _pre_hook(layer, args):
(x,) = args (x,) = args
cfg: LoraLiteConfig = layer._lora_cfg cfg: AdapterConfig = layer._lora_cfg
x_cast = x.to(cfg.dtype) x_cast = x.to(cfg.dtype)
x_new = layer._lora_variant.forward_input(layer, x_cast) x_new = layer._lora_variant.forward_input(layer, x_cast)
return (x_new.to(x.dtype),) return (x_new.to(x.dtype),)
def attach(model: nn.Module, cfg: LoraLiteConfig, calibration_data=None, *, _skip_group_init: bool = False) -> list[RemovableHandle]: def attach(model: nn.Module, cfg: AdapterConfig, calibration_data=None, *, _skip_group_init: bool = False) -> list[RemovableHandle]:
if cfg.variant not in REGISTRY: if cfg.variant not in REGISTRY:
raise KeyError(f"unknown variant {cfg.variant!r}; registered: {list(REGISTRY)}") raise KeyError(f"unknown variant {cfg.variant!r}; registered: {list(REGISTRY)}")
variant = REGISTRY[cfg.variant] variant = REGISTRY[cfg.variant]
@@ -131,7 +131,7 @@ def save(model: nn.Module, path: str) -> None:
def load(model: nn.Module, path: str) -> list[RemovableHandle]: def load(model: nn.Module, path: str) -> list[RemovableHandle]:
blob = torch.load(path, weights_only=True, map_location="cpu") blob = torch.load(path, weights_only=True, map_location="cpu")
cfg = LoraLiteConfig.from_dict(blob["cfg"]) cfg = AdapterConfig.from_dict(blob["cfg"])
handles = attach(model, cfg, _skip_group_init=True) # creates empty params; data-driven inits restored from state_dict handles = attach(model, cfg, _skip_group_init=True) # creates empty params; data-driven inits restored from state_dict
missing, unexpected = model.load_state_dict(blob["state"], strict=False) missing, unexpected = model.load_state_dict(blob["state"], strict=False)
expected_lora = {k for k in model.state_dict() if "lora_" in k} expected_lora = {k for k in model.state_dict() if "lora_" in k}
+44 -11
View File
@@ -1,13 +1,35 @@
from dataclasses import dataclass, field, asdict """AdapterConfig: per-variant typed dataclass.
from typing import Any, Literal
Replaces the older `LoraLiteConfig` + `variant_kwargs` dict. Each variant
ships its own subclass under `variants/*.py` (e.g. `DeLoRAConfig`), adding
strongly-typed fields so users discover the knobs via IDE / dataclass
introspection instead of stringly-typed dict lookups.
Wire-up:
- `AdapterConfig` holds the universal fields (variant name, rank, alpha,
dtype, targeting filters).
- Subclasses override the `variant` default and add new fields.
- `register_config(cls)` adds the subclass to `_CONFIG_REGISTRY` so
`from_dict` can route to the right type at load time.
Save format:
to_dict() emits a flat dict including `variant`; from_dict() uses that
field to look up the right subclass.
"""
from dataclasses import dataclass, asdict
from typing import Literal
import torch import torch
Role = Literal["reader", "writer", "inner"] Role = Literal["reader", "writer", "inner"]
@dataclass @dataclass
class LoraLiteConfig: class AdapterConfig:
variant: str = "lora" """Base config. Subclass per variant; do not instantiate directly."""
# variant name (subclass overrides default)
variant: str = "?"
# rank-style hyperparams shared across most variants
r: int = 8 r: int = 8
alpha: float | int = 16.0 alpha: float | int = 16.0
dtype: torch.dtype = torch.bfloat16 dtype: torch.dtype = torch.bfloat16
@@ -18,19 +40,30 @@ class LoraLiteConfig:
exclude_names: tuple[str, ...] = ("lm_head", "embed_tokens") exclude_names: tuple[str, ...] = ("lm_head", "embed_tokens")
layers: tuple[int, ...] | None = None layers: tuple[int, ...] | None = None
# variant-specific bag (e.g. lambda0 for DeLoRA)
variant_kwargs: dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> dict: def to_dict(self) -> dict:
d = asdict(self) d = asdict(self)
d["dtype"] = str(self.dtype).removeprefix("torch.") d["dtype"] = str(self.dtype).removeprefix("torch.")
return d return d
@classmethod @classmethod
def from_dict(cls, d: dict) -> "LoraLiteConfig": def from_dict(cls, d: dict) -> "AdapterConfig":
# to_dict always serializes dtype as str; torch.save preserves tuples.
# If you build the dict by hand, pass the right types -- fail loud otherwise.
d = dict(d) d = dict(d)
name = d["variant"]
sub = _CONFIG_REGISTRY[name]
d["dtype"] = getattr(torch, d["dtype"]) d["dtype"] = getattr(torch, d["dtype"])
return cls(**d) return sub(**d)
# Registry of variant_name -> config subclass. Populated by `register_config`
# decorators in each `variants/*.py` module at import time.
_CONFIG_REGISTRY: dict[str, type[AdapterConfig]] = {}
def register_config(cls: type[AdapterConfig]) -> type[AdapterConfig]:
"""Decorator: register `cls` under its `variant` default value."""
name = cls.__dataclass_fields__["variant"].default
if name in _CONFIG_REGISTRY:
raise ValueError(f"config for variant {name!r} already registered")
_CONFIG_REGISTRY[name] = cls
return cls
+3 -3
View File
@@ -4,7 +4,7 @@ from typing import Callable, Protocol, Any
import torch import torch
from torch import nn from torch import nn
from .config import LoraLiteConfig from .config import AdapterConfig
@dataclass @dataclass
@@ -44,10 +44,10 @@ class Variant(Protocol):
name: str name: str
@staticmethod @staticmethod
def param_specs(d_in: int, d_out: int, cfg: LoraLiteConfig) -> dict[str, ParamSpec]: ... def param_specs(d_in: int, d_out: int, cfg: AdapterConfig) -> dict[str, ParamSpec]: ...
@staticmethod @staticmethod
def init(layer: nn.Linear, cfg: LoraLiteConfig) -> None: ... def init(layer: nn.Linear, cfg: AdapterConfig) -> None: ...
@staticmethod @staticmethod
def forward(layer: nn.Linear, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def forward(layer: nn.Linear, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+20 -5
View File
@@ -29,13 +29,15 @@ precision (fp32 SVD round-tripped to cfg.dtype).
WHICH BASIS IS ROTATED: WHICH BASIS IS ROTATED:
By default we rotate Vh (the INPUT singular basis). This is what AntiPaSTO3 By default we rotate Vh (the INPUT singular basis). This is what AntiPaSTO3
calls `rotate_V=True` in adapter terms (V == Vh.T columns). To rotate U calls `rotate_V=True` in adapter terms (V == Vh.T columns). To rotate U
(output basis) instead, pass variant_kwargs={'rotate_basis': 'U'}. (output basis) instead, pass `rotate_basis='U'` on the AntiPaSTOConfig.
Rotating both is not implemented (one rotation is enough to span the Rotating both is not implemented (one rotation is enough to span the
identifiable steering directions; two is degenerate). identifiable steering directions; two is degenerate).
REQUIRES even rank divisible by `block_size` (default 4). r=8, bs=4 -> 2 blocks. REQUIRES even rank divisible by `block_size` (default 4). r=8, bs=4 -> 2 blocks.
""" """
import math import math
from dataclasses import dataclass
from typing import Literal
import torch import torch
from einops import einsum from einops import einsum
@@ -43,6 +45,19 @@ from jaxtyping import Float
from torch import nn, Tensor as T from torch import nn, Tensor as T
from ..variant import register, ParamSpec from ..variant import register, ParamSpec
from ..config import AdapterConfig, register_config
@register_config
@dataclass
class AntiPaSTOConfig(AdapterConfig):
variant: str = "antipasto"
# Block size for the block-diagonal Cayley rotation. r must be divisible by it.
block_size: int = 4
# Cayley map saturation: bounds rotation angle to ~max_rotation_angle radians.
max_rotation_angle: float = 0.5
# Which singular basis to rotate: 'V' (input) or 'U' (output).
rotate_basis: Literal["V", "U"] = "V"
def _cayley(skew: torch.Tensor) -> torch.Tensor: def _cayley(skew: torch.Tensor) -> torch.Tensor:
@@ -81,7 +96,7 @@ class AntiPaSTO:
@staticmethod @staticmethod
def param_specs(d_in, d_out, cfg): def param_specs(d_in, d_out, cfg):
r = cfg.r r = cfg.r
bs = int(cfg.variant_kwargs.get("block_size", 4)) bs = int(cfg.block_size)
if r % bs != 0: if r % bs != 0:
raise ValueError(f"AntiPaSTO requires r={r} divisible by block_size={bs}") raise ValueError(f"AntiPaSTO requires r={r} divisible by block_size={bs}")
n_blocks = r // bs n_blocks = r // bs
@@ -123,9 +138,9 @@ class AntiPaSTO:
y: Float[T, '*B o'], y: Float[T, '*B o'],
) -> Float[T, '*B o']: ) -> Float[T, '*B o']:
cfg = layer._lora_cfg cfg = layer._lora_cfg
bs = int(cfg.variant_kwargs.get("block_size", 4)) bs = int(cfg.block_size)
max_angle = float(cfg.variant_kwargs.get("max_rotation_angle", 0.5)) max_angle = float(cfg.max_rotation_angle)
rotate_basis = cfg.variant_kwargs.get("rotate_basis", "V") rotate_basis = cfg.rotate_basis
U = layer.lora_U.to(x.dtype) # (d_out, r) U = layer.lora_U.to(x.dtype) # (d_out, r)
S = layer.lora_S.to(x.dtype) # (r,) S = layer.lora_S.to(x.dtype) # (r,)
+12 -1
View File
@@ -35,8 +35,19 @@ import torch
from einops import einsum from einops import einsum
from jaxtyping import Float from jaxtyping import Float
from torch import nn, Tensor as T from torch import nn, Tensor as T
from dataclasses import dataclass
from ..variant import register, ParamSpec from ..variant import register, ParamSpec
from ..config import AdapterConfig, register_config
@register_config
@dataclass
class DeLoRAConfig(AdapterConfig):
variant: str = "delora"
# Initial scale for the per-layer learnable lambda. peft default is 15.0;
# we default to 0.0 (identity at t=0 even before B is zero-initialized).
lambda0: float = 0.0
@register @register
@@ -45,7 +56,7 @@ class DeLoRA:
@staticmethod @staticmethod
def param_specs(d_in, d_out, cfg): def param_specs(d_in, d_out, cfg):
lam0 = float(cfg.variant_kwargs.get("lambda0", 0.0)) lam0 = float(cfg.lambda0)
return { return {
# peft DeLoRA default: A=kaiming, B=zeros (docs/refs/peft_delora_layer.py:138-140). # peft DeLoRA default: A=kaiming, B=zeros (docs/refs/peft_delora_layer.py:138-140).
# Identity at t=0 from B=0 -> delta=0 regardless of lambda. With B=0 the # Identity at t=0 from B=0 -> delta=0 regardless of lambda. With B=0 the
+8
View File
@@ -22,8 +22,16 @@ import torch
from einops import einsum from einops import einsum
from jaxtyping import Float from jaxtyping import Float
from torch import nn, Tensor as T from torch import nn, Tensor as T
from dataclasses import dataclass
from ..variant import register, ParamSpec from ..variant import register, ParamSpec
from ..config import AdapterConfig, register_config
@register_config
@dataclass
class DoRAConfig(AdapterConfig):
variant: str = "dora"
@register @register
+8
View File
@@ -36,13 +36,21 @@ from einops import einsum
from jaxtyping import Float from jaxtyping import Float
from torch import nn, Tensor as T from torch import nn, Tensor as T
from typing import Iterable from typing import Iterable
from dataclasses import dataclass
from ..variant import register, ParamSpec from ..variant import register, ParamSpec
from ..config import AdapterConfig, register_config
CalibrationBatch = dict | tuple | list | T CalibrationBatch = dict | tuple | list | T
CalibrationData = Iterable[CalibrationBatch] CalibrationData = Iterable[CalibrationBatch]
@register_config
@dataclass
class EVAConfig(AdapterConfig):
variant: str = "eva"
@register @register
class EVA: class EVA:
name = "eva" name = "eva"
+8
View File
@@ -32,8 +32,16 @@ import torch
from einops import einsum from einops import einsum
from jaxtyping import Float from jaxtyping import Float
from torch import nn, Tensor as T from torch import nn, Tensor as T
from dataclasses import dataclass
from ..variant import register, ParamSpec from ..variant import register, ParamSpec
from ..config import AdapterConfig, register_config
@register_config
@dataclass
class HRAConfig(AdapterConfig):
variant: str = "hra"
@register @register
+16 -2
View File
@@ -16,8 +16,8 @@ In both cases g is initialized to 1 -> identity at t=0.
To match the paper exactly on a Llama/Qwen-style block requires TWO attach To match the paper exactly on a Llama/Qwen-style block requires TWO attach
passes (one per variant), since each variant uses one hook type: passes (one per variant), since each variant uses one hook type:
cfg_attn = LoraLiteConfig(variant="ia3", target_names=(r"\\.k_proj$", r"\\.v_proj$")) cfg_attn = IA3Config( target_names=(r"\\.k_proj$", r"\\.v_proj$"))
cfg_ffn = LoraLiteConfig(variant="ia3_ff", target_names=(r"\\.down_proj$",)) cfg_ffn = IA3FFConfig( target_names=(r"\\.down_proj$",))
Reference implementation: Reference implementation:
- peft IA3 layer (is_feedforward toggles input-vs-output gating, see - peft IA3 layer (is_feedforward toggles input-vs-output gating, see
@@ -27,8 +27,22 @@ Reference implementation:
import torch import torch
from jaxtyping import Float from jaxtyping import Float
from torch import nn, Tensor as T from torch import nn, Tensor as T
from dataclasses import dataclass
from ..variant import register, ParamSpec from ..variant import register, ParamSpec
from ..config import AdapterConfig, register_config
@register_config
@dataclass
class IA3Config(AdapterConfig):
variant: str = "ia3"
@register_config
@dataclass
class IA3FFConfig(AdapterConfig):
variant: str = "ia3_ff"
@register @register
+8
View File
@@ -13,8 +13,16 @@ from einops import einsum
from jaxtyping import Float from jaxtyping import Float
from torch import nn, Tensor as T from torch import nn, Tensor as T
import torch import torch
from dataclasses import dataclass
from ..variant import register, ParamSpec from ..variant import register, ParamSpec
from ..config import AdapterConfig, register_config
@register_config
@dataclass
class LoRAConfig(AdapterConfig):
variant: str = "lora"
@register @register
+8
View File
@@ -24,8 +24,16 @@ import torch
from einops import einsum from einops import einsum
from jaxtyping import Float from jaxtyping import Float
from torch import nn, Tensor as T from torch import nn, Tensor as T
from dataclasses import dataclass
from ..variant import register, ParamSpec from ..variant import register, ParamSpec
from ..config import AdapterConfig, register_config
@register_config
@dataclass
class PiSSAConfig(AdapterConfig):
variant: str = "pissa"
@register @register
+30 -15
View File
@@ -95,6 +95,19 @@ class FakeBnbModel(nn.Module):
return self.layers[0](x) return self.layers[0](x)
_CFG_BY_VARIANT = {
"lora": ll.LoRAConfig,
"pissa": ll.PiSSAConfig,
"delora": ll.DeLoRAConfig,
"ia3": ll.IA3Config,
"ia3_ff": ll.IA3FFConfig,
"dora": ll.DoRAConfig,
"hra": ll.HRAConfig,
"eva": ll.EVAConfig,
"antipasto": ll.AntiPaSTOConfig,
}
def variant_test(variant: str, dtype=torch.float32): def variant_test(variant: str, dtype=torch.float32):
print(f"\n=== variant={variant} dtype={dtype} ===") print(f"\n=== variant={variant} dtype={dtype} ===")
torch.manual_seed(0) torch.manual_seed(0)
@@ -104,13 +117,14 @@ def variant_test(variant: str, dtype=torch.float32):
with torch.no_grad(): with torch.no_grad():
y_base = model(ids).clone() y_base = model(ids).clone()
cfg = ll.LoraLiteConfig( cfg_cls = _CFG_BY_VARIANT[variant]
variant=variant, extra = {"lambda0": 15.0} if variant == "delora" else {}
cfg = cfg_cls(
r=4, r=4,
alpha=4 if variant == "pissa" else 8, # PiSSA needs scale==1 for clean recon alpha=4 if variant == "pissa" else 8, # PiSSA needs scale==1 for clean recon
dtype=dtype, dtype=dtype,
# delora identity holds via B=0 init (peft semantics); use peft default lambda0=15. # delora identity holds via B=0 init (peft semantics); use peft default lambda0=15.
variant_kwargs={"lambda0": 15.0} if variant == "delora" else {}, **extra,
) )
handles = ll.attach(model, cfg) handles = ll.attach(model, cfg)
n_targets = len(handles) n_targets = len(handles)
@@ -164,9 +178,8 @@ def variant_test(variant: str, dtype=torch.float32):
model = TinyModel().to(dtype) model = TinyModel().to(dtype)
train_cfg = cfg train_cfg = cfg
if variant == "delora": if variant == "delora":
train_cfg = ll.LoraLiteConfig( train_cfg = ll.DeLoRAConfig(
variant=cfg.variant, r=cfg.r, alpha=cfg.alpha, dtype=cfg.dtype, r=cfg.r, alpha=cfg.alpha, dtype=cfg.dtype, lambda0=0.1,
variant_kwargs={"lambda0": 0.1},
) )
ll.attach(model, train_cfg) ll.attach(model, train_cfg)
target = torch.randn(2, 16, 100, dtype=dtype) * 0.1 target = torch.randn(2, 16, 100, dtype=dtype) * 0.1
@@ -200,7 +213,7 @@ def structural_linear_like_test():
model = FakeBnbModel() model = FakeBnbModel()
x = torch.randn(2, 3, 8) x = torch.randn(2, 3, 8)
y_base = model(x).detach() y_base = model(x).detach()
ll.attach(model, ll.LoraLiteConfig(variant="lora", r=2, alpha=4, dtype=torch.float32, target_roles=())) ll.attach(model, ll.LoRAConfig(r=2, alpha=4, dtype=torch.float32, target_roles=()))
layer = model.layers[0] layer = model.layers[0]
assert hasattr(layer, "lora_A") and hasattr(layer, "lora_B") assert hasattr(layer, "lora_A") and hasattr(layer, "lora_B")
y = model(x) y = model(x)
@@ -262,10 +275,12 @@ def bitsandbytes_cuda_smoke(require_bnb: bool):
model = BnbModel(layer_cls) model = BnbModel(layer_cls)
x = torch.randn(2, 3, 8, device="cuda") x = torch.randn(2, 3, 8, device="cuda")
y_base = model(x).detach() y_base = model(x).detach()
cfg = ll.LoraLiteConfig( cfg_cls = _CFG_BY_VARIANT[variant]
variant=variant, r=2, alpha=4, dtype=torch.float16, target_roles=(), extra = {"lambda0": 0.1} if variant == "delora" else {}
# In fp16 + bnb, peft default lambda0=15 + B=0 + clamp(min=1e-4) gives\n # scale=lambda/(r*1e-4) ~ 75000 > fp16 max -> inf*0 = NaN. Use small\n # lambda0 for the fp16 test.\n variant_kwargs={"lambda0": 0.1} if variant == "delora" else {}, # In fp16 + bnb, peft default lambda0=15 + B=0 + clamp(min=1e-4) gives
) # scale=lambda/(r*1e-4) ~ 75000 > fp16 max -> inf*0 = NaN. Use small
# lambda0 for the fp16 test.
cfg = cfg_cls(r=2, alpha=4, dtype=torch.float16, target_roles=(), **extra)
ll.attach(model, cfg) ll.attach(model, cfg)
y = model(x) y = model(x)
err = (y.detach() - y_base).abs().max().item() err = (y.detach() - y_base).abs().max().item()
@@ -281,7 +296,7 @@ def bitsandbytes_cuda_smoke(require_bnb: bool):
for variant in bnb_fail: for variant in bnb_fail:
model = BnbModel(layer_cls) model = BnbModel(layer_cls)
cfg = ll.LoraLiteConfig(variant=variant, r=2, alpha=2, dtype=torch.float16, target_roles=()) cfg = _CFG_BY_VARIANT[variant](r=2, alpha=2, dtype=torch.float16, target_roles=())
try: try:
ll.attach(model, cfg) ll.attach(model, cfg)
except (TypeError, RuntimeError, AttributeError, ValueError) as e: except (TypeError, RuntimeError, AttributeError, ValueError) as e:
@@ -300,7 +315,7 @@ def eva_smoke():
with torch.no_grad(): with torch.no_grad():
y_base = model(ids).clone() y_base = model(ids).clone()
cfg = ll.LoraLiteConfig(variant="eva", r=4, alpha=8, dtype=torch.float32) cfg = ll.EVAConfig(r=4, alpha=8, dtype=torch.float32)
# 4 calibration batches of random ids # 4 calibration batches of random ids
calib = [torch.randint(0, 100, (2, 16)) for _ in range(4)] calib = [torch.randint(0, 100, (2, 16)) for _ in range(4)]
ll.attach(model, cfg, calibration_data=calib) ll.attach(model, cfg, calibration_data=calib)
@@ -380,7 +395,7 @@ def dora_bias_smoke():
return self.layers[0](x) return self.layers[0](x)
model = Wrap(layer) model = Wrap(layer)
cfg = ll.LoraLiteConfig(variant="dora", r=2, alpha=4, dtype=torch.float32, target_roles=()) cfg = ll.DoRAConfig(r=2, alpha=4, dtype=torch.float32, target_roles=())
ll.attach(model, cfg) ll.attach(model, cfg)
with torch.no_grad(): with torch.no_grad():
y_adapt = model(x) y_adapt = model(x)
@@ -404,7 +419,7 @@ def hra_forward_order_smoke():
layer = nn.Linear(d, d, bias=False) layer = nn.Linear(d, d, bias=False)
x = torch.randn(2, 3, d) x = torch.randn(2, 3, d)
cfg = ll.LoraLiteConfig(variant="hra", r=4, alpha=4, dtype=torch.float32, target_roles=()) cfg = ll.HRAConfig(r=4, alpha=4, dtype=torch.float32, target_roles=())
class Wrap(nn.Module): class Wrap(nn.Module):
def __init__(self_, lin): def __init__(self_, lin):
super().__init__() super().__init__()
+40 -28
View File
@@ -67,13 +67,28 @@ class FakeBnbModel(nn.Module):
return self.layers[0](x) return self.layers[0](x)
def cfg_for_variant(variant: str, *, training: bool = False) -> ll.LoraLiteConfig: _CFG_BY_VARIANT = {
return ll.LoraLiteConfig( "lora": ll.LoRAConfig,
variant=variant, "pissa": ll.PiSSAConfig,
"delora": ll.DeLoRAConfig,
"ia3": ll.IA3Config,
"ia3_ff": ll.IA3FFConfig,
"dora": ll.DoRAConfig,
"hra": ll.HRAConfig,
"eva": ll.EVAConfig,
"antipasto": ll.AntiPaSTOConfig,
}
def cfg_for_variant(variant: str, *, training: bool = False) -> ll.AdapterConfig:
# DeLoRA keeps identity via B=0, so nonzero lambda is needed for the
# perturb-output check to distinguish a live adapter from dead code.
extra = {"lambda0": 0.1} if variant == "delora" else {}
return _CFG_BY_VARIANT[variant](
r=4, r=4,
alpha=4 if variant == "pissa" else 8, alpha=4 if variant == "pissa" else 8,
dtype=torch.float32, dtype=torch.float32,
variant_kwargs={"lambda0": 0.1 if training else 0.0} if variant == "delora" else {}, **extra,
) )
@@ -93,25 +108,21 @@ def assert_no_base_grads(model: nn.Module) -> None:
def perturb_first_adapter(model: nn.Module) -> None: def perturb_first_adapter(model: nn.Module) -> None:
for name, p in model.named_parameters(): """Nudge one trainable adapter parameter so forward output changes.
if "lora_lambda" in name:
Priority order matters: with B=0 init (DeLoRA, EVA, LoRA), perturbing a
scalar gate or lambda alone keeps delta=0, so we hit a matrix entry first.
"""
priority = ("lora_B", "lora_g", "lora_U", "lora_A", "lora_lambda", "lora_gate")
for key in priority:
for name, p in model.named_parameters():
if not p.requires_grad or key not in name:
continue
with torch.no_grad(): with torch.no_grad():
p.add_(0.25) if p.ndim == 0:
return p.add_(0.25)
for name, p in model.named_parameters(): else:
if "lora_gate" in name: p.flatten()[0].add_(0.25)
with torch.no_grad():
p.add_(0.25)
return
for name, p in model.named_parameters():
if "lora_B" in name:
with torch.no_grad():
p.flatten()[0].add_(0.25)
return
for name, p in model.named_parameters():
if "lora_g" in name:
with torch.no_grad():
p.flatten()[0].add_(0.25)
return return
raise AssertionError("no perturbable adapter parameter found") raise AssertionError("no perturbable adapter parameter found")
@@ -134,7 +145,7 @@ def test_variant_identity_hook_save_load_and_training(variant: str):
with torch.no_grad(): with torch.no_grad():
y_init = model(ids).clone() y_init = model(ids).clone()
identity_err = (y_init - y_base).abs().max().item() identity_err = (y_init - y_base).abs().max().item()
identity_tol = {"lora": 1e-6, "pissa": 5e-4, "delora": 1e-6, "ia3": 1e-6, "dora": 5e-5, "hra": 1e-6}[variant] identity_tol = {"lora": 1e-6, "pissa": 5e-4, "delora": 1e-6, "ia3": 1e-6, "dora": 5e-5, "hra": 5e-6}[variant]
assert identity_err < identity_tol assert identity_err < identity_tol
before_perturb = adapter_state(model) before_perturb = adapter_state(model)
@@ -221,7 +232,7 @@ def test_load_fails_on_missing_and_unexpected_lora_keys():
def test_no_target_layers_is_loud_failure(): def test_no_target_layers_is_loud_failure():
cfg = ll.LoraLiteConfig(variant="lora", target_names=("definitely_missing",)) cfg = ll.LoRAConfig(target_names=("definitely_missing",))
with pytest.raises(RuntimeError, match="no target layers"): with pytest.raises(RuntimeError, match="no target layers"):
ll.attach(TinyModel(), cfg) ll.attach(TinyModel(), cfg)
@@ -232,16 +243,17 @@ def test_structural_non_linear_target_trains_for_forward_only_variants(variant:
model = FakeBnbModel() model = FakeBnbModel()
x = torch.randn(2, 3, 8) x = torch.randn(2, 3, 8)
y_base = model(x).detach() y_base = model(x).detach()
cfg = ll.LoraLiteConfig( extra = {"lambda0": 0.1} if variant == "delora" else {}
variant=variant, cfg = _CFG_BY_VARIANT[variant](
r=2, r=2,
alpha=4, alpha=4,
dtype=torch.float32, dtype=torch.float32,
target_roles=(), target_roles=(),
variant_kwargs={"lambda0": 0.0} if variant == "delora" else {}, **extra,
) )
ll.attach(model, cfg) ll.attach(model, cfg)
y_init = model(x) y_init = model(x)
# delora: lambda0=0.1 is small but B=0 still makes delta=0 at t=0, so identity holds.
assert (y_init.detach() - y_base).abs().max().item() < 1e-6 assert (y_init.detach() - y_base).abs().max().item() < 1e-6
loss = y_init.pow(2).mean() loss = y_init.pow(2).mean()
loss.backward() loss.backward()
@@ -256,6 +268,6 @@ def test_structural_non_linear_target_trains_for_forward_only_variants(variant:
@pytest.mark.parametrize("variant", ["pissa", "dora"]) @pytest.mark.parametrize("variant", ["pissa", "dora"])
def test_weight_reading_variants_reject_structural_non_linear_target(variant: str): def test_weight_reading_variants_reject_structural_non_linear_target(variant: str):
cfg = ll.LoraLiteConfig(variant=variant, r=2, alpha=2, dtype=torch.float32, target_roles=()) cfg = _CFG_BY_VARIANT[variant](r=2, alpha=2, dtype=torch.float32, target_roles=())
with pytest.raises(TypeError, match="plain nn.Linear"): with pytest.raises(TypeError, match="plain nn.Linear"):
ll.attach(FakeBnbModel(), cfg) ll.attach(FakeBnbModel(), cfg)
+37
View File
@@ -0,0 +1,37 @@
import importlib.util
import sys
from pathlib import Path
SCRIPT_PATH = Path(__file__).parents[1] / "scripts" / "metamath_gsm8k_benchmark.py"
SPEC = importlib.util.spec_from_file_location("metamath_gsm8k_benchmark", SCRIPT_PATH)
benchmark = importlib.util.module_from_spec(SPEC)
assert SPEC.loader is not None
sys.modules[SPEC.name] = benchmark
SPEC.loader.exec_module(benchmark)
extract_answer = benchmark.extract_answer
score_predictions = benchmark.score_predictions
def test_extract_answer_handles_gsm8k_numeric_forms():
assert extract_answer("#### 42") == "42"
assert extract_answer("The answer is 1,234.") == "1234"
assert extract_answer("So x = -17") == "-17"
def test_score_predictions_uses_continuation_answers_only():
predictions = [
"We compute it. The answer is 42.",
"No final number here",
"Prompt said #### 5. But the continuation answer is 6.",
]
references = [
"reasoning\n#### 42",
"reasoning\n#### 9",
"reasoning\n#### 5",
]
scored = score_predictions(predictions, references)
assert scored["correct"] == 1
assert scored["total"] == 3
assert scored["accuracy"] == 1 / 3
Generated
+1184 -32
View File
File diff suppressed because it is too large Load Diff