Precompute SVD for S-space projection for efficiency

This commit is contained in:
wassname
2026-04-10 09:46:42 +08:00
parent a1a8648865
commit 382ffc4315
2 changed files with 41 additions and 45 deletions
+23 -25
View File
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
"id": "76397472",
"id": "b8a288c6",
"metadata": {},
"source": [
"# Brukino's AntiPaSTO Appetizer: Guided CoT Eval & Frenet-Serret Curvature\n",
@@ -20,7 +20,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "703c8e38",
"id": "e8c081e6",
"metadata": {},
"outputs": [],
"source": [
@@ -44,28 +44,35 @@
{
"cell_type": "code",
"execution_count": null,
"id": "e6f56583",
"id": "7b5b34e4",
"metadata": {},
"outputs": [],
"source": [
"def get_collective_s_space_weight(model):\n",
"def get_s_space_projector(model, top_k=256):\n",
" \"\"\"\n",
" Gathers all weight matrices that write to the residual stream\n",
" (o_proj from attention and down_proj from MLP) across all layers,\n",
" and concatenates them to form a collective \"write\" transformation.\n",
" Then computes the SVD to extract the top_k modes.\n",
" Returns: U (hidden_size, top_k), S (top_k,)\n",
" \"\"\"\n",
" Ws = []\n",
" for layer in model.model.layers:\n",
" # In Qwen2, o_proj and down_proj weights are shape [hidden_size, in_features]\n",
" # We want a combined matrix of shape [hidden_size, sum(in_features)]\n",
" # so that SVD gives U of shape [hidden_size, top_k].\n",
" # o_proj.weight is [hidden_size, num_heads * head_dim]\n",
" # down_proj.weight is [hidden_size, intermediate_size]\n",
" Ws.append(layer.self_attn.o_proj.weight.detach().cpu())\n",
" Ws.append(layer.mlp.down_proj.weight.detach().cpu())\n",
" return torch.cat(Ws, dim=1).to(model.device)\n",
" W = torch.cat(Ws, dim=1).to(model.device)\n",
" \n",
" # SVD on the collective weight matrix\n",
" U, S, _ = torch.linalg.svd(W.float(), full_matrices=False)\n",
" \n",
" # Crop to top_k modes\n",
" U = U[:, :top_k]\n",
" S = S[:top_k]\n",
" return U, S\n",
"\n",
"def project_to_s_space(hidden_states, W, top_k=256):\n",
"def project_to_s_space(hidden_states, U, S):\n",
" \"\"\"\n",
" Projects the residual stream into the 'super' S-space of all residual writers.\n",
" \n",
@@ -75,16 +82,7 @@
" hidden states from the residual stream, and the U from all residual writers, \n",
" we can project the residual stream into S-space, which can be thought of as \n",
" something like the coordinate space of learned modes of behaviors.\n",
" \n",
" W: The concatenated weight matrix of all residual writers, shape (sum(in_features), hidden_size)\n",
" \"\"\"\n",
" # SVD on the collective weight matrix\n",
" U, S, _ = torch.linalg.svd(W.float(), full_matrices=False)\n",
" \n",
" # Crop to top_k modes (there will be a lot of overlap/redundancy)\n",
" U = U[:, :top_k]\n",
" S = S[:top_k]\n",
" \n",
" # Project: x_S = (x @ U)\n",
" x_S = hidden_states.to(torch.float32) @ U\n",
" \n",
@@ -125,11 +123,11 @@
{
"cell_type": "code",
"execution_count": null,
"id": "7cbb2a21",
"id": "ab5130a9",
"metadata": {},
"outputs": [],
"source": [
"def guided_eval(model, tokenizer, prompt_text, n_think=32, device=\"cuda\", s_space_weight=None):\n",
"def guided_eval(model, tokenizer, prompt_text, n_think=32, device=\"cuda\", s_space_U=None, s_space_S=None):\n",
" messages = [{\"role\": \"user\", \"content\": prompt_text}]\n",
" \n",
" inputs = tokenizer.apply_chat_template(\n",
@@ -190,8 +188,8 @@
" start_idx = prompt_ids.shape[1]\n",
" cot_hiddens = final_layer_hiddens[start_idx : start_idx + generated_ids.shape[0]]\n",
" \n",
" if s_space_weight is not None:\n",
" trajectory = project_to_s_space(cot_hiddens, s_space_weight)\n",
" if s_space_U is not None and s_space_S is not None:\n",
" trajectory = project_to_s_space(cot_hiddens, s_space_U, s_space_S)\n",
" else:\n",
" trajectory = cot_hiddens\n",
" \n",
@@ -207,7 +205,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "812de2b3",
"id": "96c84bd3",
"metadata": {
"lines_to_next_cell": 2
},
@@ -241,11 +239,11 @@
"results = {}\n",
"\n",
"# Project using the collective residual stream writers transformation\n",
"s_space_W = get_collective_s_space_weight(model) \n",
"s_space_U, s_space_S = get_s_space_projector(model)\n",
"\n",
"for p_key, p_prefix in PERSONAS.items():\n",
" print(f\"\\n--- Running: {p_key} ---\")\n",
" res = guided_eval(model, tokenizer, p_prefix + prompt_base, n_think=N_THINK_TOKENS, device=DEVICE, s_space_weight=s_space_W)\n",
" res = guided_eval(model, tokenizer, p_prefix + prompt_base, n_think=N_THINK_TOKENS, device=DEVICE, s_space_U=s_space_U, s_space_S=s_space_S)\n",
" results[p_key] = res\n",
" print(f\"Logratio (Yes/No): {res['logratio']:.3f}\")\n",
" print(f\"Prompt:\\n```md\\n{res['prompt']}```\")\n",