This commit is contained in:
wassname
2026-04-27 11:24:19 +08:00
parent 24ba8deb02
commit a342801807
5 changed files with 18 additions and 35 deletions
+10 -10
View File
@@ -47,16 +47,16 @@ just qwen-probe # Qwen/Qwen3-0.6B train/save-load probe
## Variants
| Variant | 4bit/8bit | MetaMath acc (GSM8K %) | Notes |
|---|---|---|---|
| [LoRA](https://arxiv.org/abs/2106.09685) | yes | 63.2% | |
| [PiSSA](https://arxiv.org/abs/2404.02948) | no (edits weight) | — | |
| [DeLoRA](https://arxiv.org/abs/2503.18225) | yes | — | |
| [IA3](https://arxiv.org/pdf/2205.05638) | yes | — | output gate (ia3) or input gate (ia3_ff) |
| [DoRA](https://arxiv.org/abs/2402.09353) | no (reads weight) | — | |
| [HRA](https://arxiv.org/abs/2409.01434) | yes | — | input-side Householder; works on bnb |
| [EVA](https://arxiv.org/abs/2409.07871) | no (calibration SVD) | — | |
| [AntiPaSTO](https://arxiv.org/abs/2503.08696) | no (reads weight SVD) | — | |
| Variant | 4bit/8bit | GSM8K % |
| --------------------------------------------- | --------- | ------- |
| [LoRA](https://arxiv.org/abs/2106.09685) | yes | 63.2% |
| [PiSSA](https://arxiv.org/abs/2404.02948) | no | — |
| [DeLoRA](https://arxiv.org/abs/2503.18225) | yes | — |
| [IA3](https://arxiv.org/pdf/2205.05638) | yes | — |
| [DoRA](https://arxiv.org/abs/2402.09353) | no | — |
| [HRA](https://arxiv.org/abs/2409.01434) | yes | — |
| [EVA](https://arxiv.org/abs/2409.07871) | no | — |
| [AntiPaSTO](https://arxiv.org/abs/2503.08696) | no | — |
Our test setup: We take Qwen3-0.6B-Base and train one MetaMathQA for 5000 steps. We use a rank of 32, and itnervene on all linear layer then test on GSM8K.
+2 -23
View File
@@ -6,7 +6,7 @@ default:
check: test smoke build
test:
uv run --extra test pytest -q
uv run --extra test --extra benchmark pytest -q
smoke:
uv run --extra test --extra benchmark pytest -q tests/test_metamath_smoke.py -k test_metamath_quick_train_save_load
@@ -77,29 +77,8 @@ metamath-queue-all model="Qwen/Qwen3-0.6B-Base" steps="5000" variants="lora piss
#!/usr/bin/env bash
set -euo pipefail
for variant in {{variants}}; do
lr=1e-4
extra_args=(--target-name '(q_proj|v_proj)$' --layers all --r 32 --alpha 64)
case "$variant" in
delora)
lr=1e-3
;;
ia3)
lr=1e-3
extra_args=(--target-name '(k_proj|v_proj)$' --layers all --r 32 --alpha 64)
;;
ia3_ff)
lr=1e-3
extra_args=(--target-name '(down_proj)$' --layers all --r 32 --alpha 64)
;;
eva)
lr=1e-4
;;
antipasto)
lr=1e-4
;;
esac
pueue add \
-l "why: benchmark {{model}} ${variant} on MetaMathQA->GSM8K at {{steps}} steps; resolve: outputs/metamath_gsm8k/results/benchmark_results.tsv gets a row with accuracy commit time method argv and result JSON for ${variant}" \
-w "$PWD" -o 1 -- \
bash -c "uv run --extra benchmark python scripts/metamath_gsm8k_benchmark.py --model {{model}} --variant $variant --steps {{steps}} --lr $lr $(printf '%q ' "${extra_args[@]}")"
bash scripts/bench_variant.sh '{{model}}' "$variant" {{steps}}
done
+1
View File
@@ -6,6 +6,7 @@ readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"torch>=2.1",
"numpy>=1.26",
"einops>=0.7",
"jaxtyping>=0.2.34",
"safetensors>=0.5",
+1 -1
View File
@@ -428,7 +428,7 @@ def check_probe_reload(
del loaded_model
gc.collect()
torch.cuda.empty_cache()
return {"reload_err": reload_err, "saved_tensors": len(saved["state"])}
return {"reload_err": reload_err, "saved_tensors": len(saved_sd)}
def print_final_report(row: dict[str, Any], result_path: Path, mode: str) -> None:
Generated
+4 -1
View File
@@ -12,7 +12,7 @@ resolution-markers = [
]
[options]
exclude-newer = "2026-04-22T01:31:09.47902161Z"
exclude-newer = "2026-04-22T02:16:12.377169857Z"
exclude-newer-span = "P5D"
[[package]]
@@ -1012,6 +1012,8 @@ dependencies = [
{ name = "einops" },
{ name = "jaxtyping", version = "0.3.7", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
{ name = "jaxtyping", version = "0.3.9", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
{ name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
{ name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
{ name = "safetensors" },
{ name = "torch" },
]
@@ -1051,6 +1053,7 @@ requires-dist = [
{ name = "datasets", marker = "extra == 'benchmark'", specifier = ">=3.6" },
{ name = "einops", specifier = ">=0.7" },
{ name = "jaxtyping", specifier = ">=0.2.34" },
{ name = "numpy", specifier = ">=1.26" },
{ name = "pytest", marker = "extra == 'test'" },
{ name = "safetensors", specifier = ">=0.5" },
{ name = "safetensors", marker = "extra == 'benchmark'", specifier = ">=0.5" },