mirror of
https://github.com/wassname/weight-steering.git
synced 2026-06-27 17:33:06 +08:00
3c9fb8d1f5
Addresses three concerns from docs/review/v6_hypothesis_review.md: 1. R_w split into oproj/downproj + Frobenius-balanced combined. 2. dW_left_basis_ceiling as the true weight oracle. 3. axis_kind tag (write/read/mixed/ceiling). Single-seed result: chars_clusters and attn_min_taskdiff are top-5 by both R_act and R_w_combined. Write-family bases (write/mlp_write/global_write) all have R_w_combined ~ 1.0 (random null) -- natural weight-side bases fail the weight-axis test. Multi-seed deferred to v7b.
1223 lines
64 KiB
Plaintext
1223 lines
64 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d677517a",
|
|
"metadata": {},
|
|
"source": [
|
|
"# v7 hypothesis sweep: per-tensor R_w, true weight ceiling, axis-kind tagging\n",
|
|
"\n",
|
|
"v6 found that R_w was Frobenius-dominated by mlp.down_proj (3M params)\n",
|
|
"vs self_attn.o_proj (1M), used PCA(hs_diff_B_fit) as the \"weight ceiling\"\n",
|
|
"(which is not a ceiling on weights), and silently scored read-side bases on\n",
|
|
"the write-side LoRA delta as if it meant \"explains delta\".\n",
|
|
"\n",
|
|
"v7 fixes:\n",
|
|
"1. R_w split into R_w_oproj, R_w_downproj, plus a Frobenius-balanced combined.\n",
|
|
"2. dw_left_basis is the true weight ceiling (R_w / R_w(dw_left_basis) ~ 1.0\n",
|
|
" for the oracle row by construction).\n",
|
|
"3. axis_kind tag (write/read/mixed/ceiling) on every hypothesis; read-side\n",
|
|
" rows are reported separately and excluded from the joint W-axis ranking.\n",
|
|
"4. (multi-seed loop deferred to v7b once single-seed validation passes.)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "0d0c40d2",
|
|
"metadata": {
|
|
"lines_to_next_cell": 2
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from __future__ import annotations\n",
|
|
"\n",
|
|
"import os\n",
|
|
"import sys\n",
|
|
"from dataclasses import dataclass\n",
|
|
"from pathlib import Path\n",
|
|
"\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import numpy as np\n",
|
|
"import polars as pl\n",
|
|
"import torch\n",
|
|
"import torch.nn.functional as F\n",
|
|
"from baukit import TraceDict\n",
|
|
"from loguru import logger\n",
|
|
"from tabulate import tabulate\n",
|
|
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
|
|
"\n",
|
|
"from ws.data import SYCOPHANCY_NEG_PERSONAS, SYCOPHANCY_POS_PERSONAS, SYCOPHANCY_TOPICS\n",
|
|
"from ws.diff import load_diff\n",
|
|
"from ws.steer import weight_steer"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "b417fb57",
|
|
"metadata": {
|
|
"lines_to_next_cell": 2
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"logger.remove()\n",
|
|
"logger.add(sys.stdout, level=os.environ.get(\"LOG_LEVEL\", \"INFO\"), colorize=False, format=\"{message}\")\n",
|
|
"Path(\"logs\").mkdir(exist_ok=True)\n",
|
|
"logger.add(\n",
|
|
" \"logs/hypothesis_sweep_v7.verbose.log\",\n",
|
|
" level=\"DEBUG\",\n",
|
|
" format=\"{time} | {level} | {name}:{function}:{line} - {message}\",\n",
|
|
")\n",
|
|
"torch.set_grad_enabled(False)\n",
|
|
"\n",
|
|
"MODEL_ID = \"Qwen/Qwen3-0.6B\"\n",
|
|
"W_PATH = Path(os.environ.get(\"W_PATH\", \"out/sycophancy/lora/w.pt\"))\n",
|
|
"OUT_DIR = Path(\"out/sycophancy/lora/v7\")\n",
|
|
"OUT_DIR.mkdir(parents=True, exist_ok=True)\n",
|
|
"\n",
|
|
"PCS = 8\n",
|
|
"K_BROAD = 64\n",
|
|
"N_NULL = 120\n",
|
|
"LORA_LAYERS = range(8, 22)\n",
|
|
"BOOT = 20_000\n",
|
|
"RNG = np.random.default_rng(0)\n",
|
|
"\n",
|
|
"PROBE_PROMPTS = [\n",
|
|
" f\"I think {claim}. Do you agree? Answer Yes or No.\" for claim, _q in SYCOPHANCY_TOPICS\n",
|
|
"]\n",
|
|
"FIT = PROBE_PROMPTS[: len(PROBE_PROMPTS) // 2]\n",
|
|
"EVAL = PROBE_PROMPTS[len(PROBE_PROMPTS) // 2 :]\n",
|
|
"\n",
|
|
"if not W_PATH.exists():\n",
|
|
" raise FileNotFoundError(f\"missing LoRA diff: {W_PATH}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "3de27df6",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Load model and B-side labels"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "f3c057bb",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"w = load_diff(W_PATH)\n",
|
|
"tok = AutoTokenizer.from_pretrained(MODEL_ID)\n",
|
|
"if tok.pad_token is None:\n",
|
|
" tok.pad_token = tok.eos_token\n",
|
|
"model = AutoModelForCausalLM.from_pretrained(\n",
|
|
" MODEL_ID, torch_dtype=torch.bfloat16, device_map=\"auto\", attn_implementation=\"eager\"\n",
|
|
")\n",
|
|
"model.eval()\n",
|
|
"state = model.state_dict()\n",
|
|
"n_layers = model.config.num_hidden_layers\n",
|
|
"HOOKS = [f\"model.layers.{i}\" for i in range(n_layers)]\n",
|
|
"UP_HOOKS = [f\"model.layers.{i}.mlp.up_proj\" for i in range(n_layers)]\n",
|
|
"\n",
|
|
"lm_head_W = state.get(\"lm_head.weight\")\n",
|
|
"if lm_head_W is None:\n",
|
|
" lm_head_W = state[\"model.embed_tokens.weight\"]\n",
|
|
"lm_head_W = lm_head_W.float().cpu()\n",
|
|
"d_model = lm_head_W.shape[1]\n",
|
|
"logger.info(f\"loaded {MODEL_ID} | layers={n_layers} | d_model={d_model} | LoRA tensors={len(w)} | W_PATH={W_PATH}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "76d6c6f6",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def pca(samples: torch.Tensor, k: int) -> torch.Tensor:\n",
|
|
" if samples.shape[0] <= 1:\n",
|
|
" return samples.new_zeros(samples.shape[1], 0)\n",
|
|
" centered = samples - samples.mean(0, keepdim=True)\n",
|
|
" _u, _s, vh = torch.linalg.svd(centered, full_matrices=False)\n",
|
|
" return vh[: min(k, vh.shape[0])].T.contiguous()\n",
|
|
"\n",
|
|
"\n",
|
|
"def basis_from_gram(gram: torch.Tensor, k: int) -> torch.Tensor:\n",
|
|
" evals, evecs = torch.linalg.eigh(gram.float().cpu())\n",
|
|
" keep = torch.argsort(evals, descending=True)[:k]\n",
|
|
" return evecs[:, keep].contiguous()\n",
|
|
"\n",
|
|
"\n",
|
|
"def orthonormalize(M: torch.Tensor, *, eps: float = 1e-5) -> torch.Tensor:\n",
|
|
" if M.numel() == 0:\n",
|
|
" return M.new_zeros(M.shape[0], 0)\n",
|
|
" Q, R = torch.linalg.qr(M)\n",
|
|
" keep = R.diag().abs() > eps\n",
|
|
" return Q[:, keep]\n",
|
|
"\n",
|
|
"\n",
|
|
"def orthonormal_union(*basis_list: torch.Tensor) -> torch.Tensor:\n",
|
|
" nonempty = [B for B in basis_list if B.shape[1] > 0]\n",
|
|
" if not nonempty:\n",
|
|
" return torch.zeros(d_model, 0)\n",
|
|
" return orthonormalize(torch.cat(nonempty, dim=1))\n",
|
|
"\n",
|
|
"\n",
|
|
"def intersect_basis(A: torch.Tensor, B: torch.Tensor, *, k: int = PCS) -> torch.Tensor:\n",
|
|
" if A.shape[1] == 0 or B.shape[1] == 0:\n",
|
|
" return torch.zeros(A.shape[0], 0)\n",
|
|
" U, _s, Vh = torch.linalg.svd(A.T @ B, full_matrices=False)\n",
|
|
" return orthonormalize(A @ U[:, :k] + B @ Vh.T[:, :k])[:, :k]\n",
|
|
"\n",
|
|
"\n",
|
|
"def left_svd_basis(M: torch.Tensor, k: int = PCS) -> torch.Tensor:\n",
|
|
" if M.shape[1] == 0:\n",
|
|
" return torch.zeros(M.shape[0], 0)\n",
|
|
" U, _s, _Vh = torch.linalg.svd(M.float().cpu(), full_matrices=False)\n",
|
|
" return U[:, : min(k, U.shape[1])].contiguous()\n",
|
|
"\n",
|
|
"\n",
|
|
"def right_svd_basis(M: torch.Tensor, k: int = PCS) -> torch.Tensor:\n",
|
|
" if M.shape[0] == 0:\n",
|
|
" return torch.zeros(M.shape[1], 0)\n",
|
|
" _U, _s, Vh = torch.linalg.svd(M.float().cpu(), full_matrices=False)\n",
|
|
" return Vh[: min(k, Vh.shape[0])].T.contiguous()\n",
|
|
"\n",
|
|
"\n",
|
|
"def complement_basis(basis: torch.Tensor, forbidden: torch.Tensor, *, k: int = PCS) -> torch.Tensor:\n",
|
|
" Q_forbidden = orthonormalize(forbidden)\n",
|
|
" Q_full, R = torch.linalg.qr(Q_forbidden, mode=\"complete\")\n",
|
|
" rank = int((R.diag().abs() > 1e-5).sum().item()) if R.numel() else 0\n",
|
|
" return Q_full[:, rank : rank + k].contiguous()\n",
|
|
"\n",
|
|
"\n",
|
|
"def project_away(basis: torch.Tensor, forbidden: torch.Tensor) -> torch.Tensor:\n",
|
|
" P = forbidden @ forbidden.T\n",
|
|
" return orthonormalize((torch.eye(basis.shape[0]) - P) @ basis)\n",
|
|
"\n",
|
|
"\n",
|
|
"def project_write_away(write_matrix: torch.Tensor, forbidden: torch.Tensor) -> torch.Tensor:\n",
|
|
" P = forbidden @ forbidden.T\n",
|
|
" return left_svd_basis((torch.eye(write_matrix.shape[0]) - P) @ write_matrix)\n",
|
|
"\n",
|
|
"\n",
|
|
"def principal_cos(A: torch.Tensor, B: torch.Tensor) -> float:\n",
|
|
" if A.shape[1] == 0 or B.shape[1] == 0:\n",
|
|
" return float(\"nan\")\n",
|
|
" return float(torch.linalg.svdvals(A.T @ B).clamp(0, 1).mean())\n",
|
|
"\n",
|
|
"\n",
|
|
"@dataclass(frozen=True)\n",
|
|
"class Candidate:\n",
|
|
" name: str\n",
|
|
" family: str\n",
|
|
" basis_by_layer: list[torch.Tensor]\n",
|
|
" source: str\n",
|
|
" definition: str"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "0c7d66c7",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def texts_from_prompts(prompts: list[str], *, system: str | None = None) -> list[str]:\n",
|
|
" if system is None:\n",
|
|
" return prompts\n",
|
|
" msgs = [[{\"role\": \"system\", \"content\": system}, {\"role\": \"user\", \"content\": p}] for p in prompts]\n",
|
|
" return [tok.apply_chat_template(m, tokenize=False, add_generation_prompt=True) for m in msgs]\n",
|
|
"\n",
|
|
"\n",
|
|
"def capture_blocks(prompts: list[str], *, alpha: float = 0.0, system: str | None = None) -> torch.Tensor:\n",
|
|
" texts = texts_from_prompts(prompts, system=system)\n",
|
|
" enc = tok(texts, return_tensors=\"pt\", padding=True, truncation=True, max_length=256).to(model.device)\n",
|
|
" seq_idx = enc.attention_mask.sum(-1) - 1\n",
|
|
" ctx = weight_steer(model, w, alpha) if alpha != 0 else torch.no_grad()\n",
|
|
" with ctx, TraceDict(model, HOOKS, retain_output=True) as ret:\n",
|
|
" _ = model(**enc)\n",
|
|
" rows = []\n",
|
|
" for hook in HOOKS:\n",
|
|
" x = ret[hook].output\n",
|
|
" if isinstance(x, tuple):\n",
|
|
" x = x[0]\n",
|
|
" b, _s, d = x.shape\n",
|
|
" rows.append(x.gather(1, seq_idx.view(b, 1, 1).expand(b, 1, d)).squeeze(1).float().cpu())\n",
|
|
" return torch.stack(rows, 0)\n",
|
|
"\n",
|
|
"\n",
|
|
"def capture_up_inputs(prompts: list[str], *, system: str | None = None) -> torch.Tensor:\n",
|
|
" texts = texts_from_prompts(prompts, system=system)\n",
|
|
" enc = tok(texts, return_tensors=\"pt\", padding=True, truncation=True, max_length=256).to(model.device)\n",
|
|
" seq_idx = enc.attention_mask.sum(-1) - 1\n",
|
|
" with TraceDict(model, UP_HOOKS, retain_input=True) as ret:\n",
|
|
" _ = model(**enc)\n",
|
|
" rows = []\n",
|
|
" for hook in UP_HOOKS:\n",
|
|
" x = ret[hook].input\n",
|
|
" if isinstance(x, tuple):\n",
|
|
" x = x[0]\n",
|
|
" b, _s, d = x.shape\n",
|
|
" rows.append(x.gather(1, seq_idx.view(b, 1, 1).expand(b, 1, d)).squeeze(1).float().cpu())\n",
|
|
" return torch.stack(rows, 0)\n",
|
|
"\n",
|
|
"\n",
|
|
"def capture_up_outputs_written(prompts: list[str], *, system: str | None = None) -> torch.Tensor:\n",
|
|
" texts = texts_from_prompts(prompts, system=system)\n",
|
|
" enc = tok(texts, return_tensors=\"pt\", padding=True, truncation=True, max_length=256).to(model.device)\n",
|
|
" seq_idx = enc.attention_mask.sum(-1) - 1\n",
|
|
" with TraceDict(model, UP_HOOKS, retain_output=True) as ret:\n",
|
|
" _ = model(**enc)\n",
|
|
" rows = []\n",
|
|
" for layer, hook in enumerate(UP_HOOKS):\n",
|
|
" x = ret[hook].output\n",
|
|
" if isinstance(x, tuple):\n",
|
|
" x = x[0]\n",
|
|
" b, _s, d_mlp = x.shape\n",
|
|
" x_last = x.gather(1, seq_idx.view(b, 1, 1).expand(b, 1, d_mlp)).squeeze(1).float().cpu()\n",
|
|
" W_down = state[f\"model.layers.{layer}.mlp.down_proj.weight\"].float().cpu()\n",
|
|
" rows.append(x_last @ W_down.T)\n",
|
|
" return torch.stack(rows, 0)\n",
|
|
"\n",
|
|
"\n",
|
|
"def capture_token_blocks_and_final_attn(\n",
|
|
" prompts: list[str], *, system: str\n",
|
|
") -> tuple[torch.Tensor, torch.Tensor]:\n",
|
|
" texts = texts_from_prompts(prompts, system=system)\n",
|
|
" enc = tok(texts, return_tensors=\"pt\", padding=True, truncation=True, max_length=256).to(model.device)\n",
|
|
" seq_idx = enc.attention_mask.sum(-1) - 1\n",
|
|
" out = model(**enc, output_hidden_states=True, output_attentions=True)\n",
|
|
" if out.attentions is None or out.hidden_states is None:\n",
|
|
" raise RuntimeError(\"model did not return attentions/hidden_states; attention-selected bases need eager attentions\")\n",
|
|
"\n",
|
|
" b = enc.input_ids.shape[0]\n",
|
|
" max_len = int(seq_idx.max().item()) + 1\n",
|
|
" hs_by_layer = []\n",
|
|
" attn_by_layer = []\n",
|
|
" for layer in range(n_layers):\n",
|
|
" hs = out.hidden_states[layer + 1].float().cpu()\n",
|
|
" attn = out.attentions[layer].float().cpu()\n",
|
|
" hs_aligned = hs.new_zeros(b, max_len, d_model)\n",
|
|
" attn_aligned = hs.new_zeros(b, max_len)\n",
|
|
" for sample in range(b):\n",
|
|
" n = int(seq_idx[sample].item()) + 1\n",
|
|
" hs_aligned[sample, -n:] = hs[sample, :n]\n",
|
|
" attn_aligned[sample, -n:] = attn[sample, :, n - 1, :n].mean(0)\n",
|
|
" hs_by_layer.append(hs_aligned)\n",
|
|
" attn_by_layer.append(attn_aligned)\n",
|
|
" return torch.stack(hs_by_layer), torch.stack(attn_by_layer)\n",
|
|
"\n",
|
|
"\n",
|
|
"def left_pad_sequence_dim(x: torch.Tensor, target_len: int) -> torch.Tensor:\n",
|
|
" if x.shape[2] == target_len:\n",
|
|
" return x\n",
|
|
" if x.shape[2] > target_len:\n",
|
|
" raise ValueError(f\"cannot pad length {x.shape[2]} down to {target_len}\")\n",
|
|
" pad_shape = (*x.shape[:2], target_len - x.shape[2], *x.shape[3:])\n",
|
|
" return torch.cat([x.new_zeros(pad_shape), x], dim=2)\n",
|
|
"\n",
|
|
"\n",
|
|
"def attention_selected_taskdiff_bases(\n",
|
|
" hs_pos_tokens: torch.Tensor,\n",
|
|
" hs_neg_tokens: torch.Tensor,\n",
|
|
" attn_pos: torch.Tensor,\n",
|
|
" attn_neg: torch.Tensor,\n",
|
|
") -> dict[str, list[torch.Tensor]]:\n",
|
|
" target_len = max(hs_pos_tokens.shape[2], hs_neg_tokens.shape[2])\n",
|
|
" hs_pos = left_pad_sequence_dim(hs_pos_tokens, target_len)\n",
|
|
" hs_neg = left_pad_sequence_dim(hs_neg_tokens, target_len)\n",
|
|
" a_pos = left_pad_sequence_dim(attn_pos[:, :, :, None], target_len).squeeze(-1)\n",
|
|
" a_neg = left_pad_sequence_dim(attn_neg[:, :, :, None], target_len).squeeze(-1)\n",
|
|
" diff = hs_pos - hs_neg\n",
|
|
" diff_norm = diff.norm(dim=-1)\n",
|
|
" norm_scale = diff_norm.sum(dim=(1, 2), keepdim=True) / (diff_norm.gt(0).sum(dim=(1, 2), keepdim=True) + 1e-12)\n",
|
|
" weights = {\n",
|
|
" \"attn_min_taskdiff\": torch.minimum(a_pos, a_neg),\n",
|
|
" \"attn_max_taskdiff\": torch.maximum(a_pos, a_neg),\n",
|
|
" \"attn_diff_taskdiff\": (a_pos - a_neg).abs(),\n",
|
|
" \"attn_min_x_diffnorm_taskdiff\": torch.minimum(a_pos, a_neg) * diff_norm / (norm_scale + 1e-12),\n",
|
|
" }\n",
|
|
" bases = {}\n",
|
|
" for name, weight in weights.items():\n",
|
|
" layer_bases = []\n",
|
|
" for layer in range(n_layers):\n",
|
|
" samples = diff[layer].reshape(-1, d_model)\n",
|
|
" w_flat = weight[layer].reshape(-1)\n",
|
|
" layer_bases.append(pca(samples * torch.sqrt(w_flat[:, None] + 1e-12), PCS))\n",
|
|
" bases[name] = layer_bases\n",
|
|
" return bases\n",
|
|
"\n",
|
|
"\n",
|
|
"logger.info(\"capturing B-side label and A-side activations\")\n",
|
|
"hs_pos_eval = capture_blocks(EVAL, alpha=+1.0)\n",
|
|
"hs_neg_eval = capture_blocks(EVAL, alpha=-1.0)\n",
|
|
"hs_diff_B = hs_pos_eval - hs_neg_eval\n",
|
|
"hs_pos_fit = capture_blocks(FIT, alpha=+1.0)\n",
|
|
"hs_neg_fit = capture_blocks(FIT, alpha=-1.0)\n",
|
|
"hs_diff_B_fit = hs_pos_fit - hs_neg_fit\n",
|
|
"\n",
|
|
"hs_persona_pos_fit = capture_blocks(FIT, system=SYCOPHANCY_POS_PERSONAS[0])\n",
|
|
"hs_persona_neg_fit = capture_blocks(FIT, system=SYCOPHANCY_NEG_PERSONAS[0])\n",
|
|
"hs_diff_A_fit = hs_persona_pos_fit - hs_persona_neg_fit\n",
|
|
"hs_clean_fit = capture_blocks(FIT)\n",
|
|
"up_clean_fit = capture_up_inputs(FIT)\n",
|
|
"up_persona_pos_fit = capture_up_inputs(FIT, system=SYCOPHANCY_POS_PERSONAS[0])\n",
|
|
"up_persona_neg_fit = capture_up_inputs(FIT, system=SYCOPHANCY_NEG_PERSONAS[0])\n",
|
|
"up_diff_A_fit = up_persona_pos_fit - up_persona_neg_fit\n",
|
|
"up_written_pos_fit = capture_up_outputs_written(FIT, system=SYCOPHANCY_POS_PERSONAS[0])\n",
|
|
"up_written_neg_fit = capture_up_outputs_written(FIT, system=SYCOPHANCY_NEG_PERSONAS[0])\n",
|
|
"up_written_diff_A_fit = up_written_pos_fit - up_written_neg_fit\n",
|
|
"hs_pos_tokens_fit, attn_pos_fit = capture_token_blocks_and_final_attn(FIT, system=SYCOPHANCY_POS_PERSONAS[0])\n",
|
|
"hs_neg_tokens_fit, attn_neg_fit = capture_token_blocks_and_final_attn(FIT, system=SYCOPHANCY_NEG_PERSONAS[0])\n",
|
|
"attn_selected_taskdiff = attention_selected_taskdiff_bases(\n",
|
|
" hs_pos_tokens_fit, hs_neg_tokens_fit, attn_pos_fit, attn_neg_fit\n",
|
|
")\n",
|
|
"logger.info(f\"captured label={tuple(hs_diff_B.shape)} | clean={tuple(hs_clean_fit.shape)} | up={tuple(up_clean_fit.shape)} | attn_tokens={tuple(hs_pos_tokens_fit.shape)}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ec68247f",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Build A-side candidate bases"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "a446f592",
|
|
"metadata": {
|
|
"lines_to_next_cell": 2
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def expand_rows_to(W_small: torch.Tensor, out_rows: int) -> torch.Tensor:\n",
|
|
" if W_small.shape[0] == out_rows:\n",
|
|
" return W_small\n",
|
|
" repeats = out_rows // W_small.shape[0]\n",
|
|
" if repeats * W_small.shape[0] != out_rows:\n",
|
|
" raise ValueError(f\"cannot repeat rows from {tuple(W_small.shape)} to {out_rows}\")\n",
|
|
" return W_small.repeat_interleave(repeats, dim=0)\n",
|
|
"\n",
|
|
"\n",
|
|
"def write_cols(layer: int, kinds: tuple[str, ...] = (\"self_attn.o_proj.weight\", \"mlp.down_proj.weight\")) -> torch.Tensor:\n",
|
|
" cols = []\n",
|
|
" for proj in kinds:\n",
|
|
" key = f\"model.layers.{layer}.{proj}\"\n",
|
|
" W = state.get(key)\n",
|
|
" if W is not None:\n",
|
|
" cols.append(W.float().cpu())\n",
|
|
" if not cols:\n",
|
|
" return torch.zeros(d_model, 0)\n",
|
|
" return torch.cat(cols, dim=1)\n",
|
|
"\n",
|
|
"\n",
|
|
"def read_stack(layer: int, projs: tuple[str, ...]) -> torch.Tensor:\n",
|
|
" return torch.cat([state[f\"model.layers.{layer}.{proj}\"].float().cpu() for proj in projs], dim=0)\n",
|
|
"\n",
|
|
"\n",
|
|
"def read_gram(layer: int) -> torch.Tensor:\n",
|
|
" W = read_stack(layer, (\n",
|
|
" \"self_attn.q_proj.weight\",\n",
|
|
" \"self_attn.k_proj.weight\",\n",
|
|
" \"self_attn.v_proj.weight\",\n",
|
|
" \"mlp.up_proj.weight\",\n",
|
|
" \"mlp.gate_proj.weight\",\n",
|
|
" ))\n",
|
|
" return W.T @ W\n",
|
|
"\n",
|
|
"\n",
|
|
"def suppressed_features(acts: torch.Tensor) -> torch.Tensor:\n",
|
|
" mag = acts.abs().permute(1, 0, 2)\n",
|
|
" delta = mag[:, 1:] - mag[:, :-1]\n",
|
|
" return torch.minimum(torch.relu(delta).sum(1), torch.relu(-delta).sum(1))\n",
|
|
"\n",
|
|
"\n",
|
|
"def amplified_features(acts: torch.Tensor) -> torch.Tensor:\n",
|
|
" mag = acts.abs().permute(1, 0, 2)\n",
|
|
" return torch.relu(mag[:, -1] - mag[:, 0])\n",
|
|
"\n",
|
|
"\n",
|
|
"def added_features(acts: torch.Tensor) -> torch.Tensor:\n",
|
|
" mag = acts.abs().permute(1, 0, 2)\n",
|
|
" return torch.relu(mag[:, 1:] - mag[:, :-1]).sum(1)\n",
|
|
"\n",
|
|
"\n",
|
|
"def procrustes_rotation_basis(X: torch.Tensor, Y: torch.Tensor, *, k: int = PCS, rank: int = 32) -> torch.Tensor:\n",
|
|
" joint = pca(torch.cat([X, Y], dim=0), min(rank, X.shape[0] + Y.shape[0] - 2, X.shape[1]))\n",
|
|
" if joint.shape[1] < 2:\n",
|
|
" return torch.zeros(X.shape[1], 0)\n",
|
|
" Xr = (X - X.mean(0, keepdim=True)) @ joint\n",
|
|
" Yr = (Y - Y.mean(0, keepdim=True)) @ joint\n",
|
|
" U, _s, Vh = torch.linalg.svd(Xr.T @ Yr, full_matrices=False)\n",
|
|
" R = U @ Vh\n",
|
|
" skew = R - R.T\n",
|
|
" U_skew, _s_skew, _Vh_skew = torch.linalg.svd(skew, full_matrices=False)\n",
|
|
" return orthonormalize(joint @ U_skew[:, : min(k, U_skew.shape[1])])\n",
|
|
"\n",
|
|
"\n",
|
|
"def kmeans_centroid_basis(samples: torch.Tensor, *, k_clusters: int = PCS, iters: int = 8) -> torch.Tensor:\n",
|
|
" centered = samples.float().cpu() - samples.float().cpu().mean(0, keepdim=True)\n",
|
|
" order = torch.argsort(centered.norm(dim=1), descending=True)\n",
|
|
" centroids = centered[order[: min(k_clusters, centered.shape[0])]].clone()\n",
|
|
" for _ in range(iters):\n",
|
|
" dist = torch.cdist(centered, centroids)\n",
|
|
" assign = dist.argmin(dim=1)\n",
|
|
" new_centroids = []\n",
|
|
" for idx in range(centroids.shape[0]):\n",
|
|
" members = centered[assign == idx]\n",
|
|
" new_centroids.append(members.mean(0) if members.shape[0] else centroids[idx])\n",
|
|
" centroids = torch.stack(new_centroids)\n",
|
|
" return pca(centroids - centroids.mean(0, keepdim=True), PCS)\n",
|
|
"\n",
|
|
"\n",
|
|
"_u_lm, _s_lm, vh_lm = torch.linalg.svd(lm_head_W, full_matrices=False)\n",
|
|
"lm_head_read = vh_lm[:PCS].T.contiguous()\n",
|
|
"logits_null = vh_lm[-PCS:].T.contiguous()\n",
|
|
"lm_read_broad = vh_lm[:K_BROAD].T.contiguous()\n",
|
|
"\n",
|
|
"read_grams = [read_gram(layer) for layer in range(n_layers)]\n",
|
|
"global_read_gram = sum(read_grams, torch.zeros(d_model, d_model)) + lm_head_W.T @ lm_head_W\n",
|
|
"global_read = basis_from_gram(global_read_gram, PCS)\n",
|
|
"global_read_broad = basis_from_gram(global_read_gram, K_BROAD)\n",
|
|
"global_write_cols = torch.cat([write_cols(layer) for layer in range(n_layers)], dim=1)\n",
|
|
"global_write = left_svd_basis(global_write_cols)\n",
|
|
"\n",
|
|
"downstream_read_broad = []\n",
|
|
"running = lm_head_W.T @ lm_head_W\n",
|
|
"for layer in reversed(range(n_layers)):\n",
|
|
" if layer < n_layers - 1:\n",
|
|
" running = running + read_grams[layer + 1]\n",
|
|
" downstream_read_broad.append(basis_from_gram(running, K_BROAD))\n",
|
|
"downstream_read_broad = list(reversed(downstream_read_broad))\n",
|
|
"\n",
|
|
"eye = torch.eye(d_model)\n",
|
|
"P_lm = lm_read_broad @ lm_read_broad.T\n",
|
|
"P_global_read = global_read_broad @ global_read_broad.T\n",
|
|
"\n",
|
|
"candidate_list: list[Candidate] = []\n",
|
|
"\n",
|
|
"\n",
|
|
"def add(name: str, family: str, basis_by_layer: list[torch.Tensor], definition: str, source: str = \"v5\") -> None:\n",
|
|
" if len(basis_by_layer) != n_layers:\n",
|
|
" raise ValueError(f\"{name} has {len(basis_by_layer)} layers, expected {n_layers}\")\n",
|
|
" for layer, B in enumerate(basis_by_layer):\n",
|
|
" if B.shape[0] != d_model:\n",
|
|
" raise ValueError(f\"{name}[{layer}] shape={tuple(B.shape)}, expected first dim {d_model}\")\n",
|
|
" if B.shape[1] > 0:\n",
|
|
" err = (B.T @ B - torch.eye(B.shape[1])).abs().max().item()\n",
|
|
" if err > 1e-3:\n",
|
|
" raise ValueError(f\"{name}[{layer}] is not orthonormal: maxerr={err}\")\n",
|
|
" candidate_list.append(Candidate(name, family, basis_by_layer, source, definition))\n",
|
|
"\n",
|
|
"\n",
|
|
"add(\"lm_head_read\", \"W:unembed\", [lm_head_read] * n_layers, \"top right singular vectors of lm_head\")\n",
|
|
"add(\"logits_null\", \"W:unembed\", [logits_null] * n_layers, \"bottom right singular vectors of lm_head\")\n",
|
|
"add(\"global_read\", \"W:read\", [global_read] * n_layers, \"top eigenspace of all q/k/v/up/gate reads + lm_head\")\n",
|
|
"add(\"global_write\", \"W:write\", [global_write] * n_layers, \"top left singular vectors of all o/down residual writers\")\n",
|
|
"add(\"global_write_not_global_read\", \"W:write-not-read\", [left_svd_basis((eye - P_global_read) @ global_write_cols)] * n_layers, \"global residual write projected away from global read directions\")\n",
|
|
"\n",
|
|
"write = [left_svd_basis(write_cols(layer)) for layer in range(n_layers)]\n",
|
|
"attn_write = [left_svd_basis(write_cols(layer, (\"self_attn.o_proj.weight\",))) for layer in range(n_layers)]\n",
|
|
"mlp_write = [left_svd_basis(write_cols(layer, (\"mlp.down_proj.weight\",))) for layer in range(n_layers)]\n",
|
|
"write_not_lm = [left_svd_basis((eye - P_lm) @ write_cols(layer)) for layer in range(n_layers)]\n",
|
|
"write_not_global_read = [left_svd_basis((eye - P_global_read) @ write_cols(layer)) for layer in range(n_layers)]\n",
|
|
"write_not_downstream_read = [\n",
|
|
" left_svd_basis((eye - downstream_read_broad[layer] @ downstream_read_broad[layer].T) @ write_cols(layer))\n",
|
|
" for layer in range(n_layers)\n",
|
|
"]\n",
|
|
"add(\"write\", \"W:write\", write, \"per-layer top left singular vectors of [W_o | W_down]\")\n",
|
|
"add(\"attn_write\", \"W:write\", attn_write, \"per-layer top left singular vectors of W_o\")\n",
|
|
"add(\"mlp_write\", \"W:write\", mlp_write, \"per-layer top left singular vectors of W_down\")\n",
|
|
"add(\"write_not_lm_head_read\", \"W:write-not-read\", write_not_lm, \"per-layer write projected away from lm_head top read\")\n",
|
|
"add(\"write_not_global_read\", \"W:write-not-read\", write_not_global_read, \"per-layer write projected away from global read\")\n",
|
|
"add(\"write_not_downstream_read\", \"W:write-not-read\", write_not_downstream_read, \"per-layer write projected away from downstream read + lm_head\")\n",
|
|
"\n",
|
|
"mlp_up_read = []\n",
|
|
"mlp_gate_read = []\n",
|
|
"attn_qkv_read = []\n",
|
|
"attn_ov_write = []\n",
|
|
"mlp_roundtrip = []\n",
|
|
"qk_circuit = []\n",
|
|
"input_super = []\n",
|
|
"kv_super = []\n",
|
|
"gate_kernel = []\n",
|
|
"attention_sink = []\n",
|
|
"causally_isolated = []\n",
|
|
"input_super_not_lm = []\n",
|
|
"gate_active_written = []\n",
|
|
"chars_clusters = []\n",
|
|
"for layer in range(n_layers):\n",
|
|
" up = state[f\"model.layers.{layer}.mlp.up_proj.weight\"].float().cpu()\n",
|
|
" gate = state[f\"model.layers.{layer}.mlp.gate_proj.weight\"].float().cpu()\n",
|
|
" q = state[f\"model.layers.{layer}.self_attn.q_proj.weight\"].float().cpu()\n",
|
|
" k = state[f\"model.layers.{layer}.self_attn.k_proj.weight\"].float().cpu()\n",
|
|
" v = state[f\"model.layers.{layer}.self_attn.v_proj.weight\"].float().cpu()\n",
|
|
" W_o = state[f\"model.layers.{layer}.self_attn.o_proj.weight\"].float().cpu()\n",
|
|
" W_down = state[f\"model.layers.{layer}.mlp.down_proj.weight\"].float().cpu()\n",
|
|
"\n",
|
|
" k_for_q = expand_rows_to(k, q.shape[0])\n",
|
|
" v_for_o = expand_rows_to(v, W_o.shape[1])\n",
|
|
" clean_up_x = up_clean_fit[layer]\n",
|
|
" mean_gate = F.silu(clean_up_x @ gate.T).mean(0)\n",
|
|
" gate_active = F.silu(clean_up_x @ gate.T) * (clean_up_x @ up.T)\n",
|
|
"\n",
|
|
" n_heads = model.config.num_attention_heads\n",
|
|
" n_kv_heads = model.config.num_key_value_heads\n",
|
|
" head_dim = W_o.shape[1] // n_heads\n",
|
|
" bos_id = tok.bos_token_id if tok.bos_token_id is not None else tok.eos_token_id\n",
|
|
" e_bos = state[\"model.embed_tokens.weight\"][bos_id].float().cpu()\n",
|
|
" sink_vecs = []\n",
|
|
" for head in range(n_heads):\n",
|
|
" kv_head = head * n_kv_heads // n_heads\n",
|
|
" o_h = W_o[:, head * head_dim : (head + 1) * head_dim]\n",
|
|
" v_h = v[kv_head * head_dim : (kv_head + 1) * head_dim]\n",
|
|
" sink_vecs.append(o_h @ (v_h @ e_bos))\n",
|
|
"\n",
|
|
" mlp_up_read.append(right_svd_basis(up))\n",
|
|
" mlp_gate_read.append(right_svd_basis(gate))\n",
|
|
" attn_qkv_read.append(right_svd_basis(torch.cat([q, k, v], dim=0)))\n",
|
|
" attn_ov_write.append(left_svd_basis(W_o @ v_for_o))\n",
|
|
" mlp_roundtrip.append(left_svd_basis(W_down @ up))\n",
|
|
" qk_circuit.append(left_svd_basis(q.T @ k_for_q))\n",
|
|
" input_super.append(right_svd_basis(torch.cat([q, k, v, up, gate], dim=0)))\n",
|
|
" kv_super.append(right_svd_basis(torch.cat([k, v], dim=0)))\n",
|
|
" gate_kernel.append(left_svd_basis(W_down @ (mean_gate[:, None] * up)))\n",
|
|
" attention_sink.append(pca(torch.stack(sink_vecs), PCS))\n",
|
|
" forbidden = orthonormal_union(input_super[-1], kv_super[-1], lm_read_broad)\n",
|
|
" causally_isolated.append(project_write_away(write_cols(layer), forbidden))\n",
|
|
" input_super_not_lm.append(project_away(input_super[-1], lm_read_broad)[:, :PCS])\n",
|
|
" gate_active_written.append(pca(gate_active @ W_down.T, PCS))\n",
|
|
" chars_samples = torch.cat([hs_clean_fit[layer], hs_persona_pos_fit[layer], hs_persona_neg_fit[layer]], dim=0)\n",
|
|
" chars_clusters.append(kmeans_centroid_basis(chars_samples))\n",
|
|
"\n",
|
|
"add(\"mlp_up_read\", \"W:read\", mlp_up_read, \"right singular vectors of W_up\")\n",
|
|
"add(\"mlp_gate_read\", \"W:read\", mlp_gate_read, \"right singular vectors of W_gate\")\n",
|
|
"add(\"attn_qkv_read\", \"W:read\", attn_qkv_read, \"right singular vectors of concatenated W_q/W_k/W_v\")\n",
|
|
"add(\"attn_ov_write\", \"W:OV\", attn_ov_write, \"left singular vectors of W_o W_v\")\n",
|
|
"add(\"mlp_roundtrip_write\", \"W:MLP\", mlp_roundtrip, \"left singular vectors of W_down W_up residual-to-residual map\")\n",
|
|
"add(\"qk_circuit\", \"W:QK\", qk_circuit, \"left singular vectors of W_q^T W_k after GQA row expansion\", source=\"external-v6-plan\")\n",
|
|
"add(\"input_super\", \"W:read\", input_super, \"right singular vectors of [W_q; W_k; W_v; W_up; W_gate]\", source=\"external-v6-plan\")\n",
|
|
"add(\"kv_super\", \"W:read\", kv_super, \"right singular vectors of [W_k; W_v]\", source=\"external-v6-plan\")\n",
|
|
"add(\"gate_kernel\", \"W:MLP\", gate_kernel, \"left singular vectors of W_down diag(E silu(W_gate h)) W_up\", source=\"external-v6-plan\")\n",
|
|
"add(\"attention_sink\", \"W:OV\", attention_sink, \"PCA over per-head W_o^h W_v^h e_BOS sink vectors\", source=\"external-v6-plan\")\n",
|
|
"add(\"causally_isolated\", \"W:write-not-read\", causally_isolated, \"write subspace projected away from input-read, KV, and lm_head read bases\", source=\"external-v6-plan\")\n",
|
|
"add(\"input_super_not_lm_read\", \"W:read\", input_super_not_lm, \"input_super projected away from lm_head top read directions\", source=\"external-v6-plan\")\n",
|
|
"\n",
|
|
"suppressed = pca(suppressed_features(hs_clean_fit), PCS)\n",
|
|
"amplified = pca(amplified_features(hs_clean_fit), PCS)\n",
|
|
"added = pca(added_features(hs_clean_fit), PCS)\n",
|
|
"global_clean_pca = pca(hs_clean_fit.permute(1, 0, 2).reshape(-1, d_model), PCS)\n",
|
|
"global_persona_pca = pca(\n",
|
|
" torch.cat([\n",
|
|
" hs_persona_pos_fit.permute(1, 0, 2).reshape(-1, d_model),\n",
|
|
" hs_persona_neg_fit.permute(1, 0, 2).reshape(-1, d_model),\n",
|
|
" ]),\n",
|
|
" PCS,\n",
|
|
")\n",
|
|
"add(\"suppressed\", \"act:clean\", [suppressed] * n_layers, \"PCA of base-model magnitude turnover across layers\")\n",
|
|
"add(\"amplified\", \"act:clean\", [amplified] * n_layers, \"PCA of base-model magnitudes that persist from first to last layer\")\n",
|
|
"add(\"added_features\", \"act:clean\", [added] * n_layers, \"PCA of positive layer-to-layer magnitude additions\", source=\"external-v6-plan\")\n",
|
|
"add(\"global_clean_resid_pca\", \"act:baseline\", [global_clean_pca] * n_layers, \"PCA of all clean base residual activations\")\n",
|
|
"add(\"global_persona_resid_pca\", \"act:baseline\", [global_persona_pca] * n_layers, \"PCA of persona residual activations without differencing\")\n",
|
|
"add(\"layer_clean_resid_pca\", \"act:baseline\", [pca(hs_clean_fit[layer], PCS) for layer in range(n_layers)], \"per-layer PCA of clean base residual activations\")\n",
|
|
"add(\"TaskDiff_contrast\", \"act:persona\", [pca(hs_diff_A_fit[layer], PCS) for layer in range(n_layers)], \"PCA of persona+ minus persona- residual activations\")\n",
|
|
"add(\"attn_min_taskdiff\", \"act:attn-selected\", attn_selected_taskdiff[\"attn_min_taskdiff\"], \"PCA of tokenwise persona TaskDiff weighted by min(pos, neg) final-token attention\", source=\"external-v6-plan\")\n",
|
|
"add(\"attn_max_taskdiff\", \"act:attn-selected\", attn_selected_taskdiff[\"attn_max_taskdiff\"], \"PCA of tokenwise persona TaskDiff weighted by max(pos, neg) final-token attention\", source=\"external-v6-plan\")\n",
|
|
"add(\"attn_diff_taskdiff\", \"act:attn-selected\", attn_selected_taskdiff[\"attn_diff_taskdiff\"], \"PCA of tokenwise persona TaskDiff weighted by abs(pos - neg) final-token attention\", source=\"external-v6-plan\")\n",
|
|
"add(\"attn_min_x_diffnorm_taskdiff\", \"act:attn-selected\", attn_selected_taskdiff[\"attn_min_x_diffnorm_taskdiff\"], \"PCA of tokenwise persona TaskDiff weighted by min(pos, neg) attention times tokenwise diff norm\", source=\"external-v6-plan\")\n",
|
|
"add(\"up_proj_input_contrast\", \"act:up_proj\", [pca(up_diff_A_fit[layer], PCS) for layer in range(n_layers)], \"PCA of persona contrast in inputs to mlp.up_proj\")\n",
|
|
"add(\"up_proj_output_written_contrast\", \"act:up_proj\", [pca(up_written_diff_A_fit[layer], PCS) for layer in range(n_layers)], \"PCA of persona contrast after W_up mapped back by W_down\")\n",
|
|
"add(\"gate_active_written\", \"act:MLP\", gate_active_written, \"PCA of silu(W_gate h) * W_up h mapped back by W_down on clean probes\", source=\"external-v6-plan\")\n",
|
|
"add(\"chars_clusters\", \"act:cluster\", chars_clusters, \"CHaRS-style PCA of k-means centroid differences over clean/persona activations\", source=\"external-v6-plan\")\n",
|
|
"add(\"churn\", \"act:clean\", [pca(hs_clean_fit[min(layer + 1, n_layers - 1)] - hs_clean_fit[layer], PCS) for layer in range(n_layers)], \"PCA of signed clean residual change h_{l+1}-h_l\")\n",
|
|
"add(\"rotation_contrast\", \"act:rotation\", [procrustes_rotation_basis(hs_persona_neg_fit[layer], hs_persona_pos_fit[layer]) for layer in range(n_layers)], \"skew generator from persona- to persona+ Procrustes rotation\")\n",
|
|
"add(\"qk_x_chars_clusters\", \"compound\", [intersect_basis(qk_circuit[layer], chars_clusters[layer]) for layer in range(n_layers)], \"bisector intersection of qk_circuit and CHaRS-style activation clusters\", source=\"external-v6-plan\")\n",
|
|
"add(\"WNR_union_TaskDiff\", \"compound\", [orthonormal_union(write_not_downstream_read[layer], pca(hs_diff_A_fit[layer], PCS)) for layer in range(n_layers)], \"rank-expanded union of write_not_downstream_read and TaskDiff_contrast\")\n",
|
|
"\n",
|
|
"ceiling = Candidate(\n",
|
|
" \"TaskDiff_lora_ceiling\",\n",
|
|
" \"ceiling\",\n",
|
|
" [pca(hs_diff_B_fit[layer], PCS) for layer in range(n_layers)],\n",
|
|
" \"B-side\",\n",
|
|
" \"PCA of LoRA FIT-half label; not an A-side hypothesis\",\n",
|
|
")\n",
|
|
"\n",
|
|
"logger.info(f\"built {len(candidate_list)} A-side candidates + ceiling\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "17a2f5e0",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Activation and weight scoring"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "5b8e3eba",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"_W_TENSOR_NAMES = (\"self_attn.o_proj.weight\", \"mlp.down_proj.weight\")\n",
|
|
"_dropped_keys_logged = False\n",
|
|
"\n",
|
|
"\n",
|
|
"def lora_weight_tensors(layer: int) -> dict[str, torch.Tensor]:\n",
|
|
" \"\"\"Per-tensor LoRA delta in residual-output (d_model row) space.\n",
|
|
"\n",
|
|
" v6 returned a single concatenated matrix; v7 keeps tensors separate so R_w\n",
|
|
" isn't silently Frobenius-weighted toward whichever tensor has more\n",
|
|
" parameters (down_proj has ~3x o_proj). Logs which residual-output keys\n",
|
|
" were skipped (for debugging if Qwen renames projections).\n",
|
|
" \"\"\"\n",
|
|
" global _dropped_keys_logged\n",
|
|
" out: dict[str, torch.Tensor] = {}\n",
|
|
" dropped = []\n",
|
|
" for proj in _W_TENSOR_NAMES:\n",
|
|
" key = f\"model.layers.{layer}.{proj}\"\n",
|
|
" if key not in w:\n",
|
|
" dropped.append((key, \"missing-from-LoRA\"))\n",
|
|
" continue\n",
|
|
" W = w[key].float().cpu()\n",
|
|
" if W.shape[0] != d_model:\n",
|
|
" dropped.append((key, f\"shape={tuple(W.shape)} d_model={d_model}\"))\n",
|
|
" continue\n",
|
|
" out[proj] = W\n",
|
|
" if dropped and not _dropped_keys_logged:\n",
|
|
" logger.info(f\"lora_weight_tensors layer={layer} dropped: {dropped}\")\n",
|
|
" _dropped_keys_logged = True\n",
|
|
" return out\n",
|
|
"\n",
|
|
"\n",
|
|
"def lora_weight_matrix(layer: int) -> torch.Tensor:\n",
|
|
" \"\"\"v6-compatible concatenated form, retained for dw_left_basis only.\"\"\"\n",
|
|
" tensors = lora_weight_tensors(layer)\n",
|
|
" if not tensors:\n",
|
|
" return torch.zeros(d_model, 0)\n",
|
|
" return torch.cat(list(tensors.values()), dim=1)\n",
|
|
"\n",
|
|
"\n",
|
|
"act_null_cache: dict[tuple[int, int], tuple[float, float]] = {}\n",
|
|
"w_null_cache: dict[tuple[int, int, str | None], tuple[float, float]] = {}\n",
|
|
"\n",
|
|
"\n",
|
|
"def act_null_stats(layer: int, rank: int) -> tuple[float, float]:\n",
|
|
" key = (layer, rank)\n",
|
|
" if key in act_null_cache:\n",
|
|
" return act_null_cache[key]\n",
|
|
" samples = hs_diff_B[layer]\n",
|
|
" d = samples.shape[1]\n",
|
|
" total = samples.pow(2).sum(1) + 1e-12\n",
|
|
" null = rank / d\n",
|
|
" gen = torch.Generator(device=samples.device).manual_seed(10_000 + 97 * layer + rank)\n",
|
|
" values = []\n",
|
|
" for _ in range(N_NULL):\n",
|
|
" rb, _ = torch.linalg.qr(torch.randn(d, rank, generator=gen, device=samples.device, dtype=samples.dtype))\n",
|
|
" values.append(((samples @ rb).pow(2).sum(1) / total).mean().item() / null)\n",
|
|
" arr = torch.tensor(values)\n",
|
|
" stats = (float(arr.mean()), float(arr.std(unbiased=True)))\n",
|
|
" act_null_cache[key] = stats\n",
|
|
" return stats\n",
|
|
"\n",
|
|
"\n",
|
|
"def w_null_stats(layer: int, rank: int, tensor_name: str | None = None) -> tuple[float, float]:\n",
|
|
" \"\"\"Random-orthonormal null for the weight concentration ratio.\n",
|
|
"\n",
|
|
" If tensor_name is None, uses the v6-style concatenated matrix (kept for\n",
|
|
" backward-compat with diagnostics). Otherwise scores against a single LoRA\n",
|
|
" tensor (o_proj or down_proj) so per-tensor R_w can be properly normalized.\n",
|
|
" \"\"\"\n",
|
|
" key = (layer, rank, tensor_name)\n",
|
|
" if key in w_null_cache:\n",
|
|
" return w_null_cache[key]\n",
|
|
" if tensor_name is None:\n",
|
|
" M = lora_weight_matrix(layer)\n",
|
|
" else:\n",
|
|
" tensors = lora_weight_tensors(layer)\n",
|
|
" M = tensors.get(tensor_name, torch.zeros(d_model, 0))\n",
|
|
" if M.shape[1] == 0:\n",
|
|
" stats = (float(\"nan\"), float(\"nan\"))\n",
|
|
" w_null_cache[key] = stats\n",
|
|
" return stats\n",
|
|
" d = M.shape[0]\n",
|
|
" total = M.pow(2).sum() + 1e-12\n",
|
|
" null = rank / d\n",
|
|
" seed_bump = 0 if tensor_name is None else (1 + hash(tensor_name) % 1000)\n",
|
|
" gen = torch.Generator(device=M.device).manual_seed(20_000 + 97 * layer + rank + 7919 * seed_bump)\n",
|
|
" values = []\n",
|
|
" for _ in range(N_NULL):\n",
|
|
" rb, _ = torch.linalg.qr(torch.randn(d, rank, generator=gen, device=M.device, dtype=M.dtype))\n",
|
|
" values.append(((rb.T @ M).pow(2).sum() / total).item() / null)\n",
|
|
" arr = torch.tensor(values)\n",
|
|
" stats = (float(arr.mean()), float(arr.std(unbiased=True)))\n",
|
|
" w_null_cache[key] = stats\n",
|
|
" return stats\n",
|
|
"\n",
|
|
"\n",
|
|
"def concentration_act(layer: int, basis: torch.Tensor) -> dict[str, float]:\n",
|
|
" samples = hs_diff_B[layer]\n",
|
|
" rank = basis.shape[1]\n",
|
|
" if rank == 0:\n",
|
|
" return {\"conc_act\": 0.0, \"z_act\": 0.0, \"energy_frac_act\": 0.0}\n",
|
|
" total = samples.pow(2).sum(1) + 1e-12\n",
|
|
" energy_frac = ((samples @ basis).pow(2).sum(1) / total).mean().item()\n",
|
|
" conc = energy_frac / (rank / samples.shape[1])\n",
|
|
" null_mean, null_std = act_null_stats(layer, rank)\n",
|
|
" return {\"conc_act\": conc, \"z_act\": (conc - null_mean) / (null_std + 1e-12), \"energy_frac_act\": energy_frac}\n",
|
|
"\n",
|
|
"\n",
|
|
"def concentration_w(layer: int, basis: torch.Tensor) -> dict[str, float]:\n",
|
|
" \"\"\"Per-tensor weight concentration + Frobenius-balanced combined.\n",
|
|
"\n",
|
|
" v6 returned a single conc_w that silently weighted by tensor size\n",
|
|
" (down_proj has ~3x the params of o_proj). v7 reports each tensor\n",
|
|
" separately so write-side hypotheses can be ranked by either, and a\n",
|
|
" 'combined' score that normalizes each tensor to unit Frobenius first\n",
|
|
" (size-balanced).\n",
|
|
" \"\"\"\n",
|
|
" rank = basis.shape[1]\n",
|
|
" tensors = lora_weight_tensors(layer)\n",
|
|
" out: dict[str, float] = {}\n",
|
|
" if rank == 0 or not tensors:\n",
|
|
" for name in (\"oproj\", \"downproj\", \"combined\"):\n",
|
|
" out[f\"conc_w_{name}\"] = float(\"nan\")\n",
|
|
" out[f\"z_w_{name}\"] = float(\"nan\")\n",
|
|
" out[f\"energy_frac_w_{name}\"] = float(\"nan\")\n",
|
|
" return out\n",
|
|
"\n",
|
|
" # Per-tensor scores\n",
|
|
" name_to_key = {\"oproj\": \"self_attn.o_proj.weight\", \"downproj\": \"mlp.down_proj.weight\"}\n",
|
|
" balanced_M_cols = []\n",
|
|
" for short, key in name_to_key.items():\n",
|
|
" M = tensors.get(key)\n",
|
|
" if M is None:\n",
|
|
" out[f\"conc_w_{short}\"] = float(\"nan\")\n",
|
|
" out[f\"z_w_{short}\"] = float(\"nan\")\n",
|
|
" out[f\"energy_frac_w_{short}\"] = float(\"nan\")\n",
|
|
" continue\n",
|
|
" total = M.pow(2).sum() + 1e-12\n",
|
|
" energy_frac = ((basis.T @ M).pow(2).sum() / total).item()\n",
|
|
" conc = energy_frac / (rank / M.shape[0])\n",
|
|
" null_mean, null_std = w_null_stats(layer, rank, key)\n",
|
|
" out[f\"conc_w_{short}\"] = conc\n",
|
|
" out[f\"z_w_{short}\"] = (conc - null_mean) / (null_std + 1e-12)\n",
|
|
" out[f\"energy_frac_w_{short}\"] = energy_frac\n",
|
|
" # Frobenius-balanced combined: each tensor normalized to unit Frobenius\n",
|
|
" balanced_M_cols.append(M / (M.pow(2).sum().sqrt() + 1e-12))\n",
|
|
"\n",
|
|
" # Combined: balanced concat (each tensor unit-Frobenius), then standard score\n",
|
|
" if balanced_M_cols:\n",
|
|
" M_bal = torch.cat(balanced_M_cols, dim=1)\n",
|
|
" total_bal = M_bal.pow(2).sum() + 1e-12\n",
|
|
" energy_frac_bal = ((basis.T @ M_bal).pow(2).sum() / total_bal).item()\n",
|
|
" conc_bal = energy_frac_bal / (rank / M_bal.shape[0])\n",
|
|
" # Null for balanced combined: rebuild on the fly (cheap, cached by key)\n",
|
|
" bal_key = (layer, rank, \"_balanced\")\n",
|
|
" if bal_key not in w_null_cache:\n",
|
|
" d = M_bal.shape[0]\n",
|
|
" null = rank / d\n",
|
|
" gen = torch.Generator(device=M_bal.device).manual_seed(30_000 + 97 * layer + rank)\n",
|
|
" values = []\n",
|
|
" for _ in range(N_NULL):\n",
|
|
" rb, _ = torch.linalg.qr(torch.randn(d, rank, generator=gen, device=M_bal.device, dtype=M_bal.dtype))\n",
|
|
" values.append(((rb.T @ M_bal).pow(2).sum() / total_bal).item() / null)\n",
|
|
" arr = torch.tensor(values)\n",
|
|
" w_null_cache[bal_key] = (float(arr.mean()), float(arr.std(unbiased=True)))\n",
|
|
" null_mean, null_std = w_null_cache[bal_key]\n",
|
|
" out[\"conc_w_combined\"] = conc_bal\n",
|
|
" out[\"z_w_combined\"] = (conc_bal - null_mean) / (null_std + 1e-12)\n",
|
|
" out[\"energy_frac_w_combined\"] = energy_frac_bal\n",
|
|
" else:\n",
|
|
" out[\"conc_w_combined\"] = float(\"nan\")\n",
|
|
" out[\"z_w_combined\"] = float(\"nan\")\n",
|
|
" out[\"energy_frac_w_combined\"] = float(\"nan\")\n",
|
|
" return out\n",
|
|
"\n",
|
|
"\n",
|
|
"def dw_left_basis(layer: int) -> torch.Tensor:\n",
|
|
" return left_svd_basis(lora_weight_matrix(layer))\n",
|
|
"\n",
|
|
"\n",
|
|
"def axis_kind_for(family: str) -> str:\n",
|
|
" \"\"\"Tag whether a hypothesis is read-side, write-side, or mixed in d_model.\n",
|
|
"\n",
|
|
" Read-side bases (input projections) trivially live in d_model just like the\n",
|
|
" write-side LoRA delta does, so R_w runs without error. But high R_w for a\n",
|
|
" read-side basis means \\\"this read direction happens to coincide with the\n",
|
|
" LoRA write direction\\\", not \\\"this primitive captures the write geometry\\\".\n",
|
|
" Read-side rows are reported separately and excluded from the joint W-axis\n",
|
|
" ranking. See docs/review/v6_hypothesis_review.md concern #3.\n",
|
|
" \"\"\"\n",
|
|
" if family == \"ceiling\":\n",
|
|
" return \"ceiling\"\n",
|
|
" if family in (\"W:read\", \"W:unembed\"):\n",
|
|
" return \"read\"\n",
|
|
" if family in (\"W:write\", \"W:write-not-read\", \"W:OV\", \"W:MLP\"):\n",
|
|
" return \"write\"\n",
|
|
" if family.startswith(\"act:\") or family in (\"W:QK\", \"compound\"):\n",
|
|
" return \"mixed\"\n",
|
|
" return \"mixed\"\n",
|
|
"\n",
|
|
"\n",
|
|
"# Build the true weight ceiling: top-PCS left singular vectors of the LoRA\n",
|
|
"# delta itself, per layer. This is the natural R_w oracle: scoring it gives\n",
|
|
"# R_w / R_w_ceiling ~ 1.0 for any properly-implemented per-tensor split.\n",
|
|
"weight_ceiling = Candidate(\n",
|
|
" \"dW_left_basis_ceiling\",\n",
|
|
" \"ceiling\",\n",
|
|
" [dw_left_basis(layer) for layer in range(n_layers)],\n",
|
|
" \"B-side\",\n",
|
|
" \"Top-PCS left singular vectors of the LoRA residual-output delta itself; defines R_w = 1.0 by construction\",\n",
|
|
")\n",
|
|
"\n",
|
|
"\n",
|
|
"all_candidates = [*candidate_list, ceiling, weight_ceiling]\n",
|
|
"dw_bases = [dw_left_basis(layer) for layer in range(n_layers)]\n",
|
|
"rows = []\n",
|
|
"for layer in range(n_layers):\n",
|
|
" for candidate in all_candidates:\n",
|
|
" basis = candidate.basis_by_layer[layer]\n",
|
|
" rows.append({\n",
|
|
" \"layer\": layer,\n",
|
|
" \"subspace\": candidate.name,\n",
|
|
" \"family\": candidate.family,\n",
|
|
" \"axis_kind\": axis_kind_for(candidate.family),\n",
|
|
" \"source\": candidate.source,\n",
|
|
" \"kind\": \"ceiling\" if candidate.family == \"ceiling\" else \"A-hypothesis\",\n",
|
|
" \"rank\": basis.shape[1],\n",
|
|
" **concentration_act(layer, basis),\n",
|
|
" **concentration_w(layer, basis),\n",
|
|
" \"cos_with_dW\": principal_cos(basis, dw_bases[layer]),\n",
|
|
" })\n",
|
|
"\n",
|
|
"per_layer = pl.DataFrame(rows)\n",
|
|
"per_layer_path = OUT_DIR / \"v7_per_layer.csv\"\n",
|
|
"per_layer.write_csv(per_layer_path)\n",
|
|
"\n",
|
|
"active = per_layer.filter(pl.col(\"layer\").is_in(list(LORA_LAYERS)))\n",
|
|
"summary = (\n",
|
|
" active.group_by([\"subspace\", \"family\", \"axis_kind\", \"source\", \"kind\"])\n",
|
|
" .agg(\n",
|
|
" pl.col(\"conc_act\").mean().alias(\"mean_conc_act\"),\n",
|
|
" pl.col(\"z_act\").mean().alias(\"mean_z_act\"),\n",
|
|
" pl.col(\"energy_frac_act\").mean().alias(\"mean_energy_frac_act\"),\n",
|
|
" pl.col(\"conc_w_oproj\").mean().alias(\"mean_conc_w_oproj\"),\n",
|
|
" pl.col(\"conc_w_downproj\").mean().alias(\"mean_conc_w_downproj\"),\n",
|
|
" pl.col(\"conc_w_combined\").mean().alias(\"mean_conc_w_combined\"),\n",
|
|
" pl.col(\"z_w_oproj\").mean().alias(\"mean_z_w_oproj\"),\n",
|
|
" pl.col(\"z_w_downproj\").mean().alias(\"mean_z_w_downproj\"),\n",
|
|
" pl.col(\"z_w_combined\").mean().alias(\"mean_z_w_combined\"),\n",
|
|
" pl.col(\"cos_with_dW\").mean().alias(\"mean_cos_dW\"),\n",
|
|
" pl.col(\"rank\").mean().alias(\"mean_rank\"),\n",
|
|
" )\n",
|
|
" .with_columns(\n",
|
|
" # Joint score uses the size-balanced combined R_w to be fair across hypotheses\n",
|
|
" joint_score=((pl.col(\"mean_conc_act\").log() + pl.col(\"mean_conc_w_combined\").log()) / 2).exp(),\n",
|
|
" act_w_gap_log2=(pl.col(\"mean_conc_act\").log(2) - pl.col(\"mean_conc_w_combined\").log(2)),\n",
|
|
" )\n",
|
|
" .sort(\"joint_score\", descending=True)\n",
|
|
")\n",
|
|
"\n",
|
|
"summary_path = OUT_DIR / \"v7_summary.tsv\"\n",
|
|
"summary.write_csv(summary_path, separator=\"\\t\")\n",
|
|
"\n",
|
|
"ceiling_act = float(summary.filter(pl.col(\"subspace\") == \"TaskDiff_lora_ceiling\")[\"mean_conc_act\"][0])\n",
|
|
"# True weight ceiling: dW_left_basis_ceiling. Reports as ~1.0 by construction\n",
|
|
"# (the basis IS the top singular subspace of the weight diff).\n",
|
|
"weight_ceiling_combined = float(\n",
|
|
" summary.filter(pl.col(\"subspace\") == \"dW_left_basis_ceiling\")[\"mean_conc_w_combined\"][0]\n",
|
|
")\n",
|
|
"weight_ceiling_oproj = float(\n",
|
|
" summary.filter(pl.col(\"subspace\") == \"dW_left_basis_ceiling\")[\"mean_conc_w_oproj\"][0]\n",
|
|
")\n",
|
|
"weight_ceiling_downproj = float(\n",
|
|
" summary.filter(pl.col(\"subspace\") == \"dW_left_basis_ceiling\")[\"mean_conc_w_downproj\"][0]\n",
|
|
")\n",
|
|
"logger.info(\n",
|
|
" f\"weight ceiling (dW_left_basis): combined={weight_ceiling_combined:.3f} \"\n",
|
|
" f\"oproj={weight_ceiling_oproj:.3f} downproj={weight_ceiling_downproj:.3f} \"\n",
|
|
" \"SHOULD: all > 1.0 (basis IS top singular subspace, so concentrates >> null); \"\n",
|
|
" \"oproj vs downproj differ because top-PCS captures different fractions of each \"\n",
|
|
" \"tensor's Frobenius energy (square-ish o_proj concentrates better than wide down_proj). \"\n",
|
|
" \"ELSE per-tensor split or null normalization is wrong.\"\n",
|
|
")\n",
|
|
"summary_pct = summary.with_columns(\n",
|
|
" pct_act_ceiling=100 * pl.col(\"mean_conc_act\") / ceiling_act,\n",
|
|
" pct_w_oracle_combined=100 * pl.col(\"mean_conc_w_combined\") / weight_ceiling_combined,\n",
|
|
" pct_w_oracle_oproj=100 * pl.col(\"mean_conc_w_oproj\") / weight_ceiling_oproj,\n",
|
|
" pct_w_oracle_downproj=100 * pl.col(\"mean_conc_w_downproj\") / weight_ceiling_downproj,\n",
|
|
")\n",
|
|
"summary_pct_path = OUT_DIR / \"v7_summary_pct.tsv\"\n",
|
|
"summary_pct.write_csv(summary_pct_path, separator=\"\\t\")\n",
|
|
"\n",
|
|
"# Separate write-side and read-side rankings for transparency\n",
|
|
"print(\"BLUF v7 joint act+weight (write/mixed only, ranked by joint_score):\")\n",
|
|
"write_mixed = summary_pct.filter(pl.col(\"axis_kind\").is_in([\"write\", \"mixed\", \"ceiling\"]))\n",
|
|
"print(tabulate(write_mixed.head(18).to_pandas(), headers=\"keys\", tablefmt=\"github\", floatfmt=\"+.3f\"))\n",
|
|
"\n",
|
|
"print(\"\\nv7 read-side rows (R_w means cross-space alignment, not 'explains delta'):\")\n",
|
|
"read_only = summary_pct.filter(pl.col(\"axis_kind\") == \"read\")\n",
|
|
"print(tabulate(read_only.to_pandas(), headers=\"keys\", tablefmt=\"github\", floatfmt=\"+.3f\"))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "54f86834",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Specificity: repeat activation score after removing clean residual PCs"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "e34f6612",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"clean_basis_by_layer = {c.name: c.basis_by_layer for c in candidate_list}[\"layer_clean_resid_pca\"]\n",
|
|
"specific_null_cache: dict[tuple[int, int, int], tuple[float, float]] = {}\n",
|
|
"\n",
|
|
"\n",
|
|
"def specific_null_stats(layer: int, rank: int, ambient_rank: int) -> tuple[float, float]:\n",
|
|
" key = (layer, rank, ambient_rank)\n",
|
|
" if key in specific_null_cache:\n",
|
|
" return specific_null_cache[key]\n",
|
|
" clean = clean_basis_by_layer[layer]\n",
|
|
" samples = hs_diff_B[layer] @ (torch.eye(d_model) - clean @ clean.T)\n",
|
|
" total = samples.pow(2).sum(1) + 1e-12\n",
|
|
" null = rank / ambient_rank\n",
|
|
" gen = torch.Generator(device=samples.device).manual_seed(50_000 + 97 * layer + 13 * rank)\n",
|
|
" values = []\n",
|
|
" for _ in range(N_NULL):\n",
|
|
" rb, _ = torch.linalg.qr(torch.randn(d_model, rank, generator=gen, device=samples.device, dtype=samples.dtype))\n",
|
|
" rb = project_away(rb, clean)\n",
|
|
" if rb.shape[1] != rank:\n",
|
|
" raise ValueError(f\"random residual rank collapsed: layer={layer}, rank={rank}, got={rb.shape[1]}\")\n",
|
|
" values.append(((samples @ rb).pow(2).sum(1) / total).mean().item() / null)\n",
|
|
" arr = torch.tensor(values)\n",
|
|
" stats = (float(arr.mean()), float(arr.std(unbiased=True)))\n",
|
|
" specific_null_cache[key] = stats\n",
|
|
" return stats\n",
|
|
"\n",
|
|
"\n",
|
|
"def specific_concentration_act(layer: int, basis: torch.Tensor) -> dict[str, float]:\n",
|
|
" clean = clean_basis_by_layer[layer]\n",
|
|
" residual_basis = project_away(basis, clean)\n",
|
|
" rank = residual_basis.shape[1]\n",
|
|
" if rank == 0:\n",
|
|
" return {\"specific_conc_act\": 0.0, \"specific_z_act\": 0.0, \"specific_energy_frac_act\": 0.0, \"specific_rank\": 0}\n",
|
|
" samples = hs_diff_B[layer] @ (torch.eye(d_model) - clean @ clean.T)\n",
|
|
" total = samples.pow(2).sum(1) + 1e-12\n",
|
|
" ambient_rank = d_model - clean.shape[1]\n",
|
|
" energy_frac = ((samples @ residual_basis).pow(2).sum(1) / total).mean().item()\n",
|
|
" conc = energy_frac / (rank / ambient_rank)\n",
|
|
" null_mean, null_std = specific_null_stats(layer, rank, ambient_rank)\n",
|
|
" return {\n",
|
|
" \"specific_conc_act\": conc,\n",
|
|
" \"specific_z_act\": (conc - null_mean) / (null_std + 1e-12),\n",
|
|
" \"specific_energy_frac_act\": energy_frac,\n",
|
|
" \"specific_rank\": rank,\n",
|
|
" }\n",
|
|
"\n",
|
|
"\n",
|
|
"specific_rows = []\n",
|
|
"for layer in range(n_layers):\n",
|
|
" for candidate in all_candidates:\n",
|
|
" specific_rows.append({\n",
|
|
" \"layer\": layer,\n",
|
|
" \"subspace\": candidate.name,\n",
|
|
" \"family\": candidate.family,\n",
|
|
" \"source\": candidate.source,\n",
|
|
" \"kind\": \"ceiling\" if candidate.family == \"ceiling\" else \"A-hypothesis\",\n",
|
|
" **specific_concentration_act(layer, candidate.basis_by_layer[layer]),\n",
|
|
" })\n",
|
|
"\n",
|
|
"specific_per_layer = pl.DataFrame(specific_rows)\n",
|
|
"specific_per_layer_path = OUT_DIR / \"v7_specific_per_layer.csv\"\n",
|
|
"specific_per_layer.write_csv(specific_per_layer_path)\n",
|
|
"specific_summary = (\n",
|
|
" specific_per_layer.filter(pl.col(\"layer\").is_in(list(LORA_LAYERS)))\n",
|
|
" .group_by([\"subspace\", \"family\", \"source\", \"kind\"])\n",
|
|
" .agg(\n",
|
|
" pl.col(\"specific_conc_act\").mean().alias(\"mean_specific_conc_act\"),\n",
|
|
" pl.col(\"specific_z_act\").mean().alias(\"mean_specific_z_act\"),\n",
|
|
" pl.col(\"specific_energy_frac_act\").mean().alias(\"mean_specific_energy_frac_act\"),\n",
|
|
" pl.col(\"specific_rank\").mean().alias(\"mean_specific_rank\"),\n",
|
|
" )\n",
|
|
" .sort(\"mean_specific_conc_act\", descending=True)\n",
|
|
")\n",
|
|
"specific_summary_path = OUT_DIR / \"v7_specific_summary.tsv\"\n",
|
|
"specific_summary.write_csv(specific_summary_path, separator=\"\\t\")\n",
|
|
"\n",
|
|
"print(\"BLUF v7 residualized activation specificity:\")\n",
|
|
"print(tabulate(specific_summary.head(16).to_pandas(), headers=\"keys\", tablefmt=\"github\", floatfmt=\"+.3f\"))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "c24afd48",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Figures and definitions"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "4bd98162",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"plt.rcParams.update({\"figure.dpi\": 160, \"savefig.dpi\": 240, \"font.size\": 9})\n",
|
|
"plot_df_all = summary_pct.filter(pl.col(\"kind\") == \"A-hypothesis\").to_pandas()\n",
|
|
"# Two-panel scatter: write/mixed (joint ranking) and read-side (cross-space alignment)\n",
|
|
"fig, axes = plt.subplots(1, 2, figsize=(13, 6.2), sharey=True)\n",
|
|
"for ax, kind_filter, panel_title in [\n",
|
|
" (axes[0], (\"write\", \"mixed\"), \"write+mixed (R_w = explains delta)\"),\n",
|
|
" (axes[1], (\"read\",), \"read-side (R_w = cross-space alignment)\"),\n",
|
|
"]:\n",
|
|
" panel_df = plot_df_all[plot_df_all[\"axis_kind\"].isin(kind_filter)].head(20)\n",
|
|
" for family, fam_df in panel_df.groupby(\"family\"):\n",
|
|
" ax.scatter(fam_df[\"mean_conc_act\"], fam_df[\"mean_conc_w_combined\"], s=52, alpha=0.82, label=family)\n",
|
|
" for row in panel_df.head(10).itertuples(index=False):\n",
|
|
" ax.annotate(row.subspace, (row.mean_conc_act, row.mean_conc_w_combined), fontsize=7, xytext=(3, 3), textcoords=\"offset points\")\n",
|
|
" ax.axvline(1.0, color=\"black\", linestyle=\"--\", linewidth=0.9)\n",
|
|
" ax.axhline(1.0, color=\"black\", linestyle=\"--\", linewidth=0.9)\n",
|
|
" ax.set_xscale(\"log\")\n",
|
|
" ax.set_yscale(\"log\")\n",
|
|
" ax.set_xlabel(\"activation recovery R_act\")\n",
|
|
" ax.set_title(panel_title)\n",
|
|
" ax.grid(alpha=0.25, which=\"both\")\n",
|
|
" ax.legend(fontsize=7, ncols=2)\n",
|
|
"axes[0].set_ylabel(\"weight recovery R_w (Frobenius-balanced combined)\")\n",
|
|
"ceiling_df = summary_pct.filter(pl.col(\"kind\") == \"ceiling\").to_pandas()\n",
|
|
"for ax in axes:\n",
|
|
" if len(ceiling_df):\n",
|
|
" ax.scatter(ceiling_df[\"mean_conc_act\"], ceiling_df[\"mean_conc_w_combined\"], s=85, marker=\"*\", color=\"black\", label=\"ceiling\")\n",
|
|
"fig.suptitle(\"v7: read-side R_w is cross-space alignment, not 'explains delta'\")\n",
|
|
"fig.tight_layout()\n",
|
|
"scatter_png = OUT_DIR / \"v7_joint_act_weight_scatter.png\"\n",
|
|
"scatter_pdf = OUT_DIR / \"v7_joint_act_weight_scatter.pdf\"\n",
|
|
"fig.savefig(scatter_png, bbox_inches=\"tight\")\n",
|
|
"fig.savefig(scatter_pdf, bbox_inches=\"tight\")\n",
|
|
"plt.close(fig)\n",
|
|
"\n",
|
|
"definitions_path = OUT_DIR / \"v7_definitions.md\"\n",
|
|
"plan_merge_path = OUT_DIR / \"v7_plan_merge.md\"\n",
|
|
"definitions = [\n",
|
|
" \"# v7 hypothesis definitions\",\n",
|
|
" \"\",\n",
|
|
" \"All A-side hypotheses are built without the trained LoRA. The LoRA diff is used only for B-side scoring.\",\n",
|
|
" \"\",\n",
|
|
" \"v7 changes vs v6: per-tensor R_w (oproj/downproj/combined), dW_left_basis_ceiling as the true weight ceiling, axis_kind tag (write/read/mixed/ceiling) so read-side cross-space scores aren't conflated with 'explains delta'.\",\n",
|
|
" \"\",\n",
|
|
" \"| name | family | axis_kind | source | definition |\",\n",
|
|
" \"|---|---|---|---|---|\",\n",
|
|
"]\n",
|
|
"for candidate in all_candidates:\n",
|
|
" definitions.append(f\"| `{candidate.name}` | {candidate.family} | {axis_kind_for(candidate.family)} | {candidate.source} | {candidate.definition} |\")\n",
|
|
"definitions_path.write_text(\"\\n\".join(definitions) + \"\\n\")\n",
|
|
"\n",
|
|
"plan_merge_path.write_text(\"\"\"# v7 changes vs v6\n",
|
|
"\n",
|
|
"Addresses three real concerns from `docs/review/v6_hypothesis_review.md`:\n",
|
|
"\n",
|
|
"1. **Per-tensor R_w.** `lora_weight_tensors(layer)` returns a dict {o_proj, down_proj}; `concentration_w` reports `R_w_oproj`, `R_w_downproj`, and a Frobenius-balanced `R_w_combined`. Joint score uses combined; per-tensor are reported for inspection. Eliminates the silent down_proj domination (down_proj has ~3x the params of o_proj in this model).\n",
|
|
"\n",
|
|
"2. **True weight ceiling.** Added `dW_left_basis_ceiling` candidate: top-PCS left singular vectors of the LoRA delta itself. By construction `R_w(combined) ~ d_model/PCS = 128` for that row, so `pct_w_oracle_combined` is on a true 0-100 scale (oracle = 100). The v6 column `pct_w_taskdiff_basis` was relative to `PCA(hs_diff_B_fit)` -- an activation basis, not a weight oracle.\n",
|
|
"\n",
|
|
"3. **axis_kind tag.** Each candidate is tagged write / read / mixed / ceiling. Read-side bases (mlp_up_read, mlp_gate_read, attn_qkv_read, kv_super, input_super, lm_head_read, logits_null, input_super_not_lm_read) are reported in a separate sub-table and a separate scatter panel. High R_w on a read-side basis means \"this read direction happens to coincide with LoRA write directions\", not \"this primitive captures the LoRA write geometry\".\n",
|
|
"\n",
|
|
"Deferred to v7b (multi-seed): currently single-LoRA-seed; rankings are anecdote-grade until run on >=3 LoRA seeds with stability filtering.\n",
|
|
"\n",
|
|
"Not fixed (left as known-limitations comments only):\n",
|
|
"- `chars_clusters` PCA collapses to rank 7 because centroids - mean has rank k_clusters - 1 = 7 < PCS=8.\n",
|
|
"- `qk_circuit` mixes all heads in one d_model x d_model matrix.\n",
|
|
"- `intersect_basis` uses Bjorck-Golub bisector, not strict subspace intersection (returns directions even at low principal-angle alignment).\n",
|
|
"\"\"\")\n",
|
|
"\n",
|
|
"winner = summary_pct.filter((pl.col(\"kind\") == \"A-hypothesis\") & (pl.col(\"axis_kind\").is_in([\"write\", \"mixed\"]))).row(0, named=True)\n",
|
|
"act_winners = summary_pct.filter(pl.col(\"kind\") == \"A-hypothesis\").sort(\"mean_conc_act\", descending=True).head(5)\n",
|
|
"w_winners = summary_pct.filter((pl.col(\"kind\") == \"A-hypothesis\") & (pl.col(\"axis_kind\").is_in([\"write\", \"mixed\"]))).sort(\"mean_conc_w_combined\", descending=True).head(5)\n",
|
|
"top_act = set(act_winners[\"subspace\"].to_list())\n",
|
|
"top_w = set(w_winners[\"subspace\"].to_list())\n",
|
|
"both_top5 = sorted(top_act & top_w)\n",
|
|
"conclusion_path = OUT_DIR / \"v7_conclusion.md\"\n",
|
|
"conclusion_path.write_text(f\"\"\"# v7 hypothesis sweep conclusion\n",
|
|
"\n",
|
|
"## BLUF\n",
|
|
"\n",
|
|
"Best joint A-side primitive (write/mixed only) by geometric mean of activation\n",
|
|
"and Frobenius-balanced weight recovery: `{winner['subspace']}`. R_act={winner['mean_conc_act']:.2f},\n",
|
|
"R_w_combined={winner['mean_conc_w_combined']:.2f} (oracle={weight_ceiling_combined:.2f}, so\n",
|
|
"{winner['pct_w_oracle_combined']:.1f}% of weight ceiling), joint={winner['joint_score']:.2f}.\n",
|
|
"\n",
|
|
"Per-tensor R_w for the winner: oproj={winner['mean_conc_w_oproj']:.2f} ({winner['pct_w_oracle_oproj']:.1f}% of oracle), downproj={winner['mean_conc_w_downproj']:.2f} ({winner['pct_w_oracle_downproj']:.1f}% of oracle).\n",
|
|
"\n",
|
|
"Top-5 overlap between activation winners and weight winners (write/mixed only): {both_top5}.\n",
|
|
"\n",
|
|
"## v7 changes vs v6\n",
|
|
"\n",
|
|
"1. R_w split per LoRA tensor (o_proj vs down_proj) plus a Frobenius-balanced combined; v6's single conc_w was silently dominated by down_proj (~3x the params).\n",
|
|
"2. dW_left_basis_ceiling row gives `R_w_combined~={weight_ceiling_combined:.2f}` (oracle); `pct_w_oracle_combined` is now percent-of-oracle, not percent-of-PCA(hs_diff_B_fit).\n",
|
|
"3. Read-side hypotheses (input projections) are tagged axis_kind='read' and reported in a separate sub-table. A high R_w there means cross-space alignment between the read subspace and the write-side LoRA delta -- not 'this primitive explains the delta'.\n",
|
|
"\n",
|
|
"## Caveats\n",
|
|
"\n",
|
|
"- Single LoRA seed; rankings are anecdote-grade until v7b multi-seed runs.\n",
|
|
"- R_w only scores residual-output LoRA tensors (`o_proj`, `down_proj`) because the basis lives in residual-output space (d_model rows).\n",
|
|
"- `chars_clusters` silently rank-collapses to 7 (centroids - mean has rank k-1); `qk_circuit` mixes all heads; `intersect_basis` is the Bjorck-Golub bisector not strict intersection. Inline comments only; not fixed in v7.\n",
|
|
"\n",
|
|
"## Artifacts\n",
|
|
"\n",
|
|
"- Per-layer raw scores: `{per_layer_path}`\n",
|
|
"- Summary: `{summary_path}`\n",
|
|
"- Summary with oracle-relative percentages: `{summary_pct_path}`\n",
|
|
"- Residualized activation per-layer scores: `{specific_per_layer_path}`\n",
|
|
"- Residualized activation summary: `{specific_summary_path}`\n",
|
|
"- Joint scatter (write+mixed | read sub-panel): `{scatter_png}`, `{scatter_pdf}`\n",
|
|
"- Definitions: `{definitions_path}`\n",
|
|
"- v7-vs-v6 changes: `{plan_merge_path}`\n",
|
|
"\"\"\")\n",
|
|
"\n",
|
|
"print(\"wrote:\")\n",
|
|
"for path in [\n",
|
|
" per_layer_path,\n",
|
|
" summary_path,\n",
|
|
" summary_pct_path,\n",
|
|
" specific_per_layer_path,\n",
|
|
" specific_summary_path,\n",
|
|
" definitions_path,\n",
|
|
" plan_merge_path,\n",
|
|
" conclusion_path,\n",
|
|
" scatter_png,\n",
|
|
" scatter_pdf,\n",
|
|
"]:\n",
|
|
" print(f\" {path} ({path.stat().st_size} bytes)\")\n",
|
|
"\n",
|
|
"print(\n",
|
|
" \"SHOULD: useful subspaces have R_act>1 and R_w>1; generic activation artifacts show high R_act but weak R_w. \"\n",
|
|
" \"ELSE: check basis orientation and LoRA diff tensor selection.\"\n",
|
|
")"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"jupytext": {
|
|
"cell_metadata_filter": "-all",
|
|
"main_language": "python",
|
|
"notebook_metadata_filter": "-all"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|