mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 16:15:50 +08:00
tyro and benchmark
This commit is contained in:
@@ -1,214 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from tabulate import tabulate
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
import lora_lite as ll
|
||||
|
||||
|
||||
PROMPT = "LoRA-lite probe: Paris is the capital of France. The answer is"
|
||||
EXPECTED_TARGETS = {
|
||||
"model.layers.0.self_attn.q_proj",
|
||||
"model.layers.0.self_attn.v_proj",
|
||||
}
|
||||
|
||||
|
||||
def cfg_for_variant(variant: str, dtype: torch.dtype, r: int, alpha: float) -> ll.LoraLiteConfig:
|
||||
return ll.LoraLiteConfig(
|
||||
variant=variant,
|
||||
r=r,
|
||||
alpha=r if variant == "pissa" else alpha,
|
||||
dtype=dtype,
|
||||
target_roles=(),
|
||||
target_names=(r"model\.layers\.0\.self_attn\.(q_proj|v_proj)$",),
|
||||
layers=(0,),
|
||||
variant_kwargs={"lambda0": 0.1} if variant == "delora" else {},
|
||||
)
|
||||
|
||||
|
||||
def adapter_state(model: torch.nn.Module) -> dict[str, torch.Tensor]:
|
||||
return {k: v.detach().clone() for k, v in model.state_dict().items() if "lora_" in k}
|
||||
|
||||
|
||||
def assert_only_lora_trainable(model: torch.nn.Module) -> None:
|
||||
trainable = [name for name, p in model.named_parameters() if p.requires_grad]
|
||||
assert trainable
|
||||
assert all("lora_" in name for name in trainable), trainable[:20]
|
||||
|
||||
|
||||
def assert_no_base_grads(model: torch.nn.Module) -> None:
|
||||
leaked = [name for name, p in model.named_parameters() if "lora_" not in name and p.grad is not None]
|
||||
assert leaked == [], leaked[:20]
|
||||
|
||||
|
||||
def perturb_first_adapter(model: torch.nn.Module) -> None:
|
||||
"""Nudge one trainable adapter parameter so forward output changes.
|
||||
|
||||
Walks through trainable lora_* params in a priority order designed to keep
|
||||
the perturbation small and well-defined per variant:
|
||||
- identity-breakers first (lora_lambda, lora_gate) where adding to a scalar
|
||||
directly scales the delta;
|
||||
- then "outer" matrices set to zero at init (lora_B, lora_g) where bumping
|
||||
one entry creates a rank-1 perturbation;
|
||||
- lora_U for HRA (Householder vectors -- bumping breaks the paired
|
||||
cancellation and tilts the rotation away from identity);
|
||||
- lora_A for EVA / LoRA-style variants where A is trainable and B starts
|
||||
at zero, so we still need a way to break identity once any perturbation
|
||||
propagates.
|
||||
"""
|
||||
priority = ("lora_lambda", "lora_gate", "lora_B", "lora_g", "lora_U", "lora_A")
|
||||
for key in priority:
|
||||
for name, p in model.named_parameters():
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
if key in name:
|
||||
with torch.no_grad():
|
||||
if p.ndim == 0:
|
||||
p.add_(0.25)
|
||||
else:
|
||||
p.flatten()[0].add_(0.25)
|
||||
return
|
||||
raise AssertionError("no perturbable adapter parameter found")
|
||||
|
||||
|
||||
def load_model(model_id: str, dtype: torch.dtype, device: str):
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=dtype).to(device)
|
||||
model.config.use_cache = False
|
||||
return model
|
||||
|
||||
|
||||
def run_variant(args, variant: str, input_ids: torch.Tensor, labels: torch.Tensor, dtype: torch.dtype):
|
||||
model = load_model(args.model, dtype, args.device)
|
||||
model.train()
|
||||
cfg = cfg_for_variant(variant, dtype, args.r, args.alpha)
|
||||
|
||||
with torch.no_grad():
|
||||
logits_base = model(input_ids=input_ids).logits.detach().clone()
|
||||
|
||||
ll.attach(model, cfg)
|
||||
attached_targets = set(getattr(model, "_lora_lite_attached")["targets"])
|
||||
assert attached_targets == EXPECTED_TARGETS, attached_targets
|
||||
assert_only_lora_trainable(model)
|
||||
|
||||
with torch.no_grad():
|
||||
logits_init = model(input_ids=input_ids).logits.detach().clone()
|
||||
identity_err = (logits_init - logits_base).abs().max().item()
|
||||
|
||||
clean_adapter = adapter_state(model)
|
||||
perturb_first_adapter(model)
|
||||
with torch.no_grad():
|
||||
perturb_delta = (model(input_ids=input_ids).logits - logits_init).abs().max().item()
|
||||
assert perturb_delta > 1e-7, perturb_delta
|
||||
for name, value in clean_adapter.items():
|
||||
model.state_dict()[name].copy_(value)
|
||||
|
||||
opt = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=args.lr)
|
||||
with torch.no_grad():
|
||||
loss0 = model(input_ids=input_ids, labels=labels).loss.item()
|
||||
|
||||
before_train = adapter_state(model)
|
||||
first_grad_norm = math.nan
|
||||
loss_last = math.nan
|
||||
for step in range(args.steps):
|
||||
opt.zero_grad()
|
||||
loss = model(input_ids=input_ids, labels=labels).loss
|
||||
loss.backward()
|
||||
assert_no_base_grads(model)
|
||||
grad_norm = sum(
|
||||
p.grad.detach().float().norm().item()
|
||||
for name, p in model.named_parameters()
|
||||
if "lora_" in name and p.grad is not None
|
||||
)
|
||||
assert math.isfinite(grad_norm), grad_norm
|
||||
if step == 0:
|
||||
first_grad_norm = grad_norm
|
||||
opt.step()
|
||||
loss_last = loss.item()
|
||||
|
||||
after_train = adapter_state(model)
|
||||
adapter_delta = sum((after_train[k] - before_train[k]).float().norm().item() for k in before_train)
|
||||
assert first_grad_norm > 0, first_grad_norm
|
||||
assert adapter_delta > 0, adapter_delta
|
||||
assert loss_last < loss0, (loss0, loss_last)
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
logits_trained = model(input_ids=input_ids).logits.detach().clone()
|
||||
|
||||
out_path = args.out_dir / f"{variant}_adapter.pt"
|
||||
ll.save(model, str(out_path))
|
||||
saved = torch.load(out_path, weights_only=True, map_location="cpu")
|
||||
assert set(saved["state"]) == set(after_train)
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
loaded_model = load_model(args.model, dtype, args.device)
|
||||
loaded_model.eval()
|
||||
ll.load(loaded_model, str(out_path))
|
||||
loaded_state = adapter_state(loaded_model)
|
||||
for name, value in saved["state"].items():
|
||||
assert torch.equal(loaded_state[name].cpu(), value)
|
||||
with torch.no_grad():
|
||||
logits_loaded = loaded_model(input_ids=input_ids).logits.detach().clone()
|
||||
reload_err = (logits_loaded - logits_trained).abs().max().item()
|
||||
assert reload_err < args.reload_tol, reload_err
|
||||
|
||||
del loaded_model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return {
|
||||
"variant": variant,
|
||||
"targets": len(attached_targets),
|
||||
"trainable": sum(v.numel() for v in after_train.values()),
|
||||
"id_err": identity_err,
|
||||
"perturb": perturb_delta,
|
||||
"loss0": loss0,
|
||||
"lossN": loss_last,
|
||||
"drop%": 100 * (loss0 - loss_last) / loss0,
|
||||
"grad": first_grad_norm,
|
||||
"dθ": adapter_delta,
|
||||
"reload": reload_err,
|
||||
"out": str(out_path),
|
||||
}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", default="Qwen/Qwen3-0.6B")
|
||||
parser.add_argument("--variants", nargs="+", default=["lora", "pissa", "delora", "ia3", "dora", "hra"])
|
||||
parser.add_argument("--device", default="cuda")
|
||||
parser.add_argument("--torch-dtype", default="bfloat16")
|
||||
parser.add_argument("--steps", type=int, default=8)
|
||||
parser.add_argument("--lr", type=float, default=5e-3)
|
||||
parser.add_argument("--r", type=int, default=4)
|
||||
parser.add_argument("--alpha", type=float, default=8.0)
|
||||
parser.add_argument("--reload-tol", type=float, default=2e-2)
|
||||
parser.add_argument("--out-dir", type=Path, default=Path("outputs/qwen_train_probe"))
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.device == "cuda" and not torch.cuda.is_available():
|
||||
raise RuntimeError("CUDA is required for the default Qwen probe. Pass --device cpu explicitly for local debugging.")
|
||||
|
||||
args.out_dir.mkdir(parents=True, exist_ok=True)
|
||||
dtype = getattr(torch, args.torch_dtype)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
input_ids = tokenizer(PROMPT, return_tensors="pt").input_ids.to(args.device)
|
||||
labels = input_ids.clone()
|
||||
|
||||
print("SHOULD: exact q_proj/v_proj layer-0 targets, lora-only grads, lossN<loss0, perturb>0, reload<tol. ELSE hook/target/train/save bug.")
|
||||
rows = [run_variant(args, variant, input_ids, labels, dtype) for variant in args.variants]
|
||||
print(tabulate(rows, headers="keys", tablefmt="tsv", floatfmt=".4g"))
|
||||
print("ALL QWEN PROBES PASS")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user