From a342801807aa838726de3c920641fda735ebf21b Mon Sep 17 00:00:00 2001 From: wassname <1103714+wassname@users.noreply.github.com> Date: Mon, 27 Apr 2026 11:24:19 +0800 Subject: [PATCH] wip --- README.md | 20 ++++++++++---------- justfile | 25 ++----------------------- pyproject.toml | 1 + scripts/metamath_gsm8k_benchmark.py | 2 +- uv.lock | 5 ++++- 5 files changed, 18 insertions(+), 35 deletions(-) diff --git a/README.md b/README.md index c69fe99..8410d57 100644 --- a/README.md +++ b/README.md @@ -47,16 +47,16 @@ just qwen-probe # Qwen/Qwen3-0.6B train/save-load probe ## Variants -| Variant | 4bit/8bit | MetaMath acc (GSM8K %) | Notes | -|---|---|---|---| -| [LoRA](https://arxiv.org/abs/2106.09685) | yes | 63.2% | | -| [PiSSA](https://arxiv.org/abs/2404.02948) | no (edits weight) | — | | -| [DeLoRA](https://arxiv.org/abs/2503.18225) | yes | — | | -| [IA3](https://arxiv.org/pdf/2205.05638) | yes | — | output gate (ia3) or input gate (ia3_ff) | -| [DoRA](https://arxiv.org/abs/2402.09353) | no (reads weight) | — | | -| [HRA](https://arxiv.org/abs/2409.01434) | yes | — | input-side Householder; works on bnb | -| [EVA](https://arxiv.org/abs/2409.07871) | no (calibration SVD) | — | | -| [AntiPaSTO](https://arxiv.org/abs/2503.08696) | no (reads weight SVD) | — | | +| Variant | 4bit/8bit | GSM8K % | +| --------------------------------------------- | --------- | ------- | +| [LoRA](https://arxiv.org/abs/2106.09685) | yes | 63.2% | +| [PiSSA](https://arxiv.org/abs/2404.02948) | no | — | +| [DeLoRA](https://arxiv.org/abs/2503.18225) | yes | — | +| [IA3](https://arxiv.org/pdf/2205.05638) | yes | — | +| [DoRA](https://arxiv.org/abs/2402.09353) | no | — | +| [HRA](https://arxiv.org/abs/2409.01434) | yes | — | +| [EVA](https://arxiv.org/abs/2409.07871) | no | — | +| [AntiPaSTO](https://arxiv.org/abs/2503.08696) | no | — | Our test setup: We take Qwen3-0.6B-Base and train one MetaMathQA for 5000 steps. We use a rank of 32, and itnervene on all linear layer then test on GSM8K. diff --git a/justfile b/justfile index 25bbaf6..5ab18ee 100644 --- a/justfile +++ b/justfile @@ -6,7 +6,7 @@ default: check: test smoke build test: - uv run --extra test pytest -q + uv run --extra test --extra benchmark pytest -q smoke: uv run --extra test --extra benchmark pytest -q tests/test_metamath_smoke.py -k test_metamath_quick_train_save_load @@ -77,29 +77,8 @@ metamath-queue-all model="Qwen/Qwen3-0.6B-Base" steps="5000" variants="lora piss #!/usr/bin/env bash set -euo pipefail for variant in {{variants}}; do - lr=1e-4 - extra_args=(--target-name '(q_proj|v_proj)$' --layers all --r 32 --alpha 64) - case "$variant" in - delora) - lr=1e-3 - ;; - ia3) - lr=1e-3 - extra_args=(--target-name '(k_proj|v_proj)$' --layers all --r 32 --alpha 64) - ;; - ia3_ff) - lr=1e-3 - extra_args=(--target-name '(down_proj)$' --layers all --r 32 --alpha 64) - ;; - eva) - lr=1e-4 - ;; - antipasto) - lr=1e-4 - ;; - esac pueue add \ -l "why: benchmark {{model}} ${variant} on MetaMathQA->GSM8K at {{steps}} steps; resolve: outputs/metamath_gsm8k/results/benchmark_results.tsv gets a row with accuracy commit time method argv and result JSON for ${variant}" \ -w "$PWD" -o 1 -- \ - bash -c "uv run --extra benchmark python scripts/metamath_gsm8k_benchmark.py --model {{model}} --variant $variant --steps {{steps}} --lr $lr $(printf '%q ' "${extra_args[@]}")" + bash scripts/bench_variant.sh '{{model}}' "$variant" {{steps}} done \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index dc6fdbc..5ed5f30 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,7 @@ readme = "README.md" requires-python = ">=3.10" dependencies = [ "torch>=2.1", + "numpy>=1.26", "einops>=0.7", "jaxtyping>=0.2.34", "safetensors>=0.5", diff --git a/scripts/metamath_gsm8k_benchmark.py b/scripts/metamath_gsm8k_benchmark.py index 3244e73..6d41437 100644 --- a/scripts/metamath_gsm8k_benchmark.py +++ b/scripts/metamath_gsm8k_benchmark.py @@ -428,7 +428,7 @@ def check_probe_reload( del loaded_model gc.collect() torch.cuda.empty_cache() - return {"reload_err": reload_err, "saved_tensors": len(saved["state"])} + return {"reload_err": reload_err, "saved_tensors": len(saved_sd)} def print_final_report(row: dict[str, Any], result_path: Path, mode: str) -> None: diff --git a/uv.lock b/uv.lock index 55dedd2..5762346 100644 --- a/uv.lock +++ b/uv.lock @@ -12,7 +12,7 @@ resolution-markers = [ ] [options] -exclude-newer = "2026-04-22T01:31:09.47902161Z" +exclude-newer = "2026-04-22T02:16:12.377169857Z" exclude-newer-span = "P5D" [[package]] @@ -1012,6 +1012,8 @@ dependencies = [ { name = "einops" }, { name = "jaxtyping", version = "0.3.7", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "jaxtyping", version = "0.3.9", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "safetensors" }, { name = "torch" }, ] @@ -1051,6 +1053,7 @@ requires-dist = [ { name = "datasets", marker = "extra == 'benchmark'", specifier = ">=3.6" }, { name = "einops", specifier = ">=0.7" }, { name = "jaxtyping", specifier = ">=0.2.34" }, + { name = "numpy", specifier = ">=1.26" }, { name = "pytest", marker = "extra == 'test'" }, { name = "safetensors", specifier = ">=0.5" }, { name = "safetensors", marker = "extra == 'benchmark'", specifier = ">=0.5" },