diff --git a/src/lora_lite/variants/antipasto.py b/src/lora_lite/variants/antipasto.py index e6b652d..db55f24 100644 --- a/src/lora_lite/variants/antipasto.py +++ b/src/lora_lite/variants/antipasto.py @@ -1,39 +1,18 @@ """AntiPaSTO: SVD steering with learnable singular-value deltas + block-diagonal Cayley rotation. -Paper: https://arxiv.org/pdf/2601.07473 (wassname, AntiPaSTO -- SVD-based PEFT) -Repo: https://github.com/wassname/AntiPaSTO -Lite port of the AntiPaSTO3 SVD adapter from - https://github.com/wassname/antipasto3 (offline: docs/refs/antipasto3_svd_adapter.py) +wassname 2026 https://arxiv.org/abs/2601.07473 -Decomposition (PyTorch nn.Linear convention, weight (d_out, d_in)): + W = U diag(S) Vh + W_res (top-r SVD; W_res = W - U_r S_r Vh_r) + learn: delta_s (r,), rot_T (n_blocks, bs(bs-1)/2) + R = block_diag(Cayley(skew(rot_T))); Vh_eff = R @ Vh (or U_eff = U @ R.T) + y = x @ W_res.T + ((x @ Vh_eff.T) * (S + delta_s)) @ U_eff.T - W = U diag(S) Vh + W_res (top-r SVD; W_res = W - U_r diag(S_r) Vh_r) +Identity at t=0: rot_T=0 -> R=I, delta_s=0 -> y == x @ W^T (fp32 SVD round-trip). -We freeze U, S, Vh, W_res and learn: - - delta_s : (r,) -- additive delta to singular values - - rot_T : (n_blocks, bs(bs-1)/2) -- upper-triangle of skew matrix per block - -Forward (matches base layer convention exactly at t=0): - - R = block_diag(Cayley(skew(rot_T))) # (r, r) effective - Vh_rot = R @ Vh # rotates input basis - S_eff = S + delta_s # learnable spectrum - delta_y = ((x @ Vh_rot.T) * S_eff) @ U.T # rank-r path - base_y = x @ W_res.T # frozen residual - y_total = base_y + delta_y # == original output at t=0 - -At init: rot_T = 0 -> R = I -> Vh_rot = Vh, delta_s = 0 -> S_eff = S, so -delta_y reconstructs the truncated SVD term and y_total == x @ W^T to numerical -precision (fp32 SVD round-tripped to cfg.dtype). - -WHICH BASIS IS ROTATED: - By default we rotate Vh (the INPUT singular basis). This is what AntiPaSTO3 - calls `rotate_V=True` in adapter terms (V == Vh.T columns). To rotate U - (output basis) instead, pass `rotate_basis='U'` on the AntiPaSTOConfig. - Rotating both is not implemented (one rotation is enough to span the - identifiable steering directions; two is degenerate). - -REQUIRES even rank divisible by `block_size` (default 4). r=8, bs=4 -> 2 blocks. +Refs: + - paper: https://github.com/wassname/AntiPaSTO + - lite port of: https://github.com/wassname/antipasto3 + (offline: docs/refs/antipasto3_svd_adapter.py) """ import math from dataclasses import dataclass @@ -102,7 +81,7 @@ class AntiPaSTO: n_blocks = r // bs n_triu = bs * (bs - 1) // 2 return { - # Frozen SVD components captured at init (buffers travel with state_dict). + # Frozen SVD components captured at init. "lora_U": ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True), "lora_S": ParamSpec((r,), init="zeros", trainable=False, as_buffer=True), "lora_Vh": ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True), @@ -126,8 +105,6 @@ class AntiPaSTO: layer.lora_U.copy_(Ur.to(layer.lora_U.dtype)) layer.lora_S.copy_(Sr.to(layer.lora_S.dtype)) layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype)) - # W_res is the residual after rank-r truncation. Forward adds back - # the truncated path so total == W exactly at init (mod dtype). W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype) layer.weight.data.copy_(W_res) diff --git a/src/lora_lite/variants/delora.py b/src/lora_lite/variants/delora.py index a4b213f..a7bf54f 100644 --- a/src/lora_lite/variants/delora.py +++ b/src/lora_lite/variants/delora.py @@ -1,34 +1,19 @@ -"""DeLoRA: per-input-channel weight-norm scaling, per-rank A/B normalization. +"""DeLoRA: per-channel weight-norm scaling, per-rank A/B normalization. Bini et al. 2025 (ICLR'25) https://arxiv.org/abs/2503.18225 -Paper Eq. 8: W' = W + (lambda * ||W||_F / r) B Xi A -where Xi_{i,i} = 1 / (||b_i|| ||a_i||) makes each rank-1 component unit-norm. + W' = W + (lambda * ||W||_F / r) B Xi A, Xi_{i,i} = 1 / (||b_i|| ||a_i||) -Implementation follows the peft upstream (which the DeLoRA authors maintain), -which differs from the paper notation in two ways that are equivalent at the -forward level but matter for gradients/numerics: - 1. ||W|| is captured PER INPUT CHANNEL (shape (d_in,)), not as a scalar - Frobenius norm. Used to scale `x` element-wise on the input dim. - See docs/refs/peft_delora_layer.py:150 (init) and :250 (forward). - 2. Per-rank normalization applied via division (1/||A_i||*||B^j||) inside - the diagonal scaling, instead of as F.normalize on A,B themselves. - This keeps the gradient flowing through the un-normalized parameters. +Per peft upstream: ||W|| is per-input-channel (not scalar Frobenius), and +per-rank norms divide inside the diag (not via F.normalize on A,B) so +gradients flow through un-normalized parameters. -Identity at t=0: lambda0=0 -> delta is exactly zero (bit-identity). +Identity at t=0: lambda0=0 -> delta=0 (bit-identity). -KNOWN GRADIENT ISSUE (flagged by external review 2026-04-26): - With lambda0=0 the *forward* is identity but `A,B` get zero gradient on step 0 - (delta is proportional to lambda). Only `lora_lambda` moves first step. - The paper's true initialization (frozen-copy trick, Eq. 9) achieves both - identity AND non-zero A/B gradients; we do NOT implement it here. - -Reference implementations: - - DeLoRA paper authors (ExplainableML/DeLoRA) -- their fork of peft: - https://github.com/ExplainableML/DeLoRA/blob/main/peft/src/peft/tuners/delora.py +Refs: + - paper code: https://github.com/ExplainableML/DeLoRA/blob/main/peft/src/peft/tuners/delora.py (offline: docs/refs/orig_delora.py) - - peft DeLoRA (upstreamed): - https://github.com/huggingface/peft/blob/main/src/peft/tuners/delora/layer.py + - peft: https://github.com/huggingface/peft/blob/main/src/peft/tuners/delora/layer.py (offline: docs/refs/peft_delora_layer.py) """ import torch @@ -45,8 +30,8 @@ from ..config import AdapterConfig, register_config @dataclass class DeLoRAConfig(AdapterConfig): variant: str = "delora" - # Initial scale for the per-layer learnable lambda. peft default is 15.0; - # we default to 0.0 (identity at t=0 even before B is zero-initialized). + # 0.0 = bit-identity at t=0, but A,B get zero grad until lambda moves + # (delta ∝ lambda). peft default is 15.0. lambda0: float = 0.0 @@ -58,27 +43,19 @@ class DeLoRA: def param_specs(d_in, d_out, cfg): lam0 = float(cfg.lambda0) return { - # peft DeLoRA default: A=kaiming, B=zeros (docs/refs/peft_delora_layer.py:138-140). - # Identity at t=0 from B=0 -> delta=0 regardless of lambda. With B=0 the - # delta is a function of B alone on step 0; gradient flows into B (nonzero) - # and into A only after B becomes nonzero (step 2+). Matches peft. "lora_A": ParamSpec((cfg.r, d_in), init="kaiming", trainable=True), "lora_B": ParamSpec((d_out, cfg.r), init="zeros", trainable=True), "lora_lambda": ParamSpec( (), init=lambda t: t.fill_(lam0), trainable=True ), - # ||W||_2 per input channel (shape (d_in,)); frozen buffer captured at init - # per peft DeLoRA (docs/refs/peft_delora_layer.py:150). + # ||W||_2 per input channel; frozen buffer captured at init. "lora_wnorm": ParamSpec((d_in,), init="ones", trainable=False, as_buffer=True), } @staticmethod def init(layer: nn.Module, cfg) -> None: - # DeLoRA needs ||W||_2 per input column. Plain nn.Linear: just read weight. - # bnb Linear8bitLt: weight is fp16 until first forward (then int8 + SCB), - # so capturing here works; quality is correct only because we read pre-quant. - # bnb Linear4bit / fully quantized layers: would give garbage. Use lora/ia3/hra - # for those. + # Reads weight pre-quant -- OK for nn.Linear and bnb 8bit (fp16 until 1st fwd). + # bnb Linear4bit gives garbage; use lora/ia3/hra for those. with torch.no_grad(): W = layer.weight.data.float() wnorm = W.norm(dim=0).detach().to(layer.lora_wnorm.dtype) @@ -94,9 +71,6 @@ class DeLoRA: cfg = layer._lora_cfg A = layer.lora_A # (r, d_in) B = layer.lora_B # (d_out, r) - # peft delora forward (docs/refs/peft_delora_layer.py:248-260): - # h = (x * w_norm) @ A.T; scale per-rank = (lambda/r) / (||A_i|| * ||B^j||); - # delta = (h * scale) @ B.T x_scaled = x * layer.lora_wnorm # (..., d_in) h = einsum(x_scaled, A, "... i, r i -> ... r") An = torch.clamp(A.norm(dim=1), min=1e-4) # (r,) diff --git a/src/lora_lite/variants/dora.py b/src/lora_lite/variants/dora.py index fea2b85..699d57f 100644 --- a/src/lora_lite/variants/dora.py +++ b/src/lora_lite/variants/dora.py @@ -2,20 +2,10 @@ W' = m * V / ||V||_c where V = W + (alpha/r) B A (||.||_c = per-output-row L2 norm) -At t=0: B=0 -> V=W -> y_new = (m_init / ||W||_c) (Wx + 0) = Wx when m_init = ||W||_c. +Identity at t=0: B=0 and m=||W||_c -> y_new = Wx. Requires dense weight (nn.Linear only). -Limitation: requires materializing the dense weight to compute ||V||_c. v1 supports -plain nn.Linear only; bnb 4/8-bit layers raise loudly. - -DEVIATION (numerical): - - We differentiate through ||V||_c every forward. The paper's sec. 4.3 suggests - a 'cost-saving' variant that detaches ||V|| in backward (gradient only flows - through V); we do NOT do that. Real impact: slower step, slightly different - gradient direction. Faithful to the eq.5 forward, not the optimized one. - -Reference implementations (for review/cross-check): - - peft DoRA (separate file under lora/): - https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/dora.py +Refs: + - peft: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/dora.py (offline: docs/refs/peft_lora_dora.py) """ import torch @@ -71,8 +61,7 @@ class DoRA: BA = einsum(layer.lora_B, layer.lora_A, "o r, r i -> o i") V = layer.weight + scale * BA # (d_out, d_in) v_norm = V.norm(dim=1).clamp_min(1e-12) # (d_out,) - # Bias passes through UNSCALED -- only Wx + scale*BAx is normalized. - # Matches peft DoRA forward (docs/refs/peft_lora_dora.py:157-161). + # Bias passes through unscaled (matches peft). bias = getattr(layer, "bias", None) wx = y if bias is None else (y - bias) h = einsum(x, layer.lora_A, "... i, r i -> ... r") diff --git a/src/lora_lite/variants/eva.py b/src/lora_lite/variants/eva.py index 7ef8223..c3e71a6 100644 --- a/src/lora_lite/variants/eva.py +++ b/src/lora_lite/variants/eva.py @@ -1,35 +1,16 @@ -"""EVA: Explained-Variance Adaptation. Paischer et al. 2024. +"""EVA: Explained-Variance Adaptation. Paischer et al. 2024 https://arxiv.org/abs/2410.07170 -Paper: https://arxiv.org/abs/2410.07170 (also referred to as ICLR'25 EVA). +LoRA forward `y + scale*(B@A@x)`; init A = top-r right singular vectors of the +layer-input distribution on a small calibration set (instead of kaiming). -Idea: instead of random A and zero B (LoRA) or SVD of W (PiSSA), initialize -`lora_A` to the top-r right singular vectors of the LAYER INPUT distribution -on a small calibration set. Forward = `y + scale * (B @ A @ x)` exactly like -LoRA; with `lora_B = 0` the adapter is identity at t=0. Only B trains -afterwards (A frozen). The result: each rank slot points along a direction -that actually carries information at this layer. +Identity at t=0: B=0. -This is a stripped-down EVA; we do NOT implement: - - rank redistribution across layers via explained-variance ratios - (peft EVA computes an explained_variance_ratio per layer then redistributes - the global rank budget; we use a uniform `cfg.r` per layer). - - Incremental PCA over many micro-batches (we run one full SVD on the - pooled calibration activations per layer). - - Equal-input deduplication (peft hashes inputs to share SVD across QKV). +Stripped down: uniform per-layer rank, single full SVD on pooled inputs, no QKV +input dedup. (peft does rank redistribution + IncrementalPCA + hash dedup.) -API stress-test: this variant requires data-driven init, so it implements -`group_init(model, targets, cfg, calibration_data)` to drive a single forward -pass on `calibration_data` with hooks that capture each target's input. - -Identity at t=0: `lora_B = 0` -> delta = 0 -> y unchanged. - -References: - - peft EVA (full impl, with IncrementalPCA + redistribution): - https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/eva.py - (offline: docs/refs/peft_eva.py) - - peft fine-tuning script demonstrating initialize_lora_eva_weights: - https://github.com/huggingface/peft/blob/main/examples/eva_finetuning/eva_finetuning.py - (offline: docs/refs/peft_eva_finetuning.py) +Refs: + - peft: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/eva.py + (offline: docs/refs/peft_eva.py; example: docs/refs/peft_eva_finetuning.py) """ import torch from einops import einsum @@ -58,13 +39,8 @@ class EVA: @staticmethod def param_specs(d_in, d_out, cfg): return { - # A is trainable Parameter (peft semantics): EVA only changes the INIT. - # peft copies SVD vectors into the LoRA A weight, which remains a regular - # nn.Linear.weight Parameter (docs/refs/peft_eva.py:529). - # On step 0 only B has nonzero grad (delta=0 since B=0); A starts moving - # once B becomes nonzero, same gradient pattern as DeLoRA. + # A trainable per peft: EVA only changes the init. "lora_A": ParamSpec((cfg.r, d_in), init="zeros", trainable=True), - # B is zero-init -> identity at t=0. "lora_B": ParamSpec((d_out, cfg.r), init="zeros", trainable=True), } diff --git a/src/lora_lite/variants/hra.py b/src/lora_lite/variants/hra.py index 2ed794c..909009e 100644 --- a/src/lora_lite/variants/hra.py +++ b/src/lora_lite/variants/hra.py @@ -1,31 +1,19 @@ """HRA: Householder Reflection Adaptation. Yuan et al. 2024 https://arxiv.org/abs/2405.17484 -Paper formulation (Sec. 3): adapt each frozen weight as + W' = W R, R = prod_{i=1..r} H_i, H_i = I - 2 u_i u_i^T / ||u_i||^2 - W' = W R, R = prod_{i=1..r} H_i, H_i = I - 2 u_i u_i^T / ||u_i||^2 +R is in input space (d_in x d_in); applied via a `forward_input` pre-hook so the +frozen base layer (bnb 4/8-bit OK) computes W (R x). -so the layer output becomes y' = W' x = W (R x). R is in INPUT space (d_in x d_in). +Identity at t=0: peft-style symmetric init -- U pairs (U[0]=U[1], ...) so adjacent +H_i H_i = I cancel, R = I exactly. Requires even r. Paired rows diverge after step 1. -We implement this via a `forward_input` pre-hook that returns `R x`, then the -frozen base layer (including bnb 4/8-bit Linear) computes `W (R x)` itself. +Note: paper's orthogonality regularizer (Eq. 6) is loss-side; add it in your loop. -Identity at t=0 (PEFT-style symmetric init, requires even r): - Rows are kaiming-init in pairs: U[0]=U[1], U[2]=U[3], ... Adjacent pairs of - Householder reflections with identical vectors cancel exactly - (H_i H_i = I), so R = I at init -> y' = y to bit-precision. - After the first gradient step the paired rows diverge and the chain becomes a - general orthogonal matrix; gradient flows into U from step 0 (no dead-grad). - Odd r is rejected (matches peft warning behaviour). - -OMITTED: paper also adds an orthogonality regularizer (Eq. 6 / Sec. 3.3), -a loss-side term. Add it in your training loop if you want regularized HRA. - -Reference implementations (for review/cross-check): - - HRA paper authors (DaShenZi721/HRA), llama variant of OFT layer with HRA: - https://github.com/DaShenZi721/HRA/blob/master/llama/peft/oft/layer_GS_HRA.py +Refs: + - paper code: https://github.com/DaShenZi721/HRA/blob/master/llama/peft/oft/layer_GS_HRA.py (offline: docs/refs/orig_hra_layer.py) - - peft HRA layer, reset_hra_parameters (lines 100-108): - https://github.com/huggingface/peft/blob/main/src/peft/tuners/hra/layer.py + - peft: https://github.com/huggingface/peft/blob/main/src/peft/tuners/hra/layer.py (offline: docs/refs/peft_hra_layer.py) """ import torch @@ -63,10 +51,7 @@ class HRA: @staticmethod def init(layer: nn.Module, cfg) -> None: - # Symmetric init per peft (docs/refs/peft_hra_layer.py:101-108): - # half = kaiming(r//2, d_in); U = repeat_interleave(half, 2, dim=0) - # Adjacent pairs (H_2k H_2k+1) cancel since H^2 = I, so R = I exactly, - # while gradient still flows into U from step 0. + # Symmetric init: kaiming(r//2, d_in) repeat-interleaved -> R = I, grad alive. with torch.no_grad(): r, d_in = layer.lora_U.shape half = torch.empty(r // 2, d_in, dtype=layer.lora_U.dtype, device=layer.lora_U.device) @@ -79,17 +64,7 @@ class HRA: layer: nn.Module, x: Float[T, '*B i'], ) -> Float[T, '*B i']: - """Apply x -> x R^T where R = H_0 H_1 ... H_{r-1}, H_i = I - 2 u_i u_i^T / ||u_i||^2. - - peft applies `W @ R` so y = F.linear(x, W@R) = x @ R^T @ W^T. Our pre-hook - produces `x @ R^T = x @ H_{r-1} ... H_0`, then the base layer computes - `(x R^T) @ W^T = (x R^T W^T)`, matching peft (docs/refs/peft_hra_layer.py:225-264). - - Iterate i = r-1 down to 0: each step right-multiplies x by H_i, building - x H_{r-1} H_{r-2} ... H_0 = x R^T. At symmetric init H_{2k} H_{2k+1} = I - regardless of order, so identity-at-t=0 holds either way; the order only - matters once paired rows diverge. - """ + """x -> x R^T = x H_{r-1} ... H_0. Iterate i = r-1 down to 0 to match peft.""" U = layer.lora_U # (r, d_in) Rx = x for i in range(U.shape[0] - 1, -1, -1): diff --git a/src/lora_lite/variants/ia3.py b/src/lora_lite/variants/ia3.py index b55437a..839ef64 100644 --- a/src/lora_lite/variants/ia3.py +++ b/src/lora_lite/variants/ia3.py @@ -1,28 +1,18 @@ -"""IA3-style elementwise gating. Liu et al. 2022 https://arxiv.org/abs/2205.05638 +"""IA3 elementwise gating. Liu et al. 2022 https://arxiv.org/abs/2205.05638 -Two registered variants, matching the paper's two regimes: +* `ia3` -- output-side: y_new = y * g, g shape (d_out,). For k_proj, v_proj. +* `ia3_ff` -- input-side: y_new = base(x * g), g shape (d_in,). For down_proj/fc2. -* `ia3` -- OUTPUT-side gating, parameter shape (d_out,). - y_new = y * g. Use for attention projections (k_proj, v_proj). +Identity at t=0: g=1. -* `ia3_ff` -- INPUT-side gating, parameter shape (d_in,). - y_new = base_layer(x * g). Use for FFN-down layers (down_proj, - fc2). Equivalent to the paper's "gate the FFN intermediate (post- - activation)" position because down_proj's input IS that - intermediate hidden state. - -In both cases g is initialized to 1 -> identity at t=0. - -To match the paper exactly on a Llama/Qwen-style block requires TWO attach -passes (one per variant), since each variant uses one hook type: +Example (paper's Llama/Qwen block needs both passes): cfg_attn = IA3Config( target_names=(r"\\.k_proj$", r"\\.v_proj$")) cfg_ffn = IA3FFConfig( target_names=(r"\\.down_proj$",)) -Reference implementation: - - peft IA3 layer (is_feedforward toggles input-vs-output gating, see - docs/refs/peft_ia3_layer.py:177-188 forward and :214 update_layer): - https://github.com/huggingface/peft/blob/main/src/peft/tuners/ia3/layer.py +Refs: + - peft: https://github.com/huggingface/peft/blob/main/src/peft/tuners/ia3/layer.py + (offline: docs/refs/peft_ia3_layer.py) """ import torch from jaxtyping import Float diff --git a/src/lora_lite/variants/lora.py b/src/lora_lite/variants/lora.py index ad256ff..feefe1a 100644 --- a/src/lora_lite/variants/lora.py +++ b/src/lora_lite/variants/lora.py @@ -2,12 +2,11 @@ h = W x + (alpha/r) B A x -Identity at t=0 from B=0. Faithful to the paper. +Identity at t=0 from B=0. -Reference implementations (for review/cross-check): - - peft Linear.update_layer + lora_A/B init, forward: - https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/layer.py - (see docs/refs/peft_lora_layer.py for offline copy) +Refs: + - peft: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/layer.py + (offline: docs/refs/peft_lora_layer.py) """ from einops import einsum from jaxtyping import Float diff --git a/src/lora_lite/variants/pissa.py b/src/lora_lite/variants/pissa.py index a52ebcb..61ae6e7 100644 --- a/src/lora_lite/variants/pissa.py +++ b/src/lora_lite/variants/pissa.py @@ -1,24 +1,18 @@ """PiSSA: top-r SVD of W into A,B; replace W with W_res = W - B@A. Meng et al. 2024 https://arxiv.org/abs/2404.02948 -W_eff(t=0) = W_res + B@A = W (numerically; bf16 round-trip not bit-exact). -DEVIATION FROM PAPER (documented): - - Paper sets adapter scale = 1 (no alpha/r factor); we keep LoRA's alpha/r - pipeline so callers must pass alpha=r to get paper-faithful identity. - - Saved adapter does NOT include W_res (would double checkpoint size). Instead - `adapter.save` records a fingerprint of the post-init base weights and - `adapter.load` re-runs PiSSA init then verifies the fingerprint matches - -- so loading onto a different base weight raises loudly instead of - silently producing wrong outputs. + W = U S Vh (truncated to top-r) + B = U sqrt(S), A = sqrt(S) Vh, W_res = W - B A -Reference implementations (for review/cross-check): - - PiSSA original (NeurIPS'24 spotlight) init script (SVD on dequant W): - https://github.com/MuLabPKU/PiSSA/blob/main/utils/init_pissa.py +Identity at t=0: W_res + B@A == W (bf16 round-trip, not bit-exact). +Pass alpha=r for paper-faithful scale=1. + +Refs: + - paper: https://github.com/MuLabPKU/PiSSA/blob/main/utils/init_pissa.py (offline: docs/refs/orig_pissa_init.py) - - peft PiSSA flavor (init_lora_weights='pissa') in: - https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/layer.py - (offline: docs/refs/peft_lora_layer.py, see pissa_init / loftq_init paths) + - peft: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/layer.py + (offline: docs/refs/peft_lora_layer.py, see pissa_init path) """ import torch from einops import einsum @@ -64,10 +58,8 @@ class PiSSA: A = (sqrtS[:, None] * Vhr).to(cfg.dtype) layer.lora_B.data.copy_(B) layer.lora_A.data.copy_(A) - # Compute BA in fp32 for the subtraction so W_res is accurate. + # fp32 subtraction so W_res stays accurate. BA = (B.float() @ A.float()) - # NOTE: PiSSA uses scale=1 (alpha==r) implicitly via init. We let the user pick; - # for fidelity at t=0, the convention is alpha==r so scale==1. Document in README. scale = cfg.alpha / cfg.r layer.weight.data.copy_((W - scale * BA).to(layer.weight.dtype))