tidy tests to subset of metamath

This commit is contained in:
wassname
2026-04-27 09:20:07 +08:00
parent 1a93df10b2
commit 727ef6ea73
6 changed files with 145 additions and 1211 deletions
+18 -13
View File
@@ -45,6 +45,7 @@ class BenchmarkConfig:
mode: Literal["benchmark", "probe"] = "benchmark"
device: str = "cuda"
torch_dtype: str = "bfloat16"
quantization: Literal["none", "4bit", "8bit"] = "none"
r: int = 32
alpha: float = 64.0
delora_lambda0: float = 0.1
@@ -146,7 +147,7 @@ def count_base_grad_leaks(model: torch.nn.Module) -> int:
def perturb_first_adapter(model: torch.nn.Module) -> None:
priority = ("lora_B", "lora_g", "lora_U", "lora_A", "lora_lambda", "lora_gate")
priority = ("lora_B", "lora_g", "lora_U", "lora_A", "lora_lambda", "lora_gate", "lora_delta_s", "lora_rot_T", "lora_m")
for key in priority:
for _, p in model.named_parameters():
if p.requires_grad and key in _:
@@ -159,7 +160,7 @@ def perturb_first_adapter(model: torch.nn.Module) -> None:
raise AssertionError("no perturbable adapter parameter found")
def load_model_and_tokenizer(model_id: str, dtype: torch.dtype, device: str):
def load_model_and_tokenizer(model_id: str, dtype: torch.dtype, device: str, quantization: str = "none"):
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
@@ -167,7 +168,16 @@ def load_model_and_tokenizer(model_id: str, dtype: torch.dtype, device: str):
raise RuntimeError(f"tokenizer for {model_id} has no eos_token")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=dtype).to(device)
if quantization == "none":
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=dtype).to(device)
else:
from transformers import BitsAndBytesConfig
bnb_cfg = BitsAndBytesConfig(
load_in_4bit=quantization == "4bit",
load_in_8bit=quantization == "8bit",
bnb_4bit_compute_dtype=dtype if quantization == "4bit" else None,
)
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_cfg, device_map=device)
model.config.use_cache = False
return model, tokenizer
@@ -372,13 +382,8 @@ def evaluate(model, tokenizer, dataset, args: BenchmarkConfig, split: str) -> di
@torch.no_grad()
def probe_before_train(model, batch: dict[str, torch.Tensor | int], attached_targets: list[str]) -> dict[str, Any]:
expected_targets = {
"model.layers.0.self_attn.q_proj",
"model.layers.0.self_attn.v_proj",
}
attached_set = set(attached_targets)
if attached_set != expected_targets:
raise AssertionError(f"probe expected layer-0 q/v only, got {sorted(attached_set)}")
if not attached_targets:
raise AssertionError("probe: no targets attached")
logits_init = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).logits.detach().clone()
clean_adapter = adapter_state(model)
perturb_first_adapter(model)
@@ -387,7 +392,7 @@ def probe_before_train(model, batch: dict[str, torch.Tensor | int], attached_tar
raise AssertionError(f"adapter perturbation did not affect logits: {perturb_delta}")
for name, value in clean_adapter.items():
model.state_dict()[name].copy_(value)
return {"expected_targets": sorted(expected_targets), "perturb_delta": perturb_delta}
return {"attached_targets": sorted(attached_targets), "perturb_delta": perturb_delta}
@torch.no_grad()
@@ -401,7 +406,7 @@ def check_probe_reload(
del cfg # cfg is saved in the checkpoint; keep the call-site explicit.
gc.collect()
torch.cuda.empty_cache()
loaded_model, _ = load_model_and_tokenizer(args.model, getattr(torch, args.torch_dtype), args.device)
loaded_model, _ = load_model_and_tokenizer(args.model, getattr(torch, args.torch_dtype), args.device, args.quantization)
loaded_model.eval()
ll.load(loaded_model, str(adapter_path))
saved = torch.load(adapter_path, weights_only=True, map_location="cpu")
@@ -489,7 +494,7 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
out_dir.mkdir(parents=True, exist_ok=True)
datasets = load_datasets(args)
model, tokenizer = load_model_and_tokenizer(args.model, dtype, args.device)
model, tokenizer = load_model_and_tokenizer(args.model, dtype, args.device, args.quantization)
batches, skipped_train_prompt_too_long = make_train_batches(datasets["train"], tokenizer, args)
cfg = cfg_for_variant(args, dtype)
if args.variant == "eva":