Refactor to return U, S, Vh from SVD and perform cropping outside the function

This commit is contained in:
wassname
2026-04-10 10:00:28 +08:00
parent e44dc0e74e
commit c8a59851ed
2 changed files with 23 additions and 23 deletions
+14 -14
View File
@@ -2,7 +2,7 @@
"cells": [ "cells": [
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "62eec772", "id": "e0ccb9f3",
"metadata": {}, "metadata": {},
"source": [ "source": [
"# Brukino's AntiPaSTO Appetizer: Guided CoT Eval & Frenet-Serret Curvature\n", "# Brukino's AntiPaSTO Appetizer: Guided CoT Eval & Frenet-Serret Curvature\n",
@@ -20,7 +20,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "47c7efe2", "id": "f61786b3",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -44,17 +44,17 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "8c590e6c", "id": "b130793b",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"def get_s_space_projector(model, top_k=1024):\n", "def get_s_space_svd(model):\n",
" \"\"\"\n", " \"\"\"\n",
" Gathers all weight matrices that write to the residual stream\n", " Gathers all weight matrices that write to the residual stream\n",
" (o_proj from attention and down_proj from MLP) across all layers,\n", " (o_proj from attention and down_proj from MLP) across all layers,\n",
" and concatenates them to form a collective \"write\" transformation.\n", " and concatenates them to form a collective \"write\" transformation.\n",
" Then computes the SVD to extract the top_k modes.\n", " Then computes and returns the full SVD.\n",
" Returns: U (hidden_size, top_k), S (top_k,)\n", " Returns: U, S, Vh\n",
" \"\"\"\n", " \"\"\"\n",
" Ws = []\n", " Ws = []\n",
" for layer in model.model.layers:\n", " for layer in model.model.layers:\n",
@@ -65,12 +65,9 @@
" W = torch.cat(Ws, dim=1).to(model.device)\n", " W = torch.cat(Ws, dim=1).to(model.device)\n",
" \n", " \n",
" # SVD on the collective weight matrix\n", " # SVD on the collective weight matrix\n",
" U, S, _ = torch.linalg.svd(W.float(), full_matrices=False)\n", " U, S, Vh = torch.linalg.svd(W.float(), full_matrices=False)\n",
" \n", " \n",
" # Crop to top_k modes\n", " return U, S, Vh\n",
" U = U[:, :top_k]\n",
" S = S[:top_k]\n",
" return U, S\n",
"\n", "\n",
"def project_to_s_space(hidden_states, U, S):\n", "def project_to_s_space(hidden_states, U, S):\n",
" \"\"\"\n", " \"\"\"\n",
@@ -123,7 +120,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "6ae905b2", "id": "812eed4d",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -205,7 +202,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "30e7fb4e", "id": "57d967a9",
"metadata": { "metadata": {
"lines_to_next_cell": 2 "lines_to_next_cell": 2
}, },
@@ -239,7 +236,10 @@
"results = {}\n", "results = {}\n",
"\n", "\n",
"# Project using the collective residual stream writers transformation\n", "# Project using the collective residual stream writers transformation\n",
"s_space_U, s_space_S = get_s_space_projector(model)\n", "U, S, Vh = get_s_space_svd(model)\n",
"top_k = 1024\n",
"s_space_U = U[:, :top_k]\n",
"s_space_S = S[:top_k]\n",
"\n", "\n",
"for p_key, p_prefix in PERSONAS.items():\n", "for p_key, p_prefix in PERSONAS.items():\n",
" print(f\"\\n--- Running: {p_key} ---\")\n", " print(f\"\\n--- Running: {p_key} ---\")\n",
+9 -9
View File
@@ -39,13 +39,13 @@ NUM_EXAMPLES = 3
# %% # %%
def get_s_space_projector(model, top_k=1024): def get_s_space_svd(model):
""" """
Gathers all weight matrices that write to the residual stream Gathers all weight matrices that write to the residual stream
(o_proj from attention and down_proj from MLP) across all layers, (o_proj from attention and down_proj from MLP) across all layers,
and concatenates them to form a collective "write" transformation. and concatenates them to form a collective "write" transformation.
Then computes the SVD to extract the top_k modes. Then computes and returns the full SVD.
Returns: U (hidden_size, top_k), S (top_k,) Returns: U, S, Vh
""" """
Ws = [] Ws = []
for layer in model.model.layers: for layer in model.model.layers:
@@ -56,12 +56,9 @@ def get_s_space_projector(model, top_k=1024):
W = torch.cat(Ws, dim=1).to(model.device) W = torch.cat(Ws, dim=1).to(model.device)
# SVD on the collective weight matrix # SVD on the collective weight matrix
U, S, _ = torch.linalg.svd(W.float(), full_matrices=False) U, S, Vh = torch.linalg.svd(W.float(), full_matrices=False)
# Crop to top_k modes return U, S, Vh
U = U[:, :top_k]
S = S[:top_k]
return U, S
def project_to_s_space(hidden_states, U, S): def project_to_s_space(hidden_states, U, S):
""" """
@@ -216,7 +213,10 @@ plt.figure(figsize=(10, 6))
results = {} results = {}
# Project using the collective residual stream writers transformation # Project using the collective residual stream writers transformation
s_space_U, s_space_S = get_s_space_projector(model) U, S, Vh = get_s_space_svd(model)
top_k = 1024
s_space_U = U[:, :top_k]
s_space_S = S[:top_k]
for p_key, p_prefix in PERSONAS.items(): for p_key, p_prefix in PERSONAS.items():
print(f"\n--- Running: {p_key} ---") print(f"\n--- Running: {p_key} ---")