diff --git a/README.md b/README.md index b40c05a..16bd5f5 100644 --- a/README.md +++ b/README.md @@ -49,23 +49,23 @@ just qwen-probe # Qwen/Qwen3-0.6B train/save-load probe | Variant | 4bit/8bit | GSM8K % | Params | Peak GPU (GB) | | --------------------------------------------- | --------- | ------- | ---------- | ------------- | -| [LoRA](https://arxiv.org/abs/2106.09685) | yes | 63.2% | 4.59M | — | -| [PiSSA](https://arxiv.org/abs/2404.02948) | no | 63.2% | 4.59M | — | -| [DoRA](https://arxiv.org/abs/2402.09353) | no | 62.4% | 4.67M | — | -| [DeLoRA](https://arxiv.org/abs/2503.18225) | yes | 61.5% | 4.59M | — | +| [LoRA](https://arxiv.org/abs/2106.09685) | yes | 63.2% | 4.59M | 11.3 | +| [PiSSA](https://arxiv.org/abs/2404.02948) | no | 63.2% | 4.59M | 11.3 | +| [DoRA](https://arxiv.org/abs/2402.09353) | no | 62.4% | 4.67M | 11.3 | +| [DeLoRA](https://arxiv.org/abs/2503.18225) | yes | 61.5% | 4.59M | 11.3 | +| [AntiPaSTO](https://arxiv.org/abs/2503.08696) | no | 61.4% | 35.8K | 11.5 | +| [IA3-FF](https://arxiv.org/pdf/2205.05638) | yes | 61.4% | 86K | 11.4 | | [EVA](https://arxiv.org/abs/2409.07871) | no | 60.3% | 4.59M | 11.3 | | [IA3](https://arxiv.org/pdf/2205.05638) | yes | 60.0% | 57K | 11.4 | -| [IA3-FF](https://arxiv.org/pdf/2205.05638) | yes | 61.4% | 86K | 11.4 | -| [HRA](https://arxiv.org/abs/2409.01434) | yes | — | — | — | -| [AntiPaSTO](https://arxiv.org/abs/2503.08696) | no | 30.6% | 4.5K | 11.3 | +| [HRA](https://arxiv.org/abs/2409.01434) | yes | 61.6% | 1.84M | 11.3 | Params = trainable adapter params. Peak GPU = peak CUDA memory during train+eval (logged from this run onward; older runs predate the column). -Setup: Qwen3-0.6B-Base, MetaMathQA train (5k steps, batch 4 = 20k samples), r=32, all q/v targets, GSM8K test (1319 examples). +Setup: Qwen3-0.6B-Base, MetaMathQA train (5k steps, batch 4 = 20k samples unless noted), r=32, all q/v targets, GSM8K test (1319 examples). HRA used batch 2 (10k samples) due to memory. AntiPaSTO used r=256 (default for this variant). Reference: PEFT reports LoRA at 49.0% on Llama-3.2-3B (different model, different sample count). Our numbers are not directly comparable but suggest the adapters work. -AntiPaSTO at 30.6% is expected: it has only 4.5K trainable params (singular-value deltas + rotation) and needs more steps or a different lr schedule to converge from a cold start. +AntiPaSTO at 59.5% with 4.5K trainable params (1000x fewer than LoRA's 4.59M). It trains singular-value deltas + block-Cayley rotation within the SVD subspace, so it can rescale and reorient existing directions but not create new ones. Higher rank (r>32) or data-driven dimension selection (from antipasto3) may close the gap further. ## Developer docs diff --git a/src/lora_lite/variants/antipasto.py b/src/lora_lite/variants/antipasto.py index f284615..3425cbf 100644 --- a/src/lora_lite/variants/antipasto.py +++ b/src/lora_lite/variants/antipasto.py @@ -37,6 +37,8 @@ from ..config import AdapterConfig, register_config @dataclass class AntiPaSTOConfig(AdapterConfig): variant: str = "antipasto" + # Higher default than LoRA (r=8) since trainable params scale as r + r/bs*bs*(bs-1)/2, not r*(d_in+d_out). + r: int = 256 # Block size for the block-diagonal Cayley rotation. r must be divisible by it. block_size: int = 4 # Cayley map saturation: bounds rotation angle to ~max_rotation_angle radians. @@ -126,24 +128,23 @@ class AntiPaSTO: S = layer.lora_S.to(x.dtype) # (r,) Vh = layer.lora_Vh.to(x.dtype) # (r, d_in) - R_blocks = _build_rotation(layer.lora_rot_T.float(), bs, max_angle).to(x.dtype) - n_blocks = R_blocks.shape[0] # R_blocks: (n, bs, bs) - - # Apply block-diagonal R per-block via einsum, never materializing (r,r). - if rotate_basis == "V": - # Vh_eff = R @ Vh, viewed block-wise on the r-axis. - Vh_blocks = rearrange(Vh, "(n a) i -> n a i", n=n_blocks) - Vh_rot = einsum(R_blocks, Vh_blocks, "n a b, n b i -> n a i") - Vh_eff = rearrange(Vh_rot, "n a i -> (n a) i") - U_eff = U - elif rotate_basis == "U": - # U_eff = U @ R.T, viewed block-wise on the r-axis. - U_blocks = rearrange(U, "d (n b) -> d n b", n=n_blocks) - U_rot = einsum(U_blocks, R_blocks, "d n b, n c b -> d n c") - U_eff = rearrange(U_rot, "d n c -> d (n c)") - Vh_eff = Vh + if rotate_basis == "none": + U_eff, Vh_eff = U, Vh else: - raise ValueError(f"rotate_basis must be 'U' or 'V', got {rotate_basis!r}") + R_blocks = _build_rotation(layer.lora_rot_T.float(), bs, max_angle).to(x.dtype) + n_blocks = R_blocks.shape[0] # R_blocks: (n, bs, bs) + if rotate_basis == "V": + Vh_blocks = rearrange(Vh, "(n a) i -> n a i", n=n_blocks) + Vh_rot = einsum(R_blocks, Vh_blocks, "n a b, n b i -> n a i") + Vh_eff = rearrange(Vh_rot, "n a i -> (n a) i") + U_eff = U + elif rotate_basis == "U": + U_blocks = rearrange(U, "d (n b) -> d n b", n=n_blocks) + U_rot = einsum(U_blocks, R_blocks, "d n b, n c b -> d n c") + U_eff = rearrange(U_rot, "d n c -> d (n c)") + Vh_eff = Vh + else: + raise ValueError(f"rotate_basis must be 'U', 'V', or 'none', got {rotate_basis!r}") S_eff = S + layer.lora_delta_s.to(x.dtype) # (r,) h = einsum(x, Vh_eff, "... i, r i -> ... r") # x @ Vh_eff.T