mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 16:15:50 +08:00
tidy tests to subset of metamath
This commit is contained in:
@@ -45,6 +45,7 @@ class BenchmarkConfig:
|
||||
mode: Literal["benchmark", "probe"] = "benchmark"
|
||||
device: str = "cuda"
|
||||
torch_dtype: str = "bfloat16"
|
||||
quantization: Literal["none", "4bit", "8bit"] = "none"
|
||||
r: int = 32
|
||||
alpha: float = 64.0
|
||||
delora_lambda0: float = 0.1
|
||||
@@ -146,7 +147,7 @@ def count_base_grad_leaks(model: torch.nn.Module) -> int:
|
||||
|
||||
|
||||
def perturb_first_adapter(model: torch.nn.Module) -> None:
|
||||
priority = ("lora_B", "lora_g", "lora_U", "lora_A", "lora_lambda", "lora_gate")
|
||||
priority = ("lora_B", "lora_g", "lora_U", "lora_A", "lora_lambda", "lora_gate", "lora_delta_s", "lora_rot_T", "lora_m")
|
||||
for key in priority:
|
||||
for _, p in model.named_parameters():
|
||||
if p.requires_grad and key in _:
|
||||
@@ -159,7 +160,7 @@ def perturb_first_adapter(model: torch.nn.Module) -> None:
|
||||
raise AssertionError("no perturbable adapter parameter found")
|
||||
|
||||
|
||||
def load_model_and_tokenizer(model_id: str, dtype: torch.dtype, device: str):
|
||||
def load_model_and_tokenizer(model_id: str, dtype: torch.dtype, device: str, quantization: str = "none"):
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
@@ -167,7 +168,16 @@ def load_model_and_tokenizer(model_id: str, dtype: torch.dtype, device: str):
|
||||
raise RuntimeError(f"tokenizer for {model_id} has no eos_token")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.padding_side = "left"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=dtype).to(device)
|
||||
if quantization == "none":
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=dtype).to(device)
|
||||
else:
|
||||
from transformers import BitsAndBytesConfig
|
||||
bnb_cfg = BitsAndBytesConfig(
|
||||
load_in_4bit=quantization == "4bit",
|
||||
load_in_8bit=quantization == "8bit",
|
||||
bnb_4bit_compute_dtype=dtype if quantization == "4bit" else None,
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_cfg, device_map=device)
|
||||
model.config.use_cache = False
|
||||
return model, tokenizer
|
||||
|
||||
@@ -372,13 +382,8 @@ def evaluate(model, tokenizer, dataset, args: BenchmarkConfig, split: str) -> di
|
||||
|
||||
@torch.no_grad()
|
||||
def probe_before_train(model, batch: dict[str, torch.Tensor | int], attached_targets: list[str]) -> dict[str, Any]:
|
||||
expected_targets = {
|
||||
"model.layers.0.self_attn.q_proj",
|
||||
"model.layers.0.self_attn.v_proj",
|
||||
}
|
||||
attached_set = set(attached_targets)
|
||||
if attached_set != expected_targets:
|
||||
raise AssertionError(f"probe expected layer-0 q/v only, got {sorted(attached_set)}")
|
||||
if not attached_targets:
|
||||
raise AssertionError("probe: no targets attached")
|
||||
logits_init = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).logits.detach().clone()
|
||||
clean_adapter = adapter_state(model)
|
||||
perturb_first_adapter(model)
|
||||
@@ -387,7 +392,7 @@ def probe_before_train(model, batch: dict[str, torch.Tensor | int], attached_tar
|
||||
raise AssertionError(f"adapter perturbation did not affect logits: {perturb_delta}")
|
||||
for name, value in clean_adapter.items():
|
||||
model.state_dict()[name].copy_(value)
|
||||
return {"expected_targets": sorted(expected_targets), "perturb_delta": perturb_delta}
|
||||
return {"attached_targets": sorted(attached_targets), "perturb_delta": perturb_delta}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -401,7 +406,7 @@ def check_probe_reload(
|
||||
del cfg # cfg is saved in the checkpoint; keep the call-site explicit.
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
loaded_model, _ = load_model_and_tokenizer(args.model, getattr(torch, args.torch_dtype), args.device)
|
||||
loaded_model, _ = load_model_and_tokenizer(args.model, getattr(torch, args.torch_dtype), args.device, args.quantization)
|
||||
loaded_model.eval()
|
||||
ll.load(loaded_model, str(adapter_path))
|
||||
saved = torch.load(adapter_path, weights_only=True, map_location="cpu")
|
||||
@@ -489,7 +494,7 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
datasets = load_datasets(args)
|
||||
model, tokenizer = load_model_and_tokenizer(args.model, dtype, args.device)
|
||||
model, tokenizer = load_model_and_tokenizer(args.model, dtype, args.device, args.quantization)
|
||||
batches, skipped_train_prompt_too_long = make_train_batches(datasets["train"], tokenizer, args)
|
||||
cfg = cfg_for_variant(args, dtype)
|
||||
if args.variant == "eva":
|
||||
|
||||
Reference in New Issue
Block a user