mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 16:45:56 +08:00
antipasto: delta_s init 4e-4+N(0,4e-4) from antipasto3, rotate_basis='none' option
This commit is contained in:
@@ -21,12 +21,11 @@ class ParamSpec:
|
|||||||
elif self.init == "zeros":
|
elif self.init == "zeros":
|
||||||
t.zero_()
|
t.zero_()
|
||||||
elif self.init == "near_zero":
|
elif self.init == "near_zero":
|
||||||
# ~identity init but breaks bf16 symmetry: N(0, eps) where eps is a few
|
# avoid exact-zero dead zone; N(0, 1e-4) is small enough to be
|
||||||
# orders above bf16 spacing at 0 (eps_bf16 ~ 1.2e-7). Avoids dead-grad
|
# ~identity but nonzero so gradients always have somewhere to go
|
||||||
# from exact-zero -> exact-zero in low-precision training.
|
|
||||||
t.normal_(0, 1e-4)
|
t.normal_(0, 1e-4)
|
||||||
elif self.init == "near_one":
|
elif self.init == "near_one":
|
||||||
# ~identity init for gate/scale params: 1 + N(0, eps)
|
# avoid exact-one dead zone; 1 + N(0, 1e-4)
|
||||||
t.fill_(1.0).add_(torch.randn_like(t).mul_(1e-4))
|
t.fill_(1.0).add_(torch.randn_like(t).mul_(1e-4))
|
||||||
elif self.init == "ones":
|
elif self.init == "ones":
|
||||||
t.fill_(1.0)
|
t.fill_(1.0)
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ wassname 2026 https://arxiv.org/abs/2601.07473
|
|||||||
R = block_diag(Cayley(skew(rot_T))); Vh_eff = R @ Vh (or U_eff = U @ R.T)
|
R = block_diag(Cayley(skew(rot_T))); Vh_eff = R @ Vh (or U_eff = U @ R.T)
|
||||||
y = x @ W_res.T + ((x @ Vh_eff.T) * (S + delta_s)) @ U_eff.T
|
y = x @ W_res.T + ((x @ Vh_eff.T) * (S + delta_s)) @ U_eff.T
|
||||||
|
|
||||||
Identity at t=0: rot_T~0 -> R≈I, delta_s~0 -> y ≈ x @ W^T (fp32 SVD round-trip). near_zero init breaks bf16 symmetry without meaningfully breaking identity (~1e-4 noise around zero).
|
Identity at t=0: rot_T=0 -> R=I, delta_s~4e-4 -> y ≈ x @ W^T (fp32 SVD round-trip, tiny positive bias on delta_s breaks sign symmetry).
|
||||||
|
|
||||||
Scope cut vs antipasto3: this is a fine-tuning adapter, not the full runtime
|
Scope cut vs antipasto3: this is a fine-tuning adapter, not the full runtime
|
||||||
steering interface. There is no per-call alpha, so it does not expose the
|
steering interface. There is no per-call alpha, so it does not expose the
|
||||||
@@ -41,8 +41,8 @@ class AntiPaSTOConfig(AdapterConfig):
|
|||||||
block_size: int = 4
|
block_size: int = 4
|
||||||
# Cayley map saturation: bounds rotation angle to ~max_rotation_angle radians.
|
# Cayley map saturation: bounds rotation angle to ~max_rotation_angle radians.
|
||||||
max_rotation_angle: float = 0.5
|
max_rotation_angle: float = 0.5
|
||||||
# Which singular basis to rotate: 'V' (input) or 'U' (output).
|
# Which singular basis to rotate: 'V' (input), 'U' (output), or 'none'.
|
||||||
rotate_basis: Literal["V", "U"] = "V"
|
rotate_basis: Literal["V", "U", "none"] = "V"
|
||||||
|
|
||||||
|
|
||||||
def _cayley(skew: torch.Tensor) -> torch.Tensor:
|
def _cayley(skew: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -75,17 +75,21 @@ class AntiPaSTO:
|
|||||||
bs = int(cfg.block_size)
|
bs = int(cfg.block_size)
|
||||||
if r % bs != 0:
|
if r % bs != 0:
|
||||||
raise ValueError(f"AntiPaSTO requires r={r} divisible by block_size={bs}")
|
raise ValueError(f"AntiPaSTO requires r={r} divisible by block_size={bs}")
|
||||||
n_blocks = r // bs
|
specs = dict(
|
||||||
n_triu = bs * (bs - 1) // 2
|
|
||||||
return dict(
|
|
||||||
# Frozen SVD components captured at init.
|
# Frozen SVD components captured at init.
|
||||||
lora_U=ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True),
|
lora_U=ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True),
|
||||||
lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True),
|
lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True),
|
||||||
lora_Vh=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True),
|
lora_Vh=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True),
|
||||||
# Trainable: per-singular-value delta + block-diagonal Cayley rotation.
|
# Trainable: per-singular-value delta.
|
||||||
lora_delta_s=ParamSpec((r,), init="near_zero"),
|
# antipasto3 uses 4e-4 + N(0, 4e-4): small positive bias breaks sign
|
||||||
lora_rot_T=ParamSpec((n_blocks, n_triu), init="near_zero"),
|
# symmetry (rotation alone can't); zero-init works but trains slower.
|
||||||
|
lora_delta_s=ParamSpec((r,), init=lambda t: t.normal_(0, 4e-4).add_(4e-4)),
|
||||||
)
|
)
|
||||||
|
if cfg.rotate_basis != "none":
|
||||||
|
n_blocks = r // bs
|
||||||
|
n_triu = bs * (bs - 1) // 2
|
||||||
|
specs["lora_rot_T"] = ParamSpec((n_blocks, n_triu), init="zeros")
|
||||||
|
return specs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def init(layer: nn.Module, cfg) -> None:
|
def init(layer: nn.Module, cfg) -> None:
|
||||||
@@ -105,6 +109,8 @@ class AntiPaSTO:
|
|||||||
W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
|
W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
|
||||||
layer.weight.data.copy_(W_res)
|
layer.weight.data.copy_(W_res)
|
||||||
|
|
||||||
|
# FIXME antipasto needs an init from data too
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(
|
def forward(
|
||||||
layer: nn.Module,
|
layer: nn.Module,
|
||||||
|
|||||||
@@ -32,9 +32,9 @@ class DoRA:
|
|||||||
def param_specs(d_in, d_out, cfg):
|
def param_specs(d_in, d_out, cfg):
|
||||||
return dict(
|
return dict(
|
||||||
lora_A=ParamSpec((cfg.r, d_in), init="kaiming"),
|
lora_A=ParamSpec((cfg.r, d_in), init="kaiming"),
|
||||||
lora_B=ParamSpec((d_out, cfg.r), init="near_zero"),
|
lora_B=ParamSpec((d_out, cfg.r), init="zeros"),
|
||||||
# m is filled from ||W||_c during init(); shape (d_out,)
|
# m is filled from ||W||_c during init(); shape (d_out,)
|
||||||
lora_m=ParamSpec((d_out,), init="near_zero"),
|
lora_m=ParamSpec((d_out,), init="zeros"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ class HRA:
|
|||||||
return dict(
|
return dict(
|
||||||
# Householder vectors stacked as rows (one vector per rank slot)
|
# Householder vectors stacked as rows (one vector per rank slot)
|
||||||
# init done in init() to enforce paired rows -> R = I at t=0.
|
# init done in init() to enforce paired rows -> R = I at t=0.
|
||||||
lora_U=ParamSpec((cfg.r, d_in), init="near_zero"),
|
lora_U=ParamSpec((cfg.r, d_in), init="zeros"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ class IA3:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def param_specs(d_in, d_out, cfg):
|
def param_specs(d_in, d_out, cfg):
|
||||||
return dict(lora_g=ParamSpec((d_out,), init="near_one"))
|
return dict(lora_g=ParamSpec((d_out,), init="ones"))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def init(layer: nn.Module, cfg) -> None:
|
def init(layer: nn.Module, cfg) -> None:
|
||||||
@@ -62,7 +62,7 @@ class IA3FF:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def param_specs(d_in, d_out, cfg):
|
def param_specs(d_in, d_out, cfg):
|
||||||
return dict(lora_g=ParamSpec((d_in,), init="near_one"))
|
return dict(lora_g=ParamSpec((d_in,), init="ones"))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def init(layer: nn.Module, cfg) -> None:
|
def init(layer: nn.Module, cfg) -> None:
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ class LoRA:
|
|||||||
def param_specs(d_in, d_out, cfg):
|
def param_specs(d_in, d_out, cfg):
|
||||||
return dict(
|
return dict(
|
||||||
lora_A=ParamSpec((cfg.r, d_in), init="kaiming"),
|
lora_A=ParamSpec((cfg.r, d_in), init="kaiming"),
|
||||||
lora_B=ParamSpec((d_out, cfg.r), init="near_zero"),
|
lora_B=ParamSpec((d_out, cfg.r), init="zeros"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -38,8 +38,8 @@ class PiSSA:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def param_specs(d_in, d_out, cfg):
|
def param_specs(d_in, d_out, cfg):
|
||||||
return dict(
|
return dict(
|
||||||
lora_A=ParamSpec((cfg.r, d_in), init="near_zero"),
|
lora_A=ParamSpec((cfg.r, d_in), init="zeros"),
|
||||||
lora_B=ParamSpec((d_out, cfg.r), init="near_zero"),
|
lora_B=ParamSpec((d_out, cfg.r), init="zeros"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -117,8 +117,8 @@ class ROAD:
|
|||||||
_validate_group_geometry(d_out, cfg.group_size)
|
_validate_group_geometry(d_out, cfg.group_size)
|
||||||
size = _road_param_size(d_out, cfg.road_variant)
|
size = _road_param_size(d_out, cfg.road_variant)
|
||||||
return dict(
|
return dict(
|
||||||
lora_road_theta=ParamSpec((size,), init="near_zero"),
|
lora_road_theta=ParamSpec((size,), init="zeros"),
|
||||||
lora_road_alpha=ParamSpec((size,), init="near_one"),
|
lora_road_alpha=ParamSpec((size,), init="ones"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user