diff --git a/README.md b/README.md index 71f379e..c69fe99 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,17 @@ Hackable PyTorch adapters for LoRA-family and small PEFT experiments. -`lora-lite` uses forward hooks instead of module replacement. Adapter parameters are plain `nn.Parameter`s on the target layer, e.g. `model.layers[5].self_attn.q_proj.lora_A`. +## Hackable code + + +To keep it simple and hackable we make these choices: + +- Simple forward hooks, no module replacement or custom modules. +- Simple code over fast performance +- No merge/unmerge +- Single test where we train on MetaMathQA and test on GSM8K for each variant + +Take a look at [lora.py](src/lora_lite/variants/lora.py) ## Install @@ -22,9 +32,9 @@ ll.attach(model, cfg) opt = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=1e-4) # train... -ll.save(model, "adapter.pt") +ll.save(model, "adapter.safetensors") ll.detach(model) -ll.load(model, "adapter.pt") +ll.load(model, "adapter.safetensors") ``` ## Does it work? @@ -35,49 +45,28 @@ just bnb-smoke # required CUDA bitsandbytes 4bit/8bit smoke just qwen-probe # Qwen/Qwen3-0.6B train/save-load probe ``` -See [docs/spec/20260426_lora_lite_plan.md](docs/spec/20260426_lora_lite_plan.md) for verification history and exact results. - ## Variants -| Variant | Support | Notes | -|---|---|---| -| LoRA | yes | additive low-rank adapter | -| PiSSA | yes, fp only | mutates `weight` into `W_res`; quantized PiSSA intentionally fails | -| DeLoRA | yes | normalized additive adapter with learned scalar | -| IA3 | yes | output gate (`ia3`) or input gate (`ia3_ff`); init to ones | -| DoRA | yes, fp only | reads dense `weight` for column-norm; quantized DoRA fails loudly | -| HRA | yes | input-side Householder product via pre-hook; works on bnb | -| EVA | yes, fp only | LoRA forward; `lora_A` init from PCA on calibration activations | -| AntiPaSTO | yes, fp only | top-r weight SVD with learnable singular-value deltas + Cayley rotation | -| SSVD / OFT / ROAD | no | planned | +| Variant | 4bit/8bit | MetaMath acc (GSM8K %) | Notes | +|---|---|---|---| +| [LoRA](https://arxiv.org/abs/2106.09685) | yes | 63.2% | | +| [PiSSA](https://arxiv.org/abs/2404.02948) | no (edits weight) | — | | +| [DeLoRA](https://arxiv.org/abs/2503.18225) | yes | — | | +| [IA3](https://arxiv.org/pdf/2205.05638) | yes | — | output gate (ia3) or input gate (ia3_ff) | +| [DoRA](https://arxiv.org/abs/2402.09353) | no (reads weight) | — | | +| [HRA](https://arxiv.org/abs/2409.01434) | yes | — | input-side Householder; works on bnb | +| [EVA](https://arxiv.org/abs/2409.07871) | no (calibration SVD) | — | | +| [AntiPaSTO](https://arxiv.org/abs/2503.08696) | no (reads weight SVD) | — | | -## Targeting +Our test setup: We take Qwen3-0.6B-Base and train one MetaMathQA for 5000 steps. We use a rank of 32, and itnervene on all linear layer then test on GSM8K. -By default, `lora-lite` targets linear-like modules with `in_features`, `out_features`, and `weight`, excluding `lm_head` and `embed_tokens`. -Useful `AdapterConfig` fields (shared across all variants; subclasses add -variant-specific knobs like `lambda0` on `DeLoRAConfig`): +Is this a good accuracy? TODO we need a like-for-like comparison against PEFT LoRA in the same setup before drawing conclusions. But the [PEFT library](https://github.com/huggingface/peft#results) reports LoRA at 49.0% on Llama-3.2-3B (different model and sample count). -- `target_roles`: subset of `("reader", "writer", "inner")`; `()` means all. -- `target_names`: regex includes. -- `exclude_names`: regex excludes. -- `layers`: layer indices, matching `.layers..` in module names. - -This structural targeting is why LoRA, DeLoRA, and IA3 can run on bnb-style `Linear4bit`/`Linear8bitLt` modules. PiSSA is different because it edits the base weight. - -## Save format - -Adapters are just: - -```python -torch.save({"cfg": cfg.to_dict(), "state": lora_state_dict}, "adapter.pt") -``` - -`lora_state_dict` contains full-path keys with `"lora_"` in the name. Missing or unexpected adapter keys fail on load. ## Developer docs -See [docs/developer_guide.md](docs/developer_guide.md) for the variant API, data-calibrated init, and adapter roadmap. +See [docs/developer_guide.md](docs/developer_guide.md) for the variant API, data-calibrated init, and save/load format. ## Citation @@ -89,3 +78,4 @@ See [docs/developer_guide.md](docs/developer_guide.md) for the variant API, data url = {https://github.com/wassname/lora-lite/} } ``` + diff --git a/justfile b/justfile index 981c96f..25bbaf6 100644 --- a/justfile +++ b/justfile @@ -9,10 +9,10 @@ test: uv run --extra test pytest -q smoke: - uv run --extra test python tests/smoke.py + uv run --extra test --extra benchmark pytest -q tests/test_metamath_smoke.py -k test_metamath_quick_train_save_load bnb-smoke: - uv run --extra test --extra bnb-test python tests/smoke.py --require-bnb + uv run --extra test --extra benchmark --extra bnb-test pytest -q tests/test_metamath_smoke.py -k test_attach_on_bnb_loaded_base build: rm -rf dist @@ -101,5 +101,5 @@ metamath-queue-all model="Qwen/Qwen3-0.6B-Base" steps="5000" variants="lora piss pueue add \ -l "why: benchmark {{model}} ${variant} on MetaMathQA->GSM8K at {{steps}} steps; resolve: outputs/metamath_gsm8k/results/benchmark_results.tsv gets a row with accuracy commit time method argv and result JSON for ${variant}" \ -w "$PWD" -o 1 -- \ - uv run --extra benchmark python scripts/metamath_gsm8k_benchmark.py --model {{model}} --variant "$variant" --steps {{steps}} --lr "$lr" "${extra_args[@]}" + bash -c "uv run --extra benchmark python scripts/metamath_gsm8k_benchmark.py --model {{model}} --variant $variant --steps {{steps}} --lr $lr $(printf '%q ' "${extra_args[@]}")" done \ No newline at end of file