mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 15:15:55 +08:00
wip
This commit is contained in:
@@ -47,16 +47,16 @@ just qwen-probe # Qwen/Qwen3-0.6B train/save-load probe
|
||||
|
||||
## Variants
|
||||
|
||||
| Variant | 4bit/8bit | MetaMath acc (GSM8K %) | Notes |
|
||||
|---|---|---|---|
|
||||
| [LoRA](https://arxiv.org/abs/2106.09685) | yes | 63.2% | |
|
||||
| [PiSSA](https://arxiv.org/abs/2404.02948) | no (edits weight) | — | |
|
||||
| [DeLoRA](https://arxiv.org/abs/2503.18225) | yes | — | |
|
||||
| [IA3](https://arxiv.org/pdf/2205.05638) | yes | — | output gate (ia3) or input gate (ia3_ff) |
|
||||
| [DoRA](https://arxiv.org/abs/2402.09353) | no (reads weight) | — | |
|
||||
| [HRA](https://arxiv.org/abs/2409.01434) | yes | — | input-side Householder; works on bnb |
|
||||
| [EVA](https://arxiv.org/abs/2409.07871) | no (calibration SVD) | — | |
|
||||
| [AntiPaSTO](https://arxiv.org/abs/2503.08696) | no (reads weight SVD) | — | |
|
||||
| Variant | 4bit/8bit | GSM8K % |
|
||||
| --------------------------------------------- | --------- | ------- |
|
||||
| [LoRA](https://arxiv.org/abs/2106.09685) | yes | 63.2% |
|
||||
| [PiSSA](https://arxiv.org/abs/2404.02948) | no | — |
|
||||
| [DeLoRA](https://arxiv.org/abs/2503.18225) | yes | — |
|
||||
| [IA3](https://arxiv.org/pdf/2205.05638) | yes | — |
|
||||
| [DoRA](https://arxiv.org/abs/2402.09353) | no | — |
|
||||
| [HRA](https://arxiv.org/abs/2409.01434) | yes | — |
|
||||
| [EVA](https://arxiv.org/abs/2409.07871) | no | — |
|
||||
| [AntiPaSTO](https://arxiv.org/abs/2503.08696) | no | — |
|
||||
|
||||
Our test setup: We take Qwen3-0.6B-Base and train one MetaMathQA for 5000 steps. We use a rank of 32, and itnervene on all linear layer then test on GSM8K.
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ default:
|
||||
check: test smoke build
|
||||
|
||||
test:
|
||||
uv run --extra test pytest -q
|
||||
uv run --extra test --extra benchmark pytest -q
|
||||
|
||||
smoke:
|
||||
uv run --extra test --extra benchmark pytest -q tests/test_metamath_smoke.py -k test_metamath_quick_train_save_load
|
||||
@@ -77,29 +77,8 @@ metamath-queue-all model="Qwen/Qwen3-0.6B-Base" steps="5000" variants="lora piss
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
for variant in {{variants}}; do
|
||||
lr=1e-4
|
||||
extra_args=(--target-name '(q_proj|v_proj)$' --layers all --r 32 --alpha 64)
|
||||
case "$variant" in
|
||||
delora)
|
||||
lr=1e-3
|
||||
;;
|
||||
ia3)
|
||||
lr=1e-3
|
||||
extra_args=(--target-name '(k_proj|v_proj)$' --layers all --r 32 --alpha 64)
|
||||
;;
|
||||
ia3_ff)
|
||||
lr=1e-3
|
||||
extra_args=(--target-name '(down_proj)$' --layers all --r 32 --alpha 64)
|
||||
;;
|
||||
eva)
|
||||
lr=1e-4
|
||||
;;
|
||||
antipasto)
|
||||
lr=1e-4
|
||||
;;
|
||||
esac
|
||||
pueue add \
|
||||
-l "why: benchmark {{model}} ${variant} on MetaMathQA->GSM8K at {{steps}} steps; resolve: outputs/metamath_gsm8k/results/benchmark_results.tsv gets a row with accuracy commit time method argv and result JSON for ${variant}" \
|
||||
-w "$PWD" -o 1 -- \
|
||||
bash -c "uv run --extra benchmark python scripts/metamath_gsm8k_benchmark.py --model {{model}} --variant $variant --steps {{steps}} --lr $lr $(printf '%q ' "${extra_args[@]}")"
|
||||
bash scripts/bench_variant.sh '{{model}}' "$variant" {{steps}}
|
||||
done
|
||||
@@ -6,6 +6,7 @@ readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"torch>=2.1",
|
||||
"numpy>=1.26",
|
||||
"einops>=0.7",
|
||||
"jaxtyping>=0.2.34",
|
||||
"safetensors>=0.5",
|
||||
|
||||
@@ -428,7 +428,7 @@ def check_probe_reload(
|
||||
del loaded_model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return {"reload_err": reload_err, "saved_tensors": len(saved["state"])}
|
||||
return {"reload_err": reload_err, "saved_tensors": len(saved_sd)}
|
||||
|
||||
|
||||
def print_final_report(row: dict[str, Any], result_path: Path, mode: str) -> None:
|
||||
|
||||
@@ -12,7 +12,7 @@ resolution-markers = [
|
||||
]
|
||||
|
||||
[options]
|
||||
exclude-newer = "2026-04-22T01:31:09.47902161Z"
|
||||
exclude-newer = "2026-04-22T02:16:12.377169857Z"
|
||||
exclude-newer-span = "P5D"
|
||||
|
||||
[[package]]
|
||||
@@ -1012,6 +1012,8 @@ dependencies = [
|
||||
{ name = "einops" },
|
||||
{ name = "jaxtyping", version = "0.3.7", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
||||
{ name = "jaxtyping", version = "0.3.9", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
|
||||
{ name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
||||
{ name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
|
||||
{ name = "safetensors" },
|
||||
{ name = "torch" },
|
||||
]
|
||||
@@ -1051,6 +1053,7 @@ requires-dist = [
|
||||
{ name = "datasets", marker = "extra == 'benchmark'", specifier = ">=3.6" },
|
||||
{ name = "einops", specifier = ">=0.7" },
|
||||
{ name = "jaxtyping", specifier = ">=0.2.34" },
|
||||
{ name = "numpy", specifier = ">=1.26" },
|
||||
{ name = "pytest", marker = "extra == 'test'" },
|
||||
{ name = "safetensors", specifier = ">=0.5" },
|
||||
{ name = "safetensors", marker = "extra == 'benchmark'", specifier = ">=0.5" },
|
||||
|
||||
Reference in New Issue
Block a user