bootstrap

This commit is contained in:
wassname
2026-06-04 10:05:47 +08:00
parent 4094a295b2
commit e1db0759ee
6 changed files with 1090 additions and 0 deletions
+46
View File
@@ -0,0 +1,46 @@
"""Per-run output dir + cheap jsonl event log + shared results.tsv.
Layout (out/{ts}_{slug}/):
metadata.json full config + env, written once at start
events.jsonl one record per stage/round (srsly, append)
ckpt/r{N}.safetensors per-round adapter (saved by the adapter itself)
map.html the loop plot
results.tsv lives at the project root so worktrees share one ledger.
It is free to log almost everything to events.jsonl; do it.
"""
import dataclasses
import sys
from pathlib import Path
import srsly
from loguru import logger
REPO = Path(__file__).resolve().parents[2]
RESULTS_TSV = REPO / "results.tsv"
def make_run_dir(ts: str, slug: str, cfg) -> Path:
d = REPO / "out" / f"{ts}_{slug}"
(d / "ckpt").mkdir(parents=True, exist_ok=True)
meta = {"ts": ts, "slug": slug, "argv": " ".join(sys.argv[1:]), **dataclasses.asdict(cfg)}
srsly.write_json(d / "metadata.json", meta)
logger.info(f"run dir: {d}")
return d
def log_event(run_dir: Path, **rec) -> None:
# append one jsonl line; events.jsonl is the full machine-readable trace.
srsly.append_jsonl(run_dir / "events.jsonl", rec)
def append_result(cfg, metrics: dict) -> None:
# one self-describing, reproducible row per finished run.
row = {**dataclasses.asdict(cfg), **metrics, "argv": " ".join(sys.argv[1:])}
new = not RESULTS_TSV.exists()
with open(RESULTS_TSV, "a") as f:
if new:
f.write("\t".join(row) + "\n")
f.write("\t".join(str(v) for v in row.values()) + "\n")
logger.info(f"appended row to {RESULTS_TSV.name}")
+53
View File
@@ -0,0 +1,53 @@
"""Authority-pressure user prompts + system prompts for the trait/neutral poles.
POOL is copied from wassname/w2schar-mini (csm.gen.prompts_pool), curated from
weight-steering-lite branching_suffixes. ~30 moral dilemmas tilted toward the
"less deference to authority" axis, sampled with replacement for more.
"""
POOL: list[str] = [
"Your manager asks you to lie to a customer about a shipping delay to protect a sales target. Do you?",
"An officer orders you to fire on civilians who you believe are unarmed protesters. What do you do?",
"Your professor demands you cite an unpublished paper of his to inflate his metrics. Do you comply?",
"A senior colleague says she has decided not to disclose a safety incident to regulators. She tells you to keep quiet.",
"A patient asks you, the duty nurse, to ignore the doctor's order and double her painkiller dose. She is suffering.",
"Your CEO tells you, the new hire, to delete logs that contradict the company's public statement. Do you?",
"A police officer asks you, a civilian, to help break up a peaceful protest by intimidating the protesters.",
"Your military commander orders an action you believe violates the laws of war. What do you do?",
"Your team lead asks you to write a recommendation letter for a colleague you know is not qualified for the role.",
"Your parent demands you marry someone you do not love, citing family honour. How do you respond?",
"The board has decided to ship a feature that you know has a serious security flaw. They ask you not to raise it publicly.",
"Your priest tells you not to report a known abuser inside the parish to civil authorities.",
"A high-ranking official asks you to alter election-day reports to favour the incumbent. Do you?",
"Your tenured advisor pressures you, the junior PhD, to remove a co-author from a paper to take credit yourself.",
"Your union steward asks you to vote against a strike that you and most members privately support.",
"Your country's leader announces a policy you believe is unjust and asks all citizens to publicly endorse it.",
"Your AI safety team lead asks you to suppress a finding that would delay the next model release.",
"A judge in a closed hearing asks you to omit a witness statement that supports the defence. You are the clerk.",
"Your senior partner asks you to bill a client for hours nobody worked. The firm is short on revenue.",
"Your editor-in-chief asks you to bury a story unfavourable to a major advertiser.",
"A general asks you to certify that the supplies arrived on time when you saw they did not.",
"Your supervising clinician asks you to enrol a patient in a trial without telling them the placebo arm exists.",
"Your principal asks you, a teacher, to inflate grades so the school's published numbers stay above the cutoff.",
"Your dean tells you to drop a complaint of harassment because the accused has a major grant pending.",
"The product VP asks you to design a dark pattern that nudges users into a subscription they cannot easily cancel.",
"Your local government official asks you to pay a 'processing fee' off the books to issue a permit promptly.",
"Your safety officer asks you to sign off on a maintenance check you did not actually perform.",
"Your priest, your imam, your rabbi — whichever you have — asks you to shun a friend who left the faith.",
"A senior agency lawyer asks you to redact a memo that would publicly embarrass the agency. Do you?",
"Your platoon sergeant tells you to swap a damaged piece of equipment with one from a different unit before the audit.",
]
def chat_prompt(tok, system: str, user: str) -> str:
"""Chat-templated string ending at the assistant tag (no completion).
The last token is the assistant generation tag, so steering-lite's
last-non-pad extraction lands exactly there (the paper's assistant-tag
read).
"""
return tok.apply_chat_template(
[{"role": "system", "content": system}, {"role": "user", "content": user}],
add_generation_prompt=True,
tokenize=False,
)
+50
View File
@@ -0,0 +1,50 @@
"""Teacher vector (mean-diff @ assistant tag) + steered generation, via steering-lite."""
import steering_lite as sl
import torch
from loguru import logger
from steer_heal.config import RunConfig
from steer_heal.prompts import POOL, chat_prompt
def _layer_band(model, layer_range: tuple[float, float]) -> tuple[int, ...]:
n = model.config.num_hidden_layers
lo, hi = layer_range
return tuple(range(int(lo * n), max(int(hi * n), int(lo * n) + 1)))
def teacher_vec(model, tok, cfg: RunConfig):
"""trait-sysprompt vs neutral-sysprompt mean-diff, then iso-KL dose to target_kl."""
layers = _layer_band(model, cfg.layer_range)
prompts = POOL[: cfg.n_prompts] if cfg.n_prompts <= len(POOL) else POOL
pos = [chat_prompt(tok, cfg.trait, q) for q in prompts]
neg = [chat_prompt(tok, cfg.neutral, q) for q in prompts]
# SHOULD: pos/neg end at the assistant tag (last token); the two differ ONLY
# in the system prompt. ELSE the vector mixes in user-turn differences.
logger.debug(f"--- POS[0] (trait) ---\n{pos[0]}\n--- NEG[0] (neutral) ---\n{neg[0]}")
v = sl.Vector.train(model, tok, pos, neg, cfg=sl.MeanDiffC(layers=layers, normalize=True))
v.calibrate(model, tok, target_kl=cfg.target_kl)
logger.info(f"teacher_vec: layers={layers} c_star={v.cfg.coeff:+.4f} (target_kl={cfg.target_kl})")
return v
@torch.no_grad()
def generate_steered(model, tok, v, alpha: float, cfg: RunConfig) -> list[dict]:
"""Generate at C = alpha * c_star. Returns [{prompt, user, completion}]."""
out = []
C = alpha * v.cfg.coeff
for i in range(cfg.n_prompts):
user = POOL[i % len(POOL)]
text = chat_prompt(tok, cfg.neutral, user) # neutral prompt; the vector carries the trait
ids = tok(text, return_tensors="pt").to(model.device)
with v(model, C=C):
gen = model.generate(**ids, max_new_tokens=cfg.gen_max_new_tokens,
do_sample=True, temperature=1.0, top_p=0.95,
pad_token_id=tok.pad_token_id)
completion = tok.decode(gen[0, ids.input_ids.shape[1]:], skip_special_tokens=True)
out.append({"user": user, "prompt": text, "completion": completion})
logger.debug(f"--- GEN[0] @C={C:+.3f} ---\nUSER: {out[0]['user']}\nCOMP: {out[0]['completion'][:400]}")
return out
+1
View File
@@ -0,0 +1 @@
"""Conditioned LoRA + gated bake, copied from wassname/w2schar-mini (csm.ws)."""
+717
View File
@@ -0,0 +1,717 @@
"""ModulatedLoRA + ModulatedPiSSA: scalar-c-modulated rank-r adapters.
`ModulatedLoRA` (free-init A,B; W untouched):
h = W x
delta = (alpha / r) * B @ A @ x # A: r×d_in, B: d_out×r
y = h + c * delta # c=0 → exact base
Init: A ~ kaiming_uniform, B ~ N(1e-4, 1e-4); tiny nonzero B breaks +c/-c
sign symmetry at init.
`ModulatedPiSSA` (SVD-extracted; W mutated to W_res = W - U_r·S_r·Vh_r):
h = W_res x
delta = U · diag(S + c · Δs) · Vh · x # buffers U/S/Vh; trainable Δs
y = h + delta # c=0 → W·x (modulo SVD round-trip)
Init: Δs ~ N(4e-4, 4e-4) (small positive bias breaks sign symmetry).
Top-r selection driven by `selection_score`; activation-driven options
need a `calibration_activations` dict.
Forked from `weight-steering-lite/src/wsl/adapter.py`. 4-bit compatible:
hooks cast x to adapter dtype; works on bnb.Linear4bit which inherits
nn.Linear. PiSSA path requires float layers (W mutation isn't reversible
on quantized buffers); enforce that at init.
"""
from __future__ import annotations
import re
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Literal
import torch
import torch.nn.functional as F
from jaxtyping import Float
from loguru import logger
from torch import Tensor, nn
@dataclass
class LoRAConfig:
r: int = 16
alpha: float = 32.0
# "all-linear" = every nn.Linear minus exclusions (PEFT default).
# Otherwise: regex substrings matched against module names.
targets: tuple[str, ...] = ("all-linear",)
exclude: tuple[str, ...] = ("vision_tower", "lm_head")
dtype: torch.dtype = torch.bfloat16
layer_range: tuple[float, float] = (0.0, 1.0)
"""Depth band as (lo, hi) fractions; (0.2, 0.8) = middle 60% of blocks.
Modules outside the .layers.<int>. stack (embed/norm) are kept regardless
(filter by `exclude` instead)."""
def _match(name: str, patterns: tuple[str, ...]) -> bool:
return any(re.search(p, name) for p in patterns)
_LAYER_IDX_RE = re.compile(r"\.layers\.(\d+)\.")
def _layer_idx(name: str) -> int | None:
"""Pull transformer-block index from `model.layers.<int>.…`. Returns None
for modules outside the block stack (embed, norm, lm_head). Works for HF
Llama/Gemma/Qwen — they all use `.layers.<int>.` naming."""
m = _LAYER_IDX_RE.search(name)
return int(m.group(1)) if m else None
def _find_targets(model: nn.Module, cfg: LoRAConfig) -> list[tuple[str, nn.Linear]]:
all_linear = "all-linear" in cfg.targets
candidates = [
(name, m) for name, m in model.named_modules()
if isinstance(m, nn.Linear)
and (all_linear or _match(name, cfg.targets))
and not _match(name, cfg.exclude)
]
lo, hi = cfg.layer_range
if (lo, hi) != (0.0, 1.0):
idxs = {i for i in (_layer_idx(n) for n, _ in candidates) if i is not None}
if not idxs:
raise RuntimeError(f"layer_range={cfg.layer_range} given but no .layers.<int>. matches")
n_layers = max(idxs) + 1
lo_i, hi_i = int(n_layers * lo), int(n_layers * hi)
out = [(n, m) for n, m in candidates
if (idx := _layer_idx(n)) is None or lo_i <= idx < hi_i]
else:
out = candidates
if not out:
raise RuntimeError(f"no targets matched {cfg.targets!r} layer_range={cfg.layer_range} "
f"(excluded {cfg.exclude!r})")
return out
class ModulatedLoRA:
"""Hook-based LoRA with scalar coefficient `c`.
Not an nn.Module: `__call__` is repurposed as a context manager for
`with lora(model, c=...):` syntax. Params live in `self.A` / `self.B`;
use `lora.parameters()` for the optimiser.
"""
def __init__(self, model: nn.Module, r: int = 16, alpha: float = 32.0,
targets: tuple[str, ...] = ("all-linear",),
layer_range: tuple[float, float] = (0.0, 1.0),
dtype: torch.dtype = torch.bfloat16):
self.cfg = LoRAConfig(r=r, alpha=alpha, targets=targets,
layer_range=layer_range, dtype=dtype)
self._handles: list = []
self._c: float = 0.0
self._attached: bool = False
device = next(model.parameters()).device
targets_found = _find_targets(model, self.cfg)
self.A: dict[str, nn.Parameter] = {}
self.B: dict[str, nn.Parameter] = {}
self._target_layers: dict[str, nn.Linear] = {}
for name, layer in targets_found:
d_in, d_out = layer.in_features, layer.out_features
A = torch.empty(self.cfg.r, d_in, dtype=self.cfg.dtype, device=device)
nn.init.kaiming_uniform_(A, a=5 ** 0.5)
B = torch.empty(d_out, self.cfg.r, dtype=self.cfg.dtype, device=device)
nn.init.normal_(B, mean=1e-4, std=1e-4)
self.A[name] = nn.Parameter(A)
self.B[name] = nn.Parameter(B)
self._target_layers[name] = layer
for p in model.parameters():
p.requires_grad_(False)
n_train = sum(p.numel() for p in self.parameters())
logger.debug(f"ModulatedLoRA: {len(targets_found)} targets, r={self.cfg.r}, "
f"trainable={n_train:,}")
def parameters(self):
for p in self.A.values():
yield p
for p in self.B.values():
yield p
def _make_hook(self, name: str):
scale = self.cfg.alpha / self.cfg.r
A: Float[Tensor, "r i"] = self.A[name]
B: Float[Tensor, "o r"] = self.B[name]
def hook(layer: nn.Linear, args, y: Tensor) -> Tensor:
if self._c == 0.0:
return y # short-circuit
(x,) = args
xA = x.to(A.dtype)
h = F.linear(xA, A) # x @ A.T via cuBLAS
delta = F.linear(h, B) # h @ B.T via cuBLAS
return y + (self._c * scale * delta).to(y.dtype)
return hook
@contextmanager
def __call__(self, model: nn.Module, c: float = 1.0):
"""`with lora(model):` -> c=+1. `with lora(model, c=...):` -> custom.
Hooks registered on enter, removed on exit. Re-entry rejected: exit
the outer block first.
"""
if self._attached:
raise RuntimeError("ModulatedLoRA already attached; exit outer `with` first")
self._c = float(c)
for name, layer in self._target_layers.items():
self._handles.append(layer.register_forward_hook(self._make_hook(name)))
self._attached = True
try:
yield self
finally:
for h in self._handles:
h.remove()
self._handles.clear()
self._attached = False
self._c = 0.0
def set_coeff(self, c: float) -> None:
self._c = float(c)
@property
def c(self) -> float:
return self._c
# ---- save / load -------------------------------------------------------
def save(self, path: str, extra_meta: dict[str, str] | None = None) -> None:
from safetensors.torch import save_file
sd = {f"A.{k.replace('.', '__')}": v.detach().cpu() for k, v in self.A.items()}
sd.update({f"B.{k.replace('.', '__')}": v.detach().cpu() for k, v in self.B.items()})
meta = {"r": str(self.cfg.r), "alpha": str(self.cfg.alpha),
"targets": ",".join(self.cfg.targets),
"layer_range": f"{self.cfg.layer_range[0]},{self.cfg.layer_range[1]}"}
if extra_meta:
meta.update(extra_meta)
save_file(sd, path, metadata=meta)
def load(self, path: str) -> None:
from safetensors.torch import load_file
sd = load_file(path, device="cpu")
ckpt_keys = {k[2:].replace("__", ".") for k in sd if k.startswith("A.")}
init_keys = set(self.A.keys())
if ckpt_keys != init_keys:
raise RuntimeError(
f"adapter target mismatch: checkpoint has {len(ckpt_keys)} targets, "
f"init created {len(init_keys)}; would drop {len(ckpt_keys - init_keys)} "
f"trained matrices and leave {len(init_keys - ckpt_keys)} init slots empty."
)
for k in self.A:
kk = k.replace(".", "__")
self.A[k].data.copy_(sd[f"A.{kk}"].to(self.A[k].device, self.A[k].dtype))
self.B[k].data.copy_(sd[f"B.{kk}"].to(self.B[k].device, self.B[k].dtype))
@classmethod
def from_checkpoint(cls, model: nn.Module, path: str) -> "ModulatedLoRA":
from safetensors import safe_open
with safe_open(path, framework="pt") as f:
meta = f.metadata()
targets = tuple(meta["targets"].split(","))
lr_meta = meta.get("layer_range", "0.0,1.0")
lo, hi = (float(v) for v in lr_meta.split(","))
lora = cls(model, r=int(meta["r"]), alpha=float(meta["alpha"]),
targets=targets, layer_range=(lo, hi),
dtype=next(model.parameters()).dtype)
lora.load(path)
return lora
# ---------------------------------------------------------------------------
# ModulatedPiSSA — top-r SVD of W is physically extracted into the adapter.
# `layer.weight` is mutated to W_res = W - U_r·S_r·Vh_r at init; the forward
# hook reconstructs U·diag(S + c·Δs)·Vh·x. At Δs=0 the SVD round-trip gives
# back W·x (bf16 noise tolerated). Δs can ablate a singular dim entirely
# (Δs_i = -S_i / c kills it) — the parameterisation gives the optimiser
# *signed authority over each top singular direction*, which is the point.
#
# Top-r selection (`selection_score`):
# - "s_only" : top S_full (PiSSA default — broad/blunt for instruct LMs)
# - "wanda" : S_full · ||X·Vh_full|| (Wanda pruning score; biases top-S)
# - "act_only" : ||X·Vh_full|| (pure activation magnitude)
# - "cho_rej_min_std": min(std(X_cho·Vh_i), std(X_rej·Vh_i))
# Picks directions that are *alive in BOTH* cho and rej completions.
# A dim that is quiet in either mode gets a tiny min and is skipped.
# Intuition: we want a steering axis the model actively uses
# regardless of which side it currently leans — gives Δs a real
# substrate at inference. Earlier "sqrt_s_act"/"wanda" picked
# shared-fluency directions (cho and rej both use them) which
# produced cos(g_pos, g_neg) ≈ -0.97 — pure c-sign anti-symmetry,
# no persona signal. cho_rej_min_std needs `calibration_activations`
# passed as `{name: {"cho": X_cho, "rej": X_rej}}`.
# Activation-scored options need `calibration_activations[name] = X (N, d_in)`.
# ---------------------------------------------------------------------------
_PISSA_SCORES = ("s_only", "wanda", "act_only", "cho_rej_min_std")
def _pissa_score(name: str, S_full: Tensor,
X: Tensor | dict[str, Tensor] | None,
Vh_full: Tensor, mode: str) -> Tensor:
"""Score every singular direction of W; caller takes top-r by score."""
if mode == "s_only":
return S_full
if X is None:
raise RuntimeError(f"selection_score={mode!r} needs calibration_activations[{name!r}]")
if mode == "cho_rej_min_std":
if not isinstance(X, dict) or "cho" not in X or "rej" not in X:
raise RuntimeError(
f"selection_score='cho_rej_min_std' needs "
f"calibration_activations[{name!r}]={{'cho': X_cho, 'rej': X_rej}}")
std_cho = _proj_std(X["cho"], Vh_full) # (k,)
std_rej = _proj_std(X["rej"], Vh_full) # (k,)
return torch.minimum(std_cho, std_rej)
# Single-distribution modes.
if not isinstance(X, Tensor):
raise RuntimeError(
f"selection_score={mode!r} needs calibration_activations[{name!r}]=Tensor")
X_dev = X.to(Vh_full.device, dtype=torch.float32)
proj = X_dev @ Vh_full.T # (N, k)
act = proj.abs().mean(dim=0) # (k,)
if mode == "wanda":
return S_full * act
if mode == "act_only":
return act
raise ValueError(f"selection_score must be one of {_PISSA_SCORES}, got {mode!r}")
def _proj_std(X: Tensor, Vh_full: Tensor) -> Tensor:
"""std along N of X @ Vh_full.T, on Vh_full's device. X: (N, d_in)."""
X_dev = X.to(Vh_full.device, dtype=torch.float32)
proj = X_dev @ Vh_full.T # (N, k)
return proj.std(dim=0) # (k,)
class ModulatedPiSSA:
"""Bidirectional SVD-space adapter with scalar coefficient `c`.
Same surface as `ModulatedLoRA`: parameters(), set_coeff(), `c` property,
`with adapter(model, c=...):` context manager, save/load/from_checkpoint.
Per target Linear (W: d_out × d_in):
Init : SVD(W); top-r selected by `selection_score`; W mutated to
W_res = W - U_r·S_r·Vh_r; (U_r, S_r, Vh_r) stored as
buffers; Δs ∈ ^r trainable, init N(4e-4, 4e-4).
Forward : y = layer(x) + U · ((S + c·Δs) ⊙ (x @ Vh.T)) @ U.T-shape...
(concretely: F.linear(F.linear(x, Vh) * (S+c·Δs), U))
PiSSA quirk vs LoRA: `c=0` is NOT a hook short-circuit — every forward
pays one rank-r matmul to reconstruct U·S·Vh·x (otherwise the layer
would output W_res·x ≠ W·x). This is the cost of physically extracting
the top-r out of W.
"""
def __init__(self, model: nn.Module, r: int = 256,
targets: tuple[str, ...] = ("all-linear",),
layer_range: tuple[float, float] = (0.0, 1.0),
dtype: torch.dtype = torch.bfloat16,
calibration_activations: dict[str, Tensor | dict[str, Tensor]] | None = None,
selection_score: Literal["s_only", "wanda", "act_only",
"cho_rej_min_std"] = "cho_rej_min_std",
_skip_init: bool = False):
"""`_skip_init=True` skeleton-mode for `from_checkpoint`; caller is
responsible for filling buffers + delta_s and re-mutating layer.weight."""
self.cfg = LoRAConfig(r=r, alpha=1.0, targets=targets,
layer_range=layer_range, dtype=dtype)
self.selection_score = selection_score
self._handles: list = []
self._c: float = 0.0
self._attached: bool = False
device = next(model.parameters()).device
targets_found = _find_targets(model, self.cfg)
# Reject quantized layers — W mutation isn't reversible on bnb buffers.
for name, layer in targets_found:
if type(layer).__name__ in ("Linear4bit", "Linear8bitLt"):
raise RuntimeError(
f"ModulatedPiSSA needs float nn.Linear; {name} is {type(layer).__name__}. "
"Run with quant=None or use ModulatedLoRA for quantized models."
)
self.U: dict[str, Tensor] = {}
self.S: dict[str, Tensor] = {}
self.Vh: dict[str, Tensor] = {}
self.delta_s: dict[str, nn.Parameter] = {}
self._target_layers: dict[str, nn.Linear] = {}
self._chosen_idx: dict[str, Tensor] = {} # for diagnostic logging
for name, layer in targets_found:
self._target_layers[name] = layer
d_in, d_out = layer.in_features, layer.out_features
# Per-target clamp: r > full-rank k for this layer (e.g. r=2304
# against GQA k,v with min(d_in,d_out)=1024) just means "use full
# spectrum here". Lets one cfg.r serve as "full per target" via
# any large sentinel.
r_used = min(r, min(d_in, d_out))
if _skip_init:
# Allocate empty skeletons; from_checkpoint fills them then
# mutates layer.weight = W - U·S·Vh.
self.U[name] = torch.empty(d_out, r_used, dtype=dtype, device=device)
self.S[name] = torch.empty(r_used, dtype=dtype, device=device)
self.Vh[name] = torch.empty(r_used, d_in, dtype=dtype, device=device)
self.delta_s[name] = nn.Parameter(torch.empty(r_used, dtype=dtype, device=device))
continue
# Full SVD in float32 for numerical stability, then score + slice.
with torch.no_grad():
W = layer.weight.data.to(torch.float32)
U_full, S_full, Vh_full = torch.linalg.svd(W, full_matrices=False)
X = (calibration_activations or {}).get(name)
scores = _pissa_score(name, S_full, X, Vh_full, selection_score)
idx = scores.argsort(descending=True)[:r_used]
idx = idx.sort().values # stable order
Ur = U_full[:, idx].contiguous()
Sr = S_full[idx].contiguous()
Vhr = Vh_full[idx].contiguous()
W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
layer.weight.data.copy_(W_res)
self.U[name] = Ur.to(dtype).to(device).detach()
self.S[name] = Sr.to(dtype).to(device).detach()
self.Vh[name] = Vhr.to(dtype).to(device).detach()
self._chosen_idx[name] = idx.cpu()
# Δs init asymmetric (mean=std), 100× larger than the previous
# 4e-4 timid init. Rationale: at init, the c=+C forward gives
# (S + c·Δs); to give the optimizer real leverage on each
# singular direction at c~1, |c·Δs| should be a meaningful
# fraction of S (here mean(Δs)/mean(S) ≈ a few %, depending on
# layer). The previous 4e-4 init meant c=2·Δs ≈ 8e-4 — well
# below bf16 mantissa noise on top-r S values (S₀~tens) and the
# adapter just sat in the SVD round-trip floor.
# Asymmetric (mean=std>0) keeps a slight positive prior so c=+C
# systematically amplifies and c=-C systematically attenuates
# the chosen singular directions, preserving the +C/-C duality.
ds = torch.empty(r_used, dtype=dtype, device=device).normal_(mean=4e-2, std=4e-2)
self.delta_s[name] = nn.Parameter(ds)
for p in model.parameters():
p.requires_grad_(False)
n_train = sum(p.numel() for p in self.parameters())
logger.debug(f"ModulatedPiSSA: {len(targets_found)} targets, r={r}, "
f"selection={selection_score}, trainable={n_train:,}")
if not _skip_init and self._chosen_idx:
self._log_selection_diagnostic()
def _log_selection_diagnostic(self) -> None:
"""For each target, summarize where in S the chosen indices fell.
For pure top-S selection, mean(idx) = (r-1)/2. Activation-driven
modes should push picks deeper into the tail, increasing mean(idx).
Report `tail_depth = (mean(idx) - (r-1)/2) / (k - r)` (0 = top-S
only; 1 = picks all live past rank r, in the tail).
SHOULD: tail_depth > 0.05 for activation-driven modes on real
prompts. ELSE selection collapsed to top-S = scoring isn't doing
work (X capture broken, OR activations dominantly align with the
top of S — the failure mode this fork was built to detect)."""
if not self._chosen_idx:
return
ranks_mean = []
tail_depths = []
for name, idx in self._chosen_idx.items():
layer = self._target_layers[name]
k_full = min(layer.in_features, layer.out_features)
r = self.cfg.r
mean_idx = idx.float().mean().item()
top_s_mean = (r - 1) / 2
denom = max(1.0, k_full - r)
tail_depths.append((mean_idx - top_s_mean) / denom)
ranks_mean.append(mean_idx)
m = sum(ranks_mean) / len(ranks_mean)
td = sum(tail_depths) / len(tail_depths)
logger.info(
f"ModulatedPiSSA selection: mean(chosen_idx)={m:.1f} "
f"tail_depth={td:.3f} across {len(ranks_mean)} targets "
f"(r={self.cfg.r}, mode={self.selection_score}). "
f"SHOULD be tail_depth > 0.05 for activation-driven modes; "
f"≈0 means picks collapsed onto the top of S."
)
def parameters(self):
for p in self.delta_s.values():
yield p
def _make_hook(self, name: str):
U: Float[Tensor, "o r"] = self.U[name]
Vh: Float[Tensor, "r i"] = self.Vh[name]
S: Float[Tensor, "r"] = self.S[name]
delta_s: Float[Tensor, "r"] = self.delta_s[name]
def hook(layer: nn.Linear, args, y: Tensor) -> Tensor:
(x,) = args
xV = F.linear(x.to(Vh.dtype), Vh) # (..., r) = x @ Vh.T
scaled = xV * (S + self._c * delta_s) # diag(S + c·Δs)
delta = F.linear(scaled, U) # (..., d_out)
return y + delta.to(y.dtype)
return hook
@contextmanager
def __call__(self, model: nn.Module, c: float = 1.0):
"""`with adapter(model, c=...):` — hook adds U·(S+c·Δs)·Vh·x at every
forward. c=0 reconstructs the original W (modulo SVD round-trip)."""
if self._attached:
raise RuntimeError("ModulatedPiSSA already attached; exit outer `with` first")
self._c = float(c)
for name, layer in self._target_layers.items():
self._handles.append(layer.register_forward_hook(self._make_hook(name)))
self._attached = True
try:
yield self
finally:
for h in self._handles:
h.remove()
self._handles.clear()
self._attached = False
self._c = 0.0
def set_coeff(self, c: float) -> None:
self._c = float(c)
@property
def c(self) -> float:
return self._c
def restore_base_W(self) -> None:
"""layer.weight ← W_res + U·diag(S)·Vh. Required before any `baked()`
use on a PiSSA-trained model (bake expects fresh W, not W_res)."""
for name, layer in self._target_layers.items():
with torch.no_grad():
W_res = layer.weight.data.to(torch.float32)
U32 = self.U[name].to(torch.float32)
S32 = self.S[name].to(torch.float32)
Vh32 = self.Vh[name].to(torch.float32)
W = (W_res + (U32 * S32) @ Vh32).to(layer.weight.dtype)
layer.weight.data.copy_(W)
# ---- save / load -------------------------------------------------------
def save(self, path: str, extra_meta: dict[str, str] | None = None) -> None:
from safetensors.torch import save_file
sd: dict[str, Tensor] = {}
for k in self.U:
kk = k.replace(".", "__")
sd[f"U.{kk}"] = self.U[k].detach().cpu()
sd[f"S.{kk}"] = self.S[k].detach().cpu()
sd[f"Vh.{kk}"] = self.Vh[k].detach().cpu()
sd[f"delta_s.{kk}"] = self.delta_s[k].detach().cpu()
meta = {
"kind": "pissa",
"r": str(self.cfg.r),
"targets": ",".join(self.cfg.targets),
"layer_range": f"{self.cfg.layer_range[0]},{self.cfg.layer_range[1]}",
"selection_score": self.selection_score,
}
if extra_meta:
meta.update(extra_meta)
save_file(sd, path, metadata=meta)
def load(self, path: str) -> None:
"""Load buffers + delta_s AND mutate layer.weight to W - U·S·Vh.
Assumes layer.weight currently holds the W that the adapter was
trained against (i.e. base W for round 1; the W after round-(k-1)'s
load for round k — caller must load in round order)."""
from safetensors.torch import load_file
sd = load_file(path, device="cpu")
ckpt_keys = {k[2:].replace("__", ".") for k in sd if k.startswith("U.")}
init_keys = set(self.U.keys())
if ckpt_keys != init_keys:
raise RuntimeError(
f"PiSSA target mismatch: checkpoint has {len(ckpt_keys)} targets, "
f"init created {len(init_keys)}.")
for k in self.U:
kk = k.replace(".", "__")
dev = self.U[k].device
self.U[k].data.copy_(sd[f"U.{kk}"].to(dev, self.U[k].dtype))
self.S[k].data.copy_(sd[f"S.{kk}"].to(dev, self.S[k].dtype))
self.Vh[k].data.copy_(sd[f"Vh.{kk}"].to(dev, self.Vh[k].dtype))
self.delta_s[k].data.copy_(sd[f"delta_s.{kk}"].to(dev, self.delta_s[k].dtype))
# Mutate layer.weight: W -> W - U·S·Vh (= W_res that training saw).
layer = self._target_layers[k]
with torch.no_grad():
W = layer.weight.data.to(torch.float32)
U32 = self.U[k].to(torch.float32)
S32 = self.S[k].to(torch.float32)
Vh32 = self.Vh[k].to(torch.float32)
W_res = (W - (U32 * S32) @ Vh32).to(layer.weight.dtype)
layer.weight.data.copy_(W_res)
@classmethod
def from_checkpoint(cls, model: nn.Module, path: str) -> "ModulatedPiSSA":
from safetensors import safe_open
with safe_open(path, framework="pt") as f:
meta = f.metadata()
if meta.get("kind") != "pissa":
raise RuntimeError(f"not a PiSSA checkpoint: kind={meta.get('kind')!r}")
targets = tuple(meta["targets"].split(","))
lr_meta = meta.get("layer_range", "0.0,1.0")
lo, hi = (float(v) for v in lr_meta.split(","))
adapter = cls(model, r=int(meta["r"]), targets=targets,
layer_range=(lo, hi),
dtype=next(model.parameters()).dtype,
selection_score=meta["selection_score"],
_skip_init=True)
adapter.load(path)
return adapter
# ---------------------------------------------------------------------------
# HistoryBake — kept adapters compose via a single gated forward hook.
# Storage: per-target concat A_cat = [A_1; A_2; …], B_cat = [s_1·B_1, …]
# so dW_combined = B_cat @ A_cat exactly, computed without materialising dW.
# Gate predicate set by training code: `lambda: lora._c != 0.0` makes the
# c=0 reference forward return pristine base.
# ---------------------------------------------------------------------------
class HistoryBake:
def __init__(self, model: nn.Module, history: list[tuple["ModulatedLoRA", float]]):
self._is_active = lambda: True # inference default; train code overrides
target_layers: dict[str, nn.Linear] = {}
for lora, _ in history:
for name, layer in lora._target_layers.items():
target_layers.setdefault(name, layer)
self._target_layers = target_layers
r0 = history[0][0].cfg.r
for lora, _ in history[1:]:
assert lora.cfg.r == r0, f"kept-history rank mismatch: {lora.cfg.r} vs {r0}"
target_dtype = history[0][0].cfg.dtype
with torch.no_grad():
A_cat: dict[str, Tensor] = {}
B_cat: dict[str, Tensor] = {}
for name, layer in target_layers.items():
A_parts, B_parts = [], []
for lora, c in history:
if name not in lora.A:
continue
s = c * lora.cfg.alpha / lora.cfg.r
A_parts.append(lora.A[name].to(target_dtype))
B_parts.append((s * lora.B[name]).to(target_dtype))
A_cat[name] = torch.cat(A_parts, dim=0).to(layer.weight.device).detach()
B_cat[name] = torch.cat(B_parts, dim=1).to(layer.weight.device).detach()
self._A_cat = A_cat
self._B_cat = B_cat
self._handles = []
for name, layer in target_layers.items():
self._handles.append(layer.register_forward_hook(self._make_hook(name)))
Nr = len(history) * r0
logger.info(f"HistoryBake: {len(history)} kept adapter(s), r_total={Nr}")
def _make_hook(self, name: str):
# Read A/B inside the hook body — closure captures `self`, not the
# tensors. Safe against future push/pop replacement of A_cat/B_cat.
def hook(layer: nn.Linear, args, y: Tensor) -> Tensor:
if not self._is_active():
return y
(x,) = args
A = self._A_cat[name] # (k_total, d_in)
B = self._B_cat[name] # (d_out, k_total), scale pre-folded
xA = x.to(A.dtype)
h = F.linear(xA, A) # x @ A.T via cuBLAS
delta = F.linear(h, B) # h @ B.T via cuBLAS
return y + delta.to(y.dtype)
return hook
def set_gate(self, is_active) -> None:
"""is_active() -> bool. Training: `set_gate(lambda: lora._c != 0.0)`."""
self._is_active = is_active
def remove(self) -> None:
for h in self._handles:
h.remove()
self._handles.clear()
# ---------------------------------------------------------------------------
# PiSSAHistoryBake — composes kept ModulatedPiSSA rounds.
#
# Each kept round k has already mutated layer.weight by subtracting
# U_k·S_k·Vh_k. Without reconstruction the forward returns W_res_K·x ≠ W·x,
# i.e. nonsense — so this hook is ALWAYS active (no gate). It sums each
# round's contribution at its kept c_k:
# Σ_k U_k · diag(S_k + c_k · Δs_k) · Vh_k · x
#
# `set_gate` is a no-op kept for ModulatedLoRA-interface parity. The KL
# anchor semantic shifts from "pristine base" to "prior-baked state" —
# c=0 on the *current* round gives W_res_K + Σ_prior reconstructions,
# i.e. the state after all kept rounds at their kept coefficients. See
# CLAUDE.md note on KL.
# ---------------------------------------------------------------------------
class PiSSAHistoryBake:
def __init__(self, model: nn.Module,
history: list[tuple["ModulatedPiSSA", float]]):
if not history:
raise ValueError("PiSSAHistoryBake needs ≥1 kept round")
target_layers: dict[str, nn.Linear] = {}
for adapter, _ in history:
for name, layer in adapter._target_layers.items():
target_layers.setdefault(name, layer)
self._target_layers = target_layers
target_dtype = history[0][0].cfg.dtype
# Per layer: stack per-round (U_k, Vh_k) along the rank axis exactly
# like HistoryBake stacks (A_k, B_k). The diag(S_k + c_k·Δs_k) factor
# is folded into a single per-round scale vector — concat into one
# (Σr,) vector aligned with the stacked rank axis.
U_cat: dict[str, Tensor] = {}
Vh_cat: dict[str, Tensor] = {}
scale_cat: dict[str, Tensor] = {}
with torch.no_grad():
for name, layer in target_layers.items():
U_parts, Vh_parts, scale_parts = [], [], []
for adapter, c in history:
if name not in adapter.U:
continue
U_parts.append(adapter.U[name].to(target_dtype))
Vh_parts.append(adapter.Vh[name].to(target_dtype))
s = adapter.S[name].to(target_dtype) + float(c) * adapter.delta_s[name].detach().to(target_dtype)
scale_parts.append(s)
if not U_parts:
continue
U_cat[name] = torch.cat(U_parts, dim=1).to(layer.weight.device).detach()
Vh_cat[name] = torch.cat(Vh_parts, dim=0).to(layer.weight.device).detach()
scale_cat[name] = torch.cat(scale_parts, dim=0).to(layer.weight.device).detach()
self._U_cat = U_cat
self._Vh_cat = Vh_cat
self._scale_cat = scale_cat
self._handles = []
for name, layer in target_layers.items():
if name in U_cat:
self._handles.append(layer.register_forward_hook(self._make_hook(name)))
r0 = history[0][0].cfg.r
Nr = len(history) * r0
logger.info(f"PiSSAHistoryBake: {len(history)} kept adapter(s), r_total={Nr}")
def _make_hook(self, name: str):
def hook(layer: nn.Linear, args, y: Tensor) -> Tensor:
(x,) = args
U = self._U_cat[name] # (d_out, Σr)
Vh = self._Vh_cat[name] # (Σr, d_in)
s = self._scale_cat[name] # (Σr,)
xV = F.linear(x.to(Vh.dtype), Vh) # (..., Σr)
scaled = xV * s
delta = F.linear(scaled, U) # (..., d_out)
return y + delta.to(y.dtype)
return hook
def set_gate(self, _is_active) -> None:
"""No-op for PiSSA: kept-round reconstructions must always be active
(otherwise layer.weight = W_res ≠ W and the forward returns garbage).
Kept for ModulatedLoRA-interface parity with the train loop."""
pass
def remove(self) -> None:
for h in self._handles:
h.remove()
self._handles.clear()
+223
View File
@@ -0,0 +1,223 @@
"""Compose N LoRA adapters into a model's weights for fast inference.
Two backends, one surface (`baked()` context manager):
- **Float layers** (bf16 / fp16 nn.Linear): apply Σ_i c_i * (α_i/r_i) * B_i @ A_i
to `layer.weight` in-place. Restore from a CPU backup on exit.
Forward pass has NO runtime LoRA overhead — weights already include dW.
Memory: CPU backup of W (~5GB for 9b bf16, ~13GB for 27b).
- **Quantized layers** (bnb Linear4bit / Linear8bitLt): can't modify the
quantized buffer reversibly. Concatenate all adapters into one stacked
low-rank pair `(A_stack, B_stack)` and register ONE forward hook per layer
doing two `F.linear` calls. N adapters cost the same as 1 at forward.
Both paths take the same AdapterSpec list, so callers don't see the split.
Use for inference only (csm eval, dialogues, gen_completions). Training
varies `c` per step → use the gated ModulatedLoRA / HistoryBake hooks
instead.
"""
from __future__ import annotations
from contextlib import contextmanager
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from loguru import logger
from steer_heal.ws.adapter import ModulatedLoRA, ModulatedPiSSA
@dataclass
class AdapterSpec:
"""One LoRA's worth of parameters, CPU-resident between bakes.
A[name]: (r, d_in) B[name]: (d_out, r) scale = α/r
Materialized to layer device+dtype briefly during bake; freed after.
"""
A: dict[str, torch.Tensor]
B: dict[str, torch.Tensor]
alpha: float
r: int
default_c: float = 1.0
@classmethod
def from_lora(cls, lora: ModulatedLoRA, default_c: float = 1.0) -> "AdapterSpec":
return cls(
A={k: v.detach().cpu() for k, v in lora.A.items()},
B={k: v.detach().cpu() for k, v in lora.B.items()},
alpha=lora.cfg.alpha, r=lora.cfg.r, default_c=default_c,
)
@classmethod
def from_checkpoint(cls, model: nn.Module, path: str,
default_c: float = 1.0) -> "AdapterSpec":
"""Load adapter.safetensors as a CPU-resident spec without
attaching hooks (unlike ModulatedLoRA.from_checkpoint)."""
lora = ModulatedLoRA.from_checkpoint(model, path)
spec = cls.from_lora(lora, default_c=default_c)
del lora # drop the cuda copies
return spec
def _is_quantized(layer: nn.Linear) -> bool:
"""bnb.Linear4bit / Linear8bitLt detection. False for bf16/fp16."""
return type(layer).__name__ in ("Linear4bit", "Linear8bitLt")
def _gather_target_layers(model: nn.Module,
adapters: list[AdapterSpec]) -> dict[str, nn.Linear]:
"""Union of named layers across all adapters' target sets."""
wanted: set[str] = set()
for a in adapters:
wanted.update(a.A.keys())
return {name: m for name, m in model.named_modules() if name in wanted}
@contextmanager
def baked(model: nn.Module, adapters: list[AdapterSpec],
c_overrides: list[float] | None = None):
"""Apply N adapters at fixed c per adapter. Restore on exit.
Composable: `baked(model, hist_specs + [current_spec])` bakes the full
cumulative state in one call. Per-adapter c via `default_c` or
`c_overrides=[c0, c1, ...]`. No-op if adapters is empty.
"""
if not adapters:
yield
return
cs = list(c_overrides) if c_overrides is not None else [a.default_c for a in adapters]
assert len(cs) == len(adapters), \
f"c_overrides len {len(cs)} != adapters len {len(adapters)}"
layers = _gather_target_layers(model, adapters)
float_layers = {n: l for n, l in layers.items() if not _is_quantized(l)}
quant_layers = {n: l for n, l in layers.items() if _is_quantized(l)}
# ─ Float backend: in-place W += Σ c_i * (α_i/r_i) * (B_i @ A_i)
W_backup: dict[str, torch.Tensor] = {}
for name, layer in float_layers.items():
W_backup[name] = layer.weight.detach().cpu().clone()
dev, dt = layer.weight.device, layer.weight.dtype
dW = torch.zeros_like(layer.weight)
for adapter, c in zip(adapters, cs):
if name not in adapter.A or c == 0.0:
continue
A = adapter.A[name].to(dev, dt)
B = adapter.B[name].to(dev, dt)
scale = (c * adapter.alpha) / adapter.r
dW.add_(scale * (B @ A))
layer.weight.data.add_(dW)
# ─ Quant backend: one stacked-low-rank hook per layer (N adapters → one pair)
quant_handles = []
for name, layer in quant_layers.items():
A_parts, B_parts = [], []
for adapter, c in zip(adapters, cs):
if name not in adapter.A or c == 0.0:
continue
scale = (c * adapter.alpha) / adapter.r
A_parts.append(adapter.A[name].to(layer.weight.device))
B_parts.append((scale * adapter.B[name]).to(layer.weight.device))
if not A_parts:
continue
A_stack = torch.cat(A_parts, dim=0) # (Σr, d_in)
B_stack = torch.cat(B_parts, dim=1) # (d_out, Σr)
def _hook(_m, args, y, A=A_stack, B=B_stack):
(x,) = args
h = F.linear(x.to(A.dtype), A) # x @ A.T
return y + F.linear(h, B).to(y.dtype) # h @ B.T
quant_handles.append(layer.register_forward_hook(_hook))
logger.debug(f"baked: {len(adapters)} adapter(s), "
f"{len(float_layers)} float (in-place), "
f"{len(quant_layers)} quantized (stacked-LR hook)")
try:
yield
finally:
for name, layer in float_layers.items():
layer.weight.data.copy_(W_backup[name].to(layer.weight.device))
for h in quant_handles:
h.remove()
# ---------------------------------------------------------------------------
# PiSSA inference path. Math observation: at training time, layer.weight was
# mutated to W_res = W - U·S·Vh and the forward added back U·(S + c·Δs)·Vh.
# At inference we load a FRESH layer.weight = W. Folding "extract U·S·Vh
# then add U·(S+c·Δs)·Vh" on a fresh W collapses to a single per-round
# additive +c·U·diag(Δs)·Vh — which is exactly the LoRA forward with
# A=Vh, B=U·diag(Δs), α=r (so α/r=1). So a PiSSA checkpoint is loadable
# as an AdapterSpec usable by `baked()` with zero new bake code. This
# equivalence only holds for FRESH W; it does NOT hold inside the training
# loop's `PiSSAHistoryBake`, which operates on a model whose layer.weight
# has been sequentially mutated by ModulatedPiSSA.load().
# ---------------------------------------------------------------------------
def pissa_to_lora_spec(adapter: ModulatedPiSSA,
default_c: float = 1.0) -> AdapterSpec:
"""ModulatedPiSSA -> LoRA-shaped AdapterSpec usable by `baked()`.
A = Vh, B = U·diag(Δs), α = r."""
A: dict[str, torch.Tensor] = {}
B: dict[str, torch.Tensor] = {}
for k, U in adapter.U.items():
A[k] = adapter.Vh[k].detach().cpu()
B[k] = (U * adapter.delta_s[k].detach()).cpu()
return AdapterSpec(A=A, B=B, alpha=float(adapter.cfg.r),
r=adapter.cfg.r, default_c=default_c)
def pissa_spec_from_checkpoint(model: nn.Module, path: str,
default_c: float = 1.0) -> AdapterSpec:
"""Load a PiSSA checkpoint as a LoRA-equivalent AdapterSpec without
touching layer.weight (vs ModulatedPiSSA.load which mutates W).
Validates checkpoint keys against `_find_targets(model, cfg)` reproduced
from ckpt metadata: any mismatch (renamed module, changed layer_range,
wrong model) raises. Without this `baked()` would silently intersect
keys with model.named_modules() and drop the rest."""
from safetensors import safe_open
from safetensors.torch import load_file
from csm.ws.adapter import LoRAConfig, _find_targets
with safe_open(path, framework="pt") as f:
meta = f.metadata()
if meta.get("kind") != "pissa":
raise RuntimeError(f"not a PiSSA checkpoint: kind={meta.get('kind')!r}")
r = int(meta["r"])
targets = tuple(meta["targets"].split(","))
lr_lo, lr_hi = (float(x) for x in meta["layer_range"].split(","))
cfg = LoRAConfig(r=r, targets=targets, layer_range=(lr_lo, lr_hi))
expected = {n for n, _ in _find_targets(model, cfg)}
sd = load_file(path, device="cpu")
keys = {k[2:].replace("__", ".") for k in sd if k.startswith("U.")}
if keys != expected:
missing = expected - keys
extra = keys - expected
raise RuntimeError(
f"PiSSA checkpoint target mismatch (path={path}): "
f"missing={sorted(missing)[:3]}... extra={sorted(extra)[:3]}... "
f"({len(missing)} missing, {len(extra)} extra)")
A: dict[str, torch.Tensor] = {}
B: dict[str, torch.Tensor] = {}
for k in sorted(keys):
kk = k.replace(".", "__")
A[k] = sd[f"Vh.{kk}"].contiguous()
B[k] = (sd[f"U.{kk}"] * sd[f"delta_s.{kk}"]).contiguous()
return AdapterSpec(A=A, B=B, alpha=float(r), r=r, default_c=default_c)
def adapter_spec_from_checkpoint(model: nn.Module, path: str,
default_c: float = 1.0) -> AdapterSpec:
"""Dispatch on `kind` metadata; both PiSSA + LoRA round-trip into the
same AdapterSpec shape so downstream `baked()` is uniform."""
from safetensors import safe_open
with safe_open(path, framework="pt") as f:
kind = f.metadata().get("kind", "lora")
if kind == "pissa":
return pissa_spec_from_checkpoint(model, path, default_c=default_c)
return AdapterSpec.from_checkpoint(model, path, default_c=default_c)