mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 17:16:12 +08:00
External review: per-variant audit + design notes
- Two acpx external reviews (codex + opencode): * docs/audit/variants_review.md: per-variant paper-vs-impl audit * docs/audit/design_review.md: peft EVA / baukit / antipasto3 vs lora-lite * docs/audit/SUMMARY.md: aggregate verdicts + 3 risks + 5 follow-ups - docs/refs/: peft_eva.py, peft_eva_finetuning.py, baukit_nethook.py, antipasto3_svd_adapter.py for offline reference Findings: LoRA clean; PiSSA/DoRA/IA3/HRA/DeLoRA have documented partial deviations. Top risks: init/grad tradeoffs hidden by coarse tests; qwen probe lacks strict identity tol; IA3 target placement untested.
This commit is contained in:
@@ -0,0 +1,375 @@
|
||||
"""SVD adapter for bidirectional steering via block-diagonal Cayley rotations.
|
||||
|
||||
Flax NNX port.
|
||||
|
||||
kernel = U @ diag(S) @ Vh + W_res (kernel is (in, out), standard Flax convention)
|
||||
Learnable: delta_s (additive S scaling), rotation_params (block-diagonal V rotation).
|
||||
alpha scales both: S + alpha*delta_s, U @ R(alpha).
|
||||
|
||||
Why Cayley (not Givens or matrix exponential):
|
||||
Cayley gives exact analytical reversibility: R(-alpha) = R(alpha)^{-1}.
|
||||
This is critical -- at alpha=+1 and alpha=-1 the adapter is an exact inverse of
|
||||
itself, making bidirectional steering symmetric by construction.
|
||||
|
||||
At alpha=0: U_rot = U and S_scaled = S, so the layer is identical to frozen weights.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import Array
|
||||
from jaxtyping import Float, Int
|
||||
from einops import rearrange
|
||||
from flax import nnx
|
||||
from loguru import logger
|
||||
|
||||
|
||||
# -- Custom variable types for gradient filtering ----------------------------
|
||||
|
||||
class SVDParam(nnx.Param):
|
||||
"""Trainable SVD adapter parameter (base class -- use subclasses for per-group LR)."""
|
||||
pass
|
||||
|
||||
|
||||
class DeltaSParam(SVDParam):
|
||||
"""Trainable delta_s scaling parameters (full LR)."""
|
||||
pass
|
||||
|
||||
|
||||
class RotationParam(SVDParam):
|
||||
"""Block-diagonal rotation parameters (lower LR via rotation_lr_scale config)."""
|
||||
pass
|
||||
|
||||
|
||||
class SVDFrozen(nnx.Variable):
|
||||
"""Frozen SVD component. Not differentiated."""
|
||||
pass
|
||||
|
||||
|
||||
# -- SVD Steering Linear (replaces nnx.Linear) ------------------------------
|
||||
|
||||
class SVDSteeringLinear(nnx.Module):
|
||||
"""SVD steering adapter replacing a linear layer.
|
||||
|
||||
y = ((x @ U_rot) * S_scaled) @ Vh + x @ W_res
|
||||
|
||||
where U_rot, S_scaled depend on alpha (the steering coefficient).
|
||||
Frozen: U, S, Vh, W_res. Learnable: delta_s, rotation_params.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
U: jax.Array, # (in_dim, r) - input singular vectors
|
||||
S: jax.Array, # (r,) - singular values
|
||||
Vh: jax.Array, # (r, out_dim) - output singular vectors
|
||||
W_res: jax.Array, # (in_dim, out_dim) - residual
|
||||
rotation_block_size: int,
|
||||
max_rotation_angle: float,
|
||||
rotate_U: bool = True,
|
||||
rotate_V: bool = False,
|
||||
use_delta_s: bool = True,
|
||||
*,
|
||||
rngs: nnx.Rngs,
|
||||
):
|
||||
r = S.shape[0]
|
||||
bs = min(rotation_block_size, r)
|
||||
assert r % bs == 0, f"r={r} must be divisible by block_size={bs}"
|
||||
|
||||
# Frozen SVD components
|
||||
self.svd_U = SVDFrozen(U.astype(jnp.float32))
|
||||
self.svd_S = SVDFrozen(S.astype(jnp.float32))
|
||||
self.svd_Vh = SVDFrozen(Vh.astype(jnp.float32))
|
||||
self.svd_W_res = SVDFrozen(W_res.astype(jnp.bfloat16))
|
||||
|
||||
# Trainable: delta_s with small positive bias for symmetry breaking.
|
||||
# The +4e-4 nudges the optimizer to scale up selected dims rather than just
|
||||
# rotating them. Rotation alone cannot break sign symmetry at init.
|
||||
key_s, key_r = jax.random.split(rngs.params())
|
||||
self.delta_s = DeltaSParam(
|
||||
jax.random.truncated_normal(key_s, -2.0, 2.0, (r,)) * 4e-4 + 4e-4
|
||||
)
|
||||
|
||||
# Block-diagonal skew-symmetric rotation params (lower LR via RotationParam type)
|
||||
# Upper-triangle parameterization: store only bs*(bs-1)/2 elements per block,
|
||||
# like OFT/PSOFT. Avoids dead diagonal gradients and redundant (i,j)/(j,i) states.
|
||||
n_blocks = r // bs
|
||||
n_triu = bs * (bs - 1) // 2
|
||||
self.rotation_params = RotationParam(
|
||||
jax.random.truncated_normal(key_r, -2.0, 2.0, (n_blocks, n_triu)) * 1e-4
|
||||
)
|
||||
# Pre-compute upper-triangle indices for skew-symmetric reconstruction
|
||||
rows, cols = jnp.triu_indices(bs, k=1)
|
||||
self._triu_rows = rows
|
||||
self._triu_cols = cols
|
||||
|
||||
# Steering coefficient (mutated during 3-pass forward)
|
||||
self.alpha = nnx.Variable(jnp.float32(1.0))
|
||||
self.max_angle = max_rotation_angle
|
||||
self.block_size = bs
|
||||
self.r = r
|
||||
self.rotate_U = rotate_U
|
||||
self.rotate_V = rotate_V
|
||||
self.use_delta_s = use_delta_s
|
||||
|
||||
def __call__(self, x: Float[Array, "*batch in_features"]) -> Float[Array, "*batch out_features"]:
|
||||
alpha = self.alpha.value
|
||||
U = self.svd_U.value
|
||||
S = self.svd_S.value
|
||||
Vh = self.svd_Vh.value
|
||||
W_res = self.svd_W_res.value
|
||||
params = self.rotation_params.value # (n_blocks, n_triu)
|
||||
bs = self.block_size
|
||||
n_blocks = params.shape[0]
|
||||
|
||||
# Reconstruct skew-symmetric from upper-triangle params (like OFT/PSOFT).
|
||||
# 0.5 factor matches BOFT convention: cancels the 2x gradient from A - A^T.
|
||||
A = jnp.zeros((n_blocks, bs, bs), dtype=jnp.float32)
|
||||
A = A.at[:, self._triu_rows, self._triu_cols].set(params.astype(jnp.float32))
|
||||
A = 0.5 * (A - jnp.swapaxes(A, -1, -2))
|
||||
|
||||
# Angle clamping (element-wise tanh, bounds bidirectional symmetry error)
|
||||
a_limit = 2 * math.tan(self.max_angle / 2)
|
||||
A = a_limit * jnp.tanh(A / a_limit)
|
||||
|
||||
# Cayley transform in float32: R = (I - X)^{-1}(I + X)
|
||||
eye = jnp.eye(bs, dtype=jnp.float32)
|
||||
X = alpha * A / 2
|
||||
R_blocks = jnp.linalg.solve(
|
||||
eye[None] - X,
|
||||
eye[None] + X,
|
||||
)
|
||||
|
||||
# Apply rotation to U (input singular vectors)
|
||||
if self.rotate_U:
|
||||
U_reshaped = U.reshape(U.shape[0], n_blocks, bs)
|
||||
U_rot = jnp.einsum('dnb,nbc->dnc', U_reshaped, R_blocks)
|
||||
U_rot = U_rot.reshape(U.shape)
|
||||
else:
|
||||
U_rot = U
|
||||
|
||||
# Apply rotation to Vh (output singular vectors); off by default
|
||||
# (output rotation changes the upstream basis, making adaptation harder)
|
||||
if self.rotate_V:
|
||||
Vh_reshaped = Vh.reshape(n_blocks, bs, Vh.shape[1])
|
||||
Vh_rot = jnp.einsum('nbc,nbj->ncj', R_blocks, Vh_reshaped)
|
||||
Vh_rot = Vh_rot.reshape(Vh.shape)
|
||||
else:
|
||||
Vh_rot = Vh
|
||||
|
||||
S_scaled = S + alpha * self.delta_s.value if self.use_delta_s else S
|
||||
|
||||
dt = x.dtype
|
||||
out = (x @ U_rot.astype(dt)) * S_scaled.astype(dt)
|
||||
out = out @ Vh_rot.astype(dt)
|
||||
out = out + x @ W_res.astype(dt)
|
||||
return out
|
||||
|
||||
|
||||
def create_svd_adapter(
|
||||
kernel: jax.Array,
|
||||
r: int,
|
||||
rotation_block_size: int,
|
||||
max_rotation_angle: float,
|
||||
rngs: nnx.Rngs,
|
||||
selected_indices: jax.Array | None = None,
|
||||
rotate_U: bool = True,
|
||||
rotate_V: bool = False,
|
||||
use_delta_s: bool = True,
|
||||
) -> SVDSteeringLinear:
|
||||
"""Create SVD adapter from a kernel matrix (in_dim, out_dim).
|
||||
|
||||
If selected_indices is provided, uses those SVD dimensions.
|
||||
Otherwise uses top-r by singular value.
|
||||
"""
|
||||
kernel_f32 = kernel.astype(jnp.float32)
|
||||
U_full, S_full, Vh_full = jnp.linalg.svd(kernel_f32, full_matrices=False)
|
||||
|
||||
|
||||
r_actual = min(r, S_full.shape[0])
|
||||
# Ensure divisible by block size
|
||||
bs = min(rotation_block_size, r_actual)
|
||||
r_actual = (r_actual // bs) * bs
|
||||
if r_actual == 0:
|
||||
r_actual = bs
|
||||
|
||||
if selected_indices is not None:
|
||||
indices = selected_indices[:r_actual]
|
||||
else:
|
||||
indices = jnp.arange(r_actual)
|
||||
|
||||
U = U_full[:, indices]
|
||||
S = S_full[indices]
|
||||
Vh = Vh_full[indices, :]
|
||||
W_res = kernel_f32 - U @ jnp.diag(S) @ Vh
|
||||
|
||||
return SVDSteeringLinear(
|
||||
U, S, Vh, W_res,
|
||||
rotation_block_size=bs,
|
||||
max_rotation_angle=max_rotation_angle,
|
||||
rotate_U=rotate_U,
|
||||
rotate_V=rotate_V,
|
||||
use_delta_s=use_delta_s,
|
||||
rngs=rngs,
|
||||
)
|
||||
|
||||
|
||||
# -- Dimension selection (data-aware) ----------------------------------------
|
||||
|
||||
def score_l1_trip(
|
||||
acts_projected: Float[Array, "n k"], S: Float[Array, " k"], r: int,
|
||||
) -> Int[Array, " selected"]:
|
||||
"""L1 trip scoring: union of top dims from 4 pools (cho, rej, diff_pos, diff_neg).
|
||||
|
||||
Why not top-r by singular value? That picks globally "important" dimensions but
|
||||
ignores whether they are active in the contrastive data. This approach takes:
|
||||
r/3 cho-active, r/3 rej-active, r/6 diff_pos, r/6 diff_neg
|
||||
ensuring all signal types (absolute activation and contrastive difference) are
|
||||
represented in the selected subspace.
|
||||
"""
|
||||
k = S.shape[0]
|
||||
assert r < k
|
||||
act_cho = acts_projected[::2]
|
||||
act_rej = acts_projected[1::2]
|
||||
|
||||
l1_cho = jnp.abs(act_cho).mean(axis=0)
|
||||
l1_rej = jnp.abs(act_rej).mean(axis=0)
|
||||
diff = (act_cho - act_rej).mean(axis=0)
|
||||
|
||||
scores_cho = S * l1_cho
|
||||
scores_rej = S * l1_rej
|
||||
scores_diff_pos = S * jax.nn.relu(diff)
|
||||
scores_diff_neg = S * jax.nn.relu(-diff)
|
||||
|
||||
third = r // 3
|
||||
sixth = (r - 2 * third) // 2
|
||||
sixth_rem = r - 2 * third - 2 * sixth
|
||||
|
||||
top_cho = jnp.argsort(-scores_cho)[:third]
|
||||
top_rej = jnp.argsort(-scores_rej)[:third]
|
||||
top_diff_pos = jnp.argsort(-scores_diff_pos)[:sixth + sixth_rem]
|
||||
top_diff_neg = jnp.argsort(-scores_diff_neg)[:sixth]
|
||||
|
||||
combined = jnp.unique(jnp.concatenate([top_cho, top_rej, top_diff_pos, top_diff_neg]))
|
||||
|
||||
if combined.shape[0] < r:
|
||||
scores_union = jnp.maximum(
|
||||
jnp.maximum(scores_cho, scores_rej),
|
||||
jnp.maximum(scores_diff_pos, scores_diff_neg),
|
||||
)
|
||||
# Mask out already-selected indices
|
||||
mask = jnp.zeros(k, dtype=jnp.bool_)
|
||||
mask = mask.at[combined].set(True)
|
||||
scores_union = jnp.where(mask, -jnp.inf, scores_union)
|
||||
extra = jnp.argsort(-scores_union)[:r - combined.shape[0]]
|
||||
combined = jnp.concatenate([combined, extra])
|
||||
|
||||
return jnp.sort(combined[:r])
|
||||
|
||||
|
||||
|
||||
def polarity_interleave(acts_projected: jax.Array, indices: jax.Array) -> jax.Array:
|
||||
"""Reorder indices so consecutive pairs alternate cho/rej-favoring dims.
|
||||
|
||||
Block-diagonal rotation couples dims within each block of block_size.
|
||||
If all dims in a block favor the same direction (all cho-favoring), the block
|
||||
cannot learn bidirectional steering. Interleaving forces each block to have a
|
||||
mix of cho-favoring and rej-favoring dims, enabling bidirectional learning.
|
||||
"""
|
||||
r = indices.shape[0]
|
||||
assert r % 2 == 0
|
||||
diff_signed = (acts_projected[::2, :][:, indices] - acts_projected[1::2, :][:, indices]).mean(axis=0)
|
||||
rank_order = jnp.argsort(-diff_signed)
|
||||
n_half = r // 2
|
||||
cho_ranked = rank_order[:n_half]
|
||||
rej_ranked = rank_order[n_half:][::-1]
|
||||
interleaved = jnp.stack([cho_ranked, rej_ranked], axis=1).reshape(-1)
|
||||
return indices[interleaved]
|
||||
|
||||
|
||||
# -- Attention output adapter ------------------------------------------------
|
||||
|
||||
class SVDAttnOutAdapter(nnx.Module):
|
||||
"""Drop-in for tunix Einsum('BTNH,NHD->BTD').
|
||||
|
||||
Reshapes encoded [b,t,N,H] -> [b,t,N*H], applies SVDSteeringLinear, returns [b,t,D].
|
||||
Exposes .shape = (N, H, D) so tunix Attention.head_dim/.features still work.
|
||||
"""
|
||||
def __init__(self, svd_linear: SVDSteeringLinear, num_heads: int, head_dim: int):
|
||||
self.svd_linear = svd_linear
|
||||
# Tuple attribute: read-only metadata for tunix property access
|
||||
self.shape = (num_heads, head_dim, svd_linear.svd_Vh.value.shape[1])
|
||||
|
||||
def __call__(self, encoded: jax.Array) -> jax.Array:
|
||||
b, t, N, H = encoded.shape
|
||||
return self.svd_linear(encoded.reshape(b, t, N * H))
|
||||
|
||||
|
||||
class SVDAttnQAdapter(nnx.Module):
|
||||
"""Drop-in for tunix Einsum('BTD,NDH->BTNH') -- Q projection in GQA.
|
||||
|
||||
Reshapes weight (N, D, H) -> (D, N*H) for SVD. On forward, applies SVDSteeringLinear
|
||||
then rearranges output back to [b, t, N, H].
|
||||
Exposes .shape = (N, D, H) so tunix Attention.num_heads reads shape[0].
|
||||
"""
|
||||
def __init__(self, svd_linear: SVDSteeringLinear, num_heads: int, features: int, head_dim: int):
|
||||
self.svd_linear = svd_linear
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.shape = (num_heads, features, head_dim)
|
||||
|
||||
def __call__(self, x: Float[Array, "b t D"]) -> Float[Array, "b t N H"]:
|
||||
b, t, _ = x.shape
|
||||
flat = rearrange(x, 'b t d -> (b t) d')
|
||||
out = self.svd_linear(flat)
|
||||
return rearrange(out, '(b t) (N H) -> b t N H', b=b, t=t, N=self.num_heads, H=self.head_dim)
|
||||
|
||||
|
||||
class SVDAttnKVAdapter(nnx.Module):
|
||||
"""Drop-in for tunix Einsum('BSD,CKDH->CBSKH') -- KV projection in GQA.
|
||||
|
||||
K and V share one einsum with weight (C=2, K, D, H). Reshapes to (D, C*K*H) for SVD.
|
||||
On forward, applies SVDSteeringLinear then rearranges to [C, b, t, K, H].
|
||||
Output is tuple-unpacked: key_proj, value_proj = kv_einsum(x).
|
||||
Exposes .shape = (C, K, D, H) so tunix Attention.num_kv_heads reads shape[1].
|
||||
"""
|
||||
def __init__(self, svd_linear: SVDSteeringLinear, C: int, num_kv_heads: int, features: int, head_dim: int):
|
||||
self.svd_linear = svd_linear
|
||||
self.C = C
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_dim = head_dim
|
||||
self.shape = (C, num_kv_heads, features, head_dim)
|
||||
|
||||
def __call__(self, x: Float[Array, "b t D"]) -> Float[Array, "C b t K H"]:
|
||||
b, t, _ = x.shape
|
||||
flat = rearrange(x, 'b t d -> (b t) d')
|
||||
out = self.svd_linear(flat)
|
||||
return rearrange(out, '(b t) (C K H) -> C b t K H', b=b, t=t, C=self.C, K=self.num_kv_heads, H=self.head_dim)
|
||||
|
||||
|
||||
# -- Utilities ---------------------------------------------------------------
|
||||
|
||||
def set_alpha(model: nnx.Module, alpha: float):
|
||||
"""Set steering coefficient for all SVD adapter layers."""
|
||||
for _, value in nnx.iter_graph(model):
|
||||
if isinstance(value, SVDSteeringLinear):
|
||||
value.alpha.value = jnp.float32(alpha)
|
||||
|
||||
|
||||
def get_svd_modules(model: nnx.Module) -> list[SVDSteeringLinear]:
|
||||
"""Get all SVD steering modules in the model."""
|
||||
modules = []
|
||||
for _, value in nnx.iter_graph(model):
|
||||
if isinstance(value, SVDSteeringLinear):
|
||||
modules.append(value)
|
||||
return modules
|
||||
|
||||
|
||||
def monitor_svd_adapters(model: nnx.Module) -> dict:
|
||||
"""Monitor ||delta_s||/||S|| ratio."""
|
||||
ratios = []
|
||||
for m in get_svd_modules(model):
|
||||
S = m.svd_S.value
|
||||
ds = m.delta_s.value
|
||||
ratios.append(float(jnp.linalg.norm(ds) / jnp.linalg.norm(S)))
|
||||
return {"adapter_ratio": max(ratios) if ratios else 0.0}
|
||||
@@ -0,0 +1,471 @@
|
||||
"""
|
||||
Utilities for instrumenting a torch model.
|
||||
|
||||
Trace will hook one layer at a time.
|
||||
TraceDict will hook multiple layers at once.
|
||||
subsequence slices intervals from Sequential modules.
|
||||
get_module, replace_module, get_parameter resolve dotted names.
|
||||
set_requires_grad recursively sets requires_grad in module parameters.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import inspect
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class Trace(contextlib.AbstractContextManager):
|
||||
"""
|
||||
To retain the output of the named layer during the computation of
|
||||
the given network:
|
||||
|
||||
with Trace(net, 'layer.name') as ret:
|
||||
_ = net(inp)
|
||||
representation = ret.output
|
||||
|
||||
A layer module can be passed directly without a layer name, and
|
||||
its output will be retained. By default, a direct reference to
|
||||
the output object is returned, but options can control this:
|
||||
|
||||
clone=True - retains a copy of the output, which can be
|
||||
useful if you want to see the output before it might
|
||||
be modified by the network in-place later.
|
||||
detach=True - retains a detached reference or copy. (By
|
||||
default the value would be left attached to the graph.)
|
||||
retain_grad=True - request gradient to be retained on the
|
||||
output. After backward(), ret.output.grad is populated.
|
||||
|
||||
retain_input=True - also retains the input.
|
||||
retain_output=False - can disable retaining the output.
|
||||
edit_output=fn - calls the function to modify the output
|
||||
of the layer before passing it the rest of the model.
|
||||
fn can optionally accept (output, layer) arguments
|
||||
for the original output and the layer name.
|
||||
stop=True - throws a StopForward exception after the layer
|
||||
is run, which allows running just a portion of a model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module,
|
||||
layer=None,
|
||||
retain_output=True,
|
||||
retain_input=False,
|
||||
clone=False,
|
||||
detach=False,
|
||||
retain_grad=False,
|
||||
edit_output=None,
|
||||
stop=False,
|
||||
):
|
||||
"""
|
||||
Method to replace a forward method with a closure that
|
||||
intercepts the call, and tracks the hook so that it can be reverted.
|
||||
"""
|
||||
retainer = self
|
||||
self.layer = layer
|
||||
if layer is not None:
|
||||
module = get_module(module, layer)
|
||||
|
||||
def retain_hook(m, inputs, output):
|
||||
if edit_output:
|
||||
output = invoke_with_optional_args(
|
||||
edit_output, output=output, layer=self.layer, inputs=inputs
|
||||
)
|
||||
if retain_input:
|
||||
retainer.input = recursive_copy(
|
||||
inputs[0] if len(inputs) == 1 else inputs,
|
||||
clone=clone,
|
||||
detach=detach,
|
||||
retain_grad=False,
|
||||
) # retain_grad applies to output only.
|
||||
if retain_output:
|
||||
retainer.output = recursive_copy(
|
||||
output, clone=clone, detach=detach, retain_grad=retain_grad
|
||||
)
|
||||
# When retain_grad is set, also insert a trivial
|
||||
# copy operation. That allows in-place operations
|
||||
# to follow without error.
|
||||
if retain_grad:
|
||||
output = recursive_copy(retainer.output, clone=True, detach=False)
|
||||
if stop:
|
||||
raise StopForward()
|
||||
return output
|
||||
|
||||
self.registered_hook = module.register_forward_hook(retain_hook)
|
||||
self.stop = stop
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
self.close()
|
||||
if self.stop and issubclass(type, StopForward):
|
||||
return True
|
||||
|
||||
def close(self):
|
||||
self.registered_hook.remove()
|
||||
|
||||
|
||||
class TraceDict(OrderedDict, contextlib.AbstractContextManager):
|
||||
"""
|
||||
To retain the output of multiple named layers during the computation
|
||||
of the given network:
|
||||
|
||||
with TraceDict(net, ['layer1.name1', 'layer2.name2']) as ret:
|
||||
_ = net(inp)
|
||||
representation = ret['layer1.name1'].output
|
||||
|
||||
If edit_output is provided, it should be a function that takes
|
||||
two arguments: output, and the layer name; and then it returns the
|
||||
modified output.
|
||||
|
||||
Other arguments are the same as Trace. If stop is True, then the
|
||||
execution of the network will be stopped after the last layer
|
||||
listed (even if it would not have been the last to be executed).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module,
|
||||
layers=None,
|
||||
retain_output=True,
|
||||
retain_input=False,
|
||||
clone=False,
|
||||
detach=False,
|
||||
retain_grad=False,
|
||||
edit_output=None,
|
||||
stop=False,
|
||||
):
|
||||
self.stop = stop
|
||||
|
||||
def flag_last_unseen(it):
|
||||
try:
|
||||
it = iter(it)
|
||||
prev = next(it)
|
||||
seen = set([prev])
|
||||
except StopIteration:
|
||||
return
|
||||
for item in it:
|
||||
if item not in seen:
|
||||
yield False, prev
|
||||
seen.add(item)
|
||||
prev = item
|
||||
yield True, prev
|
||||
|
||||
for is_last, layer in flag_last_unseen(layers):
|
||||
|
||||
def optional_dict(obj):
|
||||
if isinstance(obj, dict):
|
||||
return obj.get(layer, None)
|
||||
return obj
|
||||
|
||||
self[layer] = Trace(
|
||||
module=module,
|
||||
layer=layer,
|
||||
retain_output=optional_dict(retain_output),
|
||||
retain_input=optional_dict(retain_input),
|
||||
clone=optional_dict(clone),
|
||||
detach=optional_dict(detach),
|
||||
retain_grad=optional_dict(retain_grad),
|
||||
edit_output=optional_dict(edit_output),
|
||||
stop=stop and is_last,
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
self.close()
|
||||
if self.stop and issubclass(type, StopForward):
|
||||
return True
|
||||
|
||||
def close(self):
|
||||
for layer, trace in reversed(self.items()):
|
||||
trace.close()
|
||||
|
||||
|
||||
class StopForward(Exception):
|
||||
"""
|
||||
If the only output needed from running a network is the retained
|
||||
submodule then Trace(submodule, stop=True) will stop execution
|
||||
immediately after the retained submodule by raising the StopForward()
|
||||
exception. When Trace is used as context manager, it catches that
|
||||
exception and can be used as follows:
|
||||
|
||||
with Trace(net, layername, stop=True) as tr:
|
||||
net(inp) # Only runs the network up to layername
|
||||
print(tr.output)
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def recursive_copy(x, clone=None, detach=None, retain_grad=None):
|
||||
"""
|
||||
Copies a reference to a tensor, or an object that contains tensors,
|
||||
optionally detaching and cloning the tensor(s). If retain_grad is
|
||||
true, the original tensors are marked to have grads retained.
|
||||
"""
|
||||
if not clone and not detach and not retain_grad:
|
||||
return x
|
||||
if isinstance(x, torch.Tensor):
|
||||
if retain_grad:
|
||||
if not x.requires_grad:
|
||||
x.requires_grad = True
|
||||
x.retain_grad()
|
||||
elif detach:
|
||||
x = x.detach()
|
||||
if clone:
|
||||
x = x.clone()
|
||||
return x
|
||||
# Only dicts, lists, and tuples (and subclasses) can be copied.
|
||||
if isinstance(x, dict):
|
||||
return type(x)({k: recursive_copy(v, clone=clone, detach=detach, retain_grad=retain_grad) for k, v in x.items()})
|
||||
elif isinstance(x, (list, tuple)):
|
||||
return type(x)([recursive_copy(v, clone=clone, detach=detach, retain_grad=retain_grad) for v in x])
|
||||
else:
|
||||
assert False, f"Unknown type {type(x)} cannot be broken into tensors."
|
||||
|
||||
|
||||
def subsequence(
|
||||
sequential,
|
||||
first_layer=None,
|
||||
last_layer=None,
|
||||
after_layer=None,
|
||||
upto_layer=None,
|
||||
single_layer=None,
|
||||
share_weights=False,
|
||||
):
|
||||
"""
|
||||
Creates a subsequence of a pytorch Sequential model, copying over
|
||||
modules together with parameters for the subsequence. Only
|
||||
modules from first_layer to last_layer (inclusive) are included,
|
||||
or modules between after_layer and upto_layer (exclusive).
|
||||
Handles descent into dotted layer names as long as all references
|
||||
are within nested Sequential models.
|
||||
|
||||
If share_weights is True, then references the original modules
|
||||
and their parameters without copying them. Otherwise, by default,
|
||||
makes a separate brand-new copy.
|
||||
"""
|
||||
assert (single_layer is None) or (
|
||||
first_layer is last_layer is after_layer is upto_layer is None
|
||||
)
|
||||
if single_layer is not None:
|
||||
first_layer = single_layer
|
||||
last_layer = single_layer
|
||||
first, last, after, upto = [
|
||||
None if d is None else d.split(".")
|
||||
for d in [first_layer, last_layer, after_layer, upto_layer]
|
||||
]
|
||||
return hierarchical_subsequence(
|
||||
sequential,
|
||||
first=first,
|
||||
last=last,
|
||||
after=after,
|
||||
upto=upto,
|
||||
share_weights=share_weights,
|
||||
)
|
||||
|
||||
|
||||
def hierarchical_subsequence(
|
||||
sequential, first, last, after, upto, share_weights=False, depth=0
|
||||
):
|
||||
"""
|
||||
Recursive helper for subsequence() to support descent into dotted
|
||||
layer names. In this helper, first, last, after, and upto are
|
||||
arrays of names resulting from splitting on dots. Can only
|
||||
descend into nested Sequentials.
|
||||
"""
|
||||
assert (last is None) or (upto is None)
|
||||
assert (first is None) or (after is None)
|
||||
if first is last is after is upto is None:
|
||||
return sequential if share_weights else copy.deepcopy(sequential)
|
||||
assert isinstance(sequential, torch.nn.Sequential), (
|
||||
".".join((first or last or after or upto)[:depth] or "arg") + " not Sequential"
|
||||
)
|
||||
including_children = (first is None) and (after is None)
|
||||
included_children = OrderedDict()
|
||||
# A = current level short name of A.
|
||||
# AN = full name for recursive descent if not innermost.
|
||||
(F, FN), (L, LN), (A, AN), (U, UN) = [
|
||||
(d[depth], (None if len(d) == depth + 1 else d))
|
||||
if d is not None
|
||||
else (None, None)
|
||||
for d in [first, last, after, upto]
|
||||
]
|
||||
for name, layer in sequential._modules.items():
|
||||
if name == F:
|
||||
first = None
|
||||
including_children = True
|
||||
if name == A and AN is not None: # just like F if not a leaf.
|
||||
after = None
|
||||
including_children = True
|
||||
if name == U and UN is None:
|
||||
upto = None
|
||||
including_children = False
|
||||
if including_children:
|
||||
# AR = full name for recursive descent if name matches.
|
||||
FR, LR, AR, UR = [
|
||||
n if n is None or n[depth] == name else None for n in [FN, LN, AN, UN]
|
||||
]
|
||||
chosen = hierarchical_subsequence(
|
||||
layer,
|
||||
first=FR,
|
||||
last=LR,
|
||||
after=AR,
|
||||
upto=UR,
|
||||
share_weights=share_weights,
|
||||
depth=depth + 1,
|
||||
)
|
||||
if chosen is not None:
|
||||
included_children[name] = chosen
|
||||
if name == L:
|
||||
last = None
|
||||
including_children = False
|
||||
if name == U and UN is not None: # just like L if not a leaf.
|
||||
upto = None
|
||||
including_children = False
|
||||
if name == A and AN is None:
|
||||
after = None
|
||||
including_children = True
|
||||
for name in [first, last, after, upto]:
|
||||
if name is not None:
|
||||
raise ValueError("Layer %s not found" % ".".join(name))
|
||||
# Omit empty subsequences except at the outermost level,
|
||||
# where we should not return None.
|
||||
if not len(included_children) and depth > 0:
|
||||
return None
|
||||
result = torch.nn.Sequential(included_children)
|
||||
result.training = sequential.training
|
||||
return result
|
||||
|
||||
|
||||
def set_requires_grad(requires_grad, *models):
|
||||
"""
|
||||
Sets requires_grad true or false for all parameters within the
|
||||
models passed.
|
||||
"""
|
||||
for model in models:
|
||||
if isinstance(model, torch.nn.Module):
|
||||
for param in model.parameters():
|
||||
param.requires_grad = requires_grad
|
||||
elif isinstance(model, (torch.nn.Parameter, torch.Tensor)):
|
||||
model.requires_grad = requires_grad
|
||||
else:
|
||||
assert False, "unknown type %r" % type(model)
|
||||
|
||||
|
||||
def get_module(model, name):
|
||||
"""
|
||||
Finds the named module within the given model.
|
||||
"""
|
||||
for n, m in model.named_modules():
|
||||
if n == name:
|
||||
return m
|
||||
raise LookupError(name)
|
||||
|
||||
|
||||
def get_parameter(model, name):
|
||||
"""
|
||||
Finds the named parameter within the given model.
|
||||
"""
|
||||
for n, p in model.named_parameters():
|
||||
if n == name:
|
||||
return p
|
||||
raise LookupError(name)
|
||||
|
||||
|
||||
def module_names(model):
|
||||
"""
|
||||
Lists all the module names.
|
||||
"""
|
||||
return [n for n, _ in model.named_modules()]
|
||||
|
||||
|
||||
def parameter_names(model):
|
||||
"""
|
||||
Lists all the parameter names.
|
||||
"""
|
||||
return [n for n, _ in model.named_parameters()]
|
||||
|
||||
|
||||
def replace_module(model, name, new_module):
|
||||
"""
|
||||
Replaces the named module within the given model.
|
||||
"""
|
||||
if "." in name:
|
||||
parent_name, attr_name = name.rsplit(".", 1)
|
||||
model = get_module(model, parent_name)
|
||||
# original_module = getattr(model, attr_name)
|
||||
setattr(model, attr_name, new_module)
|
||||
|
||||
|
||||
def invoke_with_optional_args(fn, *args, **kwargs):
|
||||
"""
|
||||
Invokes a function with only the arguments that it
|
||||
is written to accept, giving priority to arguments
|
||||
that match by-name, using the following rules.
|
||||
(1) arguments with matching names are passed by name.
|
||||
(2) remaining non-name-matched args are passed by order.
|
||||
(3) extra caller arguments that the function cannot
|
||||
accept are not passed.
|
||||
(4) extra required function arguments that the caller
|
||||
cannot provide cause a TypeError to be raised.
|
||||
Ordinary python calling conventions are helpful for
|
||||
supporting a function that might be revised to accept
|
||||
extra arguments in a newer version, without requiring the
|
||||
caller to pass those new arguments. This function helps
|
||||
support function callers that might be revised to supply
|
||||
extra arguments, without requiring the callee to accept
|
||||
those new arguments.
|
||||
"""
|
||||
argspec = inspect.getfullargspec(fn)
|
||||
pass_args = []
|
||||
used_kw = set()
|
||||
unmatched_pos = []
|
||||
used_pos = 0
|
||||
defaulted_pos = len(argspec.args) - (
|
||||
0 if not argspec.defaults else len(argspec.defaults)
|
||||
)
|
||||
# Pass positional args that match name first, then by position.
|
||||
for i, n in enumerate(argspec.args):
|
||||
if n in kwargs:
|
||||
pass_args.append(kwargs[n])
|
||||
used_kw.add(n)
|
||||
elif used_pos < len(args):
|
||||
pass_args.append(args[used_pos])
|
||||
used_pos += 1
|
||||
else:
|
||||
unmatched_pos.append(len(pass_args))
|
||||
pass_args.append(
|
||||
None if i < defaulted_pos else argspec.defaults[i - defaulted_pos]
|
||||
)
|
||||
# Fill unmatched positional args with unmatched keyword args in order.
|
||||
if len(unmatched_pos):
|
||||
for k, v in kwargs.items():
|
||||
if k in used_kw or k in argspec.kwonlyargs:
|
||||
continue
|
||||
pass_args[unmatched_pos[0]] = v
|
||||
used_kw.add(k)
|
||||
unmatched_pos = unmatched_pos[1:]
|
||||
if len(unmatched_pos) == 0:
|
||||
break
|
||||
else:
|
||||
if unmatched_pos[0] < defaulted_pos:
|
||||
unpassed = ", ".join(
|
||||
argspec.args[u] for u in unmatched_pos if u < defaulted_pos
|
||||
)
|
||||
raise TypeError(f"{fn.__name__}() cannot be passed {unpassed}.")
|
||||
# Pass remaining kw args if they can be accepted.
|
||||
pass_kw = {
|
||||
k: v
|
||||
for k, v in kwargs.items()
|
||||
if k not in used_kw and (k in argspec.kwonlyargs or argspec.varargs is not None)
|
||||
}
|
||||
# Pass remaining positional args if they can be accepted.
|
||||
if argspec.varargs is not None:
|
||||
pass_args += list(args[used_pos:])
|
||||
return fn(*pass_args, **pass_kw)
|
||||
@@ -0,0 +1,739 @@
|
||||
# Copyright 2024-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from collections import Counter, defaultdict
|
||||
from collections.abc import Iterable, Mapping
|
||||
from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from itertools import cycle
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from tqdm import tqdm
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
|
||||
from peft.tuners.tuners_utils import _find_minimal_target_modules, check_target_module_exists
|
||||
from peft.utils.constants import MIN_TARGET_MODULES_FOR_OPTIMIZATION
|
||||
from peft.utils.incremental_pca import IncrementalPCA
|
||||
from peft.utils.other import _get_submodules, get_pattern_key
|
||||
|
||||
from .config import LoraConfig
|
||||
from .layer import Embedding, LoraLayer, MultiheadAttention, _ConvNd
|
||||
|
||||
|
||||
UNSUPPORTED_LORA_MODULES = (Embedding, MultiheadAttention, _ConvNd)
|
||||
|
||||
|
||||
class _Hook:
|
||||
"""
|
||||
A base class for hooks that prepares layer inputs for EVA.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
prepare_layer_inputs_fn: Optional[callable] = None,
|
||||
gather_distributed_inputs: bool = True,
|
||||
):
|
||||
self.name = name
|
||||
self.gather_distributed_inputs = gather_distributed_inputs
|
||||
if prepare_layer_inputs_fn is None:
|
||||
self._prepare_layer_inputs_fn = self._prepare_layer_inputs_fn_default
|
||||
else:
|
||||
self._prepare_layer_inputs_fn = prepare_layer_inputs_fn
|
||||
self.model_input = None
|
||||
|
||||
@staticmethod
|
||||
def _prepare_layer_inputs_fn_default(layer_input, model_input, layer_name) -> torch.Tensor:
|
||||
if isinstance(layer_input, torch.Tensor):
|
||||
pass
|
||||
elif isinstance(layer_input, (tuple, list)):
|
||||
layer_input = layer_input[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"unsupported input type {type(layer_input)} for prepare_layer_inputs_fn in layer {layer_name}, "
|
||||
"please provide a custom prepare_layer_inputs_fn"
|
||||
)
|
||||
# if the input has more than 2 dimensions, we flatten all but the last dimension
|
||||
if layer_input.ndim > 2:
|
||||
layer_input = layer_input.view(-1, layer_input.size(-1))
|
||||
return layer_input
|
||||
|
||||
@torch.no_grad()
|
||||
def prepare_layer_inputs(self, layer_input):
|
||||
return self._prepare_layer_inputs_fn(layer_input, self.model_input, self.name)
|
||||
|
||||
def gather_layer_inputs(self, layer_input):
|
||||
if dist.is_initialized() and self.gather_distributed_inputs:
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
# First gather sizes from all processes more efficiently
|
||||
local_size = torch.tensor([layer_input.shape[0]], device=layer_input.device)
|
||||
all_sizes = torch.empty(world_size, dtype=local_size.dtype, device=layer_input.device)
|
||||
dist.all_gather_into_tensor(all_sizes, local_size)
|
||||
all_sizes = all_sizes.tolist()
|
||||
|
||||
# Find maximum size and pad tensors
|
||||
padded_input = layer_input.new_zeros((max(all_sizes), *layer_input.shape[1:]))
|
||||
padded_input[: layer_input.shape[0]] = layer_input
|
||||
|
||||
# Gather padded tensors
|
||||
gathered_inputs = [torch.zeros_like(padded_input) for _ in range(world_size)]
|
||||
dist.all_gather(gathered_inputs, padded_input.contiguous())
|
||||
|
||||
# Remove padding for each gathered tensor
|
||||
gathered_inputs = [tensor[:size] for tensor, size in zip(gathered_inputs, all_sizes)]
|
||||
|
||||
# Concatenate along batch dimension
|
||||
return torch.cat(gathered_inputs, dim=0)
|
||||
return layer_input
|
||||
|
||||
|
||||
class SVDHook(_Hook):
|
||||
"""
|
||||
A forward hook for calculating incremental SVD on layer inputs. The hook is designed to be registered to a PyTorch
|
||||
module using the `register_forward_hook` method.
|
||||
|
||||
This hook performs a step of incremental Singular Value Decomposition (SVD) on the inputs of a specified layer
|
||||
during the forward pass of a neural network. The hook also tracks convergence of the computed components using
|
||||
cosine similarity between the current and previous components.
|
||||
|
||||
Args:
|
||||
name (str): Name of the layer to which this hook is attached.
|
||||
n_components (int): Number of principal components to compute.
|
||||
sim_thresh (Union[float, torch.Tensor]): Similarity threshold for convergence.
|
||||
prepare_layer_inputs_fn (Optional[callable]): Function to prepare layer inputs for SVD.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_components: int,
|
||||
sim_thresh: Union[float, torch.Tensor],
|
||||
**base_class_kwargs,
|
||||
):
|
||||
super().__init__(**base_class_kwargs)
|
||||
self.n_components = n_components
|
||||
self.sim_thresh = sim_thresh
|
||||
if isinstance(sim_thresh, torch.Tensor) and len(sim_thresh.shape) > 0:
|
||||
check1 = sim_thresh.size(0) == n_components or sim_thresh.size(0) == 1
|
||||
check2 = len(sim_thresh.shape) == 1
|
||||
if not (check1 and check2):
|
||||
raise ValueError(
|
||||
"if sim_thresh is a tensor with more than 0 dimensions it must have shape (n_components,) or (1,)"
|
||||
)
|
||||
self.svd = IncrementalPCA(
|
||||
n_components=n_components,
|
||||
copy=True,
|
||||
lowrank=True,
|
||||
lowrank_seed=42,
|
||||
)
|
||||
self.model_input = None
|
||||
self.converged = torch.zeros((n_components,), dtype=torch.bool)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, model, input, output):
|
||||
previous_components = None
|
||||
if hasattr(self.svd, "components_"):
|
||||
previous_components = self.svd.components_.clone().detach()
|
||||
states = self.prepare_layer_inputs(input)
|
||||
states = self.gather_layer_inputs(states)
|
||||
# check if batch sizes is more than the number of components
|
||||
if states.size(0) < self.n_components:
|
||||
print(f"skipping SVD for {self.name} because there are less than {self.n_components} examples")
|
||||
return
|
||||
self.svd.partial_fit(states.to(torch.float32))
|
||||
# add if statement to check if we are in the first step where previous_components is None
|
||||
if previous_components is None:
|
||||
return
|
||||
components = self.svd.components_
|
||||
if len(components.shape) == 1:
|
||||
components = components.reshape(1, -1)
|
||||
previous_components = previous_components.reshape(1, -1)
|
||||
# consider as converged if enough components have converged via cossim
|
||||
sim = torch.nn.functional.cosine_similarity(components, previous_components)
|
||||
self.converged = sim >= self.sim_thresh
|
||||
|
||||
|
||||
# This is used to determine if inputs of two different layers are equal. For such cases, SVD
|
||||
# needs to be done for only for one of the equal inputs.
|
||||
class HashHook(_Hook):
|
||||
"""
|
||||
A forward hook for hashing layer inputs. The hook is designed to be registered to a PyTorch module using the
|
||||
`register_forward_hook` method.
|
||||
|
||||
This hook hashes the inputs of a specified layer during the forward pass of a neural network and stores the hash
|
||||
values for later analysis or comparison.
|
||||
|
||||
Args:
|
||||
name (str): Name of the layer to which this hook is attached. hashed_inputs (list): List of hashed inputs.
|
||||
prepare_layer_inputs_fn (Optional[callable]): Function to prepare layer inputs for hashing.
|
||||
"""
|
||||
|
||||
def __init__(self, **base_class_kwargs):
|
||||
super().__init__(**base_class_kwargs)
|
||||
self.hashed_inputs = []
|
||||
|
||||
@staticmethod
|
||||
def hash_fn(tensor):
|
||||
return hash(tuple(tensor.view(-1).tolist()))
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, model, input, output):
|
||||
x = self.prepare_layer_inputs(input)
|
||||
x = self.gather_layer_inputs(x)
|
||||
self.hashed_inputs.append(self.hash_fn(x.cpu()))
|
||||
|
||||
|
||||
def find_equal_values(dictionary: dict) -> dict:
|
||||
"""
|
||||
Find keys in a dictionary that have the same value.
|
||||
|
||||
This function takes a dictionary and returns a new dictionary containing keys that have the same value. The keys in
|
||||
the output dictionary are the values from the input dictionary, and the values are lists of keys that share the
|
||||
same value.
|
||||
"""
|
||||
value_dict = defaultdict(list)
|
||||
for k, v in dictionary.items():
|
||||
value_dict[v].append(k)
|
||||
return {k: v for k, v in value_dict.items() if len(v) > 1}
|
||||
|
||||
|
||||
def get_device_with_meta_params(model: torch.nn.Module) -> torch.device:
|
||||
"""
|
||||
Get the device of the model's parameters. Useful if some parameters are on meta device.
|
||||
"""
|
||||
devices = list({p.device for p in model.parameters() if p.device.type != "meta"})
|
||||
if len(devices) > 1:
|
||||
warnings.warn(f"Could not determine device, model has multiple devices: {devices}")
|
||||
return
|
||||
return devices[0]
|
||||
|
||||
|
||||
def move_inputs_to_device(inputs, device: Union[str, torch.device]):
|
||||
"""
|
||||
Move the inputs to the specified device. Adapted from hf.Trainer.
|
||||
"""
|
||||
if hasattr(inputs, "to"):
|
||||
return inputs.to(device)
|
||||
if isinstance(inputs, Mapping):
|
||||
return type(inputs)({k: move_inputs_to_device(v, device) for k, v in inputs.items()})
|
||||
elif isinstance(inputs, (tuple, list)):
|
||||
return type(inputs)(move_inputs_to_device(v, device) for v in inputs)
|
||||
else:
|
||||
warnings.warn(f"input of type {type(inputs)} could not be moved to the correct device")
|
||||
return inputs
|
||||
|
||||
|
||||
def prepare_model_inputs_fn_language_modeling(model_input, peft_config: LoraConfig):
|
||||
"""
|
||||
Get the indices of the items that should be used for SVD.
|
||||
|
||||
Attributes:
|
||||
model_input (dict): The model inputs.
|
||||
peft_config (LoraConfig): The configuration for the LoRA layers.
|
||||
"""
|
||||
if not isinstance(model_input, dict):
|
||||
raise ValueError("When using `prepare_model_inputs_fn_language_modeling` inputs must be a dictionary")
|
||||
mask = model_input.get("attention_mask", torch.ones_like(model_input["input_ids"])).bool()
|
||||
if peft_config.eva_config.use_label_mask and hasattr(model_input, "labels"):
|
||||
mask = torch.logical_and(mask, model_input["labels"] != peft_config.eva_config.label_mask_value)
|
||||
return mask.nonzero()
|
||||
|
||||
|
||||
def prepare_layer_inputs_fn_language_modeling(layer_input, model_input, layer_name) -> torch.Tensor:
|
||||
"""
|
||||
if not all items in the input should be used for SVD, this function can be used to get the indices of the items
|
||||
that should be used.
|
||||
|
||||
Attributes:
|
||||
layer_input (torch.Tensor): The layer inputs.
|
||||
model_input (torch.Tensor):
|
||||
The model inputs or if `prepare_model_inputs_fn` is not None the output of this function.
|
||||
layer_name (str): The name of the layer.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The input to the SVD.
|
||||
"""
|
||||
# if layer inputs are not a tensor, we simply get the first item
|
||||
if isinstance(layer_input, torch.Tensor):
|
||||
pass
|
||||
elif isinstance(layer_input, (tuple, list)):
|
||||
layer_input = layer_input[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"unsupported input type {type(layer_input)} for prepare_layer_inputs_fn in layer {layer_name}, "
|
||||
"please provide a custom prepare_layer_inputs_fn"
|
||||
)
|
||||
# in this case model_input is the output of `prepare_model_inputs_fn_language_modeling`
|
||||
return layer_input[model_input.T.unbind()]
|
||||
|
||||
|
||||
def forward_fn_dict(model, inputs):
|
||||
return model(**inputs)
|
||||
|
||||
|
||||
def _get_eva_state_dict(
|
||||
model: torch.nn.Module,
|
||||
dataloader: Iterable,
|
||||
peft_config: Optional[LoraConfig],
|
||||
target_module_check_fn: callable,
|
||||
forward_fn: Optional[callable],
|
||||
prepare_model_inputs_fn: Optional[callable],
|
||||
prepare_layer_inputs_fn: Union[callable, dict[str, callable], None],
|
||||
gather_distributed_inputs: bool,
|
||||
show_progress_bar: bool,
|
||||
) -> dict:
|
||||
# Computes the rank distribution for each layer based on the explained variance ratio.
|
||||
# when rank_pattern flag is False, all values in max_components are the same
|
||||
def _get_rank_distribution(hooks, layer_hook_map, equal_inputs_map, rank_budget, max_components):
|
||||
exp_vars = {k: h[0].svd.explained_variance_ratio_[: max_components[k]] for k, h in hooks.items()}
|
||||
keys, values = zip(*[(k, c) for k, name in layer_hook_map.items() for c in exp_vars[name]])
|
||||
idx = torch.stack(values).argsort(descending=True)
|
||||
counts = Counter([keys[i] for i in idx[:rank_budget]])
|
||||
counts = {k: counts.get(k, 0) for k in layer_hook_map.keys()} # add layers with 0 rank
|
||||
for k, k_hook in equal_inputs_map.items():
|
||||
# ensure hook layers have the highest rank if they are equal to another layer
|
||||
rank, rank_hook = counts[k], counts[k_hook]
|
||||
if rank_hook >= rank:
|
||||
continue
|
||||
counts[k_hook], counts[k] = rank, rank_hook
|
||||
return counts
|
||||
|
||||
# dataloader is not empty
|
||||
if len(dataloader) == 0:
|
||||
raise ValueError("dataloader is empty")
|
||||
|
||||
# check if dist is initialized
|
||||
if dist.is_initialized() and gather_distributed_inputs:
|
||||
warnings.warn(
|
||||
"torch.distributed is initialized and `gather_distributed_inputs` is True, "
|
||||
"therefore EVA initialization will gather tensors from all ranks. "
|
||||
"Ensure the model does not receive the same inputs on different ranks."
|
||||
)
|
||||
|
||||
# for unusually high rho values, define an upper limit
|
||||
rho_threshold = 1000
|
||||
rho = peft_config.eva_config.rho
|
||||
if rho > rho_threshold:
|
||||
max_dim = max(max(p.shape) for p in model.parameters())
|
||||
rho_ceil = max_dim // peft_config.r
|
||||
rho = min(rho, rho_ceil)
|
||||
|
||||
training = model.training
|
||||
device = get_device_with_meta_params(model)
|
||||
model.eval()
|
||||
|
||||
# get model inputs
|
||||
inputs = next(iter(dataloader))
|
||||
if device is not None:
|
||||
inputs = move_inputs_to_device(inputs, device)
|
||||
if prepare_model_inputs_fn is not None:
|
||||
model_inputs_for_hooks = prepare_model_inputs_fn(inputs, peft_config)
|
||||
else:
|
||||
model_inputs_for_hooks = deepcopy(inputs)
|
||||
|
||||
hooks = {}
|
||||
max_components = {}
|
||||
rank_budget = 0
|
||||
for name, module in model.named_modules():
|
||||
if not target_module_check_fn(name, module):
|
||||
continue
|
||||
if isinstance(prepare_layer_inputs_fn, Mapping):
|
||||
fn = prepare_layer_inputs_fn.pop(name, None)
|
||||
else:
|
||||
fn = prepare_layer_inputs_fn
|
||||
hook = HashHook(name=name, prepare_layer_inputs_fn=fn, gather_distributed_inputs=gather_distributed_inputs)
|
||||
hook.model_input = model_inputs_for_hooks
|
||||
handle = module.register_forward_hook(hook)
|
||||
hooks[name] = (hook, handle)
|
||||
layer_rank = peft_config.rank_pattern.get(
|
||||
get_pattern_key(peft_config.rank_pattern.keys(), name), peft_config.r
|
||||
)
|
||||
max_components[name] = round(layer_rank * rho)
|
||||
rank_budget += layer_rank
|
||||
if isinstance(prepare_layer_inputs_fn, Mapping) and len(prepare_layer_inputs_fn) > 0:
|
||||
raise ValueError(
|
||||
"prepare_layer_inputs_fn is a mapping but the following module names were not found in the model: "
|
||||
f"{prepare_layer_inputs_fn.keys()}"
|
||||
)
|
||||
|
||||
# forward for one batch to check which layer inputs are equal to avoid unneeded svd calculations
|
||||
forward_fn(model, inputs)
|
||||
hash_dict = {k: h[0].hashed_inputs[0] for k, h in hooks.items()}
|
||||
# equal input maps groups layers which receive the same input. One layer is defined as the key and receives an svd
|
||||
# hook. For the remaining layers the svd results can be skipped.
|
||||
equal_inputs = list(find_equal_values(hash_dict).values())
|
||||
equal_inputs_map = {vv: v[0] for v in equal_inputs for vv in v[1:]}
|
||||
# for layers with equal inputs we need to make sure that the max_components are the same
|
||||
for names in equal_inputs:
|
||||
max_value = max(max_components[n] for n in names)
|
||||
for n in names:
|
||||
max_components[n] = max_value
|
||||
|
||||
# initialize svd hooks
|
||||
for name in list(hooks.keys()):
|
||||
hook, handle = hooks.pop(name)
|
||||
handle.remove()
|
||||
if name in equal_inputs_map:
|
||||
continue
|
||||
hook = SVDHook(
|
||||
n_components=max_components[name],
|
||||
sim_thresh=peft_config.eva_config.tau,
|
||||
name=name,
|
||||
prepare_layer_inputs_fn=hook._prepare_layer_inputs_fn,
|
||||
gather_distributed_inputs=gather_distributed_inputs,
|
||||
)
|
||||
module = model.get_submodule(name)
|
||||
handle = module.register_forward_hook(hook)
|
||||
hooks[name] = (hook, handle) # adding the old handle here so we dont get errors in the first forward pass
|
||||
layer_hook_map = {**dict(zip(hooks.keys(), hooks.keys())), **equal_inputs_map}
|
||||
|
||||
# start svd calculation
|
||||
if show_progress_bar and (not dist.is_initialized() or dist.get_rank() == 0):
|
||||
pbar = tqdm(iter(cycle(dataloader)), position=0, leave=False)
|
||||
use_tqdm = True
|
||||
else:
|
||||
pbar = iter(cycle(dataloader))
|
||||
use_tqdm = False
|
||||
convergence_dict = {k: False for k in hooks.keys()}
|
||||
rank_dist = max_components.copy()
|
||||
for inputs in pbar:
|
||||
if device is not None:
|
||||
inputs = move_inputs_to_device(inputs, device)
|
||||
if prepare_model_inputs_fn is not None:
|
||||
model_inputs_for_hooks = prepare_model_inputs_fn(inputs, peft_config)
|
||||
else:
|
||||
model_inputs_for_hooks = deepcopy(inputs)
|
||||
|
||||
for name in list(hooks.keys()):
|
||||
hook, handle = hooks[name]
|
||||
# check if all components that are needed for the rank distribution have converged
|
||||
converged = torch.all(hook.converged[: rank_dist[name]])
|
||||
# if a layer has switched from not converged to converged in the current step
|
||||
if (not convergence_dict[name]) and converged and handle:
|
||||
handle.remove()
|
||||
handle = None
|
||||
convergence_dict[name] = True
|
||||
continue
|
||||
# if a layer has switched from converged to not converged in the current step
|
||||
elif convergence_dict[name] and not converged:
|
||||
module = model.get_submodule(name)
|
||||
handle = module.register_forward_hook(hook)
|
||||
convergence_dict[name] = False
|
||||
hook.model_input = model_inputs_for_hooks
|
||||
hooks[name] = (hook, handle)
|
||||
|
||||
if use_tqdm:
|
||||
layer_converged = list(convergence_dict.values()) + [
|
||||
convergence_dict[v] for v in equal_inputs_map.values()
|
||||
]
|
||||
pbar.set_description(f"{sum(layer_converged)}/{len(layer_converged)} layers have converged")
|
||||
|
||||
if all(convergence_dict.values()):
|
||||
break
|
||||
|
||||
forward_fn(model, inputs)
|
||||
|
||||
# in case some hooks have to skip the svd calculation because the number of tokens is less than the number of
|
||||
# components
|
||||
if not all(hasattr(h[0].svd, "components_") for h in hooks.values()):
|
||||
continue
|
||||
|
||||
rank_dist = _get_rank_distribution(hooks, layer_hook_map, equal_inputs_map, rank_budget, max_components)
|
||||
|
||||
# check all custom hooks have been removed
|
||||
remaining_hooks = {n for n, m in model.named_modules() for v in m._forward_hooks.values() if isinstance(v, _Hook)}
|
||||
if len(remaining_hooks) > 0:
|
||||
raise ValueError(
|
||||
f"Found active hooks added by EVA that weren't properly removed: {remaining_hooks}. "
|
||||
"Please report this issue at https://github.com/huggingface/peft/issues"
|
||||
)
|
||||
|
||||
eva_state_dict = {}
|
||||
for name, rank in rank_dist.items():
|
||||
hook = hooks[layer_hook_map[name]][0]
|
||||
if not torch.all(hook.converged[:rank]):
|
||||
raise ValueError(
|
||||
f"Layer {name} has not converged but was assigned rank {rank}. "
|
||||
"Please report this issue at https://github.com/huggingface/peft/issues"
|
||||
)
|
||||
u = hook.svd.components_[:rank]
|
||||
if peft_config.eva_config.whiten:
|
||||
u /= hook.svd.singular_values_[:rank].sqrt().reshape(-1, 1)
|
||||
eva_state_dict[name] = u
|
||||
|
||||
# restore model state
|
||||
model.train(training)
|
||||
|
||||
# move tensors to device
|
||||
if device is not None:
|
||||
eva_state_dict = {k: v.to(device) for k, v in eva_state_dict.items()}
|
||||
|
||||
return eva_state_dict
|
||||
|
||||
|
||||
def _load_eva_state_dict(
|
||||
model: torch.nn.Module,
|
||||
eva_state_dict: dict,
|
||||
adapter_name: str,
|
||||
):
|
||||
peft_config = model.peft_config[adapter_name]
|
||||
update_layer_kwargs = {
|
||||
"adapter_name": adapter_name,
|
||||
"lora_dropout": peft_config.lora_dropout,
|
||||
"use_rslora": peft_config.use_rslora,
|
||||
"use_dora": peft_config.use_dora,
|
||||
"lora_bias": peft_config.lora_bias,
|
||||
}
|
||||
missing_eva_inits = []
|
||||
new_target_modules = []
|
||||
other_module_names = []
|
||||
rank_pattern = {}
|
||||
alpha_pattern = {}
|
||||
for name, module in model.named_modules():
|
||||
name_in_base_model = name.replace("base_model.model.", "")
|
||||
if not isinstance(module, LoraLayer):
|
||||
other_module_names.append(name_in_base_model)
|
||||
continue
|
||||
# Regexp matching - Find key which matches current target_name in patterns provided
|
||||
r = peft_config.rank_pattern.get(get_pattern_key(peft_config.rank_pattern.keys(), name), peft_config.r)
|
||||
alpha = peft_config.alpha_pattern.get(
|
||||
get_pattern_key(peft_config.alpha_pattern.keys(), name), peft_config.lora_alpha
|
||||
)
|
||||
if name in eva_state_dict:
|
||||
w = eva_state_dict.pop(name)
|
||||
new_rank = w.size(0)
|
||||
if new_rank == 0:
|
||||
parent, _, target_name = _get_submodules(model, name)
|
||||
setattr(parent, target_name, module.get_base_layer())
|
||||
continue
|
||||
elif new_rank != r:
|
||||
if peft_config.eva_config.adjust_scaling_factors:
|
||||
alpha *= new_rank / r
|
||||
if new_rank != r or module.lora_A[adapter_name].weight.device.type == "meta":
|
||||
module.update_layer(r=new_rank, lora_alpha=alpha, init_lora_weights="eva", **update_layer_kwargs)
|
||||
module.lora_A[adapter_name].weight.copy_(w)
|
||||
new_target_modules.append(name_in_base_model)
|
||||
else:
|
||||
module.update_layer(r=r, lora_alpha=alpha, init_lora_weights=True, **update_layer_kwargs)
|
||||
missing_eva_inits.append(name_in_base_model)
|
||||
new_rank = r
|
||||
# update rank pattern and alpha pattern
|
||||
if new_rank != peft_config.r:
|
||||
rank_pattern[name_in_base_model] = new_rank
|
||||
if alpha != peft_config.lora_alpha:
|
||||
alpha_pattern[name_in_base_model] = alpha
|
||||
|
||||
# update target modules if some lora layers have been removed due to their EVA rank being 0
|
||||
new_target_modules = new_target_modules + missing_eva_inits
|
||||
if len(new_target_modules) >= MIN_TARGET_MODULES_FOR_OPTIMIZATION:
|
||||
new_target_modules = _find_minimal_target_modules(new_target_modules, other_module_names)
|
||||
model.peft_config[adapter_name].target_modules = new_target_modules
|
||||
|
||||
# set rank pattern obtained from EVA
|
||||
model.peft_config[adapter_name].rank_pattern = rank_pattern
|
||||
|
||||
# when adjust_scaling_factors is True, lora scaling factors have been adjusted after the rank redistribution
|
||||
model.peft_config[adapter_name].alpha_pattern = alpha_pattern
|
||||
|
||||
if missing_eva_inits:
|
||||
warnings.warn(
|
||||
"the following layers were initialized with init_lora_weights=True because they "
|
||||
f"were not found in the eva state_dict: {missing_eva_inits}\ncurrently the "
|
||||
f"following lora modules are not supported by EVA: {UNSUPPORTED_LORA_MODULES}"
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_eva_state_dict(
|
||||
model: torch.nn.Module,
|
||||
dataloader: Iterable,
|
||||
peft_config: Optional[LoraConfig] = None,
|
||||
forward_fn: Optional[callable] = forward_fn_dict,
|
||||
prepare_model_inputs_fn: Optional[callable] = prepare_model_inputs_fn_language_modeling,
|
||||
prepare_layer_inputs_fn: Union[callable, dict[str, callable], None] = prepare_layer_inputs_fn_language_modeling,
|
||||
adapter_name: str = "default",
|
||||
gather_distributed_inputs: bool = True,
|
||||
show_progress_bar: bool = True,
|
||||
) -> dict:
|
||||
"""
|
||||
Compute the SVD for each layer in the model.
|
||||
|
||||
This function computes the Singular Value Decomposition (SVD) for each layer in the model. It uses the incremental
|
||||
PCA method to compute the SVD components. The function also checks for convergence of the computed components using
|
||||
cosine similarity. The rank distribution for each layer is determined based on the explained variance ratio.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model to compute the SVD for. Does not need to be a PeftModel.
|
||||
dataloader (Iterable): The dataloader to use for the forward pass.
|
||||
peft_config (Optional[LoraConfig]):
|
||||
The configuration for the LoRA layers. Only required if `model` is not a PeftModel.
|
||||
forward_fn (callable):
|
||||
The forward function to use for the forward pass. Takes two arguments: `model` and `inputs`. Default
|
||||
behavior is `return model(**inputs)`
|
||||
prepare_model_inputs_fn (Optional[callable]):
|
||||
This function receives the model inputs and the peft_config and passes the output to
|
||||
`prepare_layer_inputs_fn`. Can be used to modify the input to the SVD computation based on the original
|
||||
model inputs. For example for language modeling the attention mask is used to determine which indices are
|
||||
padding tokens and should not be used for SVD. Any function defined here expects two arguments:
|
||||
`model_input` and `peft_config`. `peft.tuners.lora.eva.prepare_model_inputs_fn_language_modeling` is used
|
||||
by default.
|
||||
prepare_layer_inputs_fn (Union[callable, Dict[str, callable], None]):
|
||||
This function receives the layer inputs, the model inputs (potentially modified by
|
||||
`prepare_model_inputs_fn`) and the name of the layer and returns the inputs that should be used for SVD for
|
||||
that particular layer. Any custom function defined here expects three arguments: `layer_input`,
|
||||
`model_input`, and `layer_name` and should return a 2d tensor. The default logic can be found in
|
||||
peft.tuners.lora.eva.prepare_layer_inputs_fn_language_modeling and works for language modeling. In this
|
||||
case model_inputs is the mask used to determine which indices should be used for SVD (created by
|
||||
`prepare_model_inputs_fn_language_modeling`).
|
||||
adapter_name (str): The name of the adapter to compute the SVD for.
|
||||
gather_distributed_inputs (bool):
|
||||
Whether to gather the layer inputs from all ranks. Default is True meaning in a distributed setting the
|
||||
layer inputs will be gathered from all ranks for the SVD computation. For non-distributed settings this
|
||||
argument is ignored. Set to False if you are using a non-distributed dataloader in a distributed setting.
|
||||
show_progress_bar (bool): Whether to show a progress bar. Default is True.
|
||||
|
||||
Returns:
|
||||
eva_state_dict (dict): The state dictionary containing the SVD components for each layer.
|
||||
"""
|
||||
|
||||
def target_module_check_fn_peft_model(name, module, unsupported_lora_modules):
|
||||
"check if a module is an adapter module via base_layer attribute"
|
||||
return hasattr(module, "base_layer") and not isinstance(module, unsupported_lora_modules)
|
||||
|
||||
def target_module_check_fn_default(name, module, peft_config):
|
||||
"check if a module is an adapter module via target_modules"
|
||||
is_target_module = True
|
||||
if peft_config.target_modules is not None:
|
||||
is_target_module = check_target_module_exists(peft_config, name)
|
||||
# Conv1D for GPT2 support
|
||||
return isinstance(module, (torch.nn.Linear, Conv1D)) and is_target_module
|
||||
|
||||
is_peft_model = hasattr(model, "peft_config")
|
||||
|
||||
# get peft_config
|
||||
if is_peft_model and peft_config is None:
|
||||
peft_config = model.peft_config[adapter_name]
|
||||
elif peft_config is None:
|
||||
raise ValueError("peft_config is required if model is not a PeftModel")
|
||||
|
||||
# setup context and target module check function
|
||||
if is_peft_model:
|
||||
ctx = model.disable_adapter()
|
||||
target_module_check_fn = partial(
|
||||
target_module_check_fn_peft_model, unsupported_lora_modules=UNSUPPORTED_LORA_MODULES
|
||||
)
|
||||
else:
|
||||
ctx = nullcontext()
|
||||
target_module_check_fn = partial(target_module_check_fn_default, peft_config=peft_config)
|
||||
|
||||
with ctx:
|
||||
eva_state_dict = _get_eva_state_dict(
|
||||
model=model,
|
||||
dataloader=dataloader,
|
||||
peft_config=peft_config,
|
||||
target_module_check_fn=target_module_check_fn,
|
||||
forward_fn=forward_fn,
|
||||
prepare_model_inputs_fn=prepare_model_inputs_fn,
|
||||
prepare_layer_inputs_fn=prepare_layer_inputs_fn,
|
||||
gather_distributed_inputs=gather_distributed_inputs,
|
||||
show_progress_bar=show_progress_bar,
|
||||
)
|
||||
return eva_state_dict
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def initialize_lora_eva_weights(
|
||||
model: torch.nn.Module,
|
||||
dataloader: Optional[Iterable] = None,
|
||||
eva_state_dict: Optional[dict] = None,
|
||||
forward_fn: Optional[callable] = forward_fn_dict,
|
||||
prepare_model_inputs_fn: Optional[callable] = prepare_model_inputs_fn_language_modeling,
|
||||
prepare_layer_inputs_fn: Union[callable, dict[str, callable], None] = prepare_layer_inputs_fn_language_modeling,
|
||||
adapter_name: str = "default",
|
||||
gather_distributed_inputs: bool = True,
|
||||
show_progress_bar: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize the weights of the LoRA layers using the EVA method.
|
||||
|
||||
This function initializes the weights of the LoRA layers using the EVA method. It computes the SVD for each adapter
|
||||
layer and updates the weights accordingly.
|
||||
|
||||
Args:
|
||||
model (PeftModel): The peft model to compute the SVD for.
|
||||
dataloader (Optional[Iterable]):
|
||||
The dataloader to use for the forward pass. If None, eva_state_dict needs to be provided.
|
||||
eva_state_dict (Optional[dict]):
|
||||
The state_dict to load into the model. If None, a dataloader needs to be provided and the state_dict will
|
||||
be computed using `get_eva_state_dict`.
|
||||
forward_fn (callable):
|
||||
The forward function to use for the forward pass. Takes two arguments: `model` and `inputs`. Default
|
||||
behavior is `return model(**inputs)`
|
||||
prepare_model_inputs_fn (Optional[callable]):
|
||||
This function receives the model inputs and the peft_config and passes the output to
|
||||
`prepare_layer_inputs_fn`. Can be used to modify the input to the SVD computation based on the original
|
||||
model inputs. For example for language modeling the attention mask is used to determine which indices are
|
||||
padding tokens and should not be used for SVD. Any function defined here expects two arguments:
|
||||
`model_input` and `peft_config`. `peft.tuners.lora.eva.prepare_model_inputs_fn_language_modeling` is used
|
||||
by default.
|
||||
prepare_layer_inputs_fn (Union[callable, Dict[str, callable], None]):
|
||||
This function receives the layer inputs, the model inputs (potentially modified by
|
||||
`prepare_model_inputs_fn`) and the name of the layer and returns the inputs that should be used for SVD for
|
||||
that particular layer. Any custom function defined here expects three arguments: `layer_input`,
|
||||
`model_input`, and `layer_name` and should return a 2d tensor. The default logic can be found in
|
||||
peft.tuners.lora.eva.prepare_layer_inputs_fn_language_modeling and works for language modeling. In this
|
||||
case model_inputs is the mask used to determine which indices should be used for SVD (created by
|
||||
`prepare_model_inputs_fn_language_modeling`).
|
||||
adapter_name (str): The name of the adapter to initialize the weights for.
|
||||
gather_distributed_inputs (bool):
|
||||
Whether to gather the layer inputs from all ranks. Default is True meaning in a distributed setting the
|
||||
layer inputs will be gathered from all ranks for the SVD computation. For non-distributed settings this
|
||||
argument is ignored. Set to False if you are using a non-distributed dataloader in a distributed setting.
|
||||
show_progress_bar (bool): Whether to show a progress bar. Default is True.
|
||||
|
||||
Returns:
|
||||
model (torch.nn.Module): The model with the initialized LoRA weights.
|
||||
"""
|
||||
if not hasattr(model, "peft_config"):
|
||||
raise ValueError("model must be a PeftModel")
|
||||
|
||||
# eva currently only works with a single active adapter
|
||||
# Important: when removing this requirement, make sure eva init works correctly if the new rank is 0.
|
||||
if len(model.active_adapters) > 1:
|
||||
raise ValueError("`initialize_lora_eva_weights` currently only works with a single active adapter")
|
||||
|
||||
# initialize_lora_eva_weights only works with `init_lora_weights='eva'`
|
||||
if model.peft_config[adapter_name].init_lora_weights != "eva":
|
||||
raise ValueError("`initialize_lora_eva_weights` can only be used with `init_lora_weights='eva'`")
|
||||
|
||||
# compute svd
|
||||
if eva_state_dict is None:
|
||||
if dataloader is None:
|
||||
raise ValueError("dataloader is required if eva_state_dict is not provided")
|
||||
eva_state_dict = get_eva_state_dict(
|
||||
model=model,
|
||||
dataloader=dataloader,
|
||||
forward_fn=forward_fn,
|
||||
prepare_model_inputs_fn=prepare_model_inputs_fn,
|
||||
prepare_layer_inputs_fn=prepare_layer_inputs_fn,
|
||||
adapter_name=adapter_name,
|
||||
gather_distributed_inputs=gather_distributed_inputs,
|
||||
show_progress_bar=show_progress_bar,
|
||||
)
|
||||
|
||||
_load_eva_state_dict(model, eva_state_dict, adapter_name)
|
||||
@@ -0,0 +1,96 @@
|
||||
# Copyright 2024-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
|
||||
from utils import DataCollator, TokenizerMetaMath
|
||||
|
||||
from peft import EvaConfig, LoraConfig, get_peft_model, initialize_lora_eva_weights
|
||||
|
||||
|
||||
DEVICE = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
|
||||
|
||||
# config
|
||||
model_name = "meta-llama/Llama-3.1-8B"
|
||||
max_seq_len = 512
|
||||
rank = 16
|
||||
alpha = 1
|
||||
rho = 2.0
|
||||
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
|
||||
svd_batch_size = 4 # can be different from the batch size used in finetuning
|
||||
batch_size = 4
|
||||
learning_rate = 5e-4
|
||||
gradient_accumulation_steps = 8
|
||||
num_epochs = 1
|
||||
output_dir = "outputs"
|
||||
bf16 = True
|
||||
|
||||
|
||||
# load model and tokenizer
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
# load dataset
|
||||
dataset = load_dataset("meta-math/MetaMathQA")
|
||||
dataset = dataset.map(
|
||||
TokenizerMetaMath(model_name),
|
||||
batched=True,
|
||||
remove_columns=dataset["train"].column_names,
|
||||
)
|
||||
dataset.set_format(type="torch")
|
||||
|
||||
# data collator
|
||||
data_collator = DataCollator(tokenizer.eos_token_id, max_length=max_seq_len)
|
||||
|
||||
# dataloader
|
||||
dataloader = DataLoader(
|
||||
dataset["train"],
|
||||
batch_size=svd_batch_size,
|
||||
collate_fn=data_collator,
|
||||
)
|
||||
|
||||
# setup peft config
|
||||
eva_config = EvaConfig(rho=rho)
|
||||
peft_config = LoraConfig(
|
||||
r=rank, lora_alpha=alpha, target_modules=target_modules, init_lora_weights="eva", eva_config=eva_config
|
||||
)
|
||||
|
||||
# move model to accelerator
|
||||
model = model.to(DEVICE)
|
||||
|
||||
# to optimize memory usage during eva initialization, set low_cpu_mem_usage=True
|
||||
peft_model = get_peft_model(model, peft_config, low_cpu_mem_usage=True)
|
||||
initialize_lora_eva_weights(peft_model, dataloader)
|
||||
|
||||
# setup training arguments
|
||||
training_args = TrainingArguments(
|
||||
per_device_train_batch_size=batch_size,
|
||||
learning_rate=learning_rate,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
num_train_epochs=num_epochs,
|
||||
output_dir=output_dir,
|
||||
remove_unused_columns=False,
|
||||
bf16=bf16,
|
||||
)
|
||||
|
||||
# continue with standard finetuning
|
||||
trainer = Trainer(
|
||||
model=peft_model,
|
||||
args=training_args,
|
||||
train_dataset=dataset["train"],
|
||||
data_collator=data_collator,
|
||||
)
|
||||
trainer.train()
|
||||
Reference in New Issue
Block a user