antipasto: delta_s init 4e-4+N(0,4e-4) from antipasto3, rotate_basis='none' option

This commit is contained in:
wassname
2026-04-27 16:27:12 +08:00
parent 7df786e80b
commit 88f107a423
8 changed files with 28 additions and 23 deletions
+3 -4
View File
@@ -21,12 +21,11 @@ class ParamSpec:
elif self.init == "zeros":
t.zero_()
elif self.init == "near_zero":
# ~identity init but breaks bf16 symmetry: N(0, eps) where eps is a few
# orders above bf16 spacing at 0 (eps_bf16 ~ 1.2e-7). Avoids dead-grad
# from exact-zero -> exact-zero in low-precision training.
# avoid exact-zero dead zone; N(0, 1e-4) is small enough to be
# ~identity but nonzero so gradients always have somewhere to go
t.normal_(0, 1e-4)
elif self.init == "near_one":
# ~identity init for gate/scale params: 1 + N(0, eps)
# avoid exact-one dead zone; 1 + N(0, 1e-4)
t.fill_(1.0).add_(torch.randn_like(t).mul_(1e-4))
elif self.init == "ones":
t.fill_(1.0)
+15 -9
View File
@@ -7,7 +7,7 @@ wassname 2026 https://arxiv.org/abs/2601.07473
R = block_diag(Cayley(skew(rot_T))); Vh_eff = R @ Vh (or U_eff = U @ R.T)
y = x @ W_res.T + ((x @ Vh_eff.T) * (S + delta_s)) @ U_eff.T
Identity at t=0: rot_T~0 -> RI, delta_s~0 -> y ≈ x @ W^T (fp32 SVD round-trip). near_zero init breaks bf16 symmetry without meaningfully breaking identity (~1e-4 noise around zero).
Identity at t=0: rot_T=0 -> R=I, delta_s~4e-4 -> y ≈ x @ W^T (fp32 SVD round-trip, tiny positive bias on delta_s breaks sign symmetry).
Scope cut vs antipasto3: this is a fine-tuning adapter, not the full runtime
steering interface. There is no per-call alpha, so it does not expose the
@@ -41,8 +41,8 @@ class AntiPaSTOConfig(AdapterConfig):
block_size: int = 4
# Cayley map saturation: bounds rotation angle to ~max_rotation_angle radians.
max_rotation_angle: float = 0.5
# Which singular basis to rotate: 'V' (input) or 'U' (output).
rotate_basis: Literal["V", "U"] = "V"
# Which singular basis to rotate: 'V' (input), 'U' (output), or 'none'.
rotate_basis: Literal["V", "U", "none"] = "V"
def _cayley(skew: torch.Tensor) -> torch.Tensor:
@@ -75,17 +75,21 @@ class AntiPaSTO:
bs = int(cfg.block_size)
if r % bs != 0:
raise ValueError(f"AntiPaSTO requires r={r} divisible by block_size={bs}")
n_blocks = r // bs
n_triu = bs * (bs - 1) // 2
return dict(
specs = dict(
# Frozen SVD components captured at init.
lora_U=ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True),
lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True),
lora_Vh=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True),
# Trainable: per-singular-value delta + block-diagonal Cayley rotation.
lora_delta_s=ParamSpec((r,), init="near_zero"),
lora_rot_T=ParamSpec((n_blocks, n_triu), init="near_zero"),
# Trainable: per-singular-value delta.
# antipasto3 uses 4e-4 + N(0, 4e-4): small positive bias breaks sign
# symmetry (rotation alone can't); zero-init works but trains slower.
lora_delta_s=ParamSpec((r,), init=lambda t: t.normal_(0, 4e-4).add_(4e-4)),
)
if cfg.rotate_basis != "none":
n_blocks = r // bs
n_triu = bs * (bs - 1) // 2
specs["lora_rot_T"] = ParamSpec((n_blocks, n_triu), init="zeros")
return specs
@staticmethod
def init(layer: nn.Module, cfg) -> None:
@@ -105,6 +109,8 @@ class AntiPaSTO:
W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
layer.weight.data.copy_(W_res)
# FIXME antipasto needs an init from data too
@staticmethod
def forward(
layer: nn.Module,
+2 -2
View File
@@ -32,9 +32,9 @@ class DoRA:
def param_specs(d_in, d_out, cfg):
return dict(
lora_A=ParamSpec((cfg.r, d_in), init="kaiming"),
lora_B=ParamSpec((d_out, cfg.r), init="near_zero"),
lora_B=ParamSpec((d_out, cfg.r), init="zeros"),
# m is filled from ||W||_c during init(); shape (d_out,)
lora_m=ParamSpec((d_out,), init="near_zero"),
lora_m=ParamSpec((d_out,), init="zeros"),
)
@staticmethod
+1 -1
View File
@@ -46,7 +46,7 @@ class HRA:
return dict(
# Householder vectors stacked as rows (one vector per rank slot)
# init done in init() to enforce paired rows -> R = I at t=0.
lora_U=ParamSpec((cfg.r, d_in), init="near_zero"),
lora_U=ParamSpec((cfg.r, d_in), init="zeros"),
)
@staticmethod
+2 -2
View File
@@ -41,7 +41,7 @@ class IA3:
@staticmethod
def param_specs(d_in, d_out, cfg):
return dict(lora_g=ParamSpec((d_out,), init="near_one"))
return dict(lora_g=ParamSpec((d_out,), init="ones"))
@staticmethod
def init(layer: nn.Module, cfg) -> None:
@@ -62,7 +62,7 @@ class IA3FF:
@staticmethod
def param_specs(d_in, d_out, cfg):
return dict(lora_g=ParamSpec((d_in,), init="near_one"))
return dict(lora_g=ParamSpec((d_in,), init="ones"))
@staticmethod
def init(layer: nn.Module, cfg) -> None:
+1 -1
View File
@@ -32,7 +32,7 @@ class LoRA:
def param_specs(d_in, d_out, cfg):
return dict(
lora_A=ParamSpec((cfg.r, d_in), init="kaiming"),
lora_B=ParamSpec((d_out, cfg.r), init="near_zero"),
lora_B=ParamSpec((d_out, cfg.r), init="zeros"),
)
@staticmethod
+2 -2
View File
@@ -38,8 +38,8 @@ class PiSSA:
@staticmethod
def param_specs(d_in, d_out, cfg):
return dict(
lora_A=ParamSpec((cfg.r, d_in), init="near_zero"),
lora_B=ParamSpec((d_out, cfg.r), init="near_zero"),
lora_A=ParamSpec((cfg.r, d_in), init="zeros"),
lora_B=ParamSpec((d_out, cfg.r), init="zeros"),
)
@staticmethod
+2 -2
View File
@@ -117,8 +117,8 @@ class ROAD:
_validate_group_geometry(d_out, cfg.group_size)
size = _road_param_size(d_out, cfg.road_variant)
return dict(
lora_road_theta=ParamSpec((size,), init="near_zero"),
lora_road_alpha=ParamSpec((size,), init="near_one"),
lora_road_theta=ParamSpec((size,), init="zeros"),
lora_road_alpha=ParamSpec((size,), init="ones"),
)
@staticmethod