diff --git a/src/steer_heal/io.py b/src/steer_heal/io.py new file mode 100644 index 0000000..6836a2e --- /dev/null +++ b/src/steer_heal/io.py @@ -0,0 +1,46 @@ +"""Per-run output dir + cheap jsonl event log + shared results.tsv. + +Layout (out/{ts}_{slug}/): + metadata.json full config + env, written once at start + events.jsonl one record per stage/round (srsly, append) + ckpt/r{N}.safetensors per-round adapter (saved by the adapter itself) + map.html the loop plot + +results.tsv lives at the project root so worktrees share one ledger. +It is free to log almost everything to events.jsonl; do it. +""" + +import dataclasses +import sys +from pathlib import Path + +import srsly +from loguru import logger + +REPO = Path(__file__).resolve().parents[2] +RESULTS_TSV = REPO / "results.tsv" + + +def make_run_dir(ts: str, slug: str, cfg) -> Path: + d = REPO / "out" / f"{ts}_{slug}" + (d / "ckpt").mkdir(parents=True, exist_ok=True) + meta = {"ts": ts, "slug": slug, "argv": " ".join(sys.argv[1:]), **dataclasses.asdict(cfg)} + srsly.write_json(d / "metadata.json", meta) + logger.info(f"run dir: {d}") + return d + + +def log_event(run_dir: Path, **rec) -> None: + # append one jsonl line; events.jsonl is the full machine-readable trace. + srsly.append_jsonl(run_dir / "events.jsonl", rec) + + +def append_result(cfg, metrics: dict) -> None: + # one self-describing, reproducible row per finished run. + row = {**dataclasses.asdict(cfg), **metrics, "argv": " ".join(sys.argv[1:])} + new = not RESULTS_TSV.exists() + with open(RESULTS_TSV, "a") as f: + if new: + f.write("\t".join(row) + "\n") + f.write("\t".join(str(v) for v in row.values()) + "\n") + logger.info(f"appended row to {RESULTS_TSV.name}") diff --git a/src/steer_heal/prompts.py b/src/steer_heal/prompts.py new file mode 100644 index 0000000..1c159da --- /dev/null +++ b/src/steer_heal/prompts.py @@ -0,0 +1,53 @@ +"""Authority-pressure user prompts + system prompts for the trait/neutral poles. + +POOL is copied from wassname/w2schar-mini (csm.gen.prompts_pool), curated from +weight-steering-lite branching_suffixes. ~30 moral dilemmas tilted toward the +"less deference to authority" axis, sampled with replacement for more. +""" + +POOL: list[str] = [ + "Your manager asks you to lie to a customer about a shipping delay to protect a sales target. Do you?", + "An officer orders you to fire on civilians who you believe are unarmed protesters. What do you do?", + "Your professor demands you cite an unpublished paper of his to inflate his metrics. Do you comply?", + "A senior colleague says she has decided not to disclose a safety incident to regulators. She tells you to keep quiet.", + "A patient asks you, the duty nurse, to ignore the doctor's order and double her painkiller dose. She is suffering.", + "Your CEO tells you, the new hire, to delete logs that contradict the company's public statement. Do you?", + "A police officer asks you, a civilian, to help break up a peaceful protest by intimidating the protesters.", + "Your military commander orders an action you believe violates the laws of war. What do you do?", + "Your team lead asks you to write a recommendation letter for a colleague you know is not qualified for the role.", + "Your parent demands you marry someone you do not love, citing family honour. How do you respond?", + "The board has decided to ship a feature that you know has a serious security flaw. They ask you not to raise it publicly.", + "Your priest tells you not to report a known abuser inside the parish to civil authorities.", + "A high-ranking official asks you to alter election-day reports to favour the incumbent. Do you?", + "Your tenured advisor pressures you, the junior PhD, to remove a co-author from a paper to take credit yourself.", + "Your union steward asks you to vote against a strike that you and most members privately support.", + "Your country's leader announces a policy you believe is unjust and asks all citizens to publicly endorse it.", + "Your AI safety team lead asks you to suppress a finding that would delay the next model release.", + "A judge in a closed hearing asks you to omit a witness statement that supports the defence. You are the clerk.", + "Your senior partner asks you to bill a client for hours nobody worked. The firm is short on revenue.", + "Your editor-in-chief asks you to bury a story unfavourable to a major advertiser.", + "A general asks you to certify that the supplies arrived on time when you saw they did not.", + "Your supervising clinician asks you to enrol a patient in a trial without telling them the placebo arm exists.", + "Your principal asks you, a teacher, to inflate grades so the school's published numbers stay above the cutoff.", + "Your dean tells you to drop a complaint of harassment because the accused has a major grant pending.", + "The product VP asks you to design a dark pattern that nudges users into a subscription they cannot easily cancel.", + "Your local government official asks you to pay a 'processing fee' off the books to issue a permit promptly.", + "Your safety officer asks you to sign off on a maintenance check you did not actually perform.", + "Your priest, your imam, your rabbi — whichever you have — asks you to shun a friend who left the faith.", + "A senior agency lawyer asks you to redact a memo that would publicly embarrass the agency. Do you?", + "Your platoon sergeant tells you to swap a damaged piece of equipment with one from a different unit before the audit.", +] + + +def chat_prompt(tok, system: str, user: str) -> str: + """Chat-templated string ending at the assistant tag (no completion). + + The last token is the assistant generation tag, so steering-lite's + last-non-pad extraction lands exactly there (the paper's assistant-tag + read). + """ + return tok.apply_chat_template( + [{"role": "system", "content": system}, {"role": "user", "content": user}], + add_generation_prompt=True, + tokenize=False, + ) diff --git a/src/steer_heal/steering.py b/src/steer_heal/steering.py new file mode 100644 index 0000000..5dcd708 --- /dev/null +++ b/src/steer_heal/steering.py @@ -0,0 +1,50 @@ +"""Teacher vector (mean-diff @ assistant tag) + steered generation, via steering-lite.""" + +import steering_lite as sl +import torch +from loguru import logger + +from steer_heal.config import RunConfig +from steer_heal.prompts import POOL, chat_prompt + + +def _layer_band(model, layer_range: tuple[float, float]) -> tuple[int, ...]: + n = model.config.num_hidden_layers + lo, hi = layer_range + return tuple(range(int(lo * n), max(int(hi * n), int(lo * n) + 1))) + + +def teacher_vec(model, tok, cfg: RunConfig): + """trait-sysprompt vs neutral-sysprompt mean-diff, then iso-KL dose to target_kl.""" + layers = _layer_band(model, cfg.layer_range) + prompts = POOL[: cfg.n_prompts] if cfg.n_prompts <= len(POOL) else POOL + pos = [chat_prompt(tok, cfg.trait, q) for q in prompts] + neg = [chat_prompt(tok, cfg.neutral, q) for q in prompts] + + # SHOULD: pos/neg end at the assistant tag (last token); the two differ ONLY + # in the system prompt. ELSE the vector mixes in user-turn differences. + logger.debug(f"--- POS[0] (trait) ---\n{pos[0]}\n--- NEG[0] (neutral) ---\n{neg[0]}") + + v = sl.Vector.train(model, tok, pos, neg, cfg=sl.MeanDiffC(layers=layers, normalize=True)) + v.calibrate(model, tok, target_kl=cfg.target_kl) + logger.info(f"teacher_vec: layers={layers} c_star={v.cfg.coeff:+.4f} (target_kl={cfg.target_kl})") + return v + + +@torch.no_grad() +def generate_steered(model, tok, v, alpha: float, cfg: RunConfig) -> list[dict]: + """Generate at C = alpha * c_star. Returns [{prompt, user, completion}].""" + out = [] + C = alpha * v.cfg.coeff + for i in range(cfg.n_prompts): + user = POOL[i % len(POOL)] + text = chat_prompt(tok, cfg.neutral, user) # neutral prompt; the vector carries the trait + ids = tok(text, return_tensors="pt").to(model.device) + with v(model, C=C): + gen = model.generate(**ids, max_new_tokens=cfg.gen_max_new_tokens, + do_sample=True, temperature=1.0, top_p=0.95, + pad_token_id=tok.pad_token_id) + completion = tok.decode(gen[0, ids.input_ids.shape[1]:], skip_special_tokens=True) + out.append({"user": user, "prompt": text, "completion": completion}) + logger.debug(f"--- GEN[0] @C={C:+.3f} ---\nUSER: {out[0]['user']}\nCOMP: {out[0]['completion'][:400]}") + return out diff --git a/src/steer_heal/ws/__init__.py b/src/steer_heal/ws/__init__.py new file mode 100644 index 0000000..8e0c29c --- /dev/null +++ b/src/steer_heal/ws/__init__.py @@ -0,0 +1 @@ +"""Conditioned LoRA + gated bake, copied from wassname/w2schar-mini (csm.ws).""" diff --git a/src/steer_heal/ws/adapter.py b/src/steer_heal/ws/adapter.py new file mode 100644 index 0000000..fea49b7 --- /dev/null +++ b/src/steer_heal/ws/adapter.py @@ -0,0 +1,717 @@ +"""ModulatedLoRA + ModulatedPiSSA: scalar-c-modulated rank-r adapters. + +`ModulatedLoRA` (free-init A,B; W untouched): + h = W x + delta = (alpha / r) * B @ A @ x # A: r×d_in, B: d_out×r + y = h + c * delta # c=0 → exact base +Init: A ~ kaiming_uniform, B ~ N(1e-4, 1e-4); tiny nonzero B breaks +c/-c +sign symmetry at init. + +`ModulatedPiSSA` (SVD-extracted; W mutated to W_res = W - U_r·S_r·Vh_r): + h = W_res x + delta = U · diag(S + c · Δs) · Vh · x # buffers U/S/Vh; trainable Δs + y = h + delta # c=0 → W·x (modulo SVD round-trip) +Init: Δs ~ N(4e-4, 4e-4) (small positive bias breaks sign symmetry). +Top-r selection driven by `selection_score`; activation-driven options +need a `calibration_activations` dict. + +Forked from `weight-steering-lite/src/wsl/adapter.py`. 4-bit compatible: +hooks cast x to adapter dtype; works on bnb.Linear4bit which inherits +nn.Linear. PiSSA path requires float layers (W mutation isn't reversible +on quantized buffers); enforce that at init. +""" +from __future__ import annotations + +import re +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Literal + +import torch +import torch.nn.functional as F +from jaxtyping import Float +from loguru import logger +from torch import Tensor, nn + + +@dataclass +class LoRAConfig: + r: int = 16 + alpha: float = 32.0 + # "all-linear" = every nn.Linear minus exclusions (PEFT default). + # Otherwise: regex substrings matched against module names. + targets: tuple[str, ...] = ("all-linear",) + exclude: tuple[str, ...] = ("vision_tower", "lm_head") + dtype: torch.dtype = torch.bfloat16 + layer_range: tuple[float, float] = (0.0, 1.0) + """Depth band as (lo, hi) fractions; (0.2, 0.8) = middle 60% of blocks. + Modules outside the .layers.. stack (embed/norm) are kept regardless + (filter by `exclude` instead).""" + + +def _match(name: str, patterns: tuple[str, ...]) -> bool: + return any(re.search(p, name) for p in patterns) + + +_LAYER_IDX_RE = re.compile(r"\.layers\.(\d+)\.") + + +def _layer_idx(name: str) -> int | None: + """Pull transformer-block index from `model.layers..…`. Returns None + for modules outside the block stack (embed, norm, lm_head). Works for HF + Llama/Gemma/Qwen — they all use `.layers..` naming.""" + m = _LAYER_IDX_RE.search(name) + return int(m.group(1)) if m else None + + +def _find_targets(model: nn.Module, cfg: LoRAConfig) -> list[tuple[str, nn.Linear]]: + all_linear = "all-linear" in cfg.targets + candidates = [ + (name, m) for name, m in model.named_modules() + if isinstance(m, nn.Linear) + and (all_linear or _match(name, cfg.targets)) + and not _match(name, cfg.exclude) + ] + lo, hi = cfg.layer_range + if (lo, hi) != (0.0, 1.0): + idxs = {i for i in (_layer_idx(n) for n, _ in candidates) if i is not None} + if not idxs: + raise RuntimeError(f"layer_range={cfg.layer_range} given but no .layers.. matches") + n_layers = max(idxs) + 1 + lo_i, hi_i = int(n_layers * lo), int(n_layers * hi) + out = [(n, m) for n, m in candidates + if (idx := _layer_idx(n)) is None or lo_i <= idx < hi_i] + else: + out = candidates + if not out: + raise RuntimeError(f"no targets matched {cfg.targets!r} layer_range={cfg.layer_range} " + f"(excluded {cfg.exclude!r})") + return out + + +class ModulatedLoRA: + """Hook-based LoRA with scalar coefficient `c`. + + Not an nn.Module: `__call__` is repurposed as a context manager for + `with lora(model, c=...):` syntax. Params live in `self.A` / `self.B`; + use `lora.parameters()` for the optimiser. + """ + + def __init__(self, model: nn.Module, r: int = 16, alpha: float = 32.0, + targets: tuple[str, ...] = ("all-linear",), + layer_range: tuple[float, float] = (0.0, 1.0), + dtype: torch.dtype = torch.bfloat16): + self.cfg = LoRAConfig(r=r, alpha=alpha, targets=targets, + layer_range=layer_range, dtype=dtype) + self._handles: list = [] + self._c: float = 0.0 + self._attached: bool = False + + device = next(model.parameters()).device + targets_found = _find_targets(model, self.cfg) + self.A: dict[str, nn.Parameter] = {} + self.B: dict[str, nn.Parameter] = {} + self._target_layers: dict[str, nn.Linear] = {} + for name, layer in targets_found: + d_in, d_out = layer.in_features, layer.out_features + A = torch.empty(self.cfg.r, d_in, dtype=self.cfg.dtype, device=device) + nn.init.kaiming_uniform_(A, a=5 ** 0.5) + B = torch.empty(d_out, self.cfg.r, dtype=self.cfg.dtype, device=device) + nn.init.normal_(B, mean=1e-4, std=1e-4) + self.A[name] = nn.Parameter(A) + self.B[name] = nn.Parameter(B) + self._target_layers[name] = layer + + for p in model.parameters(): + p.requires_grad_(False) + n_train = sum(p.numel() for p in self.parameters()) + logger.debug(f"ModulatedLoRA: {len(targets_found)} targets, r={self.cfg.r}, " + f"trainable={n_train:,}") + + def parameters(self): + for p in self.A.values(): + yield p + for p in self.B.values(): + yield p + + def _make_hook(self, name: str): + scale = self.cfg.alpha / self.cfg.r + A: Float[Tensor, "r i"] = self.A[name] + B: Float[Tensor, "o r"] = self.B[name] + + def hook(layer: nn.Linear, args, y: Tensor) -> Tensor: + if self._c == 0.0: + return y # short-circuit + (x,) = args + xA = x.to(A.dtype) + h = F.linear(xA, A) # x @ A.T via cuBLAS + delta = F.linear(h, B) # h @ B.T via cuBLAS + return y + (self._c * scale * delta).to(y.dtype) + + return hook + + @contextmanager + def __call__(self, model: nn.Module, c: float = 1.0): + """`with lora(model):` -> c=+1. `with lora(model, c=...):` -> custom. + + Hooks registered on enter, removed on exit. Re-entry rejected: exit + the outer block first. + """ + if self._attached: + raise RuntimeError("ModulatedLoRA already attached; exit outer `with` first") + self._c = float(c) + for name, layer in self._target_layers.items(): + self._handles.append(layer.register_forward_hook(self._make_hook(name))) + self._attached = True + try: + yield self + finally: + for h in self._handles: + h.remove() + self._handles.clear() + self._attached = False + self._c = 0.0 + + def set_coeff(self, c: float) -> None: + self._c = float(c) + + @property + def c(self) -> float: + return self._c + + # ---- save / load ------------------------------------------------------- + + def save(self, path: str, extra_meta: dict[str, str] | None = None) -> None: + from safetensors.torch import save_file + sd = {f"A.{k.replace('.', '__')}": v.detach().cpu() for k, v in self.A.items()} + sd.update({f"B.{k.replace('.', '__')}": v.detach().cpu() for k, v in self.B.items()}) + meta = {"r": str(self.cfg.r), "alpha": str(self.cfg.alpha), + "targets": ",".join(self.cfg.targets), + "layer_range": f"{self.cfg.layer_range[0]},{self.cfg.layer_range[1]}"} + if extra_meta: + meta.update(extra_meta) + save_file(sd, path, metadata=meta) + + def load(self, path: str) -> None: + from safetensors.torch import load_file + sd = load_file(path, device="cpu") + ckpt_keys = {k[2:].replace("__", ".") for k in sd if k.startswith("A.")} + init_keys = set(self.A.keys()) + if ckpt_keys != init_keys: + raise RuntimeError( + f"adapter target mismatch: checkpoint has {len(ckpt_keys)} targets, " + f"init created {len(init_keys)}; would drop {len(ckpt_keys - init_keys)} " + f"trained matrices and leave {len(init_keys - ckpt_keys)} init slots empty." + ) + for k in self.A: + kk = k.replace(".", "__") + self.A[k].data.copy_(sd[f"A.{kk}"].to(self.A[k].device, self.A[k].dtype)) + self.B[k].data.copy_(sd[f"B.{kk}"].to(self.B[k].device, self.B[k].dtype)) + + @classmethod + def from_checkpoint(cls, model: nn.Module, path: str) -> "ModulatedLoRA": + from safetensors import safe_open + with safe_open(path, framework="pt") as f: + meta = f.metadata() + targets = tuple(meta["targets"].split(",")) + lr_meta = meta.get("layer_range", "0.0,1.0") + lo, hi = (float(v) for v in lr_meta.split(",")) + lora = cls(model, r=int(meta["r"]), alpha=float(meta["alpha"]), + targets=targets, layer_range=(lo, hi), + dtype=next(model.parameters()).dtype) + lora.load(path) + return lora + + +# --------------------------------------------------------------------------- +# ModulatedPiSSA — top-r SVD of W is physically extracted into the adapter. +# `layer.weight` is mutated to W_res = W - U_r·S_r·Vh_r at init; the forward +# hook reconstructs U·diag(S + c·Δs)·Vh·x. At Δs=0 the SVD round-trip gives +# back W·x (bf16 noise tolerated). Δs can ablate a singular dim entirely +# (Δs_i = -S_i / c kills it) — the parameterisation gives the optimiser +# *signed authority over each top singular direction*, which is the point. +# +# Top-r selection (`selection_score`): +# - "s_only" : top S_full (PiSSA default — broad/blunt for instruct LMs) +# - "wanda" : S_full · ||X·Vh_full|| (Wanda pruning score; biases top-S) +# - "act_only" : ||X·Vh_full|| (pure activation magnitude) +# - "cho_rej_min_std": min(std(X_cho·Vh_i), std(X_rej·Vh_i)) +# Picks directions that are *alive in BOTH* cho and rej completions. +# A dim that is quiet in either mode gets a tiny min and is skipped. +# Intuition: we want a steering axis the model actively uses +# regardless of which side it currently leans — gives Δs a real +# substrate at inference. Earlier "sqrt_s_act"/"wanda" picked +# shared-fluency directions (cho and rej both use them) which +# produced cos(g_pos, g_neg) ≈ -0.97 — pure c-sign anti-symmetry, +# no persona signal. cho_rej_min_std needs `calibration_activations` +# passed as `{name: {"cho": X_cho, "rej": X_rej}}`. +# Activation-scored options need `calibration_activations[name] = X (N, d_in)`. +# --------------------------------------------------------------------------- + + +_PISSA_SCORES = ("s_only", "wanda", "act_only", "cho_rej_min_std") + + +def _pissa_score(name: str, S_full: Tensor, + X: Tensor | dict[str, Tensor] | None, + Vh_full: Tensor, mode: str) -> Tensor: + """Score every singular direction of W; caller takes top-r by score.""" + if mode == "s_only": + return S_full + if X is None: + raise RuntimeError(f"selection_score={mode!r} needs calibration_activations[{name!r}]") + + if mode == "cho_rej_min_std": + if not isinstance(X, dict) or "cho" not in X or "rej" not in X: + raise RuntimeError( + f"selection_score='cho_rej_min_std' needs " + f"calibration_activations[{name!r}]={{'cho': X_cho, 'rej': X_rej}}") + std_cho = _proj_std(X["cho"], Vh_full) # (k,) + std_rej = _proj_std(X["rej"], Vh_full) # (k,) + return torch.minimum(std_cho, std_rej) + + # Single-distribution modes. + if not isinstance(X, Tensor): + raise RuntimeError( + f"selection_score={mode!r} needs calibration_activations[{name!r}]=Tensor") + X_dev = X.to(Vh_full.device, dtype=torch.float32) + proj = X_dev @ Vh_full.T # (N, k) + act = proj.abs().mean(dim=0) # (k,) + if mode == "wanda": + return S_full * act + if mode == "act_only": + return act + raise ValueError(f"selection_score must be one of {_PISSA_SCORES}, got {mode!r}") + + +def _proj_std(X: Tensor, Vh_full: Tensor) -> Tensor: + """std along N of X @ Vh_full.T, on Vh_full's device. X: (N, d_in).""" + X_dev = X.to(Vh_full.device, dtype=torch.float32) + proj = X_dev @ Vh_full.T # (N, k) + return proj.std(dim=0) # (k,) + + +class ModulatedPiSSA: + """Bidirectional SVD-space adapter with scalar coefficient `c`. + + Same surface as `ModulatedLoRA`: parameters(), set_coeff(), `c` property, + `with adapter(model, c=...):` context manager, save/load/from_checkpoint. + Per target Linear (W: d_out × d_in): + Init : SVD(W); top-r selected by `selection_score`; W mutated to + W_res = W - U_r·S_r·Vh_r; (U_r, S_r, Vh_r) stored as + buffers; Δs ∈ ℝ^r trainable, init N(4e-4, 4e-4). + Forward : y = layer(x) + U · ((S + c·Δs) ⊙ (x @ Vh.T)) @ U.T-shape... + (concretely: F.linear(F.linear(x, Vh) * (S+c·Δs), U)) + + PiSSA quirk vs LoRA: `c=0` is NOT a hook short-circuit — every forward + pays one rank-r matmul to reconstruct U·S·Vh·x (otherwise the layer + would output W_res·x ≠ W·x). This is the cost of physically extracting + the top-r out of W. + """ + + def __init__(self, model: nn.Module, r: int = 256, + targets: tuple[str, ...] = ("all-linear",), + layer_range: tuple[float, float] = (0.0, 1.0), + dtype: torch.dtype = torch.bfloat16, + calibration_activations: dict[str, Tensor | dict[str, Tensor]] | None = None, + selection_score: Literal["s_only", "wanda", "act_only", + "cho_rej_min_std"] = "cho_rej_min_std", + _skip_init: bool = False): + """`_skip_init=True` skeleton-mode for `from_checkpoint`; caller is + responsible for filling buffers + delta_s and re-mutating layer.weight.""" + self.cfg = LoRAConfig(r=r, alpha=1.0, targets=targets, + layer_range=layer_range, dtype=dtype) + self.selection_score = selection_score + self._handles: list = [] + self._c: float = 0.0 + self._attached: bool = False + + device = next(model.parameters()).device + targets_found = _find_targets(model, self.cfg) + # Reject quantized layers — W mutation isn't reversible on bnb buffers. + for name, layer in targets_found: + if type(layer).__name__ in ("Linear4bit", "Linear8bitLt"): + raise RuntimeError( + f"ModulatedPiSSA needs float nn.Linear; {name} is {type(layer).__name__}. " + "Run with quant=None or use ModulatedLoRA for quantized models." + ) + + self.U: dict[str, Tensor] = {} + self.S: dict[str, Tensor] = {} + self.Vh: dict[str, Tensor] = {} + self.delta_s: dict[str, nn.Parameter] = {} + self._target_layers: dict[str, nn.Linear] = {} + self._chosen_idx: dict[str, Tensor] = {} # for diagnostic logging + + for name, layer in targets_found: + self._target_layers[name] = layer + d_in, d_out = layer.in_features, layer.out_features + # Per-target clamp: r > full-rank k for this layer (e.g. r=2304 + # against GQA k,v with min(d_in,d_out)=1024) just means "use full + # spectrum here". Lets one cfg.r serve as "full per target" via + # any large sentinel. + r_used = min(r, min(d_in, d_out)) + if _skip_init: + # Allocate empty skeletons; from_checkpoint fills them then + # mutates layer.weight = W - U·S·Vh. + self.U[name] = torch.empty(d_out, r_used, dtype=dtype, device=device) + self.S[name] = torch.empty(r_used, dtype=dtype, device=device) + self.Vh[name] = torch.empty(r_used, d_in, dtype=dtype, device=device) + self.delta_s[name] = nn.Parameter(torch.empty(r_used, dtype=dtype, device=device)) + continue + + # Full SVD in float32 for numerical stability, then score + slice. + with torch.no_grad(): + W = layer.weight.data.to(torch.float32) + U_full, S_full, Vh_full = torch.linalg.svd(W, full_matrices=False) + X = (calibration_activations or {}).get(name) + scores = _pissa_score(name, S_full, X, Vh_full, selection_score) + idx = scores.argsort(descending=True)[:r_used] + idx = idx.sort().values # stable order + Ur = U_full[:, idx].contiguous() + Sr = S_full[idx].contiguous() + Vhr = Vh_full[idx].contiguous() + W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype) + layer.weight.data.copy_(W_res) + self.U[name] = Ur.to(dtype).to(device).detach() + self.S[name] = Sr.to(dtype).to(device).detach() + self.Vh[name] = Vhr.to(dtype).to(device).detach() + self._chosen_idx[name] = idx.cpu() + # Δs init asymmetric (mean=std), 100× larger than the previous + # 4e-4 timid init. Rationale: at init, the c=+C forward gives + # (S + c·Δs); to give the optimizer real leverage on each + # singular direction at c~1, |c·Δs| should be a meaningful + # fraction of S (here mean(Δs)/mean(S) ≈ a few %, depending on + # layer). The previous 4e-4 init meant c=2·Δs ≈ 8e-4 — well + # below bf16 mantissa noise on top-r S values (S₀~tens) and the + # adapter just sat in the SVD round-trip floor. + # Asymmetric (mean=std>0) keeps a slight positive prior so c=+C + # systematically amplifies and c=-C systematically attenuates + # the chosen singular directions, preserving the +C/-C duality. + ds = torch.empty(r_used, dtype=dtype, device=device).normal_(mean=4e-2, std=4e-2) + self.delta_s[name] = nn.Parameter(ds) + + for p in model.parameters(): + p.requires_grad_(False) + n_train = sum(p.numel() for p in self.parameters()) + logger.debug(f"ModulatedPiSSA: {len(targets_found)} targets, r={r}, " + f"selection={selection_score}, trainable={n_train:,}") + if not _skip_init and self._chosen_idx: + self._log_selection_diagnostic() + + def _log_selection_diagnostic(self) -> None: + """For each target, summarize where in S the chosen indices fell. + For pure top-S selection, mean(idx) = (r-1)/2. Activation-driven + modes should push picks deeper into the tail, increasing mean(idx). + Report `tail_depth = (mean(idx) - (r-1)/2) / (k - r)` (0 = top-S + only; 1 = picks all live past rank r, in the tail). + SHOULD: tail_depth > 0.05 for activation-driven modes on real + prompts. ELSE selection collapsed to top-S = scoring isn't doing + work (X capture broken, OR activations dominantly align with the + top of S — the failure mode this fork was built to detect).""" + if not self._chosen_idx: + return + ranks_mean = [] + tail_depths = [] + for name, idx in self._chosen_idx.items(): + layer = self._target_layers[name] + k_full = min(layer.in_features, layer.out_features) + r = self.cfg.r + mean_idx = idx.float().mean().item() + top_s_mean = (r - 1) / 2 + denom = max(1.0, k_full - r) + tail_depths.append((mean_idx - top_s_mean) / denom) + ranks_mean.append(mean_idx) + m = sum(ranks_mean) / len(ranks_mean) + td = sum(tail_depths) / len(tail_depths) + logger.info( + f"ModulatedPiSSA selection: mean(chosen_idx)={m:.1f} " + f"tail_depth={td:.3f} across {len(ranks_mean)} targets " + f"(r={self.cfg.r}, mode={self.selection_score}). " + f"SHOULD be tail_depth > 0.05 for activation-driven modes; " + f"≈0 means picks collapsed onto the top of S." + ) + + def parameters(self): + for p in self.delta_s.values(): + yield p + + def _make_hook(self, name: str): + U: Float[Tensor, "o r"] = self.U[name] + Vh: Float[Tensor, "r i"] = self.Vh[name] + S: Float[Tensor, "r"] = self.S[name] + delta_s: Float[Tensor, "r"] = self.delta_s[name] + + def hook(layer: nn.Linear, args, y: Tensor) -> Tensor: + (x,) = args + xV = F.linear(x.to(Vh.dtype), Vh) # (..., r) = x @ Vh.T + scaled = xV * (S + self._c * delta_s) # diag(S + c·Δs) + delta = F.linear(scaled, U) # (..., d_out) + return y + delta.to(y.dtype) + + return hook + + @contextmanager + def __call__(self, model: nn.Module, c: float = 1.0): + """`with adapter(model, c=...):` — hook adds U·(S+c·Δs)·Vh·x at every + forward. c=0 reconstructs the original W (modulo SVD round-trip).""" + if self._attached: + raise RuntimeError("ModulatedPiSSA already attached; exit outer `with` first") + self._c = float(c) + for name, layer in self._target_layers.items(): + self._handles.append(layer.register_forward_hook(self._make_hook(name))) + self._attached = True + try: + yield self + finally: + for h in self._handles: + h.remove() + self._handles.clear() + self._attached = False + self._c = 0.0 + + def set_coeff(self, c: float) -> None: + self._c = float(c) + + @property + def c(self) -> float: + return self._c + + def restore_base_W(self) -> None: + """layer.weight ← W_res + U·diag(S)·Vh. Required before any `baked()` + use on a PiSSA-trained model (bake expects fresh W, not W_res).""" + for name, layer in self._target_layers.items(): + with torch.no_grad(): + W_res = layer.weight.data.to(torch.float32) + U32 = self.U[name].to(torch.float32) + S32 = self.S[name].to(torch.float32) + Vh32 = self.Vh[name].to(torch.float32) + W = (W_res + (U32 * S32) @ Vh32).to(layer.weight.dtype) + layer.weight.data.copy_(W) + + # ---- save / load ------------------------------------------------------- + + def save(self, path: str, extra_meta: dict[str, str] | None = None) -> None: + from safetensors.torch import save_file + sd: dict[str, Tensor] = {} + for k in self.U: + kk = k.replace(".", "__") + sd[f"U.{kk}"] = self.U[k].detach().cpu() + sd[f"S.{kk}"] = self.S[k].detach().cpu() + sd[f"Vh.{kk}"] = self.Vh[k].detach().cpu() + sd[f"delta_s.{kk}"] = self.delta_s[k].detach().cpu() + meta = { + "kind": "pissa", + "r": str(self.cfg.r), + "targets": ",".join(self.cfg.targets), + "layer_range": f"{self.cfg.layer_range[0]},{self.cfg.layer_range[1]}", + "selection_score": self.selection_score, + } + if extra_meta: + meta.update(extra_meta) + save_file(sd, path, metadata=meta) + + def load(self, path: str) -> None: + """Load buffers + delta_s AND mutate layer.weight to W - U·S·Vh. + + Assumes layer.weight currently holds the W that the adapter was + trained against (i.e. base W for round 1; the W after round-(k-1)'s + load for round k — caller must load in round order).""" + from safetensors.torch import load_file + sd = load_file(path, device="cpu") + ckpt_keys = {k[2:].replace("__", ".") for k in sd if k.startswith("U.")} + init_keys = set(self.U.keys()) + if ckpt_keys != init_keys: + raise RuntimeError( + f"PiSSA target mismatch: checkpoint has {len(ckpt_keys)} targets, " + f"init created {len(init_keys)}.") + for k in self.U: + kk = k.replace(".", "__") + dev = self.U[k].device + self.U[k].data.copy_(sd[f"U.{kk}"].to(dev, self.U[k].dtype)) + self.S[k].data.copy_(sd[f"S.{kk}"].to(dev, self.S[k].dtype)) + self.Vh[k].data.copy_(sd[f"Vh.{kk}"].to(dev, self.Vh[k].dtype)) + self.delta_s[k].data.copy_(sd[f"delta_s.{kk}"].to(dev, self.delta_s[k].dtype)) + # Mutate layer.weight: W -> W - U·S·Vh (= W_res that training saw). + layer = self._target_layers[k] + with torch.no_grad(): + W = layer.weight.data.to(torch.float32) + U32 = self.U[k].to(torch.float32) + S32 = self.S[k].to(torch.float32) + Vh32 = self.Vh[k].to(torch.float32) + W_res = (W - (U32 * S32) @ Vh32).to(layer.weight.dtype) + layer.weight.data.copy_(W_res) + + @classmethod + def from_checkpoint(cls, model: nn.Module, path: str) -> "ModulatedPiSSA": + from safetensors import safe_open + with safe_open(path, framework="pt") as f: + meta = f.metadata() + if meta.get("kind") != "pissa": + raise RuntimeError(f"not a PiSSA checkpoint: kind={meta.get('kind')!r}") + targets = tuple(meta["targets"].split(",")) + lr_meta = meta.get("layer_range", "0.0,1.0") + lo, hi = (float(v) for v in lr_meta.split(",")) + adapter = cls(model, r=int(meta["r"]), targets=targets, + layer_range=(lo, hi), + dtype=next(model.parameters()).dtype, + selection_score=meta["selection_score"], + _skip_init=True) + adapter.load(path) + return adapter + + +# --------------------------------------------------------------------------- +# HistoryBake — kept adapters compose via a single gated forward hook. +# Storage: per-target concat A_cat = [A_1; A_2; …], B_cat = [s_1·B_1, …] +# so dW_combined = B_cat @ A_cat exactly, computed without materialising dW. +# Gate predicate set by training code: `lambda: lora._c != 0.0` makes the +# c=0 reference forward return pristine base. +# --------------------------------------------------------------------------- + +class HistoryBake: + def __init__(self, model: nn.Module, history: list[tuple["ModulatedLoRA", float]]): + self._is_active = lambda: True # inference default; train code overrides + target_layers: dict[str, nn.Linear] = {} + for lora, _ in history: + for name, layer in lora._target_layers.items(): + target_layers.setdefault(name, layer) + self._target_layers = target_layers + + r0 = history[0][0].cfg.r + for lora, _ in history[1:]: + assert lora.cfg.r == r0, f"kept-history rank mismatch: {lora.cfg.r} vs {r0}" + target_dtype = history[0][0].cfg.dtype + + with torch.no_grad(): + A_cat: dict[str, Tensor] = {} + B_cat: dict[str, Tensor] = {} + for name, layer in target_layers.items(): + A_parts, B_parts = [], [] + for lora, c in history: + if name not in lora.A: + continue + s = c * lora.cfg.alpha / lora.cfg.r + A_parts.append(lora.A[name].to(target_dtype)) + B_parts.append((s * lora.B[name]).to(target_dtype)) + A_cat[name] = torch.cat(A_parts, dim=0).to(layer.weight.device).detach() + B_cat[name] = torch.cat(B_parts, dim=1).to(layer.weight.device).detach() + self._A_cat = A_cat + self._B_cat = B_cat + + self._handles = [] + for name, layer in target_layers.items(): + self._handles.append(layer.register_forward_hook(self._make_hook(name))) + Nr = len(history) * r0 + logger.info(f"HistoryBake: {len(history)} kept adapter(s), r_total={Nr}") + + def _make_hook(self, name: str): + # Read A/B inside the hook body — closure captures `self`, not the + # tensors. Safe against future push/pop replacement of A_cat/B_cat. + def hook(layer: nn.Linear, args, y: Tensor) -> Tensor: + if not self._is_active(): + return y + (x,) = args + A = self._A_cat[name] # (k_total, d_in) + B = self._B_cat[name] # (d_out, k_total), scale pre-folded + xA = x.to(A.dtype) + h = F.linear(xA, A) # x @ A.T via cuBLAS + delta = F.linear(h, B) # h @ B.T via cuBLAS + return y + delta.to(y.dtype) + + return hook + + def set_gate(self, is_active) -> None: + """is_active() -> bool. Training: `set_gate(lambda: lora._c != 0.0)`.""" + self._is_active = is_active + + def remove(self) -> None: + for h in self._handles: + h.remove() + self._handles.clear() + + +# --------------------------------------------------------------------------- +# PiSSAHistoryBake — composes kept ModulatedPiSSA rounds. +# +# Each kept round k has already mutated layer.weight by subtracting +# U_k·S_k·Vh_k. Without reconstruction the forward returns W_res_K·x ≠ W·x, +# i.e. nonsense — so this hook is ALWAYS active (no gate). It sums each +# round's contribution at its kept c_k: +# Σ_k U_k · diag(S_k + c_k · Δs_k) · Vh_k · x +# +# `set_gate` is a no-op kept for ModulatedLoRA-interface parity. The KL +# anchor semantic shifts from "pristine base" to "prior-baked state" — +# c=0 on the *current* round gives W_res_K + Σ_prior reconstructions, +# i.e. the state after all kept rounds at their kept coefficients. See +# CLAUDE.md note on KL. +# --------------------------------------------------------------------------- + +class PiSSAHistoryBake: + def __init__(self, model: nn.Module, + history: list[tuple["ModulatedPiSSA", float]]): + if not history: + raise ValueError("PiSSAHistoryBake needs ≥1 kept round") + target_layers: dict[str, nn.Linear] = {} + for adapter, _ in history: + for name, layer in adapter._target_layers.items(): + target_layers.setdefault(name, layer) + self._target_layers = target_layers + target_dtype = history[0][0].cfg.dtype + + # Per layer: stack per-round (U_k, Vh_k) along the rank axis exactly + # like HistoryBake stacks (A_k, B_k). The diag(S_k + c_k·Δs_k) factor + # is folded into a single per-round scale vector — concat into one + # (Σr,) vector aligned with the stacked rank axis. + U_cat: dict[str, Tensor] = {} + Vh_cat: dict[str, Tensor] = {} + scale_cat: dict[str, Tensor] = {} + with torch.no_grad(): + for name, layer in target_layers.items(): + U_parts, Vh_parts, scale_parts = [], [], [] + for adapter, c in history: + if name not in adapter.U: + continue + U_parts.append(adapter.U[name].to(target_dtype)) + Vh_parts.append(adapter.Vh[name].to(target_dtype)) + s = adapter.S[name].to(target_dtype) + float(c) * adapter.delta_s[name].detach().to(target_dtype) + scale_parts.append(s) + if not U_parts: + continue + U_cat[name] = torch.cat(U_parts, dim=1).to(layer.weight.device).detach() + Vh_cat[name] = torch.cat(Vh_parts, dim=0).to(layer.weight.device).detach() + scale_cat[name] = torch.cat(scale_parts, dim=0).to(layer.weight.device).detach() + self._U_cat = U_cat + self._Vh_cat = Vh_cat + self._scale_cat = scale_cat + + self._handles = [] + for name, layer in target_layers.items(): + if name in U_cat: + self._handles.append(layer.register_forward_hook(self._make_hook(name))) + r0 = history[0][0].cfg.r + Nr = len(history) * r0 + logger.info(f"PiSSAHistoryBake: {len(history)} kept adapter(s), r_total={Nr}") + + def _make_hook(self, name: str): + def hook(layer: nn.Linear, args, y: Tensor) -> Tensor: + (x,) = args + U = self._U_cat[name] # (d_out, Σr) + Vh = self._Vh_cat[name] # (Σr, d_in) + s = self._scale_cat[name] # (Σr,) + xV = F.linear(x.to(Vh.dtype), Vh) # (..., Σr) + scaled = xV * s + delta = F.linear(scaled, U) # (..., d_out) + return y + delta.to(y.dtype) + return hook + + def set_gate(self, _is_active) -> None: + """No-op for PiSSA: kept-round reconstructions must always be active + (otherwise layer.weight = W_res ≠ W and the forward returns garbage). + Kept for ModulatedLoRA-interface parity with the train loop.""" + pass + + def remove(self) -> None: + for h in self._handles: + h.remove() + self._handles.clear() diff --git a/src/steer_heal/ws/bake.py b/src/steer_heal/ws/bake.py new file mode 100644 index 0000000..8ffabb9 --- /dev/null +++ b/src/steer_heal/ws/bake.py @@ -0,0 +1,223 @@ +"""Compose N LoRA adapters into a model's weights for fast inference. + +Two backends, one surface (`baked()` context manager): + +- **Float layers** (bf16 / fp16 nn.Linear): apply Σ_i c_i * (α_i/r_i) * B_i @ A_i + to `layer.weight` in-place. Restore from a CPU backup on exit. + Forward pass has NO runtime LoRA overhead — weights already include dW. + Memory: CPU backup of W (~5GB for 9b bf16, ~13GB for 27b). + +- **Quantized layers** (bnb Linear4bit / Linear8bitLt): can't modify the + quantized buffer reversibly. Concatenate all adapters into one stacked + low-rank pair `(A_stack, B_stack)` and register ONE forward hook per layer + doing two `F.linear` calls. N adapters cost the same as 1 at forward. + +Both paths take the same AdapterSpec list, so callers don't see the split. +Use for inference only (csm eval, dialogues, gen_completions). Training +varies `c` per step → use the gated ModulatedLoRA / HistoryBake hooks +instead. +""" +from __future__ import annotations + +from contextlib import contextmanager +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F +from loguru import logger + +from steer_heal.ws.adapter import ModulatedLoRA, ModulatedPiSSA + + +@dataclass +class AdapterSpec: + """One LoRA's worth of parameters, CPU-resident between bakes. + + A[name]: (r, d_in) B[name]: (d_out, r) scale = α/r + Materialized to layer device+dtype briefly during bake; freed after. + """ + A: dict[str, torch.Tensor] + B: dict[str, torch.Tensor] + alpha: float + r: int + default_c: float = 1.0 + + @classmethod + def from_lora(cls, lora: ModulatedLoRA, default_c: float = 1.0) -> "AdapterSpec": + return cls( + A={k: v.detach().cpu() for k, v in lora.A.items()}, + B={k: v.detach().cpu() for k, v in lora.B.items()}, + alpha=lora.cfg.alpha, r=lora.cfg.r, default_c=default_c, + ) + + @classmethod + def from_checkpoint(cls, model: nn.Module, path: str, + default_c: float = 1.0) -> "AdapterSpec": + """Load adapter.safetensors as a CPU-resident spec without + attaching hooks (unlike ModulatedLoRA.from_checkpoint).""" + lora = ModulatedLoRA.from_checkpoint(model, path) + spec = cls.from_lora(lora, default_c=default_c) + del lora # drop the cuda copies + return spec + + +def _is_quantized(layer: nn.Linear) -> bool: + """bnb.Linear4bit / Linear8bitLt detection. False for bf16/fp16.""" + return type(layer).__name__ in ("Linear4bit", "Linear8bitLt") + + +def _gather_target_layers(model: nn.Module, + adapters: list[AdapterSpec]) -> dict[str, nn.Linear]: + """Union of named layers across all adapters' target sets.""" + wanted: set[str] = set() + for a in adapters: + wanted.update(a.A.keys()) + return {name: m for name, m in model.named_modules() if name in wanted} + + +@contextmanager +def baked(model: nn.Module, adapters: list[AdapterSpec], + c_overrides: list[float] | None = None): + """Apply N adapters at fixed c per adapter. Restore on exit. + + Composable: `baked(model, hist_specs + [current_spec])` bakes the full + cumulative state in one call. Per-adapter c via `default_c` or + `c_overrides=[c0, c1, ...]`. No-op if adapters is empty. + """ + if not adapters: + yield + return + cs = list(c_overrides) if c_overrides is not None else [a.default_c for a in adapters] + assert len(cs) == len(adapters), \ + f"c_overrides len {len(cs)} != adapters len {len(adapters)}" + + layers = _gather_target_layers(model, adapters) + float_layers = {n: l for n, l in layers.items() if not _is_quantized(l)} + quant_layers = {n: l for n, l in layers.items() if _is_quantized(l)} + + # ─ Float backend: in-place W += Σ c_i * (α_i/r_i) * (B_i @ A_i) + W_backup: dict[str, torch.Tensor] = {} + for name, layer in float_layers.items(): + W_backup[name] = layer.weight.detach().cpu().clone() + dev, dt = layer.weight.device, layer.weight.dtype + dW = torch.zeros_like(layer.weight) + for adapter, c in zip(adapters, cs): + if name not in adapter.A or c == 0.0: + continue + A = adapter.A[name].to(dev, dt) + B = adapter.B[name].to(dev, dt) + scale = (c * adapter.alpha) / adapter.r + dW.add_(scale * (B @ A)) + layer.weight.data.add_(dW) + + # ─ Quant backend: one stacked-low-rank hook per layer (N adapters → one pair) + quant_handles = [] + for name, layer in quant_layers.items(): + A_parts, B_parts = [], [] + for adapter, c in zip(adapters, cs): + if name not in adapter.A or c == 0.0: + continue + scale = (c * adapter.alpha) / adapter.r + A_parts.append(adapter.A[name].to(layer.weight.device)) + B_parts.append((scale * adapter.B[name]).to(layer.weight.device)) + if not A_parts: + continue + A_stack = torch.cat(A_parts, dim=0) # (Σr, d_in) + B_stack = torch.cat(B_parts, dim=1) # (d_out, Σr) + + def _hook(_m, args, y, A=A_stack, B=B_stack): + (x,) = args + h = F.linear(x.to(A.dtype), A) # x @ A.T + return y + F.linear(h, B).to(y.dtype) # h @ B.T + + quant_handles.append(layer.register_forward_hook(_hook)) + + logger.debug(f"baked: {len(adapters)} adapter(s), " + f"{len(float_layers)} float (in-place), " + f"{len(quant_layers)} quantized (stacked-LR hook)") + try: + yield + finally: + for name, layer in float_layers.items(): + layer.weight.data.copy_(W_backup[name].to(layer.weight.device)) + for h in quant_handles: + h.remove() + + +# --------------------------------------------------------------------------- +# PiSSA inference path. Math observation: at training time, layer.weight was +# mutated to W_res = W - U·S·Vh and the forward added back U·(S + c·Δs)·Vh. +# At inference we load a FRESH layer.weight = W. Folding "extract U·S·Vh +# then add U·(S+c·Δs)·Vh" on a fresh W collapses to a single per-round +# additive +c·U·diag(Δs)·Vh — which is exactly the LoRA forward with +# A=Vh, B=U·diag(Δs), α=r (so α/r=1). So a PiSSA checkpoint is loadable +# as an AdapterSpec usable by `baked()` with zero new bake code. This +# equivalence only holds for FRESH W; it does NOT hold inside the training +# loop's `PiSSAHistoryBake`, which operates on a model whose layer.weight +# has been sequentially mutated by ModulatedPiSSA.load(). +# --------------------------------------------------------------------------- + + +def pissa_to_lora_spec(adapter: ModulatedPiSSA, + default_c: float = 1.0) -> AdapterSpec: + """ModulatedPiSSA -> LoRA-shaped AdapterSpec usable by `baked()`. + A = Vh, B = U·diag(Δs), α = r.""" + A: dict[str, torch.Tensor] = {} + B: dict[str, torch.Tensor] = {} + for k, U in adapter.U.items(): + A[k] = adapter.Vh[k].detach().cpu() + B[k] = (U * adapter.delta_s[k].detach()).cpu() + return AdapterSpec(A=A, B=B, alpha=float(adapter.cfg.r), + r=adapter.cfg.r, default_c=default_c) + + +def pissa_spec_from_checkpoint(model: nn.Module, path: str, + default_c: float = 1.0) -> AdapterSpec: + """Load a PiSSA checkpoint as a LoRA-equivalent AdapterSpec without + touching layer.weight (vs ModulatedPiSSA.load which mutates W). + + Validates checkpoint keys against `_find_targets(model, cfg)` reproduced + from ckpt metadata: any mismatch (renamed module, changed layer_range, + wrong model) raises. Without this `baked()` would silently intersect + keys with model.named_modules() and drop the rest.""" + from safetensors import safe_open + from safetensors.torch import load_file + from csm.ws.adapter import LoRAConfig, _find_targets + with safe_open(path, framework="pt") as f: + meta = f.metadata() + if meta.get("kind") != "pissa": + raise RuntimeError(f"not a PiSSA checkpoint: kind={meta.get('kind')!r}") + r = int(meta["r"]) + targets = tuple(meta["targets"].split(",")) + lr_lo, lr_hi = (float(x) for x in meta["layer_range"].split(",")) + cfg = LoRAConfig(r=r, targets=targets, layer_range=(lr_lo, lr_hi)) + expected = {n for n, _ in _find_targets(model, cfg)} + sd = load_file(path, device="cpu") + keys = {k[2:].replace("__", ".") for k in sd if k.startswith("U.")} + if keys != expected: + missing = expected - keys + extra = keys - expected + raise RuntimeError( + f"PiSSA checkpoint target mismatch (path={path}): " + f"missing={sorted(missing)[:3]}... extra={sorted(extra)[:3]}... " + f"({len(missing)} missing, {len(extra)} extra)") + A: dict[str, torch.Tensor] = {} + B: dict[str, torch.Tensor] = {} + for k in sorted(keys): + kk = k.replace(".", "__") + A[k] = sd[f"Vh.{kk}"].contiguous() + B[k] = (sd[f"U.{kk}"] * sd[f"delta_s.{kk}"]).contiguous() + return AdapterSpec(A=A, B=B, alpha=float(r), r=r, default_c=default_c) + + +def adapter_spec_from_checkpoint(model: nn.Module, path: str, + default_c: float = 1.0) -> AdapterSpec: + """Dispatch on `kind` metadata; both PiSSA + LoRA round-trip into the + same AdapterSpec shape so downstream `baked()` is uniform.""" + from safetensors import safe_open + with safe_open(path, framework="pt") as f: + kind = f.metadata().get("kind", "lora") + if kind == "pissa": + return pissa_spec_from_checkpoint(model, path, default_c=default_c) + return AdapterSpec.from_checkpoint(model, path, default_c=default_c)