simpler test

This commit is contained in:
wassname
2026-04-27 09:47:07 +08:00
parent b60a8c3f9b
commit 24ba8deb02
10 changed files with 566 additions and 41 deletions
+41 -28
View File
@@ -33,6 +33,7 @@ CFG_BY_VARIANT = {
"hra": ll.HRAConfig,
"eva": ll.EVAConfig,
"antipasto": ll.AntiPaSTOConfig,
"road": ll.RoadConfig,
}
@@ -41,7 +42,7 @@ class BenchmarkConfig:
"""MetaMathQA -> GSM8K benchmark config. Tyro turns this into the CLI."""
model: str = "Qwen/Qwen3-0.6B-Base"
variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto"] = "lora"
variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "road"] = "lora"
mode: Literal["benchmark", "probe"] = "benchmark"
device: str = "cuda"
torch_dtype: str = "bfloat16"
@@ -49,6 +50,7 @@ class BenchmarkConfig:
r: int = 32
alpha: float = 64.0
delora_lambda0: float = 0.1
road_group_size: int = 64
target_name: list[str] = field(default_factory=lambda: list(DEFAULT_TARGETS))
layers: str = "all"
train_dataset: str = "meta-math/MetaMathQA"
@@ -118,6 +120,8 @@ def parse_layers(text: str) -> tuple[int, ...] | None:
def cfg_for_variant(args: BenchmarkConfig, dtype: torch.dtype) -> ll.AdapterConfig:
extra = {"lambda0": args.delora_lambda0} if args.variant == "delora" else {}
if args.variant == "road":
extra = {"group_size": args.road_group_size}
return CFG_BY_VARIANT[args.variant](
r=args.r,
alpha=args.r if args.variant == "pissa" else args.alpha,
@@ -147,7 +151,7 @@ def count_base_grad_leaks(model: torch.nn.Module) -> int:
def perturb_first_adapter(model: torch.nn.Module) -> None:
priority = ("lora_B", "lora_g", "lora_U", "lora_A", "lora_lambda", "lora_gate", "lora_delta_s", "lora_rot_T", "lora_m")
priority = ("lora_B", "lora_g", "lora_U", "lora_A", "lora_lambda", "lora_gate", "lora_delta_s", "lora_rot_T", "lora_m", "lora_road_theta", "lora_road_alpha")
for key in priority:
for _, p in model.named_parameters():
if p.requires_grad and key in _:
@@ -409,11 +413,12 @@ def check_probe_reload(
loaded_model, _ = load_model_and_tokenizer(args.model, getattr(torch, args.torch_dtype), args.device, args.quantization)
loaded_model.eval()
ll.load(loaded_model, str(adapter_path))
saved = torch.load(adapter_path, weights_only=True, map_location="cpu")
from safetensors.torch import load_file
saved_sd = load_file(str(adapter_path), device="cpu")
loaded_state = adapter_state(loaded_model)
if set(saved["state"]) != set(loaded_state):
if set(saved_sd) != set(loaded_state):
raise AssertionError("loaded adapter keys differ from saved adapter keys")
for name, value in saved["state"].items():
for name, value in saved_sd.items():
if not torch.equal(loaded_state[name].cpu(), value):
raise AssertionError(f"loaded adapter tensor differs: {name}")
logits_loaded = loaded_model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).logits.detach().clone()
@@ -426,12 +431,21 @@ def check_probe_reload(
return {"reload_err": reload_err, "saved_tensors": len(saved["state"])}
def print_final_report(row: dict[str, Any], result_path: Path) -> None:
print("SHOULD: grad>0, dθ>0, base_grad_leaks=0, valid/test fields present; probeΔ<0 is good but not required for tiny random smoke. ELSE adapter or eval wiring is dead/wrong.")
def print_final_report(row: dict[str, Any], result_path: Path, mode: str) -> None:
# BLUF: status line first so log tails are immediately readable
cue = "🟢" if row.get("base_grad_leaks", 0) == 0 and row.get("grad", 0) > 0 else "🔴"
n = row.get("samples", "?")
print(f"{cue} test_acc={row['test_acc']:.4g} valid_acc={row['valid_acc']:.4g} grad={row['grad']:.3g} dθ={row['']:.3g} base_grad_leaks={row['base_grad_leaks']} N={n}")
print("SHOULD: grad>0, dθ>0, base_grad_leaks=0; test/valid_acc meaningful only in benchmark mode. ELSE adapter or eval wiring is dead/wrong.")
# ordered: most important / shortest columns first
display_keys = ["variant", "test_acc", "valid_acc", "grad", "", "base_grad_leaks", "steps", "samples", "loss0", "lossN", "commit"]
if "perturb" in row:
display_keys += ["perturb", "reload"]
display_keys += ["run_id"]
display_row = {k: row[k] for k in display_keys if k in row}
print(tabulate([display_row], headers="keys", tablefmt="tsv", floatfmt=".4g"))
print(f"argv: {' '.join(sys.argv)} N={n} mode={mode}")
print(f"out: {result_path}")
print(f"argv: {' '.join(sys.argv)}")
print(f"main metric: test_acc={row['test_acc']:.4g} valid_acc={row['valid_acc']:.4g} steps={row['steps']} samples={row['samples']}")
print(tabulate([row], headers="keys", tablefmt="tsv", floatfmt=".4g"))
def current_git_commit() -> str:
@@ -447,25 +461,24 @@ def append_results_row(
result: dict[str, Any],
run_commit: str,
) -> tuple[Path, Path]:
results_dir = args.output_dir / "results"
results_dir = args.output_dir
results_dir.mkdir(parents=True, exist_ok=True)
tsv_path = results_dir / "benchmark_results.tsv"
lock_path = results_dir / "benchmark_results.tsv.lock"
tsv_path = results_dir / "summary.tsv"
lock_path = results_dir / "summary.tsv.lock"
finished_at = datetime.now(timezone.utc).isoformat(timespec="seconds")
finished_label = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
snapshot_path = results_dir / f"{result['run_id']}__{finished_label}.json"
snapshot_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
row = {
"time_utc": finished_at,
"commit": run_commit,
"test_acc": result["test_acc"],
"valid_acc": result["valid_acc"],
"method": args.variant,
"model": args.model,
"mode": args.mode,
"valid_accuracy": result["valid_accuracy"],
"test_accuracy": result["test_accuracy"],
"steps": args.steps,
"samples": result["train_samples"],
"wall_time_s": result["wall_time_s"],
"model": args.model,
"commit": run_commit[:12],
"wall_time_s": round(result["wall_time_s"]),
"time_utc": finished_at,
"argv": " ".join(sys.argv),
"result_json": str(snapshot_path),
"latest_result_json": str(result_path),
@@ -517,7 +530,7 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
valid_metrics = evaluate(model, tokenizer, datasets["valid"], args, "valid")
test_metrics = evaluate(model, tokenizer, datasets["test"], args, "test")
adapter_path = out_dir / "adapter.pt"
adapter_path = out_dir / "adapter.safetensors"
ll.save(model, str(adapter_path))
if args.mode == "probe":
model.eval()
@@ -555,8 +568,8 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
"weight_decay": args.weight_decay,
"lr_scheduler": "cosine",
"grad_norm_clip": args.grad_norm_clip,
"valid_accuracy": valid_metrics["accuracy"],
"test_accuracy": test_metrics["accuracy"],
"valid_acc": valid_metrics["accuracy"],
"test_acc": test_metrics["accuracy"],
"train": train_metrics,
"valid": valid_metrics,
"test": test_metrics,
@@ -566,12 +579,12 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
}
result_path = out_dir / "result.json"
result_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
results_tsv_path, result_snapshot_path = append_results_row(args, result_path, result, run_commit)
result["results_tsv_path"] = str(results_tsv_path)
result["result_snapshot_path"] = str(result_snapshot_path)
if args.mode == "benchmark":
results_tsv_path, result_snapshot_path = append_results_row(args, result_path, result, run_commit)
result["results_tsv_path"] = str(results_tsv_path)
result["result_snapshot_path"] = str(result_snapshot_path)
result["commit"] = run_commit
result_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
commit_prefix = run_commit[:12]
row = {
"run_id": run_id,
@@ -592,7 +605,7 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
if probe_metrics is not None:
row["perturb"] = probe_metrics["perturb_delta"]
row["reload"] = probe_metrics["reload_err"]
print_final_report(row, result_path)
print_final_report(row, result_path, args.mode)
return result