mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 15:15:55 +08:00
wip
This commit is contained in:
@@ -47,16 +47,16 @@ just qwen-probe # Qwen/Qwen3-0.6B train/save-load probe
|
|||||||
|
|
||||||
## Variants
|
## Variants
|
||||||
|
|
||||||
| Variant | 4bit/8bit | MetaMath acc (GSM8K %) | Notes |
|
| Variant | 4bit/8bit | GSM8K % |
|
||||||
|---|---|---|---|
|
| --------------------------------------------- | --------- | ------- |
|
||||||
| [LoRA](https://arxiv.org/abs/2106.09685) | yes | 63.2% | |
|
| [LoRA](https://arxiv.org/abs/2106.09685) | yes | 63.2% |
|
||||||
| [PiSSA](https://arxiv.org/abs/2404.02948) | no (edits weight) | — | |
|
| [PiSSA](https://arxiv.org/abs/2404.02948) | no | — |
|
||||||
| [DeLoRA](https://arxiv.org/abs/2503.18225) | yes | — | |
|
| [DeLoRA](https://arxiv.org/abs/2503.18225) | yes | — |
|
||||||
| [IA3](https://arxiv.org/pdf/2205.05638) | yes | — | output gate (ia3) or input gate (ia3_ff) |
|
| [IA3](https://arxiv.org/pdf/2205.05638) | yes | — |
|
||||||
| [DoRA](https://arxiv.org/abs/2402.09353) | no (reads weight) | — | |
|
| [DoRA](https://arxiv.org/abs/2402.09353) | no | — |
|
||||||
| [HRA](https://arxiv.org/abs/2409.01434) | yes | — | input-side Householder; works on bnb |
|
| [HRA](https://arxiv.org/abs/2409.01434) | yes | — |
|
||||||
| [EVA](https://arxiv.org/abs/2409.07871) | no (calibration SVD) | — | |
|
| [EVA](https://arxiv.org/abs/2409.07871) | no | — |
|
||||||
| [AntiPaSTO](https://arxiv.org/abs/2503.08696) | no (reads weight SVD) | — | |
|
| [AntiPaSTO](https://arxiv.org/abs/2503.08696) | no | — |
|
||||||
|
|
||||||
Our test setup: We take Qwen3-0.6B-Base and train one MetaMathQA for 5000 steps. We use a rank of 32, and itnervene on all linear layer then test on GSM8K.
|
Our test setup: We take Qwen3-0.6B-Base and train one MetaMathQA for 5000 steps. We use a rank of 32, and itnervene on all linear layer then test on GSM8K.
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ default:
|
|||||||
check: test smoke build
|
check: test smoke build
|
||||||
|
|
||||||
test:
|
test:
|
||||||
uv run --extra test pytest -q
|
uv run --extra test --extra benchmark pytest -q
|
||||||
|
|
||||||
smoke:
|
smoke:
|
||||||
uv run --extra test --extra benchmark pytest -q tests/test_metamath_smoke.py -k test_metamath_quick_train_save_load
|
uv run --extra test --extra benchmark pytest -q tests/test_metamath_smoke.py -k test_metamath_quick_train_save_load
|
||||||
@@ -77,29 +77,8 @@ metamath-queue-all model="Qwen/Qwen3-0.6B-Base" steps="5000" variants="lora piss
|
|||||||
#!/usr/bin/env bash
|
#!/usr/bin/env bash
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
for variant in {{variants}}; do
|
for variant in {{variants}}; do
|
||||||
lr=1e-4
|
|
||||||
extra_args=(--target-name '(q_proj|v_proj)$' --layers all --r 32 --alpha 64)
|
|
||||||
case "$variant" in
|
|
||||||
delora)
|
|
||||||
lr=1e-3
|
|
||||||
;;
|
|
||||||
ia3)
|
|
||||||
lr=1e-3
|
|
||||||
extra_args=(--target-name '(k_proj|v_proj)$' --layers all --r 32 --alpha 64)
|
|
||||||
;;
|
|
||||||
ia3_ff)
|
|
||||||
lr=1e-3
|
|
||||||
extra_args=(--target-name '(down_proj)$' --layers all --r 32 --alpha 64)
|
|
||||||
;;
|
|
||||||
eva)
|
|
||||||
lr=1e-4
|
|
||||||
;;
|
|
||||||
antipasto)
|
|
||||||
lr=1e-4
|
|
||||||
;;
|
|
||||||
esac
|
|
||||||
pueue add \
|
pueue add \
|
||||||
-l "why: benchmark {{model}} ${variant} on MetaMathQA->GSM8K at {{steps}} steps; resolve: outputs/metamath_gsm8k/results/benchmark_results.tsv gets a row with accuracy commit time method argv and result JSON for ${variant}" \
|
-l "why: benchmark {{model}} ${variant} on MetaMathQA->GSM8K at {{steps}} steps; resolve: outputs/metamath_gsm8k/results/benchmark_results.tsv gets a row with accuracy commit time method argv and result JSON for ${variant}" \
|
||||||
-w "$PWD" -o 1 -- \
|
-w "$PWD" -o 1 -- \
|
||||||
bash -c "uv run --extra benchmark python scripts/metamath_gsm8k_benchmark.py --model {{model}} --variant $variant --steps {{steps}} --lr $lr $(printf '%q ' "${extra_args[@]}")"
|
bash scripts/bench_variant.sh '{{model}}' "$variant" {{steps}}
|
||||||
done
|
done
|
||||||
@@ -6,6 +6,7 @@ readme = "README.md"
|
|||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"torch>=2.1",
|
"torch>=2.1",
|
||||||
|
"numpy>=1.26",
|
||||||
"einops>=0.7",
|
"einops>=0.7",
|
||||||
"jaxtyping>=0.2.34",
|
"jaxtyping>=0.2.34",
|
||||||
"safetensors>=0.5",
|
"safetensors>=0.5",
|
||||||
|
|||||||
@@ -428,7 +428,7 @@ def check_probe_reload(
|
|||||||
del loaded_model
|
del loaded_model
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
return {"reload_err": reload_err, "saved_tensors": len(saved["state"])}
|
return {"reload_err": reload_err, "saved_tensors": len(saved_sd)}
|
||||||
|
|
||||||
|
|
||||||
def print_final_report(row: dict[str, Any], result_path: Path, mode: str) -> None:
|
def print_final_report(row: dict[str, Any], result_path: Path, mode: str) -> None:
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ resolution-markers = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[options]
|
[options]
|
||||||
exclude-newer = "2026-04-22T01:31:09.47902161Z"
|
exclude-newer = "2026-04-22T02:16:12.377169857Z"
|
||||||
exclude-newer-span = "P5D"
|
exclude-newer-span = "P5D"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1012,6 +1012,8 @@ dependencies = [
|
|||||||
{ name = "einops" },
|
{ name = "einops" },
|
||||||
{ name = "jaxtyping", version = "0.3.7", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
{ name = "jaxtyping", version = "0.3.7", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
||||||
{ name = "jaxtyping", version = "0.3.9", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
|
{ name = "jaxtyping", version = "0.3.9", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
|
||||||
|
{ name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
||||||
|
{ name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
|
||||||
{ name = "safetensors" },
|
{ name = "safetensors" },
|
||||||
{ name = "torch" },
|
{ name = "torch" },
|
||||||
]
|
]
|
||||||
@@ -1051,6 +1053,7 @@ requires-dist = [
|
|||||||
{ name = "datasets", marker = "extra == 'benchmark'", specifier = ">=3.6" },
|
{ name = "datasets", marker = "extra == 'benchmark'", specifier = ">=3.6" },
|
||||||
{ name = "einops", specifier = ">=0.7" },
|
{ name = "einops", specifier = ">=0.7" },
|
||||||
{ name = "jaxtyping", specifier = ">=0.2.34" },
|
{ name = "jaxtyping", specifier = ">=0.2.34" },
|
||||||
|
{ name = "numpy", specifier = ">=1.26" },
|
||||||
{ name = "pytest", marker = "extra == 'test'" },
|
{ name = "pytest", marker = "extra == 'test'" },
|
||||||
{ name = "safetensors", specifier = ">=0.5" },
|
{ name = "safetensors", specifier = ">=0.5" },
|
||||||
{ name = "safetensors", marker = "extra == 'benchmark'", specifier = ">=0.5" },
|
{ name = "safetensors", marker = "extra == 'benchmark'", specifier = ">=0.5" },
|
||||||
|
|||||||
Reference in New Issue
Block a user