feat(hra): add Householder Reflection Adaptation, hook-only/bnb-friendly + Qwen proof

This commit is contained in:
wassname
2026-04-26 17:58:56 +08:00
parent 43e620176c
commit 0d929f93b3
8 changed files with 72 additions and 10 deletions
+2 -1
View File
@@ -46,7 +46,8 @@ See [docs/spec/20260426_lora_lite_plan.md](docs/spec/20260426_lora_lite_plan.md)
| DeLoRA | yes | normalized additive adapter with learned scalar | | DeLoRA | yes | normalized additive adapter with learned scalar |
| IA3 | yes | output gate initialized to ones | | IA3 | yes | output gate initialized to ones |
| DoRA | yes, fp only | reads dense `weight` for column-norm; quantized DoRA fails loudly | | DoRA | yes, fp only | reads dense `weight` for column-norm; quantized DoRA fails loudly |
| SSVD / OFT / HRA / ROAD | no | planned after the hook-only invariant is clear | | HRA | yes | output-side Householder reflection with identity gate; works on bnb |
| SSVD / OFT / ROAD | no | planned |
| S-steer / AntiPaSTO | no | should use data-calibrated `group_init`, not plain LoRA tests | | S-steer / AntiPaSTO | no | should use data-calibrated `group_init`, not plain LoRA tests |
## Targeting ## Targeting
+2 -1
View File
@@ -76,6 +76,7 @@ Activation-aware variants implement `group_init(model, targets, cfg, calibration
|---|---|---| |---|---|---|
| IA3 | Done. Output gate `y * g`, identity at `g=1`. | Qwen proof in latest probe. | | IA3 | Done. Output gate `y * g`, identity at `g=1`. | Qwen proof in latest probe. |
| DoRA | Done for fp layers. Reads dense `weight` to compute `||V||_c`; quantized layers fail fast. | Qwen proof in latest probe. | | DoRA | Done for fp layers. Reads dense `weight` to compute `||V||_c`; quantized layers fail fast. | Qwen proof in latest probe. |
| HRA | Done. Output-side Householder with identity gate; hook-only -> works on bnb. | Qwen proof in latest probe. |
| SSVD / PiSSA-family | Fits weight-SVD init path. | reconstruction/identity invariant plus train proof. | | SSVD / PiSSA-family | Fits weight-SVD init path. | reconstruction/identity invariant plus train proof. |
| HRA / OFT / ROAD | Interesting, but weight-transform semantics need clearer hook-only formulation. | pseudocode first, then rotation/non-dead-code invariant. | | OFT / ROAD | Block-diagonal rotations; weight-transform semantics need clearer hook-only formulation. | pseudocode first, then rotation/non-dead-code invariant. |
| S-steer / AntiPaSTO | Should use `group_init` and activation evidence. | calibration consumed, hooks removed, load works without calibration. | | S-steer / AntiPaSTO | Should use `group_init` and activation evidence. | calibration consumed, hooks removed, load works without calibration. |
+2
View File
@@ -38,6 +38,7 @@ The core bet is that adapter variants should own the relationship between `(x, l
| DeLoRA | done | `src/lora_lite/variants/delora.py` | | DeLoRA | done | `src/lora_lite/variants/delora.py` |
| IA3 | done | `src/lora_lite/variants/ia3.py` | | IA3 | done | `src/lora_lite/variants/ia3.py` |
| DoRA | done, fp-only | `src/lora_lite/variants/dora.py` | | DoRA | done, fp-only | `src/lora_lite/variants/dora.py` |
| HRA | done | `src/lora_lite/variants/hra.py` (output-side Householder, hook-only -> bnb-compatible) |
| Smoke tests | done | `tests/smoke.py` | | Smoke tests | done | `tests/smoke.py` |
| bnb minimal forward smoke | done | `Linear8bitLt` and `Linear4bit` pass on CUDA with `just bnb-smoke` | | bnb minimal forward smoke | done | `Linear8bitLt` and `Linear4bit` pass on CUDA with `just bnb-smoke` |
@@ -116,6 +117,7 @@ Follow-up tasks 80 (lora/pissa/delora/ia3 at 16 steps) and 81 (dora at 16 steps)
| delora | 2 | 20482 | 0.3281 | 0.3125 | 5.261 | 4.823 | 8.322 | 0.06303 | 15.1 | 0 | `outputs/qwen_train_probe/delora_adapter.pt` | | delora | 2 | 20482 | 0.3281 | 0.3125 | 5.261 | 4.823 | 8.322 | 0.06303 | 15.1 | 0 | `outputs/qwen_train_probe/delora_adapter.pt` |
| ia3 | 2 | 3072 | 0 | 0.375 | 5.25 | 4.473 | 14.79 | 0.463 | 5.926 | 0 | `outputs/qwen_train_probe/ia3_adapter.pt` | | ia3 | 2 | 3072 | 0 | 0.375 | 5.25 | 4.473 | 14.79 | 0.463 | 5.926 | 0 | `outputs/qwen_train_probe/ia3_adapter.pt` |
| dora | 2 | 23552 | 0 | 0.3203 | 5.25 | 2.439 | 53.54 | 1.776 | 7.44 | 0 | `outputs/qwen_train_probe/dora_adapter.pt` | | dora | 2 | 23552 | 0 | 0.3203 | 5.25 | 2.439 | 53.54 | 1.776 | 7.44 | 0 | `outputs/qwen_train_probe/dora_adapter.pt` |
| hra | 2 | 12290 | 0 | 0.3438 | 5.25 | 4.07 | 22.47 | 0.05225 | 4.735 | 0 | `outputs/qwen_train_probe/hra_adapter.pt` |
Failure-mode interpretation: Failure-mode interpretation:
+6 -1
View File
@@ -53,6 +53,11 @@ def perturb_first_adapter(model: torch.nn.Module) -> None:
with torch.no_grad(): with torch.no_grad():
p.add_(0.25) p.add_(0.25)
return return
for name, p in model.named_parameters():
if "lora_gate" in name:
with torch.no_grad():
p.add_(0.25)
return
for name, p in model.named_parameters(): for name, p in model.named_parameters():
if "lora_B" in name: if "lora_B" in name:
with torch.no_grad(): with torch.no_grad():
@@ -173,7 +178,7 @@ def run_variant(args, variant: str, input_ids: torch.Tensor, labels: torch.Tenso
def main() -> None: def main() -> None:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model", default="Qwen/Qwen3-0.6B") parser.add_argument("--model", default="Qwen/Qwen3-0.6B")
parser.add_argument("--variants", nargs="+", default=["lora", "pissa", "delora", "ia3", "dora"]) parser.add_argument("--variants", nargs="+", default=["lora", "pissa", "delora", "ia3", "dora", "hra"])
parser.add_argument("--device", default="cuda") parser.add_argument("--device", default="cuda")
parser.add_argument("--torch-dtype", default="bfloat16") parser.add_argument("--torch-dtype", default="bfloat16")
parser.add_argument("--steps", type=int, default=8) parser.add_argument("--steps", type=int, default=8)
+1 -1
View File
@@ -1 +1 @@
from . import lora, pissa, delora, ia3, dora # noqa: F401 side-effect: register from . import lora, pissa, delora, ia3, dora, hra # noqa: F401 side-effect: register
+47
View File
@@ -0,0 +1,47 @@
"""HRA: Householder Reflection Adaptation. Yuan et al. 2024 https://arxiv.org/abs/2405.17484
Output-side formulation with an identity-init gate:
y' = (1 - alpha) * y + alpha * R y (so y' = y when alpha = 0)
R = prod_{i=1..r} H_i, H_i = I - 2 u_i u_i^T / ||u_i||^2
`lora_gate` is initialized to 0 so y' = y at t=0. `lora_U` is initialized
kaiming so ||u_i||^2 is well-defined (no 0/0). Gradients flow into both U and
the gate even at init.
Hook-only, no weight access -> works on bnb 4/8-bit layers.
"""
import torch
from einops import einsum
from torch import nn
from ..variant import register, ParamSpec
@register
class HRA:
name = "hra"
@staticmethod
def param_specs(d_in, d_out, cfg):
return {
# one Householder vector per rank slot in R^{d_out}
"lora_U": ParamSpec((cfg.r, d_out), init="kaiming", trainable=True),
# identity gate; 0 -> y' = y exactly
"lora_gate": ParamSpec((), init="zeros", trainable=True),
}
@staticmethod
def init(layer: nn.Linear, cfg) -> None:
return
@staticmethod
def forward(layer: nn.Linear, x, y):
U = layer.lora_U # (r, d_out)
Ry = y
for i in range(U.shape[0]):
u = U[i]
sq = (u * u).sum().clamp_min(1e-12)
coeff = einsum(Ry, u, "... o, o -> ...") * (2.0 / sq)
Ry = Ry - coeff.unsqueeze(-1) * u
return y + layer.lora_gate * (Ry - y)
+3 -2
View File
@@ -131,6 +131,7 @@ def variant_test(variant: str, dtype=torch.float32):
"delora": 1e-6, # lambda0=0 "delora": 1e-6, # lambda0=0
"ia3": 1e-6, "ia3": 1e-6,
"dora": 5e-5, # m * V/||V|| with V=W -> rounding in norm/divide "dora": 5e-5, # m * V/||V|| with V=W -> rounding in norm/divide
"hra": 1e-6, # gate=0 -> exact identity
}[variant] * max(1.0, base_scale) }[variant] * max(1.0, base_scale)
assert err < tol, f" FAIL identity: err {err} > tol {tol}" assert err < tol, f" FAIL identity: err {err} > tol {tol}"
print(f" SHOULD: err<{tol:.1e}. PASS.") print(f" SHOULD: err<{tol:.1e}. PASS.")
@@ -168,7 +169,7 @@ def variant_test(variant: str, dtype=torch.float32):
target = torch.randn(2, 16, 100, dtype=dtype) * 0.1 target = torch.randn(2, 16, 100, dtype=dtype) * 0.1
trainable = [p for p in model.parameters() if p.requires_grad] trainable = [p for p in model.parameters() if p.requires_grad]
# delora has tightly-normalised updates; use Adam with higher lr to see signal in 20 steps # delora has tightly-normalised updates; use Adam with higher lr to see signal in 20 steps
if variant in ("delora", "ia3"): if variant in ("delora", "ia3", "hra"):
opt = torch.optim.Adam(trainable, lr=1e-1) opt = torch.optim.Adam(trainable, lr=1e-1)
elif variant == "dora": elif variant == "dora":
opt = torch.optim.Adam(trainable, lr=1e-3) # m near ||W||_c, bigger lr blows up opt = torch.optim.Adam(trainable, lr=1e-3) # m near ||W||_c, bigger lr blows up
@@ -254,7 +255,7 @@ def main():
parser.add_argument("--require-bnb", action="store_true") parser.add_argument("--require-bnb", action="store_true")
args = parser.parse_args() args = parser.parse_args()
for v in ("lora", "pissa", "delora", "ia3", "dora"): for v in ("lora", "pissa", "delora", "ia3", "dora", "hra"):
variant_test(v, dtype=torch.float32) variant_test(v, dtype=torch.float32)
structural_linear_like_test() structural_linear_like_test()
bitsandbytes_cuda_smoke(args.require_bnb) bitsandbytes_cuda_smoke(args.require_bnb)
+9 -4
View File
@@ -98,6 +98,11 @@ def perturb_first_adapter(model: nn.Module) -> None:
with torch.no_grad(): with torch.no_grad():
p.add_(0.25) p.add_(0.25)
return return
for name, p in model.named_parameters():
if "lora_gate" in name:
with torch.no_grad():
p.add_(0.25)
return
for name, p in model.named_parameters(): for name, p in model.named_parameters():
if "lora_B" in name: if "lora_B" in name:
with torch.no_grad(): with torch.no_grad():
@@ -111,7 +116,7 @@ def perturb_first_adapter(model: nn.Module) -> None:
raise AssertionError("no perturbable adapter parameter found") raise AssertionError("no perturbable adapter parameter found")
@pytest.mark.parametrize("variant", ["lora", "pissa", "delora", "ia3", "dora"]) @pytest.mark.parametrize("variant", ["lora", "pissa", "delora", "ia3", "dora", "hra"])
def test_variant_identity_hook_save_load_and_training(variant: str): def test_variant_identity_hook_save_load_and_training(variant: str):
ARTIFACT_DIR.mkdir(exist_ok=True) ARTIFACT_DIR.mkdir(exist_ok=True)
torch.manual_seed(0) torch.manual_seed(0)
@@ -129,7 +134,7 @@ def test_variant_identity_hook_save_load_and_training(variant: str):
with torch.no_grad(): with torch.no_grad():
y_init = model(ids).clone() y_init = model(ids).clone()
identity_err = (y_init - y_base).abs().max().item() identity_err = (y_init - y_base).abs().max().item()
identity_tol = {"lora": 1e-6, "pissa": 5e-4, "delora": 1e-6, "ia3": 1e-6, "dora": 5e-5}[variant] identity_tol = {"lora": 1e-6, "pissa": 5e-4, "delora": 1e-6, "ia3": 1e-6, "dora": 5e-5, "hra": 1e-6}[variant]
assert identity_err < identity_tol assert identity_err < identity_tol
before_perturb = adapter_state(model) before_perturb = adapter_state(model)
@@ -162,7 +167,7 @@ def test_variant_identity_hook_save_load_and_training(variant: str):
assert_only_lora_trainable(train_model) assert_only_lora_trainable(train_model)
target = torch.randn(2, 16, 100) * 0.1 target = torch.randn(2, 16, 100) * 0.1
trainable = [p for p in train_model.parameters() if p.requires_grad] trainable = [p for p in train_model.parameters() if p.requires_grad]
opt = torch.optim.Adam(trainable, lr=0.1) if variant in ("delora", "ia3") else ( opt = torch.optim.Adam(trainable, lr=0.1) if variant in ("delora", "ia3", "hra") else (
torch.optim.Adam(trainable, lr=1e-3) if variant == "dora" else torch.optim.SGD(trainable, lr=1e-2) torch.optim.Adam(trainable, lr=1e-3) if variant == "dora" else torch.optim.SGD(trainable, lr=1e-2)
) )
losses = [] losses = []
@@ -221,7 +226,7 @@ def test_no_target_layers_is_loud_failure():
ll.attach(TinyModel(), cfg) ll.attach(TinyModel(), cfg)
@pytest.mark.parametrize("variant", ["lora", "delora", "ia3"]) @pytest.mark.parametrize("variant", ["lora", "delora", "ia3", "hra"])
def test_structural_non_linear_target_trains_for_forward_only_variants(variant: str): def test_structural_non_linear_target_trains_for_forward_only_variants(variant: str):
torch.manual_seed(0) torch.manual_seed(0)
model = FakeBnbModel() model = FakeBnbModel()