feat: near_zero/near_one init for trainable params (breaks bf16 dead-grad symmetry)

Trainable params that were init'd at exact 0 or 1 now use near_zero (N(0,1e-4))
or near_one (1 + N(0,1e-4)) to break bf16 symmetry without meaningfully
breaking identity-at-t=0. Exact-zero init is kept where zero IS the identity
constraint (DeLoRA lora_B, EVA lora_B -- both scaled by other params so any
nonzero B would blow up the output).

AntiPaSTO: delta_s and rot_T now near_zero. The old exact-zero could leave
rotation learning dead in bf16 where step sizes round back to zero.

IA3: lora_g now near_one instead of exact ones. Avoids the bf16 spacing issue
around 1.0 where eps_bf16 ~ 7.8e-3 and lr=1e-3 updates were rounding away.

PiSSA: lora_A and lora_B now near_zero (both overwritten by SVD in init(),
so the init value is moot -- but ParamSpec now documents intent correctly).

HRA: lora_U now near_zero (overwritten by symmetric init in init()).

ParamSpec: added 'near_zero' and 'near_one' init modes. Default changed from
'zeros' to 'near_zero'. Tests relaxed identity tolerances accordingly.
This commit is contained in:
wassname
2026-04-27 15:55:05 +08:00
parent 0bd091fe5b
commit e624cd244f
15 changed files with 69 additions and 46 deletions
+1
View File
@@ -12,6 +12,7 @@ build/
dist/ dist/
*.egg-info/ *.egg-info/
logs/ logs/
docs/spec/
outputs/ outputs/
tests/_artifacts/ tests/_artifacts/
docs/papers/*.pdf docs/papers/*.pdf
+16 -12
View File
@@ -47,21 +47,25 @@ just qwen-probe # Qwen/Qwen3-0.6B train/save-load probe
## Variants ## Variants
| Variant | 4bit/8bit | GSM8K % | | Variant | 4bit/8bit | GSM8K % | Params | Peak GPU (GB) |
| --------------------------------------------- | --------- | ------- | | --------------------------------------------- | --------- | ------- | ---------- | ------------- |
| [LoRA](https://arxiv.org/abs/2106.09685) | yes | 63.2% | | [LoRA](https://arxiv.org/abs/2106.09685) | yes | 63.2% | 4.59M | — |
| [PiSSA](https://arxiv.org/abs/2404.02948) | no | | | [PiSSA](https://arxiv.org/abs/2404.02948) | no | 63.2% | 4.59M | — |
| [DeLoRA](https://arxiv.org/abs/2503.18225) | yes | — | | [DoRA](https://arxiv.org/abs/2402.09353) | no | 62.4% | 4.67M | — |
| [IA3](https://arxiv.org/pdf/2205.05638) | yes | — | | [DeLoRA](https://arxiv.org/abs/2503.18225) | yes | 61.5% | 4.59M | — |
| [DoRA](https://arxiv.org/abs/2402.09353) | no | | | [EVA](https://arxiv.org/abs/2409.07871) | no | 60.3% | 4.59M | 11.3 |
| [HRA](https://arxiv.org/abs/2409.01434) | yes | | | [IA3](https://arxiv.org/pdf/2205.05638) | yes | 60.0% | 57K | 11.4 |
| [EVA](https://arxiv.org/abs/2409.07871) | no | | | [IA3-FF](https://arxiv.org/pdf/2205.05638) | yes | 61.4% | 86K | 11.4 |
| [AntiPaSTO](https://arxiv.org/abs/2503.08696) | no | — | | [HRA](https://arxiv.org/abs/2409.01434) | yes | — | — | — |
| [AntiPaSTO](https://arxiv.org/abs/2503.08696) | no | 30.6% | 4.5K | 11.3 |
Our test setup: We take Qwen3-0.6B-Base and train one MetaMathQA for 5000 steps. We use a rank of 32, and itnervene on all linear layer then test on GSM8K. Params = trainable adapter params. Peak GPU = peak CUDA memory during train+eval (logged from this run onward; older runs predate the column).
Setup: Qwen3-0.6B-Base, MetaMathQA train (5k steps, batch 4 = 20k samples), r=32, all q/v targets, GSM8K test (1319 examples).
Is this a good accuracy? TODO we need a like-for-like comparison against PEFT LoRA in the same setup before drawing conclusions. But the [PEFT library](https://github.com/huggingface/peft#results) reports LoRA at 49.0% on Llama-3.2-3B (different model and sample count). Reference: PEFT reports LoRA at 49.0% on Llama-3.2-3B (different model, different sample count). Our numbers are not directly comparable but suggest the adapters work.
AntiPaSTO at 30.6% is expected: it has only 4.5K trainable params (singular-value deltas + rotation) and needs more steps or a different lr schedule to converge from a cold start.
## Developer docs ## Developer docs
+6 -2
View File
@@ -80,10 +80,14 @@ bench-variant model variant steps="5000":
set -euo pipefail set -euo pipefail
lr=1e-4 lr=1e-4
target='(q_proj|v_proj)$' target='(q_proj|v_proj)$'
# IA3 lr: paper uses 3e-3 to 1e-2 (Liu et al. 2022 §3.3). Also a hard
# bf16 floor: lora_g inits to 1.0 where bf16 spacing is ~7.8e-3, so
# AdamW updates with lr<<3.9e-3 round back to 1.0 and the param freezes.
# 5e-3 is paper-faithful AND clears the bf16 round-to-nearest threshold.
case "{{variant}}" in case "{{variant}}" in
delora) lr=1e-3 ;; delora) lr=1e-3 ;;
ia3) lr=1e-3; target='(k_proj|v_proj)$' ;; ia3) lr=5e-3; target='(k_proj|v_proj)$' ;;
ia3_ff) lr=1e-3; target='(down_proj)$' ;; ia3_ff) lr=5e-3; target='(down_proj)$' ;;
esac esac
exec uv run --extra benchmark python scripts/metamath_gsm8k_benchmark.py \ exec uv run --extra benchmark python scripts/metamath_gsm8k_benchmark.py \
--model '{{model}}' \ --model '{{model}}' \
+9 -1
View File
@@ -442,7 +442,7 @@ def print_final_report(row: dict[str, Any], result_path: Path, mode: str) -> Non
print("SHOULD: grad>0, dθ>0, base_grad_leaks=0; test/valid_acc meaningful only in benchmark mode. ELSE adapter or eval wiring is dead/wrong.") print("SHOULD: grad>0, dθ>0, base_grad_leaks=0; test/valid_acc meaningful only in benchmark mode. ELSE adapter or eval wiring is dead/wrong.")
print() print()
# ordered: most important / shortest columns first # ordered: most important / shortest columns first
display_keys = ["variant", "test_acc", "valid_acc", "grad", "", "base_grad_leaks", "steps", "samples", "loss0", "lossN", "commit"] display_keys = ["variant", "test_acc", "valid_acc", "params_M", "peak_mem_GB", "grad", "", "base_grad_leaks", "steps", "samples", "loss0", "lossN", "commit"]
if "perturb" in row: if "perturb" in row:
display_keys += ["perturb", "reload"] display_keys += ["perturb", "reload"]
display_keys += ["run_id"] display_keys += ["run_id"]
@@ -480,6 +480,8 @@ def append_results_row(
"method": args.variant, "method": args.variant,
"steps": args.steps, "steps": args.steps,
"samples": result["train_samples"], "samples": result["train_samples"],
"params_M": round(result["trainable_param_count"] / 1e6, 4),
"peak_mem_GB": round(result.get("peak_cuda_mem_gb", 0.0), 3),
"model": args.model, "model": args.model,
"commit": run_commit[:12], "commit": run_commit[:12],
"wall_time_s": round(result["wall_time_s"]), "wall_time_s": round(result["wall_time_s"]),
@@ -530,10 +532,13 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
probe_metrics = probe_before_train(model, batches[0], attached["targets"]) probe_metrics = probe_before_train(model, batches[0], attached["targets"])
model.train() model.train()
if args.device == "cuda":
torch.cuda.reset_peak_memory_stats()
started = time.time() started = time.time()
train_metrics = train(model, batches, args) train_metrics = train(model, batches, args)
valid_metrics = evaluate(model, tokenizer, datasets["valid"], args, "valid") valid_metrics = evaluate(model, tokenizer, datasets["valid"], args, "valid")
test_metrics = evaluate(model, tokenizer, datasets["test"], args, "test") test_metrics = evaluate(model, tokenizer, datasets["test"], args, "test")
peak_mem_gb = (torch.cuda.max_memory_allocated() / 1024**3) if args.device == "cuda" else 0.0
adapter_path = out_dir / "adapter.safetensors" adapter_path = out_dir / "adapter.safetensors"
ll.save(model, str(adapter_path)) ll.save(model, str(adapter_path))
@@ -581,6 +586,7 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
"probe": probe_metrics, "probe": probe_metrics,
"adapter_path": str(adapter_path), "adapter_path": str(adapter_path),
"wall_time_s": time.time() - started, "wall_time_s": time.time() - started,
"peak_cuda_mem_gb": peak_mem_gb,
} }
result_path = out_dir / "result.json" result_path = out_dir / "result.json"
result_path.write_text(json.dumps(result, indent=2), encoding="utf-8") result_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
@@ -604,6 +610,8 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
"base_grad_leaks": train_metrics["base_grad_leaks"], "base_grad_leaks": train_metrics["base_grad_leaks"],
"valid_acc": valid_metrics["accuracy"], "valid_acc": valid_metrics["accuracy"],
"test_acc": test_metrics["accuracy"], "test_acc": test_metrics["accuracy"],
"params_M": round(result["trainable_param_count"] / 1e6, 4),
"peak_mem_GB": round(peak_mem_gb, 3),
"commit": run_commit[:12], "commit": run_commit[:12],
"result": str(result_path), "result": str(result_path),
} }
+9 -1
View File
@@ -10,7 +10,7 @@ from .config import AdapterConfig
@dataclass @dataclass
class ParamSpec: class ParamSpec:
shape: tuple[int, ...] shape: tuple[int, ...]
init: str | Callable[[torch.Tensor], None] = "zeros" # 'zeros'|'kaiming'|'ones'|callable(t) init: str | Callable[[torch.Tensor], None] = "near_zero" # 'zeros'|'near_zero'|'kaiming'|'ones'|callable(t)
trainable: bool = True trainable: bool = True
as_buffer: bool = False # if True, register_buffer instead of register_parameter as_buffer: bool = False # if True, register_buffer instead of register_parameter
@@ -20,6 +20,14 @@ class ParamSpec:
self.init(t) self.init(t)
elif self.init == "zeros": elif self.init == "zeros":
t.zero_() t.zero_()
elif self.init == "near_zero":
# ~identity init but breaks bf16 symmetry: N(0, eps) where eps is a few
# orders above bf16 spacing at 0 (eps_bf16 ~ 1.2e-7). Avoids dead-grad
# from exact-zero -> exact-zero in low-precision training.
t.normal_(0, 1e-4)
elif self.init == "near_one":
# ~identity init for gate/scale params: 1 + N(0, eps)
t.fill_(1.0).add_(torch.randn_like(t).mul_(1e-4))
elif self.init == "ones": elif self.init == "ones":
t.fill_(1.0) t.fill_(1.0)
elif self.init == "kaiming": elif self.init == "kaiming":
+4 -6
View File
@@ -7,15 +7,13 @@ wassname 2026 https://arxiv.org/abs/2601.07473
R = block_diag(Cayley(skew(rot_T))); Vh_eff = R @ Vh (or U_eff = U @ R.T) R = block_diag(Cayley(skew(rot_T))); Vh_eff = R @ Vh (or U_eff = U @ R.T)
y = x @ W_res.T + ((x @ Vh_eff.T) * (S + delta_s)) @ U_eff.T y = x @ W_res.T + ((x @ Vh_eff.T) * (S + delta_s)) @ U_eff.T
Identity at t=0: rot_T=0 -> R=I, delta_s=0 -> y == x @ W^T (fp32 SVD round-trip). Identity at t=0: rot_T~0 -> RI, delta_s~0 -> y x @ W^T (fp32 SVD round-trip). near_zero init breaks bf16 symmetry without meaningfully breaking identity (~1e-4 noise around zero).
Scope cut vs antipasto3: this is a fine-tuning adapter, not the full runtime Scope cut vs antipasto3: this is a fine-tuning adapter, not the full runtime
steering interface. There is no per-call alpha, so it does not expose the steering interface. There is no per-call alpha, so it does not expose the
bidirectional R(+alpha) / R(-alpha) inference symmetry. The V-basis path uses the bidirectional R(+alpha) / R(-alpha) inference symmetry. The V-basis path uses the
opposite chirality to antipasto3's default U-basis path, so checkpoints are not opposite chirality to antipasto3's default U-basis path, so checkpoints are not
portable without a sign/basis convention. Zero-init is stricter identity than portable without a sign/basis convention.
antipasto3's small positive/random symmetry-breaking init, but can leave rotation
learning to be started by the task gradient rather than init noise.
Refs: Refs:
- paper: https://github.com/wassname/AntiPaSTO - paper: https://github.com/wassname/AntiPaSTO
@@ -85,8 +83,8 @@ class AntiPaSTO:
lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True), lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True),
lora_Vh=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True), lora_Vh=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True),
# Trainable: per-singular-value delta + block-diagonal Cayley rotation. # Trainable: per-singular-value delta + block-diagonal Cayley rotation.
lora_delta_s=ParamSpec((r,), init="zeros"), lora_delta_s=ParamSpec((r,), init="near_zero"),
lora_rot_T=ParamSpec((n_blocks, n_triu), init="zeros"), lora_rot_T=ParamSpec((n_blocks, n_triu), init="near_zero"),
) )
@staticmethod @staticmethod
+1 -1
View File
@@ -45,7 +45,7 @@ class DeLoRA:
lam0 = float(cfg.lambda0) lam0 = float(cfg.lambda0)
return dict( return dict(
lora_A=ParamSpec((cfg.r, d_in), init="kaiming"), lora_A=ParamSpec((cfg.r, d_in), init="kaiming"),
lora_B=ParamSpec((d_out, cfg.r), init="zeros"), lora_B=ParamSpec((d_out, cfg.r), init="zeros"), # exact-zero for identity: lambda scales B@A
lora_lambda=ParamSpec((), init=lambda t: t.fill_(lam0)), lora_lambda=ParamSpec((), init=lambda t: t.fill_(lam0)),
# ||W||_2 per input channel; frozen buffer captured at init. # ||W||_2 per input channel; frozen buffer captured at init.
lora_wnorm=ParamSpec((d_in,), init="ones", trainable=False, as_buffer=True), lora_wnorm=ParamSpec((d_in,), init="ones", trainable=False, as_buffer=True),
+2 -2
View File
@@ -32,9 +32,9 @@ class DoRA:
def param_specs(d_in, d_out, cfg): def param_specs(d_in, d_out, cfg):
return dict( return dict(
lora_A=ParamSpec((cfg.r, d_in), init="kaiming"), lora_A=ParamSpec((cfg.r, d_in), init="kaiming"),
lora_B=ParamSpec((d_out, cfg.r), init="zeros"), lora_B=ParamSpec((d_out, cfg.r), init="near_zero"),
# m is filled from ||W||_c during init(); shape (d_out,) # m is filled from ||W||_c during init(); shape (d_out,)
lora_m=ParamSpec((d_out,), init="zeros"), lora_m=ParamSpec((d_out,), init="near_zero"),
) )
@staticmethod @staticmethod
+2 -2
View File
@@ -40,8 +40,8 @@ class EVA:
def param_specs(d_in, d_out, cfg): def param_specs(d_in, d_out, cfg):
return dict( return dict(
# A trainable per peft: EVA only changes the init. # A trainable per peft: EVA only changes the init.
lora_A=ParamSpec((cfg.r, d_in), init="zeros"), lora_A=ParamSpec((cfg.r, d_in), init="zeros"), # overwritten by group_init SVD basis
lora_B=ParamSpec((d_out, cfg.r), init="zeros"), lora_B=ParamSpec((d_out, cfg.r), init="zeros"), # exact-zero for identity
) )
@staticmethod @staticmethod
+1 -1
View File
@@ -46,7 +46,7 @@ class HRA:
return dict( return dict(
# Householder vectors stacked as rows (one vector per rank slot) # Householder vectors stacked as rows (one vector per rank slot)
# init done in init() to enforce paired rows -> R = I at t=0. # init done in init() to enforce paired rows -> R = I at t=0.
lora_U=ParamSpec((cfg.r, d_in), init="zeros"), lora_U=ParamSpec((cfg.r, d_in), init="near_zero"),
) )
@staticmethod @staticmethod
+2 -2
View File
@@ -41,7 +41,7 @@ class IA3:
@staticmethod @staticmethod
def param_specs(d_in, d_out, cfg): def param_specs(d_in, d_out, cfg):
return dict(lora_g=ParamSpec((d_out,), init="ones")) return dict(lora_g=ParamSpec((d_out,), init="near_one"))
@staticmethod @staticmethod
def init(layer: nn.Module, cfg) -> None: def init(layer: nn.Module, cfg) -> None:
@@ -62,7 +62,7 @@ class IA3FF:
@staticmethod @staticmethod
def param_specs(d_in, d_out, cfg): def param_specs(d_in, d_out, cfg):
return dict(lora_g=ParamSpec((d_in,), init="ones")) return dict(lora_g=ParamSpec((d_in,), init="near_one"))
@staticmethod @staticmethod
def init(layer: nn.Module, cfg) -> None: def init(layer: nn.Module, cfg) -> None:
+1 -1
View File
@@ -32,7 +32,7 @@ class LoRA:
def param_specs(d_in, d_out, cfg): def param_specs(d_in, d_out, cfg):
return dict( return dict(
lora_A=ParamSpec((cfg.r, d_in), init="kaiming"), lora_A=ParamSpec((cfg.r, d_in), init="kaiming"),
lora_B=ParamSpec((d_out, cfg.r), init="zeros"), lora_B=ParamSpec((d_out, cfg.r), init="near_zero"),
) )
@staticmethod @staticmethod
+2 -2
View File
@@ -38,8 +38,8 @@ class PiSSA:
@staticmethod @staticmethod
def param_specs(d_in, d_out, cfg): def param_specs(d_in, d_out, cfg):
return dict( return dict(
lora_A=ParamSpec((cfg.r, d_in), init="zeros"), lora_A=ParamSpec((cfg.r, d_in), init="near_zero"),
lora_B=ParamSpec((d_out, cfg.r), init="zeros"), lora_B=ParamSpec((d_out, cfg.r), init="near_zero"),
) )
@staticmethod @staticmethod
+2 -2
View File
@@ -117,8 +117,8 @@ class ROAD:
_validate_group_geometry(d_out, cfg.group_size) _validate_group_geometry(d_out, cfg.group_size)
size = _road_param_size(d_out, cfg.road_variant) size = _road_param_size(d_out, cfg.road_variant)
return dict( return dict(
lora_road_theta=ParamSpec((size,), init="zeros"), lora_road_theta=ParamSpec((size,), init="near_zero"),
lora_road_alpha=ParamSpec((size,), init="ones"), lora_road_alpha=ParamSpec((size,), init="near_one"),
) )
@staticmethod @staticmethod
+11 -11
View File
@@ -31,16 +31,16 @@ CFG_BY_VARIANT = {
# Per-variant identity tolerance at t=0 (after attach, before any step). # Per-variant identity tolerance at t=0 (after attach, before any step).
# fp32 SVD round-trip + per-row norm = looser tolerance for pissa/dora/antipasto. # fp32 SVD round-trip + per-row norm = looser tolerance for pissa/dora/antipasto.
IDENTITY_TOL = { IDENTITY_TOL = {
"lora": 1e-6, "lora": 5e-3, # near_zero B: B@A ~ sqrt(r)*eps*kaiming
"pissa": 5e-4, "pissa": 5e-4, # SVD round-trip
"delora": 1e-6, "delora": 1e-6, # exact-zero B, lambda0-scaled
"ia3": 1e-6, "ia3": 5e-3, # near_one gate
"ia3_ff": 1e-6, "ia3_ff": 5e-3, # near_one gate
"dora": 5e-5, "dora": 5e-3, # near_zero B + m
"hra": 5e-6, "hra": 1e-2, # near_zero U + paired-symmetry init
"eva": 1e-6, "eva": 5e-4, # exact-zero B, SVD A overwritten in group_init
"antipasto": 5e-4, "antipasto": 5e-4, # SVD round-trip
"road": 1e-6, "road": 5e-3, # near_zero theta
} }
@@ -302,7 +302,7 @@ def test_dora_bias_passthrough():
ll.attach(model, ll.DoRAConfig(r=2, alpha=4, dtype=torch.float32, target_roles=())) ll.attach(model, ll.DoRAConfig(r=2, alpha=4, dtype=torch.float32, target_roles=()))
with torch.no_grad(): with torch.no_grad():
y = model(x) y = model(x)
assert (y - y_base).abs().max().item() < 1e-5 assert (y - y_base).abs().max().item() < 5e-3 # near_zero B + m init
def test_hra_forward_is_x_R_T(): def test_hra_forward_is_x_R_T():